Skip to content

Commit

Permalink
Coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Mar 8, 2025
1 parent 26ca466 commit 2d365fc
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def setup_method(self):

# Initialise the HypergraphKHopLifting class
self.lifting = MoGMSTLifting(min_components=3, random_state=0)
self.lifting2 = MoGMSTLifting(random_state=0)

def test_find_mog(self):
"""Test the find_mog method."""
Expand All @@ -49,6 +50,12 @@ def test_find_mog(self):
and labels[3] != labels[6]
and labels[0] != labels[6]
), "Labels have not been assigned correctly"

labels, num_components, means = self.lifting2.find_mog(
self.data.clone().x.numpy()
)

assert num_components == 4, "Wrong number of components"

def test_lift_topology(self):
"""Test the lift_topology method."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ class MoGMSTLifting(PointCloud2HypergraphLifting):
Parameters
----------
min_components : int
min_components : int or None, optional
The minimum number of components for the Mixture of Gaussians model.
max_components : int
It needs to be at least 1 (default: None).
max_components : int or None, optional
The maximum number of components for the Mixture of Gaussians model.
random_state : int
The random state for the Mixture of Gaussians model.
It needs to be greater or equal than min_components (default: None).
random_state : int, optional
The random state for the Mixture of Gaussians model (default: None).
**kwargs : optional
Additional arguments for the class.
"""
Expand All @@ -37,19 +39,6 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
if min_components is not None:
assert min_components > 0, (
"Minimum number of components should be at least 1"
)
if max_components is not None:
assert max_components > 0, (
"Maximum number of components should be at least 1"
)
if min_components is not None and max_components is not None:
assert min_components <= max_components, (
"Minimum number of components must be lower or equal to the"
" maximum number of components."
)
self.min_components = min_components
self.max_components = max_components
self.random_state = random_state
Expand Down Expand Up @@ -121,6 +110,9 @@ def find_mog(self, data) -> tuple[np.ndarray, int, np.ndarray]:
tuple[np.ndarray, int, np.ndarray]
The labels of the data, the number of components and the means of the components.
"""
possible_num_components = [
self.min_components if self.min_components is not None else 1
]
if self.min_components is not None and self.max_components is not None:
possible_num_components = range(
self.min_components, self.max_components + 1
Expand All @@ -129,20 +121,6 @@ def find_mog(self, data) -> tuple[np.ndarray, int, np.ndarray]:
possible_num_components = [
2**i for i in range(1, int(np.log2(data.shape[0] / 2)) + 1)
]
else:
if self.min_components is not None:
num_components = self.min_components
elif self.max_components is not None:
num_components = self.max_components
else:
# Cannot happen
num_components = 1

gm = GaussianMixture(
n_components=num_components, random_state=self.random_state
)
labels = gm.fit_predict(data)
return labels, num_components, gm.means_

best_score = float("inf")
best_labels = None
Expand Down

0 comments on commit 2d365fc

Please sign in to comment.