Skip to content

Commit

Permalink
Gemini drop "default" from JSON Schema
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 24, 2024
1 parent 20975a6 commit 39d124a
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 65 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ __pycache__
env*/
/TODO.md
/postgres-data/
.DS_Store
99 changes: 40 additions & 59 deletions examples/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
Run pgvector with:
mkdir postgres-data
docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 -v `pwd`/postgres-data:/var/lib/postgresql/data pgvector/pgvector:pg17
docker run --rm -e POSTGRES_PASSWORD=postgres \
-p 54320:5432 \
-v `pwd`/postgres-data:/var/lib/postgresql/data \
pgvector/pgvector:pg17
Build the search DB with:
Expand Down Expand Up @@ -72,7 +75,7 @@ async def retrieve(context: CallContext[Deps], search_query: str) -> str:
'SELECT url, title, content FROM doc_sections ORDER BY embedding <-> $1 LIMIT 8',
embedding_json,
)
return '\n\n'.join(f'# {row['title']}\nDocumentation URL:{row['url']}\n\n{row['content']}\n' for row in rows)
return '\n\n'.join(f'# {row["title"]}\nDocumentation URL:{row["url"]}\n\n{row["content"]}\n' for row in rows)


async def run_agent(question: str):
Expand All @@ -82,7 +85,7 @@ async def run_agent(question: str):

logfire.info('Asking "{question}"', question=question)

async with database_connect() as pool:
async with database_connect(False) as pool:
deps = Deps(openai=openai, pool=pool)
answer = await agent.run(question, deps=deps)
print(answer.response)
Expand All @@ -93,7 +96,10 @@ async def run_agent(question: str):
############################################################################################

# JSON document from https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992
DOCS_JSON = 'https://gist.githubusercontent.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'
DOCS_JSON = (
'https://gist.githubusercontent.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/'
'80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'
)


async def build_search_db():
Expand All @@ -111,62 +117,37 @@ async def build_search_db():
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(DB_SCHEMA)
await insert_doc_sections(openai, pool, sections)


async def insert_doc_sections(openai: AsyncOpenAI, pool: asyncpg.Pool, sections: list[DocsSection]):
"""Insert all docs sections into postgres, OpenAI is used to generate embeddings for each section.
`asyncio.Queue` is used to perform the embedding creating and insertion concurrently.
"""
queue = asyncio.Queue()

for section in sections:
queue.put_nowait(section)

async def worker():
while True:
s = cast(DocsSection, await queue.get())
try:
with logfire.span('inserting {queue_size=} {url=}', queue_size=queue.qsize(), url=s.url()):
await insert_doc_section(openai, pool, s)
except Exception:
logfire.exception('Error inserting {url=}', url=s.url())
raise
finally:
queue.task_done()

with logfire.span('inserting doc sections'):
tasks = [asyncio.create_task(worker()) for _ in range(10)]
await queue.join()
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)


async def insert_doc_section(openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection) -> bool:
url = section.url()
exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url)
if exists:
logfire.info('Skipping {url=}', url=url)
return False

with logfire.span('create embedding for {url=}', url=url):
embedding = await openai.embeddings.create(
input=section.embedding_content(),
model='text-embedding-3-small',
sem = asyncio.Semaphore(10)
async with asyncio.TaskGroup() as tg:
for section in sections:
tg.create_task(insert_doc_section(sem, openai, pool, section))


async def insert_doc_section(
sem: asyncio.Semaphore, openai: AsyncOpenAI, pool: asyncpg.Pool, section: DocsSection
) -> None:
async with sem:
url = section.url()
exists = await pool.fetchval('SELECT 1 FROM doc_sections WHERE url = $1', url)
if exists:
logfire.info('Skipping {url=}', url=url)

with logfire.span('create embedding for {url=}', url=url):
embedding = await openai.embeddings.create(
input=section.embedding_content(),
model='text-embedding-3-small',
)
assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',
url,
section.title,
section.content,
embedding_json,
)
assert len(embedding.data) == 1, f'Expected 1 embedding, got {len(embedding.data)}, doc section: {section}'
embedding = embedding.data[0].embedding
embedding_json = pydantic_core.to_json(embedding).decode()
await pool.execute(
'INSERT INTO doc_sections (url, title, content, embedding) VALUES ($1, $2, $3, $4)',
url,
section.title,
section.content,
embedding_json,
)
return True


@dataclass
Expand All @@ -180,7 +161,7 @@ class DocsSection:

def url(self) -> str:
url_path = re.sub(r'\.md$', '', self.path)
return f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, '-')}'
return f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}'

def embedding_content(self) -> str:
return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content))
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def simplify(self) -> dict[str, Any]:

def _simplify(self, schema: dict[str, Any], allow_ref: bool) -> None:
schema.pop('title', None)
schema.pop('default', None)
if ref := schema.pop('$ref', None):
if not allow_ref:
raise shared.UserError('Recursive `$ref`s in JSON Schema are not supported by Gemini')
Expand Down
11 changes: 5 additions & 6 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_require_response_tool():
def test_json_def_replaced():
class Location(BaseModel):
lat: float
lng: float
lng: float = 1.1

class Locations(BaseModel):
locations: list[Location]
Expand All @@ -184,9 +184,9 @@ class Locations(BaseModel):
'Location': {
'properties': {
'lat': {'title': 'Lat', 'type': 'number'},
'lng': {'title': 'Lng', 'type': 'number'},
'lng': {'default': 1.1, 'title': 'Lng', 'type': 'number'},
},
'required': ['lat', 'lng'],
'required': ['lat'],
'title': 'Location',
'type': 'object',
}
Expand Down Expand Up @@ -219,7 +219,7 @@ class Locations(BaseModel):
'lat': {'type': 'number'},
'lng': {'type': 'number'},
},
'required': ['lat', 'lng'],
'required': ['lat'],
'type': 'object',
},
'type': 'array',
Expand Down Expand Up @@ -267,8 +267,7 @@ class Locations(BaseModel):
'type': 'object',
},
{'type': 'null'},
],
'default': None,
]
}
},
'type': 'object',
Expand Down

0 comments on commit 39d124a

Please sign in to comment.