From 7cce24ce3059e74e449c965797ee62e2a5436922 Mon Sep 17 00:00:00 2001 From: Alvaro Moran <6949769+tengomucho@users.noreply.github.com> Date: Mon, 8 Jul 2024 10:30:02 +0200 Subject: [PATCH] Few more Inference Endpoints fixes (#69) * fix(TGI): correct clear request with a give batch id * ci(tgi): create images when pushing on current branch * fix(generator): raise error if prefill receives too many requests * feat(tgi): add more prefill lenghts Since bucketing does not work for now, we add more (small) prefill lengths. This will increase the warmup time, but it will also allow to speed up generation. * Revert "ci(tgi): create images when pushing on current branch" This reverts commit 26e119330f52e46a0b450c7757c77f73fd34cb9b. * fix(test): multiple decode test require max_batch_size to be > 1 * fix(test): expected result is different when model is compiled Compiled model results are not always very good. While this should be better investigated later on, current solution is just to use the non-compiled version. This results in some tests generating different results, so expectations has been updated accordingly. * chore: bump to version v0.1.2 --- optimum/tpu/version.py | 2 +- .../server/text_generation_server/generator.py | 15 +++++++++++---- .../text_generation_server/generator_base.py | 2 +- .../server/text_generation_server/version.py | 2 +- text-generation-inference/tests/test_decode.py | 2 +- text-generation-inference/tests/test_gpt2.py | 5 ++++- 6 files changed, 19 insertions(+), 9 deletions(-) diff --git a/optimum/tpu/version.py b/optimum/tpu/version.py index 448e3719..9f1449ef 100644 --- a/optimum/tpu/version.py +++ b/optimum/tpu/version.py @@ -15,5 +15,5 @@ from pkg_resources import parse_version -__version__ = "0.1.1" +__version__ = "0.1.2" VERSION = parse_version(__version__) diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 4254e58f..58756d96 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -40,7 +40,7 @@ optimum_logger.setLevel("CRITICAL") # These will do some bucketing on prefill lengths to avoid too many different sizes -PREFILL_LENGTHS = [ +PREFILL_LENGTHS = list(range(6, 16)) + [ 16, 32, 64, @@ -446,6 +446,12 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: active_slots = slots[Slot.State.READY] # Delete all empty slots, no need to have them anymore empty_slots = slots[Slot.State.EMPTY] + model_batch_size = self.model.config.batch_size + if model_batch_size is not None and model_batch_size < len(active_slots) + len(batch.requests): + raise ValueError( + f"Cannot prefill {len(batch.requests)} new request(s)." + f" Maximum batch size supported is: {model_batch_size}." + ) for slot in empty_slots: self.slots.remove(slot) # Assign each request to an empty slot @@ -836,7 +842,8 @@ def return_to_caller(*data): cached_batch = generator.filter(batch_id, request_ids) return_to_caller(cached_batch.SerializeToString()) if command == GeneratorCommand.CLEAR: - generator.clear() + batch_id = data[0] + generator.clear(batch_id) if command == GeneratorCommand.DELETE: if rank == 0: # Set agent to ready @@ -902,8 +909,8 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: s_cached_batch = self.mailbox.send(GeneratorCommand.FILTER, batch_id, request_ids)[0] return CachedBatch.FromString(s_cached_batch) - def clear(self): - self.mailbox.send(GeneratorCommand.CLEAR) + def clear(self, batch_id: Optional[int] = None): + self.mailbox.send(GeneratorCommand.CLEAR, batch_id) def leave(self): if self.mailbox is None: diff --git a/text-generation-inference/server/text_generation_server/generator_base.py b/text-generation-inference/server/text_generation_server/generator_base.py index a3176d7c..647c2793 100644 --- a/text-generation-inference/server/text_generation_server/generator_base.py +++ b/text-generation-inference/server/text_generation_server/generator_base.py @@ -56,7 +56,7 @@ def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch: """Remove requests that are not listed from the specified batch""" raise NotImplementedError - def clear(self): + def clear(self, batch_id: Optional[int] = None): """Remove all requests from the generator""" raise NotImplementedError diff --git a/text-generation-inference/server/text_generation_server/version.py b/text-generation-inference/server/text_generation_server/version.py index 0e766d0d..49fb7124 100644 --- a/text-generation-inference/server/text_generation_server/version.py +++ b/text-generation-inference/server/text_generation_server/version.py @@ -1,5 +1,5 @@ from pkg_resources import parse_version -__version__ = "0.1.1" +__version__ = "0.1.2" VERSION = parse_version(__version__) diff --git a/text-generation-inference/tests/test_decode.py b/text-generation-inference/tests/test_decode.py index edb67a3f..982c1da6 100644 --- a/text-generation-inference/tests/test_decode.py +++ b/text-generation-inference/tests/test_decode.py @@ -44,7 +44,7 @@ def test_decode_single(params): DecodeTestParams( model_id="google/gemma-7b", sequence_length=128, - expected_text="\n\nThe time is 1984. The place is Airstrip One, the British", + expected_text="\n\nThe first line of George Orwell’s 1984 is a perfect example", ), DecodeTestParams( model_id="mistralai/Mistral-7B-v0.3", diff --git a/text-generation-inference/tests/test_gpt2.py b/text-generation-inference/tests/test_gpt2.py index a28d5219..26d61a0b 100644 --- a/text-generation-inference/tests/test_gpt2.py +++ b/text-generation-inference/tests/test_gpt2.py @@ -65,7 +65,10 @@ def test_prefill(input_text, token_id, token_text, do_sample, batch_size, model_ def test_decode_multiple(model_path): - generator = TpuGenerator.from_pretrained(model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH) + generator = TpuGenerator.from_pretrained(model_path, + revision="", + max_batch_size=2, + max_sequence_length=SEQUENCE_LENGTH) input_text = "Once upon a time" max_new_tokens = 20 # Prefill a single request, remembering the generated token