-
Notifications
You must be signed in to change notification settings - Fork 165
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d50c8c5
commit 66b5a63
Showing
8 changed files
with
512 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
99 changes: 99 additions & 0 deletions
99
py/plugins/google-ai/src/genkit/plugins/google_ai/google_ai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
from collections.abc import Callable | ||
from typing import Optional | ||
|
||
import pydantic | ||
from genkit.core import plugin_abc | ||
from genkit.core.action import ActionKind | ||
from genkit.core.registry import Registry | ||
from genkit.core.typing import ( | ||
GenerateRequest, | ||
GenerateResponse, | ||
Message, | ||
Role, | ||
TextPart, | ||
) | ||
from genkit.plugins.google_ai.models import models | ||
from google import genai | ||
|
||
|
||
def googleai_name(name: str) -> str: | ||
"""Create a Google AI action name. | ||
Args: | ||
name: Base name for the action. | ||
Returns: | ||
The fully qualified Google AI action name. | ||
""" | ||
return f'googleai/{name}' | ||
|
||
|
||
class GoogleAiPluginOptions(pydantic.BaseModel): | ||
api_key: Optional[str] = None | ||
# TODO: implement all authentication methods | ||
# project: Optional[str] = None, | ||
# location: Optional[str] = None | ||
# TODO: implement http options | ||
# api_version: Optional[str] = None | ||
# base_url: Optional[str] = None | ||
|
||
|
||
class GoogleAi(plugin_abc.Plugin): | ||
def __init__(self, plugin_params: GoogleAiPluginOptions | None = None): | ||
api_key = ( | ||
plugin_params.api_key | ||
if plugin_params and plugin_params.api_key | ||
else os.getenv('GEMINI_API_KEY') | ||
) | ||
if not api_key: | ||
raise ValueError( | ||
'Gemini api key should be passed in plugin params ' | ||
'or as a GEMINI_API_KEY environment variable' | ||
) | ||
self._client = genai.client.Client(api_key=api_key) | ||
|
||
def initialize(self, registry: Registry): | ||
for name, model in models.SUPPORTED_MODELS.items(): | ||
model_metadata = { | ||
'model': { | ||
'supports': model.supports.model_dump(), | ||
} | ||
} | ||
|
||
registry.register_action( | ||
kind=ActionKind.MODEL, | ||
name=googleai_name(name), | ||
fn=self._create_callback(name), | ||
metadata=model_metadata, | ||
) | ||
|
||
def _create_callback( | ||
self, model: str | ||
) -> Callable[[GenerateRequest], GenerateResponse]: | ||
async def model_callback(request: GenerateRequest) -> GenerateResponse: | ||
reqest_msgs: list[genai.types.Content] = [] | ||
for msg in request.messages: | ||
message_parts: list[genai.types.Part] = [] | ||
for p in msg.content: | ||
message_parts.append( | ||
genai.types.Part.from_text(text=p.root.text) | ||
) | ||
reqest_msgs.append( | ||
genai.types.Content(parts=message_parts, role=msg.role) | ||
) | ||
response = await self._client.aio.models.generate_content( | ||
model=model, contents=reqest_msgs | ||
) | ||
|
||
return GenerateResponse( | ||
message=Message( | ||
role=Role.MODEL, | ||
content=[TextPart(text=response.text)], | ||
) | ||
) | ||
|
||
return model_callback |
112 changes: 112 additions & 0 deletions
112
py/plugins/google-ai/src/genkit/plugins/google_ai/models/models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright 2025 Google LLC | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from enum import StrEnum | ||
|
||
from genkit.core.typing import ModelInfo, Supports | ||
|
||
gemini10Pro = ModelInfo( | ||
label='Google AI - Gemini Pro', | ||
versions=['gemini-pro', 'gemini-1.0-pro-latest', 'gemini-1.0-pro-001'], | ||
supports=Supports( | ||
multiturn=True, | ||
media=False, | ||
tools=True, | ||
tool_choice=True, | ||
system_role=True, | ||
constrained='no-tools', | ||
), | ||
) | ||
|
||
gemini15Pro = ModelInfo( | ||
label='Google AI - Gemini 1.5 Pro', | ||
versions=[ | ||
'gemini-1.5-pro-latest', | ||
'gemini-1.5-pro-001', | ||
'gemini-1.5-pro-002', | ||
], | ||
supports=Supports( | ||
multiturn=True, | ||
media=True, | ||
tools=True, | ||
tool_choice=True, | ||
system_role=True, | ||
constrained='no-tools', | ||
), | ||
) | ||
|
||
gemini15Flash = ModelInfo( | ||
label='Google AI - Gemini 1.5 Flash', | ||
versions=[ | ||
'gemini-1.5-flash-latest', | ||
'gemini-1.5-flash-001', | ||
'gemini-1.5-flash-002', | ||
], | ||
supports=Supports( | ||
multiturn=True, | ||
media=True, | ||
tools=True, | ||
tool_choice=True, | ||
system_role=True, | ||
constrained='no-tools', | ||
), | ||
) | ||
|
||
|
||
gemini15Flash8b = ModelInfo( | ||
label='Google AI - Gemini 1.5 Flash', | ||
versions=['gemini-1.5-flash-8b-latest', 'gemini-1.5-flash-8b-001'], | ||
supports=Supports( | ||
multiturn=True, | ||
media=True, | ||
tools=True, | ||
tool_choice=True, | ||
system_role=True, | ||
constrained='no-tools', | ||
), | ||
) | ||
|
||
|
||
gemini20Flash = ModelInfo( | ||
label='Google AI - Gemini 2.0 Flash', | ||
supports=Supports( | ||
multiturn=True, | ||
media=True, | ||
tools=True, | ||
tool_choice=True, | ||
system_role=True, | ||
constrained='no-tools', | ||
), | ||
) | ||
|
||
|
||
gemini20ProExp0205 = ModelInfo( | ||
label='Google AI - Gemini 2.0 Pro Exp 02-05', | ||
supports=Supports( | ||
multiturn=True, | ||
media=True, | ||
tools=True, | ||
tool_choice=True, | ||
system_role=True, | ||
constrained='no-tools', | ||
), | ||
) | ||
|
||
|
||
class GoogleAiVersion(StrEnum): | ||
GEMINI_1_0_PRO = 'gemini-1.0-pro' | ||
GEMINI_1_5_PRO = 'gemini-1.5-pro' | ||
GEMINI_1_5_FLASH = 'gemini-1.5-flash' | ||
GEMINI_1_5_FLASH_8B = 'gemini-1.5-flash-8b' | ||
GEMINI_2_0_FLASH = 'gemini-2.0-flash' | ||
GEMINI_2_0_PRO_EXP_02_05 = 'gemini-2.0-pro-exp-02-05' | ||
|
||
|
||
SUPPORTED_MODELS = { | ||
GoogleAiVersion.GEMINI_1_0_PRO: gemini10Pro, | ||
GoogleAiVersion.GEMINI_1_5_PRO: gemini15Pro, | ||
GoogleAiVersion.GEMINI_1_5_FLASH: gemini15Flash, | ||
GoogleAiVersion.GEMINI_1_5_FLASH_8B: gemini15Flash8b, | ||
GoogleAiVersion.GEMINI_2_0_FLASH: gemini20Flash, | ||
GoogleAiVersion.GEMINI_2_0_PRO_EXP_02_05: gemini20ProExp0205, | ||
} |
Oops, something went wrong.