Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add corpus label for identifying groups of documents #269

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,24 @@ class SourceStorage(Component, abc.ABC):
__ragna_protocol_methods__ = ["store", "retrieve"]

@abc.abstractmethod
def store(self, documents: list[Document]) -> None:
def store(self, documents: list[Document], corpus_id: str) -> None:
"""Store content of documents.

Args:
documents: Documents to store.
corpus_id: A unique identifier for the corpus the documents belong to.
"""
...

@abc.abstractmethod
def retrieve(self, documents: list[Document], prompt: str) -> list[Source]:
def retrieve(
self, documents: list[Document], corpus_id: str, prompt: str
) -> list[Source]:
"""Retrieve sources for a given prompt.

Args:
documents: Documents to retrieve sources from.
corpus_id: A unique identifier for the corpus the documents belong to.
prompt: Prompt to retrieve sources for.

Returns:
Expand Down
40 changes: 12 additions & 28 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def chat(
self,
*,
documents: Iterable[Any],
corpus_id: str,
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
**params: Any,
Expand All @@ -72,13 +73,15 @@ def chat(
Args:
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
corpus_id: A unique identifier for the corpus the documents belong to.
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
corpus_id=corpus_id,
source_storage=source_storage,
assistant=assistant,
**params,
Expand Down Expand Up @@ -111,6 +114,7 @@ class Chat:

async with rag.chat(
documents=[path],
corpus_id="fake_corpus",
source_storage=ragna.core.RagnaDemoSourceStorage,
assistant=ragna.core.RagnaDemoAssistant,
) as chat:
Expand All @@ -121,6 +125,7 @@ class Chat:
rag: The RAG workflow this chat is associated with.
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
corpus_id: Unique identifier for the corpus the documents belong to.
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
Expand All @@ -131,13 +136,15 @@ def __init__(
rag: Rag,
*,
documents: Iterable[Any],
corpus_id: str,
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)
self.corpus_id = corpus_id
self.source_storage = self._rag._load_component(source_storage)
self.assistant = self._rag._load_component(assistant)

Expand All @@ -147,7 +154,6 @@ def __init__(
self.params = params
self._unpacked_params = self._unpack_chat_params(params)

self._prepared = False
self._messages: list[Message] = []

async def prepare(self) -> Message:
Expand All @@ -158,21 +164,8 @@ async def prepare(self) -> Message:

Returns:
Welcome message.

Raises:
ragna.core.RagnaException: If chat is already
[`prepare`][ragna.core.Chat.prepare]d.
"""
if self._prepared:
raise RagnaException(
"Chat is already prepared",
chat=self,
http_status_code=400,
detail=RagnaException.EVENT,
)

await self._run(self.source_storage.store, self.documents)
self._prepared = True
await self._run(self.source_storage.store, self.documents, self.corpus_id)

