Skip to content

Commit

Permalink
add cross encoder and bi encoders implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed May 30, 2024
1 parent 22aff92 commit a63fe50
Show file tree
Hide file tree
Showing 18 changed files with 337 additions and 192 deletions.
Empty file.
2 changes: 1 addition & 1 deletion pirate/chains/base.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
class BaseChain:
class Chain:
pass
6 changes: 3 additions & 3 deletions pirate/chains/mine_chain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pirate.models.chains import MineChainConfig
from pirate.chains.base import Chain

class MineChain:
def __init__(self, config: MineChainConfig):
class MineChain(Chain):
def __init__(self, ):
pass
29 changes: 28 additions & 1 deletion pirate/data/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, ranking: Union[str, List]):
"""
self.load(ranking)


def load(self, ranking: Union[str, List]):
"""
Load ranking from a file or a list.
Expand All @@ -41,6 +42,7 @@ def load(self, ranking: Union[str, List]):
else:
raise NotImplementedError(f"Type {type(ranking)} not supported")


def save(self, path: str):
"""
Save the ranking to a file.
Expand All @@ -60,6 +62,7 @@ def save(self, path: str):
else:
raise NotImplementedError(f"Extension {ext} not supported")


def get_passage_groups(self, qid: str) -> pl.DataFrame:
"""
Get the passage groups for a given query ID.
Expand All @@ -71,6 +74,22 @@ def get_passage_groups(self, qid: str) -> pl.DataFrame:
A DataFrame with the passage groups for the given query ID.
"""
return self.data.filter(pl.col("qid") == qid).sort("rank")


def filter_by_score(self, threshold: float) -> "Ranking":
"""
Filter the ranking by score.
Args:
threshold: The threshold score.
Returns:
A Ranking with the rows that have a score greater than the threshold.
"""
df = self.data.filter(pl.col("score") > threshold)

return Ranking(df.rows())


def _from_json(self, path: str):
"""
Expand All @@ -81,6 +100,7 @@ def _from_json(self, path: str):
"""
self.data = pl.read_ndjson(path, schema={"qid": pl.String, "pid": pl.String, "rank": pl.Int32, "score": pl.Float64})


def _from_csv(self, path: str):
"""
Load ranking from a CSV file.
Expand All @@ -90,6 +110,7 @@ def _from_csv(self, path: str):
"""
self.data = pl.read_csv(path, columns=["qid", "pid", "rank", "score"])


def _from_list(self, ranking: List):
"""
Load ranking from a list.
Expand All @@ -99,6 +120,7 @@ def _from_list(self, ranking: List):
"""
self.data = pl.DataFrame(ranking, schema=["qid", "pid", "rank", "score"])


def _to_json(self, path: str):
"""
Save the ranking to a JSON file.
Expand All @@ -108,6 +130,7 @@ def _to_json(self, path: str):
"""
self.data.write_ndjson(path)


def _to_csv(self, path: str):
"""
Save the ranking to a CSV file.
Expand All @@ -117,10 +140,12 @@ def _to_csv(self, path: str):
"""
self.data.write_csv(path)


def __getitem__(self, key):
""" Return the value of the key. """
return self.data[key]


def __repr__(self):
""" Return the string representation of the Ranking object. """
string = textwrap.dedent(
Expand All @@ -135,10 +160,12 @@ def __repr__(self):

return string


def __len__(self):
""" Return the number of rows in the ranking. """
return len(self.data)


def __iter__(self):
""" Return an iterator over the ranking. """
return iter(self.data)
return iter(self.data)
56 changes: 44 additions & 12 deletions pirate/miner/base.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,40 @@
from typing import List
import random

from typing import List, Any
from abc import ABC, abstractmethod

from pirate.data import Triples
from pirate.models import Encoder
from pirate.data.triples import Triples
from pirate.retrievers import (
BaseRetriever,
BM25Retriever,
BiEncoder,
CrossEncoder
)

class BaseMiner(ABC):
"""
BaseMiner is an abstract base class that provides methods for mining data.
"""

def __init__(
self,
triples: Triples,
*args,
**kwargs
):
def __init__(self, mining_params: Any):
"""
Initialize the BaseMiner object.
Args:
data: The data to be mined.
"""
pass
self.mining_params = mining_params


@abstractmethod
def mine(
self,
num_negs_per_pair: int = 1,
exclude_pairs: List[List[str]] = []
):
exclude_pairs: List[List[str]] = [],
*args,
**kwargs
) -> Triples:
"""
Mine negative samples from the data.
Expand All @@ -38,4 +45,29 @@ def mine(
Returns:
A list of new triples.
"""
pass
pass


def _seed(self):
if self.mining_params.seed is not None:
random.seed(self.mining_params.seed)


def _get_model(self) -> BaseRetriever:
model = self.mining_params.model

match model:
case Encoder.BM25 | Encoder.BM25L | Encoder.BM25PLUS:
model = BM25Retriever(model)
case Encoder.BIENCODER:
raise ValueError("BiEncoder not supported for mining yet.")
model = BiEncoder(model)
case Encoder.CROSSENCODER:
raise ValueError("CrossEncoder not supported for mining yet.")
model = CrossEncoder(model)
case BaseRetriever():
model = model
case _:
raise ValueError("Invalid model.")

return model
75 changes: 75 additions & 0 deletions pirate/miner/hard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import random
from typing import List, Optional

from tqdm import tqdm

from pirate.miner.base import BaseMiner
from pirate.data import (
Passages,
Queries,
Triples
)
from pirate.models import (
Sampling,
HardMinerParams
)

class HardMiner(BaseMiner):
def __init__(self, mining_params: HardMinerParams):
super().__init__(mining_params)

self.mining_params = mining_params
self.triples = self.mining_params.triples

assert self.triples is not None, "Triples must be provided."
assert len(self.triples) > 0, "Triples must not be empty."
assert max([len(i) for i in self.triples]) == 2, "Triples must be in the pair format [qid, pid]."

passage_dict = {}
query_dict = {}
for qid, pid in self.triples:
assert qid in self.mining_params.queries, f"Query {qid} not found."
assert pid in self.mining_params.passages, f"Passage {pid} not found."

passage_dict[pid] = self.mining_params.passages[pid]
query_dict[qid] = self.mining_params.queries[qid]

self.passages = Passages(passage_dict)
self.queries = Queries(query_dict)

self.encoder = self._get_model()
self._seed()


def mine(
self,
num_negs_per_pair: int = 1,
exclude_pairs: Optional[List[List[str]]] = None
) -> Triples:
self.encoder.index(self.passages)
rankings = self.encoder.rank_passages(self.queries, self.mining_params.top_k)

if self.mining_params.score_threshold:
rankings = rankings.filter_by_score(self.mining_params.score_threshold)

triples_list = []
for qid, pos_pid in tqdm(self.triples, desc="Mining hard negatives", total=len(self.triples), disable=self.mining_params.verbose):
if exclude_pairs and [qid, pos_pid] in exclude_pairs:
continue

passage_groups = rankings.get_passage_groups(qid)["pid"].to_list()

passage_sample_set = []
match self.mining_params.sampling:
case Sampling.RANDOM:
passage_sample_set = passage_groups
case Sampling.RTOP_K:
passage_sample_set = passage_groups[:self.mining_params.top_k] if self.mining_params.top_k else passage_groups

random_negative_passages = random.sample(passage_sample_set, num_negs_per_pair)

for neg_pid in random_negative_passages:
triples_list.append([qid, pos_pid, neg_pid])

triples = Triples(triples_list)
return triples
100 changes: 0 additions & 100 deletions pirate/miner/hard_neg.py

This file was deleted.

Loading

0 comments on commit a63fe50

Please sign in to comment.