Skip to content

Commit

Permalink
Merge pull request #1 from Bihan/add_llama_31_support
Browse files Browse the repository at this point in the history
Add llama 3.1 Support
  • Loading branch information
Bihan authored Aug 29, 2024
2 parents d4e2294 + 3aed4b8 commit 24b54e1
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 37 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tpu-tgi-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
TGI_VERSION=v2.0.3
TGI_VERSION=v2.2.0
- name: Generate artifact attestation for TGI
Expand All @@ -95,7 +95,7 @@ jobs:
labels: ${{ steps.meta-ie.outputs.labels }}
build-args: |
VERSION=${{ steps.version.outputs.version }}
TGI_VERSION=v2.0.3
TGI_VERSION=v2.2.0
target: inference-endpoint


Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL))

.PHONY: build_dist style style_check clean

TGI_VERSION ?= v2.0.3
TGI_VERSION ?= v2.2.0

rwildcard=$(wildcard $1) $(foreach d,$1,$(call rwildcard,$(addsuffix /$(notdir $d),$(wildcard $(dir $d)*))))

Expand Down
52 changes: 26 additions & 26 deletions examples/language-modeling/gemma_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"gcloud compute tpus tpu-vm ssh $TPU_NAME \\\n",
" --zone=$ZONE \\\n",
" -- -L 8888:localhost:8888"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -59,7 +59,6 @@
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"git clone https://github.com/huggingface/optimum-tpu.git\n",
"# Install Optimum tpu\n",
Expand All @@ -73,7 +72,8 @@
"# Change directory and launch Jupyter notebook\n",
"cd optimum-tpu/examples/language-modeling\n",
"jupyter notebook --port 8888"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -94,10 +94,10 @@
"execution_count": null,
"id": "37bccce7-1ce4-4470-9e81-c15b120ef294",
"metadata": {},
"outputs": [],
"source": [
"!huggingface-cli login --token YOUR_HF_TOKEN"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -115,13 +115,13 @@
"execution_count": null,
"id": "6d3c7bc2",
"metadata": {},
"outputs": [],
"source": [
"from optimum.tpu import fsdp_v2\n",
"\n",
"\n",
"fsdp_v2.use_fsdp_v2()"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -140,13 +140,13 @@
"execution_count": null,
"id": "f0196b5d",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"\n",
"dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -161,10 +161,10 @@
"execution_count": null,
"id": "12409299",
"metadata": {},
"outputs": [],
"source": [
"dataset[321]"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -179,13 +179,13 @@
"execution_count": null,
"id": "9c24e0b1",
"metadata": {},
"outputs": [],
"source": [
"{'instruction': 'When was the 8088 processor released?',\n",
" 'context': 'The 8086 (also called iAPX 86) is a 16-bit microprocessor chip designed by Intel between early 1976 and June 8, 1978, when it was released. The Intel 8088, released July 1, 1979, is a slightly modified chip with an external 8-bit data bus (allowing the use of cheaper and fewer supporting ICs),[note 1] and is notable as the processor used in the original IBM PC design.',\n",
" 'response': 'The Intel 8088 processor was released July 1, 1979.',\n",
" 'category': 'information_extraction'}"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -200,7 +200,6 @@
"execution_count": null,
"id": "f1497e0f",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
Expand All @@ -218,7 +217,8 @@
" prompt += tokenizer.eos_token\n",
" sample[\"prompt\"] = prompt\n",
" return sample"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -233,10 +233,10 @@
"execution_count": null,
"id": "16b44a9b",
"metadata": {},
"outputs": [],
"source": [
"data = dataset.map(preprocess_function, remove_columns=list(dataset.features))"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -253,14 +253,14 @@
"execution_count": null,
"id": "f18472ce",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM\n",
"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -275,7 +275,6 @@
"execution_count": null,
"id": "4a01f651",
"metadata": {},
"outputs": [],
"source": [
"from peft import LoraConfig\n",
"\n",
Expand All @@ -286,7 +285,8 @@
" target_modules=[\"k_proj\", \"v_proj\"],\n",
" task_type=\"CAUSAL_LM\",\n",
")"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -301,7 +301,6 @@
"execution_count": null,
"id": "780f1033",
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"from trl import SFTTrainer\n",
Expand Down Expand Up @@ -329,7 +328,8 @@
" max_seq_length=1024,\n",
" packing=True,\n",
")"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand All @@ -344,10 +344,10 @@
"execution_count": null,
"id": "4c437a81",
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
],
"outputs": []
},
{
"cell_type": "markdown",
Expand Down
1 change: 1 addition & 0 deletions optimum/tpu/generation/token_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def create(
)
generation_config.max_length = max_seq_length

generation_config._eos_token_tensor = None
# Instantiate transformers library processors and criterias
logits_processor = model._get_logits_processor(
generation_config,
Expand Down
4 changes: 2 additions & 2 deletions optimum/tpu/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _init_rope(self):
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
Expand All @@ -349,7 +349,7 @@ def _init_rope(self):
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
elif scaling_type in ["dynamic", "llama3"]:
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ keywords = [
]

dependencies = [
"transformers == 4.41.1",
"transformers == 4.43.3",
"torch >= 2.3.0, <= 2.4.0",
"torch-xla[tpu] >= 2.3.0, <= 2.4.0",
"loguru == 0.6.0",
Expand Down
4 changes: 2 additions & 2 deletions text-generation-inference/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN tar -C /tgi -xf /tgi/sources.tar.gz --strip-components=1

# Build cargo components (adapted from TGI original Dockerfile)
# Note that the build image is aligned on the same Linux version as the base image (Debian bookworm/ Ubuntu 22.04)
FROM lukemathwalker/cargo-chef:latest-rust-1.77-bookworm AS chef
FROM lukemathwalker/cargo-chef:latest-rust-1.79-bookworm AS chef
WORKDIR /usr/src

ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
Expand Down Expand Up @@ -95,7 +95,7 @@ RUN apt-get update -y \
RUN pip install --upgrade pip

# Install HuggingFace packages
ARG TRANSFORMERS_VERSION='4.41.1'
ARG TRANSFORMERS_VERSION='4.43.3'
ARG ACCELERATE_VERSION='0.27.2'
ARG SAFETENSORS_VERSION='0.4.2'

Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pkg_name := text_generation_server
BUILDDIR ?= $(CURDIR)/build
VERSION ?= 0.0.1
TGI_VERSION ?= v2.0.3
TGI_VERSION ?= v2.2.0
mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
mkfile_dir := $(dir $(mkfile_path))
pkg_dir := $(BUILDDIR)/$(pkg_name)
Expand Down
2 changes: 1 addition & 1 deletion text-generation-inference/server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
'grpc-interceptor == 0.15.2',
'typer == 0.6.1',
'safetensors == 0.4.2',
'transformers == 4.41.1',
'transformers == 4.43.3',
'loguru == 0.6.0',
"sentencepiece == 0.2.0",
"numpy<2.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def serve(
uds_path: str = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
"""This is the main entry-point for the server CLI.
Expand Down Expand Up @@ -54,6 +56,10 @@ def serve(

if trust_remote_code is not None:
logger.warning("'trust_remote_code' argument is not supported and will be ignored.")
if otlp_service_name is not None:
logger.warning("'otlp_service_name' argument is not supported and will be ignored.")
if max_input_tokens is not None:
logger.warning("'max_input_tokens' argument is not supported and will be ignored.")

# Import here after the logger is added to log potential import exceptions
from optimum.tpu.model import fetch_model
Expand Down
7 changes: 6 additions & 1 deletion text-generation-inference/tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def test_decode_single(params):
@pytest.mark.slow
@pytest.mark.parametrize("params",
[
DecodeTestParams(
model_id="meta-llama/Meta-Llama-3.1-8B",
sequence_length=256,
expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind,",
),
DecodeTestParams(
model_id="meta-llama/Meta-Llama-3-8B",
sequence_length=256,
Expand All @@ -52,7 +57,7 @@ def test_decode_single(params):
expected_text=" Winston Smith, his chin nuzzled into his breast in an effort to escape the v",
),
],
ids=["Meta-Llama-3-8B", "gemma-7b", "Mistral-7B-v0.3"],
ids=["Meta-Llama-3.1-8B", "Meta-Llama-3-8B", "gemma-7b", "Mistral-7B-v0.3"],
)
def test_decode_single_slow(params):
_test_decode_single(params)
Expand Down

0 comments on commit 24b54e1

Please sign in to comment.