Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a datasource field to the DesignWorkflow object #968

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.9.0"
__version__ = "3.10.0"
124 changes: 115 additions & 9 deletions src/citrine/informatics/data_sources.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
"""Tools for working with Descriptors."""
from abc import abstractmethod
from typing import Type, List, Mapping, Optional, Union
from uuid import UUID
from warnings import warn

from citrine._serialization import properties
from citrine._serialization.polymorphic_serializable import PolymorphicSerializable
from citrine._serialization.serializable import Serializable
from citrine.informatics.descriptors import Descriptor
from citrine.resources.file_link import FileLink
from citrine.resources.gemtables import GemTable

__all__ = ['DataSource',
'CSVDataSource',
'GemTableDataSource',
'ExperimentDataSourceRef']
__all__ = [
'DataSource',
'CSVDataSource',
'GemTableDataSource',
'ExperimentDataSourceRef',
'SnapshotDataSource',
]


class DataSource(PolymorphicSerializable['DataSource']):
Expand All @@ -28,19 +34,43 @@ def __eq__(self, other):
else:
return False

@classmethod
def _subclass_list(self) -> List[Type[Serializable]]:
return [CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource]

@classmethod
def get_type(cls, data) -> Type[Serializable]:
"""Return the subtype."""
if "type" not in data:
raise ValueError("Can only get types from dicts with a 'type' key")
types: List[Type[Serializable]] = [
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef
]
res = next((x for x in types if x.typ == data["type"]), None)
res = next((x for x in cls._subclass_list() if x.typ == data["type"]), None)
if res is None:
raise ValueError("Unrecognized type: {}".format(data["type"]))
raise ValueError(f"Unrecognized type: {data['type']}")
return res

@property
@abstractmethod
def _data_source_type(self) -> str:
"""The data source type string, which is the leading term of the data_source_id."""

@classmethod
def from_data_source_id(cls, data_source_id: str) -> "DataSource":
"""Build a DataSource from a datasource_id."""
terms = data_source_id.split("::")
res = next((x for x in cls._subclass_list() if x._data_source_type == terms[0]), None)
if res is None:
raise ValueError(f"Unrecognized type: {terms[0]}")
return res._data_source_id_builder(*terms[1:])

@classmethod
@abstractmethod
def _data_source_id_builder(cls, *args) -> "DataSource":
"""Build a DataSource based on a parsed data_source_id."""

@abstractmethod
def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""


class CSVDataSource(Serializable['CSVDataSource'], DataSource):
"""A data source based on a CSV file stored on the data platform.
Expand All @@ -65,6 +95,8 @@ class CSVDataSource(Serializable['CSVDataSource'], DataSource):
properties.String, properties.Object(Descriptor), "column_definitions")
identifiers = properties.Optional(properties.List(properties.String), "identifiers")

_data_source_type = "csv"

def __init__(self,
*,
file_link: FileLink,
Expand All @@ -74,6 +106,21 @@ def __init__(self,
self.column_definitions = column_definitions
self.identifiers = identifiers

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
# TODO Figure out how to populate the column definitions
warn("A CSVDataSource was derived from a data_source_id "
"but is missing its column_definitions and identities",
UserWarning)
return CSVDataSource(
file_link=FileLink(url=args[0], filename=args[1]),
column_definitions={}
)

def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return f"{self._data_source_type}::{self.file_link.url}::{self.file_link.filename}"


class GemTableDataSource(Serializable['GemTableDataSource'], DataSource):
"""A data source based on a GEM Table hosted on the data platform.
Expand All @@ -92,13 +139,35 @@ class GemTableDataSource(Serializable['GemTableDataSource'], DataSource):
table_id = properties.UUID("table_id")
table_version = properties.Integer("table_version")

_data_source_type = "gemd"

def __init__(self,
*,
table_id: UUID,
table_version: Union[int, str]):
self.table_id: UUID = table_id
self.table_version: Union[int, str] = table_version

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
return GemTableDataSource(table_id=UUID(args[0]), table_version=args[1])

def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return f"{self._data_source_type}::{self.table_id}::{self.table_version}"

@classmethod
def from_gemtable(cls, table: GemTable) -> "GemTableDataSource":
"""Generate a DataSource that corresponds to a GemTable.

Parameters
----------
table: GemTable
The GemTable object to reference

"""
return GemTableDataSource(table_id=table.uid, table_version=table.version)


