Skip to content

Commit

Permalink
rename and split shared.py (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Nov 4, 2024
1 parent f1bb5b5 commit ea4f999
Show file tree
Hide file tree
Showing 19 changed files with 151 additions and 141 deletions.
3 changes: 2 additions & 1 deletion pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from importlib.metadata import version

from .agent import Agent
from .shared import AgentError, CallContext, ModelRetry, UnexpectedModelBehaviour, UserError
from .call_typing import CallContext
from .exceptions import AgentError, ModelRetry, UnexpectedModelBehaviour, UserError

__all__ = 'Agent', 'AgentError', 'CallContext', 'ModelRetry', 'UnexpectedModelBehaviour', 'UserError', '__version__'
__version__ = version('pydantic_ai')
6 changes: 3 additions & 3 deletions pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

if TYPE_CHECKING:
from . import _retriever
from .shared import AgentDeps
from .call_typing import AgentDeps


__all__ = 'function_schema', 'LazyTypeAdapter'
Expand Down Expand Up @@ -118,7 +118,7 @@ def function_schema(either_function: _retriever.RetrieverEitherFunc[AgentDeps, _
var_positional_field = field_name

if errors:
from .shared import UserError
from .exceptions import UserError

error_details = '\n '.join(errors)
raise UserError(f'Error generating schema for {function.__qualname__}:\n {error_details}')
Expand Down Expand Up @@ -307,7 +307,7 @@ def _infer_docstring_style(doc: str) -> DocstringStyle:


def _is_call_ctx(annotation: Any) -> bool:
from .shared import CallContext
from .call_typing import CallContext

return annotation is CallContext or (
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is CallContext
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from typing_extensions import Self, TypeAliasType, TypedDict

from . import _utils, messages
from .call_typing import AgentDeps, CallContext
from .exceptions import ModelRetry
from .messages import LLMToolCalls, ToolCall
from .shared import AgentDeps, CallContext, ModelRetry, ResultData
from .result import ResultData

# A function that always takes `ResultData` and returns `ResultData`,
# but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai/_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from typing_extensions import Concatenate, ParamSpec

from . import _pydantic, _utils, messages
from .shared import AgentDeps, CallContext, ModelRetry
from .call_typing import AgentDeps, CallContext
from .exceptions import ModelRetry

# retrieval function parameters
P = ParamSpec('P')
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Generic, Union, cast

from . import _utils
from .shared import AgentDeps, CallContext
from .call_typing import AgentDeps, CallContext

# A function that may or maybe not take `CallContext` as an argument, and may or may not be async.
# Usage `SystemPromptFunc[AgentDeps]`
Expand Down
23 changes: 12 additions & 11 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from pydantic import ValidationError
from typing_extensions import assert_never

from . import _result, _retriever as _r, _system_prompt, _utils, messages as _messages, models, shared
from .shared import AgentDeps, ResultData
from . import _result, _retriever as _r, _system_prompt, _utils, exceptions, messages as _messages, models, result
from .call_typing import AgentDeps
from .result import ResultData

__all__ = 'Agent', 'KnownModelName'
KnownModelName = Literal[
Expand Down Expand Up @@ -74,7 +75,7 @@ async def run(
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
) -> shared.RunResult[ResultData]:
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt in async mode.
Args:
Expand All @@ -92,7 +93,7 @@ async def run(
model_ = self.model
custom_model = None
else:
raise shared.UserError('`model` must be set either when creating the agent or when calling it.')
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')

if deps is None:
deps = self._default_deps
Expand All @@ -115,7 +116,7 @@ async def run(
for retriever in self._retrievers.values():
retriever.reset()

cost = shared.Cost()
cost = result.Cost()

with _logfire.span(
'agent run {prompt=}', prompt=user_prompt, agent=self, custom_model=custom_model, model_name=model_.name()
Expand All @@ -139,17 +140,17 @@ async def run(
run_span.set_attribute('cost', cost)
handle_span.set_attribute('result', left.value)
handle_span.message = 'handle model response -> final result'
return shared.RunResult(left.value, cost, messages, new_message_index)
return result.RunResult(left.value, cost, messages, new_message_index)
else:
tool_responses = either.right
handle_span.set_attribute('tool_responses', tool_responses)
response_msgs = ' '.join(m.role for m in tool_responses)
handle_span.message = f'handle model response -> {response_msgs}'
messages.extend(tool_responses)
except (ValidationError, shared.UnexpectedModelBehaviour) as e:
except (ValidationError, exceptions.UnexpectedModelBehaviour) as e:
run_span.set_attribute('messages', messages)
# noinspection PyTypeChecker
raise shared.AgentError(messages, model_) from e
raise exceptions.AgentError(messages, model_) from e

def run_sync(
self,
Expand All @@ -158,7 +159,7 @@ def run_sync(
message_history: list[_messages.Message] | None = None,
model: models.Model | KnownModelName | None = None,
deps: AgentDeps | None = None,
) -> shared.RunResult[ResultData]:
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.
This is a convenience method that wraps `self.run` with `asyncio.run()`.
Expand Down Expand Up @@ -295,7 +296,7 @@ async def _handle_model_response(
retriever = self._retrievers.get(call.tool_name)
if retriever is None:
# should this be a retry error?
raise shared.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
raise exceptions.UnexpectedModelBehaviour(f'Unknown function name: {call.tool_name!r}')
coros.append(retriever.run(deps, call))
new_messages = await asyncio.gather(*coros)
return _utils.Either(right=new_messages)
Expand All @@ -312,7 +313,7 @@ async def _validate_result(
def _incr_result_retry(self) -> None:
self._current_result_retry += 1
if self._current_result_retry > self._max_result_retries:
raise shared.UnexpectedModelBehaviour(
raise exceptions.UnexpectedModelBehaviour(
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
)

Expand Down
17 changes: 17 additions & 0 deletions pydantic_ai/call_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations as _annotations

from dataclasses import dataclass
from typing import Generic, TypeVar

__all__ = 'AgentDeps', 'CallContext'

AgentDeps = TypeVar('AgentDeps')


@dataclass
class CallContext(Generic[AgentDeps]):
"""Information about the current call."""

deps: AgentDeps
retry: int
tool_name: str | None
83 changes: 2 additions & 81 deletions pydantic_ai/shared.py → pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations as _annotations

import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING

from pydantic import ValidationError

Expand All @@ -11,85 +10,7 @@
if TYPE_CHECKING:
from .models import Model

__all__ = (
'AgentDeps',
'ResultData',
'Cost',
'RunResult',
'ModelRetry',
'CallContext',
'AgentError',
'UserError',
'UnexpectedModelBehaviour',
)

AgentDeps = TypeVar('AgentDeps')
ResultData = TypeVar('ResultData')


@dataclass
class Cost:
"""Cost of a request or run."""

request_tokens: int | None = None
response_tokens: int | None = None
total_tokens: int | None = None
details: dict[str, int] | None = None

def __add__(self, other: Cost) -> Cost:
counts: dict[str, int] = {}
for field in 'request_tokens', 'response_tokens', 'total_tokens':
self_value = getattr(self, field)
other_value = getattr(other, field)
if self_value is not None or other_value is not None:
counts[field] = (self_value or 0) + (other_value or 0)

details = self.details.copy() if self.details is not None else None
if other.details is not None:
details = details or {}
for key, value in other.details.items():
details[key] = details.get(key, 0) + value

return Cost(**counts, details=details or None)


@dataclass
class RunResult(Generic[ResultData]):
"""Result of a run."""

response: ResultData
cost: Cost
_all_messages: list[messages.Message]
_new_message_index: int

def all_messages(self) -> list[messages.Message]:
"""Return the history of messages."""
# this is a method to be consistent with the other methods
return self._all_messages

def all_messages_json(self) -> bytes:
"""Return the history of messages as JSON bytes."""
return messages.MessagesTypeAdapter.dump_json(self.all_messages())

def new_messages(self) -> list[messages.Message]:
"""Return new messages associated with this run.
System prompts and any messages from older runs are excluded.
"""
return self.all_messages()[self._new_message_index :]

def new_messages_json(self) -> bytes:
"""Return new messages from [new_messages][] as JSON bytes."""
return messages.MessagesTypeAdapter.dump_json(self.new_messages())


@dataclass
class CallContext(Generic[AgentDeps]):
"""Information about the current call."""

deps: AgentDeps
retry: int
tool_name: str | None
__all__ = 'ModelRetry', 'AgentError', 'UserError', 'UnexpectedModelBehaviour'


class ModelRetry(Exception):
Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class UserPrompt:
role: Literal['user'] = 'user'


return_value_object = _pydantic.LazyTypeAdapter(dict[str, Any])
tool_return_value_object = _pydantic.LazyTypeAdapter(dict[str, Any])


@dataclass
Expand All @@ -43,14 +43,14 @@ def model_response_str(self) -> str:
if isinstance(self.content, str):
return self.content
else:
content = return_value_object.validate_python(self.content)
return return_value_object.dump_json(content).decode()
content = tool_return_value_object.validate_python(self.content)
return tool_return_value_object.dump_json(content).decode()

def model_response_object(self) -> dict[str, Any]:
if isinstance(self.content, str):
return {'return_value': self.content}
else:
return return_value_object.validate_python(self.content)
return tool_return_value_object.validate_python(self.content)


@dataclass
Expand Down Expand Up @@ -113,4 +113,4 @@ class LLMToolCalls:
LLMMessage = Union[LLMResponse, LLMToolCalls]
Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, LLMMessage]

MessagesTypeAdapter = pydantic.TypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
MessagesTypeAdapter = _pydantic.LazyTypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
4 changes: 2 additions & 2 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from .._utils import ObjectJsonSchema
from ..agent import KnownModelName
from ..shared import Cost
from ..result import Cost


class Model(ABC):
Expand Down Expand Up @@ -73,7 +73,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
# noinspection PyTypeChecker
return GeminiModel(model) # pyright: ignore[reportArgumentType]
else:
from ..shared import UserError
from ..exceptions import UserError

raise UserError(f'Unknown model: {model}')

Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing_extensions import TypeAlias

from .. import shared
from .. import result
from ..messages import LLMMessage, Message
from . import AbstractToolDefinition, AgentModel, Model

Expand Down Expand Up @@ -58,5 +58,5 @@ class FunctionAgentModel(AgentModel):
function: FunctionDef
agent_info: AgentInfo

async def request(self, messages: list[Message]) -> tuple[LLMMessage, shared.Cost]:
return self.function(messages, self.agent_info), shared.Cost()
async def request(self, messages: list[Message]) -> tuple[LLMMessage, result.Cost]:
return self.function(messages, self.agent_info), result.Cost()
Loading

0 comments on commit ea4f999

Please sign in to comment.