Skip to content

Commit 4df43a2

Browse files
committed
query_documents and query_topics fix
1 parent 5ce24bb commit 4df43a2

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

top2vec/Top2Vec.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,12 @@ def _embed_documents(self, train_corpus):
543543

544544
return document_vectors
545545

546+
def _embed_query(self, query):
547+
self._check_import_status()
548+
self._check_model_status()
549+
550+
return self._l2_normalize(np.array(self.embed(query)[0]))
551+
546552
def _set_document_vectors(self, document_vectors):
547553
if self.embedding_model == 'doc2vec':
548554
self.model.docvecs.vectors_docs = document_vectors
@@ -1663,7 +1669,7 @@ def query_documents(self, query, num_docs, return_documents=True, use_index=Fals
16631669
self._validate_num_docs(num_docs)
16641670

16651671
if self.embedding_model != "doc2vec":
1666-
query_vec = self._embed_documents(query)[0]
1672+
query_vec = self._embed_query(query)
16671673

16681674
else:
16691675

@@ -1740,7 +1746,7 @@ def query_topics(self, query, num_topics, reduced=False, tokenizer=None):
17401746
self._validate_query(query)
17411747

17421748
if self.embedding_model != "doc2vec":
1743-
query_vec = self._embed_documents(query)[0]
1749+
query_vec = self._embed_query(query)
17441750

17451751
else:
17461752

0 commit comments

Comments
 (0)