Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multimodal input #961

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions examples/pydantic_ai_examples/chat_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def to_chat_message(m: ModelMessage) -> ChatMessage:
first_part = m.parts[0]
if isinstance(m, ModelRequest):
if isinstance(first_part, UserPromptPart):
return {
'role': 'user',
'timestamp': first_part.timestamp.isoformat(),
'content': first_part.content,
}
if isinstance(first_part.content, str):
return {
'role': 'user',
'timestamp': first_part.timestamp.isoformat(),
'content': first_part.content,
}
elif isinstance(m, ModelResponse):
if isinstance(first_part, TextPart):
return {
Expand Down
18 changes: 18 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pathlib import Path

from rich.pretty import pprint

from pydantic_ai import Agent
from pydantic_ai.messages import BinaryContent, ImageUrl

image_url = 'https://goo.gle/instrument-img'

agent = Agent(model='google-gla:gemini-2.0-flash-exp')

output = agent.run_sync(
[
"What's in the image?",
ImageUrl(url=image_url),
]
)
pprint(output)
41 changes: 41 additions & 0 deletions multimodal.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Multimodal Support

## Kind

### Image

- **Claude** supports only base64 encoded images.
- https://docs.anthropic.com/en/docs/build-with-claude/vision#example-multiple-images
- **Groq** supports both url and base64 encoded images.
- https://console.groq.com/docs/vision
- **Mistral** supports both url and base64 encoded images.
- https://docs.mistral.ai/capabilities/vision/
- **OpenAI** support both url and base64 encoded images.
- https://platform.openai.com/docs/guides/vision

### Audio

- **OpenAI** supports base64 encoded audio.
- https://platform.openai.com/docs/guides/audio?example=audio-in

### Video

- **VertexAI** supports urls.
- https://docs.anthropic.com/en/docs/build-with-gemini/video-support

### Documents

#### PDF

- **Claude** supports only base64 encoded PDFs.
- https://docs.anthropic.com/en/docs/build-with-claude/pdf-support

```python
@dataclass
class DocumentPart: # more pdf

data: str
"""The base64 encoded data of the document part."""

media_type: Literal['application/pdf']
```
9 changes: 6 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):

user_deps: DepsT

prompt: str
prompt: str | _messages.ImageUrl | Sequence[str | _messages.ImageUrl]
new_message_index: int

model: models.Model
Expand All @@ -109,7 +109,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):

@dataclasses.dataclass
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
user_prompt: str
user_prompt: str | Sequence[_messages.UserContent]

