diff --git a/mindsql/_utils/constants.py b/mindsql/_utils/constants.py index 0d59f58..3d81e7b 100644 --- a/mindsql/_utils/constants.py +++ b/mindsql/_utils/constants.py @@ -32,4 +32,5 @@ OPENAI_VALUE_ERROR = "OpenAI API key is required" PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty." POSTGRESQL_SHOW_CREATE_TABLE_QUERY = """SELECT 'CREATE TABLE "' || table_name || '" (' || array_to_string(array_agg(column_name || ' ' || data_type), ', ') || ');' AS create_statement FROM information_schema.columns WHERE table_name = '{table}' GROUP BY table_name;""" -ANTHROPIC_VALUE_ERROR = "Anthropic API key is required" \ No newline at end of file +ANTHROPIC_VALUE_ERROR = "Anthropic API key is required" +OLLAMA_CONFIG_REQUIRED = "{type} configuration is required." diff --git a/mindsql/llms/ollama.py b/mindsql/llms/ollama.py new file mode 100644 index 0000000..647bdd9 --- /dev/null +++ b/mindsql/llms/ollama.py @@ -0,0 +1,105 @@ +from ollama import Client, Options + +from .illm import ILlm +from .._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED +from .._utils import logger + +log = logger.init_loggers("Ollama Client") + + +class Ollama(ILlm): + def __init__(self, model_config: dict, client_config=None, client: Client = None): + """ + Initialize the class with an optional config parameter. + + Parameters: + model_config (dict): The model configuration parameter. + config (dict): The configuration parameter. + client (Client): The client parameter. + + Returns: + None + """ + self.client = client + self.client_config = client_config + self.model_config = model_config + + if self.client is not None: + if self.client_config is not None: + log.warning("Client object provided. Ignoring client_config parameter.") + return + + if client_config is None: + raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Client")) + + if model_config is None: + raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model")) + + if 'model' not in model_config: + raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model name")) + + self.client = Client(**client_config) + + def system_message(self, message: str) -> any: + """ + Create a system message. + + Parameters: + message (str): The message parameter. + + Returns: + any + """ + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + """ + Create a user message. + + Parameters: + message (str): The message parameter. + + Returns: + any + """ + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + """ + Create an assistant message. + + Parameters: + message (str): The message parameter. + + Returns: + any + """ + return {"role": "assistant", "content": message} + + def invoke(self, prompt, **kwargs) -> str: + """ + Submit a prompt to the model for generating a response. + + Parameters: + prompt (str): The prompt parameter. + **kwargs: Additional keyword arguments (optional). + - temperature (float): The temperature parameter for controlling randomness in generation. + + Returns: + str + """ + if not prompt: + raise ValueError(PROMPT_EMPTY_EXCEPTION) + + model = self.model_config.get('model') + temperature = kwargs.get('temperature', 0.1) + + response = self.client.chat( + model=model, + messages=[self.user_message(prompt)], + options=Options( + temperature=temperature + ) + ) + + return response['message']['content'] diff --git a/tests/ollama_test.py b/tests/ollama_test.py new file mode 100644 index 0000000..385424f --- /dev/null +++ b/tests/ollama_test.py @@ -0,0 +1,83 @@ +import unittest +from unittest.mock import MagicMock, patch +from ollama import Client, Options + +from mindsql.llms import ILlm +from mindsql.llms import Ollama +from mindsql._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED + + +class TestOllama(unittest.TestCase): + + def setUp(self): + # Common setup for each test case + self.model_config = {'model': 'sqlcoder'} + self.client_config = {'host': 'http://localhost:11434/'} + self.client_mock = MagicMock(spec=Client) + + def test_initialization_with_client(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + self.assertEqual(ollama.client, self.client_mock) + self.assertIsNone(ollama.client_config) + self.assertEqual(ollama.model_config, self.model_config) + + def test_initialization_with_client_config(self): + ollama = Ollama(model_config=self.model_config, client_config=self.client_config) + self.assertIsNotNone(ollama.client) + self.assertEqual(ollama.client_config, self.client_config) + self.assertEqual(ollama.model_config, self.model_config) + + def test_initialization_missing_client_and_client_config(self): + with self.assertRaises(ValueError) as context: + Ollama(model_config=self.model_config) + self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Client")) + + def test_initialization_missing_model_config(self): + with self.assertRaises(ValueError) as context: + Ollama(model_config=None, client_config=self.client_config) + self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model")) + + def test_initialization_missing_model_name(self): + with self.assertRaises(ValueError) as context: + Ollama(model_config={}, client_config=self.client_config) + self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model name")) + + def test_system_message(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + message = ollama.system_message("Test system message") + self.assertEqual(message, {"role": "system", "content": "Test system message"}) + + def test_user_message(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + message = ollama.user_message("Test user message") + self.assertEqual(message, {"role": "user", "content": "Test user message"}) + + def test_assistant_message(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + message = ollama.assistant_message("Test assistant message") + self.assertEqual(message, {"role": "assistant", "content": "Test assistant message"}) + + @patch.object(Client, 'chat', return_value={'message': {'content': 'Test response'}}) + def test_invoke_success(self, mock_chat): + ollama = Ollama(model_config=self.model_config, client=Client()) + response = ollama.invoke("Test prompt") + + # Check if the response is as expected + self.assertEqual(response, 'Test response') + + # Verify that the chat method was called with the correct arguments + mock_chat.assert_called_once_with( + model=self.model_config['model'], + messages=[{"role": "user", "content": "Test prompt"}], + options=Options(temperature=0.1) + ) + + def test_invoke_empty_prompt(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + with self.assertRaises(ValueError) as context: + ollama.invoke("") + self.assertEqual(str(context.exception), PROMPT_EMPTY_EXCEPTION) + + +if __name__ == '__main__': + unittest.main()