Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a datasource field to the DesignWorkflow object #968

Merged
merged 5 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.8.0"
__version__ = "3.9.0"
107 changes: 106 additions & 1 deletion src/citrine/informatics/data_sources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tools for working with Descriptors."""
from abc import abstractmethod
from typing import Type, List, Mapping, Optional, Union
from uuid import UUID

Expand All @@ -7,6 +8,7 @@
from citrine._serialization.serializable import Serializable
from citrine.informatics.descriptors import Descriptor
from citrine.resources.file_link import FileLink
from citrine.resources.gemtables import GemTable

__all__ = ['DataSource',
'CSVDataSource',
Expand Down Expand Up @@ -34,13 +36,39 @@ def get_type(cls, data) -> Type[Serializable]:
if "type" not in data:
raise ValueError("Can only get types from dicts with a 'type' key")
types: List[Type[Serializable]] = [
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource
]
res = next((x for x in types if x.typ == data["type"]), None)
if res is None:
raise ValueError("Unrecognized type: {}".format(data["type"]))
return res

@property
@abstractmethod
def _data_source_type(self) -> str:
"""The data source type string, which is the leading term of the data_source_id."""

@classmethod
def from_data_source_id(cls, data_source_id: str) -> "DataSource":
"""Build a DataSource from a datasource_id."""
terms = data_source_id.split("::")
types: List[Type[Serializable]] = [
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: maybe lift this list to the class level, as it's duplicated above.

res = next((x for x in types if x._data_source_type == terms[0]), None)
if res is None:
raise ValueError("Unrecognized type: {}".format(terms[0]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: f-string.

return res._data_source_id_builder(*terms[1:])

@classmethod
@abstractmethod
def _data_source_id_builder(cls, *args) -> "DataSource":
"""Build a DataSource based on a parsed data_source_id."""

@abstractmethod
def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""


class CSVDataSource(Serializable['CSVDataSource'], DataSource):
"""A data source based on a CSV file stored on the data platform.
Expand All @@ -65,6 +93,8 @@ class CSVDataSource(Serializable['CSVDataSource'], DataSource):
properties.String, properties.Object(Descriptor), "column_definitions")
identifiers = properties.Optional(properties.List(properties.String), "identifiers")

_data_source_type = "csv"

