Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
luisfpereira committed Jan 16, 2025
1 parent f4010fa commit 27962fc
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 85 deletions.
8 changes: 6 additions & 2 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from topobenchmark.transforms.liftings import (
CellCycleLifting,
Graph2CellLiftingTransform,
Graph2SimplicialLiftingTransform,
SimplicialCliqueLifting,
)

Expand Down Expand Up @@ -148,7 +150,9 @@ def sg1_clique_lifted(simple_graph_1):
torch_geometric.data.Data
A simple graph data object with a clique lifting.
"""
lifting_signed = SimplicialCliqueLifting(complex_dim=3, signed=True)
lifting_signed = Graph2SimplicialLiftingTransform(
SimplicialCliqueLifting(complex_dim=3), signed=True
)
data = lifting_signed(simple_graph_1)
data.batch_0 = "null"
return data
Expand All @@ -168,7 +172,7 @@ def sg1_cell_lifted(simple_graph_1):
torch_geometric.data.Data
A simple graph data object with a cell lifting.
"""
lifting = CellCycleLifting()
lifting = Graph2CellLiftingTransform(CellCycleLifting())
data = lifting(simple_graph_1)
data.batch_0 = "null"
return data
Expand Down
59 changes: 37 additions & 22 deletions test/nn/backbones/simplicial/test_sccnn.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,53 @@
"""Unit tests for SCCNN"""

import torch
from torch_geometric.utils import get_laplacian
from ...._utils.nn_module_auto_test import NNModuleAutoTest
from topobenchmark.nn.backbones.simplicial import SCCNNCustom
from topobenchmark.transforms.liftings.graph2simplicial import (
from topobenchmark.transforms.liftings import (
Graph2SimplicialLiftingTransform,
SimplicialCliqueLifting,
)

from ...._utils.nn_module_auto_test import NNModuleAutoTest


def test_SCCNNCustom(simple_graph_1):
lifting_signed = SimplicialCliqueLifting(
complex_dim=3, signed=True
)
lifting_signed = Graph2SimplicialLiftingTransform(
SimplicialCliqueLifting(complex_dim=3), signed=True
)
data = lifting_signed(simple_graph_1)
out_dim = 4
conv_order = 1
sc_order = 3
laplacian_all = (
data.hodge_laplacian_0,
data.down_laplacian_1,
data.up_laplacian_1,
data.down_laplacian_2,
data.up_laplacian_2,
)
data.hodge_laplacian_0,
data.down_laplacian_1,
data.up_laplacian_1,
data.down_laplacian_2,
data.up_laplacian_2,
)
incidence_all = (data.incidence_1, data.incidence_2)
expected_shapes = [(data.x.shape[0], out_dim), (data.x_1.shape[0], out_dim), (data.x_2.shape[0], out_dim)]
expected_shapes = [
(data.x.shape[0], out_dim),
(data.x_1.shape[0], out_dim),
(data.x_2.shape[0], out_dim),
]

auto_test = NNModuleAutoTest([
{
"module" : SCCNNCustom,
"init": ((data.x.shape[1], data.x_1.shape[1], data.x_2.shape[1]), (out_dim, out_dim, out_dim), conv_order, sc_order),
"forward": ((data.x, data.x_1, data.x_2), laplacian_all, incidence_all),
"assert_shape": expected_shapes
},
])
auto_test = NNModuleAutoTest(
[
{
"module": SCCNNCustom,
"init": (
(data.x.shape[1], data.x_1.shape[1], data.x_2.shape[1]),
(out_dim, out_dim, out_dim),
conv_order,
sc_order,
),
"forward": (
(data.x, data.x_1, data.x_2),
laplacian_all,
incidence_all,
),
"assert_shape": expected_shapes,
},
]
)
auto_test.run()
46 changes: 18 additions & 28 deletions test/nn/wrappers/cell/test_cell_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
"""Unit tests for cell model wrappers"""

