Skip to content

Commit f53a021

Browse files
committed
Add a datasource field to the DesignWorkflow object
1 parent 510152d commit f53a021

File tree

9 files changed

+277
-85
lines changed

9 files changed

+277
-85
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.8.0"
1+
__version__ = "3.9.0"

src/citrine/informatics/data_sources.py

+106-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tools for working with Descriptors."""
2+
from abc import abstractmethod
23
from typing import Type, List, Mapping, Optional, Union
34
from uuid import UUID
45

@@ -7,6 +8,7 @@
78
from citrine._serialization.serializable import Serializable
89
from citrine.informatics.descriptors import Descriptor
910
from citrine.resources.file_link import FileLink
11+
from citrine.resources.gemtables import GemTable
1012

1113
__all__ = ['DataSource',
1214
'CSVDataSource',
@@ -34,13 +36,39 @@ def get_type(cls, data) -> Type[Serializable]:
3436
if "type" not in data:
3537
raise ValueError("Can only get types from dicts with a 'type' key")
3638
types: List[Type[Serializable]] = [
37-
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef
39+
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource
3840
]
3941
res = next((x for x in types if x.typ == data["type"]), None)
4042
if res is None:
4143
raise ValueError("Unrecognized type: {}".format(data["type"]))
4244
return res
4345

46+
@property
47+
@abstractmethod
48+
def _data_source_type(self) -> str:
49+
"""The data source type string, which is the leading term of the data_source_id."""
50+
51+
@classmethod
52+
def from_data_source_id(cls, data_source_id: str) -> "DataSource":
53+
"""Build a DataSource from a datasource_id."""
54+
terms = data_source_id.split("::")
55+
types: List[Type[Serializable]] = [
56+
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource
57+
]
58+
res = next((x for x in types if x._data_source_type == terms[0]), None)
59+
if res is None:
60+
raise ValueError("Unrecognized type: {}".format(terms[0]))
61+
return res._data_source_id_builder(*terms[1:])
62+
63+
@classmethod
64+
@abstractmethod
65+
def _data_source_id_builder(cls, *args) -> "DataSource":
66+
"""Build a DataSource based on a parsed data_source_id."""
67+
68+
@abstractmethod
69+
def to_data_source_id(self) -> str:
70+
"""Generate the data_source_id for this DataSource."""
71+
4472

4573
class CSVDataSource(Serializable['CSVDataSource'], DataSource):
4674
"""A data source based on a CSV file stored on the data platform.
@@ -65,6 +93,8 @@ class CSVDataSource(Serializable['CSVDataSource'], DataSource):
6593
properties.String, properties.Object(Descriptor), "column_definitions")
6694
identifiers = properties.Optional(properties.List(properties.String), "identifiers")
6795

