Skip to content

Commit

Permalink
use pipe unions where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Oct 10, 2024
1 parent 22fa033 commit 11d567d
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 36 deletions.
16 changes: 8 additions & 8 deletions pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations as _annotations

from inspect import Parameter, Signature, signature
from typing import Any, Callable, Literal, TypedDict, Union, cast, get_origin
from typing import Any, Callable, Literal, TypedDict, cast, get_origin

from _griffe.enumerations import DocstringSectionKind
from _griffe.models import Docstring, Object as GriffeObject
Expand All @@ -31,9 +31,9 @@ class FunctionSchema(TypedDict):
json_schema: ObjectJsonSchema
takes_info: bool
# if not None, the function takes a single by that name (besides potentially `info`)
single_arg_name: Union[str, None]
single_arg_name: str | None
positional_fields: list[str]
var_positional_field: Union[str, None]
var_positional_field: str | None


def function_schema(function: Callable[..., Any]) -> FunctionSchema:
Expand All @@ -55,10 +55,10 @@ def function_schema(function: Callable[..., Any]) -> FunctionSchema:

type_hints = _typing_extra.get_function_type_hints(function)

var_kwargs_schema: Union[core_schema.CoreSchema, None] = None
var_kwargs_schema: core_schema.CoreSchema | None = None
fields: dict[str, core_schema.TypedDictField] = {}
positional_fields: list[str] = []
var_positional_field: Union[str, None] = None
var_positional_field: str | None = None
errors: list[str] = []
decorators = _decorators.DecoratorInfos()
description, field_descriptions = _doc_descriptions(function, sig)
Expand Down Expand Up @@ -129,10 +129,10 @@ def function_schema(function: Callable[..., Any]) -> FunctionSchema:

def _build_schema(
fields: dict[str, core_schema.TypedDictField],
var_kwargs_schema: Union[core_schema.CoreSchema, None],
var_kwargs_schema: core_schema.CoreSchema | None,
gen_schema: _generate_schema.GenerateSchema,
core_config: core_schema.CoreConfig,
) -> tuple[core_schema.CoreSchema, Union[str, None]]:
) -> tuple[core_schema.CoreSchema, str | None]:
"""Generate a typed dict schema for function parameters.
Args:
Expand Down Expand Up @@ -163,7 +163,7 @@ def _build_schema(


def _doc_descriptions(
func: Callable[..., Any], sig: Signature, *, style: Union[DocstringStyle, None] = None
func: Callable[..., Any], sig: Signature, *, style: DocstringStyle | None = None
) -> tuple[str, dict[str, str]]:
"""Extract the function description and parameter descriptions from a function's docstring.
Expand Down
6 changes: 4 additions & 2 deletions pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations as _annotations

import asyncio
from dataclasses import dataclass, is_dataclass
from functools import partial
Expand Down Expand Up @@ -102,14 +104,14 @@ def __init__(self, *, left: _Left) -> None: ...
@overload
def __init__(self, *, right: _Right) -> None: ...

def __init__(self, *, left: Union[_Left, None] = None, right: Union[_Right, None] = None) -> None:
def __init__(self, *, left: _Left | None = None, right: _Right | None = None) -> None:
if (left is not None and right is not None) or (left is None and right is None):
raise TypeError('Either must have exactly one value')
self._left = left
self._right = right

@property
def left(self) -> Union[_Left, None]:
def left(self) -> _Left | None:
return self._left

@property
Expand Down
20 changes: 10 additions & 10 deletions pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ class Agent(Generic[ResultData, AgentContext]):

def __init__(
self,
model: Union[_models.Model, KnownModelName, None] = None,
model: _models.Model | KnownModelName | None = None,
response_type: type[_result.ResultData] = str,
*,
system_prompt: Union[str, Sequence[str]] = (),
system_prompt: str | Sequence[str] = (),
retrievers: Sequence[_r.Retriever[AgentContext, Any]] = (),
context: AgentContext = None,
retries: int = 1,
response_schema_name: str = 'final_response',
response_schema_description: str = 'The final response',
response_retries: Union[int, None] = None,
response_retries: int | None = None,
):
self._model = _models.infer_model(model) if model is not None else None

Expand All @@ -55,8 +55,8 @@ async def run(
self,
user_prompt: str,
*,
message_history: Union[list[_messages.Message], None] = None,
model: Union[_models.Model, KnownModelName, None] = None,
message_history: list[_messages.Message] | None = None,
model: _models.Model | KnownModelName | None = None,
) -> _result.RunResult[_result.ResultData]:
"""Run the agent with a user prompt in async mode.
Expand Down Expand Up @@ -101,8 +101,8 @@ def run_sync(
self,
user_prompt: str,
*,
message_history: Union[list[_messages.Message], None] = None,
model: Union[_models.Model, KnownModelName, None] = None,
message_history: list[_messages.Message] | None = None,
model: _models.Model | KnownModelName | None = None,
) -> _result.RunResult[_result.ResultData]:
"""Run the agent with a user prompt synchronously.
Expand Down Expand Up @@ -132,14 +132,14 @@ def retriever(self, func: _r.RetrieverFunc[AgentContext, _r.P], /) -> _r.Retriev

