Skip to content

Commit

Permalink
feat(tgi): warmup runs prefill/decode on all supported combinations
Browse files Browse the repository at this point in the history
This will prevent XLA compilation at inference time. Note that I had to
disable dynamo compilation though, otherwise the model was not
generating correct results. This leads to slower generation, but at
least generation seems stable now.
  • Loading branch information
tengomucho committed Jul 4, 2024
1 parent feba7a4 commit aba86d9
Showing 1 changed file with 63 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
GeneratedText,
Generation,
InfoResponse,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
Tokens,
)

Expand Down Expand Up @@ -325,7 +327,7 @@ def __init__(
)
self._supports_static_cache = False
# compile model when possible to accelerate decoding
if model.device.type == "xla" and ("DBG_COMPILE" in os.environ or self._supports_static_cache):
if model.device.type == "xla" and ("DBG_COMPILE" in os.environ):
self.model_one_token = torch.compile(model, backend="openxla")
logger.debug("Model compiled for decoding")
else:
Expand All @@ -341,6 +343,34 @@ def info(self) -> InfoResponse:
device_type="xla",
)

def _create_dummy_request(self, max_tokens: int) -> Batch:
"""Create a dummy request for warmup."""
# Generate a random input with slightly more tokens than requested, because special tokens are going to be
# skipped.
MARGIN = 10
input_tokens = torch.randint(self.model.config.vocab_size, (1, max_tokens + MARGIN), dtype=torch.int64)
text = self.tokenizer.decode(input_tokens[0], skip_special_tokens=True)
# These are just dummy params to allo Request creation
parameters = NextTokenChooserParameters(
temperature=1.0,
top_k=None,
top_p=None,
do_sample=False,
seed=None,
repetition_penalty=1.0,
typical_p=1.0,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=20, ignore_eos_token=True)
dummy_request = Request(
id=0,
inputs=text,
truncate=max_tokens,
parameters=parameters,
stopping_parameters=stopping_parameters,
)
return dummy_request


def warmup(self, batch: Batch) -> int:
"""Verify if the hardware can support the target load.
Expand All @@ -352,6 +382,7 @@ def warmup(self, batch: Batch) -> int:
The maximum number of tokens the model supports.
"""
logger.debug("Warming up the model")
start = time.time()
# Just check that the warmup request parameters match the model capacity
# NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size.
if self.model.config.batch_size is not None:
Expand All @@ -363,9 +394,37 @@ def warmup(self, batch: Batch) -> int:
raise ValueError(
f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
)
self.prefill(batch)
self.clear()
return batch_size * self.model.config.sequence_length

# Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
# batch sizes and sequence lengths.
seq_len = self.model.config.sequence_length
bucket_seq_len = take_nearest_length(seq_len)
requests = []
for _ in range(batch_size):
requests.append(self._create_dummy_request(seq_len))
# Prefill with different truncate sizes to test all prefill lengths
for l in PREFILL_LENGTHS:
if l > bucket_seq_len:
break
# Set all truncate values for all requests
for r in requests:
r.truncate = l
r.stopping_parameters.max_new_tokens = 10
warmup_batch = Batch(id=0,
requests=requests,
size=len(requests),
max_tokens=batch.max_tokens)
logger.debug(f"Warmup for {len(requests)} requests, truncate value {l} seq_len {seq_len}")
_generations, next_batch = self.prefill(warmup_batch)
if next_batch is not None:
self.decode([next_batch])
else:
logger.debug(f"No decode on warmup for {len(requests)}x{l}")
self.clear()

elapsed = time.time() - start
logger.debug(f"Warmup done, took {elapsed:.2f}s")
return batch_size * seq_len

@torch.no_grad
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
Expand Down

0 comments on commit aba86d9

Please sign in to comment.