96+
_data_source_type = "csv"
97+
6898
def __init__(self,
6999
*,
70100
file_link: FileLink,
@@ -74,6 +104,20 @@ def __init__(self,
74104
self.column_definitions = column_definitions
75105
self.identifiers = identifiers
76106

107+
@classmethod
108+
def _data_source_id_builder(cls, *args) -> DataSource:
109+
# TODO Figure out how to populate the column definitions
110+
return CSVDataSource(
111+
file_link=FileLink(url=args[0], filename=args[1]),
112+
column_definitions={}
113+
)
114+
115+
def to_data_source_id(self) -> str:
116+
"""Generate the data_source_id for this DataSource."""
117+
return "::".join(
118+
str(x) for x in [self._data_source_type, self.file_link.url, self.file_link.filename]
119+
)
120+
77121

78122
class GemTableDataSource(Serializable['GemTableDataSource'], DataSource):
79123
"""A data source based on a GEM Table hosted on the data platform.
@@ -92,13 +136,37 @@ class GemTableDataSource(Serializable['GemTableDataSource'], DataSource):
92136
table_id = properties.UUID("table_id")
93137
table_version = properties.Integer("table_version")
94138

139+
_data_source_type = "gemd"
140+
95141
def __init__(self,
96142
*,
97143
table_id: UUID,
98144
table_version: Union[int, str]):
99145
self.table_id: UUID = table_id
100146
self.table_version: Union[int, str] = table_version
101147

148+
@classmethod
149+
def _data_source_id_builder(cls, *args) -> DataSource:
150+
return GemTableDataSource(table_id=UUID(args[0]), table_version=args[1])
151+
152+
def to_data_source_id(self) -> str:
153+
"""Generate the data_source_id for this DataSource."""
154+
return "::".join(
155+
str(x) for x in [self._data_source_type, self.table_id, self.table_version]
156+
)
157+
158+
@classmethod
159+
def from_gemtable(cls, table: GemTable) -> "GemTableDataSource":
160+
"""Generate a DataSource that corresponds to a GemTable.
161+
162+
Parameters
163+
----------
164+
table: GemTable
165+
The GemTable object to reference
166+
167+
"""
168+
return GemTableDataSource(table_id=table.uid, table_version=table.version)
169+
102170

103171
class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSource):
104172
"""A reference to a data source based on an experiment result hosted on the data platform.
@@ -113,5 +181,42 @@ class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSourc
113181
typ = properties.String('type', default='experiments_data_source', deserializable=False)
114182
datasource_id = properties.UUID("datasource_id")
115183

184+
_data_source_type = "experiments"
185+
116186
def __init__(self, *, datasource_id: UUID):
117187
self.datasource_id: UUID = datasource_id
188+
189+
@classmethod
190+
def _data_source_id_builder(cls, *args) -> DataSource:
191+
return ExperimentDataSourceRef(datasource_id=UUID(args[0]))
192+
193+
def to_data_source_id(self) -> str:
194+
"""Generate the data_source_id for this DataSource."""
195+
return "::".join(str(x) for x in [self._data_source_type, self.datasource_id])
196+
197+
198+
class SnapshotDataSource(Serializable['SnapshotDataSource'], DataSource):
199+
"""A reference to a data source based on a Snapshot on the data platform.
200+
201+
Parameters
202+
----------
203+
snapshot_id: UUID
204+
Unique identifier for the Snapshot Data Source
205+
206+
"""
207+
208+
typ = properties.String('type', default='snapshot_data_source', deserializable=False)
209+
snapshot_id = properties.UUID("snapshot_id")
210+
211+
_data_source_type = "snapshot"
212+
213+
def __init__(self, *, snapshot_id: UUID):
214+
self.snapshot_id = snapshot_id
215+
216+
@classmethod
217+
def _data_source_id_builder(cls, *args) -> DataSource:
218+
return SnapshotDataSource(snapshot_id=UUID(args[0]))
219+
220+
def to_data_source_id(self) -> str:
221+
"""Generate the data_source_id for this DataSource."""
222+
return "::".join(str(x) for x in [self._data_source_type, self.snapshot_id])

src/citrine/informatics/workflows/design_workflow.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from citrine._rest.resource import Resource
55
from citrine._serialization import properties
6+
from citrine.informatics.data_sources import DataSource
67
from citrine.informatics.workflows.workflow import Workflow
78
from citrine.resources.design_execution import DesignExecutionCollection
89
from citrine._rest.ai_resource_metadata import AIResourceMetadata
@@ -31,11 +32,12 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata):
3132
design_space_id = properties.Optional(properties.UUID, 'design_space_id')
3233
predictor_id = properties.Optional(properties.UUID, 'predictor_id')
3334
predictor_version = properties.Optional(
34-
properties.Union([properties.Integer(), properties.String()]), 'predictor_version')
35+
properties.Union([properties.Integer, properties.String]), 'predictor_version')
3536
branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id')
3637
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
3738
branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version')
3839
""":Optional[int]: Version number of the branch that contains this workflow."""
40+
data_source = properties.Optional(properties.Object(DataSource), "data_source")
3941

4042
status_description = properties.String('status_description', serializable=False)
4143
""":str: more detailed description of the workflow's status"""
@@ -50,20 +52,54 @@ def __init__(self,
5052
design_space_id: Optional[UUID] = None,
5153
predictor_id: Optional[UUID] = None,
5254
predictor_version: Optional[Union[int, str]] = None,
55+
data_source: Optional[DataSource] = None,
5356
description: Optional[str] = None):
5457
self.name = name
5558
self.design_space_id = design_space_id
5659
self.predictor_id = predictor_id
5760
self.predictor_version = predictor_version
61+
self.data_source = data_source
5862
self.description = description
5963

6064
def __str__(self):
6165
return '<DesignWorkflow {!r}>'.format(self.name)
6266

67+
@classmethod
68+
def _pre_build(cls, data: dict) -> dict:
69+
"""Run data modification before building."""
70+
data_source_id = data.pop("data_source_id", None)
71+
if data_source_id is not None:
72+
data["data_source"] = DataSource.from_data_source_id(data_source_id).dump()
73+
return data
74+
75+
def _post_dump(self, data: dict) -> dict:
76+
"""Run data modification after dumping."""
77+
data_source = data.pop("data_source", None)
78+
if data_source is not None:
79+
data["data_source_id"] = DataSource.build(data_source).to_data_source_id()
80+
else:
81+
data["data_source_id"] = None
82+
return data
83+
6384
@property
6485
def design_executions(self) -> DesignExecutionCollection:
6586
"""Return a resource representing all visible executions of this workflow."""
6687
if getattr(self, 'project_id', None) is None:
6788
raise AttributeError('Cannot initialize execution without project reference!')
6889
return DesignExecutionCollection(
6990
project_id=self.project_id, session=self._session, workflow_id=self.uid)
91+
92+
@property
93+
def data_source_id(self) -> Optional[str]:
94+
"""A resource referencing the workflow's data source."""
95+
if self.data_source is None:
96+
return None
97+
else:
98+
return self.data_source.to_data_source_id()
99+
100+
@data_source_id.setter
101+
def data_source_id(self, value: Optional[str]):
102+
if value is None:
103+
self.data_source = None
104+
else:
105+
self.data_source = DataSource.from_data_source_id(value)

tests/conftest.py

