Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DocumentUrl and support document via BinaryContent #987

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence
from dataclasses import dataclass, field, replace
from datetime import datetime
from mimetypes import guess_type
from typing import Annotated, Any, Literal, Union, cast, overload

import pydantic
Expand Down Expand Up @@ -80,8 +81,28 @@ def media_type(self) -> ImageMediaType:
raise ValueError(f'Unknown image file extension: {self.url}')


@dataclass
class DocumentUrl:
"""The URL of the document."""

url: str
"""The URL of the document."""

kind: Literal['document-url'] = 'document-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> str:
"""Return the media type of the document, based on the url."""
type_, _ = guess_type(self.url)
if type_ is None:
raise RuntimeError(f'Unknown document file extension: {self.url}')
return type_


AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
DocumentMediaType: TypeAlias = Literal['application/pdf', 'text/plain']


@dataclass
Expand All @@ -91,7 +112,7 @@ class BinaryContent:
data: bytes
"""The binary data."""

media_type: AudioMediaType | ImageMediaType | str
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
"""The media type of the binary data."""

kind: Literal['binary'] = 'binary'
Expand All @@ -107,6 +128,11 @@ def is_image(self) -> bool:
"""Return `True` if the media type is an image type."""
return self.media_type.startswith('image/')

@property
def is_document(self) -> bool:
"""Return `True` if the media type is a document type."""
return self.media_type in {'application/pdf', 'text/plain'}

@property
def audio_format(self) -> Literal['mp3', 'wav']:
"""Return the audio format given the media type."""
Expand All @@ -118,7 +144,7 @@ def audio_format(self) -> Literal['mp3', 'wav']:
raise ValueError(f'Unknown audio media type: {self.media_type}')


UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | BinaryContent'
UserContent: TypeAlias = 'str | ImageUrl | AudioUrl | DocumentUrl | BinaryContent'


@dataclass
Expand Down
27 changes: 26 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from json import JSONDecodeError, loads as json_loads
from typing import Any, Literal, Union, cast, overload

from anthropic.types import DocumentBlockParam
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
BinaryContent,
DocumentUrl,
ImageUrl,
ModelMessage,
ModelRequest,
Expand All @@ -41,10 +43,12 @@
try:
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
from anthropic.types import (
Base64PDFSourceParam,
ImageBlockParam,
Message as AnthropicMessage,
MessageParam,
MetadataParam,
PlainTextSourceParam,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
Expand Down Expand Up @@ -322,7 +326,9 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
return system_prompt, anthropic_messages

@staticmethod
async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockParam | TextBlockParam]:
async def _map_user_prompt(
part: UserPromptPart,
) -> AsyncGenerator[ImageBlockParam | TextBlockParam | DocumentBlockParam]:
if isinstance(part.content, str):
yield TextBlockParam(text=part.content, type='text')
else:
Expand All @@ -344,6 +350,25 @@ async def _map_user_prompt(part: UserPromptPart) -> AsyncGenerator[ImageBlockPar
source={'data': io.BytesIO(response.content), 'media_type': 'image/jpeg', 'type': 'base64'},
type='image',
)
elif isinstance(item, DocumentUrl):
response = await cached_async_http_client().get(item.url)
response.raise_for_status()
if item.media_type == 'application/pdf':
yield DocumentBlockParam(
source=Base64PDFSourceParam(
data=io.BytesIO(response.content),
media_type=item.media_type,
type='base64',
),
type='document',
)
elif item.media_type == 'text/plain':
yield DocumentBlockParam(
source=PlainTextSourceParam(data=response.text, media_type=item.media_type, type='text'),
type='document',
)
else:
raise RuntimeError(f'Unsupported media type: {item.media_type}')
else:
raise RuntimeError(f'Unsupported content type: {type(item)}')

Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..messages import (
AudioUrl,
BinaryContent,
DocumentUrl,
ImageUrl,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -319,7 +320,7 @@ async def _map_user_prompt(part: UserPromptPart) -> list[_GeminiPartUnion]:
elif isinstance(item, BinaryContent):
base64_encoded = base64.b64encode(item.data).decode('utf-8')
content.append(_GeminiInlineDataPart(data=base64_encoded, mime_type=item.media_type))
elif isinstance(item, (AudioUrl, ImageUrl)):
elif isinstance(item, (AudioUrl, ImageUrl, DocumentUrl)):
try:
content.append(_GeminiFileDataData(file_uri=item.url, mime_type=item.media_type))
except ValueError:
Expand Down
Loading