diff --git a/text-generation-inference/integration-tests/test_gpt2.py b/text-generation-inference/integration-tests/test_gpt2.py index 3219ddd5..d200bd5d 100644 --- a/text-generation-inference/integration-tests/test_gpt2.py +++ b/text-generation-inference/integration-tests/test_gpt2.py @@ -5,13 +5,11 @@ MODEL_ID = "openai-community/gpt2" -BATCH_SIZE = 4 SEQUENCE_LENGTH = 1024 @pytest.fixture(scope="module") def model_name_or_path(): - os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) yield MODEL_ID diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 406764ee..6bb327ad 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -312,7 +312,10 @@ def __init__( tokenizer.padding_side = "left" self.tokenizer = tokenizer self.special_tokens = self.tokenizer.all_special_ids - self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)] + # Slots are empty to begin with, they will be populated as new batches arrive + self.slots = [] + # Note: this index will _never_ be decremented, and that's fine. + self.slot_index = 0 self.past_key_values = None # _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup # a static cache, otherwise it is not. @@ -350,13 +353,18 @@ def warmup(self, batch: Batch) -> int: The maximum number of tokens the model supports. """ # Just check that the warmup request parameters match the model capacity - batch_size = self.model.config.batch_size + # NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size. + if self.model.config.batch_size is not None: + batch_size = self.model.config.batch_size + else: + # batch size is not set, just assume it's unlimited and accept all requests + batch_size = len(batch.requests) if len(batch.requests) > batch_size: raise ValueError( f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length." ) self.prefill(batch) - return self.model.config.batch_size * self.model.config.sequence_length + return batch_size * self.model.config.sequence_length @torch.no_grad def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: @@ -373,17 +381,18 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: for slot in self.slots: slots[slot.state].append(slot) active_slots = slots[Slot.State.READY] + # Delete all empty slots, no need to have them anymore empty_slots = slots[Slot.State.EMPTY] - if len(empty_slots) < len(batch.requests): - raise ValueError( - f"Cannot prefill {len(batch.requests)} new request(s) with only {len(empty_slots)} empty slots." - f"Please align the number of concurrent requests with the static batch size: {self.model.batch_size}." - ) + for slot in empty_slots: + self.slots.remove(slot) # Assign each request to an empty slot - logger.debug(f"Prefilling {len(batch.requests)} new request(s) with {len(empty_slots)} empty slot(s)") + logger.debug(f"Prefilling {len(batch.requests)} new request(s) adding to {len(active_slots)} active slot(s)") for request in batch.requests: - slot = empty_slots.pop() + # Dynamically create a new slot for each request + slot = Slot(self.slot_index, self.tokenizer, self.model.device) + self.slot_index += 1 slot.assign(request, self.model.generation_config) + self.slots.append(slot) logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") # Reconstruct the full inputs (without padding) as seen by the model. # This comprises: @@ -466,8 +475,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa # Reconstruct input_ids and attention_mask from slots input_ids = None attention_mask = None + batch_size = len(self.slots) position_ids = torch.zeros( - [self.model.config.batch_size, 1], + [batch_size, 1], dtype=torch.int64, device=self.model.device, ) @@ -476,7 +486,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa if input_ids is None: # Create blank inputs covering all slots (even empty ones) input_ids = torch.full( - [self.model.config.batch_size, 1], + [batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64, device=self.model.device, @@ -488,7 +498,7 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa if attention_mask is None: # Create default mask covering all slots (even empty ones) attention_mask = torch.zeros( - [self.model.config.batch_size, slot.attention_mask.size(-1)], + [batch_size, slot.attention_mask.size(-1)], dtype=torch.int64, device=self.model.device, ) @@ -550,7 +560,8 @@ def _post_generate(self, outputs: Dict, next_batch_id: int, input_ids: torch.Lon text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason ) logger.debug(f"Finished generating tokens for request {request_id}") - # mark the slot as available + # This slot is now empty, it will be removed from the list of + # active slots once a new prefill is requested slot.clear() else: active_slots = True diff --git a/text-generation-inference/tests/test_gemma.py b/text-generation-inference/tests/test_gemma.py index 9139edd1..69a85a08 100644 --- a/text-generation-inference/tests/test_gemma.py +++ b/text-generation-inference/tests/test_gemma.py @@ -12,14 +12,12 @@ MODEL_ID = "google/gemma-2b" -BATCH_SIZE = 4 SEQUENCE_LENGTH = 1024 @pytest.fixture(scope="module") def model_path(): # Add variables to environment so they can be used in TpuModelForCausalLM - os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) path = fetch_model(MODEL_ID) return path diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py index 1ab77a3e..2638f620 100644 --- a/text-generation-inference/tests/test_gpt2.py +++ b/text-generation-inference/tests/test_gpt2.py @@ -12,14 +12,12 @@ MODEL_ID = "openai-community/gpt2" -BATCH_SIZE = 4 SEQUENCE_LENGTH = 1024 @pytest.fixture(scope="module") def model_path(): # Add variables to environment so they can be used in TpuModelForCausalLM - os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) path = fetch_model(MODEL_ID) return path @@ -84,7 +82,6 @@ def create_request( @pytest.mark.parametrize("batch_size", [1, 4], ids=["single", "multiple"]) def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_path): generator = TpuGenerator.from_pretrained(model_path) - assert generator.model.config.batch_size >= batch_size requests = [] max_new_tokens = 20 for i in range(batch_size): @@ -110,13 +107,13 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_ [ "It was a bright cold day in April, and the clocks were striking thirteen.", 20, - " The sun was setting, and the wind was blowing. The sun was shining, and the wind was", + " The sun was setting, and the wind was blowing. The sun was setting, and the wind was", False, ], [ "It was a bright cold day in April, and the clocks were striking thirteen.", 20, - " We sat outside the house, drinking coffee, listening to the orchestra playing through the window. We could", + " We sat at the front door and watched the clock on a box of yellow paper and found it almost", True, ], ], @@ -144,7 +141,6 @@ def test_decode_single(input_text, max_new_tokens, generated_text, do_sample, mo def test_decode_multiple(model_path): generator = TpuGenerator.from_pretrained(model_path) - assert generator.model.config.batch_size > 1 input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token