Skip to content

Commit

Permalink
feat(openapi-ingestion): implement openapi ingestion
Browse files Browse the repository at this point in the history
* enabled by default
  • Loading branch information
david-leifker committed Mar 3, 2025
1 parent a19edde commit 8b008a4
Show file tree
Hide file tree
Showing 4 changed files with 391 additions and 12 deletions.
161 changes: 151 additions & 10 deletions metadata-ingestion/src/datahub/emitter/rest_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import json
import logging
import os
from collections import defaultdict
from dataclasses import dataclass
from json.decoder import JSONDecodeError
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -31,6 +33,7 @@
ConfigurationError,
OperationalError,
)
from datahub.emitter.aspect import JSON_CONTENT_TYPE
from datahub.emitter.generic_emitter import Emitter
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.request_helper import make_curl_command
Expand Down Expand Up @@ -143,10 +146,31 @@ def build_session(self) -> requests.Session:
return session


@dataclass
class _Chunk:
items: List[str]
total_bytes: int = 0

def add_item(self, item: str) -> bool:
item_bytes = len(item.encode())
if not self.items: # Always add at least one item even if over byte limit
self.items.append(item)
self.total_bytes += item_bytes
return True
self.items.append(item)
self.total_bytes += item_bytes
return True

@staticmethod
def join(chunk: "_Chunk") -> str:
return "[" + ",".join(chunk.items) + "]"


class DataHubRestEmitter(Closeable, Emitter):
_gms_server: str
_token: Optional[str]
_session: requests.Session
_openapi_ingestion: bool

