Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update automatch #5

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,31 @@ jobs:
shell: bash
run: |
pip install -r requirements.txt
pip install --extra-index-url https://pypi.org/simple --no-cache-dir pytest codecov-cli==0.4.0
pip install --extra-index-url https://pypi.org/simple --no-cache-dir pytest codecov-cli>=0.4.1

- name: Run unittest
shell: bash
run: python -m unittest discover -s ./tests -p 'test_*.py'

- name: Statistics
- name: Codecov startup
if: success()
run: |
codecovcli create-commit
codecovcli create-report
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- name: Static Analysis
run: |
codecovcli static-analysis --token=${CODECOV_STATIC_TOKEN} \
--folders-to-exclude .artifacts \
--folders-to-exclude .github \
--folders-to-exclude .venv \
--folders-to-exclude static \
--folders-to-exclude bin
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
CODECOV_STATIC_TOKEN: ${{ secrets.CODECOV_STATIC_TOKEN }}

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4.0.1
if: always()
continue-on-error: true
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: longxingtan/open-retrievals

docs:
name: Test docs build
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,5 @@ coverage.xml
/weights/checkpoint.index
/conda/*
**/.pdf
/encode.py
/.build/*
103 changes: 54 additions & 49 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,13 @@ pip install peft
pip install open-retrievals
```

**With conda**
```shell
conda install open-retrievals -c conda-forge
```
[//]: # (**With conda**)

[//]: # (```shell)

[//]: # (conda install open-retrievals -c conda-forge)

[//]: # (```)


## Usage
Expand All @@ -69,49 +72,39 @@ print(sentence_embeddings)

**Finetune transformers by contrastive learning**
```python
from retrievals import AutoModelForEmbedding, AutoModelForMatch, RetrievalTrainer, PairCollator
from dataclasses import dataclass, field
import transformers
from transformers import AutoTokenizer
from retrievals import AutoModelForEmbedding, AutoModelForMatch, RetrievalTrainer, PairCollator, TripletCollator
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
from retrievals.data import RetrievalDataset, RerankDataset

tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, use_fast=False)
train_dataset = RetrievalDataset(args=data_args)

train_dataset = RetrievalDataset(topic_df, tokenizer, CFG.max_len, aug=False)

train_loader = DataLoader(
train_dataset,
batch_size=CFG.batch_size,
shuffle=True,
num_workers=CFG.num_workers,
pin_memory=False,
drop_last=True,
model = AutoModelForEmbedding(
model_args.model_name_or_path,
pooling_method="cls"
)
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)

loss_fn = ArcFaceAdaptiveMarginLoss(
criterion=cross_entropy,
in_features=768,
out_features=CFG.num_classes,
scale=CFG.arcface_scale,
margin=CFG.arcface_margin,
)
model = AutoModelForEmbedding(CFG.MODEL_NAME, pooling_method="cls", loss_fn=loss_fn)
lr_scheduler = get_scheduler(optimizer, num_train_steps=int(len(train_dataset) / 2 * 1))

optimizer = get_optimizer(model, lr=CFG.learning_rate)
scheduler = get_scheduler(
optimizer=optimizer, cfg=CFG, total_steps=len(train_dataset)
)
trainer = CustomTrainer(model, device="cuda", apex=CFG.apex)
trainer.train(
train_loader=train_loader,
criterion=None,
optimizer=optimizer,
epochs=CFG.epochs,
scheduler=scheduler,
dynamic_margin=True,
trainer = RetrievalTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=TripletCollator(tokenizer, max_length=data_args.query_max_len),
loss_fn=TripletLoss(),
)
torch.save(model.state_dict(), CFG.output_dir + f"model_{CFG.exp_id}.pth")
trainer.optimizer = optimizer
trainer.scheduler = lr_scheduler
trainer.train()
```

**Finetune LLM for embedding by Contrastive learning**
```python
# just change the model in AutoModelForEmbedding
from retrievals import AutoModelForEmbedding

model = AutoModelForEmbedding('llama', pooling_method='last', query_instruction='')
Expand All @@ -137,26 +130,38 @@ retrieval_model.query(method='knn')

```

**RAG with LangChain**
[//]: # (**RAG with LangChain**)

- Prerequisites
```shell
pip install langchain
```
[//]: # ()
[//]: # (- Prerequisites)

- Server
```python
[//]: # (```shell)

```
[//]: # (pip install langchain)

**RAG with LLamaIndex**
```shell
pip install llamaindex
```
[//]: # (```)

```python
[//]: # ()
[//]: # (- Server)

```
[//]: # (```python)

[//]: # ()
[//]: # (```)

[//]: # (**RAG with LLamaIndex**)

[//]: # (```shell)

[//]: # (pip install llamaindex)

[//]: # (```)

[//]: # ()
[//]: # (```python)

[//]: # ()
[//]: # (```)


## Reference & Acknowledge
Expand Down
34 changes: 21 additions & 13 deletions src/retrievals/models/match_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,31 @@
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from torch import nn
from tqdm import tqdm

logger = logging.getLogger(__name__)


class AutoModelForMatch(nn.Module):
def __init__(self, method="knn") -> None:
class AutoModelForMatch(object):
def __init__(self, method="cosine") -> None:
super().__init__()
self.method = method

def similarity_search(
self, query_embed, passage_embed=None, top_k=1, query_chunk_size=5, corpus_chunk_size=17, **kwargs
self, query_embed: torch.Tensor, passage_embed: torch.Tensor, top_k: int = 1, batch_size: int = 0, **kwargs
):
""" """
if self.method == "knn":
neighbors_model = NearestNeighbors(n_neighbors=top_k, metric="cosine", n_jobs=-1)
neighbors_model.fit(passage_embed)
dists, indices = neighbors_model.kneighbors(query_embed)
return dists, indices

elif self.method == "cosine":
dists, indices = cosine_similarity_search(query_embed, passage_embed, top_k=top_k)
dists, indices = cosine_similarity_search(query_embed, passage_embed, top_k=top_k, batch_size=batch_size)
return dists, indices

else:
raise ValueError
raise ValueError(f"Only cosine and knn method are supported by similarity_search, while get {self.method}")

def faiss_search(
self,
Expand All @@ -55,20 +53,27 @@ def get_rerank_df(self):


def cosine_similarity_search(
query_embed,
passage_embed,
query_embed: torch.Tensor,
passage_embed: torch.Tensor,
top_k: int = 1,
batch_size: int = 512,
penalty: bool = True,
temperature: float = 0,
convert_to_numpy: bool = True,
):
chunk = batch_size
if len(query_embed.size()) == 1:
query_embed = query_embed.view(1, -1)
assert query_embed.size()[1] == passage_embed.size()[1], (
f"The embed Shape of query_embed and passage_embed should be same, "
f"while received query {query_embed.size()} and passage {passage_embed.size()}"
)
chunk = batch_size if batch_size > 0 else len(query_embed)
embeddings_chunks = query_embed.split(chunk)

vals = []
inds = []
for idx in range(len(embeddings_chunks)):
cos_sim_chunk = torch.mm(embeddings_chunks[idx], passage_embed.transpose(0, 1))
cos_sim_chunk = torch.matmul(embeddings_chunks[idx], passage_embed.transpose(0, 1))
cos_sim_chunk = torch.nan_to_num(cos_sim_chunk, nan=0.0)
# if penalty:
# pen = ((contents["old_source_count"].values==0) & (contents["old_nonsource_count"].values==1))
Expand All @@ -81,8 +86,11 @@ def cosine_similarity_search(
vals_chunk, inds_chunk = torch.topk(cos_sim_chunk, k=top_k, dim=1)
vals.append(vals_chunk[:, :].detach().cpu())
inds.append(inds_chunk[:, :].detach().cpu())
vals = torch.cat(vals).detach().cpu().numpy()
inds = torch.cat(inds).detach().cpu().numpy()
vals = torch.cat(vals)
inds = torch.cat(inds)
if convert_to_numpy:
vals = vals.numpy()
inds = inds.numpy()
return inds, vals


Expand Down
Loading