diff --git a/metadata-ingestion/scripts/avro_codegen.py b/metadata-ingestion/scripts/avro_codegen.py index 2fe2729349944..c895e5fabfd37 100644 --- a/metadata-ingestion/scripts/avro_codegen.py +++ b/metadata-ingestion/scripts/avro_codegen.py @@ -323,6 +323,7 @@ def annotate_aspects(aspects: List[dict], schema_class_file: Path) -> None: for aspect in ASPECT_CLASSES }} +from typing import Literal from typing_extensions import TypedDict class AspectBag(TypedDict, total=False): @@ -332,6 +333,13 @@ class AspectBag(TypedDict, total=False): KEY_ASPECTS: Dict[str, Type[_Aspect]] = {{ {f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}': {aspect['name']}Class" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))} }} + +ENTITY_TYPE_NAMES: List[str] = [ + {f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}'" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))} +] +EntityTypeName = Literal[ + {f",{newline} ".join(f"'{aspect['Aspect']['keyForEntity']}'" for aspect in aspects if aspect["Aspect"].get("keyForEntity"))} +] """ ) @@ -346,7 +354,7 @@ def write_urn_classes(key_aspects: List[dict], urn_dir: Path) -> None: code = """ # This file contains classes corresponding to entity URNs. -from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union +from typing import ClassVar, List, Optional, Type, TYPE_CHECKING, Union, Literal import functools from deprecated.sphinx import deprecated as _sphinx_deprecated @@ -672,7 +680,7 @@ def generate_urn_class(entity_type: str, key_aspect: dict) -> str: from datahub.metadata.schema_classes import {key_aspect_class} class {class_name}(_SpecificUrn): - ENTITY_TYPE: ClassVar[str] = "{entity_type}" + ENTITY_TYPE: ClassVar[Literal["{entity_type}"]] = "{entity_type}" _URN_PARTS: ClassVar[int] = {arg_count} def __init__(self, {init_args}, *, _allow_coercion: bool = True) -> None: diff --git a/metadata-ingestion/src/datahub/ingestion/graph/client.py b/metadata-ingestion/src/datahub/ingestion/graph/client.py index 69968ecb726f2..33b8f52a6532a 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/client.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/client.py @@ -16,6 +16,7 @@ List, Literal, Optional, + Sequence, Tuple, Type, Union, @@ -42,8 +43,8 @@ ) from datahub.ingestion.graph.entity_versioning import EntityVersioningAPI from datahub.ingestion.graph.filters import ( + RawSearchFilterRule, RemovedStatusFilter, - SearchFilterRule, generate_filter, ) from datahub.ingestion.source.state.checkpoint import Checkpoint @@ -105,7 +106,7 @@ class RelatedEntity: via: Optional[str] = None -def _graphql_entity_type(entity_type: str) -> str: +def entity_type_to_graphql(entity_type: str) -> str: """Convert the entity types into GraphQL "EntityType" enum values.""" # Hard-coded special cases. @@ -797,13 +798,13 @@ def _bulk_fetch_schema_info_by_filter( container: Optional[str] = None, status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED, batch_size: int = 100, - extraFilters: Optional[List[SearchFilterRule]] = None, + extraFilters: Optional[List[RawSearchFilterRule]] = None, ) -> Iterable[Tuple[str, "GraphQLSchemaMetadata"]]: """Fetch schema info for datasets that match all of the given filters. :return: An iterable of (urn, schema info) tuple that match the filters. """ - types = [_graphql_entity_type("dataset")] + types = [entity_type_to_graphql("dataset")] # Add the query default of * if no query is specified. query = query or "*" @@ -865,7 +866,7 @@ def _bulk_fetch_schema_info_by_filter( def get_urns_by_filter( self, *, - entity_types: Optional[List[str]] = None, + entity_types: Optional[Sequence[str]] = None, platform: Optional[str] = None, platform_instance: Optional[str] = None, env: Optional[str] = None, @@ -873,8 +874,8 @@ def get_urns_by_filter( container: Optional[str] = None, status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED, batch_size: int = 10000, - extraFilters: Optional[List[SearchFilterRule]] = None, - extra_or_filters: Optional[List[Dict[str, List[SearchFilterRule]]]] = None, + extraFilters: Optional[List[RawSearchFilterRule]] = None, + extra_or_filters: Optional[List[Dict[str, List[RawSearchFilterRule]]]] = None, ) -> Iterable[str]: """Fetch all urns that match all of the given filters. @@ -965,8 +966,8 @@ def get_results_by_filter( container: Optional[str] = None, status: RemovedStatusFilter = RemovedStatusFilter.NOT_SOFT_DELETED, batch_size: int = 10000, - extra_and_filters: Optional[List[SearchFilterRule]] = None, - extra_or_filters: Optional[List[Dict[str, List[SearchFilterRule]]]] = None, + extra_and_filters: Optional[List[RawSearchFilterRule]] = None, + extra_or_filters: Optional[List[Dict[str, List[RawSearchFilterRule]]]] = None, extra_source_fields: Optional[List[str]] = None, skip_cache: bool = False, ) -> Iterable[dict]: @@ -1109,7 +1110,8 @@ def _scroll_across_entities( f"Scrolling to next scrollAcrossEntities page: {scroll_id}" ) - def _get_types(self, entity_types: Optional[List[str]]) -> Optional[List[str]]: + @classmethod + def _get_types(cls, entity_types: Optional[Sequence[str]]) -> Optional[List[str]]: types: Optional[List[str]] = None if entity_types is not None: if not entity_types: @@ -1117,7 +1119,9 @@ def _get_types(self, entity_types: Optional[List[str]]) -> Optional[List[str]]: "entity_types cannot be an empty list; use None for all entities" ) - types = [_graphql_entity_type(entity_type) for entity_type in entity_types] + types = [ + entity_type_to_graphql(entity_type) for entity_type in entity_types + ] return types def get_latest_pipeline_checkpoint( diff --git a/metadata-ingestion/src/datahub/ingestion/graph/filters.py b/metadata-ingestion/src/datahub/ingestion/graph/filters.py index 588090ec56727..01ef3d5a248ee 100644 --- a/metadata-ingestion/src/datahub/ingestion/graph/filters.py +++ b/metadata-ingestion/src/datahub/ingestion/graph/filters.py @@ -1,3 +1,4 @@ +import dataclasses import enum from typing import Any, Dict, List, Optional @@ -7,7 +8,31 @@ ) from datahub.utilities.urns.urn import guess_entity_type -SearchFilterRule = Dict[str, Any] +RawSearchFilterRule = Dict[str, Any] + + +@dataclasses.dataclass +class SearchFilterRule: + field: str + condition: str # TODO: convert to an enum + values: List[str] + negated: bool = False + + def to_raw(self) -> RawSearchFilterRule: + return { + "field": self.field, + "condition": self.condition, + "values": self.values, + "negated": self.negated, + } + + def negate(self) -> "SearchFilterRule": + return SearchFilterRule( + field=self.field, + condition=self.condition, + values=self.values, + negated=not self.negated, + ) class RemovedStatusFilter(enum.Enum): @@ -29,9 +54,9 @@ def generate_filter( env: Optional[str], container: Optional[str], status: RemovedStatusFilter, - extra_filters: Optional[List[SearchFilterRule]], - extra_or_filters: Optional[List[SearchFilterRule]] = None, -) -> List[Dict[str, List[SearchFilterRule]]]: + extra_filters: Optional[List[RawSearchFilterRule]], + extra_or_filters: Optional[List[RawSearchFilterRule]] = None, +) -> List[Dict[str, List[RawSearchFilterRule]]]: """ Generate a search filter based on the provided parameters. :param platform: The platform to filter by. @@ -43,30 +68,32 @@ def generate_filter( :param extra_or_filters: Extra OR filters to apply. These are combined with the AND filters using an OR at the top level. """ - and_filters: List[SearchFilterRule] = [] + and_filters: List[RawSearchFilterRule] = [] # Platform filter. if platform: - and_filters.append(_get_platform_filter(platform)) + and_filters.append(_get_platform_filter(platform).to_raw()) # Platform instance filter. if platform_instance: - and_filters.append(_get_platform_instance_filter(platform, platform_instance)) + and_filters.append( + _get_platform_instance_filter(platform, platform_instance).to_raw() + ) # Browse path v2 filter. if container: - and_filters.append(_get_container_filter(container)) + and_filters.append(_get_container_filter(container).to_raw()) # Status filter. status_filter = _get_status_filter(status) if status_filter: - and_filters.append(status_filter) + and_filters.append(status_filter.to_raw()) # Extra filters. if extra_filters: and_filters += extra_filters - or_filters: List[Dict[str, List[SearchFilterRule]]] = [{"and": and_filters}] + or_filters: List[Dict[str, List[RawSearchFilterRule]]] = [{"and": and_filters}] # Env filter if env: @@ -89,7 +116,7 @@ def generate_filter( return or_filters -def _get_env_filters(env: str) -> List[SearchFilterRule]: +def _get_env_filters(env: str) -> List[RawSearchFilterRule]: # The env filter is a bit more tricky since it's not always stored # in the same place in ElasticSearch. return [ @@ -125,19 +152,19 @@ def _get_status_filter(status: RemovedStatusFilter) -> Optional[SearchFilterRule # removed field is simply not present in the ElasticSearch document. Ideally this # would be a "removed" : "false" filter, but that doesn't work. Instead, we need to # use a negated filter. - return { - "field": "removed", - "values": ["true"], - "condition": "EQUAL", - "negated": True, - } + return SearchFilterRule( + field="removed", + values=["true"], + condition="EQUAL", + negated=True, + ) elif status == RemovedStatusFilter.ONLY_SOFT_DELETED: - return { - "field": "removed", - "values": ["true"], - "condition": "EQUAL", - } + return SearchFilterRule( + field="removed", + values=["true"], + condition="EQUAL", + ) elif status == RemovedStatusFilter.ALL: # We don't need to add a filter for this case. @@ -152,11 +179,11 @@ def _get_container_filter(container: str) -> SearchFilterRule: if guess_entity_type(container) != "container": raise ValueError(f"Invalid container urn: {container}") - return { - "field": "browsePathV2", - "values": [container], - "condition": "CONTAIN", - } + return SearchFilterRule( + field="browsePathV2", + values=[container], + condition="CONTAIN", + ) def _get_platform_instance_filter( @@ -171,16 +198,16 @@ def _get_platform_instance_filter( if guess_entity_type(platform_instance) != "dataPlatformInstance": raise ValueError(f"Invalid data platform instance urn: {platform_instance}") - return { - "field": "platformInstance", - "values": [platform_instance], - "condition": "EQUAL", - } + return SearchFilterRule( + field="platformInstance", + condition="EQUAL", + values=[platform_instance], + ) def _get_platform_filter(platform: str) -> SearchFilterRule: - return { - "field": "platform.keyword", - "values": [make_data_platform_urn(platform)], - "condition": "EQUAL", - } + return SearchFilterRule( + field="platform.keyword", + condition="EQUAL", + values=[make_data_platform_urn(platform)], + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py index 9966d333fdc17..7ff14fee8c38c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py +++ b/metadata-ingestion/src/datahub/ingestion/source/cassandra/cassandra.py @@ -59,9 +59,9 @@ UpstreamLineageClass, ViewPropertiesClass, ) -from datahub.sdk._entity import Entity from datahub.sdk.container import Container from datahub.sdk.dataset import Dataset +from datahub.sdk.entity import Entity logger = logging.getLogger(__name__) diff --git a/metadata-ingestion/src/datahub/sdk/__init__.py b/metadata-ingestion/src/datahub/sdk/__init__.py index ec7ecf4ce0688..09227ee1a241b 100644 --- a/metadata-ingestion/src/datahub/sdk/__init__.py +++ b/metadata-ingestion/src/datahub/sdk/__init__.py @@ -20,6 +20,7 @@ from datahub.sdk.container import Container from datahub.sdk.dataset import Dataset from datahub.sdk.main_client import DataHubClient +from datahub.sdk.search_filters import Filter, FilterDsl # We want to print out the warning if people do `from datahub.sdk import X`. # But we don't want to print out warnings if they're doing a more direct diff --git a/metadata-ingestion/src/datahub/sdk/_all_entities.py b/metadata-ingestion/src/datahub/sdk/_all_entities.py index 04c5fb6045ae6..67834adb634bc 100644 --- a/metadata-ingestion/src/datahub/sdk/_all_entities.py +++ b/metadata-ingestion/src/datahub/sdk/_all_entities.py @@ -1,8 +1,8 @@ from typing import Dict, List, Type -from datahub.sdk._entity import Entity from datahub.sdk.container import Container from datahub.sdk.dataset import Dataset +from datahub.sdk.entity import Entity # TODO: Is there a better way to declare this? ENTITY_CLASSES_LIST: List[Type[Entity]] = [ diff --git a/metadata-ingestion/src/datahub/sdk/_shared.py b/metadata-ingestion/src/datahub/sdk/_shared.py index 89fbbde5ce88c..b061fd7aa63a7 100644 --- a/metadata-ingestion/src/datahub/sdk/_shared.py +++ b/metadata-ingestion/src/datahub/sdk/_shared.py @@ -36,8 +36,8 @@ TagUrn, Urn, ) -from datahub.sdk._entity import Entity from datahub.sdk._utils import add_list_unique, remove_list_unique +from datahub.sdk.entity import Entity from datahub.utilities.urns.error import InvalidUrnError if TYPE_CHECKING: diff --git a/metadata-ingestion/src/datahub/sdk/container.py b/metadata-ingestion/src/datahub/sdk/container.py index ec4d6521c6088..d2c449a6d2166 100644 --- a/metadata-ingestion/src/datahub/sdk/container.py +++ b/metadata-ingestion/src/datahub/sdk/container.py @@ -16,7 +16,6 @@ ContainerUrn, Urn, ) -from datahub.sdk._entity import Entity, ExtraAspectsType from datahub.sdk._shared import ( DomainInputType, HasContainer, @@ -33,6 +32,7 @@ make_time_stamp, parse_time_stamp, ) +from datahub.sdk.entity import Entity, ExtraAspectsType from datahub.utilities.sentinels import Auto, auto diff --git a/metadata-ingestion/src/datahub/sdk/dataset.py b/metadata-ingestion/src/datahub/sdk/dataset.py index 2fff5adf25009..c367aa79cbcc9 100644 --- a/metadata-ingestion/src/datahub/sdk/dataset.py +++ b/metadata-ingestion/src/datahub/sdk/dataset.py @@ -18,7 +18,6 @@ from datahub.ingestion.source.sql.sql_types import resolve_sql_type from datahub.metadata.urns import DatasetUrn, SchemaFieldUrn, Urn from datahub.sdk._attribution import is_ingestion_attribution -from datahub.sdk._entity import Entity, ExtraAspectsType from datahub.sdk._shared import ( DatasetUrnOrStr, DomainInputType, @@ -39,6 +38,7 @@ parse_time_stamp, ) from datahub.sdk._utils import add_list_unique, remove_list_unique +from datahub.sdk.entity import Entity, ExtraAspectsType from datahub.utilities.sentinels import Unset, unset SchemaFieldInputType: TypeAlias = Union[ diff --git a/metadata-ingestion/src/datahub/sdk/_entity.py b/metadata-ingestion/src/datahub/sdk/entity.py similarity index 97% rename from metadata-ingestion/src/datahub/sdk/_entity.py rename to metadata-ingestion/src/datahub/sdk/entity.py index f5887e4e0fb80..f50c86f4b2b0d 100644 --- a/metadata-ingestion/src/datahub/sdk/_entity.py +++ b/metadata-ingestion/src/datahub/sdk/entity.py @@ -56,6 +56,10 @@ def _init_from_graph(self, current_aspects: models.AspectBag) -> Self: @abc.abstractmethod def get_urn_type(cls) -> Type[_SpecificUrn]: ... + @classmethod + def entity_type_name(cls) -> str: + return cls.get_urn_type().ENTITY_TYPE + @property def urn(self) -> _SpecificUrn: return self._urn diff --git a/metadata-ingestion/src/datahub/sdk/entity_client.py b/metadata-ingestion/src/datahub/sdk/entity_client.py index 99dc7f9a280ab..bcb9f2798ffcf 100644 --- a/metadata-ingestion/src/datahub/sdk/entity_client.py +++ b/metadata-ingestion/src/datahub/sdk/entity_client.py @@ -14,10 +14,10 @@ Urn, ) from datahub.sdk._all_entities import ENTITY_CLASSES -from datahub.sdk._entity import Entity from datahub.sdk._shared import UrnOrStr from datahub.sdk.container import Container from datahub.sdk.dataset import Dataset +from datahub.sdk.entity import Entity if TYPE_CHECKING: from datahub.sdk.main_client import DataHubClient diff --git a/metadata-ingestion/src/datahub/sdk/main_client.py b/metadata-ingestion/src/datahub/sdk/main_client.py index beab1dea92808..ef58007a46290 100644 --- a/metadata-ingestion/src/datahub/sdk/main_client.py +++ b/metadata-ingestion/src/datahub/sdk/main_client.py @@ -7,6 +7,7 @@ from datahub.ingestion.graph.config import DatahubClientConfig from datahub.sdk.entity_client import EntityClient from datahub.sdk.resolver_client import ResolverClient +from datahub.sdk.search_client import SearchClient class DataHubClient: @@ -39,6 +40,8 @@ def __init__( self._graph = graph + # TODO: test connection + @classmethod def from_env(cls) -> "DataHubClient": """Initialize a DataHubClient from the environment variables or ~/.datahubenv file. @@ -69,5 +72,8 @@ def entities(self) -> EntityClient: def resolve(self) -> ResolverClient: return ResolverClient(self) - # TODO: search client + @property + def search(self) -> SearchClient: + return SearchClient(self) + # TODO: lineage client diff --git a/metadata-ingestion/src/datahub/sdk/resolver_client.py b/metadata-ingestion/src/datahub/sdk/resolver_client.py index dae2f61a918dd..63a6026411992 100644 --- a/metadata-ingestion/src/datahub/sdk/resolver_client.py +++ b/metadata-ingestion/src/datahub/sdk/resolver_client.py @@ -9,6 +9,7 @@ DomainUrn, GlossaryTermUrn, ) +from datahub.sdk.search_filters import Filter, FilterDsl as F if TYPE_CHECKING: from datahub.sdk.main_client import DataHubClient @@ -38,37 +39,28 @@ def user( self, *, name: Optional[str] = None, email: Optional[str] = None ) -> CorpUserUrn: filter_explanation: str - filters = [] + filter: Filter if name is not None: if email is not None: raise SdkUsageError("Cannot specify both name and email for auto_user") - # TODO: do we filter on displayName or fullName? + # We're filtering on both fullName and displayName. It's not clear + # what the right behavior is here. filter_explanation = f"with name {name}" - filters.append( - { - "field": "fullName", - "values": [name], - "condition": "EQUAL", - } + filter = F.or_( + F.custom_filter("fullName", "EQUAL", [name]), + F.custom_filter("displayName", "EQUAL", [name]), ) elif email is not None: filter_explanation = f"with email {email}" - filters.append( - { - "field": "email", - "values": [email], - "condition": "EQUAL", - } - ) + filter = F.custom_filter("email", "EQUAL", [email]) else: raise SdkUsageError("Must specify either name or email for auto_user") - users = list( - self._graph.get_urns_by_filter( - entity_types=[CorpUserUrn.ENTITY_TYPE], - extraFilters=filters, - ) + filter = F.and_( + F.entity_type(CorpUserUrn.ENTITY_TYPE), + filter, ) + users = list(self._client.search.get_urns(filter=filter)) if len(users) == 0: # TODO: In auto methods, should we just create the user/domain/etc if it doesn't exist? raise ItemNotFoundError(f"User {filter_explanation} not found") @@ -82,15 +74,11 @@ def user( def term(self, *, name: str) -> GlossaryTermUrn: # TODO: Add some limits on the graph fetch terms = list( - self._graph.get_urns_by_filter( - entity_types=[GlossaryTermUrn.ENTITY_TYPE], - extraFilters=[ - { - "field": "id", - "values": [name], - "condition": "EQUAL", - } - ], + self._client.search.get_urns( + filter=F.and_( + F.entity_type(GlossaryTermUrn.ENTITY_TYPE), + F.custom_filter("name", "EQUAL", [name]), + ), ) ) if len(terms) == 0: diff --git a/metadata-ingestion/src/datahub/sdk/search_client.py b/metadata-ingestion/src/datahub/sdk/search_client.py new file mode 100644 index 0000000000000..bed298b694498 --- /dev/null +++ b/metadata-ingestion/src/datahub/sdk/search_client.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, +) + +from datahub.ingestion.graph.filters import RawSearchFilterRule +from datahub.metadata.urns import Urn +from datahub.sdk.search_filters import Filter + +if TYPE_CHECKING: + from datahub.sdk.main_client import DataHubClient + + +def compile_filters( + filter: Optional[Filter], +) -> Optional[List[Dict[str, List[RawSearchFilterRule]]]]: + # TODO: Not every filter type is supported for every entity type. + # If we can detect issues with the filters at compile time, we should + # raise an error. + + if filter is None: + return None + + initial_filters = filter.compile() + return [ + {"and": [rule.to_raw() for rule in andClause["and"]]} + for andClause in initial_filters + ] + + +class SearchClient: + def __init__(self, client: DataHubClient): + self._client = client + + def get_urns( + self, + query: Optional[str] = None, + filter: Optional[Filter] = None, + ) -> Iterable[Urn]: + # TODO: Add better limit / pagination support. + for urn in self._client._graph.get_urns_by_filter( + query=query, + extra_or_filters=compile_filters(filter), + ): + yield Urn.from_string(urn) diff --git a/metadata-ingestion/src/datahub/sdk/search_filters.py b/metadata-ingestion/src/datahub/sdk/search_filters.py new file mode 100644 index 0000000000000..5c5116b181ac0 --- /dev/null +++ b/metadata-ingestion/src/datahub/sdk/search_filters.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +import abc +from typing import ( + Any, + List, + Sequence, + TypedDict, + Union, +) + +import pydantic + +from datahub.configuration.common import ConfigModel +from datahub.configuration.pydantic_migration_helpers import PYDANTIC_VERSION_2 +from datahub.ingestion.graph.client import entity_type_to_graphql +from datahub.ingestion.graph.filters import SearchFilterRule +from datahub.metadata.schema_classes import EntityTypeName +from datahub.metadata.urns import DataPlatformUrn, DomainUrn + +_AndSearchFilterRule = TypedDict( + "_AndSearchFilterRule", {"and": List[SearchFilterRule]} +) +_OrFilters = List[_AndSearchFilterRule] + + +class _BaseFilter(ConfigModel): + class Config: + # We can't wrap this in a TYPE_CHECKING block because the pydantic plugin + # doesn't recognize it properly. So unfortunately we'll need to live + # with the deprecation warning w/ pydantic v2. + allow_population_by_field_name = True + if PYDANTIC_VERSION_2: + populate_by_name = True + + @abc.abstractmethod + def compile(self) -> _OrFilters: + pass + + +def _flexible_entity_type_to_graphql(entity_type: str) -> str: + if entity_type.upper() == entity_type: + # Assume that we were passed a graphql EntityType enum value, + # so no conversion is needed. + return entity_type + return entity_type_to_graphql(entity_type) + + +class _EntityTypeFilter(_BaseFilter): + entity_type: List[str] = pydantic.Field( + description="The entity type to filter on. Can be 'dataset', 'chart', 'dashboard', 'corpuser', etc.", + ) + + def _build_rule(self) -> SearchFilterRule: + return SearchFilterRule( + field="_entityType", + condition="EQUAL", + values=[_flexible_entity_type_to_graphql(t) for t in self.entity_type], + ) + + def compile(self) -> _OrFilters: + return [{"and": [self._build_rule()]}] + + +class _EntitySubtypeFilter(_BaseFilter): + entity_type: str + entity_subtype: str = pydantic.Field( + description="The entity subtype to filter on. Can be 'Table', 'View', 'Source', etc. depending on the native platform's concepts.", + ) + + def compile(self) -> _OrFilters: + rules = [ + SearchFilterRule( + field="_entityType", + condition="EQUAL", + values=[_flexible_entity_type_to_graphql(self.entity_type)], + ), + SearchFilterRule( + field="typeNames", + condition="EQUAL", + values=[self.entity_subtype], + ), + ] + return [{"and": rules}] + + +class _PlatformFilter(_BaseFilter): + platform: List[str] + # TODO: Add validator to convert string -> list of strings + + @pydantic.validator("platform", each_item=True) + def validate_platform(cls, v: str) -> str: + # Subtle - we use the constructor instead of the from_string method + # because coercion is acceptable here. + return str(DataPlatformUrn(v)) + + def _build_rule(self) -> SearchFilterRule: + return SearchFilterRule( + field="platform.keyword", + condition="EQUAL", + values=self.platform, + ) + + def compile(self) -> _OrFilters: + return [{"and": [self._build_rule()]}] + + +class _DomainFilter(_BaseFilter): + domain: List[str] + + @pydantic.validator("domain", each_item=True) + def validate_domain(cls, v: str) -> str: + return str(DomainUrn.from_string(v)) + + def _build_rule(self) -> SearchFilterRule: + return SearchFilterRule( + field="domains", + condition="EQUAL", + values=self.domain, + ) + + def compile(self) -> _OrFilters: + return [{"and": [self._build_rule()]}] + + +class _EnvFilter(_BaseFilter): + # Note that not all entity types have an env (e.g. dashboards / charts). + # If the env filter is specified, these will be excluded. + env: List[str] + + def compile(self) -> _OrFilters: + return [ + # For most entity types, we look at the origin field. + { + "and": [ + SearchFilterRule( + field="origin", + condition="EQUAL", + values=self.env, + ), + ] + }, + # For containers, we now have an "env" property as of + # https://github.com/datahub-project/datahub/pull/11214 + # Prior to this, we put "env" in the customProperties. But we're + # not bothering with that here. + { + "and": [ + SearchFilterRule( + field="env", + condition="EQUAL", + values=self.env, + ), + ] + }, + ] + + +class _CustomCondition(_BaseFilter): + """Represents a single field condition""" + + field: str + condition: str + values: List[str] + + def compile(self) -> _OrFilters: + rule = SearchFilterRule( + field=self.field, + condition=self.condition, + values=self.values, + ) + return [{"and": [rule]}] + + +class _And(_BaseFilter): + """Represents an AND conjunction of filters""" + + and_: Sequence["Filter"] = pydantic.Field(alias="and") + # TODO: Add validator to ensure that the "and" field is not empty + + def compile(self) -> _OrFilters: + # The "and" operator must be implemented by doing a Cartesian product + # of the OR clauses. + # Example 1: + # (A or B) and (C or D) -> + # (A and C) or (A and D) or (B and C) or (B and D) + # Example 2: + # (A or B) and (C or D) and (E or F) -> + # (A and C and E) or (A and C and F) or (A and D and E) or (A and D and F) or + # (B and C and E) or (B and C and F) or (B and D and E) or (B and D and F) + + # Start with the first filter's OR clauses + result = self.and_[0].compile() + + # For each subsequent filter + for filter in self.and_[1:]: + new_result = [] + # Get its OR clauses + other_clauses = filter.compile() + + # Create Cartesian product + for existing_clause in result: + for other_clause in other_clauses: + # Merge the AND conditions from both clauses + new_result.append(self._merge_ands(existing_clause, other_clause)) + + result = new_result + + return result + + @classmethod + def _merge_ands( + cls, a: _AndSearchFilterRule, b: _AndSearchFilterRule + ) -> _AndSearchFilterRule: + return { + "and": [ + *a["and"], + *b["and"], + ] + } + + +class _Or(_BaseFilter): + """Represents an OR conjunction of filters""" + + or_: Sequence["Filter"] = pydantic.Field(alias="or") + # TODO: Add validator to ensure that the "or" field is not empty + + def compile(self) -> _OrFilters: + merged_filter = [] + for filter in self.or_: + merged_filter.extend(filter.compile()) + return merged_filter + + +class _Not(_BaseFilter): + """Represents a NOT filter""" + + not_: "Filter" = pydantic.Field(alias="not") + + @pydantic.validator("not_", pre=False) + def validate_not(cls, v: "Filter") -> "Filter": + inner_filter = v.compile() + if len(inner_filter) != 1: + raise ValueError( + "Cannot negate a filter with multiple OR clauses [not yet supported]" + ) + return v + + def compile(self) -> _OrFilters: + # TODO: Eventually we'll want to implement a full DNF normalizer. + # https://en.wikipedia.org/wiki/Disjunctive_normal_form#Conversion_to_DNF + + inner_filter = self.not_.compile() + assert len(inner_filter) == 1 # validated above + + # ¬(A and B) -> (¬A) OR (¬B) + and_filters = inner_filter[0]["and"] + final_filters: _OrFilters = [] + for rule in and_filters: + final_filters.append({"and": [rule.negate()]}) + + return final_filters + + +# TODO: With pydantic 2, we can use a RootModel with a +# discriminated union to make the error messages more informative. +Filter = Union[ + _And, + _Or, + _Not, + _EntityTypeFilter, + _EntitySubtypeFilter, + _PlatformFilter, + _DomainFilter, + _EnvFilter, + _CustomCondition, +] + + +# Required to resolve forward references to "Filter" +if PYDANTIC_VERSION_2: + _And.model_rebuild() # type: ignore + _Or.model_rebuild() # type: ignore + _Not.model_rebuild() # type: ignore +else: + _And.update_forward_refs() + _Or.update_forward_refs() + _Not.update_forward_refs() + + +def load_filters(obj: Any) -> Filter: + if PYDANTIC_VERSION_2: + return pydantic.TypeAdapter(Filter).validate_python(obj) # type: ignore + else: + return pydantic.parse_obj_as(Filter, obj) # type: ignore + + +# We need FilterDsl for two reasons: +# 1. To provide wrapper methods around lots of filters while avoid bloating the +# yaml spec. +# 2. Pydantic models in general don't support positional arguments, making the +# calls feel repetitive (e.g. Platform(platform=...)). +# See https://github.com/pydantic/pydantic/issues/6792 +# We also considered using dataclasses / pydantic dataclasses, but +# ultimately decided that they didn't quite suit our requirements, +# particularly with regards to the field aliases for and/or/not. +class FilterDsl: + @staticmethod + def and_(*args: "Filter") -> _And: + return _And(and_=list(args)) + + @staticmethod + def or_(*args: "Filter") -> _Or: + return _Or(or_=list(args)) + + @staticmethod + def not_(arg: "Filter") -> _Not: + return _Not(not_=arg) + + @staticmethod + def entity_type( + entity_type: Union[EntityTypeName, Sequence[EntityTypeName]], + ) -> _EntityTypeFilter: + return _EntityTypeFilter( + entity_type=( + [entity_type] if isinstance(entity_type, str) else list(entity_type) + ) + ) + + @staticmethod + def entity_subtype(entity_type: str, subtype: str) -> _EntitySubtypeFilter: + return _EntitySubtypeFilter( + entity_type=entity_type, + entity_subtype=subtype, + ) + + @staticmethod + def platform(platform: Union[str, List[str]], /) -> _PlatformFilter: + return _PlatformFilter( + platform=[platform] if isinstance(platform, str) else platform + ) + + # TODO: Add a platform_instance filter + + @staticmethod + def domain(domain: Union[str, List[str]], /) -> _DomainFilter: + return _DomainFilter(domain=[domain] if isinstance(domain, str) else domain) + + @staticmethod + def env(env: Union[str, List[str]], /) -> _EnvFilter: + return _EnvFilter(env=[env] if isinstance(env, str) else env) + + @staticmethod + def has_custom_property(key: str, value: str) -> _CustomCondition: + return _CustomCondition( + field="customProperties", + condition="EQUAL", + values=[f"{key}={value}"], + ) + + # TODO: Add a soft-deletion status filter + # TODO: add a container / browse path filter + # TODO add shortcut for custom filters + + @staticmethod + def custom_filter( + field: str, condition: str, values: List[str] + ) -> _CustomCondition: + return _CustomCondition( + field=field, + condition=condition, + values=values, + ) diff --git a/metadata-ingestion/tests/test_helpers/sdk_v2_helpers.py b/metadata-ingestion/tests/test_helpers/sdk_v2_helpers.py index 2736f925a4371..4b162b49e2a5e 100644 --- a/metadata-ingestion/tests/test_helpers/sdk_v2_helpers.py +++ b/metadata-ingestion/tests/test_helpers/sdk_v2_helpers.py @@ -1,6 +1,6 @@ import pathlib -from datahub.sdk._entity import Entity +from datahub.sdk.entity import Entity from tests.test_helpers import mce_helpers diff --git a/metadata-ingestion/tests/unit/sdk/test_client.py b/metadata-ingestion/tests/unit/sdk/test_client.py index 16795ef8c7f81..047c2d2df37e0 100644 --- a/metadata-ingestion/tests/unit/sdk/test_client.py +++ b/metadata-ingestion/tests/unit/sdk/test_client.py @@ -3,7 +3,7 @@ from datahub.ingestion.graph.client import ( DatahubClientConfig, DataHubGraph, - _graphql_entity_type, + entity_type_to_graphql, ) from datahub.metadata.schema_classes import CorpUserEditableInfoClass @@ -26,20 +26,22 @@ def test_get_aspect(mock_test_connection): assert editable is not None -def test_graphql_entity_types(): +def test_graphql_entity_types() -> None: # FIXME: This is a subset of all the types, but it's enough to get us ok coverage. - assert _graphql_entity_type("domain") == "DOMAIN" - assert _graphql_entity_type("dataset") == "DATASET" - assert _graphql_entity_type("dashboard") == "DASHBOARD" - assert _graphql_entity_type("chart") == "CHART" - - assert _graphql_entity_type("corpuser") == "CORP_USER" - assert _graphql_entity_type("corpGroup") == "CORP_GROUP" - - assert _graphql_entity_type("dataFlow") == "DATA_FLOW" - assert _graphql_entity_type("dataJob") == "DATA_JOB" - assert _graphql_entity_type("glossaryNode") == "GLOSSARY_NODE" - assert _graphql_entity_type("glossaryTerm") == "GLOSSARY_TERM" - - assert _graphql_entity_type("dataHubExecutionRequest") == "EXECUTION_REQUEST" + known_mappings = { + "domain": "DOMAIN", + "dataset": "DATASET", + "dashboard": "DASHBOARD", + "chart": "CHART", + "corpuser": "CORP_USER", + "corpGroup": "CORP_GROUP", + "dataFlow": "DATA_FLOW", + "dataJob": "DATA_JOB", + "glossaryNode": "GLOSSARY_NODE", + "glossaryTerm": "GLOSSARY_TERM", + "dataHubExecutionRequest": "EXECUTION_REQUEST", + } + + for entity_type, graphql_type in known_mappings.items(): + assert entity_type_to_graphql(entity_type) == graphql_type diff --git a/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py b/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py new file mode 100644 index 0000000000000..16819bb2d7fb4 --- /dev/null +++ b/metadata-ingestion/tests/unit/sdk_v2/test_search_client.py @@ -0,0 +1,214 @@ +from io import StringIO + +import pytest +import yaml +from pydantic import ValidationError + +from datahub.ingestion.graph.filters import SearchFilterRule +from datahub.sdk.search_client import compile_filters +from datahub.sdk.search_filters import Filter, FilterDsl as F, load_filters +from datahub.utilities.urns.error import InvalidUrnError + + +def test_filters_simple() -> None: + yaml_dict = {"platform": ["snowflake", "bigquery"]} + filter_obj: Filter = load_filters(yaml_dict) + assert filter_obj == F.platform(["snowflake", "bigquery"]) + assert filter_obj.compile() == [ + { + "and": [ + SearchFilterRule( + field="platform.keyword", + condition="EQUAL", + values=[ + "urn:li:dataPlatform:snowflake", + "urn:li:dataPlatform:bigquery", + ], + ) + ] + } + ] + + +def test_filters_and() -> None: + yaml_dict = { + "and": [ + {"env": ["PROD"]}, + {"platform": ["snowflake", "bigquery"]}, + ] + } + filter_obj: Filter = load_filters(yaml_dict) + assert filter_obj == F.and_( + F.env("PROD"), + F.platform(["snowflake", "bigquery"]), + ) + platform_rule = SearchFilterRule( + field="platform.keyword", + condition="EQUAL", + values=[ + "urn:li:dataPlatform:snowflake", + "urn:li:dataPlatform:bigquery", + ], + ) + assert filter_obj.compile() == [ + { + "and": [ + SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]), + platform_rule, + ] + }, + { + "and": [ + SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]), + platform_rule, + ] + }, + ] + + +def test_filters_complex() -> None: + yaml_dict = yaml.safe_load( + StringIO("""\ +and: + - env: [PROD] + - or: + - platform: [ snowflake, bigquery ] + - and: + - platform: [postgres] + - not: + domain: [urn:li:domain:analytics] + - field: customProperties + condition: EQUAL + values: ["dbt_unique_id=source.project.name"] +""") + ) + filter_obj: Filter = load_filters(yaml_dict) + assert filter_obj == F.and_( + F.env("PROD"), + F.or_( + F.platform(["snowflake", "bigquery"]), + F.and_( + F.platform("postgres"), + F.not_(F.domain("urn:li:domain:analytics")), + ), + F.has_custom_property("dbt_unique_id", "source.project.name"), + ), + ) + warehouse_rule = SearchFilterRule( + field="platform.keyword", + condition="EQUAL", + values=["urn:li:dataPlatform:snowflake", "urn:li:dataPlatform:bigquery"], + ) + postgres_rule = SearchFilterRule( + field="platform.keyword", + condition="EQUAL", + values=["urn:li:dataPlatform:postgres"], + ) + domain_rule = SearchFilterRule( + field="domains", + condition="EQUAL", + values=["urn:li:domain:analytics"], + negated=True, + ) + custom_property_rule = SearchFilterRule( + field="customProperties", + condition="EQUAL", + values=["dbt_unique_id=source.project.name"], + ) + + # There's one OR clause in the original filter with 3 clauses, + # and one hidden in the env filter with 2 clauses. + # The final result should have 3 * 2 = 6 OR clauses. + assert filter_obj.compile() == [ + { + "and": [ + SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]), + warehouse_rule, + ], + }, + { + "and": [ + SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]), + postgres_rule, + domain_rule, + ], + }, + { + "and": [ + SearchFilterRule(field="origin", condition="EQUAL", values=["PROD"]), + custom_property_rule, + ], + }, + { + "and": [ + SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]), + warehouse_rule, + ], + }, + { + "and": [ + SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]), + postgres_rule, + domain_rule, + ], + }, + { + "and": [ + SearchFilterRule(field="env", condition="EQUAL", values=["PROD"]), + custom_property_rule, + ], + }, + ] + + +def test_invalid_filter() -> None: + with pytest.raises(InvalidUrnError): + F.domain("marketing") + + +def test_unsupported_not() -> None: + env_filter = F.env("PROD") + with pytest.raises( + ValidationError, + match="Cannot negate a filter with multiple OR clauses", + ): + F.not_(env_filter) + + +def test_compile_filters() -> None: + filter = F.and_(F.env("PROD"), F.platform("snowflake")) + expected_filters = [ + { + "and": [ + { + "field": "origin", + "condition": "EQUAL", + "values": ["PROD"], + "negated": False, + }, + { + "field": "platform.keyword", + "condition": "EQUAL", + "values": ["urn:li:dataPlatform:snowflake"], + "negated": False, + }, + ] + }, + { + "and": [ + { + "field": "env", + "condition": "EQUAL", + "values": ["PROD"], + "negated": False, + }, + { + "field": "platform.keyword", + "condition": "EQUAL", + "values": ["urn:li:dataPlatform:snowflake"], + "negated": False, + }, + ] + }, + ] + assert compile_filters(filter) == expected_filters