def __init__(
self,
Expand All @@ -162,6 +186,7 @@ def __init__(
ca_certificate_path: Optional[str] = None,
client_certificate_path: Optional[str] = None,
disable_ssl_verification: bool = False,
openapi_ingestion: bool = False,
):
if not gms_server:
raise ConfigurationError("gms server is required")
Expand All @@ -174,9 +199,13 @@ def __init__(
self._gms_server = fixup_gms_url(gms_server)
self._token = token
self.server_config: Dict[str, Any] = {}

self._openapi_ingestion = openapi_ingestion
self._session = requests.Session()

logger.debug(
f"Using {'OpenAPI' if openapi_ingestion else 'Restli'} for ingestion."
)

headers = {
"X-RestLi-Protocol-Version": "2.0.0",
"X-DataHub-Py-Cli-Version": nice_version_name(),
Expand Down Expand Up @@ -264,6 +293,43 @@ def to_graph(self) -> "DataHubGraph":

return DataHubGraph.from_emitter(self)

def _to_openapi_request(
self,
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
async_flag: Optional[bool] = None,
async_default: bool = False,
) -> Optional[Tuple[str, List[Dict[str, Any]]]]:
if mcp.aspect and mcp.aspectName:
resolved_async_flag = (
async_flag if async_flag is not None else async_default
)
url = f"{self._gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}"

if isinstance(mcp, MetadataChangeProposalWrapper):
aspect_value = pre_json_transform(
mcp.to_obj(simplified_structure=True)
)["aspect"]["json"]
else:
obj = mcp.aspect.to_obj()
if obj.get("value") and obj.get("contentType") == JSON_CONTENT_TYPE:
obj = json.loads(obj["value"])
aspect_value = pre_json_transform(obj)
return (
url,
[
{
"urn": mcp.entityUrn,
mcp.aspectName: {
"value": aspect_value,
"systemMetadata": mcp.systemMetadata.to_obj()
if mcp.systemMetadata
else None,
},
}
],
)
return None

def emit(
self,
item: Union[
Expand Down Expand Up @@ -317,18 +383,24 @@ def emit_mcp(
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper],
async_flag: Optional[bool] = None,
) -> None:
url = f"{self._gms_server}/aspects?action=ingestProposal"
ensure_has_system_metadata(mcp)

mcp_obj = pre_json_transform(mcp.to_obj())
payload_dict = {"proposal": mcp_obj}
if self._openapi_ingestion:
request = self._to_openapi_request(mcp, async_flag, async_default=False)
if request:
self._emit_generic(request[0], payload=request[1])
else:
url = f"{self._gms_server}/aspects?action=ingestProposal"

if async_flag is not None:
payload_dict["async"] = "true" if async_flag else "false"
mcp_obj = pre_json_transform(mcp.to_obj())
payload_dict = {"proposal": mcp_obj}

payload = json.dumps(payload_dict)
if async_flag is not None:
payload_dict["async"] = "true" if async_flag else "false"

self._emit_generic(url, payload)
payload = json.dumps(payload_dict)

self._emit_generic(url, payload)

def emit_mcps(
self,
Expand All @@ -337,10 +409,75 @@ def emit_mcps(
) -> int:
if _DATAHUB_EMITTER_TRACE:
logger.debug(f"Attempting to emit MCP batch of size {len(mcps)}")
url = f"{self._gms_server}/aspects?action=ingestProposalBatch"

for mcp in mcps:
ensure_has_system_metadata(mcp)

if self._openapi_ingestion:
return self._emit_openapi_mcps(mcps, async_flag)
else:
return self._emit_restli_mcps(mcps, async_flag)

def _emit_openapi_mcps(
self,
mcps: Sequence[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]],
async_flag: Optional[bool] = None,
) -> int:
"""
1. Grouping MCPs by their entity URL
2. Breaking down large batches into smaller chunks based on both:
* Total byte size (INGEST_MAX_PAYLOAD_BYTES)
* Maximum number of items (BATCH_INGEST_MAX_PAYLOAD_LENGTH)
The Chunk class encapsulates both the items and their byte size tracking
Serializing the items only once with json.dumps(request[1]) and reusing that
The chunking logic handles edge cases (always accepting at least one item per chunk)
The joining logic is efficient with a simple string concatenation
:param mcps: metadata change proposals to transmit
:param async_flag: the mode
:return: number of requests
"""
# group by entity url
batches: Dict[str, List[_Chunk]] = defaultdict(
lambda: [_Chunk(items=[])]
) # Initialize with one empty Chunk

for mcp in mcps:
request = self._to_openapi_request(mcp, async_flag, async_default=True)
if request:
current_chunk = batches[request[0]][-1] # Get the last chunk
# Only serialize once
serialized_item = json.dumps(request[1][0])
item_bytes = len(serialized_item.encode())

# If adding this item would exceed max_bytes, create a new chunk
# Unless the chunk is empty (always add at least one item)
if current_chunk.items and (
current_chunk.total_bytes + item_bytes > INGEST_MAX_PAYLOAD_BYTES
or len(current_chunk.items) >= BATCH_INGEST_MAX_PAYLOAD_LENGTH
):
new_chunk = _Chunk(items=[])
batches[request[0]].append(new_chunk)
current_chunk = new_chunk

current_chunk.add_item(serialized_item)

responses = []
for url, chunks in batches.items():
for chunk in chunks:
response = self._emit_generic(url, payload=_Chunk.join(chunk))
responses.append(response)

return len(responses)

def _emit_restli_mcps(
self,
mcps: Sequence[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]],
async_flag: Optional[bool] = None,
) -> int:
url = f"{self._gms_server}/aspects?action=ingestProposalBatch"

mcp_objs = [pre_json_transform(mcp.to_obj()) for mcp in mcps]

# As a safety mechanism, we need to make sure we don't exceed the max payload size for GMS.
Expand Down Expand Up @@ -392,7 +529,10 @@ def emit_usage(self, usageStats: UsageAggregation) -> None:
payload = json.dumps(snapshot)
self._emit_generic(url, payload)

def _emit_generic(self, url: str, payload: str) -> None:
def _emit_generic(self, url: str, payload: Union[str, Any]) -> requests.Response:
if not isinstance(payload, str):
payload = json.dumps(payload)

curl_command = make_curl_command(self._session, "POST", url, payload)
payload_size = len(payload)
if payload_size > INGEST_MAX_PAYLOAD_BYTES:
Expand All @@ -408,6 +548,7 @@ def _emit_generic(self, url: str, payload: str) -> None:
try:
response = self._session.post(url, data=payload)
response.raise_for_status()
return response
except HTTPError as e:
try:
info: Dict = response.json()
Expand Down
15 changes: 15 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
)


class RestSinkEndpoint(ConfigEnum):
RESTLI = auto()
OPENAPI = auto()


class RestSinkMode(ConfigEnum):
SYNC = auto()
ASYNC = auto()
Expand All @@ -64,8 +69,15 @@ class RestSinkMode(ConfigEnum):
)


_DEFAULT_REST_SINK_ENDPOINT = pydantic.parse_obj_as(
RestSinkEndpoint,
os.getenv("DATAHUB_REST_SINK_DEFAULT_ENDPOINT", RestSinkEndpoint.RESTLI),
)


class DatahubRestSinkConfig(DatahubClientConfig):
mode: RestSinkMode = _DEFAULT_REST_SINK_MODE
endpoint: RestSinkEndpoint = _DEFAULT_REST_SINK_ENDPOINT

# These only apply in async modes.
max_threads: pydantic.PositiveInt = _DEFAULT_REST_SINK_MAX_THREADS
Expand Down Expand Up @@ -172,6 +184,9 @@ def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter:
ca_certificate_path=config.ca_certificate_path,
client_certificate_path=config.client_certificate_path,
disable_ssl_verification=config.disable_ssl_verification,
openapi_ingestion=True
if config.endpoint == RestSinkEndpoint.OPENAPI
else False,
)

@property
Expand Down
Loading

0 comments on commit 8b008a4

Please sign in to comment.