Skip to content

Commit

Permalink
Gemini model (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 17, 2024
1 parent 1f7a174 commit 2350965
Show file tree
Hide file tree
Showing 28 changed files with 1,322 additions and 384 deletions.
26 changes: 12 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: Set up Python 3.12
run: uv python install 3.12
- run: uv python install 3.12

- name: Install dependencies
run: uv sync --python 3.12 --frozen
Expand All @@ -44,8 +42,7 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true

Expand Down Expand Up @@ -77,14 +74,17 @@ jobs:
merge-multiple: true
path: coverage

- run: pip install coverage[toml] --break-system-packages
- run: coverage combine coverage
- run: coverage xml
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- run: uv run --frozen coverage combine coverage
# - run: uv run --frozen coverage xml
# - uses: codecov/codecov-action@v4
# with:
# token: ${{ secrets.CODECOV_TOKEN }}
# file: ./coverage.xml
- run: coverage report --fail-under 70
- run: uv run --frozen coverage report --fail-under 70

# https://github.com/marketplace/actions/alls-green#why used for branch protection checks
check:
Expand All @@ -109,13 +109,11 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
- uses: astral-sh/setup-uv@v3
with:
enable-cache: true

- name: Set up Python 3.12
run: uv python install 3.12
- run: uv python install 3.12

- name: check GITHUB_REF matches package version
uses: samuelcolvin/check-python-version@v4.1
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ typecheck: typecheck-pyright
.PHONY: test # Run tests and collect coverage data
test:
uv run coverage run -m pytest
@uv run coverage report

.PHONY: testcov # Run tests and generate a coverage report
testcov: test
@echo "building coverage html"
@uv run coverage html
@uv run coverage report

.PHONY: all
all: format lint typecheck testcov
6 changes: 3 additions & 3 deletions demos/parse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ class MyModel(BaseModel):

agent = Agent('openai:gpt-4o', result_type=MyModel, deps=None)

result = agent.run_sync('The windy city in the US of A.')

print(result.response)
if __name__ == '__main__':
result = agent.run_sync('The windy city in the US of A.')
print(result.response)
63 changes: 63 additions & 0 deletions demos/sql_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from dataclasses import dataclass

from pydantic_ai import Agent
from devtools import debug

system_prompt = """\
Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request.
CREATE TABLE records AS (
start_timestamp timestamp with time zone,
created_at timestamp with time zone,
trace_id text,
span_id text,
parent_span_id text,
kind span_kind,
end_timestamp timestamp with time zone,
level smallint,
span_name text,
message text,
attributes_json_schema text,
attributes jsonb,
tags text[],
otel_links jsonb,
otel_events jsonb,
is_exception boolean,
otel_status_code status_code,
otel_status_message text,
otel_scope_name text,
otel_scope_version text,
otel_scope_attributes jsonb,
service_namespace text,
service_name text,
service_version text,
service_instance_id text,
process_pid integer
);
today's date = 2024-10-09
Example
request: show me records where foobar is false
response: SELECT * FROM records WHERE attributes->>'foobar' = false'
Example
request: show me records from yesterday
response: SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'
Example
request: show me error records with the tag "foobar"
response: SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)
"""


@dataclass
class Response:
sql_query: str


agent = Agent('gemini-1.5-flash', result_type=Response, system_prompt=system_prompt, deps=None)


if __name__ == '__main__':
with debug.timer('SQL Generation'):
result = agent.run_sync('show me logs from yesterday, with level "error"')
debug(result.response.sql_query)
12 changes: 4 additions & 8 deletions demos/weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@

from pydantic_ai import Agent

weather_agent: Agent[None, str] = Agent('openai:gpt-4o')


@weather_agent.system_prompt
def system_prompt():
return 'Be concise, reply with one sentence.'
weather_agent: Agent[None, str] = Agent('openai:gpt-4o', system_prompt='Be concise, reply with one sentence.')


@weather_agent.retriever_plain
Expand All @@ -32,5 +27,6 @@ async def get_whether(lat: float, lng: float):
return 'Sunny'


result = weather_agent.run_sync('What is the weather like in West London and in Wiltshire?')
debug(result)
if __name__ == '__main__':
result = weather_agent.run_sync('What is the weather like in West London and in Wiltshire?')
debug(result)
4 changes: 2 additions & 2 deletions pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .agent import Agent
from .call import CallContext, Retry
from .shared import AgentError, CallContext, Retry, UnexpectedModelBehaviour, UserError

__all__ = 'Agent', 'CallContext', 'Retry'
__all__ = 'Agent', 'AgentError', 'CallContext', 'Retry', 'UnexpectedModelBehaviour', 'UserError'
4 changes: 2 additions & 2 deletions pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

if TYPE_CHECKING:
from . import _retriever
from .call import AgentDeps
from .shared import AgentDeps


__all__ = ('function_schema',)
Expand Down Expand Up @@ -228,7 +228,7 @@ def _infer_docstring_style(doc: str) -> DocstringStyle:


def _is_call_ctx(annotation: Any) -> bool:
from .call import CallContext
from .shared import CallContext

return annotation is CallContext or (
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is CallContext
Expand Down
75 changes: 75 additions & 0 deletions pydantic_ai/_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations as _annotations

import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Callable, Generic, Union, cast

from . import _utils, messages
from .result import ResultData
from .shared import AgentDeps, CallContext, Retry

# A function that always takes `ResultData` and returns `ResultData`,
# but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
# Usage `ResultValidator[AgentDeps, ResultData]`
ResultValidatorFunc = Union[
Callable[[CallContext[AgentDeps], ResultData], ResultData],
Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]],
Callable[[ResultData], ResultData],
Callable[[ResultData], Awaitable[ResultData]],
]


