Skip to content

Commit

Permalink
more examples, switch ToolRetry to RetryPrompt
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 22, 2024
1 parent 51373c4 commit 4bd30db
Show file tree
Hide file tree
Showing 18 changed files with 316 additions and 140 deletions.
20 changes: 16 additions & 4 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,45 @@ uv run -m examples.<example_module_name>

### `pydantic_model.py`

(Demonstrates: custom `result_type`)

Simple example of using Pydantic AI to construct a Pydantic model from a text input.

```bash
uv run -m examples.pydantic_model
uv run --extra examples -m examples.pydantic_model
```

This examples uses `openai:gpt-4o` by default but it works well with other modesl, e.g. you can run it
with Gemini using:

```bash
PYDANTIC_AI_MODEL=gemini-1.5-pro uv run -m examples.pydantic_model
PYDANTIC_AI_MODEL=gemini-1.5-pro uv run --extra examples -m examples.pydantic_model
```

(or `PYDANTIC_AI_MODEL=gemini-1.5-flash...`)

### `sql_gen.py`

(Demonstrates: custom `result_type`, dynamic system prompt, result validation, agent deps)

Example demonstrating how to use Pydantic AI to generate SQL queries based on user input.

```bash
uv run -m examples.sql_gen
uv run --extra examples -m examples.sql_gen
```

or to use a custom prompt:

```bash
uv run --extra examples -m examples.sql_gen "find me whatever"
```

This model uses `gemini-1.5-flash` by default since Gemini is good at single shot queries.

### `weather.py`

(Demonstrates: retrievers, multiple retrievers, agent deps)

Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question.

In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities,
Expand All @@ -54,7 +66,7 @@ To run this example properly, you'll need two extra API keys:
**(Note if either key is missing, the code will fall back to dummy data.)**

```bash
uv run -m examples.weather
uv run --extra examples -m examples.weather
```

This example uses `openai:gpt-4o` by default. Gemini seems to be unable to handle the multiple tool
Expand Down
9 changes: 4 additions & 5 deletions examples/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@
Run with:
uv run -m examples.pydantic_model
uv run --extra examples -m examples.pydantic_model
"""

import os
from typing import cast

import logfire
from pydantic import BaseModel

from pydantic_ai import Agent
from pydantic_ai.agent import KnownModelName

# if you don't want to use logfire, just comment out these lines
import logfire

logfire.configure()
# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure(send_to_logfire='if-token-present')


class MyModel(BaseModel):
Expand Down
134 changes: 98 additions & 36 deletions examples/sql_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,79 @@
Run with:
uv run -m examples.sql_gen
uv run --extra examples -m examples.sql_gen
"""

import asyncio
import os
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import cast
from datetime import date
from typing import Annotated, Any, cast

import asyncpg
import logfire
from annotated_types import MinLen
from devtools import debug

from pydantic_ai import Agent
from pydantic_ai import Agent, CallContext, ModelRetry
from pydantic_ai.agent import KnownModelName

# if you don't want to use logfire, just comment out these lines
import logfire

# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure()

system_prompt = """\
Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request.
CREATE TABLE records AS (
start_timestamp timestamp with time zone,
created_at timestamp with time zone,
DB_SCHEMA = """
CREATE TABLE IF NOT EXISTS records (
created_at timestamptz,
start_timestamp timestamptz,
end_timestamp timestamptz,
trace_id text,
span_id text,
parent_span_id text,
kind span_kind,
end_timestamp timestamp with time zone,
level smallint,
level log_level,
span_name text,
message text,
attributes_json_schema text,
attributes jsonb,
tags text[],
otel_links jsonb,
otel_events jsonb,
is_exception boolean,
otel_status_code status_code,
otel_status_message text,
otel_scope_name text,
otel_scope_version text,
otel_scope_attributes jsonb,
service_namespace text,
service_name text,
service_version text,
service_instance_id text,
process_pid integer
service_name text
);
"""


@dataclass
class Response:
sql_query: Annotated[str, MinLen(1)]


@dataclass
class Deps:
conn: asyncpg.Connection


model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'gemini-1.5-flash'))
agent: Agent[Deps, Response] = Agent(model, result_type=Response)


@agent.system_prompt
async def system_prompt() -> str:
return f"""\
Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request.
today's date = 2024-10-09
{DB_SCHEMA}
today's date = {date.today()}
Example
request: show me records where foobar is false
response: SELECT * FROM records WHERE attributes->>'foobar' = false'
response: SELECT * FROM records WHERE attributes->>'foobar' = false
Example
request: show me records where attributes include the key "foobar"
response: SELECT * FROM records WHERE attributes ? 'foobar'
Example
request: show me records from yesterday
response: SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'
Expand All @@ -65,16 +84,59 @@
"""


@dataclass
class Response:
sql_query: str
@agent.result_validator
async def validate_result(ctx: CallContext[Deps], result: Response) -> Response:
result.sql_query = result.sql_query.replace('\\', '')
lower_query = result.sql_query.lower()
if not lower_query.startswith('select'):
raise ModelRetry('Please a SELECT query')

try:
await ctx.deps.conn.execute(f'EXPLAIN {result.sql_query}')
except asyncpg.exceptions.PostgresError as e:
raise ModelRetry(f'Invalid query: {e}') from e
else:
return result

model = cast(KnownModelName, os.getenv('PYDANTIC_AI_MODEL', 'gemini-1.5-flash'))
agent = Agent(model, result_type=Response, system_prompt=system_prompt, deps=None, retries=2)

async def main():
if len(sys.argv) == 1:
prompt = 'show me logs from yesterday, with level "error"'
else:
prompt = sys.argv[1]

if __name__ == '__main__':
with debug.timer('SQL Generation'):
result = agent.run_sync('show me logs from yesterday, with level "error"')
async with database_connect('postgresql://postgres@localhost', 'pydantic_ai_sql_gen') as conn:
deps = Deps(conn)
result = await agent.run(prompt, deps=deps)
debug(result.response.sql_query)


# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
@asynccontextmanager
async def database_connect(server_dsn: str, database: str) -> AsyncGenerator[Any, None]:
with logfire.span('check and create DB'):
conn = await asyncpg.connect(server_dsn)
try:
db_exists = await conn.fetchval('SELECT 1 FROM pg_database WHERE datname = $1', database)
if not db_exists:
await conn.execute(f'CREATE DATABASE {database}')
finally:
await conn.close()

conn = await asyncpg.connect(f'{server_dsn}/{database}')
try:
with logfire.span('create schema'):
async with conn.transaction():
if not db_exists:
await conn.execute(
"CREATE TYPE log_level AS ENUM ('debug', 'info', 'warning', 'error', 'critical')"
)
await conn.execute(DB_SCHEMA)
yield conn
finally:
await conn.close()


if __name__ == '__main__':
asyncio.run(main())
32 changes: 19 additions & 13 deletions examples/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,23 @@
Run with:
uv run -m examples.weather
uv run --extra examples -m examples.weather
"""

