diff --git a/fhir_query/__init__.py b/fhir_query/__init__.py index ba023f3..0cc58eb 100644 --- a/fhir_query/__init__.py +++ b/fhir_query/__init__.py @@ -3,10 +3,10 @@ import json import logging import sqlite3 -import sys +from nested_lookup import nested_lookup import tempfile from collections import defaultdict -from typing import Generator, Any, Optional +from typing import Any, Optional, Callable import httpx from dotty_dict import dotty @@ -47,15 +47,23 @@ def _initialize_table(self) -> None: CREATE TABLE IF NOT EXISTS resources ( id VARCHAR NOT NULL, resource_type VARCHAR NOT NULL, + key VARCHAR NOT NULL, resource JSON NOT NULL, PRIMARY KEY (id, resource_type) ) - """ + """ + ) + self.connection.execute( + """ + CREATE INDEX IF NOT EXISTS idx_resources_key + ON resources (key) + """ ) def add(self, resource: dict[str, Any]) -> None: """ Add a resource to the 'resources' table. + Add a resource to the 'resources' table. :param resource: A dictionary with 'id', 'resourceType', and other fields. """ if "id" not in resource or "resourceType" not in resource: @@ -65,10 +73,15 @@ def add(self, resource: dict[str, Any]) -> None: with self.connection: self.connection.execute( """ - INSERT INTO resources (id, resource_type, resource) - VALUES (?, ?, ?) + INSERT INTO resources (id, resource_type, key, resource) + VALUES (?, ?, ?, ?) """, - (resource["id"], resource["resourceType"], json.dumps(resource)), + ( + resource["id"], + resource["resourceType"], + f'{resource["resourceType"]}/{resource["id"]}', + json.dumps(resource), + ), ) self.adds_counters[resource["resourceType"]] += 1 except sqlite3.IntegrityError as e: @@ -132,6 +145,38 @@ def close(self) -> None: """ self.connection.close() + def aggregate(self) -> dict: + """Aggregate metadata counts resourceType(count)-count->resourceType(count).""" + + nested_dict: Callable[[], defaultdict[str, defaultdict]] = lambda: defaultdict(defaultdict) + + count_resource_types = self.count_resource_types() + + summary = nested_dict() + + for resource_type in count_resource_types: + resources = self.all_resources(resource_type) + for _ in resources: + + if "count" not in summary[resource_type]: + summary[resource_type]["count"] = 0 + summary[resource_type]["count"] += 1 + + refs = nested_lookup("reference", _) + for ref in refs: + # A codeable reference is an object with a codeable concept and a reference + if isinstance(ref, dict): + ref = ref["reference"] + ref_resource_type = ref.split("/")[0] + if "references" not in summary[resource_type]: + summary[resource_type]["references"] = nested_dict() + dst = summary[resource_type]["references"][ref_resource_type] + if "count" not in dst: + dst["count"] = 0 + dst["count"] += 1 + + return summary + class GraphDefinitionRunner(ResourceDB): """ diff --git a/fhir_query/cli.py b/fhir_query/cli.py index b56e1fe..0ca9f81 100644 --- a/fhir_query/cli.py +++ b/fhir_query/cli.py @@ -4,20 +4,28 @@ import sys import click +from click_default_group import DefaultGroup import yaml from fhir.resources.graphdefinition import GraphDefinition from halo import Halo from fhir_query import GraphDefinitionRunner, setup_logging +from fhir_query.visualizer import visualize_aggregation -@click.command() +@click.group(cls=DefaultGroup, default="main") +def cli(): + """Run FHIR GraphDefinition traversal.""" + pass + + +@cli.command() @click.option("--fhir-base-url", required=True, help="Base URL of the FHIR server.") @click.option("--graph-definition-id", help="ID of the GraphDefinition.") @click.option("--graph-definition-file-path", help="Path to the GraphDefinition JSON file.") @click.option("--start-resource-type", required=True, help="ResourceType to start traversal.") @click.option("--start-resource-id", required=True, help="ID of the starting resource.") -@click.option("--db_path", default="/tmp/fhir-graph.sqlite", help="path to of sqlite db") +@click.option("--db-path", default="/tmp/fhir-graph.sqlite", help="path to of sqlite db") @click.option( "--dry-run", is_flag=True, @@ -37,7 +45,7 @@ def main( dry_run: bool, log_file: str, ) -> None: - """Main function to run the FHIR GraphDefinition traversal.""" + """Run FHIR GraphDefinition traversal.""" setup_logging(debug, log_file) @@ -84,5 +92,37 @@ async def run_runner() -> None: raise e +@cli.command(name="visualize") +@click.option("--db-path", default="/tmp/fhir-graph.sqlite", help="path to sqlite db") +@click.option("--output-path", default="/tmp/fhir-graph.html", help="path output html") +def visualize(db_path: str, output_path: str) -> None: + """Visualize the aggregation results.""" + from fhir_query import ResourceDB + + try: + db = ResourceDB(db_path=db_path) + visualize_aggregation(db.aggregate(), output_path) + except Exception as e: + logging.error(f"Error: {e}", exc_info=True) + click.echo(f"Error: {e}", file=sys.stderr) + # raise e + + +@cli.command(name="summarize") +@click.option("--db-path", default="/tmp/fhir-graph.sqlite", help="path to sqlite db") +def summarize(db_path: str) -> None: + """Summarize the aggregation results.""" + from fhir_query import ResourceDB + + try: + db = ResourceDB(db_path=db_path) + yaml.dump(json.loads(json.dumps(db.aggregate())), sys.stdout, default_flow_style=False) + + except Exception as e: + logging.error(f"Error: {e}", exc_info=True) + click.echo(f"Error: {e}", file=sys.stderr) + # raise e + + if __name__ == "__main__": - main() + cli() diff --git a/fhir_query/visualizer.py b/fhir_query/visualizer.py new file mode 100644 index 0000000..f0cd732 --- /dev/null +++ b/fhir_query/visualizer.py @@ -0,0 +1,45 @@ +from pyvis.network import Network + +from fhir_query import ResourceDB + + +def _container(): + """Create a pyvis container.""" + return Network(notebook=True, cdn_resources="in_line") # filter_menu=True, select_menu=True + + +def _load(net: Network, aggregation: dict) -> Network: + """Load the aggregation into the visualization network.""" + # add vertices + for resource_type, _ in aggregation.items(): + assert "count" in _, _ + net.add_node(resource_type, label=f"{resource_type}/{_['count']}") + # add edges + for resource_type, _ in aggregation.items(): + for ref in _.get("references", {}): + count = _["references"][ref]["count"] + if resource_type not in net.get_nodes(): + net.add_node(resource_type, label=f"{resource_type}/?") + if ref not in net.get_nodes(): + net.add_node(ref, label=f"{ref}/?") + net.add_edge(resource_type, ref, title=count, value=count) + return net + + +def visualize_aggregation(aggregation: dict, output_path: str) -> None: + """Visualize the aggregation.""" + # Load it into a pyvis + net = _load(_container(), aggregation) + net.save_graph(str(output_path)) + net.show_buttons(filter_=["physics"]) + + +def create_network_graph(db_path: str, output_path: str) -> None: + """Render metadata as a network graph into output_path. + + \b + db_path: The directory path to the db. + output_path: The path to save the network graph. + """ + db = ResourceDB(db_path=db_path) + visualize_aggregation(db.aggregate(), output_path) diff --git a/pytest.ini b/pytest.ini index c5d9aab..2a8c41a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,4 +4,4 @@ filterwarnings = ignore::DeprecationWarning:halo.* markers = asyncio: asyncio mark -addopts = --cov=fhir_query --cov-report=term-missing \ No newline at end of file +; addopts = --cov=fhir_query --cov-report=term-missing \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9440607..f4425e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,6 @@ httpx fhir.resources==8.0.0b4 dotty-dict +nested_lookup +pyvis +click_default_group \ No newline at end of file diff --git a/setup.py b/setup.py index 4ddfa98..fb949ce 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ def parse_requirements(filename: str) -> list[str]: install_requires=parse_requirements("requirements.txt"), entry_points={ "console_scripts": [ - "fq=fhir_query:cli.main", + "fq=fhir_query:cli.cli", ], }, ) diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index fc7790b..95de108 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -1,19 +1,40 @@ from click.testing import CliRunner -from fhir_query.cli import main +from fhir_query.cli import cli + + +def test_default_option() -> None: + """Test default option.""" + runner = CliRunner() + result = runner.invoke(cli, ["--help"]) + print(result.output) + output = result.output + for _ in ["main*", "summarize", "visualize"]: + assert _ in output, f"Expected {_} in {output}" def test_help_option() -> None: """Test help option.""" runner = CliRunner() - result = runner.invoke(main, "--help") + result = runner.invoke(cli, ["main", "--help"]) output = result.output + print(output) assert "Usage:" in output assert "--fhir-base-url" in output assert "--graph-definition-id" in output assert "--graph-definition-file-path" in output assert "--start-resource-type" in output assert "--start-resource-id" in output - assert "--db_path" in output + assert "--db-path" in output assert "--dry-run" in output assert "--debug" in output assert "--log-file" in output + + +def test_visualize_help() -> None: + """Test visualizer help.""" + runner = CliRunner() + result = runner.invoke(cli, ["visualize", "--help"]) + output = result.output + assert "Usage:" in output + assert "--db-path" in output + assert "--output-path" in output diff --git a/tests/unit/test_mock_server.py b/tests/unit/test_mock_server.py index 2fc4f9f..864d196 100644 --- a/tests/unit/test_mock_server.py +++ b/tests/unit/test_mock_server.py @@ -6,6 +6,7 @@ from click.testing import CliRunner from fhir_query.cli import main +from fhir_query.visualizer import visualize_aggregation @pytest.mark.usefixtures("mock_fhir_server") @@ -36,7 +37,7 @@ def test_runner(tmp_path: str) -> None: "ResearchStudy", "--start-resource-id", "123", - "--db_path", + "--db-path", f"{tmp_path}/fhir-query.sqlite", "--graph-definition-file-path", "tests/fixtures/GraphDefinition.yaml", @@ -59,3 +60,14 @@ def test_runner(tmp_path: str) -> None: db = ResourceDB(f"{tmp_path}/fhir-query.sqlite") assert db.count_resource_types() == {"Patient": 3, "Specimen": 3} + + aggregated = db.aggregate() + assert sorted(aggregated.keys()) == ["Patient", "Specimen"] + assert aggregated["Patient"]["count"] == 3 + assert aggregated["Specimen"]["count"] == 3 + assert aggregated["Specimen"]["references"]["Patient"]["count"] == 3 + + visualize_aggregation(aggregated, f"{tmp_path}/fhir-query.html") + assert pathlib.Path(f"{tmp_path}/fhir-query.html").exists() + # to see the visualization, cp to tmp + # shutil.copy(f"{tmp_path}/fhir-query.html", "/tmp/fhir-query.html")