Skip to content

Commit

Permalink
review(Jetstream Pt): avoid double iteration on slots
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Sep 12, 2024
1 parent 4674367 commit ed796f9
Showing 1 changed file with 1 addition and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,7 @@ def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndar
pad_token_id = self.tokenizer.pad_token_id
batch_size = logits.shape[0]
tokens = jnp.full((batch_size, 1), pad_token_id)
active_slots = [slot for slot in self.slots if slot.state == slot.State.READY]
for slot in active_slots:
for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots):
# Every slot might have a different selection criteria, so we are obliged to call select in a loop
next_token = slot.select(logits)
tokens = tokens.at[slot.id].set(next_token)
Expand Down

0 comments on commit ed796f9

Please sign in to comment.