Skip to content

Commit

Permalink
makefile support, pyright patches, ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed May 16, 2024
1 parent a87da05 commit 19b605f
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 13 deletions.
30 changes: 30 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@

# Define variables
PYTHON = python3
PIP = pip3
RUFF = ruff
PYRIGHT = pyright
PYTEST = pytest
MKDOCS = mkdocs

# Define targets
.PHONY: all lint type-check test docs clean

all: lint type-check test docs

lint:
$(RUFF) check .

type-check:
$(PYRIGHT) .

test:
$(PYTEST) tests/

docs:
$(MKDOCS) build

clean:
rm -rf __pycache__
rm -rf .pytest_cache
rm -rf site
2 changes: 1 addition & 1 deletion pirate/data/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def to_csv(self, path: str):
Args:
path: The path to the file where the ranking will be saved.
"""
self.data.write_ndjson(path, include_header=False, seperator=",")
self.data.write_ndjson(path)

def __getitem__(self, key):
""" Return the value of the key. """
Expand Down
18 changes: 14 additions & 4 deletions pirate/miner/score_thresh_neg.py → pirate/miner/hard_neg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pirate.data.triples import Triples
from pirate.models import (
Encoder,
ScoreThresholdMinerParams
HardNegativesMinerParams
)
from pirate.retrievers import (
BaseRetriever,
Expand All @@ -12,10 +12,14 @@
CrossEncoder
)

class ScoreThresholdMiner:
def __init__(self, sampling_params: ScoreThresholdMinerParams):
class HardNegativesMiner:
def __init__(self, sampling_params: HardNegativesMinerParams):
self.sampling_params = sampling_params

self.passages = self.sampling_params.passages
self.queries = self.sampling_params.queries
self.triples = self.sampling_params.triples

self.encoder = self._get_model()

def _get_model(self) -> BaseRetriever:
Expand Down Expand Up @@ -44,4 +48,10 @@ def mine(
) -> Triples:
rankings = self.encoder.rank(self.passages, self.queries, self.sampling_params.top_k)

triples = []
triples_list = []
for query_id in self.queries:
for i, passage_id in enumerate(self.passages):
pass

triples = Triples(triples_list)
return triples
Empty file added pirate/miner/in_batch_neg.py
Empty file.
4 changes: 2 additions & 2 deletions pirate/models/mining.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ class MiningParams(BaseModel):
queries: Queries
triples: Triples

class ScoreThresholdMinerParams(MiningParams):
threshold: float = 0.8
class HardNegativesMiner(MiningParams):
passage_threshold: float
top_k: Optional[int] = None

model: Union[BaseRetriever, Encoder]
Expand Down
18 changes: 16 additions & 2 deletions pirate/retrievers/base.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
class BaseRetriever:
pass
from abc import ABC, abstractmethod

from pirate.data import (
Passages,
Ranking
)


class BaseRetriever(ABC):
@abstractmethod
def index(self, corpus: Passages):
pass

@abstractmethod
def rank(self, *args, **kwargs) -> Ranking:
pass
5 changes: 2 additions & 3 deletions pirate/retrievers/bm25.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import numpy as np

from rank_bm25 import (
BM25,
BM25Okapi,
BM25L,
BM25Plus
)
from loguru import logger
from typing import Optional, Callable
from typing import Any, Optional, Callable

from pirate.data.ranking import Ranking

Expand All @@ -27,7 +26,7 @@ def __init__(self, model: Encoder, tokenizer: Optional[Callable] = None):
self.indexed_corpus = None
self.corpus = None

def _get_model(self, model: Encoder) -> BM25:
def _get_model(self, model: Encoder) -> Any:
match model:
case Encoder.BM25:
return BM25Okapi
Expand Down
Empty file added pirate/samplers/__init__.py
Empty file.
Empty file added pirate/samplers/random.py
Empty file.
2 changes: 1 addition & 1 deletion tests/retrievers/test_bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_bm25_rank(sample_passages, sample_queries):
def test_invalid_index_corpus():
retriever = BM25Retriever(Encoder.BM25)
with pytest.raises(ValueError):
retriever.index(["invalid", "corpus", "type"])
retriever.index(["invalid", "corpus", "type"]) # type: ignore

def test_rank_before_index():
retriever = BM25Retriever(Encoder.BM25)
Expand Down

0 comments on commit 19b605f

Please sign in to comment.