diff --git a/metadata-ingestion/docs/sources/vertexai/vertexai_pre.md b/metadata-ingestion/docs/sources/vertexai/vertexai_pre.md index 98047482299a49..73c9fb4454a2c4 100644 --- a/metadata-ingestion/docs/sources/vertexai/vertexai_pre.md +++ b/metadata-ingestion/docs/sources/vertexai/vertexai_pre.md @@ -3,7 +3,7 @@ #### Credential to access to GCP 1. Follow the section on credentials to access Vertex AI [GCP docs](https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to). -#### Create a service account in the Extractor Project +#### Create a service account and assign roles 1. Setup a ServiceAccount as per [GCP docs](https://cloud.google.com/iam/docs/creating-managing-service-accounts#iam-service-accounts-create-console) and assign the previously created role to this service account. diff --git a/metadata-ingestion/src/datahub/ingestion/source/vertexai.py b/metadata-ingestion/src/datahub/ingestion/source/vertexai.py index f3586875430181..fca1248f075fbf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/vertexai.py +++ b/metadata-ingestion/src/datahub/ingestion/source/vertexai.py @@ -85,10 +85,11 @@ class GCPCredential(ConfigModel): _fix_private_key_newlines = pydantic_multiline_string("private_key") - def create_credential_temp_file(self, project_id: str) -> str: + def create_credential_temp_file(self, project_id: Optional[str] = None) -> str: # Adding project_id from the top level config configs = self.dict() - configs["project_id"] = project_id + if project_id: + configs["project_id"] = project_id with tempfile.NamedTemporaryFile(delete=False) as fp: cred_json = json.dumps(configs, indent=4, separators=(",", ": ")) fp.write(cred_json.encode()) @@ -168,14 +169,14 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: """ # Ingest Project - yield from self._get_project_workunits() + yield from self._gen_project_workunits() # Fetch and Ingest Models, Model Versions a from Model Registry - yield from self._get_ml_model_workunits() + yield from self._get_ml_models_workunits() # Fetch and Ingest Training Jobs - yield from self._get_training_jobs_workunit() + yield from self._get_training_jobs_workunits() # TODO Fetch Experiments and Experiment Runs - def _get_project_workunits(self) -> Iterable[MetadataWorkUnit]: + def _gen_project_workunits(self) -> Iterable[MetadataWorkUnit]: container_key = ProjectIdKey( project_id=self.config.project_id, platform=self.platform ) @@ -186,60 +187,25 @@ def _get_project_workunits(self) -> Iterable[MetadataWorkUnit]: sub_types=["Project"], ) - def _has_training_job(self, model: Model) -> bool: - """ - Validate Model Has Valid Training Job - """ - job = model.training_job - if not job: - return False - - try: - # when model has ref to training job, but field is sometimes not accessible and RunTImeError thrown when accessed - # if RunTimeError is not thrown, it is valid and proceed - name = job.name - logger.debug( - ( - f"can fetch training job name: {name} for model: (name:{model.display_name} id:{model.name})" - ) - ) - return True - except RuntimeError: - logger.debug( - f"cannot fetch training job name, not valid for model (name:{model.display_name} id:{model.name})" - ) - - return False - - def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]: + def _get_ml_models_workunits(self) -> Iterable[MetadataWorkUnit]: """ Fetch List of Models in Model Registry and generate a corresponding work unit. """ registered_models = self.client.Model.list() for model in registered_models: # create work unit for Model Group (= Model in VertexAI) - yield from self._get_ml_group_workunit(model) + yield from self._gen_ml_group_workunits(model) model_versions = model.versioning_registry.list_versions() for model_version in model_versions: - # create work unit for Training Job (if Model has reference to Training Job) - if self._has_training_job(model): - logger.info( - f"Ingesting a training job for a model: {model_version.model_display_name}" - ) - if model.training_job: - yield from self._get_data_process_properties_workunits( - model.training_job - ) - # create work unit for Model (= Model Version in VertexAI) logger.info( f"Ingesting a model (name: {model.display_name} id:{model.name})" ) - yield from self._get_ml_model_endpoint_workunit( + yield from self._gen_ml_model_endpoint_workunits( model=model, model_version=model_version ) - def _get_training_jobs_workunit(self) -> Iterable[MetadataWorkUnit]: + def _get_training_jobs_workunits(self) -> Iterable[MetadataWorkUnit]: """ Fetches training jobs from Vertex AI and generates corresponding work units. This method retrieves various types of training jobs from Vertex AI, including @@ -263,16 +229,16 @@ def _get_training_jobs_workunit(self) -> Iterable[MetadataWorkUnit]: for class_name in class_names: logger.info(f"Fetching a list of {class_name}s from VertexAI server") for job in getattr(self.client, class_name).list(): - yield from self._get_training_job_workunit(job) + yield from self._get_training_job_workunits(job) - def _get_training_job_workunit( + def _get_training_job_workunits( self, job: VertexAiResourceNoun ) -> Iterable[MetadataWorkUnit]: - yield from self._get_data_process_properties_workunits(job) - yield from self._get_job_output_workunit(job) - yield from self._get_job_input_workunit(job) + yield from self._generate_data_process_workunits(job) + yield from self._get_job_output_workunits(job) + yield from self._get_job_input_workunits(job) - def _get_ml_group_workunit( + def _gen_ml_group_workunits( self, model: Model, ) -> Iterable[MetadataWorkUnit]: @@ -308,7 +274,7 @@ def _make_ml_model_group_urn(self, model: Model) -> str: ) return urn - def _get_data_process_properties_workunits( + def _generate_data_process_workunits( self, job: VertexAiResourceNoun ) -> Iterable[MetadataWorkUnit]: """ @@ -382,7 +348,7 @@ def _search_model_version( return version return None - def _get_job_output_workunit( + def _get_job_output_workunits( self, job: VertexAiResourceNoun ) -> Iterable[MetadataWorkUnit]: """ @@ -408,7 +374,7 @@ def _get_job_output_workunit( f" found a training job: {job.display_name} generated " f"a model (name:{model.display_name} id:{model_version_str})" ) - yield from self._get_ml_model_endpoint_workunit( + yield from self._gen_ml_model_endpoint_workunits( model, model_version, job_urn ) @@ -437,7 +403,7 @@ def _search_dataset(self, dataset_id: str) -> Optional[VertexAiResourceNoun]: return self.datasets.get(dataset_id) - def _get_dataset_workunit( + def _get_dataset_workunits( self, dataset_urn: str, ds: VertexAiResourceNoun ) -> Iterable[MetadataWorkUnit]: """ @@ -449,7 +415,9 @@ def _get_dataset_workunit( aspects.append( DatasetPropertiesClass( name=self._make_vertexai_dataset_name(ds.name), - created=TimeStampClass(time=int(ds.create_time.timestamp() * 1000)), + created=TimeStampClass(time=int(ds.create_time.timestamp() * 1000)) + if ds.create_time + else None, description=f"Dataset: {ds.display_name}", customProperties={ "displayName": ds.display_name, @@ -467,7 +435,7 @@ def _get_dataset_workunit( MetadataChangeProposalWrapper.construct_many(dataset_urn, aspects=aspects) ) - def _get_job_input_workunit( + def _get_job_input_workunits( self, job: VertexAiResourceNoun ) -> Iterable[MetadataWorkUnit]: """ @@ -489,9 +457,9 @@ def _get_job_input_workunit( ) if dataset_id: - yield from self._get_data_process_input_workunit(job, dataset_id) + yield from self._gen_input_dataset_workunits(job, dataset_id) - def _get_data_process_input_workunit( + def _gen_input_dataset_workunits( self, job: VertexAiResourceNoun, dataset_id: str ) -> Iterable[MetadataWorkUnit]: """ @@ -510,7 +478,7 @@ def _get_data_process_input_workunit( dataset = self._search_dataset(dataset_id) if dataset_id else None if dataset: - yield from self._get_dataset_workunit(dataset_urn=dataset_urn, ds=dataset) + yield from self._get_dataset_workunits(dataset_urn=dataset_urn, ds=dataset) # Create URN of Training Job job_id = self._make_vertexai_job_name(entity_id=job.name) mcp = MetadataChangeProposalWrapper( @@ -522,7 +490,7 @@ def _get_data_process_input_workunit( ) yield from auto_workunit([mcp]) - def _get_endpoint_workunit( + def _gen_endpoint_workunits( self, endpoint: Endpoint, model: Model, model_version: VersionInfo ) -> Iterable[MetadataWorkUnit]: endpoint_urn = builder.make_ml_model_deployment_urn( @@ -532,38 +500,30 @@ def _get_endpoint_workunit( ), env=self.config.env, ) - deployment_aspect = MLModelDeploymentPropertiesClass( - description=model.description, - createdAt=int(endpoint.create_time.timestamp() * 1000), - version=VersionTagClass(versionTag=str(model_version.version_id)), - customProperties={"displayName": endpoint.display_name}, - ) - mcps = [] - mcps.append( - MetadataChangeProposalWrapper( - entityUrn=endpoint_urn, aspect=deployment_aspect + aspects: List[_Aspect] = list() + aspects.append( + MLModelDeploymentPropertiesClass( + description=model.description, + createdAt=int(endpoint.create_time.timestamp() * 1000), + version=VersionTagClass(versionTag=str(model_version.version_id)), + customProperties={"displayName": endpoint.display_name}, ) ) - mcps.append( - MetadataChangeProposalWrapper( - entityUrn=endpoint_urn, - aspect=ContainerClass( - container=self._get_project_container().as_urn(), - ), + aspects.append( + ContainerClass( + container=self._get_project_container().as_urn(), ) ) - mcps.append( - MetadataChangeProposalWrapper( - entityUrn=endpoint_urn, aspect=SubTypesClass(typeNames=["Endpoint"]) - ) - ) + aspects.append(SubTypesClass(typeNames=["Endpoint"])) - yield from auto_workunit(mcps) + yield from auto_workunit( + MetadataChangeProposalWrapper.construct_many(endpoint_urn, aspects=aspects) + ) - def _get_ml_model_endpoint_workunit( + def _gen_ml_model_endpoint_workunits( self, model: Model, model_version: VersionInfo, @@ -577,13 +537,13 @@ def _get_ml_model_endpoint_workunit( endpoint_urn = None if endpoint: - yield from self._get_endpoint_workunit(endpoint, model, model_version) + yield from self._gen_endpoint_workunits(endpoint, model, model_version) - yield from self._get_ml_model_properties_workunit( + yield from self._gen_ml_model_workunits( model, model_version, training_job_urn, endpoint_urn ) - def _get_ml_model_properties_workunit( + def _gen_ml_model_workunits( self, model: Model, model_version: VersionInfo, diff --git a/metadata-ingestion/tests/integration/vertexai/test_vertexai_source.py b/metadata-ingestion/tests/integration/vertexai/test_vertexai.py similarity index 95% rename from metadata-ingestion/tests/integration/vertexai/test_vertexai_source.py rename to metadata-ingestion/tests/integration/vertexai/test_vertexai.py index fdbcf9a3d1e682..9ada49a556b9fb 100644 --- a/metadata-ingestion/tests/integration/vertexai/test_vertexai_source.py +++ b/metadata-ingestion/tests/integration/vertexai/test_vertexai.py @@ -17,6 +17,7 @@ PROJECT_ID = "test-project-id" REGION = "us-west2" + @pytest.fixture def sink_file_path(tmp_path: Path) -> str: return str(tmp_path / "vertexai_source_mcps.json") @@ -81,9 +82,8 @@ def test_vertexai_source_ingestion( mock_models: List[Model], mock_training_jobs: List[VertexAiResourceNoun], ) -> None: - mocks = {} with contextlib.ExitStack() as exit_stack: - for path_to_mock in [ + for func_to_mock in [ "google.cloud.aiplatform.init", "google.cloud.aiplatform.Model.list", "google.cloud.aiplatform.datasets.TextDataset.list", @@ -101,12 +101,11 @@ def test_vertexai_source_ingestion( "google.cloud.aiplatform.AutoMLVideoTrainingJob.list", "google.cloud.aiplatform.AutoMLForecastingTrainingJob.list", ]: - mock = exit_stack.enter_context(patch(path_to_mock)) - if path_to_mock == "google.cloud.aiplatform.Model.list": + mock = exit_stack.enter_context(patch(func_to_mock)) + if func_to_mock == "google.cloud.aiplatform.Model.list": mock.return_value = mock_models else: mock.return_value = [] - mocks[path_to_mock] = mock golden_file_path = ( pytestconfig.rootpath diff --git a/metadata-ingestion/tests/unit/test_vertexai_source.py b/metadata-ingestion/tests/unit/test_vertexai_source.py index 4891850ce69b28..70692a7ac4860e 100644 --- a/metadata-ingestion/tests/unit/test_vertexai_source.py +++ b/metadata-ingestion/tests/unit/test_vertexai_source.py @@ -1,3 +1,5 @@ +import contextlib +import json from datetime import datetime from typing import List from unittest.mock import MagicMock, patch @@ -29,6 +31,9 @@ SubTypesClass, ) +PROJECT_ID = "acryl-poc" +REGION = "us-west2" + @pytest.fixture def mock_model() -> Model: @@ -79,13 +84,13 @@ def mock_training_job() -> VertexAiResourceNoun: @pytest.fixture def mock_dataset() -> VertexAiResourceNoun: - mock_training_job = MagicMock(spec=VertexAiResourceNoun) - mock_training_job.name = "mock_dataset" - mock_training_job.create_time = timestamp_pb2.Timestamp().GetCurrentTime() - mock_training_job.update_time = timestamp_pb2.Timestamp().GetCurrentTime() - mock_training_job.display_name = "mock_dataset_display_name" - mock_training_job.description = "mock_dataset_description" - return mock_training_job + mock_dataset = MagicMock(spec=VertexAiResourceNoun) + mock_dataset.name = "mock_dataset" + mock_dataset.create_time = timestamp_pb2.Timestamp().GetCurrentTime() + mock_dataset.update_time = timestamp_pb2.Timestamp().GetCurrentTime() + mock_dataset.display_name = "mock_dataset_display_name" + mock_dataset.description = "mock_dataset_description" + return mock_dataset @pytest.fixture @@ -109,26 +114,10 @@ def mock_endpoint() -> Endpoint: @pytest.fixture -def project_id() -> str: - """ - Replace with your GCP Project ID - """ - return "acryl-poc" - - -@pytest.fixture -def region() -> str: - """ - Replace with your GCP region s - """ - return "us-west2" - - -@pytest.fixture -def source(project_id: str, region: str) -> VertexAISource: +def source() -> VertexAISource: return VertexAISource( ctx=PipelineContext(run_id="vertexai-source-test"), - config=VertexAIConfig(project_id=project_id, region=region), + config=VertexAIConfig(project_id=PROJECT_ID, region=REGION), ) @@ -185,7 +174,7 @@ def test_get_ml_model_workunits( assert hasattr(mock_list, "return_value") # this check needed to go ground lint mock_list.return_value = mock_models - wcs = [wc for wc in source._get_ml_model_workunits()] + wcs = [wc for wc in source._get_ml_models_workunits()] assert len(wcs) == 2 # aspect is MLModelGroupPropertiesClass @@ -209,9 +198,7 @@ def test_get_ml_model_workunits( def test_get_ml_model_properties_workunit( source: VertexAISource, mock_model: Model, model_version: VersionInfo ) -> None: - wu = [ - wu for wu in source._get_ml_model_properties_workunit(mock_model, model_version) - ] + wu = [wu for wu in source._gen_ml_model_workunits(mock_model, model_version)] assert len(wu) == 1 assert hasattr(wu[0].metadata, "aspect") aspect = wu[0].metadata.aspect @@ -231,7 +218,7 @@ def test_get_endpoint_workunit( mock_model: Model, model_version: VersionInfo, ) -> None: - for wu in source._get_endpoint_workunit(mock_endpoint, mock_model, model_version): + for wu in source._gen_endpoint_workunits(mock_endpoint, mock_model, model_version): assert hasattr(wu.metadata, "aspect") aspect = wu.metadata.aspect if isinstance(aspect, MLModelDeploymentPropertiesClass): @@ -240,12 +227,17 @@ def test_get_endpoint_workunit( "displayName": mock_endpoint.display_name } assert aspect.createdAt == int(mock_endpoint.create_time.timestamp() * 1000) + elif isinstance(aspect, ContainerClass): + assert aspect.container == source._get_project_container().as_urn() + + elif isinstance(aspect, SubTypesClass): + assert aspect.typeNames == ["Endpoint"] def test_get_data_process_properties_workunit( source: VertexAISource, mock_training_job: VertexAiResourceNoun ) -> None: - for wu in source._get_data_process_properties_workunits(mock_training_job): + for wu in source._generate_data_process_workunits(mock_training_job): assert hasattr(wu.metadata, "aspect") aspect = wu.metadata.aspect if isinstance(aspect, DataProcessInstancePropertiesClass): @@ -256,51 +248,37 @@ def test_get_data_process_properties_workunit( assert aspect.externalUrl == source._make_job_external_url( mock_training_job ) + assert ( + aspect.customProperties["displayName"] == mock_training_job.display_name + ) elif isinstance(aspect, SubTypesClass): assert "Training Job" in aspect.typeNames -@patch("google.cloud.aiplatform.datasets.TextDataset.list") -@patch("google.cloud.aiplatform.datasets.TabularDataset.list") -@patch("google.cloud.aiplatform.datasets.ImageDataset.list") -@patch("google.cloud.aiplatform.datasets.TimeSeriesDataset.list") -@patch("google.cloud.aiplatform.datasets.VideoDataset.list") def test_get_data_process_input_workunit( - mock_text_list: List[VertexAiResourceNoun], - mock_tabular_list: List[VertexAiResourceNoun], - mock_image_list: List[VertexAiResourceNoun], - mock_time_series_list: List[VertexAiResourceNoun], - mock_video_list: List[VertexAiResourceNoun], source: VertexAISource, mock_training_job: VertexAiResourceNoun, ) -> None: - # Mocking all the dataset list - assert hasattr( - mock_text_list, "return_value" - ) # this check needed to go ground lint - mock_text_list.return_value = [] - assert hasattr( - mock_tabular_list, "return_value" - ) # this check needed to go ground lint - mock_tabular_list.return_value = [] - assert hasattr( - mock_video_list, "return_value" - ) # this check needed to go ground lint - mock_video_list.return_value = [] - assert hasattr( - mock_time_series_list, "return_value" - ) # this check needed to go ground lint - mock_time_series_list.return_value = [] - assert hasattr( - mock_image_list, "return_value" - ) # this check needed to go ground lint - mock_image_list.return_value = [] - - for wu in source._get_data_process_input_workunit(mock_training_job, "12345"): - assert hasattr(wu.metadata, "aspect") - aspect = wu.metadata.aspect - assert isinstance(aspect, DataProcessInstanceInputClass) - assert len(aspect.inputs) == 1 + with contextlib.ExitStack() as exit_stack: + for func_to_mock in [ + "google.cloud.aiplatform.init", + "google.cloud.aiplatform.datasets.TextDataset.list", + "google.cloud.aiplatform.datasets.TabularDataset.list", + "google.cloud.aiplatform.datasets.ImageDataset.list", + "google.cloud.aiplatform.datasets.TimeSeriesDataset.list", + "google.cloud.aiplatform.datasets.VideoDataset.list", + ]: + mock = exit_stack.enter_context(patch(func_to_mock)) + if func_to_mock == "google.cloud.aiplatform.CustomJob.list": + mock.return_value = [mock_training_job] + else: + mock.return_value = [] + + for wu in source._gen_input_dataset_workunits(mock_training_job, "12345"): + assert hasattr(wu.metadata, "aspect") + aspect = wu.metadata.aspect + assert isinstance(aspect, DataProcessInstanceInputClass) + assert len(aspect.inputs) == 1 def test_vertexai_config_init(): @@ -345,76 +323,100 @@ def test_vertexai_config_init(): == "https://www.googleapis.com/oauth2/v1/certs" ) + assert config._credentials_path is not None + with open(config._credentials_path, "r") as file: + content = json.loads(file.read()) + assert content["project_id"] == "test-project" + assert content["private_key_id"] == "test-key-id" + assert content["private_key_id"] == "test-key-id" + assert ( + content["private_key"] + == "-----BEGIN PRIVATE KEY-----\ntest-private-key\n-----END PRIVATE KEY-----\n" + ) + assert ( + content["client_email"] == "test-email@test-project.iam.gserviceaccount.com" + ) + assert content["client_id"] == "test-client-id" + assert content["auth_uri"] == "https://accounts.google.com/o/oauth2/auth" + assert content["token_uri"] == "https://oauth2.googleapis.com/token" + assert ( + content["auth_provider_x509_cert_url"] + == "https://www.googleapis.com/oauth2/v1/certs" + ) + -@patch("google.cloud.aiplatform.CustomJob.list") -@patch("google.cloud.aiplatform.CustomTrainingJob.list") -@patch("google.cloud.aiplatform.CustomContainerTrainingJob.list") -@patch("google.cloud.aiplatform.CustomPythonPackageTrainingJob.list") -@patch("google.cloud.aiplatform.AutoMLTabularTrainingJob.list") -@patch("google.cloud.aiplatform.AutoMLTextTrainingJob.list") -@patch("google.cloud.aiplatform.AutoMLImageTrainingJob.list") -@patch("google.cloud.aiplatform.AutoMLVideoTrainingJob.list") -@patch("google.cloud.aiplatform.AutoMLForecastingTrainingJob.list") def test_get_training_jobs_workunit( - mock_automl_forecasting_job_list: List[VertexAiResourceNoun], - mock_automl_video_job_list: List[VertexAiResourceNoun], - mock_automl_image_list: List[VertexAiResourceNoun], - mock_automl_text_job_list: List[VertexAiResourceNoun], - mock_automl_tabular_job_list: List[VertexAiResourceNoun], - mock_custom_python_job_list: List[VertexAiResourceNoun], - mock_custom_container_job_list: List[VertexAiResourceNoun], - mock_custom_training_job_list: List[VertexAiResourceNoun], - mock_custom_job_list: List[VertexAiResourceNoun], source: VertexAISource, mock_training_job: VertexAiResourceNoun, mock_training_automl_job: AutoMLTabularTrainingJob, ) -> None: - assert hasattr(mock_custom_job_list, "return_value") - mock_custom_job_list.return_value = [mock_training_job] - assert hasattr(mock_custom_training_job_list, "return_value") - mock_custom_training_job_list.return_value = [] - assert hasattr(mock_custom_container_job_list, "return_value") - mock_custom_container_job_list.return_value = [] - assert hasattr(mock_custom_python_job_list, "return_value") - mock_custom_python_job_list.return_value = [] - assert hasattr(mock_automl_tabular_job_list, "return_value") - mock_automl_tabular_job_list.return_value = [mock_training_automl_job] - assert hasattr(mock_automl_text_job_list, "return_value") - mock_automl_text_job_list.return_value = [] - assert hasattr(mock_automl_image_list, "return_value") - mock_automl_image_list.return_value = [] - assert hasattr(mock_automl_video_job_list, "return_value") - mock_automl_video_job_list.return_value = [] - assert hasattr(mock_automl_forecasting_job_list, "return_value") - mock_automl_forecasting_job_list.return_value = [] - - container_key = ProjectIdKey( - project_id=source.config.project_id, platform=source.platform + with contextlib.ExitStack() as exit_stack: + for func_to_mock in [ + "google.cloud.aiplatform.init", + "google.cloud.aiplatform.CustomJob.list", + "google.cloud.aiplatform.CustomTrainingJob.list", + "google.cloud.aiplatform.CustomContainerTrainingJob.list", + "google.cloud.aiplatform.CustomPythonPackageTrainingJob.list", + "google.cloud.aiplatform.AutoMLTabularTrainingJob.list", + "google.cloud.aiplatform.AutoMLImageTrainingJob.list", + "google.cloud.aiplatform.AutoMLTextTrainingJob.list", + "google.cloud.aiplatform.AutoMLVideoTrainingJob.list", + "google.cloud.aiplatform.AutoMLForecastingTrainingJob.list", + ]: + mock = exit_stack.enter_context(patch(func_to_mock)) + if func_to_mock == "google.cloud.aiplatform.CustomJob.list": + mock.return_value = [mock_training_job] + else: + mock.return_value = [] + + container_key = ProjectIdKey( + project_id=source.config.project_id, platform=source.platform + ) + + """ + Test the retrieval of training jobs work units from Vertex AI. + This function mocks customJob and AutoMLTabularTrainingJob, + and verifies the properties of the work units + """ + for wc in source._get_training_jobs_workunits(): + assert hasattr(wc.metadata, "aspect") + aspect = wc.metadata.aspect + if isinstance(aspect, DataProcessInstancePropertiesClass): + assert ( + aspect.name + == f"{source.config.project_id}.job.{mock_training_job.name}" + or f"{source.config.project_id}.job.{mock_training_automl_job.name}" + ) + assert ( + aspect.customProperties["displayName"] + == mock_training_job.display_name + or mock_training_automl_job.display_name + ) + if isinstance(aspect, SubTypesClass): + assert aspect.typeNames == ["Training Job"] + + if isinstance(aspect, ContainerClass): + assert aspect.container == container_key.as_urn() + + +def test_get_dataset_workunit( + mock_dataset: VertexAiResourceNoun, source: VertexAISource +) -> None: + dataset_urn = builder.make_dataset_urn( + platform=source.platform, + name=mock_dataset.name, + env=source.config.env, ) - - """ - Test the retrieval of training jobs work units from Vertex AI. - This function mocks customJob and AutoMLTabularTrainingJob, - and verifies the properties of the work units - """ - for wc in source._get_training_jobs_workunit(): - assert hasattr(wc.metadata, "aspect") - aspect = wc.metadata.aspect + for wu in source._get_dataset_workunits(dataset_urn=dataset_urn, ds=mock_dataset): + assert hasattr(wu.metadata, "aspect") + aspect = wu.metadata.aspect if isinstance(aspect, DataProcessInstancePropertiesClass): - assert ( - aspect.name - == f"{source.config.project_id}.job.{mock_training_job.name}" - or f"{source.config.project_id}.job.{mock_training_automl_job.name}" - ) - assert ( - aspect.customProperties["displayName"] == mock_training_job.display_name - or mock_training_automl_job.display_name - ) - if isinstance(aspect, SubTypesClass): - assert aspect.typeNames == ["Training Job"] - - if isinstance(aspect, ContainerClass): - assert aspect.container == container_key.as_urn() + assert aspect.name == f"{source._make_vertexai_job_name(mock_dataset.name)}" + assert aspect.customProperties["displayName"] == mock_dataset.display_name + elif isinstance(aspect, ContainerClass): + assert aspect.container == source._get_project_container().as_urn() + elif isinstance(aspect, SubTypesClass): + assert aspect.typeNames == ["Dataset"] def test_make_model_external_url(mock_model: Model, source: VertexAISource) -> None: @@ -442,7 +444,7 @@ def test_real_model_workunit( Disabled as default Use real model registered in the Vertex AI Model Registry """ - for wu in source._get_ml_model_properties_workunit( + for wu in source._gen_ml_model_workunits( model=real_model, model_version=model_version ): assert hasattr(wu.metadata, "aspect") @@ -459,7 +461,7 @@ def test_real_model_workunit( def test_real_get_data_process_properties( source: VertexAISource, real_autoML_tabular_job: _TrainingJob ) -> None: - for wu in source._get_data_process_properties_workunits(real_autoML_tabular_job): + for wu in source._generate_data_process_workunits(real_autoML_tabular_job): assert hasattr(wu.metadata, "aspect") aspect = wu.metadata.aspect if isinstance(aspect, DataProcessInstancePropertiesClass):