diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index acb8a336..ab2dcb58 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -96,10 +96,12 @@ class BinaryContent: type: Literal['binary'] = 'binary' """Type identifier, this is available on all parts as a discriminator.""" + @property def is_audio(self) -> bool: """Return `True` if the media type is an audio type.""" return self.media_type.startswith('audio/') + @property def is_image(self) -> bool: """Return `True` if the media type is an image type.""" return self.media_type.startswith('image/') diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index b5829627..fd97daa8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import base64 from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -13,6 +14,8 @@ from .. import UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( + BinaryContent, + ImageUrl, ModelMessage, ModelRequest, ModelResponse, @@ -38,7 +41,7 @@ try: from groq import NOT_GIVEN, AsyncGroq, AsyncStream from groq.types import chat - from groq.types.chat import ChatCompletion, ChatCompletionChunk + from groq.types.chat.chat_completion_content_part_image_param import ImageURL except ImportError as _import_error: raise ImportError( 'Please install `groq` to use the Groq model, ' @@ -163,7 +166,7 @@ async def _completions_create( stream: Literal[True], model_settings: GroqModelSettings, model_request_parameters: ModelRequestParameters, - ) -> AsyncStream[ChatCompletionChunk]: + ) -> AsyncStream[chat.ChatCompletionChunk]: pass @overload @@ -182,7 +185,7 @@ async def _completions_create( stream: bool, model_settings: GroqModelSettings, model_request_parameters: ModelRequestParameters, - ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: + ) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]: tools = self._get_tools(model_request_parameters) # standalone function to make it easier to override if not tools: @@ -224,7 +227,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) return ModelResponse(items, model_name=response.model, timestamp=timestamp) - async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse: + async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" peekable_response = _utils.PeekableAsyncStream(response) first_chunk = await peekable_response.peek() @@ -293,7 +296,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio if isinstance(part, SystemPromptPart): yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content) elif isinstance(part, UserPromptPart): - yield chat.ChatCompletionUserMessageParam(role='user', content=part.content) + yield _map_user_prompt(part) elif isinstance(part, ToolReturnPart): yield chat.ChatCompletionToolMessageParam( role='tool', @@ -316,7 +319,7 @@ class GroqStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for Groq models.""" _model_name: GroqModelName - _response: AsyncIterable[ChatCompletionChunk] + _response: AsyncIterable[chat.ChatCompletionChunk] _timestamp: datetime async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: @@ -355,9 +358,9 @@ def timestamp(self) -> datetime: return self._timestamp -def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage: +def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage: response_usage = None - if isinstance(completion, ChatCompletion): + if isinstance(completion, chat.ChatCompletion): response_usage = completion.usage elif completion.x_groq is not None: response_usage = completion.x_groq.usage @@ -370,3 +373,27 @@ def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage: response_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, ) + + +def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam: + content: str | list[chat.ChatCompletionContentPartParam] + if isinstance(part.content, str): + content = part.content + else: + content = [] + for item in part.content: + if isinstance(item, str): + content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text')) + elif isinstance(item, ImageUrl): + image_url = ImageURL(url=item.url) + content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) + elif isinstance(item, BinaryContent): + base64_encoded = base64.b64encode(item.data).decode('utf-8') + if item.is_image: + image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') + content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) + else: + raise ValueError('BinaryContent is not an image') + else: + raise ValueError(f'Unsupported content type: {type(item)}') + return chat.ChatCompletionUserMessageParam(role='user', content=content) diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index b641b7f5..211e221f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import base64 import os from collections.abc import AsyncIterable, AsyncIterator, Iterable from contextlib import asynccontextmanager @@ -15,6 +16,8 @@ from .. import UnexpectedModelBehavior, _utils from .._utils import now_utc as _now_utc from ..messages import ( + BinaryContent, + ImageUrl, ModelMessage, ModelRequest, ModelResponse, @@ -53,6 +56,8 @@ from mistralai.models import ( ChatCompletionResponse as MistralChatCompletionResponse, CompletionEvent as MistralCompletionEvent, + ImageURL as MistralImageURL, + ImageURLChunk as MistralImageURLChunk, Messages as MistralMessages, Tool as MistralTool, ToolCall as MistralToolCall, @@ -423,7 +428,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]: if isinstance(part, SystemPromptPart): yield MistralSystemMessage(content=part.content) elif isinstance(part, UserPromptPart): - yield MistralUserMessage(content=part.content) + yield _map_user_prompt(part) elif isinstance(part, ToolReturnPart): yield MistralToolMessage( tool_call_id=part.tool_call_id, @@ -620,3 +625,26 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None result = None return result + + +def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage: + content: str | list[MistralContentChunk] + if isinstance(part.content, str): + content = part.content + else: + content = [] + for item in part.content: + if isinstance(item, str): + content.append(MistralTextChunk(text=item)) + elif isinstance(item, ImageUrl): + content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url))) + elif isinstance(item, BinaryContent): + base64_encoded = base64.b64encode(item.data).decode('utf-8') + if item.is_image: + image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') + content.append(MistralImageURLChunk(image_url=image_url, type='image_url')) + else: + raise ValueError('BinaryContent is not an image') + else: + raise ValueError(f'Unsupported content type: {type(item)}') + return MistralUserMessage(content=content) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 67242f8a..652135c2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -421,10 +421,10 @@ def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessagePara content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) elif isinstance(item, BinaryContent): base64_encoded = base64.b64encode(item.data).decode('utf-8') - if item.is_image(): + if item.is_image: image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) - elif item.is_audio(): + elif item.is_audio: audio_url = InputAudio(data=base64_encoded, format=item.audio_format) content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio_url, type='input_audio')) else: diff --git a/uv.lock b/uv.lock index e53976e9..51d0ce01 100644 --- a/uv.lock +++ b/uv.lock @@ -1143,7 +1143,7 @@ wheels = [ [[package]] name = "groq" -version = "0.12.0" +version = "0.18.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1153,9 +1153,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/87/3e/f3861c5386adf1145465b01ca548cb35d683d07dbc9a13a06d7d1352da6d/groq-0.12.0.tar.gz", hash = "sha256:569229e2dadfc428b0df3d2987407691a4e3bc035b5849a65ef4909514a4605e", size = 107684 } +sdist = { url = "https://files.pythonhosted.org/packages/40/8c/e72c164474a88dfed6c7327ad53cb87ff11566b74b3a76d41dc7b94fc51c/groq-0.18.0.tar.gz", hash = "sha256:8e2ccfea406d68b3525af4b7c0e321fcb3d2a73fc60bb70b4156e6cd88c72f03", size = 117322 } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/13/8489c3df2047abed5acb4d3b188eedda7dc42d6c1fa4d853c22192de1115/groq-0.12.0-py3-none-any.whl", hash = "sha256:e8aa1529f82a01b2d15394b7ea242af9ee9387f65bdd1b91ce9a10f5a911dac1", size = 108852 }, + { url = "https://files.pythonhosted.org/packages/b0/6c/5a53d632b44ef7655ac8d9b34432e13160917f9307c94b1467efd34e336e/groq-0.18.0-py3-none-any.whl", hash = "sha256:81d5ac00057a45d8ce559d23ab5d3b3893011d1f12c35187ab35a9182d826ea6", size = 121911 }, ] [[package]] @@ -1579,7 +1579,7 @@ wheels = [ [[package]] name = "mistralai" -version = "1.2.5" +version = "1.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "eval-type-backport" }, @@ -1589,9 +1589,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "typing-inspect" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/55/34/95efe73fd3cd0d5f3f0198b2bfc570dfe485aa5045100aa97fa176dcb653/mistralai-1.2.5.tar.gz", hash = "sha256:05d4130f79704e3b19c0b6320944a348547879fce6894feeb72d9e9d0ee65151", size = 132348 } +sdist = { url = "https://files.pythonhosted.org/packages/16/9d/aba193fdfe0fc7403efa380189143d965becfb1bc7df3230e5c7664f8c53/mistralai-1.5.0.tar.gz", hash = "sha256:fd94bc93bc25aad9c6dd8005b1a0bc4ba1250c6b3fbf855a49936989cc6e5c0d", size = 131647 } wheels = [ - { url = "https://files.pythonhosted.org/packages/85/08/279a3afe0b319c283ae6d1ee8d42c606855093579e93e51cce2f6ced91a7/mistralai-1.2.5-py3-none-any.whl", hash = "sha256:5f0ef2680ead0329569111face1bf2ff7c67c454d43aa0e21324a8faf6c3ab22", size = 260045 }, + { url = "https://files.pythonhosted.org/packages/58/e7/7147c75c383a975c58c33f8e7ee7dbbb0e7390fbcb1ecd321f63e4c73efd/mistralai-1.5.0-py3-none-any.whl", hash = "sha256:9372537719f87bd6f9feef4747d0bf1f4fbe971f8c02945ca4b4bf3c94571c97", size = 271559 }, ] [[package]]