import torch
from torch_geometric.utils import get_laplacian
from ...._utils.nn_module_auto_test import NNModuleAutoTest
from ...._utils.flow_mocker import FlowMocker
from unittest.mock import MagicMock
from topomodelx.nn.cell.ccxn import CCXN
from topomodelx.nn.cell.cwn import CWN

from topobenchmark.nn.backbones.cell.cccn import CCCN
from topobenchmark.nn.wrappers import (
AbstractWrapper,
CCCNWrapper,
CANWrapper,
CCXNWrapper,
CWNWrapper
CWNWrapper,
)
from topomodelx.nn.cell.can import CAN
from topomodelx.nn.cell.ccxn import CCXN
from topomodelx.nn.cell.cwn import CWN
from topobenchmark.nn.backbones.cell.cccn import CCCN
from unittest.mock import MagicMock


class TestCellWrappers:
Expand All @@ -27,11 +18,9 @@ def test_CCCNWrapper(self, sg1_clique_lifted):
num_cell_dimensions = 2

wrapper = CCCNWrapper(
CCCN(
data.x_1.shape[1]
),
out_channels=out_channels,
num_cell_dimensions=num_cell_dimensions
CCCN(data.x_1.shape[1]),
out_channels=out_channels,
num_cell_dimensions=num_cell_dimensions,
)
out = wrapper(data)

Expand All @@ -44,11 +33,9 @@ def test_CCXNWrapper(self, sg1_cell_lifted):
num_cell_dimensions = 2

wrapper = CCXNWrapper(
CCXN(
data.x_0.shape[1], data.x_1.shape[1], out_channels
),
out_channels=out_channels,
num_cell_dimensions=num_cell_dimensions
CCXN(data.x_0.shape[1], data.x_1.shape[1], out_channels),
out_channels=out_channels,
num_cell_dimensions=num_cell_dimensions,
)
out = wrapper(data)

Expand All @@ -63,13 +50,16 @@ def test_CWNWrapper(self, sg1_cell_lifted):

wrapper = CWNWrapper(
CWN(
data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1], hid_channels, 2
),
out_channels=out_channels,
num_cell_dimensions=num_cell_dimensions
data.x_0.shape[1],
data.x_1.shape[1],
data.x_2.shape[1],
hid_channels,
2,
),
out_channels=out_channels,
num_cell_dimensions=num_cell_dimensions,
)
out = wrapper(data)

for key in ["labels", "batch_0", "x_0", "x_1", "x_2"]:
assert key in out

62 changes: 32 additions & 30 deletions test/nn/wrappers/simplicial/test_SCCNNWrapper.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
"""Unit tests for simplicial model wrappers"""

import torch
from torch_geometric.utils import get_laplacian
from ...._utils.nn_module_auto_test import NNModuleAutoTest
from ...._utils.flow_mocker import FlowMocker
from topobenchmark.nn.backbones.simplicial import SCCNNCustom
from topomodelx.nn.simplicial.san import SAN
from topomodelx.nn.simplicial.scn2 import SCN2
from topomodelx.nn.simplicial.sccn import SCCN
from topomodelx.nn.simplicial.scn2 import SCN2

from topobenchmark.nn.backbones.simplicial import SCCNNCustom
from topobenchmark.nn.wrappers import (
SCCNWrapper,
SCCNNWrapper,
SANWrapper,
SCNWrapper
SCCNNWrapper,
SCCNWrapper,
SCNWrapper,
)


class TestSimplicialWrappers:
"""Test simplicial model wrappers."""