class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSource):
"""A reference to a data source based on an experiment result hosted on the data platform.
Expand All @@ -113,5 +182,42 @@ class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSourc
typ = properties.String('type', default='experiments_data_source', deserializable=False)
datasource_id = properties.UUID("datasource_id")

_data_source_type = "experiments"

def __init__(self, *, datasource_id: UUID):
self.datasource_id: UUID = datasource_id

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
return ExperimentDataSourceRef(datasource_id=UUID(args[0]))

def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return f"{self._data_source_type}::{self.datasource_id}"


class SnapshotDataSource(Serializable['SnapshotDataSource'], DataSource):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will presumably be useful soon, since it's the next step after Multistep Materials tables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Bill and I were just talking about what's needed to properly support training, etc on snapshots.

"""A reference to a data source based on a Snapshot on the data platform.

Parameters
----------
snapshot_id: UUID
Unique identifier for the Snapshot Data Source

"""

typ = properties.String('type', default='snapshot_data_source', deserializable=False)
snapshot_id = properties.UUID("snapshot_id")

_data_source_type = "snapshot"

def __init__(self, *, snapshot_id: UUID):
self.snapshot_id = snapshot_id

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
return SnapshotDataSource(snapshot_id=UUID(args[0]))

def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return f"{self._data_source_type}::{self.snapshot_id}"
38 changes: 37 additions & 1 deletion src/citrine/informatics/workflows/design_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from citrine._rest.resource import Resource
from citrine._serialization import properties
from citrine.informatics.data_sources import DataSource
from citrine.informatics.workflows.workflow import Workflow
from citrine.resources.design_execution import DesignExecutionCollection
from citrine._rest.ai_resource_metadata import AIResourceMetadata
Expand Down Expand Up @@ -31,11 +32,12 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata):
design_space_id = properties.Optional(properties.UUID, 'design_space_id')
predictor_id = properties.Optional(properties.UUID, 'predictor_id')
predictor_version = properties.Optional(
properties.Union([properties.Integer(), properties.String()]), 'predictor_version')
properties.Union([properties.Integer, properties.String]), 'predictor_version')
branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id')
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version')
""":Optional[int]: Version number of the branch that contains this workflow."""
data_source = properties.Optional(properties.Object(DataSource), "data_source")

status_description = properties.String('status_description', serializable=False)
""":str: more detailed description of the workflow's status"""
Expand All @@ -50,20 +52,54 @@ def __init__(self,
design_space_id: Optional[UUID] = None,
predictor_id: Optional[UUID] = None,
predictor_version: Optional[Union[int, str]] = None,
data_source: Optional[DataSource] = None,
description: Optional[str] = None):
self.name = name
self.design_space_id = design_space_id
self.predictor_id = predictor_id
self.predictor_version = predictor_version
self.data_source = data_source
self.description = description

def __str__(self):
return '<DesignWorkflow {!r}>'.format(self.name)

@classmethod
def _pre_build(cls, data: dict) -> dict:
"""Run data modification before building."""
data_source_id = data.pop("data_source_id", None)
if data_source_id is not None:
data["data_source"] = DataSource.from_data_source_id(data_source_id).dump()
return data

def _post_dump(self, data: dict) -> dict:
"""Run data modification after dumping."""
data_source = data.pop("data_source", None)
if data_source is not None:
data["data_source_id"] = DataSource.build(data_source).to_data_source_id()
else:
data["data_source_id"] = None
return data

Comment on lines +67 to +83
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the existing serde hooks to filter out data_source_id fields and insert DataSource objects.

@property
def design_executions(self) -> DesignExecutionCollection:
"""Return a resource representing all visible executions of this workflow."""
if getattr(self, 'project_id', None) is None:
raise AttributeError('Cannot initialize execution without project reference!')
return DesignExecutionCollection(
project_id=self.project_id, session=self._session, workflow_id=self.uid)

@property
def data_source_id(self) -> Optional[str]:
"""A resource referencing the workflow's data source."""
if self.data_source is None:
return None
else:
return self.data_source.to_data_source_id()

@data_source_id.setter
def data_source_id(self, value: Optional[str]):
if value is None:
self.data_source = None
else:
self.data_source = DataSource.from_data_source_id(value)
13 changes: 0 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,16 +949,3 @@ def predictor_evaluation_workflow_dict(generic_entity, example_cv_evaluator_dict
"evaluators": [example_cv_evaluator_dict, example_holdout_evaluator_dict]
})
return ret

