Skip to content

Commit

Permalink
numpy dependency set to optional, setup.py version set to 0.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
c7nw3r committed Aug 23, 2024
1 parent 3a0fcd3 commit 75aedc8
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 10 deletions.
3 changes: 2 additions & 1 deletion localsearch/__util__/array_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import List, Dict

import numpy as np
from numpy import matmul
from numpy.linalg import norm

from localsearch import ScoredDocument


def cosine_similarity(a, b) -> float:
import numpy as np

a = np.expand_dims(a, axis=0) if len(a.shape) == 1 else a
b = np.expand_dims(b, axis=0) if len(b.shape) == 1 else b
a_norm = a / norm(a, ord=2, axis=1, keepdims=True)
Expand Down
2 changes: 1 addition & 1 deletion localsearch/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from typing import List, Optional

import numpy as np
from tqdm import tqdm

from localsearch.__spi__.model import RankedDocument, ScoredDocument, Documents
Expand Down Expand Up @@ -31,6 +30,7 @@ def search(
query: str,
config: SearchConfig = SearchConfig()
) -> List[RankedDocument]:
import numpy as np

results = flatten([reader.search_by_text(query) for reader in self.readers])
# results = unique(results, lambda x: x.document.id)
Expand Down
5 changes: 1 addition & 4 deletions localsearch/searcher/annoy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from pathlib import Path
from typing import List, Literal, Optional, Union

import numpy as np

from localsearch.__spi__ import Document, Encoder, IndexedDocument, ScoredDocument
from localsearch.__spi__.types import Searcher
from localsearch.__util__.array_utils import cosine_similarity
Expand Down Expand Up @@ -45,12 +43,11 @@ def __init__(self, config: AnnoyConfig, encoder: Encoder):
raise ValueError("no annoy library found, please install localsearch[annoy]")

def search_by_text(self, text: str, n: Optional[int] = None) -> List[ScoredDocument]:
import numpy as np
vector = self.encoder(text)
indices = self.index.get_nns_by_vector(vector, n or self.config.n, self.config.search_k)
vectors = [self.index.get_item_vector(i) for i in indices]
scores = [cosine_similarity(np.array(item), vector) for item in vectors]
# if not self.config.raw_data_dir:
# indices = [self.id_map[e] for e in indices]

documents = [self._read_document(idx) for idx in indices]
return [ScoredDocument(s, d) for s, d in zip(scores, documents)]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ stop-words
simplemma
pysbd==0.3.4
networkx==3.1
numpy
numpy==1.26.4
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
numpy
tqdm
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

long_description = Path(__file__).with_name("README.md").read_text()

version = "0.2.0"
version = "0.2.1"

setup(
name='localsearch',
Expand Down Expand Up @@ -46,7 +46,8 @@
"tantivy@git+https://github.com/leftshiftone/tantivy-py.git#egg=tantivy",
"stop-words==2018.7.23",
"simplemma==0.9.1",
"pysbd==0.3.4"
"pysbd==0.3.4",
"numpy==1.26.4"
],
'networkx': [
"networkx==3.1"
Expand Down

0 comments on commit 75aedc8

Please sign in to comment.