Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini drop default from JSON Schema #20

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ __pycache__
env*/
/TODO.md
/postgres-data/
.DS_Store
99 changes: 40 additions & 59 deletions examples/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
Run pgvector with:

mkdir postgres-data
docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 -v `pwd`/postgres-data:/var/lib/postgresql/data pgvector/pgvector:pg17
docker run --rm -e POSTGRES_PASSWORD=postgres \
-p 54320:5432 \
-v `pwd`/postgres-data:/var/lib/postgresql/data \
pgvector/pgvector:pg17

Build the search DB with:

Expand Down Expand Up @@ -72,7 +75,7 @@ async def retrieve(context: CallContext[Deps], search_query: str) -> str:
'SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8',
embedding_json,
)
return '\n\n'.join(f'# {row['title']}\nDocumentation URL:{row['url']}\n\n{row['content']}\n' for row in rows)
return '\n\n'.join(f'# {row["title"]}\nDocumentation URL:{row["url"]}\n\n{row["content"]}\n' for row in rows)


async def run_agent(question: str):
Expand All @@ -82,7 +85,7 @@ async def run_agent(question: str):

logfire.info('Asking "{question}"', question=question)

async with database_connect() as pool:
async with database_connect(False) as pool:
deps = Deps(openai=openai, pool=pool)
answer = await agent.run(question, deps=deps)
print(answer.response)
Expand All @@ -93,7 +96,10 @@ async def run_agent(question: str):
############################################################################################

# JSON document from https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992
DOCS_JSON = 'https://gist.githubusercontent.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'
DOCS_JSON = (
'https://gist.githubusercontent.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/'
'80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'
)


async def build_search_db():
Expand All @@ -111,62 +117,37 @@ async def build_search_db():
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(DB_SCHEMA)
await insert_doc_sections(openai, pool, sections)


async def insert_doc_sections(openai: AsyncOpenAI, pool: asyncpg.Pool, sections: list[DocsSection]):
"""Insert all docs sections into postgres, OpenAI is used to generate embeddings for each section.

`asyncio.Queue` is used to perform the embedding creating and insertion concurrently.
"""
queue = asyncio.Queue()

for section in sections:
queue.put_nowait(section)

async def worker():
while True:
s = cast(DocsSection, await queue.get())
try:
with logfire.span('inserting {queue_size=} {url=}', queue_size=queue.qsize(), url=s.url()):
await insert_doc_section(openai, pool, s)
except Exception:
logfire.exception('Error inserting {url=}', url=s.url())
raise
finally:
queue.task_done()

with logfire.span('inserting doc sections'):
tasks = [asyncio.create_task(worker()) for _ in range(10)]
await queue.join()
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


async def insert_doc_section(openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection) -> bool:
url = section.url()
exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url)
if exists:
logfire.info('Skipping {url=}', url=url)
return False

with logfire.span('create embedding for {url=}', url=url):
embedding = await openai.embeddings.create(
input=section.embedding_content(),
model='text-embedding-3-small',
sem = asyncio.Semaphore(10)
async with asyncio.TaskGroup() as tg:
for section in sections:
tg.create_task(insert_doc_section(sem, openai, pool, section))


async def insert_doc_section(
sem: asyncio.Semaphore, openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection
) -> None:
async with sem:
url = section.url()
exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url)
if exists:
logfire.info('Skipping {url=}', url=url)

with logfire.span('create embedding for {url=}', url=url):
embedding = await openai.embeddings.create(
input=section.embedding_content(),
model='text-embedding-3-small',
)
assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',
url,
section.title,
section.content,
embedding_json,
)
assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',
url,
section.title,
section.content,
embedding_json,
)
return True


@dataclass
Expand All @@ -180,7 +161,7 @@ class DocsSection:

def url(self) -> str:
url_path = re.sub(r'\.md$', '', self.path)
return f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, '-')}'
return f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}'

def embedding_content(self) -> str:
return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content))
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def simplify(self) -> dict[str, Any]:

def _simplify(self, schema: dict[str, Any], allow_ref: bool) -> None:
schema.pop('title', None)
schema.pop('default', None)
if ref := schema.pop('$ref', None):
if not allow_ref:
raise shared.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
Expand Down
11 changes: 5 additions & 6 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_require_response_tool():
def test_json_def_replaced():
class Location(BaseModel):
lat: float
lng: float
lng: float = 1.1

class Locations(BaseModel):
locations: list[Location]
Expand All @@ -184,9 +184,9 @@ class Locations(BaseModel):
'Location': {
'properties': {
'lat': {'title': 'Lat', 'type': 'number'},
'lng': {'title': 'Lng', 'type': 'number'},
'lng': {'default': 1.1, 'title': 'Lng', 'type': 'number'},
},
'required': ['lat', 'lng'],
'required': ['lat'],
'title': 'Location',
'type': 'object',
}
Expand Down Expand Up @@ -219,7 +219,7 @@ class Locations(BaseModel):
'lat': {'type': 'number'},
'lng': {'type': 'number'},
},
'required': ['lat', 'lng'],
'required': ['lat'],
'type': 'object',
},
'type': 'array',
Expand Down Expand Up @@ -267,8 +267,7 @@ class Locations(BaseModel):
'type': 'object',
},
{'type': 'null'},
],
'default': None,
]
}
},
'type': 'object',
Expand Down
Loading