@pytest.fixture
def design_workflow_dict(generic_entity):
ret = generic_entity.copy()
ret.update({
"name": "Example Design Workflow",
"description": "A description! Of the Design Workflow! So you know what it's for!",
"design_space_id": str(uuid.uuid4()),
"predictor_id": str(uuid.uuid4()),
"predictor_version": random.randint(1, 10),
"branch_id": str(uuid.uuid4()),
})
return ret
Comment on lines -952 to -964
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having two different DesignWorkflow test objects (a fixture & a Factory) made testing on the updated object complicated.

32 changes: 29 additions & 3 deletions tests/informatics/test_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

import pytest

from citrine.informatics.data_sources import DataSource, CSVDataSource, ExperimentDataSourceRef, GemTableDataSource
from citrine.informatics.descriptors import RealDescriptor, FormulationDescriptor
from citrine.informatics.data_sources import (
DataSource, CSVDataSource, ExperimentDataSourceRef, GemTableDataSource, SnapshotDataSource
)
from citrine.informatics.descriptors import RealDescriptor
from citrine.resources.file_link import FileLink
from citrine.resources.gemtables import GemTable

from tests.utils.factories import GemTableDataFactory

@pytest.fixture(params=[
CSVDataSource(file_link=FileLink("foo.spam", "http://example.com"),
Expand All @@ -15,7 +19,8 @@
GemTableDataSource(table_id=uuid.uuid4(), table_version=1),
GemTableDataSource(table_id=uuid.uuid4(), table_version="2"),
GemTableDataSource(table_id=uuid.uuid4(), table_version="2"),
ExperimentDataSourceRef(datasource_id=uuid.uuid4())
ExperimentDataSourceRef(datasource_id=uuid.uuid4()),
SnapshotDataSource(snapshot_id=uuid.uuid4())
])
def data_source(request):
return request.param
Expand All @@ -39,3 +44,24 @@ def test_invalid_deser():

with pytest.raises(ValueError):
DataSource.build({"type": "foo"})


def test_data_source_id(data_source):
if isinstance(data_source, CSVDataSource):
# TODO: There's no obvious way to recover the column_definitions & identifiers from the ID
with pytest.warns(UserWarning):
transformed = DataSource.from_data_source_id(data_source.to_data_source_id())
assert isinstance(data_source, CSVDataSource)
assert transformed.file_link == data_source.file_link
else:
assert data_source == DataSource.from_data_source_id(data_source.to_data_source_id())

def test_from_gem_table():
table = GemTable.build(GemTableDataFactory())
data_source = GemTableDataSource.from_gemtable(table)
assert data_source.table_id == table.uid
assert data_source.table_version == table.version

def test_invalid_data_source_id():
with pytest.raises(ValueError):
DataSource.from_data_source_id(f"Undefined::{uuid.uuid4()}")
16 changes: 6 additions & 10 deletions tests/informatics/test_workflows.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
"""Tests for citrine.informatics.workflows."""
import json
from multiprocessing.reduction import register
from uuid import uuid4, UUID

import mock
import pytest

from citrine._session import Session
from citrine.informatics.design_candidate import DesignMaterial, DesignVariable, DesignCandidate, ChemicalFormula, \
from citrine.informatics.design_candidate import DesignMaterial, DesignCandidate, ChemicalFormula, \
MeanAndStd, TopCategories, Mixture, MolecularStructure
from citrine.informatics.executions import DesignExecution
from citrine.informatics.predict_request import PredictRequest
from citrine.informatics.workflows import DesignWorkflow, Workflow
from citrine.informatics.workflows import DesignWorkflow
from citrine.resources.design_execution import DesignExecutionCollection
from citrine.resources.design_workflow import DesignWorkflowCollection

from tests.utils.factories import BranchDataFactory
from tests.utils.factories import BranchDataFactory, DesignWorkflowDataFactory
from tests.utils.session import FakeSession, FakeCall


Expand Down Expand Up @@ -48,10 +46,8 @@ def execution_collection(session) -> DesignExecutionCollection:


@pytest.fixture
def design_workflow(collection, design_workflow_dict) -> DesignWorkflow:
workflow = collection.build(design_workflow_dict)
collection.session.calls.clear()
return workflow
def design_workflow(collection) -> DesignWorkflow:
return collection.build(DesignWorkflowDataFactory(register=True))

@pytest.fixture
def design_execution(execution_collection, design_execution_dict) -> DesignExecution:
Expand Down
Loading