@@ -72,15 +72,15 @@ def test_additional_clustering():
72
72
s_dataset .cluster_similarity = similarity
73
73
s_dataset .cluster_distance = None
74
74
s_dataset .classes = {0 : 0 }
75
- s_dataset .stratification = None
75
+ s_dataset .cluster_stratification = { n : np . array ([ 0 ]) for n in names }
76
76
d_dataset = DataSet ()
77
77
d_dataset .cluster_names = names
78
78
d_dataset .cluster_map = base_map
79
79
d_dataset .cluster_weights = weights
80
80
d_dataset .cluster_similarity = None
81
81
d_dataset .cluster_distance = distance
82
82
d_dataset .classes = {0 : 0 }
83
- d_dataset .stratification = None
83
+ d_dataset .cluster_stratification = { n : np . array ([ 0 ]) for n in names }
84
84
85
85
s_dataset = additional_clustering (s_dataset , n_clusters = 5 , linkage = "average" )
86
86
assert len (s_dataset .cluster_names ) == 5
@@ -121,6 +121,7 @@ def test_force_clustering():
121
121
[0.1 , 0.2 , 0.4 , 0.6 , 1 ]
122
122
])
123
123
dataset .classes = {0 : 0 }
124
+ dataset .cluster_stratification = {"cluster1" : np .array ([0 ]), "cluster2" : np .array ([0 ]), "cluster3" : np .array ([0 ]), "cluster4" : np .array ([0 ]), "5" : np .array ([0 ])}
124
125
dataset .num_clusters = 3
125
126
126
127
# Call the force_clustering function
@@ -339,8 +340,10 @@ def test_clustering(algo):
339
340
weights = {k : 1 for k in seqs .keys ()},
340
341
location = base / "pdbbind_clean.fasta" ,
341
342
similarity = algo ,
342
- args = check_cdhit_arguments ("" ) if algo == CDHIT else check_mmseqs_arguments ("" ),
343
+ stratification = {k : 0 for k in seqs .keys ()},
344
+ class_oh = np .eye (1 ),
343
345
classes = {0 : 0 },
346
+ args = check_cdhit_arguments ("" ) if algo == CDHIT else check_mmseqs_arguments ("" ),
344
347
),
345
348
num_clusters = 50 ,
346
349
linkage = "average" ,
0 commit comments