@dataclass
class ResultValidator(Generic[AgentDeps, ResultData]):
function: ResultValidatorFunc[AgentDeps, ResultData]
_takes_ctx: bool = False
_is_async: bool = False

def __post_init__(self):
self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
self._is_async = inspect.iscoroutinefunction(self.function)

async def validate(
self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall
) -> ResultData:
"""Validate a result but calling the function.
Args:
result: The result data after Pydantic validation the message content.
deps: The agent dependencies.
retry: The current retry number.
tool_call: The original tool call message.
Returns:
Result of either the validated result data (ok) or a retry message (Err).
"""
if self._takes_ctx:
args = CallContext(deps, retry), result
else:
args = (result,)

try:
if self._is_async:
function = cast(Callable[[Any], Awaitable[ResultData]], self.function)
result_data = await function(*args)
else:
function = cast(Callable[[Any], ResultData], self.function)
result_data = await _utils.run_in_executor(function, *args)
except Retry as r:
m = messages.ToolRetry(
tool_name=tool_call.tool_name,
content=r.message,
tool_id=tool_call.tool_id,
)
raise ToolRetryError(m) from r
else:
return result_data


class ToolRetryError(Exception):
"""Internal exception used to indicate a signal a `ToolRetry` message should be returned to the LLM"""

def __init__(self, tool_retry: messages.ToolRetry):
self.tool_retry = tool_retry
super().__init__()
7 changes: 5 additions & 2 deletions pydantic_ai/_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Concatenate, ParamSpec

from . import _pydantic, _utils, messages
from .call import AgentDeps, CallContext, Retry
from .shared import AgentDeps, CallContext, Retry

# retrieval function parameters
P = ParamSpec('P')
Expand Down Expand Up @@ -64,7 +64,10 @@ def reset(self) -> None:
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
"""Run the retriever function asynchronously."""
try:
args_dict = self.validator.validate_json(message.arguments)
if isinstance(message.args, messages.ArgsJson):
args_dict = self.validator.validate_json(message.args.args_json)
else:
args_dict = self.validator.validate_python(message.args.args_object)
except ValidationError as e:
return self._on_error(e.errors(include_url=False), message)

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Generic, Union, cast

from . import _utils
from .call import AgentDeps, CallContext
from .shared import AgentDeps, CallContext

# A function that may or maybe not take `CallInfo` as an argument, and may or may not be async.
# Usage `SystemPromptFunc[AgentDeps]`
Expand Down
19 changes: 4 additions & 15 deletions pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,11 @@
from dataclasses import dataclass, is_dataclass
from functools import partial
from types import GenericAlias
from typing import (
Any,
Callable,
Generic,
Literal,
TypedDict,
TypeVar,
Union,
cast,
get_args,
overload,
)
from typing import Any, Callable, Generic, Literal, TypeVar, Union, cast, get_args, overload

from pydantic import BaseModel
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import ParamSpec, TypeAlias, is_typeddict
from typing_extensions import NotRequired, ParamSpec, TypeAlias, TypedDict, is_typeddict

_P = ParamSpec('_P')
_R = TypeVar('_R')
Expand Down Expand Up @@ -59,8 +48,8 @@ def is_model_like(type_: Any) -> bool:
'type': Literal['object'],
'title': str,
'properties': dict[str, JsonSchemaValue],
'required': list[str],
'$defs': dict[str, Any],
'required': NotRequired[list[str]],
'$defs': NotRequired[dict[str, Any]],
},
)

Expand Down
Loading

0 comments on commit 2350965

Please sign in to comment.