Skip to content

Commit

Permalink
Fix issues with biencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed Jul 5, 2024
1 parent fa335db commit cabfbf9
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions pirate/retrievers/bi_encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from loguru import logger
from typing import Optional
from sentence_transformers import SentenceTransformer
Expand All @@ -18,6 +20,7 @@ def __init__(
):
super().__init__(model_name, *args, **kwargs)

self.model_name = model_name
self.indexed_corpus = None
self.corpus = None
self.list_of_passages = None
Expand All @@ -40,28 +43,36 @@ def index(self, corpus: Passages, *args, **kwargs):
def rank_passages(self, queries: Queries, top_k: Optional[int] = None, *args, **kwargs) -> Ranking:
if self.indexed_corpus is None or self.corpus is None:
raise ValueError("Index not built. Please call the index method first.")

ranking_list = []
for i, query_id in enumerate(queries):
for i, query_id in tqdm(enumerate(queries), total=len(queries)):
query_embedding = self.encode(queries[query_id], *args, **kwargs)

scores = self.similarity(query_embedding, self.indexed_corpus)

# Ensure query_embedding is a PyTorch tensor
if not isinstance(query_embedding, torch.Tensor):
query_embedding = torch.tensor(query_embedding)

# Ensure indexed_corpus is a PyTorch tensor
if not isinstance(self.indexed_corpus, torch.Tensor):
self.indexed_corpus = torch.tensor(self.indexed_corpus)

# Reshape query_embedding to match the dimensions
query_embedding = query_embedding.reshape(1, -1)

# Calculate cosine similarity
scores = torch.nn.functional.cosine_similarity(query_embedding, self.indexed_corpus)

score_array = scores.cpu().numpy().flatten()

if top_k is not None:
top_k_indices = score_array.argsort()[-top_k:][::-1]
top_k_scores = score_array[top_k_indices]
top_k_doc_ids = [self.index_id_lookup[i] for i in top_k_indices]

for j, doc_id in enumerate(top_k_doc_ids):
ranking_list.append([query_id, doc_id, top_k_scores[j], j])

ranking_list.append([query_id, doc_id, j, top_k_scores[j]])
else:
sorted_indices = score_array.argsort()[::-1]

for j, doc_id in enumerate(sorted_indices):
ranking_list.append([query_id, self.index_id_lookup[doc_id], score_array[doc_id], j])
ranking_list.append([query_id, self.index_id_lookup[doc_id], j, score_array[doc_id]])

ranking = Ranking(ranking_list)

return ranking

0 comments on commit cabfbf9

Please sign in to comment.