Skip to content

Commit

Permalink
fixed review comments, adding unit test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ryota-cloud committed Mar 3, 2025
1 parent 0b6b7db commit 5472929
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 225 deletions.
2 changes: 1 addition & 1 deletion metadata-ingestion/docs/sources/vertexai/vertexai_pre.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#### Credential to access to GCP
1. Follow the section on credentials to access Vertex AI [GCP docs](https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to).

#### Create a service account in the Extractor Project
#### Create a service account and assign roles

1. Setup a ServiceAccount as per [GCP docs](https://cloud.google.com/iam/docs/creating-managing-service-accounts#iam-service-accounts-create-console)
and assign the previously created role to this service account.
Expand Down
134 changes: 47 additions & 87 deletions metadata-ingestion/src/datahub/ingestion/source/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ class GCPCredential(ConfigModel):

_fix_private_key_newlines = pydantic_multiline_string("private_key")

def create_credential_temp_file(self, project_id: str) -> str:
def create_credential_temp_file(self, project_id: Optional[str] = None) -> str:
# Adding project_id from the top level config
configs = self.dict()
configs["project_id"] = project_id
if project_id:
configs["project_id"] = project_id
with tempfile.NamedTemporaryFile(delete=False) as fp:
cred_json = json.dumps(configs, indent=4, separators=(",", ": "))
fp.write(cred_json.encode())
Expand Down Expand Up @@ -168,14 +169,14 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
"""

# Ingest Project
yield from self._get_project_workunits()
yield from self._gen_project_workunits()
# Fetch and Ingest Models, Model Versions a from Model Registry
yield from self._get_ml_model_workunits()
yield from self._get_ml_models_workunits()
# Fetch and Ingest Training Jobs
yield from self._get_training_jobs_workunit()
yield from self._get_training_jobs_workunits()
# TODO Fetch Experiments and Experiment Runs

def _get_project_workunits(self) -> Iterable[MetadataWorkUnit]:
def _gen_project_workunits(self) -> Iterable[MetadataWorkUnit]:
container_key = ProjectIdKey(
project_id=self.config.project_id, platform=self.platform
)
Expand All @@ -186,60 +187,25 @@ def _get_project_workunits(self) -> Iterable[MetadataWorkUnit]:
sub_types=["Project"],
)

def _has_training_job(self, model: Model) -> bool:
"""
Validate Model Has Valid Training Job
"""
job = model.training_job
if not job:
return False

try:
# when model has ref to training job, but field is sometimes not accessible and RunTImeError thrown when accessed
# if RunTimeError is not thrown, it is valid and proceed
name = job.name
logger.debug(
(
f"can fetch training job name: {name} for model: (name:{model.display_name} id:{model.name})"
)
)
return True
except RuntimeError:
logger.debug(
f"cannot fetch training job name, not valid for model (name:{model.display_name} id:{model.name})"
)

return False

def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]:
def _get_ml_models_workunits(self) -> Iterable[MetadataWorkUnit]:
"""
Fetch List of Models in Model Registry and generate a corresponding work unit.
"""
registered_models = self.client.Model.list()
for model in registered_models:
# create work unit for Model Group (= Model in VertexAI)
yield from self._get_ml_group_workunit(model)
yield from self._gen_ml_group_workunits(model)
model_versions = model.versioning_registry.list_versions()
for model_version in model_versions:
# create work unit for Training Job (if Model has reference to Training Job)
if self._has_training_job(model):
logger.info(
f"Ingesting a training job for a model: {model_version.model_display_name}"
)
if model.training_job:
yield from self._get_data_process_properties_workunits(
model.training_job
)

# create work unit for Model (= Model Version in VertexAI)
logger.info(
f"Ingesting a model (name: {model.display_name} id:{model.name})"
)
yield from self._get_ml_model_endpoint_workunit(
yield from self._gen_ml_model_endpoint_workunits(
model=model, model_version=model_version
)

def _get_training_jobs_workunit(self) -> Iterable[MetadataWorkUnit]:
def _get_training_jobs_workunits(self) -> Iterable[MetadataWorkUnit]:
"""
Fetches training jobs from Vertex AI and generates corresponding work units.
This method retrieves various types of training jobs from Vertex AI, including
Expand All @@ -263,16 +229,16 @@ def _get_training_jobs_workunit(self) -> Iterable[MetadataWorkUnit]:
for class_name in class_names:
logger.info(f"Fetching a list of {class_name}s from VertexAI server")
for job in getattr(self.client, class_name).list():
yield from self._get_training_job_workunit(job)
yield from self._get_training_job_workunits(job)

def _get_training_job_workunit(
def _get_training_job_workunits(
self, job: VertexAiResourceNoun
) -> Iterable[MetadataWorkUnit]:
yield from self._get_data_process_properties_workunits(job)
yield from self._get_job_output_workunit(job)
yield from self._get_job_input_workunit(job)
yield from self._generate_data_process_workunits(job)
yield from self._get_job_output_workunits(job)
yield from self._get_job_input_workunits(job)

def _get_ml_group_workunit(
def _gen_ml_group_workunits(
self,
model: Model,
) -> Iterable[MetadataWorkUnit]:
Expand Down Expand Up @@ -308,7 +274,7 @@ def _make_ml_model_group_urn(self, model: Model) -> str:
)
return urn

def _get_data_process_properties_workunits(
def _generate_data_process_workunits(
self, job: VertexAiResourceNoun
) -> Iterable[MetadataWorkUnit]:
"""
Expand Down Expand Up @@ -382,7 +348,7 @@ def _search_model_version(
return version
return None

def _get_job_output_workunit(
def _get_job_output_workunits(
self, job: VertexAiResourceNoun
) -> Iterable[MetadataWorkUnit]:
"""
Expand All @@ -408,7 +374,7 @@ def _get_job_output_workunit(
f" found a training job: {job.display_name} generated "
f"a model (name:{model.display_name} id:{model_version_str})"
)
yield from self._get_ml_model_endpoint_workunit(
yield from self._gen_ml_model_endpoint_workunits(
model, model_version, job_urn
)

Expand Down Expand Up @@ -437,7 +403,7 @@ def _search_dataset(self, dataset_id: str) -> Optional[VertexAiResourceNoun]:

return self.datasets.get(dataset_id)

def _get_dataset_workunit(
def _get_dataset_workunits(
self, dataset_urn: str, ds: VertexAiResourceNoun
) -> Iterable[MetadataWorkUnit]:
"""
Expand All @@ -449,7 +415,9 @@ def _get_dataset_workunit(
aspects.append(
DatasetPropertiesClass(
name=self._make_vertexai_dataset_name(ds.name),
created=TimeStampClass(time=int(ds.create_time.timestamp() * 1000)),
created=TimeStampClass(time=int(ds.create_time.timestamp() * 1000))
if ds.create_time
else None,
description=f"Dataset: {ds.display_name}",
customProperties={
"displayName": ds.display_name,
Expand All @@ -467,7 +435,7 @@ def _get_dataset_workunit(
MetadataChangeProposalWrapper.construct_many(dataset_urn, aspects=aspects)
)

def _get_job_input_workunit(
def _get_job_input_workunits(
self, job: VertexAiResourceNoun
) -> Iterable[MetadataWorkUnit]:
"""
Expand All @@ -489,9 +457,9 @@ def _get_job_input_workunit(
)

if dataset_id:
yield from self._get_data_process_input_workunit(job, dataset_id)
yield from self._gen_input_dataset_workunits(job, dataset_id)

def _get_data_process_input_workunit(
def _gen_input_dataset_workunits(
self, job: VertexAiResourceNoun, dataset_id: str
) -> Iterable[MetadataWorkUnit]:
"""
Expand All @@ -510,7 +478,7 @@ def _get_data_process_input_workunit(

dataset = self._search_dataset(dataset_id) if dataset_id else None
if dataset:
yield from self._get_dataset_workunit(dataset_urn=dataset_urn, ds=dataset)
yield from self._get_dataset_workunits(dataset_urn=dataset_urn, ds=dataset)
# Create URN of Training Job
job_id = self._make_vertexai_job_name(entity_id=job.name)
mcp = MetadataChangeProposalWrapper(
Expand All @@ -522,7 +490,7 @@ def _get_data_process_input_workunit(
)
yield from auto_workunit([mcp])

def _get_endpoint_workunit(
def _gen_endpoint_workunits(
self, endpoint: Endpoint, model: Model, model_version: VersionInfo
) -> Iterable[MetadataWorkUnit]:
endpoint_urn = builder.make_ml_model_deployment_urn(
Expand All @@ -532,38 +500,30 @@ def _get_endpoint_workunit(
),
env=self.config.env,
)
deployment_aspect = MLModelDeploymentPropertiesClass(
description=model.description,
createdAt=int(endpoint.create_time.timestamp() * 1000),
version=VersionTagClass(versionTag=str(model_version.version_id)),
customProperties={"displayName": endpoint.display_name},
)

mcps = []
mcps.append(
MetadataChangeProposalWrapper(
entityUrn=endpoint_urn, aspect=deployment_aspect
aspects: List[_Aspect] = list()
aspects.append(
MLModelDeploymentPropertiesClass(
description=model.description,
createdAt=int(endpoint.create_time.timestamp() * 1000),
version=VersionTagClass(versionTag=str(model_version.version_id)),
customProperties={"displayName": endpoint.display_name},
)
)

mcps.append(
MetadataChangeProposalWrapper(
entityUrn=endpoint_urn,
aspect=ContainerClass(
container=self._get_project_container().as_urn(),
),
aspects.append(
ContainerClass(
container=self._get_project_container().as_urn(),
)
)

mcps.append(
MetadataChangeProposalWrapper(
entityUrn=endpoint_urn, aspect=SubTypesClass(typeNames=["Endpoint"])
)
)
aspects.append(SubTypesClass(typeNames=["Endpoint"]))

yield from auto_workunit(mcps)
yield from auto_workunit(
MetadataChangeProposalWrapper.construct_many(endpoint_urn, aspects=aspects)
)

def _get_ml_model_endpoint_workunit(
def _gen_ml_model_endpoint_workunits(
self,
model: Model,
model_version: VersionInfo,
Expand All @@ -577,13 +537,13 @@ def _get_ml_model_endpoint_workunit(
endpoint_urn = None

if endpoint:
yield from self._get_endpoint_workunit(endpoint, model, model_version)
yield from self._gen_endpoint_workunits(endpoint, model, model_version)

yield from self._get_ml_model_properties_workunit(
yield from self._gen_ml_model_workunits(
model, model_version, training_job_urn, endpoint_urn
)

def _get_ml_model_properties_workunit(
def _gen_ml_model_workunits(
self,
model: Model,
model_version: VersionInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PROJECT_ID = "test-project-id"
REGION = "us-west2"


@pytest.fixture
def sink_file_path(tmp_path: Path) -> str:
return str(tmp_path / "vertexai_source_mcps.json")
Expand Down Expand Up @@ -81,9 +82,8 @@ def test_vertexai_source_ingestion(
mock_models: List[Model],
mock_training_jobs: List[VertexAiResourceNoun],
) -> None:
mocks = {}
with contextlib.ExitStack() as exit_stack:
for path_to_mock in [
for func_to_mock in [
"google.cloud.aiplatform.init",
"google.cloud.aiplatform.Model.list",
"google.cloud.aiplatform.datasets.TextDataset.list",
Expand All @@ -101,12 +101,11 @@ def test_vertexai_source_ingestion(
"google.cloud.aiplatform.AutoMLVideoTrainingJob.list",
"google.cloud.aiplatform.AutoMLForecastingTrainingJob.list",
]:
mock = exit_stack.enter_context(patch(path_to_mock))
if path_to_mock == "google.cloud.aiplatform.Model.list":
mock = exit_stack.enter_context(patch(func_to_mock))
if func_to_mock == "google.cloud.aiplatform.Model.list":
mock.return_value = mock_models
else:
mock.return_value = []
mocks[path_to_mock] = mock

golden_file_path = (
pytestconfig.rootpath
Expand Down
Loading

0 comments on commit 5472929

Please sign in to comment.