Skip to content

Commit 3bd0d98

Browse files
Bugs in tests introduced by stratification fixes cleaned
1 parent f6a1345 commit 3bd0d98

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/test_clustering.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ def test_additional_clustering():
7272
s_dataset.cluster_similarity = similarity
7373
s_dataset.cluster_distance = None
7474
s_dataset.classes = {0: 0}
75-
s_dataset.stratification = None
75+
s_dataset.cluster_stratification = {n: np.array([0]) for n in names}
7676
d_dataset = DataSet()
7777
d_dataset.cluster_names = names
7878
d_dataset.cluster_map = base_map
7979
d_dataset.cluster_weights = weights
8080
d_dataset.cluster_similarity = None
8181
d_dataset.cluster_distance = distance
8282
d_dataset.classes = {0: 0}
83-
d_dataset.stratification = None
83+
d_dataset.cluster_stratification = {n: np.array([0]) for n in names}
8484

8585
s_dataset = additional_clustering(s_dataset, n_clusters=5, linkage="average")
8686
assert len(s_dataset.cluster_names) == 5
@@ -121,6 +121,7 @@ def test_force_clustering():
121121
[0.1, 0.2, 0.4, 0.6, 1]
122122
])
123123
dataset.classes = {0: 0}
124+
dataset.cluster_stratification = {"cluster1": np.array([0]), "cluster2": np.array([0]), "cluster3": np.array([0]), "cluster4": np.array([0]), "5": np.array([0])}
124125
dataset.num_clusters = 3
125126

126127
# Call the force_clustering function
@@ -339,8 +340,10 @@ def test_clustering(algo):
339340
weights={k: 1 for k in seqs.keys()},
340341
location=base / "pdbbind_clean.fasta",
341342
similarity=algo,
342-
args=check_cdhit_arguments("") if algo == CDHIT else check_mmseqs_arguments(""),
343+
stratification={k: 0 for k in seqs.keys()},
344+
class_oh=np.eye(1),
343345
classes={0: 0},
346+
args=check_cdhit_arguments("") if algo == CDHIT else check_mmseqs_arguments(""),
344347
),
345348
num_clusters=50,
346349
linkage="average",

0 commit comments

Comments
 (0)