-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support of
optimum-nvidia
's trt-llm (#98)
- Loading branch information
1 parent
487583f
commit 9b3fc4d
Showing
13 changed files
with
350 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
name: TensorRT-LLM Tests | ||
|
||
on: | ||
workflow_dispatch: | ||
push: | ||
branches: [main] | ||
pull_request: | ||
types: [opened, reopened, synchronize] | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
pull_image_and_run_gpu_tests: | ||
runs-on: hf-dgx-01 | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
|
||
- name: Pull image | ||
run: docker pull huggingface/optimum-nvidia:latest | ||
|
||
- name: Run tests | ||
run: docker run | ||
--rm | ||
--net host | ||
--pid host | ||
--shm-size 64G | ||
--env USE_CUDA="1" | ||
--env USER_ID=$(id -u) | ||
--env GROUP_ID=$(id -g) | ||
--volume $(pwd):/workspace/optimum-benchmark | ||
--workdir /workspace/optimum-benchmark | ||
--gpus '"device=0,1"' | ||
--entrypoint /bin/bash | ||
huggingface/optimum-nvidia:latest | ||
-c "pip install -e .[test] && pytest -k 'tensorrt_llm' -x && chown -R $USER_ID:$GROUP_ID ." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,4 +168,5 @@ data/ | |
version.txt | ||
|
||
actions-runner/ | ||
experiments/ | ||
experiments/ | ||
.engine/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
defaults: | ||
- backend: tensorrt # default backend | ||
- launcher: process # default launcher | ||
- benchmark: inference # default benchmark | ||
- experiment # inheriting experiment schema | ||
- _self_ # for hydra 1.1 compatibility | ||
- override hydra/job_logging: colorlog # colorful logging | ||
- override hydra/hydra_logging: colorlog # colorful logging | ||
|
||
experiment_name: trt_llama | ||
model: NousResearch/Llama-2-7b-hf | ||
device: cuda | ||
|
||
backend: | ||
continuous_isolation: false | ||
|
||
benchmark: | ||
input_shapes: | ||
batch_size: 1 | ||
sequence_length: 64 | ||
new_tokens: 128 | ||
|
||
hydra: | ||
run: | ||
dir: runs/${experiment_name} | ||
sweep: | ||
dir: sweeps/${experiment_name} | ||
job: | ||
chdir: true | ||
env_set: | ||
OVERRIDE_BENCHMARKS: 1 | ||
CUDA_VISIBLE_DEVICES: 0 | ||
CUDA_DEVICE_ORDER: PCI_BUS_ID |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from logging import getLogger | ||
from typing import Any, Dict | ||
|
||
from hydra.utils import get_class | ||
from transformers.utils import ModelOutput | ||
|
||
from ..base import Backend | ||
from .config import TRTConfig | ||
from .utils import MODEL_TYPE_TO_TRTMODEL | ||
|
||
LOGGER = getLogger("tensorrt") | ||
|
||
|
||
class TRTBackend(Backend): | ||
NAME: str = "tensorrt" | ||
|
||
def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]) -> None: | ||
super().__init__(model, task, device, hub_kwargs) | ||
self.validate_device() | ||
self.validate_model_type() | ||
|
||
def validate_model_type(self) -> None: | ||
if self.model_type not in MODEL_TYPE_TO_TRTMODEL: | ||
raise NotImplementedError(f"TRTBackend does not support model_type {self.model_type}") | ||
|
||
def validate_device(self) -> None: | ||
if self.device != "cuda": | ||
raise NotImplementedError(f"TRTBackend only supports device cuda, got {self.device}") | ||
|
||
def configure(self, config: TRTConfig) -> None: | ||
super().configure(config) | ||
|
||
self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTMODEL[self.model_type]) | ||
ortmodel_name = self.trtmodel_class.__name__ | ||
LOGGER.info( | ||
f"\t+ Inferred TRTModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}" | ||
) | ||
|
||
# TODO: save engine path for reuse, then maybe re build with max_prompt_size | ||
self.load_trtmodel_from_pretrained() | ||
|
||
@property | ||
def trtmodel_kwargs(self) -> Dict[str, Any]: | ||
return {} | ||
|
||
def load_trtmodel_from_pretrained(self) -> None: | ||
self.pretrained_model = self.trtmodel_class.from_pretrained( | ||
self.model, | ||
**self.trtmodel_kwargs, | ||
**self.hub_kwargs, | ||
) | ||
|
||
def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput: | ||
return self.pretrained_model.generate( | ||
input_ids=input.get("input_ids", None), | ||
attention_mask=input.get("attention_mask", None), | ||
max_new_tokens=1, | ||
) | ||
|
||
def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput: | ||
return self.pretrained_model.generate( | ||
# spelling args to avoid conflict | ||
input_ids=input.get("inputs", None), # diff api | ||
attention_mask=input.get("attention_mask", None), | ||
max_new_tokens=kwargs.get("max_new_tokens", -1), | ||
min_length=kwargs.get("min_new_tokens", -1), # diff api | ||
num_beams=kwargs.get("num_beams", 1), | ||
temperature=kwargs.get("temperature", 1.0), | ||
top_k=kwargs.get("top_k", 50), | ||
top_p=kwargs.get("top_p", 1.0), | ||
repetition_penalty=kwargs.get("repetition_penalty", 1.0), | ||
length_penalty=kwargs.get("length_penalty", 1.0), | ||
seed=kwargs.get("seed", 42), | ||
pad_token_id=kwargs.get("pad_token_id", 0), | ||
bos_token_id=kwargs.get("bos_token_id", 1), | ||
eos_token_id=kwargs.get("eos_token_id", 2), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from dataclasses import dataclass | ||
from logging import getLogger | ||
|
||
from omegaconf import OmegaConf | ||
|
||
from ...import_utils import tesnorrt_version | ||
from ..config import BackendConfig | ||
|
||
LOGGER = getLogger("tensorrt") | ||
|
||
OmegaConf.register_new_resolver("tensorrt_version", tesnorrt_version) | ||
|
||
|
||
@dataclass | ||
class TRTConfig(BackendConfig): | ||
name: str = "tensorrt" | ||
version: str = "${tensorrt_version:}" | ||
_target_: str = "optimum_benchmark.backends.tensorrt.backend.TRTBackend" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
MODEL_TYPE_TO_TRTMODEL = {"llama": "optimum.nvidia.models.llama.LlamaForCausalLM"} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.