Skip to content

Commit

Permalink
☝️ Update Jetstream Pytorch revision (#91)
Browse files Browse the repository at this point in the history
* feat(Jetstream PT): update git version

* feat(jetstream_pt): drop prefill_ex method

* feat(jetstream_pt): drop generate_ex method

Now that decode takes a specialized method this is not necessary
anymore.

* refactor(Jetstream PT): drop HfEngine

Now that sampling can be passed as parameter to prefill and generate,
the custom engine is not required anymore.

* review(Jetstream Pt): avoid double iteration on slots
  • Loading branch information
tengomucho authored Sep 12, 2024
1 parent b25e973 commit 03b6573
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 108 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@df92015289953c506004e674d57651b03e4e89f2",
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921",
"torch-xla[pallas] == 2.4.0"
]

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jax
from jetstream_pt import fetch_models, torchjax
from jetstream_pt.engine import PyTorchEngine
from jetstream_pt.environment import (
JetEngineEnvironment,
JetEngineEnvironmentData,
Expand All @@ -17,19 +18,19 @@
from transformers import PretrainedConfig
from transformers import AutoConfig

from .engine import HfEngine
from .llama_model_exportable_hf import TransformerHf


def load_llama_model_info(config: "PretrainedConfig") -> Any:
num_layers = config.num_hidden_layers
num_heads = config.num_attention_heads
head_dim = config.hidden_size // num_heads
n_reps = num_heads // config.num_key_value_heads
num_kv_heads = config.num_key_value_heads
n_reps = num_heads // num_kv_heads
model_info = fetch_models.ModelInfo(
TransformerHf,
num_layers=num_layers,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
n_reps=n_reps,
)
Expand Down Expand Up @@ -76,7 +77,7 @@ def create_engine_env_data(
)
env_data.cache_shape = (
batch_size,
config.num_key_value_heads,
model_info.num_kv_heads,
max_cache_length,
model_info.head_dim,
)
Expand Down Expand Up @@ -135,7 +136,7 @@ def create_engine(
sequence_length: int,
max_input_tokens: int,
max_output_tokens: int,
) -> HfEngine:
) -> PyTorchEngine:
# NOTE: for now no quantization is done
env_data = create_engine_env_data(model_path, batch_size, sequence_length, max_input_tokens, max_output_tokens)
if env_data is None:
Expand All @@ -146,7 +147,7 @@ def create_engine(
weight_shardings = model.get_sharding_annotations()
sharded_weights = shard_weights(env, model.state_dict(), weight_shardings)

return HfEngine(
return PyTorchEngine(
pt_model=model,
env=env,
weights=torchjax.from_torch_with_copy(sharded_weights),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import logging
import time
from enum import Enum
from typing import Any, List, Optional, Tuple
from typing import List, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import torch
import torch_xla2
from jetstream.engine.token_utils import pad_tokens, take_nearest_length
from jetstream_pt.engine import PyTorchEngine
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers.generation import GenerationConfig
Expand All @@ -27,7 +28,6 @@
StoppingCriteriaParameters,
Tokens,
)
from .engine import HfEngine
from .engine_loader import create_engine
from .token_selector import TokenSelector

Expand Down Expand Up @@ -215,6 +215,8 @@ def select(self, logits: jnp.ndarray) -> int:
Return:
int: A scalar of the selected token.
"""
if len(logits.shape) == 1:
logits = logits.reshape(1, -1)
return self._selector.select(self._tokens, logits)[0]

@property
Expand All @@ -241,7 +243,7 @@ class TpuGeneratorJetStream(Generator):

def __init__(
self,
engine: HfEngine,
engine: PyTorchEngine,
tokenizer: PreTrainedTokenizerBase,
):
self.engine = engine
Expand Down Expand Up @@ -452,11 +454,11 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# To allow jit'ing the select function, we need to wrap it in a partial
slot_select = jax.tree_util.Partial(slot.select)
# Ask for prefill and insert
prefill_results, _result_tokens = self.engine.prefill_ex(
prefill_results, _result_tokens = self.engine.prefill(
params=self.params,
padded_tokens=input_ids,
true_length=true_lengths,
sampling_fn=slot_select,
sampler=slot_select,
)
next_token = prefill_results.token.item()
self.decode_state = self.engine.insert(prefill_results, self.decode_state, slot.id)
Expand All @@ -477,6 +479,16 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
logger.debug("Model ready for decoding")
return generations, batch

def _select_from_slots(self, logits: jnp.ndarray, batch_size: int=0) -> jnp.ndarray:
pad_token_id = self.tokenizer.pad_token_id
batch_size = logits.shape[0]
tokens = jnp.full((batch_size, 1), pad_token_id)
for slot in filter(lambda slot: slot.state == slot.State.READY, self.slots):
# Every slot might have a different selection criteria, so we are obliged to call select in a loop
next_token = slot.select(logits)
tokens = tokens.at[slot.id].set(next_token)
return tokens

def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode the specified prefilled requests.
Expand Down Expand Up @@ -510,19 +522,9 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
if len(active_slots) < len(request_ids):
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")

# Define a custom function to select the next token for each slot
pad_token_id = self.tokenizer.pad_token_id

def select_from_slots(logits: Any, batch_size: int) -> jnp.ndarray:
tokens = jnp.full((batch_size, 1), pad_token_id)
for slot in active_slots:
# Every slot might have a different selection criteria, so we are obliged to call select in a loop
next_token = slot.select(logits)
tokens = tokens.at[slot.id].set(next_token)
return tokens

select_fn = select_from_slots
self.decode_state, result_tokens = self.engine.generate_ex(self.params, self.decode_state, select_fn)
# Use a custom function to select the next token for each slot
select_fn = jax.tree_util.Partial(self._select_from_slots)
self.decode_state, result_tokens = self.engine.generate(self.params, self.decode_state, select_fn)

newly_empty = []
generations = []
Expand Down

0 comments on commit 03b6573

Please sign in to comment.