Skip to content

Commit

Permalink
addressing PR comments - moving knn_vector_store to KnnRecallParamSou…
Browse files Browse the repository at this point in the history
…rce and passing it as arg
  • Loading branch information
pmpailis committed Nov 24, 2023
1 parent c226899 commit 134baa1
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions dense_vector/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
logger = logging.getLogger(__name__)


def load_query_vectors(queries_file):
def load_query_vectors(queries_file) -> Dict[int, List[float]]:
if not (os.path.exists(queries_file) and os.path.isfile(queries_file)):
raise ValueError(f"Provided queries file '{queries_file}' does not exist or is not a file")
query_vectors: Dict[int, List[float]]
Expand Down Expand Up @@ -70,6 +70,8 @@ async def load_exact_neighbors(self, index: str, query_id: str, max_size: int, r
return await extract_exact_neighbors(self._query_vectors[query_id], index, max_size, self._vector_field, request_cache, client)

def get_query_vectors(self) -> Dict[int, List[float]]:
if len(self._query_vectors) == 0:
raise ValueError("Query vectors have not been initialized.")
return self._query_vectors

@classmethod
Expand Down Expand Up @@ -151,10 +153,11 @@ def __init__(self, track, params, **kwargs):
self._cache = params.get("cache", False)
self._params = params
self.infinite = True
cwd = os.path.dirname(__file__)
self._queries_file = os.path.join(cwd, "queries.json")
self._vector_field = "vector"
self._target_k = 1_000
cwd = os.path.dirname(__file__)
queries_file: str = os.path.join(cwd, "queries.json")
vector_field: str = "vector"
self._knn_vector_store: KnnVectorStore = KnnVectorStore.get_instance(queries_file, vector_field)

def partition(self, partition_index, total_partitions):
return self
Expand All @@ -165,9 +168,8 @@ def params(self):
"cache": self._params.get("cache", False),
"size": self._params.get("k", 10),
"num_candidates": self._params.get("num-candidates", 100),
"queries_file": self._queries_file,
"vector_field": self._vector_field,
"target_k": self._target_k,
"knn_vector_store": self._knn_vector_store,
}


Expand All @@ -180,14 +182,12 @@ async def __call__(self, es, params):
num_candidates = params["num_candidates"]
index = params["index"]
request_cache = params["cache"]
queries_file = params["queries_file"]
vector_field = params["vector_field"]
target_k = max(params["target_k"], k)
recall_total = 0
exact_total = 0
min_recall = k

knn_vector_store: KnnVectorStore = KnnVectorStore.get_instance(queries_file, vector_field)
knn_vector_store: KnnVectorStore = params["knn_vector_store"]
for query_id, query_vector in knn_vector_store.get_query_vectors().items():
knn_result = await es.search(
body={
Expand Down

0 comments on commit 134baa1

Please sign in to comment.