Skip to content

Commit

Permalink
feat(py): Basic imagen
Browse files Browse the repository at this point in the history
  • Loading branch information
Irillit committed Feb 27, 2025
1 parent c9ea982 commit 42876d4
Show file tree
Hide file tree
Showing 12 changed files with 582 additions and 8 deletions.
2 changes: 1 addition & 1 deletion py/packages/genkit/src/genkit/veneer/veneer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def generate(
Returns:
The generated text response.
"""
model = model if model is not None else self.registry.defaultModel
model = model if model is not None else self.registry.default_model
if model is None:
raise Exception('No model configured.')
if config and not isinstance(config, GenerationCommonConfig):
Expand Down
2 changes: 1 addition & 1 deletion py/plugins/vertex-ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",
]
dependencies = ["genkit", "google-cloud-aiplatform>=1.77.0"]
dependencies = ["genkit", "google-cloud-aiplatform>=1.77.0", "pytest-mock"]
description = "Genkit Google Cloud Vertex AI Plugin"
license = { text = "Apache-2.0" }
name = "genkit-vertex-ai-plugin"
Expand Down
12 changes: 7 additions & 5 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from genkit.plugins.vertex_ai.embedding import EmbeddingModels
from genkit.plugins.vertex_ai.gemini import GeminiVersion
from genkit.plugins.vertex_ai.imagen import ImagenVersion
from genkit.plugins.vertex_ai.plugin_api import VertexAI, vertexai_name


Expand All @@ -21,9 +22,10 @@ def package_name() -> str:


__all__ = [
'package_name',
'VertexAI',
'vertexai_name',
'EmbeddingModels',
'GeminiVersion',
package_name.__name__,
VertexAI.__name__,
vertexai_name.__name__,
EmbeddingModels.__name__,
GeminiVersion.__name__,
ImagenVersion.__name__,
]
112 changes: 112 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/imagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from enum import StrEnum
from typing import Any

from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Media1,
MediaPart,
Message,
ModelInfo,
Role,
Supports,
)
from vertexai.preview.vision_models import ImageGenerationModel


class ImagenVersion(StrEnum):
IMAGEN3 = 'imagen-3.0-generate-002'
IMAGEN3_FAST = 'imagen-3.0-fast-generate-001'
IMAGEN2 = 'imagegeneration@006'


SUPPORTED_MODELS = {
ImagenVersion.IMAGEN3: ModelInfo(
label='Vertex AI - Imagen3',
supports=Supports(
media=True,
multiturn=False,
tools=False,
systemRole=False,
output=['media'],
),
),
ImagenVersion.IMAGEN3_FAST: ModelInfo(
label='Vertex AI - Imagen3 Fast',
supports=Supports(
media=False,
multiturn=False,
tools=False,
systemRole=False,
output=['media'],
),
),
ImagenVersion.IMAGEN2: ModelInfo(
label='Vertex AI - Imagen2',
supports=Supports(
media=False,
multiturn=False,
tools=False,
systemRole=False,
output=['media'],
),
),
}


class Imagen:
"""Imagen - text to image model."""

def __init__(self, version):
self._version = version

@property
def model(self) -> ImageGenerationModel:
return ImageGenerationModel.from_pretrained(self._version)

def handle_request(self, request: GenerateRequest) -> GenerateResponse:
parts: list[str] = []
for m in request.messages:
for p in m.content:
if p.root.text is not None:
parts.append(p.root.text)
else:
raise Exception('unsupported part type')

prompt = ' '.join(parts)
images = self.model.generate_images(
prompt=prompt,
number_of_images=1,
language='en',
aspect_ratio='1:1',
safety_filter_level='block_some',
person_generation='allow_adult',
)

media_content = [
MediaPart(
media=Media1(
contentType=image._mime_type, url=image._as_base64_string()
)
)
for image in images
]

return GenerateResponse(
message=Message(
role=Role.MODEL,
content=media_content,
)
)

@property
def model_metadata(self) -> dict[str, Any]:
supports = SUPPORTED_MODELS[self._version].supports.model_dump()
return {
'model': {
'supports': supports,
}
}
10 changes: 10 additions & 0 deletions py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from genkit.plugins.vertex_ai import constants as const
from genkit.plugins.vertex_ai.embedding import Embedder, EmbeddingModels
from genkit.plugins.vertex_ai.gemini import Gemini, GeminiVersion
from genkit.plugins.vertex_ai.imagen import Imagen, ImagenVersion

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,3 +86,12 @@ def initialize(self, registry: Registry) -> None:
fn=embedder.handle_request,
metadata=embedder.model_metadata,
)

for imagen_version in ImagenVersion:
imagen = Imagen(imagen_version)
registry.register_action(
kind=ActionKind.MODEL,
name=vertexai_name(imagen_version),
fn=imagen.handle_request,
metadata=imagen.model_metadata,
)
52 changes: 52 additions & 0 deletions py/plugins/vertex-ai/tests/test_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Test Gemini models."""

import pytest
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Message,
Role,
TextPart,
)
from genkit.plugins.vertex_ai.gemini import Gemini, GeminiVersion