welcome = Message(
content="How can I help you with the documents?",
Expand All @@ -186,24 +179,15 @@ async def answer(self, prompt: str) -> Message:

Returns:
Answer.

Raises:
ragna.core.RagnaException: If chat is not
[`prepare`][ragna.core.Chat.prepare]d.
"""
if not self._prepared:
raise RagnaException(
"Chat is not prepared",
chat=self,
http_status_code=400,
detail=RagnaException.EVENT,
)

prompt = Message(content=prompt, role=MessageRole.USER)
self._messages.append(prompt)

sources = await self._run(
self.source_storage.retrieve, self.documents, prompt.content
self.source_storage.retrieve,
self.documents,
self.corpus_id,
prompt.content,
)
answer = Message(
content=await self._run(self.assistant.answer, prompt.content, sources),
Expand Down
24 changes: 16 additions & 8 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def upload_document(
return document

def schema_to_core_chat(
session: database.Session, *, user: str, chat: schemas.Chat
session: database.Session, *, user: str, chat: schemas.Chat, corpus_id: str
) -> ragna.core.Chat:
core_chat = rag.chat(
documents=[
Expand All @@ -166,6 +166,7 @@ def schema_to_core_chat(
)
for document in chat.metadata.documents
],
corpus_id=corpus_id,
source_storage=get_component(chat.metadata.source_storage), # type: ignore[arg-type]
assistant=get_component(chat.metadata.assistant), # type: ignore[arg-type]
user=user,
Expand All @@ -178,7 +179,6 @@ def schema_to_core_chat(
# if we implement a chat history feature, i.e. passing past messages to
# the assistant, this becomes crucial.
core_chat._messages = []
core_chat._prepared = chat.prepared

return core_chat

Expand All @@ -192,7 +192,7 @@ async def create_chat(

# Although we don't need the actual ragna.core.Chat object here,
# we use it to validate the documents and metadata.
schema_to_core_chat(session, user=user, chat=chat)
schema_to_core_chat(session, user=user, chat=chat, corpus_id="fake-corpus")

database.add_chat(session, user=user, chat=chat)
return chat
Expand All @@ -208,31 +208,39 @@ async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat:
return database.get_chat(session, user=user, id=id)

@app.post("/chats/{id}/prepare")
async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message:
async def prepare_chat(
user: UserDependency, id: uuid.UUID, corpus_id: str
) -> schemas.Message:
with get_session() as session:
chat = database.get_chat(session, user=user, id=id)

core_chat = schema_to_core_chat(session, user=user, chat=chat)
core_chat = schema_to_core_chat(
session, user=user, chat=chat, corpus_id=corpus_id
)

welcome = schemas.Message.from_core(await core_chat.prepare())

chat.prepared = True
chat.messages.append(welcome)
database.update_chat(session, user=user, chat=chat)

return welcome

@app.post("/chats/{id}/answer")
async def answer(
user: UserDependency, id: uuid.UUID, prompt: str
user: UserDependency,
id: uuid.UUID,
prompt: str,
corpus_id: str,
) -> schemas.Message:
with get_session() as session:
chat = database.get_chat(session, user=user, id=id)
chat.messages.append(
schemas.Message(content=prompt, role=ragna.core.MessageRole.USER)
)

core_chat = schema_to_core_chat(session, user=user, chat=chat)
core_chat = schema_to_core_chat(
session, user=user, chat=chat, corpus_id=corpus_id
)

answer = schemas.Message.from_core(await core_chat.answer(prompt))

Expand Down
3 changes: 0 additions & 3 deletions ragna/deploy/_api/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None:
source_storage=chat.metadata.source_storage,
assistant=chat.metadata.assistant,
params=chat.metadata.params,
prepared=chat.prepared,
)
)
session.commit()
Expand Down Expand Up @@ -130,7 +129,6 @@ def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat:
params=chat.params, # type: ignore[arg-type]
),
messages=messages,
prepared=chat.prepared,
)


Expand Down Expand Up @@ -206,7 +204,6 @@ def _schema_to_orm_message(
def update_chat(session: Session, user: str, chat: schemas.Chat) -> None:
orm_chat = _get_orm_chat(session, user=user, id=chat.id)

orm_chat.prepared = chat.prepared
orm_chat.messages = [
_schema_to_orm_message(session, chat_id=chat.id, message=message)
for message in chat.messages
Expand Down
1 change: 0 additions & 1 deletion ragna/deploy/_api/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class Chat(Base):
assistant = Column(types.String)
params = Column(types.JSON)
messages = relationship("Message", cascade="all, delete")
prepared = Column(types.Boolean)


source_message_association_table = Table(
Expand Down
1 change: 0 additions & 1 deletion ragna/deploy/_api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,3 @@ class Chat(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
metadata: ChatMetadata
messages: list[Message] = Field(default_factory=list)
prepared: bool = False
14 changes: 12 additions & 2 deletions ragna/deploy/_ui/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,15 @@ async def get_chats(self):
return json_data

async def answer(self, chat_id, prompt):
# FIXME: Add UI support for providing an actual corpus name
return self.improve_message(
(
await self.client.post(
f"/chats/{chat_id}/answer",
params={"prompt": prompt},
params={
"prompt": prompt,
"corpus_id": chat_id,
},
timeout=None,
)
)
Expand Down Expand Up @@ -98,7 +102,13 @@ async def start_and_prepare(
)
chat = response.raise_for_status().json()

response = await self.client.post(f"/chats/{chat['id']}/prepare", timeout=None)
# FIXME: Add UI support for providing an actual corpus name
response = await self.client.post(
f"/chats/{chat['id']}/prepare",
params={"corpus_id": chat["id"]},
timeout=None,
)

response.raise_for_status()

return chat["id"]
Expand Down
11 changes: 5 additions & 6 deletions ragna/source_storages/_chroma.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import uuid

import ragna
from ragna.core import (
Document,
Source,
)
from ragna.core import Document, Source

from ._vector_database import VectorDatabaseSourceStorage

Expand Down Expand Up @@ -36,13 +33,14 @@ def __init__(self) -> None:
def store(
self,
documents: list[Document],
corpus_id: str,
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
collection = self._client.create_collection(
str(chat_id), embedding_function=self._embedding_function
corpus_id, embedding_function=self._embedding_function
)

ids = []
Expand Down Expand Up @@ -73,14 +71,15 @@ def store(
def retrieve(
self,
documents: list[Document],
corpus_id: str,
prompt: str,
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
collection = self._client.get_collection(
str(chat_id), embedding_function=self._embedding_function
corpus_id, embedding_function=self._embedding_function
)

result = collection.query(
Expand Down
17 changes: 12 additions & 5 deletions ragna/source_storages/_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ def display_name(cls) -> str:
return "Ragna/DemoSourceStorage"

def __init__(self) -> None:
self._storage: dict[uuid.UUID, list[Source]] = {}
self._storage: dict[str, list[Source]] = {}

def store(self, documents: list[Document], *, chat_id: uuid.UUID) -> None:
self._storage[chat_id] = [
def store(
self, documents: list[Document], corpus_id: str, *, chat_id: uuid.UUID
) -> None:
self._storage[corpus_id] = [
Source(
id=str(uuid.uuid4()),
document=document,
Expand All @@ -37,6 +39,11 @@ def store(self, documents: list[Document], *, chat_id: uuid.UUID) -> None:
]

def retrieve(
self, documents: list[Document], prompt: str, *, chat_id: uuid.UUID
self,
documents: list[Document],
corpus_id: str,
prompt: str,
*,
chat_id: uuid.UUID,
) -> list[Source]:
return self._storage[chat_id]
return self._storage[corpus_id]
6 changes: 4 additions & 2 deletions ragna/source_storages/_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def __init__(self) -> None:
def store(
self,
documents: list[Document],
corpus_id: str,
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
table = self._db.create_table(name=str(chat_id), schema=self._schema)
table = self._db.create_table(name=corpus_id, schema=self._schema)

for document in documents:
for chunk in self._chunk_pages(
Expand Down Expand Up @@ -87,13 +88,14 @@ def store(
def retrieve(
self,
documents: list[Document],
corpus_id: str,
prompt: str,
*,
chat_id: uuid.UUID,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
table = self._db.open_table(str(chat_id))
table = self._db.open_table(corpus_id)

# We cannot retrieve source by a maximum number of tokens. Thus, we estimate how
# many sources we have to query. We overestimate by a factor of two to avoid
Expand Down
Loading
Loading