diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 00fec9e..7e08782 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -121,7 +121,7 @@ def test_force_clustering(): [0.1, 0.2, 0.4, 0.6, 1] ]) dataset.classes = {0: 0} - dataset.cluster_stratification = {"cluster1": np.array([0]), "cluster2": np.array([0]), "cluster3": np.array([0]), "cluster4": np.array([0]), "5": np.array([0])} + dataset.cluster_stratification = {"cluster1": np.array([0]), "cluster2": np.array([0]), "cluster3": np.array([0]), "cluster4": np.array([0]), "cluster5": np.array([0])} dataset.num_clusters = 3 # Call the force_clustering function