diff --git a/pydantic_ai/_pydantic.py b/pydantic_ai/_pydantic.py index 9fabcd84d..58c98f5ab 100644 --- a/pydantic_ai/_pydantic.py +++ b/pydantic_ai/_pydantic.py @@ -6,7 +6,7 @@ from __future__ import annotations as _annotations from inspect import Parameter, Signature, signature -from typing import Any, Callable, Literal, TypedDict, Union, cast, get_origin +from typing import Any, Callable, Literal, TypedDict, cast, get_origin from _griffe.enumerations import DocstringSectionKind from _griffe.models import Docstring, Object as GriffeObject @@ -31,9 +31,9 @@ class FunctionSchema(TypedDict): json_schema: ObjectJsonSchema takes_info: bool # if not None, the function takes a single by that name (besides potentially `info`) - single_arg_name: Union[str, None] + single_arg_name: str | None positional_fields: list[str] - var_positional_field: Union[str, None] + var_positional_field: str | None def function_schema(function: Callable[..., Any]) -> FunctionSchema: @@ -55,10 +55,10 @@ def function_schema(function: Callable[..., Any]) -> FunctionSchema: type_hints = _typing_extra.get_function_type_hints(function) - var_kwargs_schema: Union[core_schema.CoreSchema, None] = None + var_kwargs_schema: core_schema.CoreSchema | None = None fields: dict[str, core_schema.TypedDictField] = {} positional_fields: list[str] = [] - var_positional_field: Union[str, None] = None + var_positional_field: str | None = None errors: list[str] = [] decorators = _decorators.DecoratorInfos() description, field_descriptions = _doc_descriptions(function, sig) @@ -129,10 +129,10 @@ def function_schema(function: Callable[..., Any]) -> FunctionSchema: def _build_schema( fields: dict[str, core_schema.TypedDictField], - var_kwargs_schema: Union[core_schema.CoreSchema, None], + var_kwargs_schema: core_schema.CoreSchema | None, gen_schema: _generate_schema.GenerateSchema, core_config: core_schema.CoreConfig, -) -> tuple[core_schema.CoreSchema, Union[str, None]]: +) -> tuple[core_schema.CoreSchema, str | None]: """Generate a typed dict schema for function parameters. Args: @@ -163,7 +163,7 @@ def _build_schema( def _doc_descriptions( - func: Callable[..., Any], sig: Signature, *, style: Union[DocstringStyle, None] = None + func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None ) -> tuple[str, dict[str, str]]: """Extract the function description and parameter descriptions from a function's docstring. diff --git a/pydantic_ai/_utils.py b/pydantic_ai/_utils.py index c86bd8b01..5521a6e40 100644 --- a/pydantic_ai/_utils.py +++ b/pydantic_ai/_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations as _annotations + import asyncio from dataclasses import dataclass, is_dataclass from functools import partial @@ -102,14 +104,14 @@ def __init__(self, *, left: _Left) -> None: ... @overload def __init__(self, *, right: _Right) -> None: ... - def __init__(self, *, left: Union[_Left, None] = None, right: Union[_Right, None] = None) -> None: + def __init__(self, *, left: _Left | None = None, right: _Right | None = None) -> None: if (left is not None and right is not None) or (left is None and right is None): raise TypeError('Either must have exactly one value') self._left = left self._right = right @property - def left(self) -> Union[_Left, None]: + def left(self) -> _Left | None: return self._left @property diff --git a/pydantic_ai/agent.py b/pydantic_ai/agent.py index 021969004..15b1c32e8 100644 --- a/pydantic_ai/agent.py +++ b/pydantic_ai/agent.py @@ -22,16 +22,16 @@ class Agent(Generic[ResultData, AgentContext]): def __init__( self, - model: Union[_models.Model, KnownModelName, None] = None, + model: _models.Model | KnownModelName | None = None, response_type: type[_result.ResultData] = str, *, - system_prompt: Union[str, Sequence[str]] = (), + system_prompt: str | Sequence[str] = (), retrievers: Sequence[_r.Retriever[AgentContext, Any]] = (), context: AgentContext = None, retries: int = 1, response_schema_name: str = 'final_response', response_schema_description: str = 'The final response', - response_retries: Union[int, None] = None, + response_retries: int | None = None, ): self._model = _models.infer_model(model) if model is not None else None @@ -55,8 +55,8 @@ async def run( self, user_prompt: str, *, - message_history: Union[list[_messages.Message], None] = None, - model: Union[_models.Model, KnownModelName, None] = None, + message_history: list[_messages.Message] | None = None, + model: _models.Model | KnownModelName | None = None, ) -> _result.RunResult[_result.ResultData]: """Run the agent with a user prompt in async mode. @@ -101,8 +101,8 @@ def run_sync( self, user_prompt: str, *, - message_history: Union[list[_messages.Message], None] = None, - model: Union[_models.Model, KnownModelName, None] = None, + message_history: list[_messages.Message] | None = None, + model: _models.Model | KnownModelName | None = None, ) -> _result.RunResult[_result.ResultData]: """Run the agent with a user prompt synchronously. @@ -132,14 +132,14 @@ def retriever(self, func: _r.RetrieverFunc[AgentContext, _r.P], /) -> _r.Retriev @overload def retriever( - self, /, *, retries: Union[int, None] = None + self, /, *, retries: int | None = None ) -> Callable[ [_r.RetrieverFunc[AgentContext, _r.P]], _r.Retriever[AgentContext, _r.P], ]: ... def retriever( - self, func: Union[_r.RetrieverFunc[AgentContext, _r.P], None] = None, /, *, retries: Union[int, None] = None + self, func: _r.RetrieverFunc[AgentContext, _r.P] | None = None, /, *, retries: int | None = None ) -> Any: """Decorator to register a retriever function.""" if func is None: @@ -153,7 +153,7 @@ def retriever_decorator(func_: _r.RetrieverFunc[AgentContext, _r.P]) -> _r.Retri return self._register_retriever(func, retries) def _register_retriever( - self, func: _r.RetrieverFunc[AgentContext, _r.P], retries: Union[int, None] + self, func: _r.RetrieverFunc[AgentContext, _r.P], retries: int | None ) -> _r.Retriever[AgentContext, _r.P]: retries_ = retries if retries is not None else self._default_retries retriever = _r.Retriever[AgentContext, _r.P].build(func, retries_) diff --git a/pydantic_ai/messages.py b/pydantic_ai/messages.py index 11eaf2088..4c69244fb 100644 --- a/pydantic_ai/messages.py +++ b/pydantic_ai/messages.py @@ -1,3 +1,5 @@ +from __future__ import annotations as _annotations + import json from dataclasses import dataclass, field from datetime import datetime @@ -37,7 +39,7 @@ def llm_response(self) -> str: class FunctionRetry: function_id: str function_name: str - content: Union[list[pydantic_core.ErrorDetails], str] + content: list[pydantic_core.ErrorDetails] | str timestamp: datetime = field(default_factory=datetime.now) role: Literal['function-retry'] = 'function-retry' diff --git a/pydantic_ai/models/__init__.py b/pydantic_ai/models/__init__.py index aab7320e2..c3118b898 100644 --- a/pydantic_ai/models/__init__.py +++ b/pydantic_ai/models/__init__.py @@ -6,7 +6,7 @@ from __future__ import annotations as _annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Protocol, Union +from typing import TYPE_CHECKING, Protocol from ..messages import LLMMessage, Message @@ -35,7 +35,7 @@ async def request(self, messages: list[Message]) -> LLMMessage: # TODO streamed response -def infer_model(model: Union[Model, KnownModelName]) -> Model: +def infer_model(model: Model | KnownModelName) -> Model: """Infer the model from the name.""" if isinstance(model, Model): return model diff --git a/pydantic_ai/models/openai.py b/pydantic_ai/models/openai.py index 6c572cfad..addc2785c 100644 --- a/pydantic_ai/models/openai.py +++ b/pydantic_ai/models/openai.py @@ -1,7 +1,9 @@ +from __future__ import annotations as _annotations + from dataclasses import dataclass from datetime import datetime from functools import cache -from typing import Literal, Union, assert_never +from typing import Literal, assert_never from openai import AsyncClient from openai.types import ChatModel @@ -18,9 +20,7 @@ class OpenAIModel(Model): - def __init__( - self, model_name: ChatModel, *, api_key: Union[str, None] = None, client: Union[AsyncClient, None] = None - ): + def __init__(self, model_name: ChatModel, *, api_key: str | None = None, client: AsyncClient | None = None): if model_name not in ChatModel.__args__: raise ValueError(f'Invalid model name: {model_name}') self.model_name: ChatModel = model_name diff --git a/pydantic_ai/result.py b/pydantic_ai/result.py index 5f4c74d03..5cd5bb596 100644 --- a/pydantic_ai/result.py +++ b/pydantic_ai/result.py @@ -1,6 +1,8 @@ +from __future__ import annotations as _annotations + from collections.abc import AsyncIterable from dataclasses import dataclass -from typing import Any, Generic, TypedDict, TypeVar, Union +from typing import Any, Generic, TypedDict, TypeVar from pydantic import TypeAdapter, ValidationError from typing_extensions import Self @@ -65,7 +67,7 @@ class ResultSchema(Generic[ResultData]): _current_retry: int = 0 @classmethod - def build(cls, response_type: type[ResultData], name: str, description: str, retries: int) -> Union[Self, None]: + def build(cls, response_type: type[ResultData], name: str, description: str, retries: int) -> Self | None: """Build a ResponseModel dataclass from a response type.""" if response_type is str: return None diff --git a/pydantic_ai/retrievers.py b/pydantic_ai/retrievers.py index b67e4e925..81743bf63 100644 --- a/pydantic_ai/retrievers.py +++ b/pydantic_ai/retrievers.py @@ -47,9 +47,9 @@ class Retriever(Generic[AgentContext, P]): function: RetrieverFunc[AgentContext, P] is_async: bool takes_info: bool - single_arg_name: Union[str, None] + single_arg_name: str | None positional_fields: list[str] - var_positional_field: Union[str, None] + var_positional_field: str | None validator: SchemaValidator json_schema: _utils.ObjectJsonSchema max_retries: int @@ -117,7 +117,7 @@ def _call_args(self, context: AgentContext, args_dict: dict[str, Any]) -> tuple[ return args, args_dict def _on_error( - self, content: Union[list[pydantic_core.ErrorDetails], str], call_message: messages.FunctionCall + self, content: list[pydantic_core.ErrorDetails] | str, call_message: messages.FunctionCall ) -> messages.FunctionRetry: self._current_retry += 1 if self._current_retry > self.max_retries: diff --git a/pyproject.toml b/pyproject.toml index f0a6f2943..ef61fec88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,10 +91,6 @@ convention = "google" docstring-code-format = true quote-style = "single" -[tool.ruff.lint.pyupgrade] -# Preserve types, even if a file imports `from __future__ import annotations`. -keep-runtime-typing = true - [tool.pyright] typeCheckingMode = "strict" reportUnnecessaryTypeIgnoreComment = true