From 5e66bf7e7818f7d585cd6db8de1768c2b94cfe22 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Fri, 13 Sep 2024 14:14:59 +0000 Subject: [PATCH] feat(Jetstream Pt): speed-up prefill by avoiding redundant compilation A new slot is created at each prefill request, and its selector is passed as argument to a jitted function. The problem is that each new slot has a new signature, even if the contents are the same. The solution is to wrap that in a singleton slot object for the prefill, so the compiler will always see the same object and stop recompiling. --- .../jetstream_pt_support/generator.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index b6e24051..8f896966 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -244,6 +244,16 @@ def seed(self) -> int: return self._seed +class PrefillSlot: + def __init__(self): + self._curslot = None + + def set(self, slot: Slot): + self._curslot = slot + + def select(self, logits: jnp.ndarray) -> int: + return self._curslot.select(logits) + class TpuGeneratorJetStream(Generator): """A Generator for models running on TPU, single threaded.""" @@ -273,6 +283,7 @@ def __init__( self.batch_id = 0 # Note: this index will _never_ be decremented, and that's fine. self.slot_index = 0 + self.prefill_slot = PrefillSlot() @property def info(self) -> InfoResponse: @@ -443,6 +454,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: for request in batch.requests: # Dynamically create a new slot for each request slot = Slot(self._get_slot_id(), self.tokenizer) + self.prefill_slot.set(slot) self.slot_index += 1 slot.assign(self.batch_id, request, self.model.generation_config) logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") @@ -459,7 +471,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: ) slot.reset(truncated_input_ids, selector) # To allow jit'ing the select function, we need to wrap it in a partial - slot_select = jax.tree_util.Partial(slot.select) + slot_select = jax.tree_util.Partial(self.prefill_slot.select) # Ask for prefill and insert prefill_results, _result_tokens = self.engine.prefill( params=self.params,