Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chakru-r committed Mar 3, 2025
1 parent a1db703 commit 7c5a222
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 41 deletions.
112 changes: 74 additions & 38 deletions metadata-ingestion/src/datahub/api/entities/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
TagAssociationClass,
UpstreamClass,
)
from datahub.metadata.urns import DataPlatformUrn, StructuredPropertyUrn, TagUrn, GlossaryTermUrn
from datahub.metadata.urns import (
DataPlatformUrn,
GlossaryTermUrn,
StructuredPropertyUrn,
TagUrn,
)
from datahub.specific.dataset import DatasetPatchBuilder
from datahub.utilities.urns.dataset_urn import DatasetUrn

Expand All @@ -69,8 +74,8 @@ class SchemaFieldSpecification(StrictModel):
description: Union[None, str] = None
doc: Union[None, str] = None # doc is an alias for description
label: Optional[str] = None
created: Optional[dict] = None
lastModified: Optional[dict] = None
created: Optional[AuditStampClass] = None
lastModified: Optional[AuditStampClass] = None
recursive: bool = False
globalTags: Optional[List[str]] = None
glossaryTerms: Optional[List[str]] = None
Expand Down Expand Up @@ -137,9 +142,11 @@ def from_schema_field(
globalTags=[TagUrn(tag.tag).name for tag in schema_field.globalTags.tags]
if schema_field.globalTags
else None,
glossaryTerms=[GlossaryTermUrn(term.urn).name for term in schema_field.glossaryTerms.terms]
if
schema_field.glossaryTerms
glossaryTerms=[
GlossaryTermUrn(term.urn).name
for term in schema_field.glossaryTerms.terms
]
if schema_field.glossaryTerms
else None,
isPartitioningKey=schema_field.isPartitioningKey,
jsonProps=(
Expand All @@ -154,7 +161,7 @@ def either_id_or_urn_must_be_filled_out(cls, v, values):
return v

@root_validator(pre=True)
def sync_description_and_doc(cls, values) -> dict:
def sync_description_and_doc(cls, values: Dict) -> dict:
"""Synchronize doc and description fields if one is provided but not the other."""
description = values.get("description")
doc = values.get("doc")
Expand All @@ -178,7 +185,7 @@ def get_datahub_type(self) -> models.SchemaFieldDataTypeClass:
"bytes",
"fixed",
]
type = self.type.lower()
type: Optional[str] = self.type.lower() if self.type is not None else None
if type not in set(get_args(PrimitiveType)):
raise ValueError(f"Type {self.type} is not a valid primitive type")

Expand Down Expand Up @@ -237,7 +244,7 @@ def dict(self, **kwargs):

self.simplify_structured_properties()

return super().dict(exclude=exclude, exclude_defaults=True, **kwargs)
return super().dict(exclude=exclude, exclude_defaults=True, **kwargs) # type: ignore[misc]

def model_dump(self, **kwargs):
"""Custom model_dump to handle YAML serialization properly."""
Expand All @@ -259,7 +266,7 @@ def model_dump(self, **kwargs):
if field_urn.field_path == self.id:
exclude.add("urn")

return super().model_dump(exclude=exclude, exclude_defaults=True, **kwargs)
return super().model_dump(exclude=exclude, exclude_defaults=True, **kwargs) # type: ignore[misc]


class SchemaSpecification(BaseModel):
Expand Down Expand Up @@ -365,8 +372,9 @@ def _mint_owner(self, owner: Union[str, Ownership]) -> OwnerClass:
@staticmethod
def get_patch_builder(urn: str) -> DatasetPatchBuilder:
return DatasetPatchBuilder(urn)

def patch_builder(self) -> DatasetPatchBuilder:
assert self.urn is not None # Validator fills this, assert to tell mypy.
return DatasetPatchBuilder(self.urn)

@classmethod
Expand All @@ -384,7 +392,7 @@ def from_yaml(cls, file: str) -> Iterable["Dataset"]:
def entity_references(self) -> List[str]:
urn_prefix = f"{StructuredPropertyUrn.URN_PREFIX}:{StructuredPropertyUrn.LI_DOMAIN}:{StructuredPropertyUrn.ENTITY_TYPE}"
references = []
if self.schema_metadata:
if self.schema_metadata and self.schema_metadata.fields:
for field in self.schema_metadata.fields:
if field.structured_properties:
references.extend(
Expand Down Expand Up @@ -439,6 +447,7 @@ def generate_mcp(
raise ValueError(
"Either all fields must have type information or none of them should"
)

if all_fields_type_info_present:
update_technical_schema = True
else:
Expand All @@ -453,20 +462,39 @@ def generate_mcp(
hash="",
fields=[
SchemaFieldClass(
fieldPath=field.id,
fieldPath=field.id, # type: ignore[arg-type]
type=field.get_datahub_type(),
nativeDataType=field.nativeDataType or field.type,
nativeDataType=field.nativeDataType or field.type, # type: ignore[arg-type]
nullable=field.nullable,
description=field.description,
label=field.label,
created=field.created,
lastModified=field.lastModified,
recursive=field.recursive,
globalTags=field.globalTags,
glossaryTerms=field.glossaryTerms,
globalTags=GlobalTagsClass(
tags=[
TagAssociationClass(tag=make_tag_urn(tag))
for tag in field.globalTags
]
)
if field.globalTags is not None
else None,
glossaryTerms=GlossaryTermsClass(
terms=[
GlossaryTermAssociationClass(
urn=make_term_urn(term)
)
for term in field.glossaryTerms
],
auditStamp=self._mint_auditstamp("yaml"),
)
if field.glossaryTerms is not None
else None,
isPartOfKey=field.isPartOfKey,
isPartitioningKey=field.isPartitioningKey,
jsonProps=field.jsonProps,
jsonProps=json.dumps(field.jsonProps)
if field.jsonProps is not None
else None,
)
for field in self.schema_metadata.fields
],
Expand Down Expand Up @@ -732,7 +760,8 @@ def from_datahub(cls, graph: DataHubGraph, urn: str) -> "Dataset":
else:
structured_properties_map[sp.propertyUrn] = sp.values

from datahub.metadata.urns import TagUrn, GlossaryTermUrn
from datahub.metadata.urns import GlossaryTermUrn, TagUrn

return Dataset( # type: ignore[arg-type]
id=dataset_urn.name,
platform=platform_urn.platform_name,
Expand All @@ -750,7 +779,9 @@ def from_datahub(cls, graph: DataHubGraph, urn: str) -> "Dataset":
schema=Dataset._schema_from_schema_metadata(graph, urn),
tags=[TagUrn(tag.tag).name for tag in tags.tags] if tags else None,
glossary_terms=(
[GlossaryTermUrn(term.urn).name for term in glossary_terms.terms] if glossary_terms else None
[GlossaryTermUrn(term.urn).name for term in glossary_terms.terms]
if glossary_terms
else None
),
owners=yaml_owners,
properties=(
Expand Down Expand Up @@ -785,11 +816,12 @@ def dict(self, **kwargs):
if "fields" in schema_data and isinstance(schema_data["fields"], list):
# Process each field using its custom dict method
processed_fields = []
for field in self.schema_metadata.fields:
if field:
# Use dict method for Pydantic v1
processed_field = field.dict(**kwargs)
processed_fields.append(processed_field)
if self.schema_metadata and self.schema_metadata.fields:
for field in self.schema_metadata.fields:
if field:
# Use dict method for Pydantic v1
processed_field = field.dict(**kwargs)
processed_fields.append(processed_field)

# Replace the fields in the result with the processed ones
schema_data["fields"] = processed_fields
Expand All @@ -812,10 +844,10 @@ def model_dump(self, **kwargs):
# Check which method exists in the parent class
if hasattr(super(), "model_dump"):
# For Pydantic v2
result = super().model_dump(exclude=exclude, **kwargs)
result = super().model_dump(exclude=exclude, **kwargs) # type: ignore[misc]
elif hasattr(super(), "dict"):
# For Pydantic v1
result = super().dict(exclude=exclude, **kwargs)
result = super().dict(exclude=exclude, **kwargs) # type: ignore[misc]
else:
# Fallback to __dict__ if neither method exists
result = {k: v for k, v in self.__dict__.items() if k not in exclude}
Expand All @@ -828,16 +860,19 @@ def model_dump(self, **kwargs):
if "fields" in schema_data and isinstance(schema_data["fields"], list):
# Process each field using its custom model_dump
processed_fields = []
for field in self.schema_metadata.fields:
if field:
# Call the appropriate serialization method on each field
if hasattr(field, "model_dump"):
processed_field = field.model_dump(**kwargs)
elif hasattr(field, "dict"):
processed_field = field.dict(**kwargs)
else:
processed_field = {k: v for k, v in field.__dict__.items()}
processed_fields.append(processed_field)
if self.schema_metadata and self.schema_metadata.fields:
for field in self.schema_metadata.fields:
if field:
# Call the appropriate serialization method on each field
if hasattr(field, "model_dump"):
processed_field = field.model_dump(**kwargs)
elif hasattr(field, "dict"):
processed_field = field.dict(**kwargs)
else:
processed_field = {
k: v for k, v in field.__dict__.items()
}
processed_fields.append(processed_field)

# Replace the fields in the result with the processed ones
schema_data["fields"] = processed_fields
Expand All @@ -860,7 +895,7 @@ def to_yaml(
# Set up ruamel.yaml for preserving comments
yaml_handler = YAML(typ="rt") # round-trip mode
yaml_handler.default_flow_style = False
yaml_handler.preserve_quotes = True
yaml_handler.preserve_quotes = True # type: ignore[assignment]
yaml_handler.indent(mapping=2, sequence=2, offset=0)

if file.exists():
Expand Down Expand Up @@ -933,12 +968,13 @@ def to_yaml(


def _update_dict_preserving_comments(
target: Dict, source: Dict, optional_fields: List[str] = ["urn"]
target: Dict, source: Dict, optional_fields: Optional[List[str]] = None
) -> None:
"""
Updates a target dictionary with values from source, preserving comments and structure.
This modifies the target dictionary in-place.
"""
optional_fields = optional_fields or ["urn"]
# For each key in the source dict
for key, value in source.items():
if key in target:
Expand Down
9 changes: 6 additions & 3 deletions metadata-ingestion/src/datahub/cli/specific/dataset_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
from pathlib import Path
from typing import Set, Tuple
from typing import List, Set, Tuple

import click
from click_default_group import DefaultGroup
Expand Down Expand Up @@ -130,7 +130,7 @@ def file(lintcheck: bool, lintfix: bool, file: str) -> None:
shutil.copyfile(file, temp_path)

# Run the linting
datasets = Dataset.from_yaml(temp_path)
datasets = Dataset.from_yaml(temp.name)
for dataset in datasets:
dataset.to_yaml(temp_path)

Expand Down Expand Up @@ -173,10 +173,13 @@ def file(lintcheck: bool, lintfix: bool, file: str) -> None:
def sync(file: str, to_datahub: bool) -> None:
"""Sync a Dataset file to/from DataHub"""

failures = []
failures: List[str] = []
with get_default_graph() as graph:
datasets = Dataset.from_yaml(file)
for dataset in datasets:
assert (
dataset.urn is not None
) # Validator should have ensured this is filled. Tell mypy it's not None
if to_datahub:
missing_entity_references = [
entity_reference
Expand Down

0 comments on commit 7c5a222

Please sign in to comment.