import asyncio
import os
from dataclasses import dataclass
from typing import Any, cast

import logfire
from devtools import debug
from httpx import AsyncClient

from pydantic_ai import Agent, CallContext, ModelRetry
from pydantic_ai.agent import KnownModelName

# if you don't want to use logfire, just comment out these lines
import logfire

logfire.configure()
# 'if-token-present' means nothing will be sent (and the example wil work) if you don't have logfire set up
logfire.configure(send_to_logfire='if-token-present')


@dataclass
Expand Down Expand Up @@ -53,12 +52,16 @@ async def get_lat_lng(ctx: CallContext[Deps], location_description: str) -> dict
'q': location_description,
'api_key': ctx.deps.geo_api_key,
}
r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params)
r.raise_for_status()
data = r.json()
if not data:
with logfire.span('calling geocode API', params=params) as span:
r = await ctx.deps.client.get('https://geocode.maps.co/search', params=params)
r.raise_for_status()
data = r.json()
span.set_attribute('response', data)

if data:
return {'lat': data[0]['lat'], 'lng': data[0]['lon']}
else:
raise ModelRetry('Could not find the location')
return {'lat': data[0]['lat'], 'lng': data[0]['lon']}


@weather_agent.retriever_context
Expand All @@ -79,9 +82,12 @@ async def get_weather(ctx: CallContext[Deps], lat: float, lng: float) -> dict[st
'location': f'{lat},{lng}',
'units': 'metric',
}
r = await ctx.deps.client.get('https://api.tomorrow.io/v4/weather/realtime', params=params)
r.raise_for_status()
data = r.json()
with logfire.span('calling weather API', params=params) as span:
r = await ctx.deps.client.get('https://api.tomorrow.io/v4/weather/realtime', params=params)
r.raise_for_status()
data = r.json()
span.set_attribute('response', data)

values = data['data']['values']
# https://docs.tomorrow.io/reference/data-layers-weather-codes
code_lookup = {
Expand Down
17 changes: 8 additions & 9 deletions pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def __post_init__(self):
self._is_async = inspect.iscoroutinefunction(self.function)

async def validate(
self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall
self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall | None
) -> ResultData:
"""Validate a result but calling the function.
Args:
result: The result data after Pydantic validation the message content.
deps: The agent dependencies.
retry: The current retry number.
tool_call: The original tool call message.
tool_call: The original tool call message, `None` if there was no tool call.
Returns:
Result of either the validated result data (ok) or a retry message (Err).
Expand All @@ -59,11 +59,10 @@ async def validate(
function = cast(Callable[[Any], ResultData], self.function)
result_data = await _utils.run_in_executor(function, *args)
except ModelRetry as r:
m = messages.ToolRetry(
tool_name=tool_call.tool_name,
content=r.message,
tool_id=tool_call.tool_id,
)
m = messages.RetryPrompt(content=r.message)
if tool_call is not None:
m.tool_name = tool_call.tool_name
m.tool_id = tool_call.tool_id
raise ToolRetryError(m) from r
else:
return result_data
Expand All @@ -72,7 +71,7 @@ async def validate(
class ToolRetryError(Exception):
"""Internal exception used to indicate a signal a `ToolRetry` message should be returned to the LLM."""

def __init__(self, tool_retry: messages.ToolRetry):
def __init__(self, tool_retry: messages.RetryPrompt):
self.tool_retry = tool_retry
super().__init__()

Expand Down Expand Up @@ -127,7 +126,7 @@ def validate(self, tool_call: messages.ToolCall) -> ResultData:
else:
result = self.type_adapter.validate_python(tool_call.args.args_object)
except ValidationError as e:
m = messages.ToolRetry(
m = messages.RetryPrompt(
tool_name=tool_call.tool_name,
content=e.errors(include_url=False),
tool_id=tool_call.tool_id,
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai/_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def _call_args(self, deps: AgentDeps, args_dict: dict[str, Any]) -> tuple[list[A

def _on_error(
self, content: list[pydantic_core.ErrorDetails] | str, call_message: messages.ToolCall
) -> messages.ToolRetry:
) -> messages.RetryPrompt:
self._current_retry += 1
if self._current_retry > self.max_retries:
# TODO custom error with details of the retriever
raise
else:
return messages.ToolRetry(
return messages.RetryPrompt(
tool_name=call_message.tool_name,
content=content,
tool_id=call_message.tool_id,
Expand Down
Loading

0 comments on commit 4bd30db

Please sign in to comment.