Skip to content

Commit

Permalink
Rag example (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 24, 2024
1 parent c417c8d commit 20975a6
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 3 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ __pycache__
/.coverage
env*/
/TODO.md
/postgres-data/
38 changes: 38 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,41 @@ uv run --extra examples -m examples.weather

This example uses `openai:gpt-4o` by default. Gemini seems to be unable to handle the multiple tool
calls.

### `rag.py`

(Demonstrates: retrievers, agent deps, RAG search)

RAG search example. This demo allows you to ask question of the [logfire](https://pydantic.dev/logfire) documentation.

This is done by creating a database containing each section of the markdown documentation, then registering
the search tool as a retriever with the Pydantic AI agent.

Logic for extracting sections from markdown files and a JSON file with that data is available in
[this gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992).

[PostgreSQL with pgvector](https://github.com/pgvector/pgvector) is used as the search database.

The easiest way to download and run pgvector is using Docker:

```bash
mkdir postgres-data
docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 -v `pwd`/postgres-data:/var/lib/postgresql/data pgvector/pgvector:pg17
```

We run postgres port `54320` to avoid conflicts with any other postgres instances you may have running.
We also mount the postgresql `data` directory locally to persist the data if you need to stop and restart the container.

Wit that running, we can then build the search database with (**WARNING**: this requires `OPENAI_API_KEY` and will calling the OpenAI embedding API around 300 times to generate embeddings for each section of the documentation):

```bash
uv run --extra examples -m examples.rag build
```

(Note building the database doesn't use Pydantic AI right now, instead it uses the OpenAI SDK directly.)

You can then ask the agent a question with:

```bash
uv run --extra examples -m examples.rag search "How do I configure logfire to work with FastAPI?"
```
252 changes: 252 additions & 0 deletions examples/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
"""RAG example with pydantic-ai.
Run pgvector with:
mkdir postgres-data
docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 -v `pwd`/postgres-data:/var/lib/postgresql/data pgvector/pgvector:pg17
Build the search DB with:
uv run --extra examples -m examples.rag build
Ask the agent a question with:
uv run --extra examples -m examples.rag search "How do I configure logfire to work with FastAPI?"
"""

from __future__ import annotations as _annotations

import asyncio
import os
import re
import sys
import unicodedata
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import cast

import asyncpg
import httpx
import logfire
import pydantic_core
from openai import AsyncOpenAI
from pydantic import TypeAdapter
from typing_extensions import AsyncGenerator

from pydantic_ai import CallContext
from pydantic_ai.agent import Agent, KnownModelName

# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure()
logfire.instrument_asyncpg()


@dataclass
class Deps:
openai: AsyncOpenAI
pool: asyncpg.Pool


model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'openai:gpt-4o'))
agent: Agent[Deps, str] = Agent(model)


@agent.retriever_context
async def retrieve(context: CallContext[Deps], search_query: str) -> str:
"""Retrieve documentation sections based on a search query.
Args:
context: The call context.
search_query: The search query.
"""
with logfire.span('create embedding for {search_query=}', search_query=search_query):
embedding = await context.deps.openai.embeddings.create(
input=search_query,
model='text-embedding-3-small',
)

assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc query: {search_query!r}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
rows = await context.deps.pool.fetch(
'SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8',
embedding_json,
)
return '\n\n'.join(f'# {row['title']}\nDocumentation URL:{row['url']}\n\n{row['content']}\n' for row in rows)


async def run_agent(question: str):
"""Entry point to run the agent and perform RAG based question answering."""
openai = AsyncOpenAI()
logfire.instrument_openai(openai)

logfire.info('Asking "{question}"', question=question)

async with database_connect() as pool:
deps = Deps(openai=openai, pool=pool)
answer = await agent.run(question, deps=deps)
print(answer.response)


############################################################################################
# The rest of this file is dedicated to preparing the search database, and some utilities. #
############################################################################################

# JSON document from https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992
DOCS_JSON = 'https://gist.githubusercontent.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'


async def build_search_db():
"""Build the search database."""
async with httpx.AsyncClient() as client:
response = await client.get(DOCS_JSON)
response.raise_for_status()
sections = sessions_ta.validate_json(response.content)

openai = AsyncOpenAI()
logfire.instrument_openai(openai)

async with database_connect(True) as pool:
with logfire.span('create schema'):
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(DB_SCHEMA)
await insert_doc_sections(openai, pool, sections)


async def insert_doc_sections(openai: AsyncOpenAI, pool: asyncpg.Pool, sections: list[DocsSection]):
"""Insert all docs sections into postgres, OpenAI is used to generate embeddings for each section.
`asyncio.Queue` is used to perform the embedding creating and insertion concurrently.
"""
queue = asyncio.Queue()

for section in sections:
queue.put_nowait(section)

async def worker():
while True:
s = cast(DocsSection, await queue.get())
try:
with logfire.span('inserting {queue_size=} {url=}', queue_size=queue.qsize(), url=s.url()):
await insert_doc_section(openai, pool, s)
except Exception:
logfire.exception('Error inserting {url=}', url=s.url())
raise
finally:
queue.task_done()

with logfire.span('inserting doc sections'):
tasks = [asyncio.create_task(worker()) for _ in range(10)]
await queue.join()
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


async def insert_doc_section(openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection) -> bool:
url = section.url()
exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url)
if exists:
logfire.info('Skipping {url=}', url=url)
return False

with logfire.span('create embedding for {url=}', url=url):
embedding = await openai.embeddings.create(
input=section.embedding_content(),
model='text-embedding-3-small',
)
assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',
url,
section.title,
section.content,
embedding_json,
)
return True


