Skip to content

Commit ede810b

Browse files
committed
Address DW with branch root ID and version.
Now that the design workflow APIs support branch root ID and version explicitly, the SDK should use them. This cuts down on excess requests, resulting in much faster executions.
1 parent 0e454fe commit ede810b

File tree

8 files changed

+62
-127
lines changed

8 files changed

+62
-127
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.6.0"
1+
__version__ = "3.7.0"

src/citrine/informatics/workflows/design_workflow.py

+6-29
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,17 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata):
3232
predictor_id = properties.Optional(properties.UUID, 'predictor_id')
3333
predictor_version = properties.Optional(
3434
properties.Union([properties.Integer(), properties.String()]), 'predictor_version')
35-
_branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id')
35+
branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id')
36+
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
37+
branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version')
38+
""":Optional[int]: Version number of the branch that contains this workflow."""
3639

3740
status_description = properties.String('status_description', serializable=False)
3841
""":str: more detailed description of the workflow's status"""
3942
typ = properties.String('type', default='DesignWorkflow', deserializable=False)
4043

41-
_branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id',
42-
serializable=False, deserializable=False)
43-
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
44-
_branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version',
45-
serializable=False, deserializable=False)
46-
""":Optional[int]: Version number of the branch that contains this workflow."""
44+
_branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id',
45+
serializable=False)
4746