@overload
def retriever(
self, /, *, retries: Union[int, None] = None
self, /, *, retries: int | None = None
) -> Callable[
[_r.RetrieverFunc[AgentContext, _r.P]],
_r.Retriever[AgentContext, _r.P],
]: ...

def retriever(
self, func: Union[_r.RetrieverFunc[AgentContext, _r.P], None] = None, /, *, retries: Union[int, None] = None
self, func: _r.RetrieverFunc[AgentContext, _r.P] | None = None, /, *, retries: int | None = None
) -> Any:
"""Decorator to register a retriever function."""
if func is None:
Expand All @@ -153,7 +153,7 @@ def retriever_decorator(func_: _r.RetrieverFunc[AgentContext, _r.P]) -> _r.Retri
return self._register_retriever(func, retries)

def _register_retriever(
self, func: _r.RetrieverFunc[AgentContext, _r.P], retries: Union[int, None]
self, func: _r.RetrieverFunc[AgentContext, _r.P], retries: int | None
) -> _r.Retriever[AgentContext, _r.P]:
retries_ = retries if retries is not None else self._default_retries
retriever = _r.Retriever[AgentContext, _r.P].build(func, retries_)
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai/messages.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations as _annotations

import json
from dataclasses import dataclass, field
from datetime import datetime
Expand Down Expand Up @@ -37,7 +39,7 @@ def llm_response(self) -> str:
class FunctionRetry:
function_id: str
function_name: str
content: Union[list[pydantic_core.ErrorDetails], str]
content: list[pydantic_core.ErrorDetails] | str
timestamp: datetime = field(default_factory=datetime.now)
role: Literal['function-retry'] = 'function-retry'

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations as _annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Protocol, Union
from typing import TYPE_CHECKING, Protocol

from ..messages import LLMMessage, Message

Expand Down Expand Up @@ -35,7 +35,7 @@ async def request(self, messages: list[Message]) -> LLMMessage:
# TODO streamed response


def infer_model(model: Union[Model, KnownModelName]) -> Model:
def infer_model(model: Model | KnownModelName) -> Model:
"""Infer the model from the name."""
if isinstance(model, Model):
return model
Expand Down
8 changes: 4 additions & 4 deletions pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations as _annotations

from dataclasses import dataclass
from datetime import datetime
from functools import cache
from typing import Literal, Union, assert_never
from typing import Literal, assert_never

from openai import AsyncClient
from openai.types import ChatModel
Expand All @@ -18,9 +20,7 @@


class OpenAIModel(Model):
def __init__(
self, model_name: ChatModel, *, api_key: Union[str, None] = None, client: Union[AsyncClient, None] = None
):
def __init__(self, model_name: ChatModel, *, api_key: str | None = None, client: AsyncClient | None = None):
if model_name not in ChatModel.__args__:
raise ValueError(f'Invalid model name: {model_name}')
self.model_name: ChatModel = model_name
Expand Down
6 changes: 4 additions & 2 deletions pydantic_ai/result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations as _annotations

from collections.abc import AsyncIterable
from dataclasses import dataclass
from typing import Any, Generic, TypedDict, TypeVar, Union
from typing import Any, Generic, TypedDict, TypeVar

from pydantic import TypeAdapter, ValidationError
from typing_extensions import Self
Expand Down Expand Up @@ -65,7 +67,7 @@ class ResultSchema(Generic[ResultData]):
_current_retry: int = 0

@classmethod
def build(cls, response_type: type[ResultData], name: str, description: str, retries: int) -> Union[Self, None]:
def build(cls, response_type: type[ResultData], name: str, description: str, retries: int) -> Self | None:
"""Build a ResponseModel dataclass from a response type."""
if response_type is str:
return None
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class Retriever(Generic[AgentContext, P]):
function: RetrieverFunc[AgentContext, P]
is_async: bool
takes_info: bool
single_arg_name: Union[str, None]
single_arg_name: str | None
positional_fields: list[str]
var_positional_field: Union[str, None]
var_positional_field: str | None
validator: SchemaValidator
json_schema: _utils.ObjectJsonSchema
max_retries: int
Expand Down Expand Up @@ -117,7 +117,7 @@ def _call_args(self, context: AgentContext, args_dict: dict[str, Any]) -> tuple[
return args, args_dict

def _on_error(
self, content: Union[list[pydantic_core.ErrorDetails], str], call_message: messages.FunctionCall
self, content: list[pydantic_core.ErrorDetails] | str, call_message: messages.FunctionCall
) -> messages.FunctionRetry:
self._current_retry += 1
if self._current_retry > self.max_retries:
Expand Down
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ convention = "google"
docstring-code-format = true
quote-style = "single"

[tool.ruff.lint.pyupgrade]
# Preserve types, even if a file imports `from __future__ import annotations`.
keep-runtime-typing = true

[tool.pyright]
typeCheckingMode = "strict"
reportUnnecessaryTypeIgnoreComment = true
Expand Down

0 comments on commit 11d567d

Please sign in to comment.