def __init__(self,
*,
file_link: FileLink,
Expand All @@ -74,6 +104,20 @@ def __init__(self,
self.column_definitions = column_definitions
self.identifiers = identifiers

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
# TODO Figure out how to populate the column definitions
return CSVDataSource(
file_link=FileLink(url=args[0], filename=args[1]),
column_definitions={}
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no obvious method to identify what the column definitions and identifiers would be from just the DataSourceId.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be too disruptive to issue a warning here, so the user knows that the data source they're getting is incomplete? Actually, maybe the warning should come when they access the DesignWorkflow.data_source field, if it's a CSVDataSource with no column_definition.

Either way, I'd like if we can alert them this isn't a completely accurate representation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, using a warning makes sense. I don't think people are using CSVDataSources anymore, but...


def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return "::".join(
str(x) for x in [self._data_source_type, self.file_link.url, self.file_link.filename]
)


class GemTableDataSource(Serializable['GemTableDataSource'], DataSource):
"""A data source based on a GEM Table hosted on the data platform.
Expand All @@ -92,13 +136,37 @@ class GemTableDataSource(Serializable['GemTableDataSource'], DataSource):
table_id = properties.UUID("table_id")
table_version = properties.Integer("table_version")

_data_source_type = "gemd"

def __init__(self,
*,
table_id: UUID,
table_version: Union[int, str]):
self.table_id: UUID = table_id
self.table_version: Union[int, str] = table_version

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
return GemTableDataSource(table_id=UUID(args[0]), table_version=args[1])

def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return "::".join(
str(x) for x in [self._data_source_type, self.table_id, self.table_version]
)
Copy link
Collaborator

@anoto-moniz anoto-moniz Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: what's the benefit of this over an f-string?

f{self._data_source_type}::{self.table_id}::{self.table_version}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thought when writing was that we have a series of identifiers joined by :: in all cases, but yes, it ended up not great. I thought of creating a get_identifiers method that got joined at the parent level, but then that seemed a bit absurd. I'll swap to f-string.


@classmethod
def from_gemtable(cls, table: GemTable) -> "GemTableDataSource":
"""Generate a DataSource that corresponds to a GemTable.

Parameters
----------
table: GemTable
The GemTable object to reference

"""
return GemTableDataSource(table_id=table.uid, table_version=table.version)


class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSource):
"""A reference to a data source based on an experiment result hosted on the data platform.
Expand All @@ -113,5 +181,42 @@ class ExperimentDataSourceRef(Serializable['ExperimentDataSourceRef'], DataSourc
typ = properties.String('type', default='experiments_data_source', deserializable=False)
datasource_id = properties.UUID("datasource_id")

_data_source_type = "experiments"

def __init__(self, *, datasource_id: UUID):
self.datasource_id: UUID = datasource_id

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
return ExperimentDataSourceRef(datasource_id=UUID(args[0]))

def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return "::".join(str(x) for x in [self._data_source_type, self.datasource_id])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same nitpick as above.



class SnapshotDataSource(Serializable['SnapshotDataSource'], DataSource):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will presumably be useful soon, since it's the next step after Multistep Materials tables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Bill and I were just talking about what's needed to properly support training, etc on snapshots.

"""A reference to a data source based on a Snapshot on the data platform.

Parameters
----------
snapshot_id: UUID
Unique identifier for the Snapshot Data Source

"""

typ = properties.String('type', default='snapshot_data_source', deserializable=False)
snapshot_id = properties.UUID("snapshot_id")

_data_source_type = "snapshot"

def __init__(self, *, snapshot_id: UUID):
self.snapshot_id = snapshot_id

@classmethod
def _data_source_id_builder(cls, *args) -> DataSource:
return SnapshotDataSource(snapshot_id=UUID(args[0]))

def to_data_source_id(self) -> str:
"""Generate the data_source_id for this DataSource."""
return "::".join(str(x) for x in [self._data_source_type, self.snapshot_id])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

38 changes: 37 additions & 1 deletion src/citrine/informatics/workflows/design_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from citrine._rest.resource import Resource
from citrine._serialization import properties
from citrine.informatics.data_sources import DataSource
from citrine.informatics.workflows.workflow import Workflow
from citrine.resources.design_execution import DesignExecutionCollection
from citrine._rest.ai_resource_metadata import AIResourceMetadata
Expand Down Expand Up @@ -31,11 +32,12 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata):
design_space_id = properties.Optional(properties.UUID, 'design_space_id')
predictor_id = properties.Optional(properties.UUID, 'predictor_id')
predictor_version = properties.Optional(
properties.Union([properties.Integer(), properties.String()]), 'predictor_version')
properties.Union([properties.Integer, properties.String]), 'predictor_version')
branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id')
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version')
""":Optional[int]: Version number of the branch that contains this workflow."""
data_source = properties.Optional(properties.Object(DataSource), "data_source")

status_description = properties.String('status_description', serializable=False)
""":str: more detailed description of the workflow's status"""
Expand All @@ -50,20 +52,54 @@ def __init__(self,
design_space_id: Optional[UUID] = None,
predictor_id: Optional[UUID] = None,
predictor_version: Optional[Union[int, str]] = None,
data_source: Optional[DataSource] = None,
description: Optional[str] = None):
self.name = name
self.design_space_id = design_space_id
self.predictor_id = predictor_id
self.predictor_version = predictor_version
self.data_source = data_source
self.description = description

def __str__(self):
return '<DesignWorkflow {!r}>'.format(self.name)

@classmethod
def _pre_build(cls, data: dict) -> dict:
"""Run data modification before building."""
data_source_id = data.pop("data_source_id", None)
if data_source_id is not None:
data["data_source"] = DataSource.from_data_source_id(data_source_id).dump()
return data

def _post_dump(self, data: dict) -> dict:
"""Run data modification after dumping."""
data_source = data.pop("data_source", None)
if data_source is not None:
data["data_source_id"] = DataSource.build(data_source).to_data_source_id()
else:
data["data_source_id"] = None
return data

Comment on lines +67 to +83
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the existing serde hooks to filter out data_source_id fields and insert DataSource objects.

@property
def design_executions(self) -> DesignExecutionCollection:
"""Return a resource representing all visible executions of this workflow."""
if getattr(self, 'project_id', None) is None:
raise AttributeError('Cannot initialize execution without project reference!')
return DesignExecutionCollection(
project_id=self.project_id, session=self._session, workflow_id=self.uid)

@property
def data_source_id(self) -> Optional[str]:
"""A resource referencing the workflow's data source."""
if self.data_source is None:
return None
else:
return self.data_source.to_data_source_id()

@data_source_id.setter
def data_source_id(self, value: Optional[str]):
if value is None:
self.data_source = None
else:
self.data_source = DataSource.from_data_source_id(value)
13 changes: 0 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,16 +949,3 @@ def predictor_evaluation_workflow_dict(generic_entity, example_cv_evaluator_dict
"evaluators": [example_cv_evaluator_dict, example_holdout_evaluator_dict]
})
return ret

@pytest.fixture
def design_workflow_dict(generic_entity):
ret = generic_entity.copy()
ret.update({
"name": "Example Design Workflow",
"description": "A description! Of the Design Workflow! So you know what it's for!",
"design_space_id": str(uuid.uuid4()),
"predictor_id": str(uuid.uuid4()),
"predictor_version": random.randint(1, 10),
"branch_id": str(uuid.uuid4()),
})
return ret
Comment on lines -952 to -964
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having two different DesignWorkflow test objects (a fixture & a Factory) made testing on the updated object complicated.

31 changes: 28 additions & 3 deletions tests/informatics/test_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@

import pytest

from citrine.informatics.data_sources import DataSource, CSVDataSource, ExperimentDataSourceRef, GemTableDataSource
from citrine.informatics.descriptors import RealDescriptor, FormulationDescriptor
from citrine.informatics.data_sources import (
DataSource, CSVDataSource, ExperimentDataSourceRef, GemTableDataSource, SnapshotDataSource
)
from citrine.informatics.descriptors import RealDescriptor
from citrine.resources.file_link import FileLink
from citrine.resources.gemtables import GemTable

from tests.utils.factories import GemTableDataFactory

@pytest.fixture(params=[
CSVDataSource(file_link=FileLink("foo.spam", "http://example.com"),
Expand All @@ -15,7 +19,8 @@
GemTableDataSource(table_id=uuid.uuid4(), table_version=1),
GemTableDataSource(table_id=uuid.uuid4(), table_version="2"),
GemTableDataSource(table_id=uuid.uuid4(), table_version="2"),
ExperimentDataSourceRef(datasource_id=uuid.uuid4())
ExperimentDataSourceRef(datasource_id=uuid.uuid4()),
SnapshotDataSource(snapshot_id=uuid.uuid4())
])
def data_source(request):
return request.param
Expand All @@ -39,3 +44,23 @@ def test_invalid_deser():

with pytest.raises(ValueError):
DataSource.build({"type": "foo"})


def test_data_source_id(data_source):
if isinstance(data_source, CSVDataSource):
# TODO: There's no obvious way to recover the column_definitions & identifiers from the ID
transformed = DataSource.from_data_source_id(data_source.to_data_source_id())
assert isinstance(data_source, CSVDataSource)
assert transformed.file_link == data_source.file_link
else:
assert data_source == DataSource.from_data_source_id(data_source.to_data_source_id())

def test_from_gem_table():
table = GemTable.build(GemTableDataFactory())
data_source = GemTableDataSource.from_gemtable(table)
assert data_source.table_id == table.uid
assert data_source.table_version == table.version

def test_invalid_data_source_id():
with pytest.raises(ValueError):
DataSource.from_data_source_id(f"Undefined::{uuid.uuid4()}")
16 changes: 6 additions & 10 deletions tests/informatics/test_workflows.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
"""Tests for citrine.informatics.workflows."""
import json
from multiprocessing.reduction import register
from uuid import uuid4, UUID

import mock
import pytest

from citrine._session import Session
from citrine.informatics.design_candidate import DesignMaterial, DesignVariable, DesignCandidate, ChemicalFormula, \
from citrine.informatics.design_candidate import DesignMaterial, DesignCandidate, ChemicalFormula, \
MeanAndStd, TopCategories, Mixture, MolecularStructure
from citrine.informatics.executions import DesignExecution
from citrine.informatics.predict_request import PredictRequest
from citrine.informatics.workflows import DesignWorkflow, Workflow
from citrine.informatics.workflows import DesignWorkflow
from citrine.resources.design_execution import DesignExecutionCollection
from citrine.resources.design_workflow import DesignWorkflowCollection

from tests.utils.factories import BranchDataFactory
from tests.utils.factories import BranchDataFactory, DesignWorkflowDataFactory
from tests.utils.session import FakeSession, FakeCall


Expand Down Expand Up @@ -48,10 +46,8 @@ def execution_collection(session) -> DesignExecutionCollection:


@pytest.fixture
def design_workflow(collection, design_workflow_dict) -> DesignWorkflow:
workflow = collection.build(design_workflow_dict)
collection.session.calls.clear()
return workflow
def design_workflow(collection) -> DesignWorkflow:
return collection.build(DesignWorkflowDataFactory(register=True))

@pytest.fixture
def design_execution(execution_collection, design_execution_dict) -> DesignExecution:
Expand Down
Loading
Loading