Skip to content

Commit

Permalink
adds aggregation, visualization
Browse files Browse the repository at this point in the history
black formatting
  • Loading branch information
bwalsh committed Dec 18, 2024
1 parent 33db5f1 commit ddb0647
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 16 deletions.
57 changes: 51 additions & 6 deletions fhir_query/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import json
import logging
import sqlite3
import sys
from nested_lookup import nested_lookup
import tempfile
from collections import defaultdict
from typing import Generator, Any, Optional
from typing import Any, Optional, Callable

import httpx
from dotty_dict import dotty
Expand Down Expand Up @@ -47,15 +47,23 @@ def _initialize_table(self) -> None:
CREATE TABLE IF NOT EXISTS resources (
id VARCHAR NOT NULL,
resource_type VARCHAR NOT NULL,
key VARCHAR NOT NULL,
resource JSON NOT NULL,
PRIMARY KEY (id, resource_type)
)
"""
"""
)
self.connection.execute(
"""
CREATE INDEX IF NOT EXISTS idx_resources_key
ON resources (key)
"""
)

def add(self, resource: dict[str, Any]) -> None:
"""
Add a resource to the 'resources' table.
Add a resource to the 'resources' table.
:param resource: A dictionary with 'id', 'resourceType', and other fields.
"""
if "id" not in resource or "resourceType" not in resource:
Expand All @@ -65,10 +73,15 @@ def add(self, resource: dict[str, Any]) -> None:
with self.connection:
self.connection.execute(
"""
INSERT INTO resources (id, resource_type, resource)
VALUES (?, ?, ?)
INSERT INTO resources (id, resource_type, key, resource)
VALUES (?, ?, ?, ?)
""",
(resource["id"], resource["resourceType"], json.dumps(resource)),
(
resource["id"],
resource["resourceType"],
f'{resource["resourceType"]}/{resource["id"]}',
json.dumps(resource),
),
)
self.adds_counters[resource["resourceType"]] += 1
except sqlite3.IntegrityError as e:
Expand Down Expand Up @@ -132,6 +145,38 @@ def close(self) -> None:
"""
self.connection.close()

def aggregate(self) -> dict:
"""Aggregate metadata counts resourceType(count)-count->resourceType(count)."""

nested_dict: Callable[[], defaultdict[str, defaultdict]] = lambda: defaultdict(defaultdict)

count_resource_types = self.count_resource_types()

summary = nested_dict()

for resource_type in count_resource_types:
resources = self.all_resources(resource_type)
for _ in resources:

if "count" not in summary[resource_type]:
summary[resource_type]["count"] = 0
summary[resource_type]["count"] += 1

refs = nested_lookup("reference", _)
for ref in refs:
# A codeable reference is an object with a codeable concept and a reference
if isinstance(ref, dict):
ref = ref["reference"]
ref_resource_type = ref.split("/")[0]
if "references" not in summary[resource_type]:
summary[resource_type]["references"] = nested_dict()
dst = summary[resource_type]["references"][ref_resource_type]
if "count" not in dst:
dst["count"] = 0
dst["count"] += 1

return summary


class GraphDefinitionRunner(ResourceDB):
"""
Expand Down
48 changes: 44 additions & 4 deletions fhir_query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@
import sys

import click
from click_default_group import DefaultGroup
import yaml
from fhir.resources.graphdefinition import GraphDefinition
from halo import Halo

from fhir_query import GraphDefinitionRunner, setup_logging
from fhir_query.visualizer import visualize_aggregation


@click.command()
@click.group(cls=DefaultGroup, default="main")
def cli():
"""Run FHIR GraphDefinition traversal."""
pass