4847
def __init__(self,
4948
name: str,
@@ -68,25 +67,3 @@ def design_executions(self) -> DesignExecutionCollection:
6867
raise AttributeError('Cannot initialize execution without project reference!')
6968
return DesignExecutionCollection(
7069
project_id=self.project_id, session=self._session, workflow_id=self.uid)
71-
72-
@property
73-
def branch_root_id(self):
74-
"""Retrieve the root ID of the branch this workflow is on."""
75-
return self._branch_root_id
76-
77-
@branch_root_id.setter
78-
def branch_root_id(self, value):
79-
"""Set the root ID of the branch this workflow is on."""
80-
self._branch_root_id = value
81-
self._branch_id = None
82-
83-
@property
84-
def branch_version(self):
85-
"""Retrieve the version of the branch this workflow is on."""
86-
return self._branch_version
87-
88-
@branch_version.setter
89-
def branch_version(self, value):
90-
"""Set the version of the branch this workflow is on."""
91-
self._branch_version = value
92-
self._branch_id = None

src/citrine/resources/design_workflow.py

+10-37
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from citrine._rest.collection import Collection
66
from citrine._session import Session
7-
from citrine.exceptions import NotFound
87
from citrine.informatics.workflows import DesignWorkflow
98
from citrine.resources.response import Response
109
from functools import partial
@@ -31,25 +30,6 @@ def __init__(self,
3130
self.branch_root_id = branch_root_id
3231
self.branch_version = branch_version
3332

34-
def _resolve_branch_root_and_version(self, workflow):
35-
from citrine.resources.branch import BranchCollection
36-
37-
workflow_copy = deepcopy(workflow)
38-
bc = BranchCollection(self.project_id, self.session)
39-
branch = bc.get_by_version_id(version_id=workflow_copy._branch_id)
40-
workflow_copy._branch_root_id = branch.root_id
41-
workflow_copy._branch_version = branch.version
42-
return workflow_copy
43-
44-
def _resolve_branch_id(self, root_id, version):
45-
from citrine.resources.branch import BranchCollection
46-
47-
if root_id and version:
48-
bc = BranchCollection(self.project_id, self.session)
49-
branch = bc.get(root_id=root_id, version=version)
50-
return branch.uid
51-
return None
52-
5333
def register(self, model: DesignWorkflow) -> DesignWorkflow:
5434
"""
5535
Upload a new design workflow.
@@ -77,15 +57,15 @@ def register(self, model: DesignWorkflow) -> DesignWorkflow:
7757
'project.design_workflows.register().')
7858
raise RuntimeError(msg)
7959
else:
80-
# branch_id is in the body of design workflow endpoints, so it must be serialized.
81-
# This means the collection branch_id might not match the workflow branch_id. The
82-
# collection should win out, since the user is explicitly referencing the branch
83-
# represented by this collection.
84-
# To avoid modifying the parameter, and to ensure the only change is the branch_id, we
85-
# deepcopy, modify, then register it.
60+
# branch_root_id and branch_version are in the body of design workflow endpoints, so
61+
# they must be serialized. This means the collection fields might not match the
62+
# workflow fields. The collection should win out, since the user is explicitly
63+
# referencing the branch represented by this collection.
64+
# To avoid modifying the parameter, and to ensure the only changes are the
65+
# branch_root_id and branch_version, we deepcopy, modify, then register it.
8666
model_copy = deepcopy(model)
87-
model_copy._branch_id = self._resolve_branch_id(self.branch_root_id,
88-
self.branch_version)
67+
model_copy.branch_root_id = self.branch_root_id
68+
model_copy.branch_version = self.branch_version
8969
return super().register(model_copy)
9070

9171
def build(self, data: dict) -> DesignWorkflow:
@@ -104,7 +84,6 @@ def build(self, data: dict) -> DesignWorkflow:
10484
10585
"""
10686
workflow = DesignWorkflow.build(data)
107-
workflow = self._resolve_branch_root_and_version(workflow)
10887
workflow._session = self.session
10988
workflow.project_id = self.project_id
11089
return workflow
@@ -137,13 +116,6 @@ def update(self, model: DesignWorkflow) -> DesignWorkflow:
137116
raise ValueError('Cannot update a design workflow unless its branch_root_id and '
138117
'branch_version are set.')
139118

140-
try:
141-
model._branch_id = self._resolve_branch_id(model.branch_root_id,
142-
model.branch_version)
143-
except NotFound:
144-
raise ValueError('Cannot update a design workflow unless its branch_root_id and '
145-
'branch_version exists.')
146-
147119
# If executions have already been done, warn about future behavior change
148120
executions = model.design_executions.list()
149121
if next(executions, None) is not None:
@@ -197,7 +169,8 @@ def _fetch_page(self,
197169
additional_params: Optional[dict] = None,
198170
) -> Tuple[Iterable[dict], str]:
199171
params = additional_params or {}
200-
params["branch"] = self._resolve_branch_id(self.branch_root_id, self.branch_version)
172+
params["branch_root_id"] = self.branch_root_id
173+
params["branch_version"] = self.branch_version
201174
return super()._fetch_page(path=path,
202175
fetch_func=fetch_func,
203176
page=page,

src/citrine/seeding/find_or_create.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ def create_or_update(*,
175175
# Locally created design workflows likely won't have a branch ID but
176176
# need one to be updated.
177177
if isinstance(old_resource, DesignWorkflow):
178-
new_resource._branch_root_id = old_resource.branch_root_id
179-
new_resource._branch_version = old_resource.branch_version
178+
new_resource.branch_root_id = old_resource.branch_root_id
179+
new_resource.branch_version = old_resource.branch_version
180180
return collection.update(new_resource)
181181
else:
182182
logger.info("Registering new module: {}".format(resource.name))

tests/resources/test_branch.py

+23
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,29 @@ def test_branch_get(session, collection, branch_path):
107107
assert session.last_call == FakeCall(method='GET', path=branch_path, params={'page': 1, 'per_page': 1, 'root': root_id, 'version': version})
108108

109109

110+
def test_branch_get_not_found(session, collection, branch_path):
111+
# Given
112+
session.set_response({"response": []})
113+
114+
# When
115+
with pytest.raises(NotFound):
116+
collection.get(root_id=uuid.uuid4(), version=1)
117+
118+
119+
def test_branch_get_by_version_id(session, collection, branch_path):
120+
# Given
121+
branch_data = BranchDataFactory()
122+
version_id = branch_data['id']
123+
session.set_response(branch_data)
124+
125+
# When
126+
branch = collection.get_by_version_id(version_id=version_id)
127+
128+
# Then
129+
assert session.num_calls == 1
130+
assert session.last_call == FakeCall(method='GET', path=f"{branch_path}/{version_id}")
131+
132+
110133
def test_branch_list(session, collection, branch_path):
111134
# Given
112135
branch_count = 5

tests/resources/test_design_workflows.py

+18-46
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def collection(branch_data, collection_without_branch) -> DesignWorkflowCollecti
4343

4444
@pytest.fixture
4545
def workflow(collection, branch_data, design_workflow_dict) -> DesignWorkflow:
46-
design_workflow_dict["branch_id"] = branch_data["id"]
46+
design_workflow_dict["branch_root_id"] = branch_data["metadata"]["root_id"]
47+
design_workflow_dict["branch_version"] = branch_data["metadata"]["version"]
4748

4849
collection.session.set_response(branch_data)
4950
workflow = collection.build(design_workflow_dict)
@@ -71,12 +72,6 @@ def workflow_path(collection, workflow=None):
7172
path = f'{path}/{workflow.uid}'
7273
return path
7374

74-
def branches_path(collection, branch_id=None):
75-
path = f'/projects/{collection.project_id}/branches'
76-
if branch_id:
77-
path = f'{path}/{branch_id}'
78-
return path
79-
8075
def assert_workflow(actual, expected, *, include_branch=False):
8176
assert actual.name == expected.name
8277
assert actual.description == expected.description
@@ -86,7 +81,7 @@ def assert_workflow(actual, expected, *, include_branch=False):
8681
assert actual.predictor_version == expected.predictor_version
8782
assert actual.project_id == expected.project_id
8883
if include_branch:
89-
assert actual.branch_id == expected.branch_id
84+
assert actual._branch_id == expected._branch_id
9085
assert actual.branch_root_id == expected.branch_root_id
9186
assert actual.branch_version == expected.branch_version
9287

@@ -99,29 +94,22 @@ def test_basic_methods(workflow, collection, design_workflow_dict):
9994
@pytest.mark.parametrize("optional_args", all_combination_lengths(OPTIONAL_ARGS))
10095
def test_register(session, branch_data, workflow_minimal, collection, optional_args):
10196
workflow = workflow_minimal
102-
branch_id = branch_data['id']
103-
branch_data_get_resp = {"response": [branch_data]}
104-
branch_data_get_params = {
105-
'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version
106-
}
97+
branch_root_id = branch_data['metadata']['root_id']
98+
branch_version = branch_data['metadata']['version']
10799

108100
# Set a random value for all optional args selected for this run.
109101
for name, factory in optional_args:
110102
setattr(workflow, name, factory())
111103

112104
# Given
113-
post_dict = {**workflow.dump(), "branch_id": str(branch_id)}
114-
session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data)
105+
post_dict = {**workflow.dump(), "branch_root_id": str(branch_root_id), "branch_version": branch_version}
106+
session.set_responses({**post_dict, 'status_description': 'status'})
115107

116108
# When
117109
new_workflow = collection.register(workflow)
118110

119111
# Then
120-
assert session.calls == [
121-
FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params),
122-
FakeCall(method='POST', path=workflow_path(collection), json=post_dict),
123-
FakeCall(method='GET', path=branches_path(collection, branch_id)),
124-
]
112+
assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)]
125113