@pytest.mark.parametrize('version', [x for x in GeminiVersion])
def test_generate_text_response(mocker, version):
mocked_respond = 'Mocked Respond'
request = GenerateRequest(
messages=[
Message(
role=Role.USER,
content=[
TextPart(text=f'Hi, mock!'),
],
),
]
)
gemini = Gemini(version)
genai_model_mock = mocker.MagicMock()
model_response_mock = mocker.MagicMock()
model_response_mock.text = mocked_respond
genai_model_mock.generate_content.return_value = model_response_mock
mocker.patch(
'genkit.plugins.vertex_ai.gemini.Gemini.gemini_model', genai_model_mock
)

response = gemini.handle_request(request)
assert isinstance(response, GenerateResponse)
assert response.message.content[0].root.text == mocked_respond


@pytest.mark.parametrize('version', [x for x in GeminiVersion])
def test_gemini_metadata(version):
gemini = Gemini(version)
supports = gemini.model_metadata['model']['supports']
assert isinstance(supports, dict)
assert supports['multiturn']
assert supports['media']
assert supports['tools']
assert supports['system_role']
57 changes: 57 additions & 0 deletions py/plugins/vertex-ai/tests/test_imagen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

"""Test Gemini models."""

import pytest
from genkit.core.typing import (
GenerateRequest,
GenerateResponse,
Media1,
Message,
Role,
TextPart,
)
from genkit.plugins.vertex_ai.imagen import Imagen, ImagenVersion


@pytest.mark.parametrize('version', [x for x in ImagenVersion])
def test_generate(mocker, version):
mocked_respond = 'Supposed Base64 string'
request = GenerateRequest(
messages=[
Message(
role=Role.USER,
content=[
TextPart(text=f'Draw a test.'),
],
),
]
)
imagen = Imagen(version)
genai_model_mock = mocker.MagicMock()
model_response_mock = mocker.MagicMock()
model_response_mock._mime_type = ''
model_response_mock._as_base64_string.return_value = mocked_respond
genai_model_mock.generate_images.return_value = [model_response_mock]
mocker.patch(
'genkit.plugins.vertex_ai.imagen.Imagen.model', genai_model_mock
)

response = imagen.handle_request(request)
assert isinstance(response, GenerateResponse)
assert isinstance(response.message.content[0].root.media, Media1)
assert response.message.content[0].root.media.url == mocked_respond


@pytest.mark.parametrize('version', [x for x in ImagenVersion])
def test_gemini_metadata(version):
imagen = Imagen(version)
supports = imagen.model_metadata['model']['supports']
assert isinstance(supports, dict)
assert not supports['multiturn']
assert not supports['tools']
assert not supports['system_role']
Loading

0 comments on commit 42876d4

Please sign in to comment.