From 004f7708a854d36dcf5d0efcefd0b6d3b1b4da6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Fri, 7 Mar 2025 19:42:10 -0800 Subject: [PATCH] Coverage --- ..._discrete_configuration_complex_lifting.py | 25 +++++++++++++++++-- .../discrete_configuration_complex_lifting.py | 10 +++----- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/test/transforms/liftings/graph2cell/test_discrete_configuration_complex_lifting.py b/test/transforms/liftings/graph2cell/test_discrete_configuration_complex_lifting.py index e9ab74ee..b79a2aa5 100644 --- a/test/transforms/liftings/graph2cell/test_discrete_configuration_complex_lifting.py +++ b/test/transforms/liftings/graph2cell/test_discrete_configuration_complex_lifting.py @@ -55,8 +55,29 @@ def test_discrete_configuration_complex_lifting_lift_topology(mock_data): assert "adjacency_0" in lifted_topology assert "adjacency_1" in lifted_topology -if __name__ == "__main__": - pytest.main() +class TestDiscreteConfigurationComplexLifting: + """Test the DiscreteConfigurationComplexLifting class.""" + + def setup_method(self): + """Initialise the DiscreteConfigurationComplexLifting class.""" + self.lifting_concat = DiscreteConfigurationComplexLifting(k=2, complex_dim=2, feature_aggregation="concat") + self.lifting_sum = DiscreteConfigurationComplexLifting(k=2, complex_dim=2, feature_aggregation="sum") + self.lifting_mean = DiscreteConfigurationComplexLifting(k=2, complex_dim=2, feature_aggregation="mean") + + def test_lift_topology(self, simple_graph_1): + """Test the lift_topology method. + + Parameters + ---------- + simple_graph_1 : Data + A simple graph used for testing. + """ + data = simple_graph_1 + + assert self.lifting_concat.forward(data.clone()).incidence_1.shape[1] == 156, "Something is wrong with incidence_1." + assert self.lifting_sum.forward(data.clone()).incidence_1.shape[1] == 156, "Something is wrong with incidence_1." + assert self.lifting_mean.forward(data.clone()).incidence_1.shape[1] == 156, "Something is wrong with incidence_1." + # import torch diff --git a/topobench/transforms/liftings/graph2cell/discrete_configuration_complex_lifting.py b/topobench/transforms/liftings/graph2cell/discrete_configuration_complex_lifting.py index de9881cf..f53b5aa3 100644 --- a/topobench/transforms/liftings/graph2cell/discrete_configuration_complex_lifting.py +++ b/topobench/transforms/liftings/graph2cell/discrete_configuration_complex_lifting.py @@ -49,10 +49,9 @@ def __init__( ): self.k = k self.complex_dim = kwargs["complex_dim"] - if feature_aggregation not in ["mean", "sum", "concat"]: - raise ValueError( - "feature_aggregation must be one of 'mean', 'sum', 'concat'" - ) + assert feature_aggregation in ["mean", "sum", "concat"], ( + "Feature_aggregation must be one of 'mean', 'sum', 'concat'" + ) self.feature_aggregation = feature_aggregation super().__init__(preserve_edge_attr=preserve_edge_attr, **kwargs) @@ -91,8 +90,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: The lifted topology. """ G = self._generate_graph_from_data(data) - if G.is_directed(): - raise ValueError("Directed Graphs are not supported.") + assert not G.is_directed(), "Directed Graphs are not supported." Configuration = generate_configuration_class( G, self.feature_aggregation, self.contains_edge_attr