Skip to content

Commit

Permalink
Support multimodal inputs (#971)
Browse files Browse the repository at this point in the history
Co-authored-by: David Montague <35119617+dmontagu@users.noreply.github.com>
  • Loading branch information
Kludex and dmontagu authored Feb 25, 2025
1 parent 4cd2603 commit 96be03d
Show file tree
Hide file tree
Showing 40 changed files with 8,814 additions and 168 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
args: ['--skip', 'tests/models/cassettes/*']
additional_dependencies:
- tomli
55 changes: 55 additions & 0 deletions docs/input.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Image and Audio Input

Some LLMs are now capable of understanding both audio and image content.

## Image Input

!!! info
Some models do not support image input. Please check the model's documentation to confirm whether it supports image input.

If you have a direct URL for the image, you can use [`ImageUrl`][pydantic_ai.ImageUrl]:

```py {title="main.py" test="skip" lint="skip"}
from pydantic_ai import Agent, ImageUrl

image_url = ImageUrl(url='https://iili.io/3Hs4FMg.png')

agent = Agent(model='openai:gpt-4o')
result = agent.run_sync(
[
'What company is this logo from?',
ImageUrl(url='https://iili.io/3Hs4FMg.png'),
]
)
print(result.data)
#> This is the logo for Pydantic, a data validation and settings management library in Python.
```

If you have the image locally, you can also use [`BinaryContent`][pydantic_ai.BinaryContent]:

```py {title="main.py" test="skip" lint="skip"}
import httpx

from pydantic_ai import Agent, BinaryContent

image_response = httpx.get('https://iili.io/3Hs4FMg.png') # Pydantic logo

agent = Agent(model='openai:gpt-4o')
result = agent.run_sync(
[
'What company is this logo from?',
BinaryContent(data=image_response.content, media_type='image/png'), # (1)!
]
)
print(result.data)
#> This is the logo for Pydantic, a data validation and settings management library in Python.
```

1. To ensure the example is runnable we download this image from the web, but you can also use `Path().read_bytes()` to read a local file's contents.

## Audio Input

!!! info
Some models do not support audio input. Please check the model's documentation to confirm whether it supports audio input.

You can provide audio input using either [`AudioUrl`][pydantic_ai.AudioUrl] or [`BinaryContent`][pydantic_ai.BinaryContent]. The process is analogous to the examples above.
1 change: 1 addition & 0 deletions examples/pydantic_ai_examples/chat_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def to_chat_message(m: ModelMessage) -> ChatMessage:
first_part = m.parts[0]
if isinstance(m, ModelRequest):
if isinstance(first_part, UserPromptPart):
assert isinstance(first_part.content, str)
return {
'role': 'user',
'timestamp': first_part.timestamp.isoformat(),
Expand Down
101 changes: 51 additions & 50 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,52 +16,53 @@ nav:
- contributing.md
- troubleshooting.md
- Documentation:
- agents.md
- models.md
- dependencies.md
- tools.md
- results.md
- message-history.md
- testing-evals.md
- logfire.md
- multi-agent-applications.md
- graph.md
- agents.md
- models.md
- dependencies.md
- tools.md
- results.md
- message-history.md
- testing-evals.md
- logfire.md
- multi-agent-applications.md
- graph.md
- input.md
- Examples:
- examples/index.md
- examples/pydantic-model.md
- examples/weather-agent.md
- examples/bank-support.md
- examples/sql-gen.md
- examples/flight-booking.md
- examples/rag.md
- examples/stream-markdown.md
- examples/stream-whales.md
- examples/chat-app.md
- examples/question-graph.md
- examples/index.md
- examples/pydantic-model.md
- examples/weather-agent.md
- examples/bank-support.md
- examples/sql-gen.md
- examples/flight-booking.md
- examples/rag.md
- examples/stream-markdown.md
- examples/stream-whales.md
- examples/chat-app.md
- examples/question-graph.md
- API Reference:
- api/agent.md
- api/tools.md
- api/result.md
- api/messages.md
- api/exceptions.md
- api/settings.md
- api/usage.md
- api/format_as_xml.md
- api/models/base.md
- api/models/openai.md
- api/models/anthropic.md
- api/models/cohere.md
- api/models/gemini.md
- api/models/vertexai.md
- api/models/groq.md
- api/models/mistral.md
- api/models/test.md
- api/models/function.md
- api/pydantic_graph/graph.md
- api/pydantic_graph/nodes.md
- api/pydantic_graph/state.md
- api/pydantic_graph/mermaid.md
- api/pydantic_graph/exceptions.md
- api/agent.md
- api/tools.md
- api/result.md
- api/messages.md
- api/exceptions.md
- api/settings.md
- api/usage.md
- api/format_as_xml.md
- api/models/base.md
- api/models/openai.md
- api/models/anthropic.md
- api/models/cohere.md
- api/models/gemini.md
- api/models/vertexai.md
- api/models/groq.md
- api/models/mistral.md
- api/models/test.md
- api/models/function.md
- api/pydantic_graph/graph.md
- api/pydantic_graph/nodes.md
- api/pydantic_graph/state.md
- api/pydantic_graph/mermaid.md
- api/pydantic_graph/exceptions.md

extra:
# hide the "Made with Material for MkDocs" message
Expand Down Expand Up @@ -100,12 +101,12 @@ theme:
- content.code.copy
- content.code.select
- navigation.path
# - navigation.expand
# - navigation.expand
- navigation.indexes
- navigation.sections
- navigation.tracking
- toc.follow
# - navigation.tabs # don't use navbar tabs
# - navigation.tabs # don't use navbar tabs
logo: "img/logo-white.svg"
favicon: "favicon.ico"

Expand Down Expand Up @@ -151,7 +152,7 @@ markdown_extensions:
emoji_generator: !!python/name:material.extensions.emoji.to_svg
options:
custom_icons:
- docs/.overrides/.icons
- docs/.overrides/.icons
- pymdownx.tabbed:
alternate_style: true
- pymdownx.tasklist:
Expand Down Expand Up @@ -190,6 +191,6 @@ plugins:
# waiting for https://github.com/encode/httpx/discussions/3091#discussioncomment-11205594

hooks:
- 'docs/.hooks/main.py'
- 'docs/.hooks/build_llms_txt.py'
- 'docs/.hooks/algolia.py'
- "docs/.hooks/main.py"
- "docs/.hooks/build_llms_txt.py"
- "docs/.hooks/algolia.py"
14 changes: 11 additions & 3 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@

from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
from .messages import AudioUrl, BinaryContent, ImageUrl
from .tools import RunContext, Tool

__all__ = (
'__version__',
# agent
'Agent',
'EndStrategy',
'HandleResponseNode',
'ModelRequestNode',
'UserPromptNode',
'capture_run_messages',
'RunContext',
'Tool',
# exceptions
'AgentRunError',
'ModelRetry',
'UnexpectedModelBehavior',
'UsageLimitExceeded',
'UserError',
'__version__',
# messages
'ImageUrl',
'AudioUrl',
'BinaryContent',
# tools
'Tool',
'RunContext',
)
__version__ = version('pydantic_ai_slim')
11 changes: 7 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import dataclasses
from abc import ABC
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import field
Expand Down Expand Up @@ -89,7 +89,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):

user_deps: DepsT

prompt: str
prompt: str | Sequence[_messages.UserContent]
new_message_index: int

model: models.Model
Expand All @@ -109,7 +109,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):

