Skip to content

Commit

Permalink
WIP debug
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Jan 17, 2025
1 parent e0ce48e commit 192fef3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def __init__(
super().__init__(pt_model, env, weights)
# Call model prefill and generate needs to be JIT'ed, because it is called with sharded notations, and it would
# otherwise not work for some models.
self._call_model_prefill = jax.jit(
self._call_model_prefill,
)
# self._call_model_prefill = jax.jit(
# self._call_model_prefill,
# )
self._call_model_generate = jax.jit(
self._call_model_generate,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def set(self, slot: Slot):
self._curslot = slot

def select(self, logits: jnp.ndarray) -> int:
jax.debug.print("logits shape: {} logits {}", logits.shape, logits)
return self._curslot.select(logits)

class TpuGeneratorJetStream(Generator):
Expand Down Expand Up @@ -411,6 +412,7 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]:
max_prefill_length=max_prefill_length,
jax_padding=True,
)
print(f"tokens: {tokens}, true_length: {true_length}")
return tokens, true_length

def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ def forward(
ragged_batch_index: precomputed batch index for ragged attention
ragged_block_index: precomputed block index for ragged attention
"""
# jax.debug.print("tokens", tokens)
# jax.debug.print("input_pos", input_pos)
# jax.debug.print("mask", mask)
# jax.debug.print("start", start)
# jax.debug.print("ragged_batch_index", ragged_batch_index)
# jax.debug.print("ragged_block_index", ragged_block_index)

with jax.named_scope("transformer_tok"):
seqlen = tokens.shape[-1]
h = self.tok_embeddings(tokens)
Expand All @@ -435,6 +442,8 @@ def forward(
freqs_cis = self.freqs_cis[input_pos]
freqs_cis = freqs_cis.reshape(bsz, seqlen, -1)

print(f"h tok_embeddings: {h}")

end = None if start is None else (start + input_pos) % self.env.cache_len
# For stacked case, cannot get cache inside the loop which will cause cache copy
for layer_id, layer in enumerate(self.layers):
Expand Down
63 changes: 63 additions & 0 deletions text-generation-inference/tests/test_qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

import pytest
from decode_tests_utils import *


# All tests in this file are for jetstream
pytestmark = pytest.mark.jetstream

@pytest.mark.filterwarnings("ignore:.*:UserWarning")
def test_decode_single_jetstream_pytorch():
params = DecodeTestParams(
model_id="Qwen/Qwen2.5-0.5B",
sequence_length=256,
expected_text=" Winston Smith, his chin nuzzled into his breast, stretched, and looked out across the city",
max_new_tokens=1,
)


model_path = prepare_model(params.model_id, params.sequence_length)
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
max_new_tokens = params.max_new_tokens

generator = AutoGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=params.sequence_length
)
request = create_request(
id=0,
inputs=input_text,
max_new_tokens=max_new_tokens,
do_sample=params.do_sample,
top_k=params.top_k,
seed=1234,
repetition_penalty=params.repetition_penalty,
)
batch = Batch(id=0, requests=[request], size=1, max_tokens=params.sequence_length)
generations, next_batch = generator.prefill(batch)
print(f"generations prefill: {generations}")
# We already generated one token: call decode max_new_tokens - 1 times
for _ in tqdm(range(max_new_tokens - 1)):
assert next_batch.size == 1
assert next_batch.max_tokens == params.sequence_length
assert len(generations) == 1
assert len(generations[0].tokens.ids) == 1
generations, next_batch = generator.decode([next_batch])
# Destroy generator: this will properly stop threads and prevent them from getting stuck if one of the following
# assertions fails.
del generator
assert next_batch is None
assert len(generations) == 1
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert output.finish_reason == 0
# print(f"generations: {generations}")
print(f"Generated text: {output.text}")
if params.do_sample:
if output.text == params.expected_text:
print("❌: Expected text is equal to generated text")
return
else:
if output.text != params.expected_text:
print("❌: Expected text is not equal to generated text")
return
print("✅: Test passed")

0 comments on commit 192fef3

Please sign in to comment.