diff --git a/pyproject.toml b/pyproject.toml index 30c1ef5b..01ad1a5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ quality = ["black", "ruff", "isort"] # Jetstream/Pytorch support is experimental for now, requires installation from fixed commit. # Pallas is pulled because it will install a compatible version of jax[tpu]. jetstream-pt = [ - "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@df92015289953c506004e674d57651b03e4e89f2", + "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921", "torch-xla[pallas] == 2.4.0" ] diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py deleted file mode 100644 index 4874aa3a..00000000 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Any, Callable, Optional, Tuple - -import jax -import jax.numpy as jnp -import torch -from jetstream.engine import engine_api -from jetstream_pt import engine - - -class HfEngine(engine.PyTorchEngine): - def __init__( - self, - pt_model: torch.nn.Module, - env: engine.JetEngineEnvironment, - weights=None, - ): - super().__init__(pt_model, env, weights) - self.prefill_ex = jax.jit( - self.prefill_ex, - out_shardings=(self.get_prefix_destination_sharding(), None), - ) - - def generate_ex( - self, params: Any, decode_state: engine.DecodeState, sampling_fn: Callable[[Any, int], jax.Array] - ) -> tuple[engine.DecodeState, engine_api.ResultTokens]: - sampling_fn_backup = self._sampling - self._sampling = sampling_fn - new_decode_state, result_tokens = self.generate(params, decode_state) - self._sampling = sampling_fn_backup - return new_decode_state, result_tokens - - def prefill_ex( - self, - *, - params: Any, # Weights - _existing_prefix: Optional[engine.Prefix] = None, - padded_tokens: jax.Array, - true_length: int, - sampling_fn: Callable[[jax.Array], jax.Array], - ) -> Tuple[engine.Prefix, engine_api.ResultTokens]: - if isinstance(padded_tokens, jax.Array): - batched_token = padded_tokens.reshape(1, -1) - else: - raise TypeError("Input tokens should be of type Jax Array, but receiving:" " {prefill_inputs}") - seq_len = padded_tokens.shape[0] - input_indexes = jnp.arange(0, seq_len) - logits, updated_caches = self._call_model_prefill( - params, - batched_token, - input_indexes, - ) - if len(logits.shape) == 3: # b, seqlen, num words - logits = logits[0] # seqlen, num words - - # This is equivalent to last_logits = logits[:, true_length - 1, :], but it can be jitted - last_logits = jax.lax.dynamic_slice_in_dim(logits, true_length - 1, 1, axis=0) - token = sampling_fn(last_logits) - token_out = jnp.reshape(token, (1, 1)) - data = jnp.concatenate( - [ - token_out, # First token - jnp.ones_like(token_out), # validity of first token - jnp.zeros((1, 1), dtype=jnp.int32), # length = 0 - ], - axis=-1, - ) - length = token_out.shape[1] - result = engine_api.ResultTokens( - data=data, - tokens_idx=(0, length), - valid_idx=(length, 2 * length), - length_idx=(2 * length, 2 * length + 1), - samples_per_slot=1, - ) - # truncate to true_length didnt work need to be out side of jit - # caches = [ - # (jax.lax.dynamic_slice_in_dim( - # k, seq_len - true_length, true_length, axis=2), - # jax.lax.dynamic_slice_in_dim( - # v, seq_len - true_length, true_length, axis=2)) - # for k, v in updated_caches - # ] - return engine.Prefix(token, updated_caches, true_length), result diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 332cf5de..91cf9784 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -5,6 +5,7 @@ import jax from jetstream_pt import fetch_models, torchjax +from jetstream_pt.engine import PyTorchEngine from jetstream_pt.environment import ( JetEngineEnvironment, JetEngineEnvironmentData, @@ -17,7 +18,6 @@ from transformers import PretrainedConfig from transformers import AutoConfig -from .engine import HfEngine from .llama_model_exportable_hf import TransformerHf @@ -25,11 +25,12 @@ def load_llama_model_info(config: "PretrainedConfig") -> Any: num_layers = config.num_hidden_layers num_heads = config.num_attention_heads head_dim = config.hidden_size // num_heads - n_reps = num_heads // config.num_key_value_heads + num_kv_heads = config.num_key_value_heads + n_reps = num_heads // num_kv_heads model_info = fetch_models.ModelInfo( TransformerHf, num_layers=num_layers, - num_heads=num_heads, + num_kv_heads=num_kv_heads, head_dim=head_dim, n_reps=n_reps, ) @@ -76,7 +77,7 @@ def create_engine_env_data( ) env_data.cache_shape = ( batch_size, - config.num_key_value_heads, + model_info.num_kv_heads, max_cache_length, model_info.head_dim, ) @@ -135,7 +136,7 @@ def create_engine( sequence_length: int, max_input_tokens: int, max_output_tokens: int, -) -> HfEngine: +) -> PyTorchEngine: # NOTE: for now no quantization is done env_data = create_engine_env_data(model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens) if env_data is None: @@ -146,7 +147,7 @@ def create_engine( weight_shardings = model.get_sharding_annotations() sharded_weights = shard_weights(env, model.state_dict(), weight_shardings) - return HfEngine( + return PyTorchEngine( pt_model=model, env=env, weights=torchjax.from_torch_with_copy(sharded_weights), diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 59291bf7..359d90cd 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -2,7 +2,7 @@ import logging import time from enum import Enum -from typing import Any, List, Optional, Tuple +from typing import List, Optional, Tuple import jax import jax.numpy as jnp @@ -10,6 +10,7 @@ import torch import torch_xla2 from jetstream.engine.token_utils import pad_tokens, take_nearest_length +from jetstream_pt.engine import PyTorchEngine from loguru import logger from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers.generation import GenerationConfig @@ -27,7 +28,6 @@ StoppingCriteriaParameters, Tokens, ) -from .engine import HfEngine from .engine_loader import create_engine from .token_selector import TokenSelector @@ -215,6 +215,8 @@ def select(self, logits: jnp.ndarray) -> int: Return: int: A scalar of the selected token. """ + if len(logits.shape) == 1: + logits = logits.reshape(1, -1) return self._selector.select(self._tokens, logits)[0] @property @@ -241,7 +243,7 @@ class TpuGeneratorJetStream(Generator): def __init__( self, - engine: HfEngine, + engine: PyTorchEngine, tokenizer: PreTrainedTokenizerBase, ): self.engine = engine @@ -452,11 +454,11 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # To allow jit'ing the select function, we need to wrap it in a partial slot_select = jax.tree_util.Partial(slot.select) # Ask for prefill and insert - prefill_results, _result_tokens = self.engine.prefill_ex( + prefill_results, _result_tokens = self.engine.prefill( params=self.params, padded_tokens=input_ids, true_length=true_lengths, - sampling_fn=slot_select, + sampler=slot_select, ) next_token = prefill_results.token.item() self.decode_state = self.engine.insert(prefill_results, self.decode_state, slot.id) @@ -477,6 +479,16 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: logger.debug("Model ready for decoding") return generations, batch + def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray: + pad_token_id = self.tokenizer.pad_token_id + batch_size = logits.shape[0] + tokens = jnp.full((batch_size, 1), pad_token_id) + for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots): + # Every slot might have a different selection criteria, so we are obliged to call select in a loop + next_token = slot.select(logits) + tokens = tokens.at[slot.id].set(next_token) + return tokens + def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: """Decode the specified prefilled requests. @@ -510,19 +522,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa if len(active_slots) < len(request_ids): raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") - # Define a custom function to select the next token for each slot - pad_token_id = self.tokenizer.pad_token_id - - def select_from_slots(logits: Any, batch_size: int) -> jnp.ndarray: - tokens = jnp.full((batch_size, 1), pad_token_id) - for slot in active_slots: - # Every slot might have a different selection criteria, so we are obliged to call select in a loop - next_token = slot.select(logits) - tokens = tokens.at[slot.id].set(next_token) - return tokens - - select_fn = select_from_slots - self.decode_state, result_tokens = self.engine.generate_ex(self.params, self.decode_state, select_fn) + # Use a custom function to select the next token for each slot + select_fn = jax.tree_util.Partial(self._select_from_slots) + self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn) newly_empty = [] generations = []