Skip to content

Commit

Permalink
Merge branch 'master' into sqlserver-integration
Browse files Browse the repository at this point in the history
  • Loading branch information
cyrannano authored Aug 23, 2024
2 parents 845f26d + b95ca73 commit 701c046
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
1 change: 1 addition & 0 deletions mindsql/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
SQLSERVER_SHOW_DATABASE_QUERY= "SELECT name FROM sys.databases;"
SQLSERVER_DB_TABLES_INFO_SCHEMA_QUERY = "SELECT CONCAT(TABLE_SCHEMA,'.',TABLE_NAME) FROM [{db}].INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE'"
SQLSERVER_SHOW_CREATE_TABLE_QUERY = "DECLARE @TableName NVARCHAR(MAX) = '{table}'; DECLARE @SchemaName NVARCHAR(MAX) = '{schema}'; DECLARE @SQL NVARCHAR(MAX); SELECT @SQL = 'CREATE TABLE ' + @SchemaName + '.' + t.name + ' (' + CHAR(13) + ( SELECT ' ' + c.name + ' ' + UPPER(tp.name) + CASE WHEN tp.name IN ('char', 'varchar', 'nchar', 'nvarchar') THEN '(' + CASE WHEN c.max_length = -1 THEN 'MAX' ELSE CAST(c.max_length AS VARCHAR(10)) END + ')' WHEN tp.name IN ('decimal', 'numeric') THEN '(' + CAST(c.precision AS VARCHAR(10)) + ',' + CAST(c.scale AS VARCHAR(10)) + ')' ELSE '' END + ',' + CHAR(13) FROM sys.columns c JOIN sys.types tp ON c.user_type_id = tp.user_type_id WHERE c.object_id = t.object_id ORDER BY c.column_id FOR XML PATH(''), TYPE ).value('.', 'NVARCHAR(MAX)') + CHAR(13) + ')' FROM sys.tables t JOIN sys.schemas s ON t.schema_id = s.schema_id WHERE t.name = @TableName AND s.name = @SchemaName; SELECT @SQL AS SQLQuery;"
OLLAMA_CONFIG_REQUIRED = "{type} configuration is required."
105 changes: 105 additions & 0 deletions mindsql/llms/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from ollama import Client, Options

from .illm import ILlm
from .._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED
from .._utils import logger

log = logger.init_loggers("Ollama Client")


class Ollama(ILlm):
def __init__(self, model_config: dict, client_config=None, client: Client = None):
"""
Initialize the class with an optional config parameter.
Parameters:
model_config (dict): The model configuration parameter.
config (dict): The configuration parameter.
client (Client): The client parameter.
Returns:
None
"""
self.client = client
self.client_config = client_config
self.model_config = model_config

if self.client is not None:
if self.client_config is not None:
log.warning("Client object provided. Ignoring client_config parameter.")
return

if client_config is None:
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Client"))

if model_config is None:
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model"))

if 'model' not in model_config:
raise ValueError(OLLAMA_CONFIG_REQUIRED.format(type="Model name"))

self.client = Client(**client_config)

def system_message(self, message: str) -> any:
"""
Create a system message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
"""
Create a user message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
"""
Create an assistant message.
Parameters:
message (str): The message parameter.
Returns:
any
"""
return {"role": "assistant", "content": message}

def invoke(self, prompt, **kwargs) -> str:
"""
Submit a prompt to the model for generating a response.
Parameters:
prompt (str): The prompt parameter.
**kwargs: Additional keyword arguments (optional).
- temperature (float): The temperature parameter for controlling randomness in generation.
Returns:
str
"""
if not prompt:
raise ValueError(PROMPT_EMPTY_EXCEPTION)

model = self.model_config.get('model')
temperature = kwargs.get('temperature', 0.1)

response = self.client.chat(
model=model,
messages=[self.user_message(prompt)],
options=Options(
temperature=temperature
)
)

return response['message']['content']
83 changes: 83 additions & 0 deletions tests/ollama_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import unittest
from unittest.mock import MagicMock, patch
from ollama import Client, Options

from mindsql.llms import ILlm
from mindsql.llms import Ollama
from mindsql._utils.constants import PROMPT_EMPTY_EXCEPTION, OLLAMA_CONFIG_REQUIRED


class TestOllama(unittest.TestCase):

def setUp(self):
# Common setup for each test case
self.model_config = {'model': 'sqlcoder'}
self.client_config = {'host': 'http://localhost:11434/'}
self.client_mock = MagicMock(spec=Client)

def test_initialization_with_client(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
self.assertEqual(ollama.client, self.client_mock)
self.assertIsNone(ollama.client_config)
self.assertEqual(ollama.model_config, self.model_config)

def test_initialization_with_client_config(self):
ollama = Ollama(model_config=self.model_config, client_config=self.client_config)
self.assertIsNotNone(ollama.client)
self.assertEqual(ollama.client_config, self.client_config)
self.assertEqual(ollama.model_config, self.model_config)

def test_initialization_missing_client_and_client_config(self):
with self.assertRaises(ValueError) as context:
Ollama(model_config=self.model_config)
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Client"))

def test_initialization_missing_model_config(self):
with self.assertRaises(ValueError) as context:
Ollama(model_config=None, client_config=self.client_config)
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model"))

def test_initialization_missing_model_name(self):
with self.assertRaises(ValueError) as context:
Ollama(model_config={}, client_config=self.client_config)
self.assertEqual(str(context.exception), OLLAMA_CONFIG_REQUIRED.format(type="Model name"))

def test_system_message(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
message = ollama.system_message("Test system message")
self.assertEqual(message, {"role": "system", "content": "Test system message"})

def test_user_message(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
message = ollama.user_message("Test user message")
self.assertEqual(message, {"role": "user", "content": "Test user message"})

def test_assistant_message(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
message = ollama.assistant_message("Test assistant message")
self.assertEqual(message, {"role": "assistant", "content": "Test assistant message"})

@patch.object(Client, 'chat', return_value={'message': {'content': 'Test response'}})
def test_invoke_success(self, mock_chat):
ollama = Ollama(model_config=self.model_config, client=Client())
response = ollama.invoke("Test prompt")

# Check if the response is as expected
self.assertEqual(response, 'Test response')

# Verify that the chat method was called with the correct arguments
mock_chat.assert_called_once_with(
model=self.model_config['model'],
messages=[{"role": "user", "content": "Test prompt"}],
options=Options(temperature=0.1)
)

def test_invoke_empty_prompt(self):
ollama = Ollama(model_config=self.model_config, client=self.client_mock)
with self.assertRaises(ValueError) as context:
ollama.invoke("")
self.assertEqual(str(context.exception), PROMPT_EMPTY_EXCEPTION)


if __name__ == '__main__':
unittest.main()

0 comments on commit 701c046

Please sign in to comment.