def test_SCCNNWrapper(self, sg1_clique_lifted):
"""Test SCCNNWrapper.
Parameters
----------
sg1_clique_lifted : torch_geometric.data.Data
Expand All @@ -30,12 +28,17 @@ def test_SCCNNWrapper(self, sg1_clique_lifted):
out_dim = 4
conv_order = 1
sc_order = 3
init_args = (data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), (out_dim, out_dim, out_dim), conv_order, sc_order
init_args = (
(data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]),
(out_dim, out_dim, out_dim),
conv_order,
sc_order,
)

wrapper = SCCNNWrapper(
SCCNNCustom(*init_args),
out_channels=out_dim,
num_cell_dimensions=3
SCCNNCustom(*init_args),
out_channels=out_dim,
num_cell_dimensions=3,
)
out = wrapper(data)
# Assert keys in output
Expand All @@ -44,20 +47,20 @@ def test_SCCNNWrapper(self, sg1_clique_lifted):

def test_SANWarpper(self, sg1_clique_lifted):
"""Test SANWarpper.
Parameters
----------
sg1_clique_lifted : torch_geometric.data.Data
A fixture of simple graph 1 lifted with SimlicialCliqueLifting
A fixture of simple graph 1 lifted with SimlicialCliqueLifting
"""
data = sg1_clique_lifted
out_dim = data.x_0.shape[1]
hidden_channels = data.x_0.shape[1]

wrapper = SANWrapper(
SAN(data.x_0.shape[1], hidden_channels),
out_channels=out_dim,
num_cell_dimensions=3
SAN(data.x_0.shape[1], hidden_channels),
out_channels=out_dim,
num_cell_dimensions=3,
)
out = wrapper(data)
# Assert keys in output
Expand All @@ -66,19 +69,19 @@ def test_SANWarpper(self, sg1_clique_lifted):

def test_SCNWrapper(self, sg1_clique_lifted):
"""Test SCNWrapper.
Parameters
----------
sg1_clique_lifted : torch_geometric.data.Data
A fixture of simple graph 1 lifted with SimlicialCliqueLifting
A fixture of simple graph 1 lifted with SimlicialCliqueLifting
"""
data = sg1_clique_lifted
out_dim = data.x_0.shape[1]

wrapper = SCNWrapper(
SCN2(data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]),
out_channels=out_dim,
num_cell_dimensions=3
SCN2(data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]),
out_channels=out_dim,
num_cell_dimensions=3,
)
out = wrapper(data)
# Assert keys in output
Expand All @@ -87,23 +90,22 @@ def test_SCNWrapper(self, sg1_clique_lifted):

def test_SCCNWrapper(self, sg1_clique_lifted):
"""Test SCCNWrapper.
Parameters
----------
sg1_clique_lifted : torch_geometric.data.Data
A fixture of simple graph 1 lifted with SimlicialCliqueLifting
A fixture of simple graph 1 lifted with SimlicialCliqueLifting
"""
data = sg1_clique_lifted
out_dim = data.x_0.shape[1]
max_rank = 2

wrapper = SCCNWrapper(
SCCN(data.x_0.shape[1], max_rank),
out_channels=out_dim,
num_cell_dimensions=3
SCCN(data.x_0.shape[1], max_rank),
out_channels=out_dim,
num_cell_dimensions=3,
)
out = wrapper(data)
# Assert keys in output
for key in ["labels", "batch_0", "x_0", "x_1", "x_2"]:
assert key in out

7 changes: 4 additions & 3 deletions topobenchmark/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ def _map_lifting_name(lifting_name):
return LiftingTransform


def _route_lifting_kwargs(kwargs, LiftingMap):
def _route_lifting_kwargs(kwargs, LiftingMap, Transform):
lifting_map_sign = inspect.signature(LiftingMap)
transform_sign = inspect.signature(Transform)

lifting_map_kwargs = {}
transform_kwargs = {}

for key, value in kwargs.items():
if key in lifting_map_sign.parameters:
lifting_map_kwargs[key] = value
else:
elif key in transform_sign.parameters:
transform_kwargs[key] = value

return lifting_map_kwargs, transform_kwargs
Expand All @@ -72,7 +73,7 @@ def __init__(self, transform_name, **kwargs):
LiftingMap_ = TRANSFORMS[transform_name]
Transform = _map_lifting_name(transform_name)
lifting_map_kwargs, transform_kwargs = _route_lifting_kwargs(
kwargs, LiftingMap_
kwargs, LiftingMap_, Transform
)

lifting_map = LiftingMap_(**lifting_map_kwargs)
Expand Down

0 comments on commit 27962fc

Please sign in to comment.