Skip to content

Commit

Permalink
generators: add output delimiters; add hook for postprocessing (#1097)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmartin-tech committed Feb 20, 2025
2 parents 859c413 + dc1d998 commit 1139fce
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 113 deletions.
9 changes: 9 additions & 0 deletions docs/source/garak.generators.base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Attributes:
* context_len - The number of tokens in the model context window, or None
* modality - A dictionary with two keys, "in" and "out", each holding a set of the modalities supported by the generator. "in" refers to prompt expectations, and "out" refers to output. For example, a text-to-text+image model would have modality: ``dict = {"in": {"text"}, "out": {"text", "image"}}``.
* supports_multiple_generations - Whether or not the generator can natively return multiple outputs from a prompt in a single function call. When set to False, the ``generate()`` method will make repeated calls, one output at a time, until the requested number of generations (in ``generations``) is reached.
* skip_seq_start, skip_start_end - If both asserted, content between these two will be pruned before being returned. Useful for removing chain-of-thought, for example

Functions:

Expand All @@ -32,12 +33,20 @@ The general flow in ``generate()`` is as follows:
#. Otherwise, we need to assemble the outputs over multiple calls. There are two options here.
#. Is garak running with ``parallel_attempts > 1`` configured? In that case, start a multiprocessing pool with as many workers as the value of ``parallel_attempts``, and have each one of these work on building the required number of generations, in any order.
#. Otherwise, call ``_call_model()`` repeatedly to collect the requested number of generations.
#. Call the ``_post_generate_hook()`` (a no-op by default)
#. If skip sequence start and end are both defined, call ``_prune_skip_sequences()``
#. Return the resulting list of prompt responses.


#. **_call_model()**: This method handles direct interaction with the model. It takes a prompt and an optional number of generations this call, and returns a list of prompt responses (e.g. strings) and ``None``s. Models may return ``None`` in the case the underlying system failed unrecoverably. This is the method to write model interaction code in. If the class' supports_multiple_generations is false, _call_model does not need to accept values of ``generations_this_call`` other than ``1``.

#. **_pre_generate_hook()**: An optional hook called before generation, useful if the class needs to do some setup or housekeeping before generation.

#. **_verify_model_result**: Validation of model output types, useful in debugging. If this fails, the generator doesn't match the expectations in the rest of garak.

#. **_post_generate_hook()**: An optional hook called after generation, useful if the class needs to do some modification of output.

#. **_prune_skip_sequences()**: Called if both ``skip_seq_start`` and ``skip_seq_end`` are defined. Strip out any response content between the start and end markers.



Expand Down
38 changes: 38 additions & 0 deletions garak/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import logging
import re
from typing import List, Union

from colorama import Fore, Style
Expand All @@ -23,6 +24,8 @@ class Generator(Configurable):
"temperature": None,
"top_k": None,
"context_len": None,
"skip_seq_start": None,
"skip_seq_end": None,
}

active = True
Expand Down Expand Up @@ -85,6 +88,35 @@ def _verify_model_result(result: List[Union[str, None]]):
def clear_history(self):
pass

def _post_generate_hook(self, outputs: List[str | None]) -> List[str | None]:
return outputs

def _prune_skip_sequences(self, outputs: List[str | None]) -> List[str | None]:
rx_complete = (
re.escape(self.skip_seq_start) + ".*?" + re.escape(self.skip_seq_end)
)
rx_missing_final = re.escape(self.skip_seq_start) + ".*?$"

complete_seqs_removed = [
(
re.sub(rx_complete, "", o, flags=re.DOTALL | re.MULTILINE)
if o is not None
else None
)
for o in outputs
]

partial_seqs_removed = [
(
re.sub(rx_missing_final, "", o, flags=re.DOTALL | re.MULTILINE)
if o is not None
else None
)
for o in complete_seqs_removed
]

return partial_seqs_removed

