Skip to content

Commit

Permalink
feat(Jetstream Pt): speed-up prefill by avoiding redundant compilation
Browse files Browse the repository at this point in the history
A new slot is created at each prefill request, and its selector is
passed as argument to a jitted function. The problem is that each new
slot has a new signature, even if the contents are the same. The
solution is to wrap that in a singleton slot object for the prefill, so
the compiler will always see the same object and stop recompiling.
  • Loading branch information
tengomucho committed Sep 16, 2024
1 parent abf5f5a commit 5e66bf7
Showing 1 changed file with 13 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ def seed(self) -> int:
return self._seed


class PrefillSlot:
def __init__(self):
self._curslot = None

def set(self, slot: Slot):
self._curslot = slot

def select(self, logits: jnp.ndarray) -> int:
return self._curslot.select(logits)

class TpuGeneratorJetStream(Generator):
"""A Generator for models running on TPU, single threaded."""

Expand Down Expand Up @@ -273,6 +283,7 @@ def __init__(
self.batch_id = 0
# Note: this index will _never_ be decremented, and that's fine.
self.slot_index = 0
self.prefill_slot = PrefillSlot()

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -443,6 +454,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
for request in batch.requests:
# Dynamically create a new slot for each request
slot = Slot(self._get_slot_id(), self.tokenizer)
self.prefill_slot.set(slot)
self.slot_index += 1
slot.assign(self.batch_id, request, self.model.generation_config)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
Expand All @@ -459,7 +471,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
)
slot.reset(truncated_input_ids, selector)
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(slot.select)
slot_select = jax.tree_util.Partial(self.prefill_slot.select)
# Ask for prefill and insert
prefill_results, _result_tokens = self.engine.prefill(
params=self.params,
Expand Down

0 comments on commit 5e66bf7

Please sign in to comment.