Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Nov 27, 2024
1 parent 7643035 commit d8d290a
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
from jetstream.engine import sampling_utils
import torch_xla2
from transformers.generation import (
GenerationConfig,
GenerationMixin,
Expand Down Expand Up @@ -173,7 +174,10 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray:
Return:
`jnp.ndarray`: A `jnp.ndarray` containing the selected tokens.
"""
scores = self.logits_processor(input_ids, logits)
scores = self.logits_processor(input_ids, torch_xla2.tensor.j2t(logits))
scores = torch_xla2.tensor.t2j(scores).to_device(logits.device)
# breakpoint()

if self.mode == GenerationMode.SAMPLE:
# split the key to avoid reusing the same key for multiple samples
subkey, self.key = jax.random.split(self.key)
Expand Down

0 comments on commit d8d290a

Please sign in to comment.