126114
assert new_workflow.branch_root_id == collection.branch_root_id
127115
assert new_workflow.branch_version == collection.branch_version
@@ -133,23 +121,18 @@ def test_register_conflicting_branches(session, branch_data, workflow, collectio
133121
old_branch_root_id = uuid.uuid4()
134122
workflow.branch_root_id = old_branch_root_id
135123
assert workflow.branch_root_id != collection.branch_root_id
124+
125+
new_branch_root_id = str(branch_data["metadata"]["root_id"])
126+
new_branch_version = branch_data["metadata"]["version"]
136127

137-
branch_data_get_resp = {"response": [branch_data]}
138-
branch_data_get_params = {
139-
'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version
140-
}
141-
post_dict = {**workflow.dump(), "branch_id": str(branch_data["id"])}
142-
session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data)
128+
post_dict = {**workflow.dump(), "branch_root_id": new_branch_root_id, "branch_version": new_branch_version}
129+
session.set_responses({**post_dict, 'status_description': 'status'})
143130

144131
# When
145132
new_workflow = collection.register(workflow)
146133

147134
# Then
148-
assert session.calls == [
149-
FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params),
150-
FakeCall(method='POST', path=workflow_path(collection), json=post_dict),
151-
FakeCall(method='GET', path=branches_path(collection, branch_data["id"])),
152-
]
135+
assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)]
153136

154137
assert workflow.branch_root_id == old_branch_root_id
155138
assert new_workflow.branch_root_id == collection.branch_root_id
@@ -180,10 +163,10 @@ def test_delete(collection):
180163

181164

182165
def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollection):
183-
branch_data_get_resp = {"response": [branch_data]}
184-
branch_id = uuid.UUID(branch_data['id'])
166+
branch_root_id = uuid.UUID(branch_data['metadata']['root_id'])
167+
branch_version = branch_data['metadata']['version']
185168

186-
collection.session.set_responses(branch_data_get_resp, {"response": []})
169+
collection.session.set_responses({"response": []})
187170

