Skip to content

Commit

Permalink
feat(py): Initial implementation of OLLAMA plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
kirgrim committed Feb 22, 2025
1 parent 0149d10 commit 5714359
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 164 deletions.
3 changes: 2 additions & 1 deletion py/plugins/ollama/src/genkit/plugins/ollama/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0

from genkit.plugins.ollama.plugin_api import Ollama
from genkit.plugins.ollama.plugin_api import Ollama, ollama_name


def package_name() -> str:
Expand All @@ -11,4 +11,5 @@ def package_name() -> str:
__all__ = [
package_name.__name__,
Ollama.__name__,
ollama_name.__name__,
]
6 changes: 6 additions & 0 deletions py/plugins/ollama/src/genkit/plugins/ollama/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0
from enum import StrEnum

DEFAULT_OLLAMA_SERVER_URL = 'http://127.0.0.1:11434'


class OllamaAPITypes(StrEnum):
CHAT = 'chat'
GENERATE = 'generate'
87 changes: 78 additions & 9 deletions py/plugins/ollama/src/genkit/plugins/ollama/models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Dict, List, Optional

from enum import StrEnum
from typing import List, Dict, Optional
from genkit.core.schema_types import (
GenerateRequest,
GenerateResponse,
Message,
Role,
TextPart,
)
from genkit.plugins.ollama.constants import (
DEFAULT_OLLAMA_SERVER_URL,
OllamaAPITypes,
)
from pydantic import BaseModel, Field, HttpUrl

from pydantic import BaseModel, HttpUrl, Field
import ollama as ollama_api

from genkit.plugins.ollama.constants import DEFAULT_OLLAMA_SERVER_URL


class OllamaAPITypes(StrEnum):
CHAT = 'chat'
GENERATE = 'generate'
LOG = logging.getLogger(__name__)


class ModelDefinition(BaseModel):
Expand All @@ -29,3 +36,65 @@ class OllamaPluginParams(BaseModel):
embedders: List[EmbeddingModelDefinition] = Field(default_factory=list)
server_address: HttpUrl = Field(default=HttpUrl(DEFAULT_OLLAMA_SERVER_URL))
request_headers: Optional[Dict[str, str]] = None


class OllamaModel:
def __init__(
self, client: ollama_api.Client, model_definition: ModelDefinition
):
self.client = client
self.model_definition = model_definition

def generate(self, request: GenerateRequest) -> GenerateResponse:
txt_response = 'Failed to get response from Ollama API'

if self.model_definition.api_type == OllamaAPITypes.CHAT:
api_response = self._chat_with_ollama(request=request)
if api_response:
txt_response = api_response.message.content
else:
api_response = self._generate_ollama_response(request=request)
if api_response:
txt_response = api_response.response

return GenerateResponse(
message=Message(
role=Role.model,
content=[TextPart(text=txt_response)],
)
)

def _chat_with_ollama(
self, request: GenerateRequest
) -> ollama_api.ChatResponse | None:
ollama_messages: List[Dict[str, str]] = []

for message in request.messages:
item = {
'role': message.role.value,
'content': '',
}
for text_part in message.content:
if isinstance(text_part.root, TextPart):
item['content'] += text_part.root.text
else:
LOG.error(f'Unsupported part of message: {text_part}')
ollama_messages.append(item)
return self.client.chat(
model=self.model_definition.name, messages=ollama_messages
)

def _generate_ollama_response(
self, request: GenerateRequest
) -> ollama_api.GenerateResponse | None:
request_kwargs = {
'model': self.model_definition.name,
'prompt': '',
}
for message in request.messages:
for text_part in message.content:
if isinstance(text_part.root, TextPart):
request_kwargs['prompt'] += text_part.root.text
else:
LOG.error('Non-text messages are not supported')
return self.client.generate(**request_kwargs)
61 changes: 35 additions & 26 deletions py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,47 @@
"""

import logging
from typing import Callable

import ollama as ollama_api

from genkit.plugins.ollama.models import OllamaPluginParams
from genkit.plugins.ollama.utils import (
register_ollama_embedder,
register_ollama_model,
from genkit.core.action import ActionKind
from genkit.core.plugin_abc import Plugin
from genkit.core.registry import Registry
from genkit.plugins.ollama.models import (
OllamaAPITypes,
OllamaModel,
OllamaPluginParams,
)
from genkit.veneer import Genkit

import ollama as ollama_api

LOG = logging.getLogger(__name__)


def Ollama(plugin_params: OllamaPluginParams) -> Callable[[Genkit], None]:
client = ollama_api.Client(
host=plugin_params.server_address.unicode_string()
)
def ollama_name(name: str) -> str:
return f'ollama/{name}'

def plugin(ai: Genkit) -> None:
for model in plugin_params.models:
register_ollama_model(
ai=ai,
model=model,
client=client,
)

for embedder in plugin_params.embedders:
register_ollama_embedder(
ai=ai,
embedder=embedder,
client=client,
)
class Ollama(Plugin):
def __init__(self, plugin_params: OllamaPluginParams):
self.plugin_params = plugin_params
self.client = ollama_api.Client(
host=self.plugin_params.server_address.unicode_string()
)

return plugin
def initialize(self, registry: Registry) -> None:
for model_definition in self.plugin_params.models:
model = OllamaModel(
client=self.client,
model_definition=model_definition,
)
registry.register_action(
kind=ActionKind.MODEL,
name=ollama_name(model_definition.name),
fn=model.generate,
metadata={
'multiturn': model_definition.api_type
== OllamaAPITypes.CHAT,
'system_role': True,
},
)
# TODO: introduce embedders here
# for embedder in self.plugin_params.embedders:
11 changes: 0 additions & 11 deletions py/plugins/ollama/src/genkit/plugins/ollama/utils/__init__.py

This file was deleted.

This file was deleted.

91 changes: 0 additions & 91 deletions py/plugins/ollama/src/genkit/plugins/ollama/utils/model_utils.py

This file was deleted.

2 changes: 2 additions & 0 deletions py/plugins/ollama/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2025 Google LLC
# SPDX-License-Identifier: Apache-2.0
Loading

0 comments on commit 5714359

Please sign in to comment.