Skip to content

Commit

Permalink
Merge branch 'master' into urn-validation-3
Browse files Browse the repository at this point in the history
  • Loading branch information
david-leifker authored Dec 2, 2024
2 parents 739c639 + 2f20c52 commit 81ee794
Show file tree
Hide file tree
Showing 25 changed files with 166 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import random
import signal
import subprocess
import textwrap
import time
from typing import Any, Iterator, Sequence

Expand Down Expand Up @@ -110,6 +111,48 @@ def _wait_for_dag_finish(
raise NotReadyError(f"DAG has not finished yet: {dag_run['state']}")


def _dump_dag_logs(airflow_instance: AirflowInstance, dag_id: str) -> None:
# Get the dag run info
res = airflow_instance.session.get(
f"{airflow_instance.airflow_url}/api/v1/dags/{dag_id}/dagRuns", timeout=5
)
res.raise_for_status()
dag_run = res.json()["dag_runs"][0]
dag_run_id = dag_run["dag_run_id"]

# List the tasks in the dag run
res = airflow_instance.session.get(
f"{airflow_instance.airflow_url}/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances",
timeout=5,
)
res.raise_for_status()
task_instances = res.json()["task_instances"]

# Sort tasks by start_date to maintain execution order
task_instances.sort(key=lambda x: x["start_date"] or "")

print(f"\nTask execution order for DAG {dag_id}:")
for task in task_instances:
task_id = task["task_id"]
state = task["state"]
try_number = task.get("try_number", 1)

task_header = f"Task: {task_id} (State: {state}; Try: {try_number})"

# Get logs for the task's latest try number
try:
res = airflow_instance.session.get(
f"{airflow_instance.airflow_url}/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}"
f"/taskInstances/{task_id}/logs/{try_number}",
params={"full_content": "true"},
timeout=5,
)
res.raise_for_status()
print(f"\n=== {task_header} ===\n{textwrap.indent(res.text, ' ')}")
except Exception as e:
print(f"Failed to fetch logs for {task_header}: {e}")


@contextlib.contextmanager
def _run_airflow(
tmp_path: pathlib.Path,
Expand Down Expand Up @@ -377,6 +420,11 @@ def test_airflow_plugin(
print("Sleeping for a few seconds to let the plugin finish...")
time.sleep(10)

try:
_dump_dag_logs(airflow_instance, dag_id)
except Exception as e:
print(f"Failed to dump DAG logs: {e}")

if dag_id == DAG_TO_SKIP_INGESTION:
# Verify that no MCPs were generated.
assert not os.path.exists(airflow_instance.metadata_file)
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion-modules/airflow-plugin/tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ commands =
[testenv:py310-airflow24]
extras = dev,integration-tests,plugin-v2,test-airflow24

[testenv:py310-airflow{26,27,28},py311-airflow{29,210}]
[testenv:py3{10,11}-airflow{26,27,28,29,210}]
extras = dev,integration-tests,plugin-v2

Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def fqn(self) -> str:
return (
self.qualified_name
or self.id
or Urn.create_from_string(self.urn).get_entity_id()[0]
or Urn.from_string(self.urn).get_entity_id()[0]
)

@validator("urn", pre=True, always=True)
Expand Down
2 changes: 1 addition & 1 deletion metadata-ingestion/src/datahub/cli/put_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def platform(
"""

if name.startswith(f"urn:li:{DataPlatformUrn.ENTITY_TYPE}"):
platform_urn = DataPlatformUrn.create_from_string(name)
platform_urn = DataPlatformUrn.from_string(name)
platform_name = platform_urn.get_entity_id_as_string()
else:
platform_name = name.lower()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _get_owner_urn(maybe_urn: str) -> str:

def _abort_if_non_existent_urn(graph: DataHubGraph, urn: str, operation: str) -> None:
try:
parsed_urn: Urn = Urn.create_from_string(urn)
parsed_urn: Urn = Urn.from_string(urn)
entity_type = parsed_urn.get_type()
except Exception:
click.secho(f"Provided urn {urn} does not seem valid", fg="red")
Expand Down
43 changes: 43 additions & 0 deletions metadata-ingestion/src/datahub/emitter/mcp_patch_builder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import json
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union

from datahub.emitter.aspect import JSON_PATCH_CONTENT_TYPE
from datahub.emitter.serialization_helper import pre_json_transform
from datahub.metadata.schema_classes import (
AuditStampClass,
ChangeTypeClass,
EdgeClass,
GenericAspectClass,
KafkaAuditHeaderClass,
MetadataChangeProposalClass,
SystemMetadataClass,
)
from datahub.metadata.urns import Urn
from datahub.utilities.urns.urn import guess_entity_type


Expand Down Expand Up @@ -89,3 +93,42 @@ def build(self) -> Iterable[MetadataChangeProposalClass]:
)
for aspect_name, patches in self.patches.items()
]

@classmethod
def _mint_auditstamp(cls, message: Optional[str] = None) -> AuditStampClass:
"""
Creates an AuditStampClass instance with the current timestamp and other default values.
Args:
message: The message associated with the audit stamp (optional).
Returns:
An instance of AuditStampClass.
"""
return AuditStampClass(
time=int(time.time() * 1000.0),
actor="urn:li:corpuser:datahub",
message=message,
)

@classmethod
def _ensure_urn_type(
cls, entity_type: str, edges: List[EdgeClass], context: str
) -> None:
"""
Ensures that the destination URNs in the given edges have the specified entity type.
Args:
entity_type: The entity type to check against.
edges: A list of Edge objects.
context: The context or description of the operation.
Raises:
ValueError: If any of the destination URNs is not of the specified entity type.
"""
for e in edges:
urn = Urn.from_string(e.destinationUrn)
if not urn.entity_type == entity_type:
raise ValueError(
f"{context}: {e.destinationUrn} is not of type {entity_type}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def from_string_name(cls, ref: str) -> "BigQueryTableRef":
@classmethod
def from_urn(cls, urn: str) -> "BigQueryTableRef":
"""Raises: ValueError if urn is not a valid BigQuery table URN."""
dataset_urn = DatasetUrn.create_from_string(urn)
dataset_urn = DatasetUrn.from_string(urn)
split = dataset_urn.name.rsplit(".", 3)
if len(split) == 3:
project, dataset, table = split
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

is_resource_row: bool = not row["subresource"]
entity_urn = row["resource"]
entity_type = Urn.create_from_string(row["resource"]).get_type()
entity_type = Urn.from_string(row["resource"]).get_type()

term_associations: List[
GlossaryTermAssociationClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def collapse_name(name: str, collapse_urns: CollapseUrns) -> str:
def collapse_urn(urn: str, collapse_urns: CollapseUrns) -> str:
if len(collapse_urns.urns_suffix_regex) == 0:
return urn
urn_obj = DatasetUrn.create_from_string(urn)
urn_obj = DatasetUrn.from_string(urn)
name = collapse_name(name=urn_obj.get_dataset_name(), collapse_urns=collapse_urns)
data_platform_urn = urn_obj.get_data_platform_urn()
return str(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
def delete_entity(self, urn: str) -> None:
assert self.ctx.graph

entity_urn = Urn.create_from_string(urn)
entity_urn = Urn.from_string(urn)
self.report.num_soft_deleted_entity_removed += 1
self.report.num_soft_deleted_entity_removed_by_type[entity_urn.entity_type] = (
self.report.num_soft_deleted_entity_removed_by_type.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
convert_to_cardinality,
)
from datahub.ingestion.source.sql.sql_report import SQLSourceReport
from datahub.metadata.com.linkedin.pegasus2avro.schema import EditableSchemaMetadata
from datahub.ingestion.source.sql.sql_types import resolve_sql_type
from datahub.metadata.com.linkedin.pegasus2avro.schema import (
EditableSchemaMetadata,
NumberType,
)
from datahub.metadata.schema_classes import (
DatasetFieldProfileClass,
DatasetProfileClass,
Expand Down Expand Up @@ -361,6 +365,8 @@ class _SingleDatasetProfiler(BasicDatasetProfilerBase):
platform: str
env: str

column_types: Dict[str, str] = dataclasses.field(default_factory=dict)

def _get_columns_to_profile(self) -> List[str]:
if not self.config.any_field_level_metrics_enabled():
return []
Expand All @@ -374,6 +380,7 @@ def _get_columns_to_profile(self) -> List[str]:

for col_dict in self.dataset.columns:
col = col_dict["name"]
self.column_types[col] = str(col_dict["type"])
# We expect the allow/deny patterns to specify '<table_pattern>.<column_pattern>'
if not self.config._allow_deny_patterns.allowed(
f"{self.dataset_name}.{col}"
Expand Down Expand Up @@ -430,6 +437,21 @@ def _get_column_type(self, column_spec: _SingleColumnSpec, column: str) -> None:
self.dataset, column
)

if column_spec.type_ == ProfilerDataType.UNKNOWN:
try:
datahub_field_type = resolve_sql_type(
self.column_types[column], self.dataset.engine.dialect.name.lower()
)
except Exception as e:
logger.debug(
f"Error resolving sql type {self.column_types[column]}: {e}"
)
datahub_field_type = None
if datahub_field_type is None:
return
if isinstance(datahub_field_type, NumberType):
column_spec.type_ = ProfilerDataType.NUMERIC

@_run_with_query_combiner
def _get_column_cardinality(
self, column_spec: _SingleColumnSpec, column: str
Expand Down
16 changes: 14 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def resolve_vertica_modified_type(type_string: str) -> Any:
return VERTICA_SQL_TYPES_MAP[type_string]


# see https://docs.snowflake.com/en/sql-reference/intro-summary-data-types.html
SNOWFLAKE_TYPES_MAP: Dict[str, Any] = {
"NUMBER": NumberType,
"DECIMAL": NumberType,
Expand Down Expand Up @@ -312,6 +311,18 @@ def resolve_vertica_modified_type(type_string: str) -> Any:
"GEOGRAPHY": None,
}


def resolve_snowflake_modified_type(type_string: str) -> Any:
# Match types with precision and scale, e.g., 'DECIMAL(38,0)'
match = re.match(r"([a-zA-Z_]+)\(\d+,\s\d+\)", type_string)
if match:
modified_type_base = match.group(1) # Extract the base type
return SNOWFLAKE_TYPES_MAP.get(modified_type_base, None)

# Fallback for types without precision/scale
return SNOWFLAKE_TYPES_MAP.get(type_string, None)


# see https://github.com/googleapis/python-bigquery-sqlalchemy/blob/main/sqlalchemy_bigquery/_types.py#L32
BIGQUERY_TYPES_MAP: Dict[str, Any] = {
"STRING": StringType,
Expand Down Expand Up @@ -380,6 +391,7 @@ def resolve_vertica_modified_type(type_string: str) -> Any:
"row": RecordType,
"map": MapType,
"array": ArrayType,
"json": RecordType,
}

# https://docs.aws.amazon.com/athena/latest/ug/data-types.html
Expand Down Expand Up @@ -490,7 +502,7 @@ def resolve_sql_type(
TypeClass = resolve_vertica_modified_type(column_type)
elif platform == "snowflake":
# Snowflake types are uppercase, so we check that.
TypeClass = _merged_mapping.get(column_type.upper())
TypeClass = resolve_snowflake_modified_type(column_type.upper())

if TypeClass:
return TypeClass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def handle_end_of_stream(
logger.debug("Generating tags")

for tag_association in self.processed_tags.values():
tag_urn = TagUrn.create_from_string(tag_association.tag)
tag_urn = TagUrn.from_string(tag_association.tag)
mcps.append(
MetadataChangeProposalWrapper(
entityUrn=tag_urn.urn(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def transform(
)
if transformed_aspect:
# for end of stream records, we modify the workunit-id
structured_urn = Urn.create_from_string(urn)
structured_urn = Urn.from_string(urn)
simple_name = "-".join(structured_urn.get_entity_id())
record_metadata = envelope.metadata.copy()
record_metadata.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_entity_name(assertion: BaseEntityAssertion) -> Tuple[str, str, str]:
if qualified_name is not None:
parts = qualified_name.split(".")
else:
urn_id = Urn.create_from_string(assertion.entity).entity_ids[1]
urn_id = Urn.from_string(assertion.entity).entity_ids[1]
parts = urn_id.split(".")
if len(parts) > 3:
parts = parts[-3:]
Expand Down
Loading

0 comments on commit 81ee794

Please sign in to comment.