diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index 448e3719..9f1449ef 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -15,5 +15,5 @@ from pkg_resources import parse_version -__version__ = "0.1.1" +__version__ = "0.1.2" VERSION = parse_version(__version__) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 4254e58f..58756d96 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -40,7 +40,7 @@ optimum_logger.setLevel("CRITICAL") # These will do some bucketing on prefill lengths to avoid too many different sizes -PREFILL_LENGTHS = [ +PREFILL_LENGTHS = list(range(6, 16)) + [ 16, 32, 64, @@ -446,6 +446,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: active_slots = slots[Slot.State.READY] # Delete all empty slots, no need to have them anymore empty_slots = slots[Slot.State.EMPTY] + model_batch_size = self.model.config.batch_size + if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests): + raise ValueError( + f"Cannot prefill {len(batch.requests)} new request(s)." + f" Maximum batch size supported is: {model_batch_size}." + ) for slot in empty_slots: self.slots.remove(slot) # Assign each request to an empty slot @@ -836,7 +842,8 @@ def return_to_caller(*data): cached_batch = generator.filter(batch_id, request_ids) return_to_caller(cached_batch.SerializeToString()) if command == GeneratorCommand.CLEAR: - generator.clear() + batch_id = data[0] + generator.clear(batch_id) if command == GeneratorCommand.DELETE: if rank == 0: # Set agent to ready @@ -902,8 +909,8 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: s_cached_batch = self.mailbox.send(GeneratorCommand.FILTER, batch_id, request_ids)[0] return CachedBatch.FromString(s_cached_batch) - def clear(self): - self.mailbox.send(GeneratorCommand.CLEAR) + def clear(self, batch_id: Optional[int] = None): + self.mailbox.send(GeneratorCommand.CLEAR, batch_id) def leave(self): if self.mailbox is None: diff --git a/text-generation-inference/server/text_generation_server/generator_base.py b/text-generation-inference/server/text_generation_server/generator_base.py index a3176d7c..647c2793 100644 --- a/text-generation-inference/server/text_generation_server/generator_base.py +++ b/text-generation-inference/server/text_generation_server/generator_base.py @@ -56,7 +56,7 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: """Remove requests that are not listed from the specified batch""" raise NotImplementedError - def clear(self): + def clear(self, batch_id: Optional[int] = None): """Remove all requests from the generator""" raise NotImplementedError diff --git a/text-generation-inference/server/text_generation_server/version.py b/text-generation-inference/server/text_generation_server/version.py index 0e766d0d..49fb7124 100644 --- a/text-generation-inference/server/text_generation_server/version.py +++ b/text-generation-inference/server/text_generation_server/version.py @@ -1,5 +1,5 @@ from pkg_resources import parse_version -__version__ = "0.1.1" +__version__ = "0.1.2" VERSION = parse_version(__version__) diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index edb67a3f..982c1da6 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -44,7 +44,7 @@ def test_decode_single(params): DecodeTestParams( model_id="google/gemma-7b", sequence_length=128, - expected_text="\n\nThe time is 1984. The place is Airstrip One, the British", + expected_text="\n\nThe first line of George Orwell’s 1984 is a perfect example", ), DecodeTestParams( model_id="mistralai/Mistral-7B-v0.3", diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py index a28d5219..26d61a0b 100644 --- a/text-generation-inference/tests/test_gpt2.py +++ b/text-generation-inference/tests/test_gpt2.py @@ -65,7 +65,10 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_ def test_decode_multiple(model_path): - generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH) + generator = TpuGenerator.from_pretrained(model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH) input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token