Skip to content

Commit

Permalink
Fix tutorial_lifting
Browse files Browse the repository at this point in the history
  • Loading branch information
luisfpereira committed Jan 17, 2025
1 parent 27962fc commit 1f5ad56
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 72 deletions.
24 changes: 23 additions & 1 deletion topobenchmark/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,32 @@

from .data_manipulations import DATA_MANIPULATIONS
from .feature_liftings import FEATURE_LIFTINGS
from .liftings import LIFTINGS
from .liftings import (
GRAPH2CELL_LIFTINGS,
GRAPH2HYPERGRAPH_LIFTINGS,
GRAPH2SIMPLICIAL_LIFTINGS,
LIFTINGS,
)

TRANSFORMS = {
**LIFTINGS,
**FEATURE_LIFTINGS,
**DATA_MANIPULATIONS,
}


_map_lifting_type_to_dict = {
"graph2cell": GRAPH2CELL_LIFTINGS,
"graph2hypergraph": GRAPH2HYPERGRAPH_LIFTINGS,
"graph2simplicial": GRAPH2SIMPLICIAL_LIFTINGS,
}


def add_lifting_map(LiftingMap, lifting_type, name=None):
if name is None:
name = LiftingMap.__name__

liftings_dict = _map_lifting_type_to_dict[lifting_type]

for dict_ in (liftings_dict, LIFTINGS, TRANSFORMS):
dict_[name] = LiftingMap
31 changes: 13 additions & 18 deletions topobenchmark/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,29 @@

import torch_geometric

from topobenchmark.transforms import LIFTINGS, TRANSFORMS
from topobenchmark.transforms import (
LIFTINGS,
TRANSFORMS,
_map_lifting_type_to_dict,
)
from topobenchmark.transforms.liftings import (
GRAPH2CELL_LIFTINGS,
GRAPH2HYPERGRAPH_LIFTINGS,
GRAPH2SIMPLICIAL_LIFTINGS,
Graph2CellLiftingTransform,
Graph2HypergraphLiftingTransform,
Graph2SimplicialLiftingTransform,
LiftingTransform,
)

_map_lifting_types = {
"graph2cell": (GRAPH2CELL_LIFTINGS, Graph2CellLiftingTransform),
"graph2hypergraph": (
GRAPH2HYPERGRAPH_LIFTINGS,
Graph2HypergraphLiftingTransform,
),
"graph2simplicial": (
GRAPH2SIMPLICIAL_LIFTINGS,
Graph2SimplicialLiftingTransform,
),
_map_lifting_type_to_transform = {
"graph2cell": Graph2CellLiftingTransform,
"graph2hypergraph": Graph2HypergraphLiftingTransform,
"graph2simplicial": Graph2SimplicialLiftingTransform,
}


def _map_lifting_name(lifting_name):
for liftings_dict, Transform in _map_lifting_types.values():
def _map_lifting_to_transform(lifting_name):
for key, liftings_dict in _map_lifting_type_to_dict.items():
if lifting_name in liftings_dict:
return Transform
return _map_lifting_type_to_transform[key]

return LiftingTransform

Expand Down Expand Up @@ -71,7 +66,7 @@ def __init__(self, transform_name, **kwargs):
transform = TRANSFORMS[transform_name](**kwargs)
else:
LiftingMap_ = TRANSFORMS[transform_name]
Transform = _map_lifting_name(transform_name)
Transform = _map_lifting_to_transform(transform_name)
lifting_map_kwargs, transform_kwargs = _route_lifting_kwargs(
kwargs, LiftingMap_, Transform
)
Expand Down
1 change: 1 addition & 0 deletions topobenchmark/transforms/liftings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Graph2ComplexLiftingTransform,
Graph2HypergraphLiftingTransform,
Graph2SimplicialLiftingTransform,
LiftingMap,
LiftingTransform,
)
from .graph2cell import GRAPH2CELL_LIFTINGS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def lift(self, domain):
for set_k_simplices in simplices:
simplicial_complex.add_simplices_from(list(set_k_simplices))

# because Complex pads unexisting dimensions with empty matrices
# because ComplexData pads unexisting dimensions with empty matrices
simplicial_complex.practical_dim = self.complex_dim

