From 63f2213c4cbdf3511c2cb9f2be8f93a28d7c3efa Mon Sep 17 00:00:00 2001 From: Aryamanz29 Date: Sun, 5 May 2024 00:44:33 +0530 Subject: [PATCH 1/6] DVX-405: Adds support for scheduling workflow runs --- pyatlan/client/constants.py | 29 +- pyatlan/client/workflow.py | 350 +++++++++++++++++++--- pyatlan/model/workflow.py | 32 +- tests/integration/test_client.py | 9 - tests/integration/test_workflow_client.py | 250 ++++++++++++++++ tests/unit/constants.py | 72 +++++ tests/unit/test_workflow_client.py | 316 +++++++++++++------ 7 files changed, 912 insertions(+), 146 deletions(-) create mode 100644 tests/integration/test_workflow_client.py diff --git a/pyatlan/client/constants.py b/pyatlan/client/constants.py index bc889d84c..13df5702f 100644 --- a/pyatlan/client/constants.py +++ b/pyatlan/client/constants.py @@ -408,9 +408,34 @@ WORKFLOW_RUN = API( WORKFLOW_RUN_API, HTTPMethod.POST, HTTPStatus.OK, endpoint=EndPoint.HERACLES ) -WORKFLOW_UPDATE_API = "workflows" +WORKFLOW_API = "workflows" WORKFLOW_UPDATE = API( - WORKFLOW_UPDATE_API + "/{workflow_name}", + WORKFLOW_API + "/{workflow_name}", + HTTPMethod.POST, + HTTPStatus.OK, + endpoint=EndPoint.HERACLES, +) +WORKFLOW_ARCHIVE = API( + WORKFLOW_API + "/{workflow_name}/archive", + HTTPMethod.POST, + HTTPStatus.OK, + endpoint=EndPoint.HERACLES, +) +WORKFLOW_SCHEDULE_RUN = "runs" +GET_ALL_SCHEDULE_RUNS = API( + WORKFLOW_SCHEDULE_RUN + "/cron", + HTTPMethod.GET, + HTTPStatus.OK, + endpoint=EndPoint.HERACLES, +) +GET_SCHEDULE_RUN = API( + WORKFLOW_SCHEDULE_RUN + "/cron/{workflow_name}", + HTTPMethod.GET, + HTTPStatus.OK, + endpoint=EndPoint.HERACLES, +) +STOP_WORKFLOW_RUN = API( + WORKFLOW_SCHEDULE_RUN + "/{workflow_run_id}/stop", HTTPMethod.POST, HTTPStatus.OK, endpoint=EndPoint.HERACLES, diff --git a/pyatlan/client/workflow.py b/pyatlan/client/workflow.py index bff8a91f5..3bbc8be43 100644 --- a/pyatlan/client/workflow.py +++ b/pyatlan/client/workflow.py @@ -4,10 +4,14 @@ from time import sleep from typing import List, Optional, Union, overload -from pydantic.v1 import validate_arguments +from pydantic.v1 import ValidationError, parse_obj_as, validate_arguments from pyatlan.client.common import ApiCaller from pyatlan.client.constants import ( + GET_ALL_SCHEDULE_RUNS, + GET_SCHEDULE_RUN, + STOP_WORKFLOW_RUN, + WORKFLOW_ARCHIVE, WORKFLOW_INDEX_RUN_SEARCH, WORKFLOW_INDEX_SEARCH, WORKFLOW_RERUN, @@ -22,6 +26,8 @@ Workflow, WorkflowResponse, WorkflowRunResponse, + WorkflowSchedule, + WorkflowScheduleResponse, WorkflowSearchRequest, WorkflowSearchResponse, WorkflowSearchResult, @@ -37,6 +43,9 @@ class WorkflowClient: directly but can be obtained through the workflow property of AtlanClient. """ + _WORKFLOW_RUN_SCHEDULE = "orchestration.atlan.com/schedule" + _WORKFLOW_RUN_TIMEZONE = "orchestration.atlan.com/timezone" + def __init__(self, client: ApiCaller): if not isinstance(client, ApiCaller): raise ErrorCode.INVALID_PARAMETER_TYPE.exception_with_parameters( @@ -44,34 +53,16 @@ def __init__(self, client: ApiCaller): ) self._client = client - @validate_arguments - def find_by_type( - self, prefix: WorkflowPackage, max_results: int = 10 - ) -> List[WorkflowSearchResult]: - """ - Find workflows based on their type (prefix). Note: Only workflows that have been run will be found. - - :param prefix: name of the specific workflow to find (for example CONNECTION_DELETE) - :param max_results: the maximum number of results to retrieve - :returns: the list of workflows of the provided type, with the most-recently created first - :raises ValidationError: If the provided prefix is invalid workflow package - :raises AtlanError: on any API communication issue - """ - query = Bool( - filter=[ - NestedQuery( - query=Prefix(field="metadata.name.keyword", value=prefix.value), - path="metadata", - ) - ] - ) - request = WorkflowSearchRequest(query=query, size=max_results) - raw_json = self._client._call_api( - WORKFLOW_INDEX_SEARCH, - request_obj=request, - ) - response = WorkflowSearchResponse(**raw_json) - return response.hits.hits or [] + @staticmethod + def _parse_response(raw_json, response_type): + try: + if isinstance(raw_json, List): + return parse_obj_as(List[response_type], raw_json) + return parse_obj_as(response_type, raw_json) + except ValidationError as err: + raise ErrorCode.JSON_ERROR.exception_with_parameters( + raw_json, 200, str(err) + ) from err @validate_arguments def _find_latest_run(self, workflow_name: str) -> Optional[WorkflowSearchResult]: @@ -92,7 +83,7 @@ def _find_latest_run(self, workflow_name: str) -> Optional[WorkflowSearchResult] ) ] ) - response = self._find_run(query) + response = self._find_runs(query, size=1) return results[0] if (results := response.hits.hits) else None @validate_arguments @@ -116,7 +107,7 @@ def _find_current_run(self, workflow_name: str) -> Optional[WorkflowSearchResult ) ] ) - response = self._find_run(query, size=50) + response = self._find_runs(query, size=50) if results := response.hits.hits: for result in results: if result.status in { @@ -126,14 +117,87 @@ def _find_current_run(self, workflow_name: str) -> Optional[WorkflowSearchResult return result return None - def _find_run(self, query: Query, size: int = 1) -> WorkflowSearchResponse: - request = WorkflowSearchRequest(query=query, size=size) + def _find_runs( + self, + query: Query, + from_: int = 0, + size: int = 100, + ) -> WorkflowSearchResponse: + """ + Retrieve existing workflow runs. + + :param query: query object to filter workflow runs. + :param from_: starting index of the search results (default: `0`). + :param size: maximum number of search results to return (default: `100`). + :returns: a response containing the matching workflow runs. + :raises AtlanError: on any API communication issue + """ + request = WorkflowSearchRequest(query=query, from_=from_, size=size) raw_json = self._client._call_api( WORKFLOW_INDEX_RUN_SEARCH, request_obj=request, ) return WorkflowSearchResponse(**raw_json) + def _add_schedule( + self, + workflow: Workflow, + workflow_schedule: WorkflowSchedule, + ): + """ + Adds required schedule parameters to the workflow object. + """ + workflow.metadata.annotations and workflow.metadata.annotations.update( + { + self._WORKFLOW_RUN_SCHEDULE: workflow_schedule.cron_schedule, + self._WORKFLOW_RUN_TIMEZONE: workflow_schedule.timezone, + } + ) + + @validate_arguments + def find_by_type( + self, prefix: WorkflowPackage, max_results: int = 10 + ) -> List[WorkflowSearchResult]: + """ + Find workflows based on their type (prefix). + Note: Only workflows that have been run will be found. + + :param prefix: name of the specific workflow to find (for example CONNECTION_DELETE) + :param max_results: the maximum number of results to retrieve + :returns: the list of workflows of the provided type, with the most-recently created first + :raises ValidationError: If the provided prefix is invalid workflow package + :raises AtlanError: on any API communication issue + """ + query = Bool( + filter=[ + NestedQuery( + query=Prefix(field="metadata.name.keyword", value=prefix.value), + path="metadata", + ) + ] + ) + request = WorkflowSearchRequest(query=query, size=max_results) + raw_json = self._client._call_api( + WORKFLOW_INDEX_SEARCH, + request_obj=request, + ) + response = WorkflowSearchResponse(**raw_json) + return response.hits.hits or [] + + def _handle_workflow_types(self, workflow): + if isinstance(workflow, WorkflowPackage): + if results := self.find_by_type(workflow): + detail = results[0].source + else: + raise ErrorCode.NO_PRIOR_RUN_AVAILABLE.exception_with_parameters( + workflow + ) + elif isinstance(workflow, WorkflowSearchResult): + detail = workflow.source + else: + detail = workflow + return detail + @overload def rerun( self, workflow: WorkflowPackage, idempotent: bool = False @@ -158,7 +222,9 @@ def rerun( idempotent: bool = False, ) -> WorkflowRunResponse: """ - Rerun the workflow immediately. Note: this must be a workflow that was previously run. + Rerun the workflow immediately. + Note: this must be a workflow that was previously run. + :param workflow: The workflow to rerun. :param idempotent: If `True`, the workflow will only be rerun if it is not already currently running :returns: the details of the workflow run (if `idempotent`, will return details of the already-running workflow) @@ -166,18 +232,7 @@ def rerun( :raises InvalidRequestException: If no prior runs are available for the provided workflow :raises AtlanError: on any API communication issue """ - if isinstance(workflow, WorkflowPackage): - if results := self.find_by_type(workflow): - detail = results[0].source - else: - raise ErrorCode.NO_PRIOR_RUN_AVAILABLE.exception_with_parameters( - workflow - ) - elif isinstance(workflow, WorkflowSearchResult): - detail = workflow.source - else: - detail = workflow - + detail = self._handle_workflow_types(workflow) if idempotent and detail.metadata.name: # Introducing a delay before checking the current workflow run # since it takes some time to start or stop @@ -203,7 +258,9 @@ def rerun( return WorkflowRunResponse(**raw_json) @validate_arguments - def run(self, workflow: Workflow) -> WorkflowResponse: + def run( + self, workflow: Workflow, workflow_schedule: Optional[WorkflowSchedule] = None + ) -> WorkflowResponse: """ Run the Atlan workflow with a specific configuration. @@ -213,10 +270,16 @@ def run(self, workflow: Workflow) -> WorkflowResponse: Consider using the "rerun()" method instead to re-execute an existing workflow. :param workflow: The workflow to run. + :param workflow_schedule: (Optional) a WorkflowSchedule object containing: + - A cron schedule expression, e.g: `5 4 * * *`. + - The time zone for the cron schedule, e.g: `Europe/Paris`. + :returns: Details of the workflow run. :raises ValidationError: If the provided `workflow` is invalid. :raises AtlanError: on any API communication issue. """ + if workflow_schedule: + self._add_schedule(workflow, workflow_schedule) raw_json = self._client._call_api( WORKFLOW_RUN, request_obj=workflow, @@ -244,7 +307,8 @@ def monitor( self, workflow_response: WorkflowResponse, logger: Optional[Logger] = None ) -> Optional[AtlanWorkflowPhase]: """ - Monitor the status of the workflow's run, + Monitor the status of the workflow's run. + :param workflow_response: The workflow_response returned from running the workflow :param logger: the logger to log status information (logging.INFO for summary info. logging:DEBUG for detail info) @@ -274,3 +338,193 @@ def monitor( def _get_run_details(self, name: str) -> Optional[WorkflowSearchResult]: return self._find_latest_run(workflow_name=name) + + @validate_arguments + def get_runs( + self, + workflow_name: str, + workflow_phase: AtlanWorkflowPhase, + from_: int = 0, + size: int = 100, + ) -> Optional[List[WorkflowSearchResult]]: + """ + Retrieves all workflow runs. + + :param workflow_name: name of the workflow as displayed + in the UI (e.g: `atlan-snowflake-miner-1714638976`). + :param workflow_phase: phase of the given workflow (e.g: Succeeded, Running, Failed, etc). + :param from_: starting index of the search results (default: `0`). + :param size: maximum number of search results to return (default: `100`). + :returns: a list of runs of the given workflow. + :raises AtlanError: on any API communication issue. + """ + query = Bool( + must=[ + NestedQuery( + query=Term( + field="spec.workflowTemplateRef.name.keyword", + value=workflow_name, + ), + path="spec", + ) + ], + filter=[Term(field="status.phase.keyword", value=workflow_phase.value)], + ) + response = self._find_runs(query) + return results if (results := response.hits.hits) else None + + @validate_arguments + def stop( + self, + workflow_run_id: str, + ) -> WorkflowRunResponse: + """ + Stop the provided, running workflow. + + :param workflow_run_id: identifier of the specific workflow run + to stop eg: `atlan-snowflake-miner-1714638976-9wfxz`. + :returns: details of the stopped workflow. + :raises AtlanError: on any API communication issue. + """ + raw_json = self._client._call_api( + STOP_WORKFLOW_RUN.format_path({"workflow_run_id": workflow_run_id}), + ) + return self._parse_response(raw_json, WorkflowRunResponse) + + @validate_arguments + def delete( + self, + workflow_name: str, + ) -> None: + """ + Archive (delete) the provided workflow. + + :param workflow_name: name of the workflow as displayed + in the UI (e.g: `atlan-snowflake-miner-1714638976`). + :raises AtlanError: on any API communication issue. + """ + self._client._call_api( + WORKFLOW_ARCHIVE.format_path({"workflow_name": workflow_name}), + ) + + @overload + def add_schedule( + self, workflow: WorkflowResponse, workflow_schedule: WorkflowSchedule + ) -> WorkflowResponse: ... + + @overload + def add_schedule( + self, workflow: WorkflowPackage, workflow_schedule: WorkflowSchedule + ) -> WorkflowResponse: ... + + @overload + def add_schedule( + self, workflow: WorkflowSearchResult, workflow_schedule: WorkflowSchedule + ) -> WorkflowResponse: ... + + @overload + def add_schedule( + self, workflow: WorkflowSearchResultDetail, workflow_schedule: WorkflowSchedule + ) -> WorkflowResponse: ... + + @validate_arguments + def add_schedule( + self, + workflow: Union[ + WorkflowResponse, + WorkflowPackage, + WorkflowSearchResult, + WorkflowSearchResultDetail, + ], + workflow_schedule: WorkflowSchedule, + ) -> WorkflowResponse: + """ + Add a schedule for an existing workflow run. + + :param workflow: existing workflow run to schedule. + :param workflow_schedule: a WorkflowSchedule object containing: + - A cron schedule expression, e.g: `5 4 * * *`. + - The time zone for the cron schedule, e.g: `Europe/Paris`. + + :returns: a scheduled workflow. + :raises AtlanError: on any API communication issue. + """ + + workflow_to_update = self._handle_workflow_types(workflow) + self._add_schedule(workflow_to_update, workflow_schedule) + raw_json = self._client._call_api( + WORKFLOW_UPDATE.format_path( + {"workflow_name": workflow_to_update.metadata.name} + ), + request_obj=workflow_to_update, + ) + return WorkflowResponse(**raw_json) + + @overload + def remove_schedule(self, workflow: WorkflowResponse) -> WorkflowResponse: ... + + @overload + def remove_schedule(self, workflow: WorkflowPackage) -> WorkflowResponse: ... + + @overload + def remove_schedule(self, workflow: WorkflowSearchResult) -> WorkflowResponse: ... + + @overload + def remove_schedule( + self, workflow: WorkflowSearchResultDetail + ) -> WorkflowResponse: ... + + @validate_arguments + def remove_schedule( + self, + workflow: Union[ + WorkflowResponse, + WorkflowPackage, + WorkflowSearchResult, + WorkflowSearchResultDetail, + ], + ) -> WorkflowResponse: + """ + Remove a scheduled run from an existing workflow run. + + :param workflow_run: existing workflow run to remove the schedule from. + :returns: a workflow. + :raises AtlanError: on any API communication issue. + """ + workflow_to_update = self._handle_workflow_types(workflow) + workflow_to_update.metadata.annotations and workflow_to_update.metadata.annotations.pop( + self._WORKFLOW_RUN_SCHEDULE, None + ) + raw_json = self._client._call_api( + WORKFLOW_UPDATE.format_path( + {"workflow_name": workflow_to_update.metadata.name} + ), + request_obj=workflow_to_update, + ) + return WorkflowResponse(**raw_json) + + @validate_arguments + def get_all_scheduled_runs(self) -> List[WorkflowScheduleResponse]: + """ + Retrieve all scheduled runs for workflows. + + :returns: a list of scheduled workflow runs. + :raises AtlanError: on any API communication issue. + """ + raw_json = self._client._call_api(GET_ALL_SCHEDULE_RUNS) + return self._parse_response(raw_json.get("items"), WorkflowScheduleResponse) + + @validate_arguments + def get_scheduled_run(self, workflow_name: str) -> WorkflowScheduleResponse: + """ + Retrieve existing scheduled run for a workflow. + + :param workflow_name: name of the workflow as displayed + in the UI (e.g: `atlan-snowflake-miner-1714638976`). + :returns: a list of scheduled workflow runs. + :raises AtlanError: on any API communication issue. + """ + raw_json = self._client._call_api( + GET_SCHEDULE_RUN.format_path({"workflow_name": f"{workflow_name}-cron"}), + ) + return self._parse_response(raw_json, WorkflowScheduleResponse) diff --git a/pyatlan/model/workflow.py b/pyatlan/model/workflow.py index 2503bdc79..063917738 100644 --- a/pyatlan/model/workflow.py +++ b/pyatlan/model/workflow.py @@ -92,7 +92,7 @@ class WorkflowSearchResultStatus(AtlanObject): resources_duration: Optional[Dict[str, int]] = Field(default=None) startedAt: Optional[str] = Field(default=None) stored_templates: Any = Field(default=None) - storedWorkflowTemplateSpec: Any = Field(default=None) + stored_workflow_template_spec: Any = Field(default=None) synchronization: Optional[Dict[str, Any]] = Field(default=None) @@ -148,13 +148,41 @@ def __init__(__pydantic_self__, **data: Any) -> None: class WorkflowResponse(AtlanObject): metadata: WorkflowMetadata spec: WorkflowSpec - payload: Optional[List[Any]] = Field(default=None) + payload: Optional[List[Any]] = Field(default_factory=list) class WorkflowRunResponse(WorkflowResponse): status: WorkflowSearchResultStatus +class WorkflowSchedule(AtlanObject): + timezone: str + cron_schedule: str + + +class WorkflowScheduleSpec(AtlanObject): + schedule: Optional[str] = Field(default=None) + timezone: Optional[str] = Field(default=None) + workflow_spec: Optional[WorkflowSpec] = Field(default=None) + concurrency_policy: Optional[str] = Field(default=None) + starting_deadline_seconds: Optional[int] = Field(default=None) + successful_jobs_history_limit: Optional[int] = Field(default=None) + failed_jobs_history_limit: Optional[int] = Field(default=None) + + +class WorkflowScheduleStatus(AtlanObject): + active: Optional[Any] = Field(default=None) + conditions: Optional[Any] = Field(default=None) + last_scheduled_time: Optional[str] = Field(default=None) + + +class WorkflowScheduleResponse(AtlanObject): + metadata: Optional[WorkflowMetadata] = Field(default=None) + spec: Optional[WorkflowScheduleSpec] = Field(default=None) + status: Optional[WorkflowScheduleStatus] = Field(default=None) + workflow_metadata: Optional[WorkflowMetadata] = Field(default=None) + + class WorkflowSearchRequest(AtlanObject): from_: int = Field(0, alias="from") size: int = 10 diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 0515018df..114e3628b 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -30,7 +30,6 @@ CertificateStatus, SortOrder, UTMTags, - WorkflowPackage, ) from pyatlan.model.fluent_search import FluentSearch from pyatlan.model.search import DSL, Bool, IndexSearchRequest, SortItem, Term @@ -548,14 +547,6 @@ def test_glossary_category_remove_announcement( _test_remove_announcement(client, category, AtlasGlossaryCategory, glossary.guid) -def test_workflow_find_by_type(client: AtlanClient): - results = client.workflow.find_by_type( - prefix=WorkflowPackage.SNOWFLAKE, max_results=10 - ) - assert results - assert len(results) >= 1 - - def test_audit_find_by_user( client: AtlanClient, current_user: UserMinimalResponse, audit_info: AuditInfo ): diff --git a/tests/integration/test_workflow_client.py b/tests/integration/test_workflow_client.py new file mode 100644 index 000000000..043621a23 --- /dev/null +++ b/tests/integration/test_workflow_client.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2024 Atlan Pte. Ltd. +import time +from typing import Generator + +import pytest + +from pyatlan.client.atlan import AtlanClient +from pyatlan.client.workflow import WorkflowClient +from pyatlan.model.assets import Connection +from pyatlan.model.enums import AtlanConnectorType, AtlanWorkflowPhase, WorkflowPackage +from pyatlan.model.packages.snowflake_miner import SnowflakeMiner +from pyatlan.model.workflow import WorkflowResponse, WorkflowSchedule +from tests.integration.client import TestId, delete_asset +from tests.integration.connection_test import create_connection + +MODULE_NAME = TestId.make_unique("WorfklowClient") +WORKFLOW_TEMPLATE_REF = "workflowTemplateRef" +WORKFLOW_SCHEDULE_SCHEDULE = "45 4 * * *" +WORKFLOW_SCHEDULE_TIMEZONE = "Asia/Kolkata" +WORKFLOW_SCHEDULE_UPDATED_1 = "45 5 * * *" +WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_1 = "Europe/Paris" +WORKFLOW_SCHEDULE_UPDATED_2 = "45 6 * * *" +WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_2 = "Europe/London" +WORKFLOW_SCHEDULE_UPDATED_3 = "45 7 * * *" +WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_3 = "Europe/Dublin" + + +@pytest.fixture(scope="module") +def connection(client: AtlanClient) -> Generator[Connection, None, None]: + connection = create_connection( + client=client, name=MODULE_NAME, connector_type=AtlanConnectorType.SNOWFLAKE + ) + yield connection + delete_asset(client, guid=connection.guid, asset_type=Connection) + + +def delete_workflow(client: AtlanClient, workflow_name: str) -> None: + client.workflow.delete(workflow_name=workflow_name) + + +@pytest.fixture(scope="module") +def workflow( + client: AtlanClient, connection: Connection +) -> Generator[WorkflowResponse, None, None]: + assert connection and connection.qualified_name + miner = ( + SnowflakeMiner(connection_qualified_name=connection.qualified_name) + .s3( + s3_bucket="test-s3-bucket", + s3_prefix="test-s3-prefix", + s3_bucket_region="test-s3-bucket-region", + sql_query_key="TEST_QUERY", + default_database_key="TEST_SNOWFLAKE", + default_schema_key="TEST_SCHEMA", + session_id_key="TEST_SESSION_ID", + ) + .popularity_window(days=15) + .native_lineage(enabled=True) + .custom_config(config={"test": True, "feature": 1234}) + .to_workflow() + ) + schedule = WorkflowSchedule( + cron_schedule=WORKFLOW_SCHEDULE_SCHEDULE, timezone=WORKFLOW_SCHEDULE_TIMEZONE + ) + workflow = client.workflow.run(miner, workflow_schedule=schedule) + assert workflow + # Adding some delay to make sure + # the workflow run is indexed in ES. + time.sleep(30) + yield workflow + assert workflow.metadata.name + delete_workflow(client, workflow.metadata.name) + + +def test_workflow_find_by_type(client: AtlanClient): + results = client.workflow.find_by_type( + prefix=WorkflowPackage.SNOWFLAKE, max_results=10 + ) + assert results + assert len(results) >= 1 + + +def test_workflow_get_runs_and_stop(client: AtlanClient, workflow: WorkflowResponse): + # Retrieve the lastest workflow run + assert workflow and workflow.metadata.name + runs = client.workflow.get_runs( + workflow_name=workflow.metadata.name, workflow_phase=AtlanWorkflowPhase.RUNNING + ) + assert runs + assert len(runs) == 1 + run = runs[0] + assert run and run.id + assert workflow.metadata.name and (workflow.metadata.name in run.id) + + # Stop the running workflow + run_response = client.workflow.stop(workflow_run_id=run.id) + assert run_response + assert ( + run_response.status and run_response.status.phase == AtlanWorkflowPhase.RUNNING + ) + assert ( + run_response.status.stored_workflow_template_spec + and run_response.status.stored_workflow_template_spec.get( + WORKFLOW_TEMPLATE_REF + ).get("name") + == workflow.metadata.name + ) + + # Test workflow monitoring + workflow_status = client.workflow.monitor(workflow_response=workflow) + assert workflow_status == AtlanWorkflowPhase.FAILED + + +def test_workflow_get_all_scheduled_runs( + client: AtlanClient, workflow: WorkflowResponse +): + found = False + runs = client.workflow.get_all_scheduled_runs() + + assert workflow and workflow.metadata.name + scheduled_workflow_name = f"{workflow.metadata.name}-cron" + assert runs and len(runs) >= 1 + + for run in runs: + if run.metadata and run.metadata.name == scheduled_workflow_name: + found = True + break + + if not found: + pytest.fail( + f"Unable to find scheduled run for workflow: {workflow.metadata.name}" + ) + + +def _assert_scheduled_run(client: AtlanClient, workflow: WorkflowResponse): + assert workflow and workflow.metadata.name + scheduled_workflow = client.workflow.get_scheduled_run( + workflow_name=workflow.metadata.name + ) + scheduled_workflow_name = f"{workflow.metadata.name}-cron" + assert ( + scheduled_workflow + and scheduled_workflow.metadata + and scheduled_workflow.metadata.name == scheduled_workflow_name + ) + + +def test_workflow_get_scheduled_run(client: AtlanClient, workflow: WorkflowResponse): + _assert_scheduled_run(client, workflow) + + +def _assert_add_schedule(workflow, scheduled_workflow, schedule, timezone): + assert scheduled_workflow + assert scheduled_workflow.metadata + assert scheduled_workflow.metadata.name == workflow.metadata.name + assert scheduled_workflow.metadata.annotations + assert ( + scheduled_workflow.metadata.annotations.get( + WorkflowClient._WORKFLOW_RUN_SCHEDULE + ) + == schedule + ) + assert ( + scheduled_workflow.metadata.annotations.get( + WorkflowClient._WORKFLOW_RUN_TIMEZONE + ) + == timezone + ) + + +def _assert_remove_schedule(response, workflow): + assert response + assert response.metadata.annotations + assert response.metadata.name == workflow.metadata.name + assert WorkflowClient._WORKFLOW_RUN_TIMEZONE in response.metadata.annotations + assert WorkflowClient._WORKFLOW_RUN_SCHEDULE not in response.metadata.annotations + + +def test_workflow_add_remove_schedule(client: AtlanClient, workflow: WorkflowResponse): + schedule = WorkflowSchedule( + cron_schedule=WORKFLOW_SCHEDULE_UPDATED_1, + timezone=WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_1, + ) + + # NOTE: This method will overwrite existing workflow run schedule + # Try to update schedule again, with `Workflow` object + scheduled_workflow = client.workflow.add_schedule( + workflow=workflow, workflow_schedule=schedule + ) + + _assert_add_schedule( + workflow, + scheduled_workflow, + WORKFLOW_SCHEDULE_UPDATED_1, + WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_1, + ) + # Make sure scheduled run exists + _assert_scheduled_run(client, workflow) + # Now remove the scheduled run + response = client.workflow.remove_schedule(workflow) + _assert_remove_schedule(response, workflow) + + # Try to update schedule again, with `WorkflowSearchResult` object + existing_workflow = client.workflow.find_by_type( + prefix=WorkflowPackage.SNOWFLAKE_MINER + )[0] + assert existing_workflow + assert existing_workflow.source.metadata.name == workflow.metadata.name + + schedule = WorkflowSchedule( + cron_schedule=WORKFLOW_SCHEDULE_UPDATED_2, + timezone=WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_2, + ) + scheduled_workflow = client.workflow.add_schedule( + workflow=existing_workflow, workflow_schedule=schedule + ) + + _assert_add_schedule( + workflow, + scheduled_workflow, + WORKFLOW_SCHEDULE_UPDATED_2, + WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_2, + ) + # Make sure scheduled run exists + _assert_scheduled_run(client, workflow) + # Now remove the scheduled run + response = client.workflow.remove_schedule(workflow) + _assert_remove_schedule(response, workflow) + + # Try to update schedule again, with `WorkflowPackage` object + schedule = WorkflowSchedule( + cron_schedule=WORKFLOW_SCHEDULE_UPDATED_3, + timezone=WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_3, + ) + scheduled_workflow = client.workflow.add_schedule( + workflow=WorkflowPackage.SNOWFLAKE_MINER, workflow_schedule=schedule + ) + + _assert_add_schedule( + workflow, + scheduled_workflow, + WORKFLOW_SCHEDULE_UPDATED_3, + WORKFLOW_SCHEDULE_TIMEZONE_UPDATED_3, + ) + # Make sure scheduled run exists + _assert_scheduled_run(client, workflow) + # Now remove the scheduled run + response = client.workflow.remove_schedule(workflow) + _assert_remove_schedule(response, workflow) diff --git a/tests/unit/constants.py b/tests/unit/constants.py index 9c404c32f..ce239260d 100644 --- a/tests/unit/constants.py +++ b/tests/unit/constants.py @@ -1,4 +1,11 @@ from pyatlan.model.assets import AtlasGlossary +from pyatlan.model.enums import AtlanWorkflowPhase +from pyatlan.model.workflow import ( + WorkflowMetadata, + WorkflowResponse, + WorkflowSchedule, + WorkflowSpec, +) TEST_ASSET_CLIENT_METHODS = { "find_personas_by_name": [ @@ -466,6 +473,71 @@ ], } +TEST_WORKFLOW_CLIENT_METHODS = { + "run": [ + (["abc"], "value is not a valid dict"), + ([None], "none is not an allowed value"), + ], + "rerun": [ + (["abc"], "value is not a valid enumeration member"), + ([None], "none is not an allowed value"), + ], + "update": [ + (["abc"], "value is not a valid dict"), + ([None], "none is not an allowed value"), + ], + "find_by_type": [ + (["abc"], "value is not a valid enumeration member"), + ([None], "none is not an allowed value"), + ], + "monitor": [ + (["abc", "test-logger"], "value is not a valid dict"), + ( + [ + WorkflowResponse(metadata=WorkflowMetadata(), spec=WorkflowSpec()), + "test-logger", + ], + "instance of Logger expected", + ), + ([None, "test-logger"], "none is not an allowed value"), + ], + "get_runs": [ + ([[123], AtlanWorkflowPhase.RUNNING, 123, 456], "str type expected"), + ([None, AtlanWorkflowPhase.RUNNING, 123, 456], "none is not an allowed value"), + ], + "stop": [ + ([[123]], "str type expected"), + ([None], "none is not an allowed value"), + ], + "delete": [ + ([[123]], "str type expected"), + ([None], "none is not an allowed value"), + ], + "add_schedule": [ + ( + [[123], WorkflowSchedule(timezone="atlan", cron_schedule="*")], + "value is not a valid dict", + ), + ( + [[123], WorkflowSchedule(timezone="atlan", cron_schedule="*")], + "value is not a valid enumeration member", + ), + ( + [None, WorkflowSchedule(timezone="atlan", cron_schedule="*")], + "none is not an allowed value", + ), + ], + "remove_schedule": [ + ([[123]], "value is not a valid dict"), + ([[123]], "value is not a valid enumeration member"), + ([None], "none is not an allowed value"), + ], + "get_scheduled_run": [ + ([[123]], "str type expected"), + ([None], "none is not an allowed value"), + ], +} + APPLICABLE_GLOSSARIES = "applicable_glossaries" APPLICABLE_CONNECTIONS = "applicable_connections" diff --git a/tests/unit/test_workflow_client.py b/tests/unit/test_workflow_client.py index 091e9297c..456a37ab8 100644 --- a/tests/unit/test_workflow_client.py +++ b/tests/unit/test_workflow_client.py @@ -5,6 +5,7 @@ import pytest from pydantic.v1 import ValidationError +from pyatlan.client.atlan import AtlanClient from pyatlan.client.common import ApiCaller from pyatlan.client.constants import WORKFLOW_INDEX_SEARCH from pyatlan.client.workflow import WorkflowClient @@ -16,6 +17,10 @@ WorkflowMetadata, WorkflowResponse, WorkflowRunResponse, + WorkflowSchedule, + WorkflowScheduleResponse, + WorkflowScheduleSpec, + WorkflowScheduleStatus, WorkflowSearchHits, WorkflowSearchRequest, WorkflowSearchResponse, @@ -24,6 +29,13 @@ WorkflowSearchResultStatus, WorkflowSpec, ) +from tests.unit.constants import TEST_WORKFLOW_CLIENT_METHODS + + +@pytest.fixture(autouse=True) +def set_env(monkeypatch): + monkeypatch.setenv("ATLAN_BASE_URL", "https://test.atlan.com") + monkeypatch.setenv("ATLAN_API_KEY", "test-api-key") @pytest.fixture() @@ -102,7 +114,7 @@ def rerun_response_with_idempotent( @pytest.fixture() -def run_response() -> WorkflowResponse: +def workflow_response() -> WorkflowResponse: return WorkflowResponse( metadata=WorkflowMetadata(name="name", namespace="namespace"), spec=WorkflowSpec(), @@ -110,6 +122,35 @@ def run_response() -> WorkflowResponse: ) +@pytest.fixture() +def workflow_run_response() -> WorkflowRunResponse: + return WorkflowRunResponse( + metadata=WorkflowMetadata(name="name", namespace="namespace"), + spec=WorkflowSpec(), + payload=[PackageParameter(parameter="test-param", type="test-type", body={})], + status=WorkflowSearchResultStatus(phase=AtlanWorkflowPhase.RUNNING), + ) + + +@pytest.fixture() +def schedule() -> WorkflowSchedule: + return WorkflowSchedule(timezone="Europe/Paris", cron_schedule="45 4 * * *") + + +@pytest.fixture() +def schedule_response() -> WorkflowScheduleResponse: + return WorkflowScheduleResponse( + spec=WorkflowScheduleSpec(), + metadata=WorkflowMetadata(name="name", namespace="namespace"), + workflow_metadata=WorkflowMetadata(name="name", namespace="namespace"), + status=WorkflowScheduleStatus( + active="test-active", + conditions="test-conditions", + last_scheduled_time="test-last-scheduled-time", + ), + ) + + @pytest.fixture() def update_response() -> WorkflowResponse: return WorkflowResponse( @@ -127,19 +168,12 @@ def test_init_when_wrong_class_raises_exception(api_caller): WorkflowClient(api_caller) -@pytest.mark.parametrize( - "prefix, error_msg", - [ - ["abc", "value is not a valid enumeration member"], - [None, "none is not an allowed value"], - ], -) -def test_find_by_type_when_given_wrong_parameters_raises_validation_error( - prefix, error_msg, client: WorkflowClient -): - with pytest.raises(ValidationError) as err: - client.find_by_type(prefix=prefix) - assert error_msg in str(err.value) +@pytest.mark.parametrize("method, params", TEST_WORKFLOW_CLIENT_METHODS.items()) +def test_workflow_client_methods_validation_error(method, params): + client_method = getattr(AtlanClient().workflow, method) + for param_values, error_msg in params: + with pytest.raises(ValidationError, match=error_msg): + client_method(*param_values) def test_find_by_type(client: WorkflowClient, mock_api_caller): @@ -154,21 +188,6 @@ def test_find_by_type(client: WorkflowClient, mock_api_caller): ) -@pytest.mark.parametrize( - "workflow, error_msg", - [ - ["abc", "value is not a valid enumeration member"], - [None, "none is not an allowed value"], - ], -) -def test_re_run_when_given_wrong_parameter_raises_validation_error( - workflow, error_msg, client: WorkflowClient -): - with pytest.raises(ValidationError) as err: - client.rerun(workflow=workflow) - assert error_msg in str(err.value) - - def test_re_run_when_given_workflowpackage_with_no_prior_runs_raises_invalid_request_error( client: WorkflowClient, mock_api_caller ): @@ -194,6 +213,8 @@ def test_re_run_when_given_workflowpackage( ] assert client.rerun(WorkflowPackage.FIVETRAN) == rerun_response + assert mock_api_caller._call_api.call_count == 2 + mock_api_caller.reset_mock() def test_re_run_when_given_workflowsearchresultdetail( @@ -205,6 +226,8 @@ def test_re_run_when_given_workflowsearchresultdetail( mock_api_caller._call_api.return_value = rerun_response.dict() assert client.rerun(workflow=search_result_detail) == rerun_response + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() def test_re_run_when_given_workflowsearchresult( @@ -216,6 +239,8 @@ def test_re_run_when_given_workflowsearchresult( mock_api_caller._call_api.return_value = rerun_response.dict() assert client.rerun(workflow=search_result) == rerun_response + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() def test_re_run_when_given_workflowpackage_with_idempotent( @@ -228,13 +253,14 @@ def test_re_run_when_given_workflowpackage_with_idempotent( mock_api_caller._call_api.side_effect = [ search_response.dict(), search_response.dict(), - rerun_response_with_idempotent.dict(), ] assert ( client.rerun(WorkflowPackage.FIVETRAN, idempotent=True) == rerun_response_with_idempotent ) + assert mock_api_caller._call_api.call_count == 2 + mock_api_caller.reset_mock() def test_re_run_when_given_workflowsearchresultdetail_with_idempotent( @@ -245,15 +271,14 @@ def test_re_run_when_given_workflowsearchresultdetail_with_idempotent( search_result_detail: WorkflowSearchResultDetail, rerun_response_with_idempotent: WorkflowRunResponse, ): - mock_api_caller._call_api.side_effect = [ - search_response.dict(), - rerun_response_with_idempotent.dict(), - ] + mock_api_caller._call_api.return_value = search_response.dict() assert ( client.rerun(workflow=search_result_detail, idempotent=True) == rerun_response_with_idempotent ) + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() def test_re_run_when_given_workflowsearchresult_with_idempotent( @@ -264,43 +289,22 @@ def test_re_run_when_given_workflowsearchresult_with_idempotent( search_result: WorkflowSearchResult, rerun_response_with_idempotent: WorkflowRunResponse, ): - mock_api_caller._call_api.side_effect = [ - search_response.dict(), - rerun_response_with_idempotent.dict(), - ] + mock_api_caller._call_api.return_value = search_response.dict() assert ( client.rerun(workflow=search_result, idempotent=True) == rerun_response_with_idempotent ) - - -@pytest.mark.parametrize( - "workflow_response, logger, error_msg", - [ - ["abc", "test-logger", "value is not a valid dict"], - [ - WorkflowResponse(metadata=WorkflowMetadata(), spec=WorkflowSpec()), - "test-logger", - "instance of Logger expected", - ], - [None, "test-logger", "none is not an allowed value"], - ], -) -def test_monitor_when_given_wrong_parameter_raises_validation_error( - workflow_response, logger, error_msg, client: WorkflowClient -): - with pytest.raises(ValidationError) as err: - client.monitor(workflow_response, logger=logger) - assert error_msg in str(err.value) + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() def test_run_when_given_workflow( client: WorkflowClient, mock_api_caller, - run_response: WorkflowResponse, + workflow_response: WorkflowResponse, ): - mock_api_caller._call_api.return_value = run_response.dict() + mock_api_caller._call_api.return_value = workflow_response.dict() response = client.run( Workflow( metadata=WorkflowMetadata(name="name", namespace="namespace"), @@ -310,22 +314,31 @@ def test_run_when_given_workflow( ], ) # type: ignore[call-arg] ) - assert response == run_response + assert response == workflow_response + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() -@pytest.mark.parametrize( - "workflow, error_msg", - [ - ["abc", "value is not a valid dict"], - [None, "none is not an allowed value"], - ], -) -def test_run_when_given_wrong_parameter_raises_validation_error( - workflow, error_msg, client: WorkflowClient +def test_run_when_given_workflow_with_schedule( + client: WorkflowClient, + schedule: WorkflowSchedule, + mock_api_caller, + workflow_response: WorkflowResponse, ): - with pytest.raises(ValidationError) as err: - client.run(workflow) - assert error_msg in str(err.value) + mock_api_caller._call_api.return_value = workflow_response.dict() + response = client.run( + Workflow( + metadata=WorkflowMetadata(name="name", namespace="namespace"), + spec=WorkflowSpec(), + payload=[ + PackageParameter(parameter="test-param", type="test-type", body={}) + ], + ), # type: ignore[call-arg] + workflow_schedule=schedule, + ) + assert response == workflow_response + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() def test_update_when_given_workflow( @@ -337,18 +350,151 @@ def test_update_when_given_workflow( mock_api_caller._call_api.return_value = update_response.dict() response = client.update(workflow=search_result.to_workflow()) assert response == update_response + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() -@pytest.mark.parametrize( - "workflow, error_msg", - [ - ["abc", "value is not a valid dict"], - [None, "none is not an allowed value"], - ], -) -def test_update_when_given_wrong_parameter_raises_validation_error( - workflow, error_msg, client: WorkflowClient +def test_workflow_get_runs( + client: WorkflowClient, + mock_api_caller, + search_response: WorkflowSearchResponse, +): + mock_api_caller._call_api.return_value = search_response.dict() + response = client.get_runs( + workflow_name="test-workflow", + workflow_phase=AtlanWorkflowPhase.RUNNING, + ) + + assert response == search_response.hits.hits + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() + + +def test_workflow_stop( + client: WorkflowClient, + mock_api_caller, + workflow_run_response: WorkflowRunResponse, ): - with pytest.raises(ValidationError) as err: - client.update(workflow=workflow) - assert error_msg in str(err.value) + mock_api_caller._call_api.return_value = workflow_run_response.dict() + response = client.stop(workflow_run_id="test-workflow-run-id") + + assert response == WorkflowRunResponse(**workflow_run_response.dict()) + assert mock_api_caller._call_api.call_count == 1 + mock_api_caller.reset_mock() + + +def test_workflow_delete(client: WorkflowClient, mock_api_caller): + mock_api_caller._call_api.return_value = None + assert not client.delete(workflow_name="test-workflow") + + +def test_workflow_add_schedule( + client: WorkflowClient, + schedule: WorkflowSchedule, + workflow_response: WorkflowResponse, + search_response: WorkflowSearchResponse, + search_result: WorkflowSearchResult, + mock_api_caller, +): + # Workflow response + mock_api_caller._call_api.side_effect = [ + workflow_response.dict(), + ] + response = client.add_schedule( + workflow=workflow_response, workflow_schedule=schedule + ) + + assert mock_api_caller._call_api.call_count == 1 + assert response == WorkflowResponse(**workflow_response.dict()) + mock_api_caller.reset_mock() + + # Workflow package + mock_api_caller._call_api.side_effect = [ + search_response.dict(), + workflow_response.dict(), + ] + response = client.add_schedule( + workflow=WorkflowPackage.FIVETRAN, workflow_schedule=schedule + ) + + assert mock_api_caller._call_api.call_count == 2 + assert response == WorkflowResponse(**workflow_response.dict()) + mock_api_caller.reset_mock() + + # Workflow search result + mock_api_caller._call_api.side_effect = [workflow_response.dict()] + response = client.add_schedule(workflow=search_result, workflow_schedule=schedule) + + assert mock_api_caller._call_api.call_count == 1 + assert response == WorkflowResponse(**workflow_response.dict()) + mock_api_caller.reset_mock() + + +def test_workflow_remove_schedule( + client: WorkflowClient, + workflow_response: WorkflowResponse, + search_response: WorkflowSearchResponse, + search_result: WorkflowSearchResult, + mock_api_caller, +): + # Workflow response + mock_api_caller._call_api.side_effect = [ + workflow_response.dict(), + ] + response = client.remove_schedule(workflow=workflow_response) + + assert mock_api_caller._call_api.call_count == 1 + assert response == WorkflowResponse(**workflow_response.dict()) + mock_api_caller.reset_mock() + + # Workflow package + mock_api_caller._call_api.side_effect = [ + search_response.dict(), + workflow_response.dict(), + ] + response = client.remove_schedule(workflow=WorkflowPackage.FIVETRAN) + + assert mock_api_caller._call_api.call_count == 2 + assert response == WorkflowResponse(**workflow_response.dict()) + mock_api_caller.reset_mock() + + # Workflow search result + mock_api_caller._call_api.side_effect = [workflow_response.dict()] + response = client.remove_schedule(workflow=search_result) + + assert mock_api_caller._call_api.call_count == 1 + assert response == WorkflowResponse(**workflow_response.dict()) + mock_api_caller.reset_mock() + + +def test_workflow_get_all_scheduled_runs( + client: WorkflowClient, + workflow_response: WorkflowResponse, + search_response: WorkflowSearchResponse, + search_result: WorkflowSearchResult, + schedule_response: WorkflowScheduleResponse, + mock_api_caller, +): + mock_api_caller._call_api.return_value = {"items": [schedule_response]} + response = client.get_all_scheduled_runs() + + assert mock_api_caller._call_api.call_count == 1 + assert response and len(response) == 1 + assert response[0] == WorkflowScheduleResponse(**schedule_response.dict()) + mock_api_caller.reset_mock() + + +def test_workflow_get_scheduled_run( + client: WorkflowClient, + workflow_response: WorkflowResponse, + search_response: WorkflowSearchResponse, + search_result: WorkflowSearchResult, + schedule_response: WorkflowScheduleResponse, + mock_api_caller, +): + mock_api_caller._call_api.return_value = schedule_response + response = client.get_scheduled_run(workflow_name="test-workflow") + + assert mock_api_caller._call_api.call_count == 1 + assert response == WorkflowScheduleResponse(**schedule_response.dict()) + mock_api_caller.reset_mock() From d4233205b1afc64c1901cb79fde897ea728515ee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 May 2024 20:57:07 +0000 Subject: [PATCH 2/6] Bump jinja2 from 3.1.3 to 3.1.4 in /docs Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- docs/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index df9a5f3d7..fcca8ff99 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -2,6 +2,6 @@ sphinx~=7.2.6 furo~=2024.1.29 requests>=2.24 pydantic~=2.6.1 -jinja2==3.1.3 +jinja2==3.1.4 networkx==3.1 tenacity==8.2.3 From 5b591e19bca3a7943d0d4252e3c64fa222e24af2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 May 2024 20:55:55 +0000 Subject: [PATCH 3/6] Bump jinja2 from 3.1.3 to 3.1.4 Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4. - [Release notes](https://github.com/pallets/jinja/releases) - [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst) - [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4) --- updated-dependencies: - dependency-name: jinja2 dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a2175d67c..cfcdee362 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ requests>=2.24 pydantic~=2.6.1 -jinja2==3.1.3 +jinja2==3.1.4 networkx>=3.1 tenacity==8.2.3 From ac85f6ad6a83b5f737012bd65eafb9e774733170 Mon Sep 17 00:00:00 2001 From: Ernest Hill Date: Wed, 8 May 2024 11:27:03 +0300 Subject: [PATCH 4/6] Initial implementation of DVX-430 --- pyatlan/client/asset.py | 12 ++++++++++-- tests/integration/glossary_test.py | 14 +++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/pyatlan/client/asset.py b/pyatlan/client/asset.py index f7ed782eb..11a1fc092 100644 --- a/pyatlan/client/asset.py +++ b/pyatlan/client/asset.py @@ -1535,7 +1535,10 @@ def find_term_by_name( # the issue below is fixed or when we switch to pydantic v2 # https://github.com/pydantic/pydantic/issues/2901 def get_hierarchy( - self, glossary: AtlasGlossary, attributes: Optional[List[AtlanField]] = None + self, + glossary: AtlasGlossary, + attributes: Optional[List[AtlanField]] = None, + related_attributes: Optional[List[AtlanField]] = None, ) -> CategoryHierarchy: """ Retrieve category hierarchy in this Glossary, in a traversable form. You can traverse in either depth_first @@ -1546,6 +1549,7 @@ def get_hierarchy( :param glossary: the glossary to retrieve the category hierarchy for :param attributes: attributes to retrieve for each category in the hierarchy + :param related_attributes: attributes to retrieve for each related asset in the hierarchy :returns: a traversable category hierarchy """ from pyatlan.model.fluent_search import FluentSearch @@ -1554,6 +1558,8 @@ def get_hierarchy( raise ErrorCode.GLOSSARY_MISSING_QUALIFIED_NAME.exception_with_parameters() if attributes is None: attributes = [] + if related_attributes is None: + related_attributes = [] top_categories: Set[str] = set() category_dict: Dict[str, AtlasGlossaryCategory] = {} search = ( @@ -1565,7 +1571,9 @@ def get_hierarchy( .sort(AtlasGlossaryCategory.NAME.order(SortOrder.ASCENDING)) ) for field in attributes: - search.include_on_results(field) + search = search.include_on_results(field) + for field in related_attributes: + search = search.include_on_relations(field) request = search.to_request() response = self.search(request) for category in filter( diff --git a/tests/integration/glossary_test.py b/tests/integration/glossary_test.py index 3000a7644..8ff9648f0 100644 --- a/tests/integration/glossary_test.py +++ b/tests/integration/glossary_test.py @@ -588,7 +588,6 @@ def test_find_category_by_name_qn_guid_correctly_populated( mid1a_term: AtlasGlossaryTerm, mid2a_category: AtlasGlossaryCategory, ): - category = client.asset.find_category_by_name( name=mid1a_category.name, glossary_name=hierarchy_glossary.name, @@ -672,7 +671,11 @@ def test_hierarchy( leaf2ba_category: AtlasGlossaryCategory, ): sleep(10) - hierarchy = client.asset.get_hierarchy(glossary=hierarchy_glossary) + hierarchy = client.asset.get_hierarchy( + glossary=hierarchy_glossary, + attributes=[AtlasGlossaryCategory.TERMS], + related_attributes=[AtlasGlossaryTerm.NAME], + ) root_categories = hierarchy.root_categories @@ -682,9 +685,14 @@ def test_hierarchy( assert root_categories[1].name assert "top" in root_categories[0].name assert "top" in root_categories[1].name - assert hierarchy.get_category(top1_category.guid) + category_without_terms = hierarchy.get_category(top1_category.guid) + assert 0 == len(category_without_terms.terms) assert hierarchy.get_category(mid1a_category.guid) + category_with_term = hierarchy.get_category(mid1a_category.guid) + assert category_with_term.terms + assert 1 == len(category_with_term.terms) + assert f"mid1a_{TERM_NAME1}" == category_with_term.terms[0].name assert hierarchy.get_category(leaf1aa_category.guid) assert hierarchy.get_category(leaf1ab_category.guid) assert hierarchy.get_category(mid1b_category.guid) From 0d531c095dc544c9ba7a3e4e1f286d33adb3b1c3 Mon Sep 17 00:00:00 2001 From: Aryamanz29 Date: Wed, 8 May 2024 14:43:42 +0530 Subject: [PATCH 5/6] Req-changes: Use `any()` method for checking the workflow run found --- tests/integration/test_workflow_client.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_workflow_client.py b/tests/integration/test_workflow_client.py index 043621a23..94a99f175 100644 --- a/tests/integration/test_workflow_client.py +++ b/tests/integration/test_workflow_client.py @@ -115,17 +115,15 @@ def test_workflow_get_runs_and_stop(client: AtlanClient, workflow: WorkflowRespo def test_workflow_get_all_scheduled_runs( client: AtlanClient, workflow: WorkflowResponse ): - found = False runs = client.workflow.get_all_scheduled_runs() assert workflow and workflow.metadata.name scheduled_workflow_name = f"{workflow.metadata.name}-cron" assert runs and len(runs) >= 1 - for run in runs: - if run.metadata and run.metadata.name == scheduled_workflow_name: - found = True - break + found = any( + run.metadata and run.metadata.name == scheduled_workflow_name for run in runs + ) if not found: pytest.fail( From 11ab0962561747fb42d52892a98c5ace3902eb4e Mon Sep 17 00:00:00 2001 From: Ernest Hill Date: Wed, 8 May 2024 13:06:45 +0300 Subject: [PATCH 6/6] Fix mypy violation --- tests/integration/glossary_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/glossary_test.py b/tests/integration/glossary_test.py index 8ff9648f0..77444976b 100644 --- a/tests/integration/glossary_test.py +++ b/tests/integration/glossary_test.py @@ -687,6 +687,7 @@ def test_hierarchy( assert "top" in root_categories[1].name assert hierarchy.get_category(top1_category.guid) category_without_terms = hierarchy.get_category(top1_category.guid) + assert category_without_terms.terms is not None assert 0 == len(category_without_terms.terms) assert hierarchy.get_category(mid1a_category.guid) category_with_term = hierarchy.get_category(mid1a_category.guid)