Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): improve extract-sql-agg-log command #12803

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions metadata-ingestion/src/datahub/cli/check_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pprint
import shutil
import tempfile
from typing import Dict, List, Optional, Union
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

import click

Expand All @@ -20,7 +21,10 @@
from datahub.ingestion.source.source_registry import source_registry
from datahub.ingestion.transformer.transform_registry import transform_registry
from datahub.telemetry import telemetry
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList
from datahub.utilities.file_backed_collections import (
ConnectionWrapper,
FileBackedDict,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -391,29 +395,78 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:
raise e


def _jsonify(data: Any) -> Any:
if dataclasses.is_dataclass(data):
# dataclasses.asdict() is recursive, which is not what we want.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we want dataclassess recursion if we do as well in recursion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the comment to clarify

# We're doing the recursion manually here, so we don't want
# dataclasses to do it too.
return {
f.name: _jsonify(getattr(data, f.name)) for f in dataclasses.fields(data)
}
elif isinstance(data, list):
return [_jsonify(item) for item in data]
elif isinstance(data, dict):
return {_jsonify(k): _jsonify(v) for k, v in data.items()}
elif isinstance(data, datetime):
return data.isoformat()
else:
return data


@check.command()
@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False))
@click.option("--output", type=click.Path())
def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None:
@click.argument("db-file", type=click.Path(exists=True, dir_okay=False))
def extract_sql_agg_log(db_file: str) -> None:
"""Convert a sqlite db generated by the SqlParsingAggregator into a JSON."""

from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery
if pathlib.Path(db_file).suffix != ".db":
raise click.UsageError("DB file must be a sqlite db")

output_dir = pathlib.Path(db_file).with_suffix("")
output_dir.mkdir(exist_ok=True)

shared_connection = ConnectionWrapper(pathlib.Path(db_file))

tables: List[str] = [
row[0]
for row in shared_connection.execute(
"""\
SELECT
name
FROM
sqlite_schema
WHERE
type ='table' AND
name NOT LIKE 'sqlite_%';
""",
parameters={},
)
]
logger.info(f"Extracting {len(tables)} tables from {db_file}: {tables}")

for table in tables:
table_output_path = output_dir / f"{table}.json"
if table_output_path.exists():
logger.info(f"Skipping {table_output_path} because it already exists")
continue

assert dataclasses.is_dataclass(LoggedQuery)
# Some of the tables might actually be FileBackedList. Because
# the list is built on top of the FileBackedDict, we don't
# need to distinguish between the two cases.

shared_connection = ConnectionWrapper(pathlib.Path(query_log_file))
query_log = FileBackedList[LoggedQuery](
shared_connection=shared_connection, tablename="stored_queries"
)
logger.info(f"Extracting {len(query_log)} queries from {query_log_file}")
queries = [dataclasses.asdict(query) for query in query_log]
table_data: FileBackedDict[Any] = FileBackedDict(
shared_connection=shared_connection, tablename=table
)

if output:
with open(output, "w") as f:
json.dump(queries, f, indent=2, default=str)
logger.info(f"Extracted {len(queries)} queries to {output}")
else:
click.echo(json.dumps(queries, indent=2))
data = {}
with click.progressbar(
table_data.items(), length=len(table_data), label=f"Extracting {table}"
) as items:
for k, v in items:
data[k] = _jsonify(v)

with open(table_output_path, "w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(f"Extracted {len(data)} entries to {table_output_path}")


@check.command()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[
{
{
"0": {
"query": "create table foo as select a, b from bar",
"session_id": null,
"timestamp": null,
"user": null,
"default_db": "dev",
"default_schema": "public"
}
]
}
37 changes: 23 additions & 14 deletions metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ def _ts(ts: int) -> datetime:
return datetime.fromtimestamp(ts, tz=timezone.utc)


@freeze_time(FROZEN_TIME)
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
def make_basic_aggregator(store: bool = False) -> SqlParsingAggregator:
aggregator = SqlParsingAggregator(
platform="redshift",
generate_lineage=True,
generate_usage_statistics=False,
generate_operations=False,
query_log=QueryLogSetting.STORE_ALL,
query_log=QueryLogSetting.STORE_ALL if store else QueryLogSetting.DISABLED,
)

aggregator.add_observed_query(
Expand All @@ -59,26 +58,36 @@ def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> N
)
)

return aggregator


@freeze_time(FROZEN_TIME)
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
aggregator = make_basic_aggregator()
mcps = list(aggregator.gen_metadata())

check_goldens_stream(
outputs=mcps,
golden_path=RESOURCE_DIR / "test_basic_lineage.json",
)

# This test also validates the query log storage functionality.

@freeze_time(FROZEN_TIME)
def test_aggregator_dump(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
# Validates the query log storage + extraction functionality.
aggregator = make_basic_aggregator(store=True)
aggregator.close()

query_log_db = aggregator.report.query_log_path
query_log_json = tmp_path / "query_log.json"
run_datahub_cmd(
[
"check",
"extract-sql-agg-log",
str(query_log_db),
"--output",
str(query_log_json),
]
)
assert query_log_db is not None

run_datahub_cmd(["check", "extract-sql-agg-log", query_log_db])

output_json_dir = pathlib.Path(query_log_db).with_suffix("")
assert (
len(list(output_json_dir.glob("*.json"))) > 5
) # 5 is arbitrary, but should have at least a couple tables
query_log_json = output_json_dir / "stored_queries.json"
mce_helpers.check_golden_file(
pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json"
)
Expand Down
Loading