Commit de06e56 1 parent 42a764b commit de06e56 Copy full SHA for de06e56
File tree 1 file changed +15
-2
lines changed
1 file changed +15
-2
lines changed Original file line number Diff line number Diff line change 26
26
except ImportError :
27
27
_HAVE_CUMAP = False
28
28
29
+ try :
30
+ from cuml .cluster import HDBSCAN as cuHDBSCAN
31
+
32
+ _HAVE_CUHDBSCAN = True
33
+ except ImportError :
34
+ _HAVE_CUHDBSCAN = False
35
+
29
36
try :
30
37
import hnswlib
31
38
@@ -1369,13 +1376,19 @@ def compute_topics(self,
1369
1376
'metric' : 'euclidean' ,
1370
1377
'cluster_selection_method' : 'eom' }
1371
1378
1372
- cluster = hdbscan .HDBSCAN (** hdbscan_args ).fit (umap_embedding )
1379
+ if gpu_hdbscan and _HAVE_CUHDBSCAN :
1380
+ cluster = cuHDBSCAN (** hdbscan_args )
1381
+ labels = cluster .fit_predict (umap_embedding )
1382
+
1383
+ else :
1384
+ cluster = hdbscan .HDBSCAN (** hdbscan_args ).fit (umap_embedding )
1385
+ labels = cluster .labels_
1373
1386
1374
1387
# calculate topic vectors from dense areas of documents
1375
1388
logger .info ('Finding topics' )
1376
1389
1377
1390
# create topic vectors
1378
- self ._create_topic_vectors (cluster . labels_ )
1391
+ self ._create_topic_vectors (labels )
1379
1392
1380
1393
# deduplicate topics
1381
1394
self ._deduplicate_topics (topic_merge_delta )
You can’t perform that action at this time.
0 commit comments