Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 committed Mar 6, 2025
1 parent f770fef commit 298a651
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 22 deletions.
18 changes: 13 additions & 5 deletions metadata-ingestion/src/datahub/cli/check_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,12 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None:

def _jsonify(data: Any) -> Any:
if dataclasses.is_dataclass(data):
return dataclasses.asdict(data)
# dataclasses.asdict() is recursive, which is not what we want.
# We're doing the recursion manually here, so we don't want
# dataclasses to do it too.
return {
f.name: _jsonify(getattr(data, f.name)) for f in dataclasses.fields(data)
}
elif isinstance(data, list):
return [_jsonify(item) for item in data]
elif isinstance(data, dict):
Expand All @@ -417,9 +422,7 @@ def extract_sql_agg_log(db_file: str) -> None:
raise click.UsageError("DB file must be a sqlite db")

output_dir = pathlib.Path(db_file).with_suffix("")
if output_dir.exists():
raise click.UsageError(f"Output directory {output_dir} already exists")
output_dir.mkdir()
output_dir.mkdir(exist_ok=True)

shared_connection = ConnectionWrapper(pathlib.Path(db_file))

Expand All @@ -441,9 +444,15 @@ def extract_sql_agg_log(db_file: str) -> None:
logger.info(f"Extracting {len(tables)} tables from {db_file}: {tables}")

for table in tables:
table_output_path = output_dir / f"{table}.json"
if table_output_path.exists():
logger.info(f"Skipping {table_output_path} because it already exists")
continue

# Some of the tables might actually be FileBackedList. Because
# the list is built on top of the FileBackedDict, we don't
# need to distinguish between the two cases.

table_data: FileBackedDict[Any] = FileBackedDict(
shared_connection=shared_connection, tablename=table
)
Expand All @@ -455,7 +464,6 @@ def extract_sql_agg_log(db_file: str) -> None:
for k, v in items:
data[k] = _jsonify(v)

table_output_path = output_dir / f"{table}.json"
with open(table_output_path, "w") as f:
json.dump(data, f, indent=2, default=str)
logger.info(f"Extracted {len(data)} entries to {table_output_path}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[
{
{
"0": {
"query": "create table foo as select a, b from bar",
"session_id": null,
"timestamp": null,
"user": null,
"default_db": "dev",
"default_schema": "public"
}
]
}
37 changes: 23 additions & 14 deletions metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,13 @@ def _ts(ts: int) -> datetime:
return datetime.fromtimestamp(ts, tz=timezone.utc)


@freeze_time(FROZEN_TIME)
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
def make_basic_aggregator(store: bool = False) -> SqlParsingAggregator:
aggregator = SqlParsingAggregator(
platform="redshift",
generate_lineage=True,
generate_usage_statistics=False,
generate_operations=False,
query_log=QueryLogSetting.STORE_ALL,
query_log=QueryLogSetting.STORE_ALL if store else QueryLogSetting.DISABLED,
)

aggregator.add_observed_query(
Expand All @@ -59,26 +58,36 @@ def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> N
)
)

return aggregator


@freeze_time(FROZEN_TIME)
def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
aggregator = make_basic_aggregator()
mcps = list(aggregator.gen_metadata())

check_goldens_stream(
outputs=mcps,
golden_path=RESOURCE_DIR / "test_basic_lineage.json",
)

# This test also validates the query log storage functionality.

@freeze_time(FROZEN_TIME)
def test_aggregator_dump(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None:
# Validates the query log storage + extraction functionality.
aggregator = make_basic_aggregator(store=True)
aggregator.close()

query_log_db = aggregator.report.query_log_path
query_log_json = tmp_path / "query_log.json"
run_datahub_cmd(
[
"check",
"extract-sql-agg-log",
str(query_log_db),
"--output",
str(query_log_json),
]
)
assert query_log_db is not None

run_datahub_cmd(["check", "extract-sql-agg-log", query_log_db])

output_json_dir = pathlib.Path(query_log_db).with_suffix("")
assert (
len(list(output_json_dir.glob("*.json"))) > 5
) # 5 is arbitrary, but should have at least a couple tables
query_log_json = output_json_dir / "stored_queries.json"
mce_helpers.check_golden_file(
pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json"
)
Expand Down

0 comments on commit 298a651

Please sign in to comment.