Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🦙 Add llama fine-tuning notebook example #130

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ Fine-tuning is supported and tested on the TPU `v5e`. We have tested so far:
You can check the examples:

- [Fine-Tune Gemma on Google TPU](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/gemma_tuning.ipynb)
- The [Llama fine-tuning script](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.md)
- The [Llama fine-tuning script](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.ipynb)
2 changes: 1 addition & 1 deletion docs/source/howto/training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ We provide several example scripts to help you get started:
- See our [Gemma fine-tuning notebook](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/gemma_tuning.ipynb) for a step-by-step guide

2. LLaMA Fine-tuning:
- Check our [LLaMA fine-tuning script](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.md) for detailed instructions
- Check our [LLaMA fine-tuning script](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.ipynb) for detailed instructions
2 changes: 1 addition & 1 deletion docs/source/tutorials/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Explore how to fine-tune language models on TPU infrastructure:
- Includes dataset preparation and PEFT/LoRA implementation
- Provides step-by-step training workflow

2. **LLaMA Fine-tuning Guide** ([examples/language-modeling/llama_tuning.md](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.md))
2. **LLaMA Fine-tuning Guide** ([examples/language-modeling/llama_tuning.ipynb](https://github.com/huggingface/optimum-tpu/blob/main/examples/language-modeling/llama_tuning.ipynb))
- Detailed guide for fine-tuning LLaMA-2 and LLaMA-3 models
- Explains SPMD and FSDP concepts
- Shows how to implement efficient data parallel training
Expand Down
241 changes: 241 additions & 0 deletions examples/language-modeling/llama_tuning.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c7227736-2685-4971-9402-d6015b5319b5",
"metadata": {},
"source": [
"# Fine-Tune Llama on Google TPU\n",
"\n",
"Training Large Language Models (LLMs) on Google Tensor Processing Units (TPUs) with Single Program Multiple Data (SPMD) offers a multitude of benefits. TPUs provide competitive processing power, enabling good training times and allowing researchers to experiment with larger models and datasets efficiently. SPMD architecture optimizes resource utilization by distributing tasks across multiple TPUs, enhancing parallelism and scalability.\n",
"The easiest approach to tune a model with SPMD is using Fully Sharded Data Parallel [(FSDP)](https://engineering.fb.com/2021/07/15/open-source/fsdp/). Pytorch/XLA most recent and performant implementation is [FSDP v2](https://github.com/pytorch/xla/blob/master/docs/fsdpv2.md), that allows to shard weights, activations and outputs.\n",
"\n",
"\n",
"This example shows to tune one of Meta's Llama models on single host TPUs. For information on TPUs architecture, you can consult the [documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).\n",
"\n",
"\n",
"### Prerequisites\n",
"\n",
"We consider you have already created a single-host TPU VM, such as a `v5litepod8` setup, and you have ssh access to the machine.\n",
"You need to clone `optimum-tpu` and install few modules:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4f15b21",
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"git clone https://github.com/huggingface/optimum-tpu.git\n",
"# Install Optimum TPU\n",
"pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html\n",
"# Install TRL and PEFT for training (see later how they are used)\n",
"pip install trl peft\n",
"# Install Jupyter notebook\n",
"pip install -U jupyterlab notebook\n",
"# Optionally, install widgets extensions for better rendering\n",
"pip install ipywidgets widgetsnbextension\n",
"# This will be necessary for the language modeling example\n",
"pip install datasets evaluate accelerate\n",
"# Change directory and launch Jupyter notebook\n",
"cd optimum-tpu/examples/language-modeling\n",
"jupyter notebook --port 8888"
]
},
{
"cell_type": "markdown",
"id": "6a4cc927",
"metadata": {},
"source": [
"We should then see the familiar Jupyter output that shows the address accessible from a browser:\n",
"\n",
"```\n",
"http://localhost:8888/tree?token=3ceb24619d0a2f99acf5fba41c51b475b1ddce7cadb2a133\n",
"```\n",
"\n",
"Since we are going to use the gated `llama` model, we will need to log in using a [Hugging Face token](https://huggingface.co/settings/tokens):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d273dead",
"metadata": {},
"outputs": [],
"source": [
"!huggingface-cli login --token YOUR_HF_TOKEN"
]
},
{
"cell_type": "markdown",
"id": "75a0bd6e",
"metadata": {},
"source": [
"### Enable FSDPv2\n",
"\n",
"To fine-tune an LLM, it might be necessary to shard the model across the TPUs to prevent memory issues and enhance tuning performances. Fully Sharded Data Parallel is an algorithm that has been implemented on Pytorch and that allows to wrap modules to distribute them.\n",
"When using Pytorch/XLA on TPUs, [FSDPv2](https://pytorch.org/xla/master/#fully-sharded-data-parallel-via-spmd) is an utility that re-expresses the famous FSDP algorithm using SPMD (Single Program Multiple Data). In `optimum-tpu` it is possible to use dedicated helpers to use FSPDv2. To enable it, you can use the dedicated function, that should be called at the beginning of the execution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d3c7bc2",
"metadata": {},
"outputs": [],
"source": [
"from optimum.tpu import fsdp_v2\n",
"fsdp_v2.use_fsdp_v2()"
]
},
{
"cell_type": "markdown",
"id": "733118c2",
"metadata": {},
"source": [
"Then, the tokenizer and model need to be loaded. We will choose [`meta-llama/Llama-3.2-1B`](https://huggingface.co/meta-llama/Llama-3.2-1B) for this example."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d07e0b43",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"model_id = \"meta-llama/Llama-3.2-1B\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"# Add custom token for padding Llama\n",
"tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)"
]
},
{
"cell_type": "markdown",
"id": "d5576762",
"metadata": {},
"source": [
"To tune the model with the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, you can load it and obtain the `quote` column:\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5c365bdd",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"data = load_dataset(\"Abirate/english_quotes\")\n",
"data = data.map(lambda samples: tokenizer(samples[\"quote\"]), batched=True)"
]
},
{
"cell_type": "markdown",
"id": "73174355",
"metadata": {},
"source": [
"You then need to specify the FSDP training arguments to enable the sharding feature, the function will deduce the classes that should be sharded:\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2c0c0797",
"metadata": {},
"outputs": [],
"source": [
"fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)"
]
},
{
"cell_type": "markdown",
"id": "50b6ebe4",
"metadata": {},
"source": [
"The `fsdp_training_args` will specify the Pytorch module that needs to be sharded:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d55e0aa",
"metadata": {},
"outputs": [],
"source": [
"fsdp_training_args"
]
},
{
"cell_type": "markdown",
"id": "ccb4820f",
"metadata": {},
"source": [
"Now training can be done as simply as using the standard `Trainer` class:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "12da486c",
"metadata": {},
"outputs": [],
"source": [
"from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments\n",
"trainer = Trainer(\n",
" model=model,\n",
" train_dataset=data[\"train\"],\n",
" args=TrainingArguments(\n",
" per_device_train_batch_size=24,\n",
" num_train_epochs=10,\n",
" max_steps=-1,\n",
" output_dir=\"/tmp/output\",\n",
" optim=\"adafactor\",\n",
" logging_steps=1,\n",
" dataloader_drop_last=True, # Required by FSDP v2 and SPMD.\n",
" **fsdp_training_args,\n",
" ),\n",
" data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),\n",
")\n",
"\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "fc3dd275",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
95 changes: 0 additions & 95 deletions examples/language-modeling/llama_tuning.md

This file was deleted.

Loading