def generate(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
Expand Down Expand Up @@ -152,4 +184,10 @@ def generate(
self._verify_model_result(output_one)
outputs.append(output_one[0])

outputs = self._post_generate_hook(outputs)

if hasattr(self, "skip_seq_start") and hasattr(self, "skip_seq_end"):
if self.skip_seq_start is not None and self.skip_seq_end is not None:
outputs = self._prune_skip_sequences(outputs)

return outputs
2 changes: 2 additions & 0 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class LiteLLMGenerator(Generator):
"top_k",
"frequency_penalty",
"presence_penalty",
"skip_seq_start",
"skip_seq_end",
"stop",
)

Expand Down
6 changes: 3 additions & 3 deletions garak/generators/nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class NVOpenAIChat(OpenAICompatible):
"uri": "https://integrate.api.nvidia.com/v1/",
"vary_seed_each_call": True, # encourage variation when generations>1. not respected by all NIMs
"vary_temp_each_call": True, # encourage variation when generations>1. not respected by all NIMs
"suppressed_params": {"n", "frequency_penalty", "presence_penalty"},
"suppressed_params": {"n", "frequency_penalty", "presence_penalty", "timeout"},
}
active = True
supports_multiple_generations = False
Expand Down Expand Up @@ -91,8 +91,8 @@ def _call_model(
logging.critical(msg, exc_info=uee)
raise GarakException(f"🛑 {msg}") from uee
# except openai.NotFoundError as oe:
except Exception as oe:
msg = "NIM endpoint not found. Is the model name spelled correctly?"
except Exception as oe: # too broad
msg = "NIM generation failed. Is the model name spelled correctly?"
logging.critical(msg, exc_info=oe)
raise GarakException(f"🛑 {msg}") from oe

Expand Down
21 changes: 20 additions & 1 deletion garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class OpenAICompatible(Generator):
"stop": ["#", ";"],
"suppressed_params": set(),
"retry_json": True,
"extra_params": {},
}

