diff --git a/README.md b/README.md index 93bb986..dd48779 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ Originally based on [Tom Dörr's `fish.codex` repository](https://github.com/tom-doerr/codex.fish), but with some additional functionality. It uses the [chat completions API endpoint](https://platform.openai.com/docs/api-reference/chat/create) -and can be hooked up to OpenAI, Azure OpenAI or a self-hosted LLM behind any -OpenAI-compatible API. +and can be hooked up to Google, OpenAI, Azure OpenAI or a self-hosted LLM +behind any OpenAI-compatible API. Continuous integration is performed against Azure OpenAI. @@ -60,6 +60,17 @@ model = api_key = ``` +If you use [Gemini](https://deepmind.google/technologies/gemini): + +```ini +[fish-ai] +configuration = gemini + +[gemini] +provider = google +api_key = +``` + ### Install `fish-ai` Install the plugin. You can install it using [`fisher`](https://github.com/jorgebucaran/fisher). diff --git a/conf.d/fish_ai.fish b/conf.d/fish_ai.fish index 57d73c8..624bc67 100644 --- a/conf.d/fish_ai.fish +++ b/conf.d/fish_ai.fish @@ -24,11 +24,11 @@ end ## function _fish_ai_install --on-event fish_ai_install python3 -m venv ~/.fish-ai - ~/.fish-ai/bin/pip install -qq openai + ~/.fish-ai/bin/pip install -qq openai google-generativeai end function _fish_ai_update --on-event fish_ai_update - ~/.fish-ai/bin/pip install -qq --upgrade openai + ~/.fish-ai/bin/pip install -qq --upgrade openai google-generativeai end function __fish_ai_uninstall --on-event fish_ai_uninstall diff --git a/functions/_fish_ai_engine.py b/functions/_fish_ai_engine.py index 25c06e2..099bcf7 100644 --- a/functions/_fish_ai_engine.py +++ b/functions/_fish_ai_engine.py @@ -2,6 +2,7 @@ from openai import OpenAI from openai import AzureOpenAI +import google.generativeai as genai from configparser import ConfigParser from os import path import logging @@ -49,7 +50,7 @@ def get_config(key): return config.get(section=active_section, option=key) -def get_client(): +def get_openai_client(): if (get_config('provider') == 'azure'): return AzureOpenAI( azure_endpoint=get_config('server'), @@ -72,18 +73,55 @@ def get_client(): .format(get_config('provider'))) +def create_message_history(messages): + """ + Create message history which can be used with Gemini. + Google uses a different chat history format than OpenAI. + The message content should be put in a parts array and + system messages are not supported. + """ + outputs = [] + system_messages = [] + for message in messages: + if message.get('role') == 'system': + system_messages.append(message.get('content')) + for i in range(len(messages) - 1): + message = messages[i] + if message.get('role') == 'user': + outputs.append({ + 'role': 'user', + 'parts': system_messages + [message.get('content')] if i == 0 + else [message.get('content')] + }) + elif message.get('role') == 'assistant': + outputs.append({ + 'role': 'model', + 'parts': [message.get('content')] + }) + return outputs + + def get_response(messages): start_time = time_ns() - completions = get_client().chat.completions.create( - model=get_config('model'), - max_tokens=4096, - messages=messages, - stream=False, - temperature=float(get_config('temperature') or '0.2'), - n=1, - ) + + if get_config('provider') == 'google': + genai.configure(api_key=get_config('api_key')) + model = genai.GenerativeModel(get_config('model') or 'gemini-pro') + chat = model.start_chat(history=create_message_history(messages)) + response = (chat.send_message(messages[-1].get('content')) + .text.strip(' `')) + else: + completions = get_openai_client().chat.completions.create( + model=get_config('model'), + max_tokens=4096, + messages=messages, + stream=False, + temperature=float(get_config('temperature') or '0.2'), + n=1, + ) + response = completions.choices[0].message.content.strip(' `') + end_time = time_ns() - response = completions.choices[0].message.content.strip(' `') get_logger().debug('Response received from backend: ' + response) get_logger().debug('Processing time: ' + str(round((end_time - start_time) / 1000000)) + ' ms.')