diff --git a/.gitignore b/.gitignore index 3714644e0..11f2837c1 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ __pycache__ env*/ /TODO.md /postgres-data/ +.DS_Store diff --git a/examples/rag.py b/examples/rag.py index 5340a1c59..a7bc7a4dd 100644 --- a/examples/rag.py +++ b/examples/rag.py @@ -3,7 +3,10 @@ Run pgvector with: mkdir postgres-data - docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 -v `pwd`/postgres-data:/var/lib/postgresql/data pgvector/pgvector:pg17 + docker run --rm -e POSTGRES_PASSWORD=postgres \ + -p 54320:5432 \ + -v `pwd`/postgres-data:/var/lib/postgresql/data \ + pgvector/pgvector:pg17 Build the search DB with: @@ -72,7 +75,7 @@ async def retrieve(context: CallContext[Deps], search_query: str) -> str: 'SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8', embedding_json, ) - return '\n\n'.join(f'# {row['title']}\nDocumentation URL:{row['url']}\n\n{row['content']}\n' for row in rows) + return '\n\n'.join(f'# {row["title"]}\nDocumentation URL:{row["url"]}\n\n{row["content"]}\n' for row in rows) async def run_agent(question: str): @@ -82,7 +85,7 @@ async def run_agent(question: str): logfire.info('Asking "{question}"', question=question) - async with database_connect() as pool: + async with database_connect(False) as pool: deps = Deps(openai=openai, pool=pool) answer = await agent.run(question, deps=deps) print(answer.response) @@ -93,7 +96,10 @@ async def run_agent(question: str): ############################################################################################ # JSON document from https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992 -DOCS_JSON = 'https://gist.githubusercontent.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json' +DOCS_JSON = ( + 'https://gist.githubusercontent.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/' + '80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json' +) async def build_search_db(): @@ -111,62 +117,37 @@ async def build_search_db(): async with pool.acquire() as conn: async with conn.transaction(): await conn.execute(DB_SCHEMA) - await insert_doc_sections(openai, pool, sections) - -async def insert_doc_sections(openai: AsyncOpenAI, pool: asyncpg.Pool, sections: list[DocsSection]): - """Insert all docs sections into postgres, OpenAI is used to generate embeddings for each section. - - `asyncio.Queue` is used to perform the embedding creating and insertion concurrently. - """ - queue = asyncio.Queue() - - for section in sections: - queue.put_nowait(section) - - async def worker(): - while True: - s = cast(DocsSection, await queue.get()) - try: - with logfire.span('inserting {queue_size=} {url=}', queue_size=queue.qsize(), url=s.url()): - await insert_doc_section(openai, pool, s) - except Exception: - logfire.exception('Error inserting {url=}', url=s.url()) - raise - finally: - queue.task_done() - - with logfire.span('inserting doc sections'): - tasks = [asyncio.create_task(worker()) for _ in range(10)] - await queue.join() - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - - -async def insert_doc_section(openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection) -> bool: - url = section.url() - exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url) - if exists: - logfire.info('Skipping {url=}', url=url) - return False - - with logfire.span('create embedding for {url=}', url=url): - embedding = await openai.embeddings.create( - input=section.embedding_content(), - model='text-embedding-3-small', + sem = asyncio.Semaphore(10) + async with asyncio.TaskGroup() as tg: + for section in sections: + tg.create_task(insert_doc_section(sem, openai, pool, section)) + + +async def insert_doc_section( + sem: asyncio.Semaphore, openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection +) -> None: + async with sem: + url = section.url() + exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url) + if exists: + logfire.info('Skipping {url=}', url=url) + + with logfire.span('create embedding for {url=}', url=url): + embedding = await openai.embeddings.create( + input=section.embedding_content(), + model='text-embedding-3-small', + ) + assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}' + embedding = embedding.data[0].embedding + embedding_json = pydantic_core.to_json(embedding).decode() + await pool.execute( + 'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)', + url, + section.title, + section.content, + embedding_json, ) - assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}' - embedding = embedding.data[0].embedding - embedding_json = pydantic_core.to_json(embedding).decode() - await pool.execute( - 'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)', - url, - section.title, - section.content, - embedding_json, - ) - return True @dataclass @@ -180,7 +161,7 @@ class DocsSection: def url(self) -> str: url_path = re.sub(r'\.md$', '', self.path) - return f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, '-')}' + return f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}' def embedding_content(self) -> str: return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content)) diff --git a/pydantic_ai/models/gemini.py b/pydantic_ai/models/gemini.py index eaf818577..4ec7e8be1 100644 --- a/pydantic_ai/models/gemini.py +++ b/pydantic_ai/models/gemini.py @@ -422,6 +422,7 @@ def simplify(self) -> dict[str, Any]: def _simplify(self, schema: dict[str, Any], allow_ref: bool) -> None: schema.pop('title', None) + schema.pop('default', None) if ref := schema.pop('$ref', None): if not allow_ref: raise shared.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 85aaad8cf..2b1c5f610 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -172,7 +172,7 @@ def test_require_response_tool(): def test_json_def_replaced(): class Location(BaseModel): lat: float - lng: float + lng: float = 1.1 class Locations(BaseModel): locations: list[Location] @@ -184,9 +184,9 @@ class Locations(BaseModel): 'Location': { 'properties': { 'lat': {'title': 'Lat', 'type': 'number'}, - 'lng': {'title': 'Lng', 'type': 'number'}, + 'lng': {'default': 1.1, 'title': 'Lng', 'type': 'number'}, }, - 'required': ['lat', 'lng'], + 'required': ['lat'], 'title': 'Location', 'type': 'object', } @@ -219,7 +219,7 @@ class Locations(BaseModel): 'lat': {'type': 'number'}, 'lng': {'type': 'number'}, }, - 'required': ['lat', 'lng'], + 'required': ['lat'], 'type': 'object', }, 'type': 'array', @@ -267,8 +267,7 @@ class Locations(BaseModel): 'type': 'object', }, {'type': 'null'}, - ], - 'default': None, + ] } }, 'type': 'object',