Skip to content

Commit

Permalink
allow arbitrary types, change file names, deflauth import patch
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed May 15, 2024
1 parent f29e03a commit a87da05
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 18 deletions.
5 changes: 5 additions & 0 deletions pirate/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .base import BaseData
from .passages import Passages
from .queries import Queries
from .triples import Triples
from .ranking import Ranking
2 changes: 1 addition & 1 deletion pirate/data/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def from_json(self, path: str):
Args:
path: The path to the JSON file from which the ranking will be loaded.
"""
self.data = pl.read_json(path, columns=["qid", "pid", "rank", "score"])
self.data = pl.read_json(path, schema={"qid": pl.String, "pid": pl.String, "rank": pl.Int32, "score": pl.Float64})

def from_csv(self, path: str):
"""
Expand Down
12 changes: 6 additions & 6 deletions pirate/data/triple.py → pirate/data/triples.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,27 @@ def from_csv(self, path: str) -> List[List[str]]:
with open(path, "r") as f:
return [[item.strip() for item in line.split(",")] for line in f]

def to_json(self, path: str):
def to_json(self, path: str) -> None:
"""
Save the triples to a JSON file.
Args:
path: The path to the file where the triples will be saved.
"""
with open(path, "w") as f:
for qid, ppid, npid in self.triples:
f.write(json.dumps([qid, ppid, npid]) + "\n")
for triple in self.triples:
f.write(json.dumps(list(triple)) + "\n")

def to_csv(self, path: str):
def to_csv(self, path: str) -> None:
"""
Save the triples to a CSV file.
Args:
path: The path to the file where the triples will be saved.
"""
with open(path, "w") as f:
for qid, ppid, npid in self.triples:
f.write(f"{qid},{ppid},{npid}\n")
for triple in self.triples:
f.write(f"{','.join(triple)}\n")

def __getitem__(self, index: int) -> List[str]:
""" Return the triple at the given index. """
Expand Down
19 changes: 12 additions & 7 deletions pirate/miner/score_thresh_neg.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import List

from pirate.models import Encoder
from pirate.models.mining import MiningParams
from pirate.data.triples import Triples
from pirate.models import (
Encoder,
ScoreThresholdMinerParams
)
from pirate.retrievers import (
BaseRetriever,
BM25Retriever,
Expand All @@ -10,20 +13,22 @@
)

class ScoreThresholdMiner:
def __init__(self, sampling_params: MiningParams):
def __init__(self, sampling_params: ScoreThresholdMinerParams):
self.sampling_params = sampling_params

self.encoder = self.get_model()
self.encoder = self._get_model()

def get_model(self) -> BaseRetriever:
def _get_model(self) -> BaseRetriever:
model = self.sampling_params.model

match model:
case Encoder.BM25 | Encoder.BM25L | Encoder.BM25PLUS:
model = BM25Retriever(model)
case Encoder.BIENCODER:
raise ValueError("BiEncoder not supported for mining yet.")
model = BiEncoder(model)
case Encoder.CROSSENCODER:
raise ValueError("CrossEncoder not supported for mining yet.")
model = CrossEncoder(model)
case BaseRetriever():
model = model
Expand All @@ -36,7 +41,7 @@ def mine(
self,
num_negs_per_pair: int = 1,
exclude_pairs: List[List[str]] = []
):
) -> Triples:
rankings = self.encoder.rank(self.passages, self.queries, self.sampling_params.top_k)


triples = []
3 changes: 2 additions & 1 deletion pirate/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .types import *
from .types import *
from .mining import *
4 changes: 3 additions & 1 deletion pirate/models/mining.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from typing import Optional, Union
from pydantic import BaseModel, ConfigDict

from pirate.models import Encoder
from pirate.models.types import Sampling
Expand All @@ -11,6 +11,8 @@
)

class MiningParams(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

passages: Passages
queries: Queries
triples: Triples
Expand Down
4 changes: 4 additions & 0 deletions pirate/retrievers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import BaseRetriever
from .bm25 import BM25Retriever
from .bi_encoder import BiEncoder
from .cross_encoder import CrossEncoder
4 changes: 2 additions & 2 deletions pirate/retrievers/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def index(self, corpus: Passages):
logger.info("Finished indexing corpus.")

def rank(self, queries: Queries, top_k: Optional[int] = None) -> Ranking:
if self.indexed_corpus is None:
if self.indexed_corpus is None or self.corpus is None:
raise ValueError("Index not built. Please call the index method first.")

tokenized_queries = [self.tokenizer(queries[query_id]) for query_id in queries]
Expand All @@ -74,7 +74,7 @@ def rank(self, queries: Queries, top_k: Optional[int] = None) -> Ranking:
sorted_indices = score_array.argsort()[::-1]

for j, doc_id in enumerate(sorted_indices):
ranking_list.append([query_id, self.corpus[doc_id], score_array[doc_id], j])
ranking_list.append([query_id, self.index_id_lookup[doc_id], score_array[doc_id], j])

ranking = Ranking(ranking_list)

Expand Down

0 comments on commit a87da05

Please sign in to comment.