diff --git a/examples/rag/collection.jsonl b/examples/rag/collection.jsonl new file mode 100644 index 0000000..da5bf31 --- /dev/null +++ b/examples/rag/collection.jsonl @@ -0,0 +1,5 @@ +{"id":0,"contents":"The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.\n"} +{"id":1,"contents":"The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.\n"} +{"id":2,"contents":"Essay on The Manhattan Project - The Manhattan Project The Manhattan Project was to see if making an atomic bomb possible. The success of this project would forever change the world forever making it known that something this powerful can be manmade.\n"} +{"id":3,"contents":"The Manhattan Project was the name for a project conducted during World War II, to develop the first atomic bomb. It refers specifically to the period of the project from 194 \u00e2\u0080\u00a6 2-1946 under the control of the U.S. Army Corps of Engineers, under the administration of General Leslie R. Groves.\n"} +{"id":4,"contents":"versions of each volume as well as complementary websites. The first website\u00e2\u0080\u0093The Manhattan Project: An Interactive History\u00e2\u0080\u0093is available on the Office of History and Heritage Resources website, http:\/\/www.cfo. doe.gov\/me70\/history. The Office of History and Heritage Resources and the National Nuclear Security\n"} diff --git a/examples/rag/rag_example.ipynb b/examples/rag/rag_example.ipynb new file mode 100644 index 0000000..ace5b13 --- /dev/null +++ b/examples/rag/rag_example.ipynb @@ -0,0 +1,323 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8ca8370c", + "metadata": { + "id": "8ca8370c" + }, + "source": [ + "# Retrieval Augmented Generation\n", + "\n", + "This is an example of RAG for the dataset provided by the user.\n", + "```bash\n", + "pip install -r requirements.txt\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "sWXiPzk8FOye", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sWXiPzk8FOye", + "outputId": "ad957d34-a9bf-4a5a-abcb-3f62872c7ccd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" + ] + } + ], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "RakGybmyFhub", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "RakGybmyFhub", + "outputId": "dcc61dfb-e818-4cbe-8df4-2f402d441814" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.2.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.13.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.10.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch) (2.19.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.2.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.4.99)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Requirement already satisfied: pyserini in /usr/local/lib/python3.10/dist-packages (0.24.0)\n", + "Requirement already satisfied: Cython>=0.29.21 in /usr/local/lib/python3.10/dist-packages (from pyserini) (3.0.9)\n", + "Requirement already satisfied: numpy>=1.18.1 in /usr/local/lib/python3.10/dist-packages (from pyserini) (1.25.2)\n", + "Requirement already satisfied: pandas>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from pyserini) (1.5.3)\n", + "Requirement already satisfied: pyjnius>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from pyserini) (1.6.1)\n", + "Requirement already satisfied: scikit-learn>=0.22.1 in /usr/local/lib/python3.10/dist-packages (from pyserini) (1.2.2)\n", + "Requirement already satisfied: scipy>=1.4.1 in /usr/local/lib/python3.10/dist-packages (from pyserini) (1.11.4)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from pyserini) (4.66.2)\n", + "Requirement already satisfied: transformers>=4.6.0 in /usr/local/lib/python3.10/dist-packages (from pyserini) (4.38.2)\n", + "Requirement already satisfied: sentencepiece>=0.1.95 in /usr/local/lib/python3.10/dist-packages (from pyserini) (0.1.99)\n", + "Requirement already satisfied: nmslib>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from pyserini) (2.1.1)\n", + "Requirement already satisfied: onnxruntime>=1.8.1 in /usr/local/lib/python3.10/dist-packages (from pyserini) (1.17.1)\n", + "Requirement already satisfied: lightgbm>=3.3.2 in /usr/local/lib/python3.10/dist-packages (from pyserini) (4.1.0)\n", + "Requirement already satisfied: spacy>=3.2.1 in /usr/local/lib/python3.10/dist-packages (from pyserini) (3.7.4)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from pyserini) (6.0.1)\n", + "Requirement already satisfied: openai>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from pyserini) (1.14.1)\n", + "Requirement already satisfied: tiktoken>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pyserini) (0.6.0)\n", + "Requirement already satisfied: pybind11<2.6.2 in /usr/local/lib/python3.10/dist-packages (from nmslib>=2.1.1->pyserini) (2.6.1)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from nmslib>=2.1.1->pyserini) (5.9.5)\n", + "Requirement already satisfied: coloredlogs in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.8.1->pyserini) (15.0.1)\n", + "Requirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.8.1->pyserini) (24.3.7)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.8.1->pyserini) (24.0)\n", + "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.8.1->pyserini) (3.20.3)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.8.1->pyserini) (1.12)\n", + "Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from openai>=1.0.0->pyserini) (3.7.1)\n", + "Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from openai>=1.0.0->pyserini) (1.7.0)\n", + "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from openai>=1.0.0->pyserini) (0.27.0)\n", + "Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from openai>=1.0.0->pyserini) (2.6.4)\n", + "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from openai>=1.0.0->pyserini) (1.3.1)\n", + "Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from openai>=1.0.0->pyserini) (4.10.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.4.0->pyserini) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.4.0->pyserini) (2023.4)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.22.1->pyserini) (1.3.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.22.1->pyserini) (3.3.0)\n", + "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (3.0.12)\n", + "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (1.0.5)\n", + "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (1.0.10)\n", + "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (2.0.8)\n", + "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (3.0.9)\n", + "Requirement already satisfied: thinc<8.3.0,>=8.2.2 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (8.2.3)\n", + "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (1.1.2)\n", + "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (2.4.8)\n", + "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (2.0.10)\n", + "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (0.3.4)\n", + "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (0.9.0)\n", + "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (6.4.0)\n", + "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (2.31.0)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (3.1.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (67.7.2)\n", + "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from spacy>=3.2.1->pyserini) (3.3.0)\n", + "Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken>=0.4.0->pyserini) (2023.12.25)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers>=4.6.0->pyserini) (3.13.1)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.6.0->pyserini) (0.20.3)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.6.0->pyserini) (0.15.2)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.6.0->pyserini) (0.4.2)\n", + "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai>=1.0.0->pyserini) (3.6)\n", + "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->openai>=1.0.0->pyserini) (1.2.0)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai>=1.0.0->pyserini) (2024.2.2)\n", + "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->openai>=1.0.0->pyserini) (1.0.4)\n", + "Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->openai>=1.0.0->pyserini) (0.14.0)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers>=4.6.0->pyserini) (2023.6.0)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai>=1.0.0->pyserini) (0.6.0)\n", + "Requirement already satisfied: pydantic-core==2.16.3 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->openai>=1.0.0->pyserini) (2.16.3)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas>=1.4.0->pyserini) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=3.2.1->pyserini) (3.3.2)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=3.2.1->pyserini) (2.0.7)\n", + "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy>=3.2.1->pyserini) (0.7.11)\n", + "Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from thinc<8.3.0,>=8.2.2->spacy>=3.2.1->pyserini) (0.1.4)\n", + "Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer<0.10.0,>=0.3.0->spacy>=3.2.1->pyserini) (8.1.7)\n", + "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from weasel<0.4.0,>=0.1.0->spacy>=3.2.1->pyserini) (0.16.0)\n", + "Requirement already satisfied: humanfriendly>=9.1 in /usr/local/lib/python3.10/dist-packages (from coloredlogs->onnxruntime>=1.8.1->pyserini) (10.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->spacy>=3.2.1->pyserini) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->onnxruntime>=1.8.1->pyserini) (1.3.0)\n", + "Requirement already satisfied: faiss-cpu in /usr/local/lib/python3.10/dist-packages (1.8.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from faiss-cpu) (1.25.2)\n" + ] + } + ], + "source": [ + "!pip install torch\n", + "!pip install pyserini\n", + "!pip install faiss-cpu" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8ppn-rp_GWVm", + "metadata": { + "id": "8ppn-rp_GWVm" + }, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('drive/MyDrive')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "29563e5d-41b0-4f89-8d8b-a54b40f8dfb7", + "metadata": { + "id": "29563e5d-41b0-4f89-8d8b-a54b40f8dfb7" + }, + "outputs": [], + "source": [ + "from llments.lm.base.hugging_face import HuggingFaceLM\n", + "from llments.lm.rag import RAGLanguageModel\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "device = 'cuda:0' # change to 'mps' if you have a mac, or 'cuda:0' if you have an NVIDIA GPU" + ] + }, + { + "cell_type": "markdown", + "id": "d0022efe", + "metadata": { + "id": "d0022efe" + }, + "source": [ + "## Encode the Documents file provided in jsonl format\n", + "\n", + "The following code generates an encoding for Documents." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "931a71d1", + "metadata": { + "id": "931a71d1" + }, + "outputs": [], + "source": [ + "language_model = None" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5b8e218f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5b8e218f", + "outputId": "17bfc02b-9cc0-4495-e9d5-ee9ad37771c3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating Datastore...\n", + "Initializing the document encoder ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n", + "5it [00:00, 13653.33it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Building the index ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1/1 [00:01<00:00, 1.87s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Index creation completed sucessfully!\n", + "Datastore creation completed successfully!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "ragLM = RAGLanguageModel(base=language_model, document_path='collection.jsonl', index_path='drive/MyDrive/llments/examples/rag/msmarco_index', index_encoder='facebook/contriever', fields=['text'], to_faiss=True, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "cf155543", + "metadata": { + "id": "cf155543" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/rag/requirements.txt b/examples/rag/requirements.txt new file mode 100644 index 0000000..1b86476 --- /dev/null +++ b/examples/rag/requirements.txt @@ -0,0 +1,5 @@ +numpy +pandas +torch +pyserini +faiss-cpu \ No newline at end of file diff --git a/llments/datastore/datastore.py b/llments/datastore/datastore.py index aa941ee..88e128b 100644 --- a/llments/datastore/datastore.py +++ b/llments/datastore/datastore.py @@ -1,2 +1,81 @@ +from pyserini.encode import JsonlRepresentationWriter, FaissRepresentationWriter, JsonlCollectionIterator +from pyserini.encode import DprDocumentEncoder, TctColBertDocumentEncoder, AnceDocumentEncoder, AggretrieverDocumentEncoder, AutoDocumentEncoder, CosDprDocumentEncoder +from pyserini.encode import UniCoilDocumentEncoder +from pyserini.encode import OpenAIDocumentEncoder, OPENAI_API_RETRY_DELAY + class Datastore: - ... + def __init__(self, input_jsonl, output_dir, encoder, fields, device): + self.input_jsonl = input_jsonl + self.output_dir = output_dir + self.encoder = encoder + self.device = device + self.fields = fields + + def encode(self, delimiter="\n", docid_field=None, batch_size=64, max_length=256, dimension=768, + prefix=None, pooling='cls', l2_norm=False, to_faiss=False, use_openai=False, rate_limit=3500): + encoder_class_map = { + "dpr": DprDocumentEncoder, + "tct_colbert": TctColBertDocumentEncoder, + "aggretriever": AggretrieverDocumentEncoder, + "ance": AnceDocumentEncoder, + "sentence-transformers": AutoDocumentEncoder, + "unicoil": UniCoilDocumentEncoder, + "openai-api": OpenAIDocumentEncoder, + "cosdpr": CosDprDocumentEncoder, + "auto": AutoDocumentEncoder, + } + + encoder_class = None + + for class_keyword in encoder_class_map: + if class_keyword in self.encoder.lower(): + encoder_class = encoder_class_map[class_keyword] + break + + # if none of the class keyword was matched, use the AutoDocumentEncoder + if encoder_class is None: + encoder_class = AutoDocumentEncoder + + if "sentence-transformers" in self.encoder: + pooling = 'mean' + l2_norm = True + elif "contriever" in self.encoder: + pooling = 'mean' + l2_norm = False + elif "auto" in self.encoder: + pooling = 'cls' + l2_norm = False + + print("Initializing the document encoder ...") + + if encoder_class == AutoDocumentEncoder: + encoder_instance = encoder_class(model_name=self.encoder, device=self.device, pooling=pooling, l2_norm=l2_norm, prefix=prefix) + else: + encoder_instance = encoder_class(model_name = self.encoder, device = self.device) + + if to_faiss: + embedding_writer = FaissRepresentationWriter(self.output_dir, dimension=dimension) + else: + embedding_writer = JsonlRepresentationWriter(self.output_dir) + + collection_iterator = JsonlCollectionIterator(self.input_jsonl, self.fields, docid_field, delimiter) + + if use_openai: + batch_size = int(rate_limit / (60 / OPENAI_API_RETRY_DELAY)) + + print("Building the index ...") + + with embedding_writer: + for batch_info in collection_iterator(batch_size): + texts = batch_info['text'] + titles = batch_info['title'] if 'title' in self.fields else None + expands = batch_info['expand'] if 'expand' in self.fields else None + fp16 = False + max_length = max_length + add_sep = False + + embeddings = encoder_instance.encode(texts=texts, titles=titles, expands=expands, fp16=fp16, max_length=max_length, add_sep=add_sep) + batch_info['vector'] = embeddings + embedding_writer.write(batch_info, self.fields) + + print("\nIndex creation completed sucessfully!") diff --git a/llments/lm/rag.py b/llments/lm/rag.py index bf83a80..8b4cd35 100644 --- a/llments/lm/rag.py +++ b/llments/lm/rag.py @@ -3,13 +3,35 @@ class RAGLanguageModel(LanguageModel): - def __init__(self, base: LanguageModel, datastore: Datastore): + def __init__(self, base: LanguageModel, document_path: str, index_path: str, index_encoder: str, fields: list, + to_faiss: bool, device: str, delimiter="\n", docid_field=None, batch_size=64, max_length=256, + dimension=768, prefix=None, pooling='cls', l2_norm=False, use_openai=False, rate_limit=3500): """Apply retrieval-augmented generation over a datastore. Args: base: The language model to be modified. + document_path: The path to the document file + index_path: The path to store the generated index + index_encoder: The type of document encoder + fields: The document fields to be encoded + to_faiss: Store as a FAISS index + device: The device to be used for encoding + delimiter: Delimiter for document separation + docid_field: Field in the document containing document id + batch_size: Batch size for encoding + max_length: Maximum length of the input sequence + dimension: Dimensionality of the encoding + prefix: Prefix to add to each document + pooling: Pooling strategy for document encoding + l2_norm: Whether to apply L2 normalization Returns: LanguageModel: The enhanced language model. """ - raise NotImplementedError("This is not implemented yet.") + print("Creating Datastore...") + pyserini_encoder = Datastore(document_path, index_path, index_encoder, fields, device) + pyserini_encoder.encode(delimiter=delimiter, docid_field=docid_field, batch_size=batch_size, max_length=max_length, + dimension=dimension, prefix=prefix, pooling=pooling, l2_norm=l2_norm, to_faiss=to_faiss, + use_openai=use_openai, rate_limit=rate_limit) + print("Datastore creation completed successfully!") +