Skip to content

Commit

Permalink
Testing generate args (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Oct 13, 2024
1 parent ff834b5 commit 7365101
Show file tree
Hide file tree
Showing 12 changed files with 474 additions and 58 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ test:
testcov: test
@echo "building coverage html"
@uv run coverage html --show-contexts
@uv run coverage report

.PHONY: all
all: format lint typecheck test
4 changes: 1 addition & 3 deletions demos/parse_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from devtools import debug
from pydantic import BaseModel

from pydantic_ai import Agent
Expand All @@ -11,7 +10,6 @@ class MyModel(BaseModel):

agent = Agent('openai:gpt-4o', response_type=MyModel, deps=None)

# debug(agent.result_schema.json_schema)
result = agent.run_sync('The windy city in the US of A.')

debug(result.response)
print(result.response)
11 changes: 10 additions & 1 deletion pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ def function_schema(either_function: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P])
field_info,
decorators,
)
td_schema['metadata'] = {'is_model_like': is_model_like(annotation)}
extra_metadata = {'is_model_like': is_model_like(annotation)}
if metadata := td_schema.get('metadata'):
metadata.update(extra_metadata)
else:
td_schema['metadata'] = extra_metadata
if p.kind == Parameter.POSITIONAL_ONLY:
positional_fields.append(field_name)
elif p.kind == Parameter.VAR_POSITIONAL:
Expand All @@ -130,6 +134,11 @@ def function_schema(either_function: _r.RetrieverEitherFunc[_r.AgentDeps, _r.P])
# PluggableSchemaValidator is api compat with SchemaValidator
schema_validator = cast(SchemaValidator, schema_validator)
json_schema = GenerateJsonSchema().generate(schema)

# instead of passing `description` through in core_schema, we just add it here
if description:
json_schema = {'description': description} | json_schema

return FunctionSchema(
description=description,
validator=schema_validator,
Expand Down
15 changes: 10 additions & 5 deletions pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,16 @@ def is_model_like(type_: Any) -> bool:
)


class ObjectJsonSchema(TypedDict):
type: Literal['object']
title: str
properties: dict[str, JsonSchemaValue]
required: list[str]
ObjectJsonSchema = TypedDict(
'ObjectJsonSchema',
{
'type': Literal['object'],
'title': str,
'properties': dict[str, JsonSchemaValue],
'required': list[str],
'$defs': dict[str, Any],
},
)


