-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ingestion-tracing): implement ingestion integration with tracing…
… api
- Loading branch information
1 parent
7bee19c
commit b8f315b
Showing
8 changed files
with
1,502 additions
and
16 deletions.
There are no files selected for viewing
188 changes: 188 additions & 0 deletions
188
metadata-ingestion/src/datahub/emitter/openapi_emitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
import json | ||
import logging | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from datetime import timedelta | ||
from typing import Dict, List, Optional, Sequence, Union | ||
|
||
from requests import Response | ||
|
||
from datahub.cli.cli_utils import ensure_has_system_metadata | ||
from datahub.emitter.mcp import MetadataChangeProposalWrapper | ||
from datahub.emitter.openapi_tracer import OpenAPITrace | ||
from datahub.emitter.response_helper import extract_trace_data | ||
from datahub.emitter.rest_emitter import ( | ||
_DATAHUB_EMITTER_TRACE, | ||
BATCH_INGEST_MAX_PAYLOAD_LENGTH, | ||
INGEST_MAX_PAYLOAD_BYTES, | ||
DataHubRestEmitter, | ||
) | ||
from datahub.emitter.serialization_helper import pre_json_transform | ||
from datahub.metadata.com.linkedin.pegasus2avro.mxe import ( | ||
MetadataChangeProposal, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class Chunk: | ||
items: List[str] | ||
total_bytes: int = 0 | ||
|
||
def add_item(self, item: str) -> bool: | ||
item_bytes = len(item.encode()) | ||
if not self.items: # Always add at least one item even if over byte limit | ||
self.items.append(item) | ||
self.total_bytes += item_bytes | ||
return True | ||
self.items.append(item) | ||
self.total_bytes += item_bytes | ||
return True | ||
|
||
@staticmethod | ||
def join(chunk: "Chunk") -> str: | ||
return "[" + ",".join(chunk.items) + "]" | ||
|
||
|
||
class DataHubOpenApiEmitter(DataHubRestEmitter, OpenAPITrace): | ||
def __init__( | ||
self, | ||
gms_server: str, | ||
token: Optional[str] = None, | ||
timeout_sec: Optional[float] = None, | ||
connect_timeout_sec: Optional[float] = None, | ||
read_timeout_sec: Optional[float] = None, | ||
retry_status_codes: Optional[List[int]] = None, | ||
retry_methods: Optional[List[str]] = None, | ||
retry_max_times: Optional[int] = None, | ||
extra_headers: Optional[Dict[str, str]] = None, | ||
ca_certificate_path: Optional[str] = None, | ||
client_certificate_path: Optional[str] = None, | ||
disable_ssl_verification: bool = False, | ||
default_trace_mode: bool = False, | ||
): | ||
super().__init__( | ||
gms_server, | ||
token, | ||
timeout_sec, | ||
connect_timeout_sec, | ||
read_timeout_sec, | ||
retry_status_codes, | ||
retry_methods, | ||
retry_max_times, | ||
extra_headers, | ||
ca_certificate_path, | ||
client_certificate_path, | ||
disable_ssl_verification, | ||
default_trace_mode, | ||
) | ||
|
||
def _emit_generic(self, url: str, payload: dict) -> Response: | ||
return super()._emit_generic(url, payload=json.dumps(payload)) | ||
|
||
def _to_request( | ||
self, | ||
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper], | ||
async_flag: Optional[bool] = None, | ||
async_default: bool = False, | ||
): | ||
resolved_async_flag = async_flag if async_flag is not None else async_default | ||
url = f"{self._gms_server}/openapi/v3/entity/{mcp.entityType}?async={'true' if resolved_async_flag else 'false'}" | ||
ensure_has_system_metadata(mcp) | ||
aspect_value = pre_json_transform(mcp.aspect.to_obj()) | ||
return ( | ||
url, | ||
[ | ||
{ | ||
"urn": mcp.entityUrn, | ||
mcp.aspectName: { | ||
"value": aspect_value, | ||
"systemMetadata": mcp.systemMetadata.to_obj(), | ||
}, | ||
} | ||
], | ||
) | ||
|
||
def emit_mcp( | ||
self, | ||
mcp: Union[MetadataChangeProposal, MetadataChangeProposalWrapper], | ||
async_flag: Optional[bool] = None, | ||
trace_flag: Optional[bool] = None, | ||
trace_timeout: Optional[timedelta] = timedelta(seconds=3600), | ||
) -> None: | ||
request = self._to_request(mcp, async_flag) | ||
|
||
response = self._emit_generic(request[0], payload=request[1]) | ||
|
||
if self._should_trace(async_flag, trace_flag): | ||
trace_data = extract_trace_data(response) if response else None | ||
if trace_data: | ||
self.await_status([trace_data], trace_timeout) | ||
|
||
def emit_mcps( | ||
self, | ||
mcps: Sequence[Union[MetadataChangeProposal, MetadataChangeProposalWrapper]], | ||
async_flag: Optional[bool] = None, | ||
trace_flag: Optional[bool] = None, | ||
trace_timeout: Optional[timedelta] = timedelta(seconds=3600), | ||
) -> int: | ||
""" | ||
1. Grouping MCPs by their entity URL | ||
2. Breaking down large batches into smaller chunks based on both: | ||
* Total byte size (INGEST_MAX_PAYLOAD_BYTES) | ||
* Maximum number of items (BATCH_INGEST_MAX_PAYLOAD_LENGTH) | ||
The Chunk class encapsulates both the items and their byte size tracking | ||
Serializing the items only once with json.dumps(request[1]) and reusing that | ||
The chunking logic handles edge cases (always accepting at least one item per chunk) | ||
The joining logic is efficient with a simple string concatenation | ||
:param mcps: metadata change proposals to transmit | ||
:param async_flag: the mode | ||
:return: | ||
""" | ||
if _DATAHUB_EMITTER_TRACE: | ||
logger.debug(f"Attempting to emit MCP batch of size {len(mcps)}") | ||
|
||
# group by entity url | ||
batches = defaultdict( | ||
lambda: [Chunk(items=[])] | ||
) # Initialize with one empty Chunk | ||
|
||
for mcp in mcps: | ||
request = self._to_request(mcp, async_flag, async_default=True) | ||
current_chunk = batches[request[0]][-1] # Get the last chunk | ||
# Only serialize once | ||
serialized_item = json.dumps(request[1][0]) | ||
item_bytes = len(serialized_item.encode()) | ||
|
||
# If adding this item would exceed max_bytes, create a new chunk | ||
# Unless the chunk is empty (always add at least one item) | ||
if current_chunk.items and ( | ||
current_chunk.total_bytes + item_bytes > INGEST_MAX_PAYLOAD_BYTES | ||
or len(current_chunk.items) >= BATCH_INGEST_MAX_PAYLOAD_LENGTH | ||
): | ||
new_chunk = Chunk(items=[]) | ||
batches[request[0]].append(new_chunk) | ||
current_chunk = new_chunk | ||
|
||
current_chunk.add_item(serialized_item) | ||
|
||
responses = [] | ||
for url, chunks in batches.items(): | ||
for chunk in chunks: | ||
response = super()._emit_generic(url, payload=Chunk.join(chunk)) | ||
responses.append(response) | ||
|
||
if self._should_trace(async_flag, trace_flag, async_default=True): | ||
trace_data = [] | ||
for response in responses: | ||
data = extract_trace_data(response) if response else None | ||
if data is not None: | ||
trace_data.append(data) | ||
|
||
if trace_data: | ||
self.await_status(trace_data, trace_timeout) | ||
|
||
return len(responses) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import logging | ||
import time | ||
from datetime import datetime, timedelta | ||
from typing import List | ||
|
||
from datahub.configuration.common import ( | ||
OperationalError, | ||
) | ||
from datahub.emitter.response_helper import TraceData | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
PENDING_STATUS = "PENDING" | ||
INITIAL_BACKOFF = 1.0 # Start with 1 second | ||
MAX_BACKOFF = 300.0 # Cap at 5 minutes | ||
BACKOFF_FACTOR = 2.0 # Double the wait time each attempt | ||
|
||
|
||
class OpenAPITrace: | ||
def await_status( | ||
self, | ||
trace_data: List[TraceData], | ||
trace_timeout: timedelta, | ||
) -> None: | ||
"""Verify the status of asynchronous write operations. | ||
Args: | ||
trace_data: List of trace data to verify | ||
trace_timeout: Maximum time to wait for verification. | ||
Raises: | ||
OperationalError: If verification fails or times out | ||
""" | ||
try: | ||
if not trace_data: | ||
logger.debug("No trace data to verify") | ||
return | ||
|
||
start_time = datetime.now() | ||
|
||
for trace in trace_data: | ||
current_backoff = INITIAL_BACKOFF | ||
|
||
while trace.data: | ||
if datetime.now() - start_time > trace_timeout: | ||
raise OperationalError( | ||
f"Timeout waiting for async write completion after {trace_timeout.total_seconds()} seconds" | ||
) | ||
|
||
base_url = f"{self._gms_server}/openapi/v1/trace/write" | ||
url = f"{base_url}/{trace.trace_id}?onlyIncludeErrors=false&detailed=true" | ||
|
||
response = self._emit_generic(url, payload=trace.data) | ||
json_data = response.json() | ||
|
||
for urn, aspects in json_data.items(): | ||
for aspect_name, aspect_status in aspects.items(): | ||
if not aspect_status["success"]: | ||
error_msg = ( | ||
f"Unable to validate async write to DataHub GMS: " | ||
f"Persistence failure for URN '{urn}' aspect '{aspect_name}'. " | ||
f"Status: {aspect_status}" | ||
) | ||
raise OperationalError(error_msg, aspect_status) | ||
|
||
primary_storage = aspect_status["primaryStorage"][ | ||
"writeStatus" | ||
] | ||
search_storage = aspect_status["searchStorage"][ | ||
"writeStatus" | ||
] | ||
|
||
# Remove resolved statuses | ||
if ( | ||
primary_storage != PENDING_STATUS | ||
and search_storage != PENDING_STATUS | ||
): | ||
trace.data[urn].remove(aspect_name) | ||
|
||
# Remove urns with all statuses resolved | ||
if not trace.data[urn]: | ||
trace.data.pop(urn) | ||
|
||
# Adjust backoff based on response | ||
if trace.data: | ||
# If we still have pending items, increase backoff | ||
current_backoff = min( | ||
current_backoff * BACKOFF_FACTOR, MAX_BACKOFF | ||
) | ||
logger.debug( | ||
f"Waiting {current_backoff} seconds before next check" | ||
) | ||
time.sleep(current_backoff) | ||
|
||
except Exception as e: | ||
logger.error(f"Error during status verification: {str(e)}") | ||
raise | ||
Oops, something went wrong.