# avoid attempt to pickle the client attribute
Expand Down Expand Up @@ -220,8 +221,15 @@ def _call_model(
if arg == "model":
create_args[arg] = self.name
continue
if arg == "extra_params":
continue
if hasattr(self, arg) and arg not in self.suppressed_params:
create_args[arg] = getattr(self, arg)
if getattr(self, arg) is not None:
create_args[arg] = getattr(self, arg)

if hasattr(self, "extra_params"):
for k, v in self.extra_params.items():
create_args[k] = v

if self.generator == self.client.completions:
if not isinstance(prompt, str):
Expand Down Expand Up @@ -263,6 +271,17 @@ def _call_model(
else:
raise e

if not hasattr(response, "choices"):
logging.debug(
"Did not get a well-formed response, retrying. Expected object with .choices member, got: '%s'"
% repr(response)
)
msg = "no .choices member in generator response"
if self.retry_json:
raise garak.exception.GarakBackoffTrigger(msg)
else:
return [None]

if self.generator == self.client.completions:
return [c.text for c in response.choices]
elif self.generator == self.client.chat.completions:
Expand Down
2 changes: 2 additions & 0 deletions garak/generators/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class RestGenerator(Generator):
"request_timeout",
"ratelimit_codes",
"skip_codes",
"skip_seq_start",
"skip_seq_end",
"temperature",
"top_k",
"proxies",
Expand Down
1 change: 0 additions & 1 deletion tests/generators/test_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import re

from garak import cli
Expand Down
147 changes: 39 additions & 108 deletions tests/generators/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import importlib
import inspect
import pytest
import random

from garak import _plugins
from garak import _config
from garak.generators.test import Blank, Repeat, Single
from garak.generators.base import Generator


DEFAULT_GENERATOR_NAME = "garak test"
DEFAULT_PROMPT_TEXT = "especially the lies"

Expand All @@ -20,112 +19,6 @@
]


def test_generators_test_blank():
g = Blank(DEFAULT_GENERATOR_NAME)
output = g.generate(prompt="test", generations_this_call=5)
assert output == [
"",
"",
"",
"",
"",
], "generators.test.Blank with generations_this_call=5 should return five empty strings"


def test_generators_test_repeat():
g = Repeat(DEFAULT_GENERATOR_NAME)
output = g.generate(prompt=DEFAULT_PROMPT_TEXT)
assert output == [
DEFAULT_PROMPT_TEXT
], "generators.test.Repeat should send back a list of the posed prompt string"


def test_generators_test_single_one():
g = Single(DEFAULT_GENERATOR_NAME)
output = g.generate(prompt="test")
assert isinstance(
output, list
), "Single generator .generate() should send back a list"
assert (
len(output) == 1
), "Single.generate() without generations_this_call should send a list of one string"
assert isinstance(
output[0], str
), "Single generator output list should contain strings"

output = g._call_model(prompt="test")
assert isinstance(output, list), "Single generator _call_model should return a list"
assert (
len(output) == 1
), "_call_model w/ generations_this_call 1 should return a list of length 1"
assert isinstance(
output[0], str
), "Single generator output list should contain strings"


def test_generators_test_single_many():
random_generations = random.randint(2, 12)
g = Single(DEFAULT_GENERATOR_NAME)
output = g.generate(prompt="test", generations_this_call=random_generations)
assert isinstance(
output, list
), "Single generator .generate() should send back a list"
assert (
len(output) == random_generations
), "Single.generate() with generations_this_call should return equal generations"
for i in range(0, random_generations):
assert isinstance(
output[i], str
), "Single generator output list should contain strings (all positions)"


def test_generators_test_single_too_many():
g = Single(DEFAULT_GENERATOR_NAME)
with pytest.raises(ValueError):
output = g._call_model(prompt="test", generations_this_call=2)
assert "Single._call_model should refuse to process generations_this_call > 1"


def test_generators_test_blank_one():
g = Blank(DEFAULT_GENERATOR_NAME)
output = g.generate(prompt="test")
assert isinstance(
output, list
), "Blank generator .generate() should send back a list"
assert (
len(output) == 1
), "Blank generator .generate() without generations_this_call should return a list of length 1"
assert isinstance(
output[0], str
), "Blank generator output list should contain strings"
assert (
output[0] == ""
), "Blank generator .generate() output list should contain strings"


def test_generators_test_blank_many():
g = Blank(DEFAULT_GENERATOR_NAME)
output = g.generate(prompt="test", generations_this_call=2)
assert isinstance(
output, list
), "Blank generator .generate() should send back a list"
assert (
len(output) == 2
), "Blank generator .generate() w/ generations_this_call=2 should return a list of length 2"
assert isinstance(
output[0], str
), "Blank generator output list should contain strings (first position)"
assert isinstance(
output[1], str
), "Blank generator output list should contain strings (second position)"
assert (
output[0] == ""
), "Blank generator .generate() output list should contain strings (first position)"
assert (
output[1] == ""
), "Blank generator .generate() output list should contain strings (second position)"


def test_parallel_requests():
_config.system.parallel_requests = 2

Expand Down Expand Up @@ -221,3 +114,41 @@ def test_instantiate_generators(classname):
m = importlib.import_module("garak." + ".".join(classname.split(".")[:-1]))
g = getattr(m, classname.split(".")[-1])(config_root=config_root)
assert isinstance(g, Generator)


def test_skip_seq():
target_string = "TEST TEST 1234"
test_string_with_thinking = "TEST TEST <think>not thius tho</think>1234"
test_string_with_thinking_complex = '<think></think>TEST TEST <think>not thius tho</think>1234<think>!"(^-&$(!$%*))</think>'
test_string_with_newlines = "<think>\n\n</think>" + target_string
g = _plugins.load_plugin("generators.test.Repeat")
r = g.generate(test_string_with_thinking)
g.skip_seq_start = None
g.skip_seq_end = None
assert (
r[0] == test_string_with_thinking
), "test.Repeat should give same output as input when no think tokens specified"
g.skip_seq_start = "<think>"
g.skip_seq_end = "</think>"
r = g.generate(test_string_with_thinking)
assert (
r[0] == target_string
), "content between single skip sequence should be removed"
r = g.generate(test_string_with_thinking_complex)
assert (
r[0] == target_string
), "content between multiple skip sequences should be removed"
r = g.generate(test_string_with_newlines)
assert r[0] == target_string, "skip seqs full of newlines should be removed"

test_no_answer = "<think>not sure the output to provide</think>"
r = g.generate(test_no_answer)
assert r[0] == "", "Output of all skip strings should be empty"

test_truncated_think = f"<think>thinking a bit</think>{target_string}<think>this process required a lot of details that is processed by"
r = g.generate(test_truncated_think)
assert r[0] == target_string, "truncated skip strings should be omitted"

test_truncated_think_no_answer = "<think>thinking a bit</think><think>this process required a lot of details that is processed by"
r = g.generate(test_truncated_think_no_answer)
assert r[0] == "", "truncated skip strings should be omitted"
Loading

0 comments on commit 1139fce

Please sign in to comment.