Skip to content

Commit

Permalink
Coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Mar 8, 2025
1 parent 649b8ea commit 004f770
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,29 @@ def test_discrete_configuration_complex_lifting_lift_topology(mock_data):
assert "adjacency_0" in lifted_topology
assert "adjacency_1" in lifted_topology

if __name__ == "__main__":
pytest.main()
class TestDiscreteConfigurationComplexLifting:
"""Test the DiscreteConfigurationComplexLifting class."""

def setup_method(self):
"""Initialise the DiscreteConfigurationComplexLifting class."""
self.lifting_concat = DiscreteConfigurationComplexLifting(k=2, complex_dim=2, feature_aggregation="concat")
self.lifting_sum = DiscreteConfigurationComplexLifting(k=2, complex_dim=2, feature_aggregation="sum")
self.lifting_mean = DiscreteConfigurationComplexLifting(k=2, complex_dim=2, feature_aggregation="mean")

def test_lift_topology(self, simple_graph_1):
"""Test the lift_topology method.
Parameters
----------
simple_graph_1 : Data
A simple graph used for testing.
"""
data = simple_graph_1

assert self.lifting_concat.forward(data.clone()).incidence_1.shape[1] == 156, "Something is wrong with incidence_1."
assert self.lifting_sum.forward(data.clone()).incidence_1.shape[1] == 156, "Something is wrong with incidence_1."
assert self.lifting_mean.forward(data.clone()).incidence_1.shape[1] == 156, "Something is wrong with incidence_1."



# import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ def __init__(
):
self.k = k
self.complex_dim = kwargs["complex_dim"]
if feature_aggregation not in ["mean", "sum", "concat"]:
raise ValueError(
"feature_aggregation must be one of 'mean', 'sum', 'concat'"
)
assert feature_aggregation in ["mean", "sum", "concat"], (
"Feature_aggregation must be one of 'mean', 'sum', 'concat'"
)
self.feature_aggregation = feature_aggregation
super().__init__(preserve_edge_attr=preserve_edge_attr, **kwargs)

Expand Down Expand Up @@ -91,8 +90,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
The lifted topology.
"""
G = self._generate_graph_from_data(data)
if G.is_directed():
raise ValueError("Directed Graphs are not supported.")
assert not G.is_directed(), "Directed Graphs are not supported."

Configuration = generate_configuration_class(
G, self.feature_aggregation, self.contains_edge_attr
Expand Down

0 comments on commit 004f770

Please sign in to comment.