188171
lst = list(collection.list_archived(per_page=10))
189172
assert len(lst) == 0
@@ -192,7 +175,7 @@ def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollecti
192175
assert collection.session.last_call == FakeCall(
193176
method='GET',
194177
path=expected_path,
195-
params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch': branch_id},
178+
params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch_root_id': branch_root_id, 'branch_version': branch_version},
196179
json=None
197180
)
198181

@@ -213,17 +196,10 @@ def test_missing_project(design_workflow_dict):
213196

214197
def test_update(session, branch_data, workflow, collection_without_branch):
215198
# Given
216-
branch_data_get_resp = {"response": [branch_data]}
217-
branch_data_get_params = {
218-
'page': 1, 'per_page': 1, 'root': str(workflow.branch_root_id), 'version': workflow.branch_version
219-
}
220-
221199
post_dict = workflow.dump()
222200
session.set_responses(
223-
branch_data_get_resp,
224201
{"per_page": 1, "next": "", "response": []},
225202
{**post_dict, 'status_description': 'status'},
226-
branch_data
227203
)
228204

229205
# When
@@ -232,20 +208,16 @@ def test_update(session, branch_data, workflow, collection_without_branch):
232208
# Then
233209
executions_path = f'/projects/{collection_without_branch.project_id}/design-workflows/{workflow.uid}/executions'
234210
assert session.calls == [
235-
FakeCall(method='GET', path=branches_path(collection_without_branch), params=branch_data_get_params),
236211
FakeCall(method='GET', path=executions_path, params={'page': 1, 'per_page': 100}),
237212
FakeCall(method='PUT', path=workflow_path(collection_without_branch, workflow), json=post_dict),
238-
FakeCall(method='GET', path=branches_path(collection_without_branch, branch_data["id"])),
239213
]
240214
assert_workflow(new_workflow, workflow)
241215

242216

243217
def test_update_failure_with_existing_execution(session, branch_data, workflow, collection_without_branch, design_execution_dict):
244-
branch_data_get_resp = {"response": [branch_data]}
245218
workflow.branch_root_id = uuid.uuid4()
246219
post_dict = workflow.dump()
247220
session.set_responses(
248-
branch_data_get_resp,
249221
{"per_page": 1, "next": "", "response": [design_execution_dict]},
250222
{**post_dict, 'status_description': 'status'})
251223

tests/resources/test_workflow.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,15 @@ def test_build_design_workflow(session, basic_design_workflow_data):
6262

6363
def test_list_workflows(session, basic_design_workflow_data):
6464
#Given
65-
branch_data = BranchDataFactory()
66-
branch_data_get_resp = {"response": [branch_data]}
67-
session.set_response(branch_data)
68-
6965
workflow_collection = DesignWorkflowCollection(project_id=uuid.uuid4(), session=session)
70-
session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20}, branch_data)
66+
session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20})
7167

7268
# When
7369
workflows = list(workflow_collection.list(per_page=20))
7470

7571
# Then
7672
expected_design_call = FakeCall(method='GET', path='/projects/{}/modules'.format(workflow_collection.project_id),
7773
params={'per_page': 20, 'module_type': 'DESIGN_WORKFLOW'})
78-
assert 2 == session.num_calls
74+
assert 1 == session.num_calls
7975
assert len(workflows) == 1
8076
assert isinstance(workflows[0], DesignWorkflow)

tests/seeding/test_find_or_create.py

-6
Original file line numberDiff line numberDiff line change
@@ -353,17 +353,11 @@ def test_create_or_update_unique_found_design_workflow(session):
353353
dw2_dict = DesignWorkflowDataFactory(branch_root_id=root_id, branch_version=version)
354354
dw3_dict = DesignWorkflowDataFactory()
355355
session.set_responses(
356-
# Build (setup)
357-
branch_data, # Find the model's branch root ID and version
358356
# List
359-
{"response": [branch_data]}, # Find the collection's branch version ID
360357
{"response": [dw1_dict, dw2_dict, dw3_dict]}, # Return the design workflows
361-
branch_data, branch_data, branch_data, # Lookup the branch root ID and version of each design workflow.
362358
# Update
363-
{"response": [branch_data]}, # Lookup the module's branch version ID
364359
{"response": []}, # Check if there are any executions
365360
dw2_dict, # Return the updated design workflow
366-
branch_data # Lookup the updated design workflow branch root ID and version
367361
)
368362

369363
collection = LocalDesignWorkflowCollection(project_id=uuid4(), session=session, branch_root_id=root_id, branch_version=version)

0 commit comments

Comments
 (0)