diff --git a/mindsql/_utils/constants.py b/mindsql/_utils/constants.py index 58bc3fd..704dc57 100644 --- a/mindsql/_utils/constants.py +++ b/mindsql/_utils/constants.py @@ -36,3 +36,4 @@ SQLSERVER_SHOW_DATABASE_QUERY= "SELECT name FROM sys.databases;" SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT CONCAT(TABLE_SCHEMA,'.',TABLE_NAME) FROM [{db}].INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'" SQLSERVER_SHOW_CREATE_TABLE_QUERY = "DECLARE @TableName NVARCHAR(MAX) = '{table}'; DECLARE @SchemaName NVARCHAR(MAX) = '{schema}'; DECLARE @SQL NVARCHAR(MAX); SELECT @SQL = 'CREATE TABLE ' + @SchemaName + '.' + t.name + ' (' + CHAR(13) + ( SELECT ' ' + c.name + ' ' + UPPER(tp.name) + CASE WHEN tp.name IN ('char', 'varchar', 'nchar', 'nvarchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(c.max_length AS VARCHAR(10)) END + ')' WHEN tp.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR(10)) + ',' + CAST(c.scale AS VARCHAR(10)) + ')' ELSE '' END + ',' + CHAR(13) FROM sys.columns c JOIN sys.types tp ON c.user_type_id = tp.user_type_id WHERE c.object_id = t.object_id ORDER BY c.column_id FOR XML PATH(''), TYPE ).value('.', 'NVARCHAR(MAX)') + CHAR(13) + ')' FROM sys.tables t JOIN sys.schemas s ON t.schema_id = s.schema_id WHERE t.name = @TableName AND s.name = @SchemaName; SELECT @SQL AS SQLQuery;" +OLLAMA_CONFIG_REQUIRED = "{type} configuration is required." diff --git a/mindsql/llms/ollama.py b/mindsql/llms/ollama.py new file mode 100644 index 0000000..647bdd9 --- /dev/null +++ b/mindsql/llms/ollama.py @@ -0,0 +1,105 @@ +from ollama import Client, Options + +from .illm import ILlm +from .._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED +from .._utils import logger + +log = logger.init_loggers("Ollama Client") + + +class Ollama(ILlm): + def __init__(self, model_config: dict, client_config=None, client: Client = None): + """ + Initialize the class with an optional config parameter. + + Parameters: + model_config (dict): The model configuration parameter. + config (dict): The configuration parameter. + client (Client): The client parameter. + + Returns: + None + """ + self.client = client + self.client_config = client_config + self.model_config = model_config + + if self.client is not None: + if self.client_config is not None: + log.warning("Client object provided. Ignoring client_config parameter.") + return + + if client_config is None: + raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Client")) + + if model_config is None: + raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model")) + + if 'model' not in model_config: + raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model name")) + + self.client = Client(**client_config) + + def system_message(self, message: str) -> any: + """ + Create a system message. + + Parameters: + message (str): The message parameter. + + Returns: + any + """ + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + """ + Create a user message. + + Parameters: + message (str): The message parameter. + + Returns: + any + """ + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + """ + Create an assistant message. + + Parameters: + message (str): The message parameter. + + Returns: + any + """ + return {"role": "assistant", "content": message} + + def invoke(self, prompt, **kwargs) -> str: + """ + Submit a prompt to the model for generating a response. + + Parameters: + prompt (str): The prompt parameter. + **kwargs: Additional keyword arguments (optional). + - temperature (float): The temperature parameter for controlling randomness in generation. + + Returns: + str + """ + if not prompt: + raise ValueError(PROMPT_EMPTY_EXCEPTION) + + model = self.model_config.get('model') + temperature = kwargs.get('temperature', 0.1) + + response = self.client.chat( + model=model, + messages=[self.user_message(prompt)], + options=Options( + temperature=temperature + ) + ) + + return response['message']['content'] diff --git a/tests/ollama_test.py b/tests/ollama_test.py new file mode 100644 index 0000000..385424f --- /dev/null +++ b/tests/ollama_test.py @@ -0,0 +1,83 @@ +import unittest +from unittest.mock import MagicMock, patch +from ollama import Client, Options + +from mindsql.llms import ILlm +from mindsql.llms import Ollama +from mindsql._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED + + +class TestOllama(unittest.TestCase): + + def setUp(self): + # Common setup for each test case + self.model_config = {'model': 'sqlcoder'} + self.client_config = {'host': 'http://localhost:11434/'} + self.client_mock = MagicMock(spec=Client) + + def test_initialization_with_client(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + self.assertEqual(ollama.client, self.client_mock) + self.assertIsNone(ollama.client_config) + self.assertEqual(ollama.model_config, self.model_config) + + def test_initialization_with_client_config(self): + ollama = Ollama(model_config=self.model_config, client_config=self.client_config) + self.assertIsNotNone(ollama.client) + self.assertEqual(ollama.client_config, self.client_config) + self.assertEqual(ollama.model_config, self.model_config) + + def test_initialization_missing_client_and_client_config(self): + with self.assertRaises(ValueError) as context: + Ollama(model_config=self.model_config) + self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Client")) + + def test_initialization_missing_model_config(self): + with self.assertRaises(ValueError) as context: + Ollama(model_config=None, client_config=self.client_config) + self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model")) + + def test_initialization_missing_model_name(self): + with self.assertRaises(ValueError) as context: + Ollama(model_config={}, client_config=self.client_config) + self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model name")) + + def test_system_message(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + message = ollama.system_message("Test system message") + self.assertEqual(message, {"role": "system", "content": "Test system message"}) + + def test_user_message(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + message = ollama.user_message("Test user message") + self.assertEqual(message, {"role": "user", "content": "Test user message"}) + + def test_assistant_message(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + message = ollama.assistant_message("Test assistant message") + self.assertEqual(message, {"role": "assistant", "content": "Test assistant message"}) + + @patch.object(Client, 'chat', return_value={'message': {'content': 'Test response'}}) + def test_invoke_success(self, mock_chat): + ollama = Ollama(model_config=self.model_config, client=Client()) + response = ollama.invoke("Test prompt") + + # Check if the response is as expected + self.assertEqual(response, 'Test response') + + # Verify that the chat method was called with the correct arguments + mock_chat.assert_called_once_with( + model=self.model_config['model'], + messages=[{"role": "user", "content": "Test prompt"}], + options=Options(temperature=0.1) + ) + + def test_invoke_empty_prompt(self): + ollama = Ollama(model_config=self.model_config, client=self.client_mock) + with self.assertRaises(ValueError) as context: + ollama.invoke("") + self.assertEqual(str(context.exception), PROMPT_EMPTY_EXCEPTION) + + +if __name__ == '__main__': + unittest.main()