Skip to content

Commit

Permalink
Add support for groq
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Feb 21, 2025
1 parent c5fea8b commit 1dfd9b1
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 17 deletions.
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ class BinaryContent:
type: Literal['binary'] = 'binary'
"""Type identifier, this is available on all parts as a discriminator."""

@property
def is_audio(self) -> bool:
"""Return `True` if the media type is an audio type."""
return self.media_type.startswith('audio/')

@property
def is_image(self) -> bool:
"""Return `True` if the media type is an image type."""
return self.media_type.startswith('image/')
Expand Down
43 changes: 35 additions & 8 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import base64
from collections.abc import AsyncIterable, AsyncIterator, Iterable
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
Expand All @@ -13,6 +14,8 @@
from .. import UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
BinaryContent,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
Expand All @@ -38,7 +41,7 @@
try:
from groq import NOT_GIVEN, AsyncGroq, AsyncStream
from groq.types import chat
from groq.types.chat import ChatCompletion, ChatCompletionChunk
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
except ImportError as _import_error:
raise ImportError(
'Please install `groq` to use the Groq model, '
Expand Down Expand Up @@ -163,7 +166,7 @@ async def _completions_create(
stream: Literal[True],
model_settings: GroqModelSettings,
model_request_parameters: ModelRequestParameters,
) -> AsyncStream[ChatCompletionChunk]:
) -> AsyncStream[chat.ChatCompletionChunk]:
pass

@overload
Expand All @@ -182,7 +185,7 @@ async def _completions_create(
stream: bool,
model_settings: GroqModelSettings,
model_request_parameters: ModelRequestParameters,
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
) -> chat.ChatCompletion | AsyncStream[chat.ChatCompletionChunk]:
tools = self._get_tools(model_request_parameters)
# standalone function to make it easier to override
if not tools:
Expand Down Expand Up @@ -224,7 +227,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
return ModelResponse(items, model_name=response.model, timestamp=timestamp)

async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
peekable_response = _utils.PeekableAsyncStream(response)
first_chunk = await peekable_response.peek()
Expand Down Expand Up @@ -293,7 +296,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio
if isinstance(part, SystemPromptPart):
yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
elif isinstance(part, UserPromptPart):
yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
yield _map_user_prompt(part)
elif isinstance(part, ToolReturnPart):
yield chat.ChatCompletionToolMessageParam(
role='tool',
Expand All @@ -316,7 +319,7 @@ class GroqStreamedResponse(StreamedResponse):
"""Implementation of `StreamedResponse` for Groq models."""

_model_name: GroqModelName
_response: AsyncIterable[ChatCompletionChunk]
_response: AsyncIterable[chat.ChatCompletionChunk]
_timestamp: datetime

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
Expand Down Expand Up @@ -355,9 +358,9 @@ def timestamp(self) -> datetime:
return self._timestamp


def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> usage.Usage:
response_usage = None
if isinstance(completion, ChatCompletion):
if isinstance(completion, chat.ChatCompletion):
response_usage = completion.usage
elif completion.x_groq is not None:
response_usage = completion.x_groq.usage
Expand All @@ -370,3 +373,27 @@ def _map_usage(completion: ChatCompletionChunk | ChatCompletion) -> usage.Usage:
response_tokens=response_usage.completion_tokens,
total_tokens=response_usage.total_tokens,
)


def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam:
content: str | list[chat.ChatCompletionContentPartParam]
if isinstance(part.content, str):
content = part.content
else:
content = []
for item in part.content:
if isinstance(item, str):
content.append(chat.ChatCompletionContentPartTextParam(text=item, type='text'))
elif isinstance(item, ImageUrl):
image_url = ImageURL(url=item.url)
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
elif isinstance(item, BinaryContent):
base64_encoded = base64.b64encode(item.data).decode('utf-8')
if item.is_image:
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
content.append(chat.ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
else:
raise ValueError('BinaryContent is not an image')
else:
raise ValueError(f'Unsupported content type: {type(item)}')
return chat.ChatCompletionUserMessageParam(role='user', content=content)
30 changes: 29 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import base64
import os
from collections.abc import AsyncIterable, AsyncIterator, Iterable
from contextlib import asynccontextmanager
Expand All @@ -15,6 +16,8 @@
from .. import UnexpectedModelBehavior, _utils
from .._utils import now_utc as _now_utc
from ..messages import (
BinaryContent,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -53,6 +56,8 @@
from mistralai.models import (
ChatCompletionResponse as MistralChatCompletionResponse,
CompletionEvent as MistralCompletionEvent,
ImageURL as MistralImageURL,
ImageURLChunk as MistralImageURLChunk,
Messages as MistralMessages,
Tool as MistralTool,
ToolCall as MistralToolCall,
Expand Down Expand Up @@ -423,7 +428,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
if isinstance(part, SystemPromptPart):
yield MistralSystemMessage(content=part.content)
elif isinstance(part, UserPromptPart):
yield MistralUserMessage(content=part.content)
yield _map_user_prompt(part)
elif isinstance(part, ToolReturnPart):
yield MistralToolMessage(
tool_call_id=part.tool_call_id,
Expand Down Expand Up @@ -620,3 +625,26 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
result = None

return result


def _map_user_prompt(part: UserPromptPart) -> MistralUserMessage:
content: str | list[MistralContentChunk]
if isinstance(part.content, str):
content = part.content
else:
content = []
for item in part.content:
if isinstance(item, str):
content.append(MistralTextChunk(text=item))
elif isinstance(item, ImageUrl):
content.append(MistralImageURLChunk(image_url=MistralImageURL(url=item.url)))
elif isinstance(item, BinaryContent):
base64_encoded = base64.b64encode(item.data).decode('utf-8')
if item.is_image:
image_url = MistralImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
content.append(MistralImageURLChunk(image_url=image_url, type='image_url'))
else:
raise ValueError('BinaryContent is not an image')
else:
raise ValueError(f'Unsupported content type: {type(item)}')
return MistralUserMessage(content=content)
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,10 @@ def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessagePara
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
elif isinstance(item, BinaryContent):
base64_encoded = base64.b64encode(item.data).decode('utf-8')
if item.is_image():
if item.is_image:
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}')
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url'))
elif item.is_audio():
elif item.is_audio:
audio_url = InputAudio(data=base64_encoded, format=item.audio_format)
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio_url, type='input_audio'))
else:
Expand Down
12 changes: 6 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1dfd9b1

Please sign in to comment.