diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/__init__.py b/py/plugins/ollama/src/genkit/plugins/ollama/__init__.py index 62e060156..85ac6a513 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/__init__.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/__init__.py @@ -1,7 +1,7 @@ # Copyright 2025 Google LLC # SPDX-License-Identifier: Apache-2.0 -from genkit.plugins.ollama.plugin_api import Ollama +from genkit.plugins.ollama.plugin_api import Ollama, ollama_name def package_name() -> str: @@ -11,4 +11,5 @@ def package_name() -> str: __all__ = [ package_name.__name__, Ollama.__name__, + ollama_name.__name__, ] diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py index 28658ae84..a30590733 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py @@ -1,4 +1,10 @@ # Copyright 2025 Google LLC # SPDX-License-Identifier: Apache-2.0 +from enum import StrEnum DEFAULT_OLLAMA_SERVER_URL = 'http://127.0.0.1:11434' + + +class OllamaAPITypes(StrEnum): + CHAT = 'chat' + GENERATE = 'generate' diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/models.py b/py/plugins/ollama/src/genkit/plugins/ollama/models.py index ca2b9e1bd..e943c75e6 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/models.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/models.py @@ -1,17 +1,24 @@ # Copyright 2025 Google LLC # SPDX-License-Identifier: Apache-2.0 +import logging +from typing import Dict, List, Optional -from enum import StrEnum -from typing import List, Dict, Optional +from genkit.core.schema_types import ( + GenerateRequest, + GenerateResponse, + Message, + Role, + TextPart, +) +from genkit.plugins.ollama.constants import ( + DEFAULT_OLLAMA_SERVER_URL, + OllamaAPITypes, +) +from pydantic import BaseModel, Field, HttpUrl -from pydantic import BaseModel, HttpUrl, Field +import ollama as ollama_api -from genkit.plugins.ollama.constants import DEFAULT_OLLAMA_SERVER_URL - - -class OllamaAPITypes(StrEnum): - CHAT = 'chat' - GENERATE = 'generate' +LOG = logging.getLogger(__name__) class ModelDefinition(BaseModel): @@ -29,3 +36,65 @@ class OllamaPluginParams(BaseModel): embedders: List[EmbeddingModelDefinition] = Field(default_factory=list) server_address: HttpUrl = Field(default=HttpUrl(DEFAULT_OLLAMA_SERVER_URL)) request_headers: Optional[Dict[str, str]] = None + + +class OllamaModel: + def __init__( + self, client: ollama_api.Client, model_definition: ModelDefinition + ): + self.client = client + self.model_definition = model_definition + + def generate(self, request: GenerateRequest) -> GenerateResponse: + txt_response = 'Failed to get response from Ollama API' + + if self.model_definition.api_type == OllamaAPITypes.CHAT: + api_response = self._chat_with_ollama(request=request) + if api_response: + txt_response = api_response.message.content + else: + api_response = self._generate_ollama_response(request=request) + if api_response: + txt_response = api_response.response + + return GenerateResponse( + message=Message( + role=Role.model, + content=[TextPart(text=txt_response)], + ) + ) + + def _chat_with_ollama( + self, request: GenerateRequest + ) -> ollama_api.ChatResponse | None: + ollama_messages: List[Dict[str, str]] = [] + + for message in request.messages: + item = { + 'role': message.role.value, + 'content': '', + } + for text_part in message.content: + if isinstance(text_part.root, TextPart): + item['content'] += text_part.root.text + else: + LOG.error(f'Unsupported part of message: {text_part}') + ollama_messages.append(item) + return self.client.chat( + model=self.model_definition.name, messages=ollama_messages + ) + + def _generate_ollama_response( + self, request: GenerateRequest + ) -> ollama_api.GenerateResponse | None: + request_kwargs = { + 'model': self.model_definition.name, + 'prompt': '', + } + for message in request.messages: + for text_part in message.content: + if isinstance(text_part.root, TextPart): + request_kwargs['prompt'] += text_part.root.text + else: + LOG.error('Non-text messages are not supported') + return self.client.generate(**request_kwargs) diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index b94566f44..3c6def55a 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -6,38 +6,47 @@ """ import logging -from typing import Callable -import ollama as ollama_api - -from genkit.plugins.ollama.models import OllamaPluginParams -from genkit.plugins.ollama.utils import ( - register_ollama_embedder, - register_ollama_model, +from genkit.core.action import ActionKind +from genkit.core.plugin_abc import Plugin +from genkit.core.registry import Registry +from genkit.plugins.ollama.models import ( + OllamaAPITypes, + OllamaModel, + OllamaPluginParams, ) -from genkit.veneer import Genkit + +import ollama as ollama_api LOG = logging.getLogger(__name__) -def Ollama(plugin_params: OllamaPluginParams) -> Callable[[Genkit], None]: - client = ollama_api.Client( - host=plugin_params.server_address.unicode_string() - ) +def ollama_name(name: str) -> str: + return f'ollama/{name}' - def plugin(ai: Genkit) -> None: - for model in plugin_params.models: - register_ollama_model( - ai=ai, - model=model, - client=client, - ) - for embedder in plugin_params.embedders: - register_ollama_embedder( - ai=ai, - embedder=embedder, - client=client, - ) +class Ollama(Plugin): + def __init__(self, plugin_params: OllamaPluginParams): + self.plugin_params = plugin_params + self.client = ollama_api.Client( + host=self.plugin_params.server_address.unicode_string() + ) - return plugin + def initialize(self, registry: Registry) -> None: + for model_definition in self.plugin_params.models: + model = OllamaModel( + client=self.client, + model_definition=model_definition, + ) + registry.register_action( + kind=ActionKind.MODEL, + name=ollama_name(model_definition.name), + fn=model.generate, + metadata={ + 'multiturn': model_definition.api_type + == OllamaAPITypes.CHAT, + 'system_role': True, + }, + ) + # TODO: introduce embedders here + # for embedder in self.plugin_params.embedders: diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/utils/__init__.py b/py/plugins/ollama/src/genkit/plugins/ollama/utils/__init__.py deleted file mode 100644 index f319895a9..000000000 --- a/py/plugins/ollama/src/genkit/plugins/ollama/utils/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2025 Google LLC -# SPDX-License-Identifier: Apache-2.0 - -from genkit.plugins.ollama.utils.model_utils import register_ollama_model -from genkit.plugins.ollama.utils.embedding_utils import register_ollama_embedder - - -__all__ = [ - register_ollama_model.__name__, - register_ollama_embedder.__name__, -] diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/utils/embedding_utils.py b/py/plugins/ollama/src/genkit/plugins/ollama/utils/embedding_utils.py deleted file mode 100644 index ec542ed73..000000000 --- a/py/plugins/ollama/src/genkit/plugins/ollama/utils/embedding_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2025 Google LLC -# SPDX-License-Identifier: Apache-2.0 - - -""" -Ollama Embedders for Genkit. -""" - -import ollama as ollama_api - -from genkit.plugins.ollama.models import EmbeddingModelDefinition -from genkit.veneer import Genkit - - -def register_ollama_embedder( - ai: Genkit, - embedder: EmbeddingModelDefinition, - client: ollama_api.Client, -) -> None: - pass diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/utils/model_utils.py b/py/plugins/ollama/src/genkit/plugins/ollama/utils/model_utils.py deleted file mode 100644 index e59a4f9be..000000000 --- a/py/plugins/ollama/src/genkit/plugins/ollama/utils/model_utils.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2025 Google LLC -# SPDX-License-Identifier: Apache-2.0 - - -""" -Ollama Models for Genkit. -""" - -import logging -from typing import Optional, List, Dict - - -from genkit.core.types import ( - GenerateRequest, - GenerateResponse, - TextPart, - Message, - Role, -) -from genkit.veneer import Genkit - -from genkit.plugins.ollama.models import ModelDefinition, OllamaAPITypes - -import ollama as ollama_api - -LOG = logging.getLogger(__name__) - - -def register_ollama_model( - ai: Genkit, - model: ModelDefinition, - client: ollama_api.Client, -) -> None: - def _execute_ollama_request(request: GenerateRequest) -> GenerateResponse: - def _chat_with_ollama() -> ollama_api.ChatResponse: - ollama_messages: List[Dict[str, str]] = [] - - for message in request.messages: - item = { - 'role': message.role.value, - 'content': '', - } - for text_part in message.content: - if isinstance(text_part, TextPart): - item['content'] += text_part.text - else: - LOG.warning(f'Unsupported part of message: {text_part}') - ollama_messages.append(item) - return client.chat(model=model.name, messages=ollama_messages) - - def _generate_ollama_response() -> Optional[ - ollama_api.GenerateResponse - ]: - request_kwargs = { - 'model': model.name, - 'prompt': '', - } - for message in request.messages: - for text_part in message.content: - if isinstance(text_part, TextPart): - request_kwargs['prompt'] += text_part.text - else: - LOG.error('Non-text messages are not supported') - return client.generate(**request_kwargs) - - txt_response = 'Failed to get response from Ollama API' - - if model.api_type == OllamaAPITypes.CHAT: - api_response = _chat_with_ollama() - if api_response: - txt_response = api_response.message.content - else: - api_response = _generate_ollama_response() - if api_response: - txt_response = api_response.response - - return GenerateResponse( - message=Message( - role=Role.model, - content=[TextPart(text=txt_response)], - ) - ) - - ai.define_model( - name=f'ollama/{model.name}', - fn=_execute_ollama_request, - metadata={ - 'multiturn': model.api_type == OllamaAPITypes.CHAT, - 'system_role': True, - }, - ) diff --git a/py/plugins/ollama/tests/__init__.py b/py/plugins/ollama/tests/__init__.py new file mode 100644 index 000000000..7229ac50e --- /dev/null +++ b/py/plugins/ollama/tests/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 diff --git a/py/plugins/ollama/tests/conftest.py b/py/plugins/ollama/tests/conftest.py new file mode 100644 index 000000000..8058a7784 --- /dev/null +++ b/py/plugins/ollama/tests/conftest.py @@ -0,0 +1,80 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from unittest import mock + +from genkit.plugins.ollama import Ollama +from genkit.plugins.ollama.models import ( + OllamaPluginParams, + ModelDefinition, + OllamaAPITypes, +) +from genkit.plugins.ollama.plugin_api import ollama_api +from genkit.veneer import Genkit + + +@pytest.fixture +def ollama_model() -> str: + return 'ollama/llama3.2:latest' + + +@pytest.fixture +def chat_model_plugin_params(ollama_model: str) -> OllamaPluginParams: + return OllamaPluginParams( + models=[ + ModelDefinition( + name=ollama_model.split('/')[-1], + api_type=OllamaAPITypes.CHAT, + ) + ], + ) + + +@pytest.fixture +def genkit_veneer_chat_model( + ollama_model: str, + chat_model_plugin_params: OllamaPluginParams, +) -> Genkit: + return Genkit( + plugins=[ + Ollama( + plugin_params=chat_model_plugin_params, + ) + ], + model=ollama_model, + ) + + +@pytest.fixture +def generate_model_plugin_params(ollama_model: str) -> OllamaPluginParams: + return OllamaPluginParams( + models=[ + ModelDefinition( + name=ollama_model.split('/')[-1], + api_type=OllamaAPITypes.GENERATE, + ) + ], + ) + + +@pytest.fixture +def genkit_veneer_generate_model( + ollama_model: str, + generate_model_plugin_params: OllamaPluginParams, +) -> Genkit: + return Genkit( + plugins=[ + Ollama( + plugin_params=generate_model_plugin_params, + ) + ], + model=ollama_model, + ) + + +@pytest.fixture +def mock_ollama_api_client(): + with mock.patch.object(ollama_api, 'Client') as mock_ollama_client: + yield mock_ollama_client diff --git a/py/plugins/ollama/tests/test_plugin_api.py b/py/plugins/ollama/tests/test_plugin_api.py new file mode 100644 index 000000000..91b0deb55 --- /dev/null +++ b/py/plugins/ollama/tests/test_plugin_api.py @@ -0,0 +1,87 @@ +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +from unittest import mock + +import ollama as ollama_api +from genkit.core.schema_types import GenerateResponse, Message, Role, TextPart +from genkit.veneer import Genkit + + +def test_adding_ollama_chat_model_to_genkit_veneer( + ollama_model: str, + genkit_veneer_chat_model: Genkit, +): + assert len(genkit_veneer_chat_model.registry.actions) == 1 + assert ollama_model in genkit_veneer_chat_model.registry.actions['model'] + + +def test_adding_ollama_generation_model_to_genkit_veneer( + ollama_model: str, + genkit_veneer_generate_model: Genkit, +): + assert len(genkit_veneer_generate_model.registry.actions) == 1 + assert ( + ollama_model in genkit_veneer_generate_model.registry.actions['model'] + ) + + +def test_get_chat_model_response_from_llama_api_flow( + mock_ollama_api_client: mock.Mock, genkit_veneer_chat_model: Genkit +): + mock_response_message = 'Mocked response message' + + mock_ollama_api_client.return_value.chat.return_value = ( + ollama_api.ChatResponse( + message=ollama_api.Message( + content=mock_response_message, + role='user', + ) + ) + ) + + def _test_fun(): + return genkit_veneer_chat_model.generate( + messages=[ + Message( + role=Role.user, + content=[ + TextPart(text='Test message'), + ], + ) + ] + ) + + response = genkit_veneer_chat_model.flow()(_test_fun)() + + assert isinstance(response, GenerateResponse) + assert response.message.content[0].root.text == mock_response_message + + +def test_get_generate_model_response_from_llama_api_flow( + mock_ollama_api_client: mock.Mock, genkit_veneer_generate_model: Genkit +): + mock_response_message = 'Mocked response message' + + mock_ollama_api_client.return_value.generate.return_value = ( + ollama_api.GenerateResponse( + response=mock_response_message, + ) + ) + + def _test_fun(): + return genkit_veneer_generate_model.generate( + messages=[ + Message( + role=Role.user, + content=[ + TextPart(text='Test message'), + ], + ) + ] + ) + + response = genkit_veneer_generate_model.flow()(_test_fun)() + + assert isinstance(response, GenerateResponse) + assert response.message.content[0].root.text == mock_response_message diff --git a/py/samples/ollama/hello.py b/py/samples/ollama/hello.py index 9bbb43199..f384b09a2 100644 --- a/py/samples/ollama/hello.py +++ b/py/samples/ollama/hello.py @@ -2,17 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 -from genkit.core.types import Message, TextPart, Role +from genkit.core.schema_types import Message, Role, TextPart +from genkit.plugins.ollama import Ollama, ollama_name from genkit.plugins.ollama.models import ( - OllamaPluginParams, ModelDefinition, OllamaAPITypes, + OllamaPluginParams, ) -from genkit.plugins.ollama import Ollama - from genkit.veneer import Genkit - # model can be pulled with `ollama pull *LLM_VERSION*` LLM_VERSION = 'llama3.2:latest' @@ -31,7 +29,7 @@ plugin_params=plugin_params, ) ], - model=f'ollama/{LLM_VERSION}', + model=ollama_name(LLM_VERSION), )