Skip to content

Commit

Permalink
test(tgi): added test to validate Llama3 8b on TGI
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Apr 30, 2024
1 parent 584ffda commit 0167fc8
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions text-generation-inference/tests/test_llama3_8b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os

import pytest
from text_generation_server.generator import TpuGenerator
from text_generation_server.pb.generate_pb2 import (
Batch,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
)
from tqdm import tqdm

from optimum.tpu.model import fetch_model


MODEL_ID = "meta-llama/Meta-Llama-3-8B"
SEQUENCE_LENGTH = 256


@pytest.fixture(scope="module")
def model_path():
# Add variables to environment so they can be used in AutoModelForCausalLM
os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH)
path = fetch_model(MODEL_ID)
return path


def create_request(
id: int,
inputs: str,
max_new_tokens=20,
do_sample: bool = False,
top_k: int = 50,
top_p: float = 0.9,
temperature: float = 1.0,
seed: int = 0,
repetition_penalty: float = 1.0,
):
# For these tests we can safely set typical_p to 1.0 (default)
typical_p = 1.0
if not do_sample:
# Drop top_p parameter to avoid warnings
top_p = 1.0
parameters = NextTokenChooserParameters(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
seed=seed,
repetition_penalty=repetition_penalty,
typical_p=typical_p,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters)


@pytest.mark.slow
def test_decode_single(model_path):
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
max_new_tokens = 20
generated_text = " Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,"

generator = TpuGenerator.from_pretrained(
model_path, revision="", max_batch_size=1, max_sequence_length=SEQUENCE_LENGTH
)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for _ in tqdm(range(max_new_tokens - 1)):
assert next_batch.size == 1
assert next_batch.max_tokens == 256
assert len(generations) == 1
assert len(generations[0].tokens.ids) == 1
generations, next_batch = generator.decode([next_batch])
assert next_batch is None
assert len(generations) == 1
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert output.finish_reason == 0
assert output.text == generated_text

0 comments on commit 0167fc8

Please sign in to comment.