Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressed #11: Added anthropic support #13

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mindsql/_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@
CONFIG_REQUIRED_ERROR = "Configuration is required."
LLAMA_PROMPT_EXCEPTION = "Prompt cannot be empty."
OPENAI_VALUE_ERROR = "OpenAI API key is required"
OPENAI_PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
PROMPT_EMPTY_EXCEPTION = "Prompt cannot be empty."
POSTGRESQL_SHOW_CREATE_TABLE_QUERY = """SELECT 'CREATE TABLE "' || table_name || '" (' || array_to_string(array_agg(column_name || ' ' || data_type), ', ') || ');' AS create_statement FROM information_schema.columns WHERE table_name = '{table}' GROUP BY table_name;"""
ANTHROPIC_VALUE_ERROR = "Anthropic API key is required"
91 changes: 91 additions & 0 deletions mindsql/llms/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from anthropic import Anthropic

from . import ILlm
from .._utils.constants import ANTHROPIC_VALUE_ERROR, PROMPT_EMPTY_EXCEPTION


class AnthropicAi(ILlm):
def __init__(self, config=None, client=None):
"""
Initialize the class with an optional config parameter.

Parameters:
config (any): The configuration parameter.
client (any): The client parameter.

Returns:
None
"""
self.config = config
self.client = client

if client is not None:
self.client = client
return

if 'api_key' not in config:
raise ValueError(ANTHROPIC_VALUE_ERROR)
api_key = config.pop('api_key')
self.client = Anthropic(api_key=api_key, **config)

def system_message(self, message: str) -> any:
"""
Create a system message.

Parameters:
message (str): The message parameter.

Returns:
any
"""
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
"""
Create a user message.

Parameters:
message (str): The message parameter.

Returns:
any
"""
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
"""
Create an assistant message.

Parameters:
message (str): The message parameter.

Returns:
any
"""
return {"role": "assistant", "content": message}

def invoke(self, prompt, **kwargs) -> str:
"""
Submit a prompt to the model for generating a response.

Parameters:
prompt (str): The prompt parameter.
**kwargs: Additional keyword arguments (optional).
- temperature (float): The temperature parameter for controlling randomness in generation.
- max_tokens (int): Maximum number of tokens to be generated.
Returns:
str: The generated response from the model.
"""
if prompt is None or len(prompt) == 0:
raise Exception(PROMPT_EMPTY_EXCEPTION)

model = self.config.get("model", "claude-3-opus-20240229")
temperature = kwargs.get("temperature", 0.1)
max_tokens = kwargs.get("max_tokens", 1024)
response = self.client.messages.create(model=model, messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens, temperature=temperature)
for content in response.content:
if isinstance(content, dict) and content.get("type") == "text":
return content["text"]
elif hasattr(content, "text"):
return content.text
4 changes: 2 additions & 2 deletions mindsql/llms/open_ai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from openai import OpenAI

from . import ILlm
from .._utils.constants import OPENAI_VALUE_ERROR, OPENAI_PROMPT_EMPTY_EXCEPTION
from .._utils.constants import OPENAI_VALUE_ERROR, PROMPT_EMPTY_EXCEPTION


class OpenAi(ILlm):
Expand Down Expand Up @@ -77,7 +77,7 @@ def invoke(self, prompt, **kwargs) -> str:
str: The generated response from the model.
"""
if prompt is None or len(prompt) == 0:
raise Exception(OPENAI_PROMPT_EMPTY_EXCEPTION)
raise Exception(PROMPT_EMPTY_EXCEPTION)

model = self.config.get("model", "gpt-3.5-turbo")
temperature = kwargs.get("temperature", 0.1)
Expand Down