diff --git a/metadata-ingestion/src/datahub/cli/check_cli.py b/metadata-ingestion/src/datahub/cli/check_cli.py index 6b3124fc37393a..b47dfdaf95e9f5 100644 --- a/metadata-ingestion/src/datahub/cli/check_cli.py +++ b/metadata-ingestion/src/datahub/cli/check_cli.py @@ -5,7 +5,8 @@ import pprint import shutil import tempfile -from typing import Dict, List, Optional, Union +from datetime import datetime +from typing import Any, Dict, List, Optional, Union import click @@ -20,7 +21,10 @@ from datahub.ingestion.source.source_registry import source_registry from datahub.ingestion.transformer.transform_registry import transform_registry from datahub.telemetry import telemetry -from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedList +from datahub.utilities.file_backed_collections import ( + ConnectionWrapper, + FileBackedDict, +) logger = logging.getLogger(__name__) @@ -391,29 +395,78 @@ def test_path_spec(config: str, input: str, path_spec_key: str) -> None: raise e +def _jsonify(data: Any) -> Any: + if dataclasses.is_dataclass(data): + # dataclasses.asdict() is recursive. We're doing the recursion + # manually here via _jsonify calls, so we can't use + # dataclasses.asdict() here. + return { + f.name: _jsonify(getattr(data, f.name)) for f in dataclasses.fields(data) + } + elif isinstance(data, list): + return [_jsonify(item) for item in data] + elif isinstance(data, dict): + return {_jsonify(k): _jsonify(v) for k, v in data.items()} + elif isinstance(data, datetime): + return data.isoformat() + else: + return data + + @check.command() -@click.argument("query-log-file", type=click.Path(exists=True, dir_okay=False)) -@click.option("--output", type=click.Path()) -def extract_sql_agg_log(query_log_file: str, output: Optional[str]) -> None: +@click.argument("db-file", type=click.Path(exists=True, dir_okay=False)) +def extract_sql_agg_log(db_file: str) -> None: """Convert a sqlite db generated by the SqlParsingAggregator into a JSON.""" - from datahub.sql_parsing.sql_parsing_aggregator import LoggedQuery + if pathlib.Path(db_file).suffix != ".db": + raise click.UsageError("DB file must be a sqlite db") + + output_dir = pathlib.Path(db_file).with_suffix("") + output_dir.mkdir(exist_ok=True) + + shared_connection = ConnectionWrapper(pathlib.Path(db_file)) + + tables: List[str] = [ + row[0] + for row in shared_connection.execute( + """\ +SELECT + name +FROM + sqlite_schema +WHERE + type ='table' AND + name NOT LIKE 'sqlite_%'; +""", + parameters={}, + ) + ] + logger.info(f"Extracting {len(tables)} tables from {db_file}: {tables}") + + for table in tables: + table_output_path = output_dir / f"{table}.json" + if table_output_path.exists(): + logger.info(f"Skipping {table_output_path} because it already exists") + continue - assert dataclasses.is_dataclass(LoggedQuery) + # Some of the tables might actually be FileBackedList. Because + # the list is built on top of the FileBackedDict, we don't + # need to distinguish between the two cases. - shared_connection = ConnectionWrapper(pathlib.Path(query_log_file)) - query_log = FileBackedList[LoggedQuery]( - shared_connection=shared_connection, tablename="stored_queries" - ) - logger.info(f"Extracting {len(query_log)} queries from {query_log_file}") - queries = [dataclasses.asdict(query) for query in query_log] + table_data: FileBackedDict[Any] = FileBackedDict( + shared_connection=shared_connection, tablename=table + ) - if output: - with open(output, "w") as f: - json.dump(queries, f, indent=2, default=str) - logger.info(f"Extracted {len(queries)} queries to {output}") - else: - click.echo(json.dumps(queries, indent=2)) + data = {} + with click.progressbar( + table_data.items(), length=len(table_data), label=f"Extracting {table}" + ) as items: + for k, v in items: + data[k] = _jsonify(v) + + with open(table_output_path, "w") as f: + json.dump(data, f, indent=2, default=str) + logger.info(f"Extracted {len(data)} entries to {table_output_path}") @check.command() diff --git a/metadata-ingestion/tests/unit/sql_parsing/aggregator_goldens/test_basic_lineage_query_log.json b/metadata-ingestion/tests/unit/sql_parsing/aggregator_goldens/test_basic_lineage_query_log.json index e8e72bf25d3039..770e3220630235 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/aggregator_goldens/test_basic_lineage_query_log.json +++ b/metadata-ingestion/tests/unit/sql_parsing/aggregator_goldens/test_basic_lineage_query_log.json @@ -1,5 +1,5 @@ -[ - { +{ + "0": { "query": "create table foo as select a, b from bar", "session_id": null, "timestamp": null, @@ -7,4 +7,4 @@ "default_db": "dev", "default_schema": "public" } -] \ No newline at end of file +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py index 6768a431eb9f88..19e28a67ba061d 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py @@ -41,14 +41,13 @@ def _ts(ts: int) -> datetime: return datetime.fromtimestamp(ts, tz=timezone.utc) -@freeze_time(FROZEN_TIME) -def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None: +def make_basic_aggregator(store: bool = False) -> SqlParsingAggregator: aggregator = SqlParsingAggregator( platform="redshift", generate_lineage=True, generate_usage_statistics=False, generate_operations=False, - query_log=QueryLogSetting.STORE_ALL, + query_log=QueryLogSetting.STORE_ALL if store else QueryLogSetting.DISABLED, ) aggregator.add_observed_query( @@ -59,6 +58,12 @@ def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> N ) ) + return aggregator + + +@freeze_time(FROZEN_TIME) +def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None: + aggregator = make_basic_aggregator() mcps = list(aggregator.gen_metadata()) check_goldens_stream( @@ -66,19 +71,23 @@ def test_basic_lineage(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> N golden_path=RESOURCE_DIR / "test_basic_lineage.json", ) - # This test also validates the query log storage functionality. + +@freeze_time(FROZEN_TIME) +def test_aggregator_dump(pytestconfig: pytest.Config, tmp_path: pathlib.Path) -> None: + # Validates the query log storage + extraction functionality. + aggregator = make_basic_aggregator(store=True) aggregator.close() + query_log_db = aggregator.report.query_log_path - query_log_json = tmp_path / "query_log.json" - run_datahub_cmd( - [ - "check", - "extract-sql-agg-log", - str(query_log_db), - "--output", - str(query_log_json), - ] - ) + assert query_log_db is not None + + run_datahub_cmd(["check", "extract-sql-agg-log", query_log_db]) + + output_json_dir = pathlib.Path(query_log_db).with_suffix("") + assert ( + len(list(output_json_dir.glob("*.json"))) > 5 + ) # 5 is arbitrary, but should have at least a couple tables + query_log_json = output_json_dir / "stored_queries.json" mce_helpers.check_golden_file( pytestconfig, query_log_json, RESOURCE_DIR / "test_basic_lineage_query_log.json" )