Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify compound ingest operation to wait for table build completion #970

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.10.0"
__version__ = "3.11.0"
65 changes: 58 additions & 7 deletions src/citrine/jobs/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from gemd.enumeration.base_enumeration import BaseEnumeration
from logging import getLogger
from time import time, sleep
from typing import Union
from uuid import UUID
from warnings import warn

from citrine._rest.resource import Resource
from citrine._serialization.properties import Set as PropertySet, String, Object
Expand All @@ -23,6 +25,16 @@ class JobSubmissionResponse(Resource['JobSubmissionResponse']):
""":UUID: job id of the job submission request"""


class JobStatus(BaseEnumeration):
"""The valid status codes for a job."""

SUBMITTED = "Submitted"
PENDING = "Pending"
RUNNING = "Running"
SUCCESS = "Success"
FAILURE = "Failure"


class TaskNode(Resource['TaskNode']):
"""Individual task status.

Expand All @@ -33,14 +45,29 @@ class TaskNode(Resource['TaskNode']):
""":str: unique identification number for the job task"""
task_type = properties.String("task_type")
""":str: the type of task running"""
status = properties.String("status")
""":str: The last reported status of this particular task.
One of "Submitted", "Pending", "Running", "Success", or "Failure"."""
_status = properties.String("status")
dependencies = PropertySet(String(), "dependencies")
""":Set[str]: all the tasks that this task is dependent on"""
failure_reason = properties.Optional(String(), "failure_reason")
""":str: if a task has failed, the failure reason will be in this parameter"""

@property
def status(self) -> Union[JobStatus, str]:
"""The last reported status of this particular task."""
if resolved := JobStatus.from_str(self._status, exception=False):
return resolved
else:
return self._status

@status.setter
def status(self, value: Union[JobStatus, str]) -> None:
if JobStatus.from_str(value, exception=False) is None:
warn(
f"{value} is not a recognized JobStatus; this will become an error as of v4.0.0.",
DeprecationWarning
)
self._status = value


class JobStatusResponse(Resource['JobStatusResponse']):
"""A response to a job status check.
Expand All @@ -50,13 +77,37 @@ class JobStatusResponse(Resource['JobStatusResponse']):

job_type = properties.String("job_type")
""":str: the type of job for this status report"""
status = properties.String("status")
_status = properties.String("status")
""":str: The status of the job. One of "Running", "Success", or "Failure"."""
tasks = properties.List(Object(TaskNode), "tasks")
""":List[TaskNode]: all of the constituent task required to complete this job"""
output = properties.Optional(properties.Mapping(String, String), 'output')
""":Optional[dict[str, str]]: job output properties and results"""

