diff --git a/pirate/data/__init__.py b/pirate/data/__init__.py index e69de29..697c41e 100644 --- a/pirate/data/__init__.py +++ b/pirate/data/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseData +from .passages import Passages +from .queries import Queries +from .triples import Triples +from .ranking import Ranking \ No newline at end of file diff --git a/pirate/data/ranking.py b/pirate/data/ranking.py index 667f609..2e29dc6 100644 --- a/pirate/data/ranking.py +++ b/pirate/data/ranking.py @@ -67,7 +67,7 @@ def from_json(self, path: str): Args: path: The path to the JSON file from which the ranking will be loaded. """ - self.data = pl.read_json(path, columns=["qid", "pid", "rank", "score"]) + self.data = pl.read_json(path, schema={"qid": pl.String, "pid": pl.String, "rank": pl.Int32, "score": pl.Float64}) def from_csv(self, path: str): """ diff --git a/pirate/data/triple.py b/pirate/data/triples.py similarity index 92% rename from pirate/data/triple.py rename to pirate/data/triples.py index 5534399..f3b5d36 100644 --- a/pirate/data/triple.py +++ b/pirate/data/triples.py @@ -86,7 +86,7 @@ def from_csv(self, path: str) -> List[List[str]]: with open(path, "r") as f: return [[item.strip() for item in line.split(",")] for line in f] - def to_json(self, path: str): + def to_json(self, path: str) -> None: """ Save the triples to a JSON file. @@ -94,10 +94,10 @@ def to_json(self, path: str): path: The path to the file where the triples will be saved. """ with open(path, "w") as f: - for qid, ppid, npid in self.triples: - f.write(json.dumps([qid, ppid, npid]) + "\n") + for triple in self.triples: + f.write(json.dumps(list(triple)) + "\n") - def to_csv(self, path: str): + def to_csv(self, path: str) -> None: """ Save the triples to a CSV file. @@ -105,8 +105,8 @@ def to_csv(self, path: str): path: The path to the file where the triples will be saved. """ with open(path, "w") as f: - for qid, ppid, npid in self.triples: - f.write(f"{qid},{ppid},{npid}\n") + for triple in self.triples: + f.write(f"{','.join(triple)}\n") def __getitem__(self, index: int) -> List[str]: """ Return the triple at the given index. """ diff --git a/pirate/miner/score_thresh_neg.py b/pirate/miner/score_thresh_neg.py index ee5f41a..f965594 100644 --- a/pirate/miner/score_thresh_neg.py +++ b/pirate/miner/score_thresh_neg.py @@ -1,7 +1,10 @@ from typing import List -from pirate.models import Encoder -from pirate.models.mining import MiningParams +from pirate.data.triples import Triples +from pirate.models import ( + Encoder, + ScoreThresholdMinerParams +) from pirate.retrievers import ( BaseRetriever, BM25Retriever, @@ -10,20 +13,22 @@ ) class ScoreThresholdMiner: - def __init__(self, sampling_params: MiningParams): + def __init__(self, sampling_params: ScoreThresholdMinerParams): self.sampling_params = sampling_params - self.encoder = self.get_model() + self.encoder = self._get_model() - def get_model(self) -> BaseRetriever: + def _get_model(self) -> BaseRetriever: model = self.sampling_params.model match model: case Encoder.BM25 | Encoder.BM25L | Encoder.BM25PLUS: model = BM25Retriever(model) case Encoder.BIENCODER: + raise ValueError("BiEncoder not supported for mining yet.") model = BiEncoder(model) case Encoder.CROSSENCODER: + raise ValueError("CrossEncoder not supported for mining yet.") model = CrossEncoder(model) case BaseRetriever(): model = model @@ -36,7 +41,7 @@ def mine( self, num_negs_per_pair: int = 1, exclude_pairs: List[List[str]] = [] - ): + ) -> Triples: rankings = self.encoder.rank(self.passages, self.queries, self.sampling_params.top_k) - \ No newline at end of file + triples = [] \ No newline at end of file diff --git a/pirate/models/__init__.py b/pirate/models/__init__.py index 4192f6a..021fb37 100644 --- a/pirate/models/__init__.py +++ b/pirate/models/__init__.py @@ -1 +1,2 @@ -from .types import * \ No newline at end of file +from .types import * +from .mining import * \ No newline at end of file diff --git a/pirate/models/mining.py b/pirate/models/mining.py index d851819..a4680ed 100644 --- a/pirate/models/mining.py +++ b/pirate/models/mining.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel from typing import Optional, Union +from pydantic import BaseModel, ConfigDict from pirate.models import Encoder from pirate.models.types import Sampling @@ -11,6 +11,8 @@ ) class MiningParams(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + passages: Passages queries: Queries triples: Triples diff --git a/pirate/retrievers/__init__.py b/pirate/retrievers/__init__.py index e69de29..1aeb114 100644 --- a/pirate/retrievers/__init__.py +++ b/pirate/retrievers/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseRetriever +from .bm25 import BM25Retriever +from .bi_encoder import BiEncoder +from .cross_encoder import CrossEncoder \ No newline at end of file diff --git a/pirate/retrievers/bm25.py b/pirate/retrievers/bm25.py index 95f977a..6c865e2 100644 --- a/pirate/retrievers/bm25.py +++ b/pirate/retrievers/bm25.py @@ -52,7 +52,7 @@ def index(self, corpus: Passages): logger.info("Finished indexing corpus.") def rank(self, queries: Queries, top_k: Optional[int] = None) -> Ranking: - if self.indexed_corpus is None: + if self.indexed_corpus is None or self.corpus is None: raise ValueError("Index not built. Please call the index method first.") tokenized_queries = [self.tokenizer(queries[query_id]) for query_id in queries] @@ -74,7 +74,7 @@ def rank(self, queries: Queries, top_k: Optional[int] = None) -> Ranking: sorted_indices = score_array.argsort()[::-1] for j, doc_id in enumerate(sorted_indices): - ranking_list.append([query_id, self.corpus[doc_id], score_array[doc_id], j]) + ranking_list.append([query_id, self.index_id_lookup[doc_id], score_array[doc_id], j]) ranking = Ranking(ranking_list)