Skip to content

Commit

Permalink
Remove the AgentModel class
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Jan 29, 2025
1 parent ef98cec commit 3d74599
Show file tree
Hide file tree
Showing 8 changed files with 3 additions and 123 deletions.
1 change: 0 additions & 1 deletion docs/api/models/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
members:
- KnownModelName
- Model
- AgentModel
- AbstractToolDefinition
- StreamedResponse
- ALLOW_MODEL_REQUESTS
Expand Down
4 changes: 2 additions & 2 deletions docs/api/models/vertexai.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

Custom interface to the `*-aiplatform.googleapis.com` API for Gemini models.

This model uses [`GeminiAgentModel`][pydantic_ai.models.gemini.GeminiAgentModel] with just the URL and auth method
changed from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel], it relies on the VertexAI
This model inherits from [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] with just the URL and auth method
changed, it relies on the VertexAI
[`generateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/generateContent)
and
[`streamGenerateContent`](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamGenerateContent)
Expand Down
3 changes: 1 addition & 2 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,8 @@ agent = Agent(model)

To implement support for models not already supported, you will need to subclass the [`Model`][pydantic_ai.models.Model] abstract base class.

This in turn will require you to implement the following other abstract base classes:
For streaming, you'll also need to implement the following abstract base class:

* [`AgentModel`][pydantic_ai.models.AgentModel]
* [`StreamedResponse`][pydantic_ai.models.StreamedResponse]

The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py).
Expand Down
23 changes: 0 additions & 23 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,6 @@ def __init__(
else:
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())

# async def agent_model(
# self,
# *,
# function_tools: list[ToolDefinition],
# allow_text_result: bool,
# result_tools: list[ToolDefinition],
# ) -> AgentModel:
# return AnthropicAgentModel(
# self.client,
# self.model_name,
# allow_text_result,
# tools,
# )

def name(self) -> str:
return f'anthropic:{self.model_name}'

Expand All @@ -159,15 +145,6 @@ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
'input_schema': f.parameters_json_schema,
}

# @dataclass
# class AnthropicAgentModel(AgentModel):
# """Implementation of `AgentModel` for Anthropic models."""
#
# client: AsyncAnthropic
# model_name: AnthropicModelName
# allow_text_result: bool
# tools: list[ToolParam]

def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[ToolParam]:
tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools]
if agent_request_config.result_tools:
Expand Down
13 changes: 0 additions & 13 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,6 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre
self.function = function
self.stream_function = stream_function

# async def agent_model(
# self,
# *,
# function_tools: list[ToolDefinition],
# allow_text_result: bool,
# result_tools: list[ToolDefinition],
# ) -> AgentModel:
# return FunctionAgentModel(
# self.function,
# self.stream_function,
# AgentInfo(function_tools, allow_text_result, result_tools, None),
# )

def name(self) -> str:
function_name = self.function.__name__ if self.function is not None else ''
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
Expand Down
26 changes: 0 additions & 26 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,6 @@ def __init__(
else:
self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())

# async def agent_model(
# self,
# *,
# function_tools: list[ToolDefinition],
# allow_text_result: bool,
# result_tools: list[ToolDefinition],
# ) -> AgentModel:
# check_allow_model_requests()
# # tools = [self._map_tool_definition(r) for r in function_tools]
# # if result_tools:
# # tools += [self._map_tool_definition(r) for r in result_tools]
# return GroqAgentModel(
# self.client,
# self.model_name,
# allow_text_result,
# tools,
# )

def name(self) -> str:
return f'groq:{self.model_name}'

Expand All @@ -144,14 +126,6 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
},
}

# @dataclass
# class GroqAgentModel(AgentModel):
# """Implementation of `AgentModel` for Groq models."""
#
# client: AsyncGroq
# model_name: str
# allow_text_result: bool
# tools: list[chat.ChatCompletionToolParam]
def _get_tools(self, agent_request_config: AgentRequestConfig) -> list[chat.ChatCompletionToolParam]:
tools = [self._map_tool_definition(r) for r in agent_request_config.function_tools]
if agent_request_config.result_tools:
Expand Down
27 changes: 0 additions & 27 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,36 +132,9 @@ def __init__(
api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())

# async def agent_model(
# self,
# *,
# function_tools: list[ToolDefinition],
# allow_text_result: bool,
# result_tools: list[ToolDefinition],
# ) -> AgentModel:
# """Create an agent model, this is called for each step of an agent run from Pydantic AI call."""
# check_allow_model_requests()
# return MistralAgentModel(
# self.client,
# self.model_name,
# allow_text_result,
# function_tools,
# result_tools,
# )

def name(self) -> str:
return f'mistral:{self.model_name}'

# @dataclass
# class MistralAgentModel(AgentModel):
# """Implementation of `AgentModel` for Mistral models."""
#
# client: Mistral
# model_name: MistralModelName
# allow_text_result: bool
# function_tools: list[ToolDefinition]
# result_tools: list[ToolDefinition]

async def request(
self,
messages: list[ModelMessage],
Expand Down
29 changes: 0 additions & 29 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,6 @@ def __init__(
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
self.system_prompt_role = system_prompt_role

# async def agent_model(
# self,
# *,
# function_tools: list[ToolDefinition],
# allow_text_result: bool,
# result_tools: list[ToolDefinition],
# ) -> AgentModel:
# check_allow_model_requests()
# tools = [self._map_tool_definition(r) for r in function_tools]
# if result_tools:
# tools += [self._map_tool_definition(r) for r in result_tools]
# return OpenAIAgentModel(
# self.client,
# self.model_name,
# allow_text_result,
# tools,
# self.system_prompt_role,
# )

def name(self) -> str:
return f'openai:{self.model_name}'

Expand All @@ -145,16 +126,6 @@ def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
},
}

# @dataclass
# class OpenAIAgentModel(AgentModel):
# """Implementation of `AgentModel` for OpenAI models."""
#
# client: AsyncOpenAI
# model_name: OpenAIModelName
# allow_text_result: bool
# tools: list[chat.ChatCompletionToolParam]
# system_prompt_role: OpenAISystemPromptRole | None

async def request(
self,
messages: list[ModelMessage],
Expand Down

0 comments on commit 3d74599

Please sign in to comment.