@property
def status(self) -> Union[JobStatus, str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since BaseEnumeration has str as a superclass, returning a JobStatus would not lead to a change in behavior, so this is not an API break.

"""The last reported status of this particular task."""
if resolved := JobStatus.from_str(self._status, exception=False):
return resolved
else:
return self._status

@status.setter
def status(self, value: Union[JobStatus, str]) -> None:
if resolved := JobStatus.from_str(value, exception=False):
if resolved not in [JobStatus.RUNNING, JobStatus.SUCCESS, JobStatus.FAILURE]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this support any valid JobStatus value? If so, you can use list(JobStatus) to ensure it's resilient against typos and future code changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. That's strange, that the Response can't take the other statuses. Not even Pending?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what's documented. This ought to just be a read only object, so maybe we don't care about the fact that the far end will only ever return a subset.

warn(
f"{value} is not a valid JobStatus for a JobStatusResponse; "
f"this will become an error as of v4.0.0.",
DeprecationWarning
)
else:
warn(
f"{value} is not a recognized JobStatus; this will become an error as of v4.0.0.",
DeprecationWarning
)
self._status = value


def _poll_for_job_completion(session: Session,
job: Union[JobSubmissionResponse, UUID, str],
Expand Down Expand Up @@ -102,7 +153,7 @@ def _poll_for_job_completion(session: Session,
while True:
response = session.get_resource(path=path, params=params)
status: JobStatusResponse = JobStatusResponse.build(response)
if status.status in ['Success', 'Failure']:
if status.status in [JobStatus.SUCCESS, JobStatus.FAILURE]:
break
elif time() - start_time < timeout:
logger.info(
Expand All @@ -115,12 +166,12 @@ def _poll_for_job_completion(session: Session,
f'Note job on server is unaffected by this timeout.')
logger.debug('Last status: {}'.format(status.dump()))
raise PollingTimeoutError('Job {} timed out.'.format(job_id))
if status.status == 'Failure':
if status.status == JobStatus.FAILURE:
logger.debug(f'Job terminated with Failure status: {status.dump()}')
if raise_errors:
failure_reasons = []
for task in status.tasks:
if task.status == 'Failure':
if task.status == JobStatus.FAILURE:
logger.error(f'Task {task.id} failed with reason "{task.failure_reason}"')
failure_reasons.append(task.failure_reason)
raise JobFailureError(
Expand Down
21 changes: 16 additions & 5 deletions src/citrine/resources/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class Ingestion(Resource['Ingestion']):
raise_errors = properties.Optional(properties.Boolean(), 'raise_errors', default=True)

@property
@deprecated(deprecated_in='3.11.0', removed_in='4.0.0',
details="The project_id attribute is deprecated since "
"dataset access is now controlled through teams.")
def project_id(self) -> Optional[UUID]:
"""[DEPRECATED] The project ID associated with this ingest."""
return self._project_id
Expand Down Expand Up @@ -300,15 +303,15 @@ def build_objects_async(self,
if not build_table:
project_id = None
elif project is None:
if self.project_id is None:
if self._project_id is None:
raise ValueError("Building a table requires a target project.")
else:
warn(
"Building a table with an implicit project is deprecated "
"and will be removed in v4. Please pass a project explicitly.",
DeprecationWarning
)
project_id = self.project_id
project_id = self._project_id
elif isinstance(project, Project):
project_id = project.uid
elif isinstance(project, UUID):
Expand Down Expand Up @@ -365,18 +368,26 @@ def poll_for_job_completion(self,
if polling_delay is not None:
kwargs["polling_delay"] = polling_delay

_poll_for_job_completion(
build_job_status = _poll_for_job_completion(
session=self.session,
team_id=self.team_id,
job=job,
raise_errors=False, # JobFailureError doesn't contain the error
**kwargs
)
if build_job_status.output is not None and "table_build_job_id" in build_job_status.output:
_poll_for_job_completion(
session=self.session,
team_id=self.team_id,
job=build_job_status.output["table_build_job_id"],
raise_errors=False, # JobFailureError doesn't contain the error
**kwargs
)
return self.status()

def status(self) -> IngestionStatus:
"""
[ALPHA] Retrieve the status of the ingestion from platform.
[ALPHA] Retrieve the status of the ingestion from platform.

Returns
----------
Expand Down Expand Up @@ -438,7 +449,7 @@ def poll_for_job_completion(self,

def status(self) -> IngestionStatus:
"""
[ALPHA] Retrieve the status of the ingestion from platform.
[ALPHA] Retrieve the status of the ingestion from platform.

Returns
----------
Expand Down
39 changes: 39 additions & 0 deletions tests/jobs/test_deprecations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from citrine.jobs.job import JobStatus, JobStatusResponse, TaskNode
import pytest
import warnings

from tests.utils.factories import TaskNodeDataFactory, JobStatusResponseDataFactory

def test_status_response_status():
status_response = JobStatusResponse.build(JobStatusResponseDataFactory(failure=True))
assert status_response.status == JobStatus.FAILURE

with pytest.deprecated_call():
status_response.status = 'Failed'
with warnings.catch_warnings():
warnings.simplefilter("error")
assert not isinstance(status_response.status, JobStatus)

with pytest.deprecated_call():
status_response.status = JobStatus.PENDING
with warnings.catch_warnings():
warnings.simplefilter("error")
assert status_response.status == JobStatus.PENDING

with warnings.catch_warnings():
warnings.simplefilter("error")
status_response.status = JobStatus.SUCCESS
assert status_response.status == JobStatus.SUCCESS

def test_task_node_status():
status_response = TaskNode.build(TaskNodeDataFactory(failure=True))
assert status_response.status == JobStatus.FAILURE

with pytest.deprecated_call():
status_response.status = 'Failed'
assert not isinstance(status_response.status, JobStatus)

with warnings.catch_warnings():
warnings.simplefilter("error")
status_response.status = JobStatus.SUCCESS
assert status_response.status == JobStatus.SUCCESS
4 changes: 1 addition & 3 deletions tests/jobs/test_waiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import time

from citrine.informatics.executions.design_execution import DesignExecution
from citrine.informatics.executions.predictor_evaluation_execution import (
PredictorEvaluationExecution)
from citrine.jobs.waiting import (
wait_for_asynchronous_object,
wait_while_executing,
Expand Down Expand Up @@ -53,7 +51,7 @@ def test_wait_while_validating_timeout(sleep_mock, time_mock):
module.in_progress.return_value = True
collection.get.return_value = module

with pytest.raises(ConditionTimeoutError) as exceptio:
with pytest.raises(ConditionTimeoutError):
wait_while_validating(collection=collection, module=module, timeout=1.0)

@mock.patch('time.sleep', return_value=None)
Expand Down
41 changes: 14 additions & 27 deletions tests/resources/test_file_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from citrine.resources.ingestion import Ingestion, IngestionCollection
from citrine.exceptions import NotFound

from tests.utils.factories import FileLinkDataFactory, _UploaderFactory
from tests.utils.factories import (
FileLinkDataFactory, _UploaderFactory, JobStatusResponseDataFactory,
IngestionStatusResponseDataFactory, IngestFilesResponseDataFactory, JobSubmissionResponseDataFactory
)
from tests.utils.session import FakeSession, FakeS3Client, FakeCall, FakeRequestResponseApiError


Expand Down Expand Up @@ -536,31 +539,15 @@ def test_ingest(collection: FileCollection, session):
good_file2 = collection.build({"filename": "also.csv", "id": str(uuid4()), "version": str(uuid4())})
bad_file = FileLink(filename="bad.csv", url="http://files.com/input.csv")

ingest_create_resp = {
"team_id": str(uuid4()),
"dataset_id": str(uuid4()),
"ingestion_id": str(uuid4())
}
job_id_resp = {
'job_id': str(uuid4())
}
job_status_resp = {
'job_id': job_id_resp['job_id'],
'job_type': 'create-gemd-objects',
'status': 'Success',
'tasks': [{'id': f'create-gemd-objects-{uuid4()}',
'task_type': 'create-gemd-objects-task',
'status': 'Success',
'dependencies': [],
'failure_reason': None}],
'output': {}
}
ingest_status_resp = {
"ingestion_id": ingest_create_resp["ingestion_id"],
"status": "ingestion_created",
"errors": [],
}
session.set_responses(ingest_create_resp, job_id_resp, job_status_resp, ingest_status_resp)
ingest_files_resp = IngestFilesResponseDataFactory()
job_id_resp = JobSubmissionResponseDataFactory()
job_status_resp = JobStatusResponseDataFactory(
job_id=job_id_resp['job_id'],
job_type='create-gemd-objects',
)
ingest_status_resp = IngestionStatusResponseDataFactory()

session.set_responses(ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp)
collection.ingest([good_file1, good_file2])

with pytest.raises(ValueError, match=bad_file.url):
Expand All @@ -572,7 +559,7 @@ def test_ingest(collection: FileCollection, session):
with pytest.raises(ValueError):
collection.ingest([good_file1], build_table=True)

session.set_responses(ingest_create_resp, job_id_resp, job_status_resp, ingest_status_resp)
session.set_responses(ingest_files_resp, job_id_resp, job_status_resp, ingest_status_resp)
coll_with_project_id = FileCollection(team_id=uuid4(), dataset_id=uuid4(), session=session)
coll_with_project_id.project_id = uuid4()
with pytest.deprecated_call():
Expand Down
14 changes: 7 additions & 7 deletions tests/resources/test_gemd_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from citrine._utils.functions import format_escaped_url

from tests.utils.factories import MaterialRunDataFactory, MaterialSpecDataFactory
from tests.utils.factories import JobSubmissionResponseFactory
from tests.utils.factories import JobSubmissionResponseDataFactory
from tests.utils.session import FakeSession, FakeCall


Expand Down Expand Up @@ -409,7 +409,7 @@ def test_async_update(gemd_collection, session):
'output': {}
}

session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp)
session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp)

# This returns None on successful update with wait.
gemd_collection.async_update(obj, wait_for_response=True)
Expand All @@ -423,7 +423,7 @@ def test_async_update_and_no_dataset_id(gemd_collection, session):
uids={'id': str(uuid4())}
)

session.set_response(JobSubmissionResponseFactory())
session.set_response(JobSubmissionResponseDataFactory())
gemd_collection.dataset_id = None

with pytest.raises(RuntimeError):
Expand All @@ -444,7 +444,7 @@ def test_async_update_timeout(gemd_collection, session):
'output': {}
}

session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp)
session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp)

with pytest.raises(PollingTimeoutError):
gemd_collection.async_update(obj, wait_for_response=True,
Expand All @@ -465,7 +465,7 @@ def test_async_update_and_wait(gemd_collection, session):
'output': {}
}

session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp)
session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp)

# This returns None on successful update with wait.
gemd_collection.async_update(obj, wait_for_response=True)
Expand All @@ -485,7 +485,7 @@ def test_async_update_and_wait_failure(gemd_collection, session):
'output': {}
}

session.set_responses(JobSubmissionResponseFactory(), fake_job_status_resp)
session.set_responses(JobSubmissionResponseDataFactory(), fake_job_status_resp)

with pytest.raises(JobFailureError):
gemd_collection.async_update(obj, wait_for_response=True)
Expand All @@ -499,7 +499,7 @@ def test_async_update_with_no_wait(gemd_collection, session):
uids={'id': str(uuid4())}
)

session.set_response(JobSubmissionResponseFactory())
session.set_response(JobSubmissionResponseDataFactory())
job_id = gemd_collection.async_update(obj, wait_for_response=False)
assert job_id is not None

Expand Down
Loading