-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Correct extra token, start preparing docker image for TGI/Jetstream Pt (
#93) * fix(Jetstream Pt): remove extra token in decode Jetstream's `generate` function returns input token as result token. The next token is instead available in the decode_state, so this change uses this instead. * fix(engine): set batch_size and sequence_length * fix(Jetstream PT): correct warmup internal params * test(tgi): added a warmup test * chore(jetstream pt): check input type in decode * fix(token selector): seed can be a very large number Before, we could have an error is the seed was bigger than a 64 bit number. * fix(Jetstream PT): handle slot's seed in a clean way * feat(docker): TGI image now include Jetstream Pytorch dependencies This allows testing TGI images with Jetstream Pytorch. * fix(Jetstream Pt): batch returned in prefill initialized to None This is required when there are no more tokens generated after prefill. * feat(Jetstream Pt): speed-up prefill by avoiding redundant compilation A new slot is created at each prefill request, and its selector is passed as argument to a jitted function. The problem is that each new slot has a new signature, even if the contents are the same. The solution is to wrap that in a singleton slot object for the prefill, so the compiler will always see the same object and stop recompiling. * chore(generator): use prefill bucket sizes defined in Jetstream
- Loading branch information
1 parent
03b6573
commit 4265e13
Showing
6 changed files
with
83 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
|
||
|
||
import pytest | ||
from helpers import create_request, prepare_model | ||
from text_generation_server.auto_generator import AutoGenerator | ||
from text_generation_server.pb.generate_pb2 import Batch | ||
|
||
from optimum.tpu.jetstream_pt_support import jetstream_pt_available | ||
|
||
|
||
def test_warmup_jetstream_pytorch(): | ||
if not jetstream_pt_available(): | ||
pytest.skip("Jetstream PyTorch is not available") | ||
model_id = "Maykeye/TinyLLama-v0" | ||
|
||
# The maximum sequence length of the model is set to 1000, but warmup will round that up to the next power of two | ||
# in prefill (1024). | ||
sequence_length = 1000 | ||
|
||
model_path = prepare_model(model_id, sequence_length) | ||
input_text = "It was a bright cold day in April, and the clocks were striking thirteen." | ||
max_new_tokens = 20 | ||
|
||
generator = AutoGenerator.from_pretrained( | ||
model_path, revision="", max_batch_size=1, max_sequence_length=sequence_length | ||
) | ||
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) | ||
batch = Batch(id=0, requests=[request], size=1, max_tokens=sequence_length) | ||
generator.warmup(batch) | ||
|