diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py index ce0820c4..99690d6d 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/token_selector.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp from jetstream.engine import sampling_utils +import torch_xla2 from transformers.generation import ( GenerationConfig, GenerationMixin, @@ -173,7 +174,10 @@ def select(self, input_ids: jnp.ndarray, logits: jnp.ndarray) -> jnp.ndarray: Return: `jnp.ndarray`: A `jnp.ndarray` containing the selected tokens. """ - scores = self.logits_processor(input_ids, logits) + scores = self.logits_processor(input_ids, torch_xla2.tensor.j2t(logits)) + scores = torch_xla2.tensor.t2j(scores).to_device(logits.device) + # breakpoint() + if self.mode == GenerationMode.SAMPLE: # split the key to avoid reusing the same key for multiple samples subkey, self.key = jax.random.split(self.key)