return simplicial_complex
124 changes: 72 additions & 52 deletions tutorials/tutorial_lifting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
"\n",
"import lightning as pl\n",
"import networkx as nx\n",
"import hydra\n",
"import torch_geometric\n",
"from omegaconf import OmegaConf\n",
"from topomodelx.nn.simplicial.scn2 import SCN2\n",
"from toponetx.classes import SimplicialComplex\n",
Expand All @@ -72,8 +70,8 @@
"from topobenchmark.nn.readouts import PropagateSignalDown\n",
"from topobenchmark.nn.wrappers.simplicial import SCNWrapper\n",
"from topobenchmark.optimizer import TBOptimizer\n",
"from topobenchmark.transforms.liftings.graph2simplicial import (\n",
" Graph2SimplicialLifting,\n",
"from topobenchmark.transforms.liftings import (\n",
" LiftingMap,\n",
")"
]
},
Expand Down Expand Up @@ -101,14 +99,17 @@
" \"data_domain\": \"graph\",\n",
" \"data_type\": \"TUDataset\",\n",
" \"data_name\": \"MUTAG\",\n",
" \"data_dir\": \"./data/MUTAG/\"}\n",
" \"data_dir\": \"./data/MUTAG/\",\n",
"}\n",
"\n",
"\n",
"transform_config = { \"clique_lifting\":\n",
" {\"_target_\": \"__main__.SimplicialCliquesLEQLifting\",\n",
" \"transform_name\": \"SimplicialCliquesLEQLifting\",\n",
" \"transform_type\": \"lifting\",\n",
" \"complex_dim\": 3,}\n",
"transform_config = {\n",
" \"clique_lifting\": {\n",
" \"_target_\": \"topobenchmark.transforms.data_transform.DataTransform\",\n",
" \"transform_name\": \"SimplicialCliquesLEQLifting\",\n",
" \"transform_type\": \"lifting\",\n",
" \"complex_dim\": 3,\n",
" }\n",
"}\n",
"\n",
"split_config = {\n",
Expand Down Expand Up @@ -138,21 +139,19 @@
"}\n",
"\n",
"loss_config = {\n",
" \"dataset_loss\": \n",
" {\n",
" \"task\": \"classification\", \n",
" \"loss_type\": \"cross_entropy\"\n",
" }\n",
" \"dataset_loss\": {\"task\": \"classification\", \"loss_type\": \"cross_entropy\"}\n",
"}\n",
"\n",
"evaluator_config = {\"task\": \"classification\",\n",
" \"num_classes\": out_channels,\n",
" \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n",
"evaluator_config = {\n",
" \"task\": \"classification\",\n",
" \"num_classes\": out_channels,\n",
" \"metrics\": [\"accuracy\", \"precision\", \"recall\"],\n",
"}\n",
"\n",
"optimizer_config = {\"optimizer_id\": \"Adam\",\n",
" \"parameters\":\n",
" {\"lr\": 0.001,\"weight_decay\": 0.0005}\n",
" }\n",
"optimizer_config = {\n",
" \"optimizer_id\": \"Adam\",\n",
" \"parameters\": {\"lr\": 0.001, \"weight_decay\": 0.0005},\n",
"}\n",
"\n",
"\n",
"loader_config = OmegaConf.create(loader_config)\n",
Expand All @@ -174,6 +173,7 @@
"def wrapper(**factory_kwargs):\n",
" def factory(backbone):\n",
" return SCNWrapper(backbone, **factory_kwargs)\n",
"\n",
" return factory"
]
},
Expand All @@ -197,28 +197,30 @@
"metadata": {},
"outputs": [],
"source": [
"class SimplicialCliquesLEQLifting(Graph2SimplicialLifting):\n",
"class SimplicialCliquesLEQLifting(LiftingMap):\n",
" r\"\"\"Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n",
" \n",
" Args:\n",
" kwargs (optional): Additional arguments for the class.\n",
" \"\"\"\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
" def __init__(self, complex_dim=2):\n",
" super().__init__()\n",
" self.complex_dim = complex_dim\n",
"\n",
"\n",
" def lift_topology(self, data: torch_geometric.data.Data) -> dict:\n",
" def lift(self, domain) -> dict:\n",
" r\"\"\"Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n",
"\n",
" Args:\n",
" data (torch_geometric.data.Data): The input data to be lifted.\n",
" Returns:\n",
" dict: The lifted topology.\n",
" \"\"\"\n",
" graph = self._generate_graph_from_data(data)\n",
" graph = domain\n",
"\n",
" simplicial_complex = SimplicialComplex(graph)\n",
" cliques = nx.find_cliques(graph)\n",
" \n",
" simplices: list[set[tuple[Any, ...]]] = [set() for _ in range(2, self.complex_dim + 1)]\n",
"\n",
" simplices: list[set[tuple[Any, ...]]] = [\n",
" set() for _ in range(2, self.complex_dim + 1)\n",
" ]\n",
" for clique in cliques:\n",
" if len(clique) <= self.complex_dim + 1:\n",
" for i in range(2, self.complex_dim + 1):\n",
Expand All @@ -227,8 +229,11 @@
"\n",
" for set_k_simplices in simplices:\n",
" simplicial_complex.add_simplices_from(list(set_k_simplices))\n",
" \n",
" # because ComplexData pads unexisting dimensions with empty matrices\n",
" simplicial_complex.practical_dim = self.complex_dim\n",
"\n",
" return self._get_lifted_topology(simplicial_complex, graph)\n"
" return simplicial_complex"
]
},
{
Expand All @@ -251,9 +256,9 @@
"metadata": {},
"outputs": [],
"source": [
"from topobenchmark.transforms import TRANSFORMS\n",
"from topobenchmark.transforms import add_lifting_map\n",
"\n",
"TRANSFORMS[\"SimplicialCliquesLEQLifting\"] = SimplicialCliquesLEQLifting"
"add_lifting_map(SimplicialCliquesLEQLifting, \"graph2simplicial\")"
]
},
{
Expand All @@ -275,8 +280,12 @@
"dataset, dataset_dir = graph_loader.load()\n",
"\n",
"preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n",
"dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n",
"datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)"
"dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(\n",
" split_config\n",
")\n",
"datamodule = TBDataloader(\n",
" dataset_train, dataset_val, dataset_test, batch_size=32\n",
")"
]
},
{
Expand All @@ -299,12 +308,19 @@
"metadata": {},
"outputs": [],
"source": [
"backbone = SCN2(in_channels_0=dim_hidden,in_channels_1=dim_hidden,in_channels_2=dim_hidden)\n",
"backbone = SCN2(\n",
" in_channels_0=dim_hidden,\n",
" in_channels_1=dim_hidden,\n",
" in_channels_2=dim_hidden,\n",
")\n",
"backbone_wrapper = wrapper(**wrapper_config)\n",
"\n",
"readout = PropagateSignalDown(**readout_config)\n",
"loss = TBLoss(**loss_config)\n",
"feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n",
"feature_encoder = AllCellFeatureEncoder(\n",
" in_channels=[in_channels, in_channels, in_channels],\n",
" out_channels=dim_hidden,\n",
")\n",
"\n",
"evaluator = TBEvaluator(**evaluator_config)\n",
"optimizer = TBOptimizer(**optimizer_config)"
Expand All @@ -316,14 +332,16 @@
"metadata": {},
"outputs": [],
"source": [
"model = TBModel(backbone=backbone,\n",
" backbone_wrapper=backbone_wrapper,\n",
" readout=readout,\n",
" loss=loss,\n",
" feature_encoder=feature_encoder,\n",
" evaluator=evaluator,\n",
" optimizer=optimizer,\n",
" compile=False,)"
"model = TBModel(\n",
" backbone=backbone,\n",
" backbone_wrapper=backbone_wrapper,\n",
" readout=readout,\n",
" loss=loss,\n",
" feature_encoder=feature_encoder,\n",
" evaluator=evaluator,\n",
" optimizer=optimizer,\n",
" compile=False,\n",
")"
]
},
{
Expand Down Expand Up @@ -386,7 +404,9 @@
],
"source": [
"# Increase the number of epochs to get better results\n",
"trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False)\n",
"trainer = pl.Trainer(\n",
" max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False\n",
")\n",
"\n",
"trainer.fit(model, datamodule)\n",
"train_metrics = trainer.callback_metrics"
Expand Down Expand Up @@ -415,9 +435,9 @@
}
],
"source": [
"print(' Training metrics\\n', '-'*26)\n",
"print(\" Training metrics\\n\", \"-\" * 26)\n",
"for key in train_metrics:\n",
" print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))"
" print(\"{:<21s} {:>5.4f}\".format(key + \":\", train_metrics[key].item()))"
]
},
{
Expand Down Expand Up @@ -505,9 +525,9 @@
}
],
"source": [
"print(' Testing metrics\\n', '-'*25)\n",
"print(\" Testing metrics\\n\", \"-\" * 25)\n",
"for key in test_metrics:\n",
" print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))"
" print(\"{:<20s} {:>5.4f}\".format(key + \":\", test_metrics[key].item()))"
]
},
{
Expand Down

0 comments on commit 1f5ad56

Please sign in to comment.