@cli.command()
@click.option("--fhir-base-url", required=True, help="Base URL of the FHIR server.")
@click.option("--graph-definition-id", help="ID of the GraphDefinition.")
@click.option("--graph-definition-file-path", help="Path to the GraphDefinition JSON file.")
@click.option("--start-resource-type", required=True, help="ResourceType to start traversal.")
@click.option("--start-resource-id", required=True, help="ID of the starting resource.")
@click.option("--db_path", default="/tmp/fhir-graph.sqlite", help="path to of sqlite db")
@click.option("--db-path", default="/tmp/fhir-graph.sqlite", help="path to of sqlite db")
@click.option(
"--dry-run",
is_flag=True,
Expand All @@ -37,7 +45,7 @@ def main(
dry_run: bool,
log_file: str,
) -> None:
"""Main function to run the FHIR GraphDefinition traversal."""
"""Run FHIR GraphDefinition traversal."""

setup_logging(debug, log_file)

Expand Down Expand Up @@ -84,5 +92,37 @@ async def run_runner() -> None:
raise e


@cli.command(name="visualize")
@click.option("--db-path", default="/tmp/fhir-graph.sqlite", help="path to sqlite db")
@click.option("--output-path", default="/tmp/fhir-graph.html", help="path output html")
def visualize(db_path: str, output_path: str) -> None:
"""Visualize the aggregation results."""
from fhir_query import ResourceDB

try:
db = ResourceDB(db_path=db_path)
visualize_aggregation(db.aggregate(), output_path)
except Exception as e:
logging.error(f"Error: {e}", exc_info=True)
click.echo(f"Error: {e}", file=sys.stderr)
# raise e


@cli.command(name="summarize")
@click.option("--db-path", default="/tmp/fhir-graph.sqlite", help="path to sqlite db")
def summarize(db_path: str) -> None:
"""Summarize the aggregation results."""
from fhir_query import ResourceDB

try:
db = ResourceDB(db_path=db_path)
yaml.dump(json.loads(json.dumps(db.aggregate())), sys.stdout, default_flow_style=False)

except Exception as e:
logging.error(f"Error: {e}", exc_info=True)
click.echo(f"Error: {e}", file=sys.stderr)
# raise e


if __name__ == "__main__":
main()
cli()
45 changes: 45 additions & 0 deletions fhir_query/visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pyvis.network import Network

from fhir_query import ResourceDB


def _container():
"""Create a pyvis container."""
return Network(notebook=True, cdn_resources="in_line") # filter_menu=True, select_menu=True


def _load(net: Network, aggregation: dict) -> Network:
"""Load the aggregation into the visualization network."""
# add vertices
for resource_type, _ in aggregation.items():
assert "count" in _, _
net.add_node(resource_type, label=f"{resource_type}/{_['count']}")
# add edges
for resource_type, _ in aggregation.items():
for ref in _.get("references", {}):
count = _["references"][ref]["count"]
if resource_type not in net.get_nodes():
net.add_node(resource_type, label=f"{resource_type}/?")
if ref not in net.get_nodes():
net.add_node(ref, label=f"{ref}/?")
net.add_edge(resource_type, ref, title=count, value=count)
return net


def visualize_aggregation(aggregation: dict, output_path: str) -> None:
"""Visualize the aggregation."""
# Load it into a pyvis
net = _load(_container(), aggregation)
net.save_graph(str(output_path))
net.show_buttons(filter_=["physics"])


def create_network_graph(db_path: str, output_path: str) -> None:
"""Render metadata as a network graph into output_path.
\b
db_path: The directory path to the db.
output_path: The path to save the network graph.
"""
db = ResourceDB(db_path=db_path)
visualize_aggregation(db.aggregate(), output_path)
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ filterwarnings =
ignore::DeprecationWarning:halo.*
markers =
asyncio: asyncio mark
addopts = --cov=fhir_query --cov-report=term-missing
; addopts = --cov=fhir_query --cov-report=term-missing
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ httpx

fhir.resources==8.0.0b4
dotty-dict
nested_lookup
pyvis
click_default_group
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def parse_requirements(filename: str) -> list[str]:
install_requires=parse_requirements("requirements.txt"),
entry_points={
"console_scripts": [
"fq=fhir_query:cli.main",
"fq=fhir_query:cli.cli",
],
},
)
27 changes: 24 additions & 3 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,40 @@
from click.testing import CliRunner
from fhir_query.cli import main
from fhir_query.cli import cli


def test_default_option() -> None:
"""Test default option."""
runner = CliRunner()
result = runner.invoke(cli, ["--help"])
print(result.output)
output = result.output
for _ in ["main*", "summarize", "visualize"]:
assert _ in output, f"Expected {_} in {output}"


def test_help_option() -> None:
"""Test help option."""
runner = CliRunner()
result = runner.invoke(main, "--help")
result = runner.invoke(cli, ["main", "--help"])
output = result.output
print(output)
assert "Usage:" in output
assert "--fhir-base-url" in output
assert "--graph-definition-id" in output
assert "--graph-definition-file-path" in output
assert "--start-resource-type" in output
assert "--start-resource-id" in output
assert "--db_path" in output
assert "--db-path" in output
assert "--dry-run" in output
assert "--debug" in output
assert "--log-file" in output


def test_visualize_help() -> None:
"""Test visualizer help."""
runner = CliRunner()
result = runner.invoke(cli, ["visualize", "--help"])
output = result.output
assert "Usage:" in output
assert "--db-path" in output
assert "--output-path" in output
14 changes: 13 additions & 1 deletion tests/unit/test_mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from click.testing import CliRunner

from fhir_query.cli import main
from fhir_query.visualizer import visualize_aggregation


@pytest.mark.usefixtures("mock_fhir_server")
Expand Down Expand Up @@ -36,7 +37,7 @@ def test_runner(tmp_path: str) -> None:
"ResearchStudy",
"--start-resource-id",
"123",
"--db_path",
"--db-path",
f"{tmp_path}/fhir-query.sqlite",
"--graph-definition-file-path",
"tests/fixtures/GraphDefinition.yaml",
Expand All @@ -59,3 +60,14 @@ def test_runner(tmp_path: str) -> None:

db = ResourceDB(f"{tmp_path}/fhir-query.sqlite")
assert db.count_resource_types() == {"Patient": 3, "Specimen": 3}

aggregated = db.aggregate()
assert sorted(aggregated.keys()) == ["Patient", "Specimen"]
assert aggregated["Patient"]["count"] == 3
assert aggregated["Specimen"]["count"] == 3
assert aggregated["Specimen"]["references"]["Patient"]["count"] == 3

visualize_aggregation(aggregated, f"{tmp_path}/fhir-query.html")
assert pathlib.Path(f"{tmp_path}/fhir-query.html").exists()
# to see the visualization, cp to tmp
# shutil.copy(f"{tmp_path}/fhir-query.html", "/tmp/fhir-query.html")

0 comments on commit ddb0647

Please sign in to comment.