From 3d74599f75db5d9b791fa12f96f858874ffdcc31 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 28 Jan 2025 17:55:32 -0700 Subject: [PATCH] Remove the AgentModel class --- docs/api/models/base.md | 1 - docs/api/models/vertexai.md | 4 +-- docs/models.md | 3 +- .../pydantic_ai/models/anthropic.py | 23 --------------- .../pydantic_ai/models/function.py | 13 --------- pydantic_ai_slim/pydantic_ai/models/groq.py | 26 ----------------- .../pydantic_ai/models/mistral.py | 27 ----------------- pydantic_ai_slim/pydantic_ai/models/openai.py | 29 ------------------- 8 files changed, 3 insertions(+), 123 deletions(-) diff --git a/docs/api/models/base.md b/docs/api/models/base.md index bf72de7e6..24fcb9bb8 100644 --- a/docs/api/models/base.md +++ b/docs/api/models/base.md @@ -5,7 +5,6 @@ members: - KnownModelName - Model - - AgentModel - AbstractToolDefinition - StreamedResponse - ALLOW_MODEL_REQUESTS diff --git a/docs/api/models/vertexai.md b/docs/api/models/vertexai.md index d59968c79..0c4d48f0c 100644 --- a/docs/api/models/vertexai.md +++ b/docs/api/models/vertexai.md @@ -2,8 +2,8 @@ Custom interface to the `*-aiplatform.googleapis.com` API for Gemini models. -This model uses [`GeminiAgentModel`][pydantic_ai.models.gemini.GeminiAgentModel] with just the URL and auth method -changed from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel], it relies on the VertexAI +This model inherits from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] with just the URL and auth method +changed, it relies on the VertexAI [`generateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent) and [`streamGenerateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent) diff --git a/docs/models.md b/docs/models.md index 76fc53dae..945043491 100644 --- a/docs/models.md +++ b/docs/models.md @@ -515,9 +515,8 @@ agent = Agent(model) To implement support for models not already supported, you will need to subclass the [`Model`][pydantic_ai.models.Model] abstract base class. -This in turn will require you to implement the following other abstract base classes: +For streaming, you'll also need to implement the following abstract base class: -* [`AgentModel`][pydantic_ai.models.AgentModel] * [`StreamedResponse`][pydantic_ai.models.StreamedResponse] The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py). diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 32a4d8b56..e7d164843 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -134,20 +134,6 @@ def __init__( else: self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client()) - # async def agent_model( - # self, - # *, - # function_tools: list[ToolDefinition], - # allow_text_result: bool, - # result_tools: list[ToolDefinition], - # ) -> AgentModel: - # return AnthropicAgentModel( - # self.client, - # self.model_name, - # allow_text_result, - # tools, - # ) - def name(self) -> str: return f'anthropic:{self.model_name}' @@ -159,15 +145,6 @@ def _map_tool_definition(f: ToolDefinition) -> ToolParam: 'input_schema': f.parameters_json_schema, } - # @dataclass - # class AnthropicAgentModel(AgentModel): - # """Implementation of `AgentModel` for Anthropic models.""" - # - # client: AsyncAnthropic - # model_name: AnthropicModelName - # allow_text_result: bool - # tools: list[ToolParam] - def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[ToolParam]: tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools] if agent_request_config.result_tools: diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index b8deab572..56792dda9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -63,19 +63,6 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre self.function = function self.stream_function = stream_function - # async def agent_model( - # self, - # *, - # function_tools: list[ToolDefinition], - # allow_text_result: bool, - # result_tools: list[ToolDefinition], - # ) -> AgentModel: - # return FunctionAgentModel( - # self.function, - # self.stream_function, - # AgentInfo(function_tools, allow_text_result, result_tools, None), - # ) - def name(self) -> str: function_name = self.function.__name__ if self.function is not None else '' stream_function_name = self.stream_function.__name__ if self.stream_function is not None else '' diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index a250f80aa..1c833210d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -112,24 +112,6 @@ def __init__( else: self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client()) - # async def agent_model( - # self, - # *, - # function_tools: list[ToolDefinition], - # allow_text_result: bool, - # result_tools: list[ToolDefinition], - # ) -> AgentModel: - # check_allow_model_requests() - # # tools = [self._map_tool_definition(r) for r in function_tools] - # # if result_tools: - # # tools += [self._map_tool_definition(r) for r in result_tools] - # return GroqAgentModel( - # self.client, - # self.model_name, - # allow_text_result, - # tools, - # ) - def name(self) -> str: return f'groq:{self.model_name}' @@ -144,14 +126,6 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: }, } - # @dataclass - # class GroqAgentModel(AgentModel): - # """Implementation of `AgentModel` for Groq models.""" - # - # client: AsyncGroq - # model_name: str - # allow_text_result: bool - # tools: list[chat.ChatCompletionToolParam] def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[chat.ChatCompletionToolParam]: tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools] if agent_request_config.result_tools: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 49cbdf6d5..49a9eea30 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -132,36 +132,9 @@ def __init__( api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client()) - # async def agent_model( - # self, - # *, - # function_tools: list[ToolDefinition], - # allow_text_result: bool, - # result_tools: list[ToolDefinition], - # ) -> AgentModel: - # """Create an agent model, this is called for each step of an agent run from Pydantic AI call.""" - # check_allow_model_requests() - # return MistralAgentModel( - # self.client, - # self.model_name, - # allow_text_result, - # function_tools, - # result_tools, - # ) - def name(self) -> str: return f'mistral:{self.model_name}' - # @dataclass - # class MistralAgentModel(AgentModel): - # """Implementation of `AgentModel` for Mistral models.""" - # - # client: Mistral - # model_name: MistralModelName - # allow_text_result: bool - # function_tools: list[ToolDefinition] - # result_tools: list[ToolDefinition] - async def request( self, messages: list[ModelMessage], diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index bebec05a8..75347dc4d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -112,25 +112,6 @@ def __init__( self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client()) self.system_prompt_role = system_prompt_role - # async def agent_model( - # self, - # *, - # function_tools: list[ToolDefinition], - # allow_text_result: bool, - # result_tools: list[ToolDefinition], - # ) -> AgentModel: - # check_allow_model_requests() - # tools = [self._map_tool_definition(r) for r in function_tools] - # if result_tools: - # tools += [self._map_tool_definition(r) for r in result_tools] - # return OpenAIAgentModel( - # self.client, - # self.model_name, - # allow_text_result, - # tools, - # self.system_prompt_role, - # ) - def name(self) -> str: return f'openai:{self.model_name}' @@ -145,16 +126,6 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: }, } - # @dataclass - # class OpenAIAgentModel(AgentModel): - # """Implementation of `AgentModel` for OpenAI models.""" - # - # client: AsyncOpenAI - # model_name: OpenAIModelName - # allow_text_result: bool - # tools: list[chat.ChatCompletionToolParam] - # system_prompt_role: OpenAISystemPromptRole | None - async def request( self, messages: list[ModelMessage],