@dataclasses.dataclass
class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], NodeRunEndT], ABC):
user_prompt: str
user_prompt: str | Sequence[_messages.UserContent]

system_prompts: tuple[str, ...]
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
Expand All @@ -135,7 +135,10 @@ async def _get_first_message(
return next_message

async def _prepare_messages(
self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[DepsT]
self,
user_prompt: str | Sequence[_messages.UserContent],
message_history: list[_messages.ModelMessage] | None,
run_context: RunContext[DepsT],
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
try:
ctx_messages = get_captured_run_messages()
Expand Down
20 changes: 10 additions & 10 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def __init__(
@overload
async def run(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -235,7 +235,7 @@ async def run(
@overload
async def run(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT],
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -249,7 +249,7 @@ async def run(

async def run(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -313,7 +313,7 @@ async def main():
@contextmanager
def iter(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -466,7 +466,7 @@ async def main():
@overload
def run_sync(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
message_history: list[_messages.ModelMessage] | None = None,
model: models.Model | models.KnownModelName | None = None,
Expand All @@ -480,7 +480,7 @@ def run_sync(
@overload
def run_sync(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None,
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -494,7 +494,7 @@ def run_sync(

def run_sync(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down Expand Up @@ -555,7 +555,7 @@ def run_sync(
@overload
def run_stream(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -570,7 +570,7 @@ def run_stream(
@overload
def run_stream(
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT],
message_history: list[_messages.ModelMessage] | None = None,
Expand All @@ -585,7 +585,7 @@ def run_stream(
@asynccontextmanager
async def run_stream( # noqa C901
self,
user_prompt: str,
user_prompt: str | Sequence[_messages.UserContent],
*,
result_type: type[RunResultDataT] | None = None,
message_history: list[_messages.ModelMessage] | None = None,
Expand Down
Loading

0 comments on commit 96be03d

Please sign in to comment.