Skip to content

Commit

Permalink
Add support for ImageUrl for Anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Feb 21, 2025
1 parent 70a34b2 commit c5fea8b
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,14 @@ def __init__(
client to use, if provided, `api_key` and `http_client` must be `None`.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
"""
self._http_client = http_client or cached_async_http_client()
self._model_name = model_name
if anthropic_client is not None:
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
self.client = anthropic_client
elif http_client is not None:
self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
else:
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
self.client = AsyncAnthropic(api_key=api_key, http_client=self._http_client)

async def request(
self,
Expand Down Expand Up @@ -218,7 +217,7 @@ async def _messages_create(
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls

system_prompt, anthropic_messages = self._map_message(messages)
system_prompt, anthropic_messages = await self._map_message(messages)

return await self.client.messages.create(
max_tokens=model_settings.get('max_tokens', 1024),
Expand Down Expand Up @@ -270,7 +269,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
return tools

def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
system_prompt: str = ''
anthropic_messages: list[MessageParam] = []
Expand All @@ -280,7 +279,7 @@ def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageP
if isinstance(part, SystemPromptPart):
system_prompt += part.content
elif isinstance(part, UserPromptPart):
anthropic_messages.append(_map_user_prompt(part))
anthropic_messages.append(await _map_user_prompt(part, self._http_client))
elif isinstance(part, ToolReturnPart):
anthropic_messages.append(
MessageParam(
Expand Down Expand Up @@ -372,7 +371,7 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
)


def _map_user_prompt(part: UserPromptPart) -> MessageParam:
async def _map_user_prompt(part: UserPromptPart, http_client: AsyncHTTPClient) -> MessageParam:
part_content: str | list[ImageBlockParam | TextBlockParam]
if isinstance(part.content, str):
part_content = part.content
Expand All @@ -395,7 +394,19 @@ def _map_user_prompt(part: UserPromptPart) -> MessageParam:
)
part_content.append(image_block)
elif isinstance(item, ImageUrl):
raise ValueError('ImageUrl is not supported in Anthropic')
response = await http_client.get(item.url)
response.raise_for_status()
image_block = ImageBlockParam(
source={
'data': io.BytesIO(response.content),
'media_type': 'image/jpeg',
'type': 'base64',
},
type='image',
)
part_content.append(image_block)
else:
raise ValueError(f'Unsupported content type: {type(item)}')
return MessageParam(role='user', content=part_content)


Expand Down

0 comments on commit c5fea8b

Please sign in to comment.