Skip to content

Commit

Permalink
Add more generate() kwargs (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii authored Nov 10, 2023
1 parent 62afb2c commit 7313a87
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 15 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,14 @@ While only the model name or path is required to stand up a non-persistent pipel
Users can also control the generation characteristics for individual prompts (i.e., when calling `pipe()`) with the following options:

- `max_length: int` Sets the per-prompt maximum token length for prompt + response.
- `min_new_tokens: int` Sets the minimum number of tokens generated in the response. `max_length` will take precedence over this setting.
- `max_new_tokens: int` Sets the maximum number of tokens generated in the response.
- `ignore_eos: bool` (Defaults to `False`) Setting to `True` prevents generation from ending when the EOS token is encountered.
- `top_p: float` (Defaults to `0.9`) When set below `1.0`, filter tokens and keep only the most probable, where token probabilities sum to ≥`top_p`.
- `top_k: int` (Defaults to `None`) When `None`, top-k filtering is disabled. When set, the number of highest probability tokens to keep.
- `temperature: float` (Defaults to `None`) When `None`, temperature is disabled. When set, modulates token probabilities.
- `do_sample: bool` (Defaults to `True`) When `True`, sample output logits. When `False`, use greedy sampling.
- `return_full_text: bool` (Defaults to `False`) When `True`, prepends the input prompt to the returned text

## Persistent Deployment

Expand Down Expand Up @@ -234,11 +237,14 @@ While only the model name or path is required to stand up a persistent deploymen
Users can also control the generation characteristics for individual prompts (i.e., when calling `client.generate()`) with the following options:

- `max_length: int` Sets the per-prompt maximum token length for prompt + response.
- `min_new_tokens: int` Sets the minimum number of tokens generated in the response. `max_length` will take precedence over this setting.
- `max_new_tokens: int` Sets the maximum number of tokens generated in the response.
- `ignore_eos: bool` (Defaults to `False`) Setting to `True` prevents generation from ending when the EOS token is encountered.
- `top_p: float` (Defaults to `0.9`) When set below `1.0`, filter tokens and keep only the most probable, where token probabilities sum to ≥`top_p`.
- `top_k: int` (Defaults to `None`) When `None`, top-k filtering is disabled. When set, the number of highest probability tokens to keep.
- `temperature: float` (Defaults to `None`) When `None`, temperature is disabled. When set, modulates token probabilities.
- `do_sample: bool` (Defaults to `True`) When `True`, sample output logits. When `False`, use greedy sampling.
- `return_full_text: bool` (Defaults to `False`) When `True`, prepends the input prompt to the returned text


# Contributing
Expand Down
7 changes: 7 additions & 0 deletions mii/batching/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@
# generate() kwargs
MAX_LENGTH_KWARG = "max_length"
MAX_NEW_TOKENS_KWARG = "max_new_tokens"
MIN_NEW_TOKENS_KWARG = "min_new_tokens"
STREAM_KWARG = "stream"
IGNORE_EOS_KWARG = "ignore_eos"
TOP_P_KWARG = "top_p"
TOP_K_KWARG = "top_k"
TEMPERATURE_KWARG = "temperature"
RETURN_FULL_TEXT_KWARG = "return_full_text"
DO_SAMPLE_KWARG = "do_sample"
STOP_KWARG = "stop"

# Default kwarg values
MIN_NEW_TOKENS_DEFAULT = 0
STREAM_DEFAULT = False
IGNORE_EOS_DEFAULT = False
TOP_P_DEFAULT = 0.9
RETURN_FULL_TEXT_DEFAULT = False
DO_SAMPLE_DEFAULT = True

# Processing method key names
TOP_K_NAME = "TopK"
Expand Down
6 changes: 3 additions & 3 deletions mii/batching/generation/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(
class GreedySampler(BaseGenerationSampler):
def __call__(self, logits: torch.Tensor) -> Tuple[torch.LongTensor, torch.Tensor]:
logits = logits.float()
sampler = Categorical(logits=logits)
#sampler = Categorical(logits=logits)
next_tokens = logits.argmax(dim=-1)
logprobs = sampler.log_prob(next_tokens)
return next_tokens, logprobs
#logprobs = sampler.log_prob(next_tokens)
return next_tokens #, logprobs
2 changes: 1 addition & 1 deletion mii/batching/generation/stop_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TokenStopCriterion(BaseGenerationStopCriterion):
def __init__(self, token: Union[str, int], tokenizer) -> None:
super().__init__(tokenizer=tokenizer)
if isinstance(token, str):
token_id = self.tokenizer.tokenize(token)[0]
token_id = self.tokenizer.encode(token)[0]
else:
token_id = token
self.stop_token_id = token_id
Expand Down
57 changes: 46 additions & 11 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,29 @@

from mii.batching.constants import (MAX_LENGTH_KWARG,
MAX_NEW_TOKENS_KWARG,
MIN_NEW_TOKENS_KWARG,
STREAM_KWARG,
IGNORE_EOS_KWARG,
TOP_P_KWARG,
TOP_K_KWARG,
TEMPERATURE_KWARG,
RETURN_FULL_TEXT_KWARG,
DO_SAMPLE_KWARG,
STOP_KWARG,
MIN_NEW_TOKENS_DEFAULT,
STREAM_DEFAULT,
IGNORE_EOS_DEFAULT,
TOP_P_DEFAULT,
RETURN_FULL_TEXT_DEFAULT,
DO_SAMPLE_DEFAULT,
TOP_K_NAME,
TOP_P_NAME,
TEMP_NAME,
SAMPLER_NAME,
STOP_NAME)
from mii.batching.generation.logit_processors import TopPLogitProcessor, TopKLogitProcessor, TemperatureLogitProcessor
from mii.batching.generation.samplers import LogitsSampler
from mii.batching.generation.stop_criterion import EosGenerationStopCriterion
from mii.batching.generation.samplers import LogitsSampler, GreedySampler
from mii.batching.generation.stop_criterion import EosGenerationStopCriterion, TokenStopCriterion
from mii.batching.postprocess import (
run_batch_logit_processing,
run_batch_sampler,
Expand Down Expand Up @@ -138,10 +145,12 @@ class RaggedRequest:
seq_length: int
max_length: int
max_new_tokens: int
min_new_tokens: int
last_in_prompt: bool
post_processing: List[object]
stream: bool = False
ignore_eos: bool = False
return_full_text: bool = False

_next_token: Union[None, torch.Tensor] = None
_is_done: bool = False
Expand All @@ -164,6 +173,8 @@ def next_token(self, next_token: Union[None, torch.Tensor]) -> None:
def is_done(self) -> bool:
if self.ignore_eos:
return False
if self.seq_length < self.min_new_tokens:
return False
return self._is_done

@is_done.setter
Expand Down Expand Up @@ -477,6 +488,9 @@ def _generate_output(self, r: RaggedRequest) -> bool:
else:
output_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens],
dim=0)
if r.return_full_text:
# Avoid returning bos token, refactor this later
output_tokens = torch.cat((r.prompt_tokens[1:], output_tokens))
outputs.append((
r.uid,
output_tokens,
Expand Down Expand Up @@ -581,6 +595,7 @@ def _queue_flush_request(self, uid: int) -> None:
seq_length=None,
max_length=None,
max_new_tokens=None,
min_new_tokens=None,
last_in_prompt=None,
post_processing=None,
stream=None,
Expand Down Expand Up @@ -612,8 +627,10 @@ def make_request(self,
max_length = kwargs.pop(MAX_LENGTH_KWARG, self.max_length)
assert max_length > prompt_length, f"prompt length must be less than {MAX_LENGTH_KWARG}"
max_new_tokens = kwargs.pop(MAX_NEW_TOKENS_KWARG, max_length - prompt_length)
min_new_tokens = kwargs.pop(MIN_NEW_TOKENS_KWARG, MIN_NEW_TOKENS_DEFAULT)
stream = kwargs.pop(STREAM_KWARG, STREAM_DEFAULT)
ignore_eos = kwargs.pop(IGNORE_EOS_KWARG, IGNORE_EOS_DEFAULT)
return_full_text = kwargs.pop(RETURN_FULL_TEXT_KWARG, RETURN_FULL_TEXT_DEFAULT)

post_processing = []

Expand All @@ -638,14 +655,30 @@ def make_request(self,
temperature=temp)
post_processing.append(temp_name)

if SAMPLER_NAME not in self._post_processors:
self._post_processors[SAMPLER_NAME] = LogitsSampler()
post_processing.append(SAMPLER_NAME)

if STOP_NAME not in self._post_processors:
self._post_processors[STOP_NAME] = EosGenerationStopCriterion(
tokenizer=self.tokenizer)
post_processing.append(STOP_NAME)
do_sample = kwargs.pop(DO_SAMPLE_KWARG, DO_SAMPLE_DEFAULT)
if do_sample:
sampler_name = "_".join((SAMPLER_NAME, "logits"))
if sampler_name not in self._post_processors:
self._post_processors[sampler_name] = LogitsSampler()
else:
sampler_name = "_".join((SAMPLER_NAME, "greedy"))
if sampler_name not in self._post_processors:
self._post_processors[sampler_name] = GreedySampler()
post_processing.append(sampler_name)

stop = kwargs.pop(STOP_KWARG, None)
if stop is not None:
stop_name = "_".join((STOP_NAME, stop))
if stop_name not in self._post_processors:
self._post_processors[stop_name] = TokenStopCriterion(
token=stop,
tokenizer=self.tokenizer)
else:
stop_name = STOP_NAME
if STOP_NAME not in self._post_processors:
self._post_processors[stop_name] = EosGenerationStopCriterion(
tokenizer=self.tokenizer)
post_processing.append(stop_name)

assert kwargs == {}, f"Unknown keyword arguments {kwargs}"

Expand All @@ -657,10 +690,12 @@ def make_request(self,
seq_length=0,
max_length=max_length,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
last_in_prompt=True,
post_processing=post_processing,
stream=stream,
ignore_eos=ignore_eos,
return_full_text=return_full_text,
)

def make_response(self,
Expand Down Expand Up @@ -821,7 +856,7 @@ def get_response(self) -> Tuple[int, Response]:
# this requires some refactoring how we do the put and request in
# `ModelResponse`
if not self.is_rank_0:
return Response(generated_text="",
return -1, Response(generated_text="",
prompt_length=None,
generated_length=None,
finish_reason=None)
Expand Down
3 changes: 3 additions & 0 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def _run_inference(self, method_name, request_proto):
# so new requests can be processed
while uids_running:
uid, response = self.inference_pipeline.get_response()
# TODO: Ugly hack for multi-threading. Will be fixed when we refactor these methods
if uid == -1:
uid = uids_running[0]
responses.append(response)
self.inference_pipeline.flush_uid(uid)
uids_complete_order.append(uids_running.index(uid))
Expand Down
20 changes: 20 additions & 0 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,30 @@ def test_multi_replica(deployment, query):


def test_query_kwargs(deployment, query):
# test ignore_eos
output = deployment(query,
max_length=128,
min_new_tokens=16,
ignore_eos=True,
top_p=0.9,
top_k=50,
temperature=0.9)
assert output, "output is empty"


def test_do_sample(deployment, query):
output_0 = deployment(query, do_sample=False, max_length=128)
output_1 = deployment(query, do_sample=False, max_length=128)
assert output_0.response == output_1.response, "do_sample=False should always return the same output"


def test_stop_token(deployment, query):
pytest.skip("not working yet")
output = deployment(query, stop=".", max_length=512)
print(str(output.response))
assert str(output.response[0]).endswith("."), "output should end with 'the'"


def test_return_full_text(deployment, query):
output = deployment(query, max_length=128, return_full_text=True)
assert output.response[0].startswith(query), "output should start with the prompt"

0 comments on commit 7313a87

Please sign in to comment.