diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 661ef089..9e23b4c5 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -4,6 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass, field, replace from datetime import datetime +from mimetypes import guess_type from typing import Annotated, Any, Literal, Union, cast, overload import pydantic @@ -80,8 +81,28 @@ def media_type(self) -> ImageMediaType: raise ValueError(f'Unknown image file extension: {self.url}') +@dataclass +class DocumentUrl: + """The URL of the document.""" + + url: str + """The URL of the document.""" + + kind: Literal['document-url'] = 'document-url' + """Type identifier, this is available on all parts as a discriminator.""" + + @property + def media_type(self) -> str: + """Return the media type of the document, based on the url.""" + type_, _ = guess_type(self.url) + if type_ is None: + raise RuntimeError(f'Unknown document file extension: {self.url}') + return type_ + + AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg'] ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp'] +DocumentMediaType: TypeAlias = Literal['application/pdf', 'text/plain'] @dataclass @@ -91,7 +112,7 @@ class BinaryContent: data: bytes """The binary data.""" - media_type: AudioMediaType | ImageMediaType | str + media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str """The media type of the binary data.""" kind: Literal['binary'] = 'binary' @@ -107,6 +128,11 @@ def is_image(self) -> bool: """Return `True` if the media type is an image type.""" return self.media_type.startswith('image/') + @property + def is_document(self) -> bool: + """Return `True` if the media type is a document type.""" + return self.media_type in {'application/pdf', 'text/plain'} + @property def audio_format(self) -> Literal['mp3', 'wav']: """Return the audio format given the media type.""" @@ -118,7 +144,7 @@ def audio_format(self) -> Literal['mp3', 'wav']: raise ValueError(f'Unknown audio media type: {self.media_type}') -UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent' +UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | BinaryContent' @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index a71bddbd..928c2e77 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -8,6 +8,7 @@ from json import JSONDecodeError, loads as json_loads from typing import Any, Literal, Union, cast, overload +from anthropic.types import DocumentBlockParam from httpx import AsyncClient as AsyncHTTPClient from typing_extensions import assert_never @@ -15,6 +16,7 @@ from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( BinaryContent, + DocumentUrl, ImageUrl, ModelMessage, ModelRequest, @@ -41,10 +43,12 @@ try: from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream from anthropic.types import ( + Base64PDFSourceParam, ImageBlockParam, Message as AnthropicMessage, MessageParam, MetadataParam, + PlainTextSourceParam, RawContentBlockDeltaEvent, RawContentBlockStartEvent, RawContentBlockStopEvent, @@ -322,7 +326,9 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me return system_prompt, anthropic_messages @staticmethod - async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]: + async def _map_user_prompt( + part: UserPromptPart, + ) -> AsyncGenerator[ImageBlockParam | TextBlockParam | DocumentBlockParam]: if isinstance(part.content, str): yield TextBlockParam(text=part.content, type='text') else: @@ -344,6 +350,25 @@ async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockPar source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'}, type='image', ) + elif isinstance(item, DocumentUrl): + response = await cached_async_http_client().get(item.url) + response.raise_for_status() + if item.media_type == 'application/pdf': + yield DocumentBlockParam( + source=Base64PDFSourceParam( + data=io.BytesIO(response.content), + media_type=item.media_type, + type='base64', + ), + type='document', + ) + elif item.media_type == 'text/plain': + yield DocumentBlockParam( + source=PlainTextSourceParam(data=response.text, media_type=item.media_type, type='text'), + type='document', + ) + else: + raise RuntimeError(f'Unsupported media type: {item.media_type}') else: raise RuntimeError(f'Unsupported content type: {type(item)}') diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 5d843e08..b89c9ddc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -19,6 +19,7 @@ from ..messages import ( AudioUrl, BinaryContent, + DocumentUrl, ImageUrl, ModelMessage, ModelRequest, @@ -319,7 +320,7 @@ async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]: elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') content.append(_GeminiInlineDataPart(data=base64_encoded, mime_type=item.media_type)) - elif isinstance(item, (AudioUrl, ImageUrl)): + elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl)): try: content.append(_GeminiFileDataData(file_uri=item.url, mime_type=item.media_type)) except ValueError: