From 02e0138a53bd404e6253e8737711224ed3a0f22e Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Mon, 20 Jan 2025 14:54:41 +0000 Subject: [PATCH] WIP --- .../jetstream_pt_support/compatibility.py | 2 +- .../jetstream_pt_support/engine_loader.py | 6 +- .../jetstream_pt_support/generator.py | 2 +- .../jetstream_pt_support/models/__init__.py | 1 + .../models/qwen2_model.py | 292 ++++++++++++------ .../models/qwen2_model_v2.py | 214 +++++++++++++ .../tests/test_decode_jetstream.py | 8 +- text-generation-inference/tests/test_qwen2.py | 66 ++++ 8 files changed, 500 insertions(+), 91 deletions(-) create mode 100644 text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model_v2.py create mode 100644 text-generation-inference/tests/test_qwen2.py diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py index d1fc325d..456eac26 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py @@ -25,7 +25,7 @@ def model_can_use_jetstream_pt(model_path: str) -> bool: """ config = AutoConfig.from_pretrained(model_path) # For now few models are supported - supported_models = ["llama", "gemma", "mixtral"] + supported_models = ["llama", "gemma", "mixtral", "qwen2"] if config.model_type not in supported_models: return False if jetstream_pt_available(): diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 24bd600a..f5110a02 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -24,7 +24,7 @@ from transformers import AutoConfig from .compatibility import model_can_use_jetstream_pt -from .models import GemmaModel, LlamaModel, MixtralModel +from .models import GemmaModel, LlamaModel, MixtralModel, Qwen2Model class OptimumJetstreamEngine(PyTorchEngine): @@ -66,6 +66,8 @@ def load_model_info(config: "PretrainedConfig") -> Any: model_class = GemmaModel elif config.model_type == "mixtral": model_class = MixtralModel + elif config.model_type == "qwen2": + model_class = Qwen2Model else: raise ValueError(f"Unsupported model type {config.model_type}") model_info = fetch_models.ModelInfo( @@ -101,7 +103,7 @@ def create_engine_env_data( head_dim_shardable = model_info.num_kv_heads == 1 and model_info.head_dim % num_devices == 0 if num_kv_heads_shardable or head_dim_shardable: - shard_on_batch = False + shard_on_batch = False else: shard_on_batch = True aligned_batch_size = (batch_size + num_devices - 1) // num_devices * num_devices diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 1baed358..64148eea 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -407,7 +407,7 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]: tokens, true_length = pad_tokens(input_ids[0], self.tokenizer.bos_token_id, self.tokenizer.pad_token_id, - is_bos=True, + is_bos=(self.tokenizer.bos_token_id is not None), max_prefill_length=max_prefill_length, jax_padding=True, ) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py index 9855bde6..4a835fe1 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py @@ -1,3 +1,4 @@ from .gemma_model_hf import GemmaModelHf as GemmaModel from .llama_model_exportable_hf import TransformerHf as LlamaModel from .mixtral_model_hf import MixtralModelHf as MixtralModel +from .qwen2_model import Qwen2Model diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py index 16e17a98..d340ade2 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py @@ -2,22 +2,25 @@ Qwen2 model implementation, based on Jetstream implementation of Llama model. """ -from typing import Any, List, Optional import copy +from typing import Any, List, Optional + import jax -import math import torch import torch.nn.functional as F -from jetstream_pt.model_base import ModuleBase from jetstream_pt.layers import ( - Attention, - RMSNorm, - get_quantized_embedding_layer, - get_quantized_linear_layer, + AttentionKernel, + Int8KVAttentionKernel, + RMSNorm, + apply_rotary_emb, + get_quantized_embedding_layer, + get_quantized_linear_layer, ) -from torch import nn +from jetstream_pt.model_base import ModuleBase -from . import model_args +# Use llama's functions and classes that are the same as in Qwen2 +from jetstream_pt.third_party.llama.model_exportable import model_args +from transformers import GenerationConfig, GenerationMixin, Qwen2Config class FeedForward(ModuleBase): @@ -82,8 +85,145 @@ def forward(self, x): result = self.w2(F.silu(self.w1(x)) * self.w3(x)) return result +class QwenAttention(ModuleBase): + """Attention module.""" + + def __init__( + self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.n_rep = self.n_heads // self.n_kv_heads + self.env = env + self.hidden_size = hidden_size + self.layer_id = layer_id + + LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs = {"quant_config": env.quant_config} -class TransformerBlock(ModuleBase): + self.wo = LinearLayer( + n_heads * self.head_dim, + hidden_size, + bias=False, + device=device, + **linear_kwargs, + ) + + Kernel = ( + Int8KVAttentionKernel + if env.quant_config.enable_kv_quantization + else AttentionKernel + ) + self.attention_kernel = Kernel(env, self.layer_id) + + self.q_size = n_heads * self.head_dim + self.kv_size = self.n_kv_heads * self.head_dim + if self.env.qkv_fusion: + self._register_load_state_dict_pre_hook(self.load_hook) + self.wqkv = LinearLayer( + hidden_size, + (n_heads + 2 * self.n_kv_heads) * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + else: + self.wq = LinearLayer( + hidden_size, + n_heads * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + self.wk = LinearLayer( + hidden_size, + self.n_kv_heads * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + self.wv = LinearLayer( + hidden_size, + self.n_kv_heads * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ): + with jax.named_scope("attn_linear_before_cache"): + bsz, seqlen = x.shape[0], x.shape[-2] + + # qkv fuse + if self.env.qkv_fusion: + xq, xk, xv = self.wqkv(x).split( + [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + else: + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + shard_axis = 0 if self.env.shard_on_batch else 2 + self.env.apply_sharding(xq, axis=shard_axis) + self.env.apply_sharding(xk, axis=shard_axis) + self.env.apply_sharding(xv, axis=shard_axis) + + with jax.named_scope("attn_rope"): + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + xq = xq.transpose(1, 2) + + if mask.ndim == 2: + if seqlen == 1: + mask = mask[:, None, None, :] + else: + mask = mask[None, None, :, :] + + # if cache is not None and cache.cache_k is not None: + # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") + output = self.attention_kernel( + xq=xq, + xk=xk, + xv=xv, + mask=mask, + # cache[self.layer_id], + cache=cache, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, + ).type_as(xq) + # print(f"output {output.shape}") + output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class Qwen2DecoderLayer(ModuleBase): """Transformer block.""" def __init__( @@ -99,7 +239,7 @@ def __init__( self.head_dim = args.dim // args.n_heads self.args = args - self.attention = Attention( + self.attention = QwenAttention( args.n_heads, args.n_kv_heads or args.n_heads, args.dim // args.n_heads, @@ -132,6 +272,10 @@ def __init__( self.attention.annotate_sharding("wk.weight", 0) self.attention.annotate_sharding("wv.weight", 0) self.attention.annotate_sharding("wo.weight", 1) + self.attention.annotate_sharding("wq.weight.bias", 0) + self.attention.annotate_sharding("wk.weight.bias", 0) + self.attention.annotate_sharding("wv.weight.bias", 0) + self.attention.annotate_sharding("wo.weight.bias", -1) self.hf_name("feed_forward", "mlp") self.hf_name("attention_norm", "input_layernorm") @@ -168,71 +312,71 @@ def forward( return out -def apply_scaling(freqs: torch.Tensor, config: model_args.RopeScalingArgs): - # Values obtained from grid search - scale_factor = config.factor - low_freq_factor = config.low_freq_factor - high_freq_factor = config.high_freq_factor - old_context_len = config.original_max_position_embeddings - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - - def precompute_freqs_cis( dim: int, end: int, theta: float = 10000.0, - rope_scaling_config: model_args.RopeScalingArgs = None, ): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if rope_scaling_config is not None: - freqs = apply_scaling(freqs, rope_scaling_config) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis -class Transformer(ModuleBase): - """Transformer module.""" +class Qwen2Model(ModuleBase, GenerationMixin): + """Qwen2 module.""" def __init__( self, - params: model_args.ModelArgs, + config: Qwen2Config, + device, env, ): + if config.sliding_window is not None: + raise ValueError("Sliding window is not supported for Qwen2 model") + if config.rope_scaling is not None: + raise ValueError("Rope scaling is not supported for Qwen2 model") + super().__init__() + self.config = config + self.generation_config = GenerationConfig.from_model_config(config) + + # NOTE: these parameters are deduced from the config's intermediate_size and hidden_size, so to be compatible + # with the original Jestream/Pytorch model. + ffn_dim_multiplier = config.intermediate_size / int(8 * config.hidden_size / 3) + multiple_of = 1 + params = model_args.ModelArgs( + dim=config.hidden_size, + n_layers=config.num_hidden_layers, + n_heads=config.num_attention_heads, + n_kv_heads=config.num_key_value_heads, + vocab_size=config.vocab_size, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + max_seq_len=env.cache_len, + bf16_enable=env.bf16_enable, + rope_theta=config.rope_theta, + ) + params.device = device self.env = env + self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers + self.vocab_size = config.vocab_size + self.n_layers = config.num_hidden_layers Embedding = get_quantized_embedding_layer(env.quant_config) self.tok_embeddings = Embedding( - params.vocab_size, - params.dim, - device=params.device, + config.vocab_size, + config.hidden_size, + device=device, ) self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params, env)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps, device=params.device) + for layer_id in range(config.num_hidden_layers): + self.layers.append(Qwen2DecoderLayer(layer_id, params, env)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=params.device) LinearLayer = get_quantized_linear_layer(env.quant_config) linear_kwargs = {} @@ -240,18 +384,16 @@ def __init__( linear_kwargs["quant_config"] = env.quant_config self.output = LinearLayer( - params.dim, - params.vocab_size, + config.hidden_size, + config.vocab_size, bias=False, device=params.device, **linear_kwargs, ) - # TODO what to do with this freqs_cis = precompute_freqs_cis( - self.params.dim // self.params.n_heads, + config.hidden_size // config.num_attention_heads, self.params.max_seq_len * 2, - theta=self.params.rope_theta, - rope_scaling_config=self.params.rope_scaling_args, + theta=config.rope_theta, ) self.register_buffer("freqs_cis", freqs_cis) @@ -319,36 +461,6 @@ def forward( output = self.output(h).float() return output - @classmethod - def from_hf_model_id(cls, model_id, env, is_tiny=False): - if is_tiny: - name = "llama-2-tiny" - else: - name = { - "meta-llama/Llama-2-7b-chat-hf": "llama-2-7b", - "meta-llama/Llama-2-7b-hf": "llama-2-7b", - "meta-llama/Llama-2-13b-chat-hf": "llama-2-13b", - "meta-llama/Llama-2-13b-hf": "llama-2-13b", - "meta-llama/Llama-2-70b-hf": "llama-2-70b", - "meta-llama/Llama-2-70b-chat-hf": "llama-2-70b", - "meta-llama/Meta-Llama-3-8B": "llama-3-8b", - "meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b", - "meta-llama/Meta-Llama-3-70B": "llama-3-70b", - "meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b", - "meta-llama/Llama-3.1-8B": "llama-3.1-8b", - "meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-8b", - "meta-llama/Llama-3.2-1B": "llama-3.2-1b", - "meta-llama/Llama-3.2-1B-Instruct": "llama-3.2-1b", - "meta-llama/Llama-3.3-70B": "llama-3.3-70b", - "meta-llama/Llama-3.3-70B-Instruct": "llama-3.3-70b", - }.get(model_id) - assert name - args = model_args.get_model_args( - name, env.cache_len, env.batch_size, env.bf16_enable - ) - args.device = "meta" - model = cls(args, env) - return model def convert_hf_weights(self, hf_weights): @@ -363,6 +475,8 @@ def transform(val, n_heads): updated = copy.copy(hf_weights) for key, value in hf_weights.items(): + if "bias" in key and ("q_proj" in key or "k_proj" in key): + continue if "q_proj" in key: updated[key] = transform(value, self.params.n_heads) if "k_proj" in key: @@ -372,3 +486,9 @@ def transform(val, n_heads): res = super().convert_hf_weights(updated) res["freqs_cis"] = self.freqs_cis return res + + @classmethod + def from_config(cls, config, env): + device = "meta" + model = cls(config, device, env) + return model diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model_v2.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model_v2.py new file mode 100644 index 00000000..b2d6514d --- /dev/null +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model_v2.py @@ -0,0 +1,214 @@ +import copy +import dataclasses +import math +from contextlib import contextmanager +from functools import partial + +import torch +from jetstream_pt.third_party.llama import model_exportable +from jetstream_pt.third_party.llama.model_exportable import Transformer, model_args, TransformerBlock +from transformers import GenerationConfig, GenerationMixin, LlamaConfig +import jax +import torch +import torch.nn.functional as F +from typing import Any, List, Optional +from jetstream_pt.layers import ( + AttentionKernel, + Int8KVAttentionKernel, + RMSNorm, + apply_rotary_emb, + get_quantized_embedding_layer, + get_quantized_linear_layer, +) +from jetstream_pt.model_base import ModuleBase +from transformers import GenerationConfig, GenerationMixin, Qwen2Config + + + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +class Qwen2Model(ModuleBase, GenerationMixin): + """Transformer module that uses HF LlamaConfig instead of Jetstream Pytorch ModelArgs + device. + + Note that this class also derives from GenerationMixin, so that we can use its methods. + """ + + + def __init__( + self, + config: Qwen2Config, + device, + env, + ): + super().__init__() + if config.sliding_window is not None: + raise ValueError("Sliding window is not supported for Qwen2 model") + if config.rope_scaling is not None: + raise ValueError("Rope scaling is not supported for Qwen2 model") + + self.config = config + self.generation_config = GenerationConfig.from_model_config(config) + + # NOTE: these parameters are deduced from the config's intermediate_size and hidden_size, so to be compatible + # with the original Jestream/Pytorch model. + ffn_dim_multiplier = config.intermediate_size / int(8 * config.hidden_size / 3) + multiple_of = 1 + + params = model_args.ModelArgs( + dim=config.hidden_size, + n_layers=config.num_hidden_layers, + n_heads=config.num_attention_heads, + n_kv_heads=config.num_key_value_heads, + vocab_size=config.vocab_size, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + max_seq_len=env.cache_len, + bf16_enable=env.bf16_enable, + rope_theta=config.rope_theta, + ) + params.device = device + + self.env = env + self.params = params + self.vocab_size = config.vocab_size + self.n_layers = config.num_hidden_layers + + Embedding = get_quantized_embedding_layer(env.quant_config) + self.tok_embeddings = Embedding( + config.vocab_size, + config.hidden_size, + device=device, + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(config.num_hidden_layers): + self.layers.append(TransformerBlock(layer_id, params, env)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=params.device) + + LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs["quant_config"] = env.quant_config + + self.output = LinearLayer( + config.hidden_size, + config.vocab_size, + bias=False, + device=device, + **linear_kwargs, + ) + freqs_cis = precompute_freqs_cis( + config.hidden_size // config.num_attention_heads, + self.params.max_seq_len * 2, + theta=config.rope_theta, + ) + + self.register_buffer("freqs_cis", freqs_cis) + + self.hf_name("output", "lm_head") + self.hf_name("norm", "model.norm") + self.hf_name("layers", "model.layers") + self.hf_name("tok_embeddings", "model.embed_tokens") + + self.annotate_sharding("tok_embeddings.weight", 1) + self.annotate_sharding("output.weight", 0) + + + @torch.no_grad() + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + caches: List[Any], + mask, + start=None, + ragged_batch_index=None, + ragged_block_index=None, + ): + """ + tokens: the input token for decoding + input_pos: the decoding position relative to the start, which is the length of the decoding results + caches: kv caches + mask: causal mask to filter the attention results + start: the starting position for each slot + ragged_batch_index: precomputed batch index for ragged attention + ragged_block_index: precomputed block index for ragged attention + """ + with jax.named_scope("transformer_tok"): + seqlen = tokens.shape[-1] + h = self.tok_embeddings(tokens) + + with jax.named_scope("transformer_freq"): + bsz, seqlen = tokens.shape + freqs_cis = self.freqs_cis[input_pos] + freqs_cis = freqs_cis.reshape(bsz, seqlen, -1) + + end = None if start is None else (start + input_pos) % self.env.cache_len + # For stacked case, cannot get cache inside the loop which will cause cache copy + for layer_id, layer in enumerate(self.layers): + if caches[0].stacked: + cache = caches[0] + else: + cache = caches[layer_id] + # else: # For stacked case, there is only 1 yer of kv cache + + with jax.named_scope("TransformerBlock_Layer_" + str(layer_id)): + h = layer( + h, + freqs_cis, + mask, + cache, + start, + end, + ragged_batch_index, + ragged_block_index, + ) + + with jax.named_scope("transformer_norm"): + h = self.norm(h) + output = self.output(h).float() + return output + + + def convert_hf_weights(self, hf_weights): + + def transform(val, n_heads): + dim1, dim2 = val.shape + return ( + val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + updated = copy.copy(hf_weights) + + for key, value in hf_weights.items(): + if "bias" in key and ("q_proj" in key or "k_proj" in key): + continue + if "q_proj" in key: + updated[key] = transform(value, self.params.n_heads) + if "k_proj" in key: + updated[key] = transform( + value, self.params.n_kv_heads or self.params.n_heads + ) + res = super().convert_hf_weights(updated) + res["freqs_cis"] = self.freqs_cis + return res + + @classmethod + def from_config(cls, config, env): + device = "meta" + model = cls(config, device, env) + return model diff --git a/text-generation-inference/tests/test_decode_jetstream.py b/text-generation-inference/tests/test_decode_jetstream.py index 9bf72947..79e4ba20 100644 --- a/text-generation-inference/tests/test_decode_jetstream.py +++ b/text-generation-inference/tests/test_decode_jetstream.py @@ -70,9 +70,15 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample): sequence_length=256, expected_text=" Winston Smith, his chin nuzzled into his breast, stretched, and looked out across the city", max_new_tokens=20, + ), + DecodeTestParams( + model_id="Qwen/Qwen2.5-0.5B", + sequence_length=256, + expected_text=" Winston Smith, his chin nuzzled into his breast, stretched, and looked out across the city", + max_new_tokens=20, ) ], - ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny", "Trendyol-LLM-7b-base-v0.1", "Llama-3.2-1B"], + ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny", "Trendyol-LLM-7b-base-v0.1", "Llama-3.2-1B", "Qwen2.5-0.5B"], ) def test_decode_single_jetstream_pytorch(params, do_sample): params.do_sample = do_sample diff --git a/text-generation-inference/tests/test_qwen2.py b/text-generation-inference/tests/test_qwen2.py new file mode 100644 index 00000000..c7495bf4 --- /dev/null +++ b/text-generation-inference/tests/test_qwen2.py @@ -0,0 +1,66 @@ + +import pytest +from decode_tests_utils import * + + +# All tests in this file are for jetstream +pytestmark = pytest.mark.jetstream + +@pytest.mark.filterwarnings("ignore:.*:UserWarning") +def test_decode_single_jetstream_pytorch(): + params = DecodeTestParams( + model_id="Qwen/Qwen2.5-0.5B", + # model_id="Maykeye/TinyLLama-v0", + # model_id="tengomucho/tiny_qwen2.5", + sequence_length=256, + expected_text=" Winston Smith, his chin nuzzled into his breast, stretched, and looked out across the city", + max_new_tokens=20, + ) + + + model_path = prepare_model(params.model_id, params.sequence_length) + # model_path = paqrams.model_id + + # input_text = "It was a bright cold day in April, and the clocks were striking thirteen." + input_text = "Winston Smith, his chin nuzzled into his" + max_new_tokens = params.max_new_tokens + + generator = AutoGenerator.from_pretrained( + model_path, revision="", max_batch_size=1, max_sequence_length=params.sequence_length + ) + request = create_request( + id=0, + inputs=input_text, + max_new_tokens=max_new_tokens, + do_sample=params.do_sample, + top_k=params.top_k, + seed=1234, + repetition_penalty=params.repetition_penalty, + ) + batch = Batch(id=0, requests=[request], size=1, max_tokens=params.sequence_length) + generations, next_batch = generator.prefill(batch) + print(f"generations prefill: {generations}") + # We already generated one token: call decode max_new_tokens - 1 times + for _ in tqdm(range(max_new_tokens - 1)): + assert next_batch.size == 1 + assert next_batch.max_tokens == params.sequence_length + assert len(generations) == 1 + assert len(generations[0].tokens.ids) == 1 + generations, next_batch = generator.decode([next_batch]) + + assert next_batch is None + assert len(generations) == 1 + output = generations[0].generated_text + assert output.generated_tokens == max_new_tokens + assert output.finish_reason == 0 + # print(f"generations: {generations}") + print(f"Generated text: {output.text}") + if params.do_sample: + if output.text == params.expected_text: + print("❌: Expected text is equal to generated text") + return + else: + if output.text != params.expected_text: + print("❌: Expected text is not equal to generated text") + return + print("✅: Test passed")