Skip to content

Commit

Permalink
chore(jetstream): token selector operations are done in torch
Browse files Browse the repository at this point in the history
Conversions of scores tensors from jax to torch and back are done when
calling logits processor. This will be required in newer versions of
transformers.
  • Loading branch information
tengomucho committed Nov 28, 2024
1 parent 4182b74 commit 73866d6
Showing 1 changed file with 7 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import jax
import jax.numpy as jnp
import torch_xla2
from jetstream.engine import sampling_utils
from transformers.generation import (
GenerationConfig,
Expand Down Expand Up @@ -173,7 +174,12 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
Return:
`jnp.ndarray`: A `jnp.ndarray` containing the selected tokens.
"""
scores = self.logits_processor(input_ids, logits)
# Logits processors is written in pytorch, so parameters are cast to float32 and converted to pytorch and back
# to jax with j2t/t2j (that is a bit expensive, it does copies), otherwise some operations are not supported.
logits_t = torch_xla2.tensor.j2t(logits.astype(jnp.float32))
scores = self.logits_processor(input_ids, logits_t)
scores = torch_xla2.tensor.t2j(scores).to_device(logits.device)

if self.mode == GenerationMode.SAMPLE:
# split the key to avoid reusing the same key for multiple samples
subkey, self.key = jax.random.split(self.key)
Expand Down

0 comments on commit 73866d6

Please sign in to comment.