Skip to content

Commit

Permalink
fix(Jetstream Pt): correct generation
Browse files Browse the repository at this point in the history
Text generation was not correct because the weights in the model were
not correctly loaded. This is not something that it was easy to spot
just looking at few tokens generated, and it was something that it was
actually fixed already in the Jetstream/Pytorch code, but the fix hadn't
been ported to optimum-tpu.

This fix implement the necessary weights changes, aligning to Jetstream
Pytorch, and tests expected output has been modified accordingly.
  • Loading branch information
tengomucho committed Sep 18, 2024
1 parent c897a7f commit d7b8978
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def instantiate_model_from_repo_id(
env.device = "meta"
model = create_model(model_dir, env)
weights = fetch_models._load_weights(model_dir)
updated_keys = model.get_hf_names_to_real_name()
for name, updated in updated_keys.items():
if name in weights:
val = weights.pop(name)
weights[updated] = val
weights = model.convert_hf_weights(weights)

model.load_state_dict(weights, assign=True, strict=False)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import Any, List, Optional

import jax
Expand Down Expand Up @@ -295,3 +296,26 @@ def shard_weights(self, _weights_dict):
Assumes the weights_dict is a list of XLATensor2
"""

def convert_hf_weights(self, hf_weights):

def transform(val, n_heads):
dim1, dim2 = val.shape
return (
val.reshape(n_heads, 2, dim1 // n_heads // 2, dim2)
.transpose(1, 2)
.reshape(dim1, dim2)
)

updated = copy.copy(hf_weights)

for key, value in hf_weights.items():
if "q_proj" in key:
updated[key] = transform(value, self.config.num_attention_heads)
if "k_proj" in key:
updated[key] = transform(
value, self.config.num_key_value_heads or self.config.num_attention_heads
)
res = super().convert_hf_weights(updated)
res["freqs_cis"] = self.freqs_cis
return res
10 changes: 6 additions & 4 deletions text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class DecodeTestParams:
sequence_length: int
expected_text: str
do_sample: bool = False
max_new_tokens: int = 20


@pytest.mark.parametrize("params",
Expand Down Expand Up @@ -64,7 +65,7 @@ def test_decode_single_slow(params):
def _test_decode_single(params):
model_path = prepare_model(params.model_id, params.sequence_length)
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
max_new_tokens = 20
max_new_tokens = params.max_new_tokens

generator = AutoGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=params.sequence_length
Expand Down Expand Up @@ -100,12 +101,12 @@ def _test_decode_single(params):
DecodeTestParams(
model_id="meta-llama/Llama-2-7b-hf",
sequence_length=256,
expected_text="\nThe clocks were striking thirteen\nThe clocks were striking thirteen\nThe",
expected_text="\nWinston Smith, his chin nuzzled into his breast in an effort to escape",
),
DecodeTestParams(
model_id="meta-llama/Meta-Llama-3-8B",
sequence_length=256,
expected_text=" Winston Smith, his chin on his hands, and the clock in the Ministry of Truth, Minit",
expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,",
),
],
ids=["Llama-2-7b-hf", "Meta-Llama-3-8B"],
Expand All @@ -123,7 +124,8 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample):
DecodeTestParams(
model_id="Maykeye/TinyLLama-v0",
sequence_length=256,
expected_text=" She had a big and it had a big, blue, and a big, red and a big",
expected_text=" The sun was shining and the sky was shining.\nSuddenly, a big wind came and blew the wind away.",
max_new_tokens=25,
),
],
ids=["TinyLLama-v0"],
Expand Down

0 comments on commit d7b8978

Please sign in to comment.