Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added watsonx.ai generator #1058

Merged
merged 9 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/garak.generators.watsonx.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
garak.generators.watsonx
=======================

.. automodule:: garak.generators.watsonx
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ For a detailed oversight into how a generator operates, see :ref:`garak.generato
garak.generators.rest
garak.generators.rasa
garak.generators.test
garak.generators.watsonx

126 changes: 126 additions & 0 deletions garak/generators/watsonx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from garak import _config
from garak.generators.base import Generator
from typing import List, Union
import os
import importlib


class WatsonXGenerator(Generator):
"""
This is a generator for watsonx.ai.

Make sure that you initialize the environment variables:
'WATSONX_TOKEN',
'WATSONX_URL',
and 'WATSONX_PROJECTID'.

To use a tuned model that is deployed, use 'deployment/deployment' for the -n flag and make sure
to also initialize the 'WATSONX_DEPLOYID' environment variable.
"""

ENV_VAR = "WATSONX_TOKEN"
URI_ENV_VAR = "WATSONX_URL"
PID_ENV_VAR = "WATSONX_PROJECTID"
DID_ENV_VAR = "WATSONX_DEPLOYID"
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"uri": None,
"project_id": None,
"deployment_id": None,
"frequency_penalty": 0.5,
"logprobs": True,
"top_logprobs": 3,
"presence_penalty": 0.3,
"temperature": 0.7,
"max_tokens": 100,
"time_limit": 300000,
"top_p": 0.9,
"n": 1,
}

generator_family_name = "watsonx"

def __init__(self, name="", config_root=_config):
super().__init__(name, config_root=config_root)

# Initialize and validate api_key
if self.api_key is not None:
os.environ[self.ENV_VAR] = self.api_key

# Initialize and validate url.
if self.uri is not None:
pass
else :
self.uri = os.getenv("WATSONX_URL", None)
if self.uri is None:
raise ValueError(
f"The {self.URI_ENV_VAR} environment variable is required. Please enter the URL corresponding to the region of your provisioned service instance. \n"
)
# Initialize and validate project_id.
if self.project_id is not None:
pass
else :
self.project_id = os.getenv("WATSONX_PROJECTID", None)
if self.project_id is None:
raise ValueError(
f"The {self.PID_ENV_VAR} environment variable is required. Please enter the corresponding Project ID of the resource. \n"
)

# Import Foundation Models from ibm_watsonx_ai module. Import the Credentials function from the same module.
self.watsonx = importlib.import_module("ibm_watsonx_ai.foundation_models")
self.Credentials = getattr(
importlib.import_module("ibm_watsonx_ai"), "Credentials"
)

def get_model(self):
# Call Credentials function with the url and api_key.
credentials = self.Credentials(url=self.uri, api_key=self.api_key)
if self.name == "deployment/deployment":
self.deployment_id = os.getenv("WATSONX_DEPLOYID", None)
if self.deployment_id is None:
raise ValueError(
f"The {self.DID_ENV_VAR} environment variable is required. Please enter the corresponding Deployment ID of the resource. \n"
)

return self.watsonx.ModelInference(
deployment_id=self.deployment_id,
credentials=credentials,
project_id=self.project_id,
)

else:
return self.watsonx.ModelInference(
model_id=self.name,
credentials=credentials,
project_id=self.project_id,
params=self.watsonx.schema.TextChatParameters(
frequency_penalty=self.frequency_penalty,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,
presence_penalty=self.presence_penalty,
temperature=self.temperature,
max_tokens=self.max_tokens,
time_limit=self.time_limit,
top_p=self.top_p,
n=self.n,
),
)

def _call_model(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:

# Get/Create Model
model = self.get_model()

# Check if message is empty. If it is, append null byte.
if not prompt:
prompt = "\x00"
print(
"WARNING: Empty prompt was found. Null byte character appended to prevent API failure."
)

# Parse the output to only contain the output message from the model. Return a list containing that message.
return ["".join(model.generate(prompt=prompt)["results"][0]["generated_text"])]


DEFAULT_CLASS = "WatsonXGenerator"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ dependencies = [
"zalgolib>=0.2.2",
"ecoji>=0.1.1",
"deepl==1.17.0",
"ibm-watsonx-ai==1.1.25",
"fschat>=0.2.36",
"litellm>=1.41.21",
"jsonpath-ng>=1.6.1",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ deepl==1.17.0
fschat>=0.2.36
litellm>=1.41.21
jsonpath-ng>=1.6.1
ibm-watsonx-ai==1.1.25
huggingface_hub>=0.21.0
python-magic-bin>=0.4.14; sys_platform == "win32"
python-magic>=0.4.21; sys_platform != "win32"
Expand Down
1 change: 1 addition & 0 deletions tests/generators/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def test_instantiate_generators(classname):
"org_id": "fake", # required for NeMo
"uri": "https://example.com", # required for rest
"provider": "fake", # required for LiteLLM
"project_id": "fake", # required for watsonx
}
}
}
Expand Down
Loading