@dataclass
class DocsSection:
id: int
parent: int | None
path: str
level: int
title: str
content: str

def url(self) -> str:
url_path = re.sub(r'\.md$', '', self.path)
return f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, '-')}'

def embedding_content(self) -> str:
return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content))


sessions_ta = TypeAdapter(list[DocsSection])


# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
@asynccontextmanager
async def database_connect(create_db: bool = False) -> AsyncGenerator[asyncpg.Pool, None]:
server_dsn, database = 'postgresql://postgres:postgres@localhost:54320', 'pydantic_ai_rag'
if create_db:
with logfire.span('check and create DB'):
conn = await asyncpg.connect(server_dsn)
try:
db_exists = await conn.fetchval('SELECT 1 FROM pg_database WHERE datname = $1', database)
if not db_exists:
await conn.execute(f'CREATE DATABASE {database}')
finally:
await conn.close()

pool = await asyncpg.create_pool(f'{server_dsn}/{database}')
try:
yield pool
finally:
await pool.close()


DB_SCHEMA = """
CREATE EXTENSION IF NOT EXISTS vector;
CREATE TABLE IF NOT EXISTS doc_sections (
id serial PRIMARY KEY,
url text NOT NULL UNIQUE,
title text NOT NULL,
content text NOT NULL,
-- text-embedding-3-small returns a vector of 1536 floats
embedding vector(1536) NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_doc_sections_embedding ON doc_sections USING hnsw (embedding vector_l2_ops);
"""


def slugify(value: str, separator: str, unicode: bool = False) -> str:
"""Slugify a string, to make it URL friendly."""
# Taken unchanged from https://github.com/Python-Markdown/markdown/blob/3.7/markdown/extensions/toc.py#L38
if not unicode:
# Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty`
value = unicodedata.normalize('NFKD', value)
value = value.encode('ascii', 'ignore').decode('ascii')
value = re.sub(r'[^\w\s-]', '', value).strip().lower()
return re.sub(rf'[{separator}\s]+', separator, value)


if __name__ == '__main__':
action = sys.argv[1] if len(sys.argv) > 1 else None
if action == 'build':
asyncio.run(build_search_db())
elif action == 'search':
if len(sys.argv) == 3:
q = sys.argv[2]
else:
q = 'How do I configure logfire to work with FastAPI?'
asyncio.run(run_agent(q))
else:
print('uv run --extra examples -m examples.rag build|search', file=sys.stderr)
sys.exit(1)
1 change: 1 addition & 0 deletions examples/sql_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure()
logfire.instrument_asyncpg()

DB_SCHEMA = """
CREATE TABLE IF NOT EXISTS records (
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ logfire = [
]
examples = [
"asyncpg>=0.30.0",
"logfire>=1.2.0",
"logfire[asyncpg]>=1.2.0",
]

[tool.uv]
Expand Down
23 changes: 21 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 20975a6

Please sign in to comment.