system_prompts: tuple[str, ...]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
Expand All @@ -135,7 +135,10 @@ async def _get_first_message(
return next_message

async def _prepare_messages(
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
self,
user_prompt: str | Sequence[_messages.UserContent],
message_history: list[_messages.ModelMessage] | None,
run_context: RunContext[DepsT],
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
try:
ctx_messages = get_captured_run_messages()
Expand Down
17 changes: 10 additions & 7 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
@overload
async def run(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -235,7 +235,7 @@ async def run(
@overload
async def run(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT],
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -249,7 +249,7 @@ async def run(

async def run(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -466,7 +466,7 @@ async def main():
@overload
def run_sync(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
message_history: list[_messages.ModelMessage] | None = None,
model: models.Model | models.KnownModelName | None = None,
Expand All @@ -480,7 +480,7 @@ def run_sync(
@overload
def run_sync(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None,
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -494,7 +494,7 @@ def run_sync(

def run_sync(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -538,7 +538,8 @@ def run_sync(
"""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
return asyncio.get_event_loop().run_until_complete(
event_loop = asyncio.new_event_loop()
result = event_loop.run_until_complete(
self.run(
user_prompt,
result_type=result_type,
Expand All @@ -551,6 +552,8 @@ def run_sync(
infer_name=False,
)
)
event_loop.close()
return result

@overload
def run_stream(
Expand Down
91 changes: 90 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations as _annotations

import uuid
from collections.abc import Sequence
from dataclasses import dataclass, field, replace
from datetime import datetime
from typing import Annotated, Any, Literal, Union, cast, overload

import pydantic
import pydantic_core
from typing_extensions import TypeAlias

from ._utils import now_utc as _now_utc
from .exceptions import UnexpectedModelBehavior
Expand All @@ -32,6 +34,93 @@ class SystemPromptPart:
"""Part type identifier, this is available on all parts as a discriminator."""


@dataclass
class AudioUrl:
"""A URL to an audio file."""

url: str
"""The URL of the audio file."""

type: Literal['audio_url'] = 'audio_url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> AudioMediaType:
"""Return the media type of the audio file, based on the url."""
if self.url.endswith('.mp3'):
return 'audio/mpeg'
elif self.url.endswith('.wav'):
return 'audio/wav'
else:
raise ValueError(f'Unknown audio file extension: {self.url}')


@dataclass
class ImageUrl:
"""A URL to an image."""

url: str
"""The URL of the image."""

type: Literal['image-url'] = 'image-url'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def media_type(self) -> ImageMediaType:
"""Return the media type of the image, based on the url."""
if self.url.endswith('.jpg'):
return 'image/jpeg'
elif self.url.endswith('.png'):
return 'image/png'
elif self.url.endswith('.gif'):
return 'image/gif'
elif self.url.endswith('.webp'):
return 'image/webp'
else:
raise ValueError(f'Unknown image file extension: {self.url}')


AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg']
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']


@dataclass
class BinaryContent:
"""Binary content, e.g. an audio or image file."""

data: bytes
"""The binary data."""

media_type: AudioMediaType | ImageMediaType | str
"""The media type of the binary data."""

type: Literal['binary'] = 'binary'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def is_audio(self) -> bool:
"""Return `True` if the media type is an audio type."""
return self.media_type.startswith('audio/')

@property
def is_image(self) -> bool:
"""Return `True` if the media type is an image type."""
return self.media_type.startswith('image/')

@property
def audio_format(self) -> Literal['mp3', 'wav']:
"""Return the audio format given the media type."""
if self.media_type == 'audio/mpeg':
return 'mp3'
elif self.media_type == 'audio/wav':
return 'wav'
else:
raise ValueError(f'Unknown audio media type: {self.media_type}')


UserContent: TypeAlias = str | ImageUrl | AudioUrl | BinaryContent


@dataclass
class UserPromptPart:
"""A user prompt, generally written by the end user.
Expand All @@ -40,7 +129,7 @@ class UserPromptPart:
[`Agent.run_sync`][pydantic_ai.Agent.run_sync], and [`Agent.run_stream`][pydantic_ai.Agent.run_stream].
"""

content: str
content: str | Sequence[UserContent]
"""The content of the prompt."""

timestamp: datetime = field(default_factory=_now_utc)
Expand Down
49 changes: 46 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from __future__ import annotations as _annotations

import io
from collections.abc import AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
from json import JSONDecodeError, loads as json_loads
from typing import Any, Literal, Union, cast, overload

from anthropic.types import ImageBlockParam
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never

from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
BinaryContent,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -214,7 +218,7 @@ async def _messages_create(
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls

system_prompt, anthropic_messages = self._map_message(messages)
system_prompt, anthropic_messages = await self._map_message(messages)

return await self.client.messages.create(
max_tokens=model_settings.get('max_tokens', 1024),
Expand Down Expand Up @@ -266,7 +270,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
return tools

def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
system_prompt: str = ''
anthropic_messages: list[MessageParam] = []
Expand All @@ -276,7 +280,7 @@ def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageP
if isinstance(part, SystemPromptPart):
system_prompt += part.content
elif isinstance(part, UserPromptPart):
anthropic_messages.append(MessageParam(role='user', content=part.content))
anthropic_messages.append(await _map_user_prompt(part))
elif isinstance(part, ToolReturnPart):
anthropic_messages.append(
MessageParam(
Expand Down Expand Up @@ -368,6 +372,45 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
)


async def _map_user_prompt(part: UserPromptPart) -> MessageParam:
part_content: str | list[ImageBlockParam | TextBlockParam]
if isinstance(part.content, str):
part_content = part.content
else:
part_content = []
for item in part.content:
if isinstance(item, str):
part_content.append(TextBlockParam(text=item, type='text'))
elif isinstance(item, BinaryContent):
if item.media_type not in ('image/jpeg', 'image/png', 'image/gif', 'image/webp'):
# TODO(Marcelo): Replace for a better exception?
raise ValueError('Unsupported media type for image')
image_block = ImageBlockParam(
source={
'data': io.BytesIO(item.data),
'media_type': item.media_type,
'type': 'base64',
},
type='image',
)
part_content.append(image_block)
elif isinstance(item, ImageUrl):
response = await cached_async_http_client().get(item.url)
response.raise_for_status()
image_block = ImageBlockParam(
source={
'data': io.BytesIO(response.content),
'media_type': 'image/jpeg',
'type': 'base64',
},
type='image',
)
part_content.append(image_block)
else:
raise ValueError(f'Unsupported content type: {type(item)}')
return MessageParam(role='user', content=part_content)


@dataclass
class AnthropicStreamedResponse(StreamedResponse):
"""Implementation of `StreamedResponse` for Anthropic models."""
Expand Down
Loading
Loading