From 2d365fcd7e6dc65933d60a96f0908ca160b27cbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Fri, 7 Mar 2025 19:18:38 -0800 Subject: [PATCH] Coverage --- .../test_mogmst_lifting.py | 7 ++++ .../pointcloud2hypergraph/mogmst_lifting.py | 40 +++++-------------- 2 files changed, 16 insertions(+), 31 deletions(-) diff --git a/test/transforms/liftings/pointcloud2hypergraph/test_mogmst_lifting.py b/test/transforms/liftings/pointcloud2hypergraph/test_mogmst_lifting.py index c380d0d2..bb748f48 100644 --- a/test/transforms/liftings/pointcloud2hypergraph/test_mogmst_lifting.py +++ b/test/transforms/liftings/pointcloud2hypergraph/test_mogmst_lifting.py @@ -32,6 +32,7 @@ def setup_method(self): # Initialise the HypergraphKHopLifting class self.lifting = MoGMSTLifting(min_components=3, random_state=0) + self.lifting2 = MoGMSTLifting(random_state=0) def test_find_mog(self): """Test the find_mog method.""" @@ -49,6 +50,12 @@ def test_find_mog(self): and labels[3] != labels[6] and labels[0] != labels[6] ), "Labels have not been assigned correctly" + + labels, num_components, means = self.lifting2.find_mog( + self.data.clone().x.numpy() + ) + + assert num_components == 4, "Wrong number of components" def test_lift_topology(self): """Test the lift_topology method.""" diff --git a/topobench/transforms/liftings/pointcloud2hypergraph/mogmst_lifting.py b/topobench/transforms/liftings/pointcloud2hypergraph/mogmst_lifting.py index b47ae789..eae32305 100644 --- a/topobench/transforms/liftings/pointcloud2hypergraph/mogmst_lifting.py +++ b/topobench/transforms/liftings/pointcloud2hypergraph/mogmst_lifting.py @@ -19,12 +19,14 @@ class MoGMSTLifting(PointCloud2HypergraphLifting): Parameters ---------- - min_components : int + min_components : int or None, optional The minimum number of components for the Mixture of Gaussians model. - max_components : int + It needs to be at least 1 (default: None). + max_components : int or None, optional The maximum number of components for the Mixture of Gaussians model. - random_state : int - The random state for the Mixture of Gaussians model. + It needs to be greater or equal than min_components (default: None). + random_state : int, optional + The random state for the Mixture of Gaussians model (default: None). **kwargs : optional Additional arguments for the class. """ @@ -37,19 +39,6 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - if min_components is not None: - assert min_components > 0, ( - "Minimum number of components should be at least 1" - ) - if max_components is not None: - assert max_components > 0, ( - "Maximum number of components should be at least 1" - ) - if min_components is not None and max_components is not None: - assert min_components <= max_components, ( - "Minimum number of components must be lower or equal to the" - " maximum number of components." - ) self.min_components = min_components self.max_components = max_components self.random_state = random_state @@ -121,6 +110,9 @@ def find_mog(self, data) -> tuple[np.ndarray, int, np.ndarray]: tuple[np.ndarray, int, np.ndarray] The labels of the data, the number of components and the means of the components. """ + possible_num_components = [ + self.min_components if self.min_components is not None else 1 + ] if self.min_components is not None and self.max_components is not None: possible_num_components = range( self.min_components, self.max_components + 1 @@ -129,20 +121,6 @@ def find_mog(self, data) -> tuple[np.ndarray, int, np.ndarray]: possible_num_components = [ 2**i for i in range(1, int(np.log2(data.shape[0] / 2)) + 1) ] - else: - if self.min_components is not None: - num_components = self.min_components - elif self.max_components is not None: - num_components = self.max_components - else: - # Cannot happen - num_components = 1 - - gm = GaussianMixture( - n_components=num_components, random_state=self.random_state - ) - labels = gm.fit_predict(data) - return labels, num_components, gm.means_ best_score = float("inf") best_labels = None