diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index b3bc00a2..359d90cd 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -483,8 +483,7 @@ def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndar pad_token_id = self.tokenizer.pad_token_id batch_size = logits.shape[0] tokens = jnp.full((batch_size, 1), pad_token_id) - active_slots = [slot for slot in self.slots if slot.state == slot.State.READY] - for slot in active_slots: + for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots): # Every slot might have a different selection criteria, so we are obliged to call select in a loop next_token = slot.select(logits) tokens = tokens.at[slot.id].set(next_token)