Skip to content

Commit f5f7538

Browse files
committed
PR feedback
1 parent e777535 commit f5f7538

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

src/citrine/informatics/data_sources.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import abstractmethod
33
from typing import Type, List, Mapping, Optional, Union
44
from uuid import UUID
5+
from warnings import warn
56

67
from citrine._serialization import properties
78
from citrine._serialization.polymorphic_serializable import PolymorphicSerializable
@@ -30,15 +31,16 @@ def __eq__(self, other):
3031
else:
3132
return False
3233

34+
@classmethod
35+
def _subclass_list(self) -> List[Type[Serializable]]:
36+
return [CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource]
37+
3338
@classmethod
3439
def get_type(cls, data) -> Type[Serializable]:
3540
"""Return the subtype."""
3641
if "type" not in data:
3742
raise ValueError("Can only get types from dicts with a 'type' key")
38-
types: List[Type[Serializable]] = [
39-
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource
40-
]
41-
res = next((x for x in types if x.typ == data["type"]), None)
43+
res = next((x for x in cls._subclass_list() if x.typ == data["type"]), None)
4244
if res is None:
4345
raise ValueError("Unrecognized type: {}".format(data["type"]))
4446
return res
@@ -52,10 +54,7 @@ def _data_source_type(self) -> str:
5254
def from_data_source_id(cls, data_source_id: str) -> "DataSource":
5355
"""Build a DataSource from a datasource_id."""
5456
terms = data_source_id.split("::")
55-
types: List[Type[Serializable]] = [
56-
CSVDataSource, GemTableDataSource, ExperimentDataSourceRef, SnapshotDataSource
57-
]
58-
res = next((x for x in types if x._data_source_type == terms[0]), None)
57+
res = next((x for x in cls._subclass_list() if x._data_source_type == terms[0]), None)
5958
if res is None:
6059
raise ValueError("Unrecognized type: {}".format(terms[0]))
6160
return res._data_source_id_builder(*terms[1:])
@@ -107,16 +106,17 @@ def __init__(self,
107106
@classmethod
108107
def _data_source_id_builder(cls, *args) -> DataSource:
109108
# TODO Figure out how to populate the column definitions
109+
warn("A CSVDataSource was derived from a data_source_id "
110+
"but is missing its column_definitions and identities",
111+
UserWarning)
110112
return CSVDataSource(
111113
file_link=FileLink(url=args[0], filename=args[1]),
112114
column_definitions={}
113115
)
114116

115117
def to_data_source_id(self) -> str:
116118
"""Generate the data_source_id for this DataSource."""
117-
return "::".join(
118-
str(x) for x in [self._data_source_type, self.file_link.url, self.file_link.filename]
119-
)
119+
return f"{self._data_source_type}::{self.file_link.url}::{self.file_link.filename}"
120120

121121

122122
class GemTableDataSource(Serializable['GemTableDataSource'], DataSource):
@@ -151,9 +151,7 @@ def _data_source_id_builder(cls, *args) -> DataSource:
151151

152152
def to_data_source_id(self) -> str:
153153
"""Generate the data_source_id for this DataSource."""
154-
return "::".join(
155-
str(x) for x in [self._data_source_type, self.table_id, self.table_version]
156-
)
154+
return f"{self._data_source_type}::{self.table_id}::{self.table_version}"
157155

158156
@classmethod
159157
def from_gemtable(cls, table: GemTable) -> "GemTableDataSource":
@@ -192,7 +190,7 @@ def _data_source_id_builder(cls, *args) -> DataSource:
192190

193191
def to_data_source_id(self) -> str:
194192
"""Generate the data_source_id for this DataSource."""
195-
return "::".join(str(x) for x in [self._data_source_type, self.datasource_id])
193+
return f"{self._data_source_type}::{self.datasource_id}"
196194

197195

198196
class SnapshotDataSource(Serializable['SnapshotDataSource'], DataSource):
@@ -219,4 +217,4 @@ def _data_source_id_builder(cls, *args) -> DataSource:
219217

220218
def to_data_source_id(self) -> str:
221219
"""Generate the data_source_id for this DataSource."""
222-
return "::".join(str(x) for x in [self._data_source_type, self.snapshot_id])
220+
return f"{self._data_source_type}::{self.snapshot_id}"

tests/informatics/test_data_source.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def test_invalid_deser():
4949
def test_data_source_id(data_source):
5050
if isinstance(data_source, CSVDataSource):
5151
# TODO: There's no obvious way to recover the column_definitions & identifiers from the ID
52-
transformed = DataSource.from_data_source_id(data_source.to_data_source_id())
52+
with pytest.warns(UserWarning):
53+
transformed = DataSource.from_data_source_id(data_source.to_data_source_id())
5354
assert isinstance(data_source, CSVDataSource)
5455
assert transformed.file_link == data_source.file_link
5556
else:

0 commit comments

Comments
 (0)