Skip to content

Commit

Permalink
add benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
Zheng Niu committed Jul 15, 2024
1 parent 0f79d54 commit 90802c5
Show file tree
Hide file tree
Showing 2 changed files with 243 additions and 6 deletions.
3 changes: 2 additions & 1 deletion benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class RequestFuncInput:

@dataclass
class RequestFuncOutput:
generated_text: str = ""
generated_text: str = "",
output_tokens: Union[List[int], List[List[int]]] = [],
success: bool = False
latency: float = 0.0
ttft: float = 0.0 # Time to first token
Expand Down
246 changes: 241 additions & 5 deletions benchmarks/benchmark_tts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,243 @@
"""Benchmark offline inference throughput."""
import argparse
import asyncio
from asyncio import tasks
import json
import random
import time
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, AsyncGenerator
import warnings

import numpy as np
import torch
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)

from benchmarks.backend_request_func import RequestFuncInput, RequestFuncOutput
from benchmarks.benchmark_serving import BenchmarkMetrics
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
from vllm.utils import FlexibleArgumentParser

def calculate_metrics(
input_requests: List[Tuple[str, int, int]],
outputs: List[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = []
total_input = 0
completed = 0
itls: List[float] = []
tpots: List[float] = []
ttfts: List[float] = []
for i in range(len(outputs)):
if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len = len(outputs[i].output_tokens)
actual_output_lens.append(output_len)
total_input += input_requests[i][1]
if output_len > 1:
tpots.append(
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
itls += outputs[i].itl
ttfts.append(outputs[i].ttft)
completed += 1
else:
actual_output_lens.append(0)

if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=sum(actual_output_lens),
request_throughput=completed / dur_s,
input_throughput=total_input / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) *
1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000,
std_ttft_ms=np.std(ttfts or 0) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
)

return metrics, actual_output_lens

async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request

if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue

# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)

async def generate_streaming(llm: AsyncLLMEngine, request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None)-> RequestFuncOutput:
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0.0
st = time.perf_counter()
sampling_params = SamplingParams(n=1, temperature=1, detokenize=False, stop_token_ids=[21803], max_tokens=2048, top_k=1)
results_generator = llm.generate(request_func_input.prompt, sampling_params, request_id=id)
async for request_output in results_generator:
token_ids = request_output.outputs[0].token_ids
# print(f'{id} {[x - 21178 for x in token_ids[-1]]}')
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft

# Decoding phase
output.itl.append(timestamp -
most_recent_timestamp)

most_recent_timestamp = timestamp

output.latency = most_recent_timestamp - st
output.success = True
output.output_tokens = token_ids

if pbar:
pbar.update(1)
return output

async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
request_rate=16,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
) -> Tuple[float, int]:


engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
)
llm = AsyncLLMEngine.from_engine_args(engine_args)
pbar = tqdm(total=len(requests))
benchmark_start_time = time.perf_counter()

async for request in get_request(requests, request_rate):
prompt, prompt_len, output_len = request
request_func_input = RequestFuncInput(
model=model,
prompt=prompt,
prompt_len=prompt_len,
output_len=output_len,
use_beam_search=use_beam_search,
)
tasks.append(
asyncio.create_task(
generate_streaming(llm, request_func_input=request_func_input,
pbar=pbar)))

outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

if pbar is not None:
pbar.close()

benchmark_duration = time.perf_counter() - benchmark_start_time

metrics, actual_output_lens = calculate_metrics(
input_requests=requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
)

print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:",
metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput))
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
metrics.input_throughput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
metrics.output_throughput))
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
n=50,
c='-'))
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("=" * 50)

def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
Expand Down Expand Up @@ -144,7 +368,6 @@ def run_hf(
end = time.perf_counter()
return end - start


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
Expand All @@ -153,10 +376,21 @@ def main(args: argparse.Namespace):
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = open(args.dataset).read().splitlines()
requests = [(f'[Stts][spk_emb][speed_5]{request}[Ptts]', 0, 0) for request in requests]
requests = [(f'[Stts][spk_emb][speed_5]{request}[Ptts]', len(tokenizer(request).input_ids), 2048) for request in requests]
requests = requests[:args.num_prompts]

input_ids = tokenizer([x[0] for x in requests], return_tensors="pt", padding=True).input_ids
total_input_tokens = sum(count for _, count, _ in requests)

if args.streaming:
asyncio.run(run_vllm_async(requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.download_dir, args.load_format))
return

if args.backend == "vllm":
elapsed_time, total_num_tokens = run_vllm(
Expand All @@ -176,7 +410,7 @@ def main(args: argparse.Namespace):
else:
raise ValueError(f"Unknown backend: {args.backend}")

print(f"Total input {input_ids.numel()}, total output {total_num_tokens}")
print(f"Total input {total_input_tokens}, total output {total_num_tokens}")
print(f"Elapsed time: {elapsed_time:.2f}s")
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
Expand All @@ -200,6 +434,7 @@ def main(args: argparse.Namespace):
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--streaming", action="store_true")
parser.add_argument("--dataset",
type=str,
default=None,
Expand Down Expand Up @@ -369,4 +604,5 @@ def main(args: argparse.Namespace):
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")

main(args)

0 comments on commit 90802c5

Please sign in to comment.