Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple prepare from chat #263

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions examples/corpus.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from dotenv import load_dotenv\n",
"\n",
"load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from ragna import Corpus, Rag, assistants, source_storages\n",
"\n",
"rag = Rag()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"document_path = \"demo_document.txt\"\n",
"\n",
"with open(document_path, \"w\") as file:\n",
" file.write(\"This is a test document\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"corpus = Corpus(name=\"demo_corpus\", documents=[document_path], prepared=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/nenb/nicks_projects/ragna/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"chat = rag.chat(\n",
" corpus=corpus,\n",
" source_storage=source_storages.LanceDB,\n",
" assistant=assistants.Gpt35Turbo16k,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Message(content='How can I help you with the documents?', role=<MessageRole.SYSTEM: 'system'>, sources=[])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.prepare()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"This document is a test document.\n"
]
}
],
"source": [
"print(await chat.answer(\"What is this document?\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "ragna-dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion ragna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@

from . import assistants, core, deploy, source_storages
from ._utils import local_root
from .core import Rag
from .core import Corpus, Rag
9 changes: 3 additions & 6 deletions ragna/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"Assistant",
"Chat",
"Component",
"Corpus",
"Document",
"DocumentHandler",
"DocumentUploadParameters",
Expand All @@ -20,16 +21,12 @@
"TxtDocumentHandler",
]

from ._utils import (
EnvVarRequirement,
PackageRequirement,
RagnaException,
Requirement,
)
from ._utils import EnvVarRequirement, PackageRequirement, RagnaException, Requirement

# isort: split

from ._document import (
Corpus,
Document,
DocumentHandler,
DocumentUploadParameters,
Expand Down
10 changes: 5 additions & 5 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pydantic
import pydantic.utils

from ._document import Document
from ._document import Corpus, Document
from ._utils import RequirementsMixin, merge_models


Expand Down Expand Up @@ -100,20 +100,20 @@ class SourceStorage(Component, abc.ABC):
__ragna_protocol_methods__ = ["store", "retrieve"]

@abc.abstractmethod
def store(self, documents: list[Document]) -> None:
def store(self, corpus: Corpus) -> None:
"""Store content of documents.

Args:
documents: Documents to store.
corpus: Corpus to store.
"""
...

@abc.abstractmethod
def retrieve(self, documents: list[Document], prompt: str) -> list[Source]:
def retrieve(self, corpus: Corpus, prompt: str) -> list[Source]:
"""Retrieve sources for a given prompt.

Args:
documents: Documents to retrieve sources from.
corpus: Corpus to retrieve sources from.
prompt: Prompt to retrieve sources for.

Returns:
Expand Down
43 changes: 42 additions & 1 deletion ragna/core/_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Iterator,
Optional,
Type,
TypeVar,
Union,
)

import jwt
from pydantic import BaseModel
Expand Down Expand Up @@ -267,3 +276,35 @@ def extract_pages(self, document: Document) -> Iterator[Page]:
) as document:
for number, page in enumerate(document, 1):
yield Page(text=page.get_text(sort=True), number=number)


class Corpus:
"""Collection of documents.

Attributes:
documents: Documents in the corpus.
name: Name of the corpus.
prepared: Whether the corpus is prepared ie stored in a vector database.
"""

def __init__(self, documents: Iterable[Any], *, name: str, prepared: bool):
self.documents = parse_documents(documents)
self.name = name
self.prepared = prepared


def parse_documents(documents: Iterable[Any]) -> list[Document]:
documents_ = []
for document in documents:
if not isinstance(document, Document):
document = LocalDocument.from_path(document)

if not document.is_readable():
raise RagnaException(
"Document not readable",
document=document,
http_status_code=404,
)

documents_.append(document)
return documents_
43 changes: 12 additions & 31 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Awaitable,
Callable,
Generic,
Iterable,
Optional,
Type,
TypeVar,
Expand All @@ -21,7 +20,7 @@
import pydantic

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
from ._document import Corpus
from ._utils import RagnaException, default_user, merge_models

T = TypeVar("T")
Expand Down Expand Up @@ -62,23 +61,22 @@ def _load_component(self, component: Union[Type[C], C]) -> C:
def chat(
self,
*,
documents: Iterable[Any],
corpus: Corpus,
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].

Args:
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
corpus: Corpus to use.
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
corpus=corpus,
source_storage=source_storage,
assistant=assistant,
**params,
Expand Down Expand Up @@ -119,8 +117,7 @@ class Chat:

Args:
rag: The RAG workflow this chat is associated with.
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
corpus: Corpus to use.
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
Expand All @@ -130,14 +127,14 @@ def __init__(
self,
rag: Rag,
*,
documents: Iterable[Any],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be able to also pass in regular paths here. While Corpus is a useful abstraction, it is unnecessary for the base case. I would move the _parse_documents function that you moved to Corpus back here and let it return a Corpus in case we encounter a list of paths.

corpus: Corpus,
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)
self.corpus = corpus
self.source_storage = self._rag._load_component(source_storage)
self.assistant = self._rag._load_component(assistant)

Expand All @@ -163,16 +160,16 @@ async def prepare(self) -> Message:
ragna.core.RagnaException: If chat is already
[`prepare`][ragna.core.Chat.prepare]d.
"""
if self._prepared:
if self.corpus.prepared:
raise RagnaException(
"Chat is already prepared",
chat=self,
http_status_code=400,
detail=RagnaException.EVENT,
)

await self._run(self.source_storage.store, self.documents)
self._prepared = True
await self._run(self.source_storage.store, self.corpus)
self.corpus.prepared = True

welcome = Message(
content="How can I help you with the documents?",
Expand All @@ -191,7 +188,7 @@ async def answer(self, prompt: str) -> Message:
ragna.core.RagnaException: If chat is not
[`prepare`][ragna.core.Chat.prepare]d.
"""
if not self._prepared:
if not self.corpus.prepared:
raise RagnaException(
"Chat is not prepared",
chat=self,
Expand All @@ -203,7 +200,7 @@ async def answer(self, prompt: str) -> Message:
self._messages.append(prompt)

sources = await self._run(
self.source_storage.retrieve, self.documents, prompt.content
self.source_storage.retrieve, self.corpus, prompt.content
)
answer = Message(
content=await self._run(self.assistant.answer, prompt.content, sources),
Expand All @@ -221,22 +218,6 @@ async def answer(self, prompt: str) -> Message:

return answer

def _parse_documents(self, documents: Iterable[Any]) -> list[Document]:
documents_ = []
for document in documents:
if not isinstance(document, Document):
document = LocalDocument.from_path(document)

if not document.is_readable():
raise RagnaException(
"Document not readable",
document=document,
http_status_code=404,
)

documents_.append(document)
return documents_

def _unpack_chat_params(
self, params: dict[str, Any]
) -> dict[Callable, dict[str, Any]]:
Expand Down
Loading
Loading