def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Agent(Generic[AgentDeps, ResultData]):
__slots__ = (
'_model',
'result_schema',
'_allow_plain_message',
'_allow_plain_response',
'_system_prompts',
'_retrievers',
'_default_retries',
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(
response_schema_description,
response_retries if response_retries is not None else retries,
)
self._allow_plain_message = self.result_schema is None or self.result_schema.allow_plain_message
self._allow_plain_response = self.result_schema is None or self.result_schema.allow_plain_response

self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
self._retrievers: dict[str, _r.Retriever[AgentDeps, Any]] = {r_.name: r_ for r_ in retrievers}
Expand Down Expand Up @@ -98,10 +98,10 @@ async def run(

messages.append(_messages.UserPrompt(user_prompt))

functions: list[_models.AbstractRetrieverDefinition] = list(self._retrievers.values())
functions: list[_models.AbstractToolDefinition] = list(self._retrievers.values())
if self.result_schema is not None:
functions.append(self.result_schema)
agent_model = model_.agent_model(self._allow_plain_message, functions)
agent_model = model_.agent_model(self._allow_plain_response, functions)

for retriever in self._retrievers.values():
retriever.reset()
Expand Down Expand Up @@ -217,7 +217,7 @@ async def _handle_model_response(
messages.append(llm_message)
if llm_message.role == 'llm-response':
# plain string response
if self._allow_plain_message:
if self._allow_plain_response:
return _utils.Some(cast(ResultData, llm_message.content))
else:
messages.append(_messages.PlainResponseForbidden())
Expand Down
16 changes: 12 additions & 4 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class Model(ABC):
"""Abstract class for a model."""

@abstractmethod
def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
"""Create an agent model."""
def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
"""Create an agent model.
Args:
allow_plain_response: Whether plain text final response is permitted.
tools: The tools available to the agent.
"""
raise NotImplementedError()


Expand All @@ -47,8 +52,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
raise TypeError(f'Invalid model: {model}')


class AbstractRetrieverDefinition(Protocol):
"""Abstract definition of a retriever/function/tool."""
class AbstractToolDefinition(Protocol):
"""Abstract definition of a function/tool.
These are generally retrievers, but can also include the response function if one exists.
"""

name: str
description: str
Expand Down
24 changes: 13 additions & 11 deletions pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from typing import TYPE_CHECKING, Protocol

from ..messages import LLMMessage, Message
from . import AbstractRetrieverDefinition, AgentModel, Model
from . import AbstractToolDefinition, AgentModel, Model

if TYPE_CHECKING:
from .._utils import ObjectJsonSchema


class FunctionDef(Protocol):
def __call__(
self, messages: list[Message], allow_plain_message: bool, retrievers: dict[str, RetrieverDescription], /
self, messages: list[Message], allow_plain_response: bool, tools: dict[str, ToolDescription], /
) -> LLMMessage: ...


@dataclass
class RetrieverDescription:
class ToolDescription:
name: str
description: str
json_schema: ObjectJsonSchema
Expand All @@ -30,19 +30,21 @@ class FunctionModel(Model):

function: FunctionDef

def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
return TestAgentModel(
def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
return FunctionAgentModel(
self.function,
allow_plain_message,
{r.name: RetrieverDescription(r.name, r.description, r.json_schema) for r in retrievers},
allow_plain_response,
{r.name: ToolDescription(r.name, r.description, r.json_schema) for r in tools},
)


@dataclass
class TestAgentModel(AgentModel):
class FunctionAgentModel(AgentModel):
__test__ = False

function: FunctionDef
allow_plain_message: bool
retrievers: dict[str, RetrieverDescription]
allow_plain_response: bool
tools: dict[str, ToolDescription]

async def request(self, messages: list[Message]) -> LLMMessage:
return self.function(messages, self.allow_plain_message, self.retrievers)
return self.function(messages, self.allow_plain_response, self.tools)
14 changes: 7 additions & 7 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LLMResponse,
Message,
)
from . import AbstractRetrieverDefinition, AgentModel, Model
from . import AbstractToolDefinition, AgentModel, Model


class OpenAIModel(Model):
Expand All @@ -26,20 +26,20 @@ def __init__(self, model_name: ChatModel, *, api_key: str | None = None, client:
self.model_name: ChatModel = model_name
self.client = client or cached_async_client(api_key)

def agent_model(self, allow_plain_message: bool, retrievers: list[AbstractRetrieverDefinition]) -> AgentModel:
def agent_model(self, allow_plain_response: bool, tools: list[AbstractToolDefinition]) -> AgentModel:
return OpenAIAgentModel(
self.client,
self.model_name,
allow_plain_message,
[map_retriever_definition(t) for t in retrievers],
allow_plain_response,
[map_tool_definition(t) for t in tools],
)


@dataclass
class OpenAIAgentModel(AgentModel):
client: AsyncClient
model_name: ChatModel
allow_plain_message: bool
allow_plain_response: bool
tools: list[ChatCompletionToolParam]

async def request(self, messages: list[Message]) -> LLMMessage:
Expand All @@ -66,7 +66,7 @@ async def completions_create(self, messages: list[Message]) -> ChatCompletion:
# standalone function to make it easier to override
if not self.tools:
tool_choice: Literal['none', 'required', 'auto'] = 'none'
elif not self.allow_plain_message:
elif not self.allow_plain_response:
tool_choice = 'required'
else:
tool_choice = 'auto'
Expand All @@ -87,7 +87,7 @@ def cached_async_client(api_key: str) -> AsyncClient:
return AsyncClient(api_key=api_key)


def map_retriever_definition(f: AbstractRetrieverDefinition) -> ChatCompletionToolParam:
def map_tool_definition(f: AbstractToolDefinition) -> ChatCompletionToolParam:
return {
'type': 'function',
'function': {
Expand Down
Loading

0 comments on commit 7365101

Please sign in to comment.