From 905695926e1fe0eca81ba61d66efc6784206866a Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 15 Mar 2024 10:02:09 +0000 Subject: [PATCH 1/5] feat(test): add tqdm to get feedback when running locally --- text-generation-inference/tests/test_generator.py | 9 +++++---- text-generation-inference/tests/test_generator_gemma.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/text-generation-inference/tests/test_generator.py b/text-generation-inference/tests/test_generator.py index 6baa8547..af001e50 100644 --- a/text-generation-inference/tests/test_generator.py +++ b/text-generation-inference/tests/test_generator.py @@ -1,5 +1,6 @@ import pytest import os +from tqdm import tqdm from text_generation_server.generator import TpuGenerator from text_generation_server.model import fetch_model from text_generation_server.pb.generate_pb2 import ( @@ -121,7 +122,7 @@ def test_decode_single(input_text, max_new_tokens, generated_text, do_sample, mo batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) generations, next_batch = generator.prefill(batch) # We already generated one token: call decode max_new_tokens - 1 times - for i in range(max_new_tokens - 1): + for _ in tqdm(range(max_new_tokens - 1), "Decoding tokens"): assert next_batch.size == 1 assert next_batch.max_tokens == 1024 assert len(generations) == 1 @@ -152,7 +153,7 @@ def test_decode_multiple(model_path): assert len(tokens[0]) == 1 # Decode a few tokens gen_tokens = 4 - for _ in range(gen_tokens - 1): + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"): generations, next_batch = generator.decode([next_batch]) assert len(generations) == 1 g = generations[0] @@ -172,7 +173,7 @@ def test_decode_multiple(model_path): assert len(tokens[1]) == 1 # Decode more tokens until we reach the maximum for the first request batches = [next_batch, next_batch_1] - for _ in range(max_new_tokens - gen_tokens): + for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"): generations, next_batch = generator.decode(batches) for g in generations: tokens[g.request_id].append(g.tokens.ids[0]) @@ -189,7 +190,7 @@ def test_decode_multiple(model_path): assert output.generated_tokens == max_new_tokens generated_text = output.text # Continue decoding until the end of the second request - for _ in range(gen_tokens - 1): + for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"): generations, next_batch = generator.decode([next_batch]) assert len(generations) == 1 g = generations[0] diff --git a/text-generation-inference/tests/test_generator_gemma.py b/text-generation-inference/tests/test_generator_gemma.py index aa7fd4bb..1269f41d 100644 --- a/text-generation-inference/tests/test_generator_gemma.py +++ b/text-generation-inference/tests/test_generator_gemma.py @@ -1,5 +1,6 @@ import pytest import os +from tqdm import tqdm from text_generation_server.generator import TpuGenerator from text_generation_server.model import fetch_model from text_generation_server.pb.generate_pb2 import ( @@ -57,7 +58,7 @@ def test_decode_single(model_path): batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) generations, next_batch = generator.prefill(batch) # We already generated one token: call decode max_new_tokens - 1 times - for _ in range(max_new_tokens - 1): + for _ in tqdm(range(max_new_tokens - 1)): assert next_batch.size == 1 assert next_batch.max_tokens == 1024 assert len(generations) == 1 From 1ec9adbab3d8d25322aba151a585919168b788b0 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 15 Mar 2024 10:34:36 +0000 Subject: [PATCH 2/5] fix(test): remove generation config warnings --- text-generation-inference/tests/test_generator.py | 6 ++++++ text-generation-inference/tests/test_generator_gemma.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/text-generation-inference/tests/test_generator.py b/text-generation-inference/tests/test_generator.py index af001e50..ea729beb 100644 --- a/text-generation-inference/tests/test_generator.py +++ b/text-generation-inference/tests/test_generator.py @@ -45,6 +45,11 @@ def create_request( seed: int = 0, repetition_penalty: float = 1.0, ): + # For these tests we can safely set typical_p to 1.0 (default) + typical_p = 1.0 + if do_sample == False: + # Drop top_p parameter to avoid warnings + top_p = 1.0 parameters = NextTokenChooserParameters( temperature=temperature, top_k=top_k, @@ -52,6 +57,7 @@ def create_request( do_sample=do_sample, seed=seed, repetition_penalty=repetition_penalty, + typical_p=typical_p, ) stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens) return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters) diff --git a/text-generation-inference/tests/test_generator_gemma.py b/text-generation-inference/tests/test_generator_gemma.py index 1269f41d..92c2fd71 100644 --- a/text-generation-inference/tests/test_generator_gemma.py +++ b/text-generation-inference/tests/test_generator_gemma.py @@ -36,6 +36,11 @@ def create_request( seed: int = 0, repetition_penalty: float = 1.0, ): + # For these tests we can safely set typical_p to 1.0 (default) + typical_p = 1.0 + if do_sample == False: + # Drop top_p parameter to avoid warnings + top_p = 1.0 parameters = NextTokenChooserParameters( temperature=temperature, top_k=top_k, @@ -43,6 +48,7 @@ def create_request( do_sample=do_sample, seed=seed, repetition_penalty=repetition_penalty, + typical_p=typical_p, ) stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens) return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters) From c3549ab1794d91e134b3d46d2e5101237fa57a55 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 15 Mar 2024 14:42:25 +0000 Subject: [PATCH 3/5] feat: compilation can be enabled only for decoding This will only enable compilation for decoding. Note that there is not a big speedup for now, probably due to slot increasing buffer size over time, triggering recompilation. --- .../server/text_generation_server/generator.py | 17 +++++++++++++---- .../server/text_generation_server/modeling.py | 6 +----- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 2a282a82..72cf20ca 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -1,6 +1,7 @@ import copy import logging import time +import os from abc import ABC from enum import Enum from typing import List, Optional, Tuple @@ -304,6 +305,12 @@ def __init__( tokenizer: PreTrainedTokenizerBase, ): self.model = model + if model.device.type == "xla" and "DBG_COMPILE" in os.environ: + self.model_one_token = torch.compile(model, backend="openxla") + logger.debug("Model compiled for decoding") + else: + self.model_one_token = model + # Specify padding options for decoder-only architecture tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" @@ -426,7 +433,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # Reset/clear KV cache self.past_key_values = None generation, next_batch = self._generate_token( - batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args + batch.id, input_ids, self.model, attention_mask=attention_mask, position_ids=position_ids, **extra_args ) # Reactivate previously active slots for the next decode, and append @@ -494,15 +501,17 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa else: extra_args["attention_mask"] = attention_mask extra_args["past_key_values"] = self.past_key_values - return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args) + return self._generate_token( + next_batch_id, input_ids, self.model_one_token, position_ids=position_ids, **extra_args + ) def _generate_token( - self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params + self, next_batch_id: int, input_ids: torch.LongTensor, model: torch.nn.Module, **forward_extra_params ) -> Tuple[List[Generation], CachedBatch]: # Add barrier to allow next graph step to always be the same xm.mark_step() # Forward - outputs = self.model( + outputs = model( input_ids, return_dict=True, use_cache=True, diff --git a/text-generation-inference/server/text_generation_server/modeling.py b/text-generation-inference/server/text_generation_server/modeling.py index dd420e95..bb59d1f8 100644 --- a/text-generation-inference/server/text_generation_server/modeling.py +++ b/text-generation-inference/server/text_generation_server/modeling.py @@ -62,11 +62,7 @@ def from_pretrained( model.config.batch_size = batch_size if sequence_length is not None or getattr(model.config, "sequence_length", None) is None: model.config.sequence_length = sequence_length - - # Do eval, and compile + # Do eval model.eval() - if device == "xla" and "DBG_COMPILE" in environ: - model = torch.compile(model, backend="openxla_eval") - logger.debug("Model compiled.") return model From 3a4d10b26d539029c85ae4248c4b15ba2dbc24ff Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 20 Mar 2024 13:47:44 +0000 Subject: [PATCH 4/5] feat: logits post-processing happens on CPU Logits post-processing is not very heavyweight, and doing it on CPU actually accelerates decoding, because compilation is not re-triggered. --- .../text_generation_server/generator.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 72cf20ca..e35dba6f 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -4,7 +4,7 @@ import os from abc import ABC from enum import Enum -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict import torch import torch_xla.core.xla_model as xm @@ -182,7 +182,7 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s selector: (`TokenSelector`): An object implementing the updated token selection logic. """ - self._tokens = input_ids.clone() + self._tokens = input_ids.cpu() self._next_text_token_start = 0 self._next_text_token_end = torch.numel(self._tokens) self._next_text = "" @@ -211,9 +211,11 @@ def _decode_next_tokens( self, ) -> str: """Hack to hopefully support generate_stream for the maximum number of tokenizers""" + # Copy the tokens to CPU to avoid recompilation on TPU. Post-processing is quite fast anyway. + tokens = self._tokens.cpu() # We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode # which decide to add a space or not depending on the surrounding ids. - new_text = self._tokenizer.decode(self._tokens[self._next_text_token_start :], skip_special_tokens=False) + new_text = self._tokenizer.decode(tokens[self._next_text_token_start :], skip_special_tokens=False) if new_text.endswith("�"): # utf-8 char at the end means it's a potential unfinished byte sequence # from byte fallback tokenization. @@ -221,7 +223,7 @@ def _decode_next_tokens( # Compare the generated text with the one using only the tokens producing the last one last_text = self._tokenizer.decode( - self._tokens[self._next_text_token_start : self._next_text_token_end], + tokens[self._next_text_token_start : self._next_text_token_end], skip_special_tokens=False, ) if len(new_text) == len(last_text): @@ -229,7 +231,7 @@ def _decode_next_tokens( return "" # Return the decoded text and store its token offsets self._next_text_token_start = self._next_text_token_end - self._next_text_token_end = torch.numel(self._tokens) + self._next_text_token_end = torch.numel(tokens) return new_text[len(last_text) :] def append(self, next_token: int) -> str: @@ -249,7 +251,7 @@ def append(self, next_token: int) -> str: The corresponding decoded text (if any). """ self._tokens = torch.cat( - [self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)] + [self._tokens, torch.tensor([next_token], dtype=self._tokens.dtype)] ) # Update mask only if it was set previously if self._mask is not None: @@ -521,8 +523,11 @@ def _generate_token( # Save KV cache self.past_key_values = outputs.past_key_values # Barrier for XLA model - xm.mark_step(wait=False) + xm.mark_step() + ret = self._post_generate(outputs, next_batch_id, input_ids) + return ret + def _post_generate(self, outputs: Dict, next_batch_id: int, input_ids: torch.LongTensor) -> Tuple[List[Generation], CachedBatch]: generations = [] active_slots = False for i, slot in enumerate(self.slots): From 401eea67b3470c6cfc7ff1644af1975fac77dcd5 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 22 Mar 2024 09:47:25 +0100 Subject: [PATCH 5/5] fix: comparison to False should be `cond is False` --- text-generation-inference/tests/test_generator.py | 2 +- text-generation-inference/tests/test_generator_gemma.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/text-generation-inference/tests/test_generator.py b/text-generation-inference/tests/test_generator.py index ea729beb..1ab77a3e 100644 --- a/text-generation-inference/tests/test_generator.py +++ b/text-generation-inference/tests/test_generator.py @@ -47,7 +47,7 @@ def create_request( ): # For these tests we can safely set typical_p to 1.0 (default) typical_p = 1.0 - if do_sample == False: + if not do_sample: # Drop top_p parameter to avoid warnings top_p = 1.0 parameters = NextTokenChooserParameters( diff --git a/text-generation-inference/tests/test_generator_gemma.py b/text-generation-inference/tests/test_generator_gemma.py index 92c2fd71..9139edd1 100644 --- a/text-generation-inference/tests/test_generator_gemma.py +++ b/text-generation-inference/tests/test_generator_gemma.py @@ -38,7 +38,7 @@ def create_request( ): # For these tests we can safely set typical_p to 1.0 (default) typical_p = 1.0 - if do_sample == False: + if not do_sample: # Drop top_p parameter to avoid warnings top_p = 1.0 parameters = NextTokenChooserParameters(