Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
c7nw3r committed Sep 5, 2023
2 parents 9174912 + bea8b67 commit 76eb6de
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
6 changes: 3 additions & 3 deletions localsearch/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def add(self, docs: Documents, batch_size: int | None = None, verbose: bool = Fa
for idx in tqdm(idxs, total=len(idxs)) if verbose else idxs:
docs_batch = docs[idx: idx+batch_size]

for idx, doc in enumerate(docs_batch, self._get_start_idx()):
write_json(Path(self._raw_data_dir) / f"{idx}.json", asdict(doc))

for writer in self._writers:
writer.append(docs_batch)

for idx, doc in enumerate(docs_batch, self._get_start_idx()):
write_json(Path(self._raw_data_dir) / f"{idx}.json", asdict(doc))

def _get_start_idx(self) -> int:
idxs = [
int(fn.removesuffix(".json")) for fn in os.listdir(self._raw_data_dir)
Expand Down
34 changes: 20 additions & 14 deletions localsearch/searcher/annoy_search.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import os
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Union, Literal, Dict
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

import numpy as np

from localsearch.__spi__ import IndexedDocument, Encoder, ScoredDocument, Document
from localsearch.__spi__ import Document, Encoder, IndexedDocument, ScoredDocument
from localsearch.__spi__.types import Searcher
from localsearch.__util__.array_utils import cosine_similarity
from localsearch.__util__.io_utils import read_json, write_json, grep, list_files, delete_file, delete_folder
from localsearch.__util__.io_utils import delete_file, delete_folder, grep, list_files, read_json, write_json


@dataclass
class AnnoyConfig:
path: str
raw_data_dir: Optional[str] = None
raw_data_dir: Optional[str] = None # TODO: better handling in AnnoySearch methods
n: int = 5
k: int = 5
n_trees: int = 10
search_k: int = -1 # defaults to n_trees * n
index_name: Optional[str] = "annoy"
index_fields: Optional[List[str]] = field(default_factory=lambda: ["text"])
metric: Literal["angular", "euclidean", "manhattan", "hamming", "dot"] = "euclidean"
Expand Down Expand Up @@ -48,10 +50,11 @@ def __init__(self, config: AnnoyConfig, encoder: Encoder):

def read(self, text: str, n: Optional[int] = None) -> List[ScoredDocument]:
vector = self.encoder(text)
indices = self.index.get_nns_by_vector(vector, n or self.config.n, search_k=self.config.k)
indices = self.index.get_nns_by_vector(vector, n or self.config.n, self.config.search_k)
vectors = [self.index.get_item_vector(i) for i in indices]
scores = [cosine_similarity(np.array(item), vector) for item in vectors]
indices = [self.id_map[e] for e in indices]
if not self.config.raw_data_dir:
indices = [self.id_map[e] for e in indices]

documents = [self._read_document(idx) for idx in indices]
return [ScoredDocument(s, d) for s, d in zip(scores, documents)]
Expand Down Expand Up @@ -101,21 +104,25 @@ def _rebuild(self):
new_index = self.AnnoyIndex(self.encoder.get_output_dim(), self.config.metric)
folder = self.config.raw_data_dir if self.config.raw_data_dir else self.path.replace(".ann", "")
for path in list_files(folder, recursive=True):
if path.endswith(".json"):
idx = int(path.replace(".json", "").split("_")[1])
path = Path(path)
if path.suffix == ".json":
idx = int(path.stem) if self.config.raw_data_dir else int(path.stem.split("_")[1])
vector = self.index.get_item_vector(idx)
new_index.add_item(idx, vector)

os.remove(self.path)
self.index = new_index

def _save(self):
self.index.build(10)
self.index.build(self.config.n_trees)
self.index.save(self.path)

def _read_document(self, idx: str) -> IndexedDocument:
folder = self.config.raw_data_dir if self.config.raw_data_dir else self.path.replace(".ann", "")
return IndexedDocument(**read_json(grep(folder, idx)), index=self.config.index_name)
if folder := self.config.raw_data_dir:
return IndexedDocument(**read_json(Path(folder) / f"{idx}.json"), index=self.config.index_name)
else:
folder = self.path.replace(".ann", "")
return IndexedDocument(**read_json(grep(folder, idx)), index=self.config.index_name)

def search_by_source(self, source: str, n: Optional[int] = None) -> List[Document]:
folder = self.config.raw_data_dir if self.config.raw_data_dir else self.path.replace(".ann", "")
Expand All @@ -128,4 +135,3 @@ def remove_by_source(self, source: str):

self._rebuild()
self._save()

3 changes: 1 addition & 2 deletions localsearch/source_repo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ def get_full_context(
source: StructuredSource,
chars_before: int,
chars_after: int,
source_id_field: str = "source_id",
source_part_field: str = "source_part",
text_start_idx_field: str = "text_start_idx",
doc_title_prefix: str = "Document title:",
section_title_prefix: str = "Section_title:"
section_title_prefix: str = "Section title:"
) -> str:

source_part: int = result.document.fields[source_part_field]
Expand Down

0 comments on commit 76eb6de

Please sign in to comment.