-13
Original file line numberDiff line numberDiff line change
@@ -949,16 +949,3 @@ def predictor_evaluation_workflow_dict(generic_entity, example_cv_evaluator_dict
949949
"evaluators": [example_cv_evaluator_dict, example_holdout_evaluator_dict]
950950
})
951951
return ret
952-
953-
@pytest.fixture
954-
def design_workflow_dict(generic_entity):
955-
ret = generic_entity.copy()
956-
ret.update({
957-
"name": "Example Design Workflow",
958-
"description": "A description! Of the Design Workflow! So you know what it's for!",
959-
"design_space_id": str(uuid.uuid4()),
960-
"predictor_id": str(uuid.uuid4()),
961-
"predictor_version": random.randint(1, 10),
962-
"branch_id": str(uuid.uuid4()),
963-
})
964-
return ret

tests/informatics/test_data_source.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33

44
import pytest
55

6-
from citrine.informatics.data_sources import DataSource, CSVDataSource, ExperimentDataSourceRef, GemTableDataSource
7-
from citrine.informatics.descriptors import RealDescriptor, FormulationDescriptor
6+
from citrine.informatics.data_sources import (
7+
DataSource, CSVDataSource, ExperimentDataSourceRef, GemTableDataSource, SnapshotDataSource
8+
)
9+
from citrine.informatics.descriptors import RealDescriptor
810
from citrine.resources.file_link import FileLink
11+
from citrine.resources.gemtables import GemTable
912

13+
from tests.utils.factories import GemTableDataFactory
1014

1115
@pytest.fixture(params=[
1216
CSVDataSource(file_link=FileLink("foo.spam", "http://example.com"),
@@ -15,7 +19,8 @@
1519
GemTableDataSource(table_id=uuid.uuid4(), table_version=1),
1620
GemTableDataSource(table_id=uuid.uuid4(), table_version="2"),
1721
GemTableDataSource(table_id=uuid.uuid4(), table_version="2"),
18-
ExperimentDataSourceRef(datasource_id=uuid.uuid4())
22+
ExperimentDataSourceRef(datasource_id=uuid.uuid4()),
23+
SnapshotDataSource(snapshot_id=uuid.uuid4())
1924
])
2025
def data_source(request):
2126
return request.param
@@ -39,3 +44,23 @@ def test_invalid_deser():
3944

4045
with pytest.raises(ValueError):
4146
DataSource.build({"type": "foo"})
47+
48+
49+
def test_data_source_id(data_source):
50+
if isinstance(data_source, CSVDataSource):
51+
# TODO: There's no obvious way to recover the column_definitions & identifiers from the ID
52+
transformed = DataSource.from_data_source_id(data_source.to_data_source_id())
53+
assert isinstance(data_source, CSVDataSource)
54+
assert transformed.file_link == data_source.file_link
55+
else:
56+
assert data_source == DataSource.from_data_source_id(data_source.to_data_source_id())
57+
58+
def test_from_gem_table():
59+
table = GemTable.build(GemTableDataFactory())
60+
data_source = GemTableDataSource.from_gemtable(table)
61+
assert data_source.table_id == table.uid
62+
assert data_source.table_version == table.version
63+
64+
def test_invalid_data_source_id():
65+
with pytest.raises(ValueError):
66+
DataSource.from_data_source_id(f"Undefined::{uuid.uuid4()}")

tests/informatics/test_workflows.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
"""Tests for citrine.informatics.workflows."""
2-
import json
2+
from multiprocessing.reduction import register
33
from uuid import uuid4, UUID
44

5-
import mock
65
import pytest
76

8-
from citrine._session import Session
9-
from citrine.informatics.design_candidate import DesignMaterial, DesignVariable, DesignCandidate, ChemicalFormula, \
7+
from citrine.informatics.design_candidate import DesignMaterial, DesignCandidate, ChemicalFormula, \
108
MeanAndStd, TopCategories, Mixture, MolecularStructure
119
from citrine.informatics.executions import DesignExecution
1210
from citrine.informatics.predict_request import PredictRequest
13-
from citrine.informatics.workflows import DesignWorkflow, Workflow
11+
from citrine.informatics.workflows import DesignWorkflow
1412
from citrine.resources.design_execution import DesignExecutionCollection
1513
from citrine.resources.design_workflow import DesignWorkflowCollection
1614

17-
from tests.utils.factories import BranchDataFactory
15+
from tests.utils.factories import BranchDataFactory, DesignWorkflowDataFactory
1816
from tests.utils.session import FakeSession, FakeCall
1917

2018

@@ -48,10 +46,8 @@ def execution_collection(session) -> DesignExecutionCollection:
4846

4947

5048
@pytest.fixture
51-
def design_workflow(collection, design_workflow_dict) -> DesignWorkflow:
52-
workflow = collection.build(design_workflow_dict)
53-
collection.session.calls.clear()
54-
return workflow
49+
def design_workflow(collection) -> DesignWorkflow:
50+
return collection.build(DesignWorkflowDataFactory(register=True))
5551

5652
@pytest.fixture
5753
def design_execution(execution_collection, design_execution_dict) -> DesignExecution:

0 commit comments

Comments
 (0)