Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SFTTrainer integration example #42

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/getting_started/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,21 @@ for epoch in range(100):
Selectors that are current implemented are [`Energy`][zeus.optimizer.power_limit.Energy], [`Time`][zeus.optimizer.power_limit.Time], [`ZeusCost`][zeus.optimizer.power_limit.ZeusCost] and [`MaxSlowdownConstraint`][zeus.optimizer.power_limit.MaxSlowdownConstraint].

### `HFGlobalPowerLimitOptimizer`
For easy use with [HuggingFace 🤗 Transformers](https://huggingface.co/docs/transformers/en/index), [`HFGlobalPowerLimitOptimizer`][zeus.optimizer.power_limit.HFGlobalPowerLimitOptimizer] is a drop-in compatible [HuggingFace 🤗 Trainer Callback](https://huggingface.co/docs/transformers/en/main_classes/callback). When initializing a [HuggingFace 🤗 Trainer](https://huggingface.co/docs/transformers/main_classes/trainer), initialize and pass in [`HFGlobalPowerLimitOptimizer`][zeus.optimizer.power_limit.HFGlobalPowerLimitOptimizer] as shown below:
For easy use with [HuggingFace 🤗 Transformers](https://huggingface.co/docs/transformers/en/index), [`HFGlobalPowerLimitOptimizer`][zeus.optimizer.power_limit.HFGlobalPowerLimitOptimizer] is a drop-in compatible [HuggingFace 🤗 Trainer Callback](https://huggingface.co/docs/transformers/en/main_classes/callback). When initializing a [HuggingFace 🤗 Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) or a [TFL SFTTrainer](https://huggingface.co/docs/trl/main/en/sft_trainer), initialize and pass in [`HFGlobalPowerLimitOptimizer`][zeus.optimizer.power_limit.HFGlobalPowerLimitOptimizer] as shown below:

```python
monitor = ZeusMonitor()
optimizer = HFGlobalPowerLimitOptimizer(monitor)

# Initialize HuggingFace 🤗 Trainer
# Also works with SFTTrainer.
trainer = Trainer(
...,
callbacks=[optimizer], # Add the `HFGlobalPowerLimitOptimizer` callback
)
```
Refer to our [HuggingFace 🤗 example training code for fine-tuning using HFGlobalPowerLimitOptimizer](https://github.com/ml-energy/zeus/tree/master/examples/huggingface/) for complete running examples for single-GPU and multi-GPU training using [HuggingFace 🤗 Trainer](https://huggingface.co/docs/transformers/main_classes/trainer).
Refer to our [HuggingFace 🤗 example integration](https://github.com/ml-energy/zeus/tree/master/examples/huggingface/) for:
- Transformers [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer) integration for **causal langauge modeling** (i.e., pre-training)
- TRL [`SFTTrainer`](https://huggingface.co/docs/trl/main/en/sft_trainer) integration for **Gemma 7b supervised fine-tuning with QLoRA**

## Recurring jobs

Expand Down
25 changes: 18 additions & 7 deletions examples/huggingface/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# Integrating Zeus with HuggingFace 🤗

This example will demonstrate how to integrate Zeus with `HuggingFace 🤗 Trainer` using `HFGlobalPowerLimitOptimizer`.

[`run_clm.py`](run_clm.py) was adapted from [HuggingFace 🤗's example training code for fine-tuning language models](https://github.com/huggingface/transformers/tree/f3aa7db439a2a3942f76c115197fe953984ac334/examples/pytorch/language-modeling).
This example will demonstrate how to integrate Zeus's `HFGlobalPowerLimitOptimizer` with HuggingFace Transformers:
- `run_clm.py`: Transformers [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer) for **causal langauge modeling** (i.e., pre-training)
- `run_gemma_sft_qlora.py`: TRL [`SFTTrainer`](https://huggingface.co/docs/trl/main/en/sft_trainer) for **Gemma 7b supervised fine-tuning with QLoRA**

## Dependencies

Use the included requirements.txt file to include all extra dependencies:
To run the `Trainer` integration script (`run_clm.py`):
```sh
pip install -r requirements.txt
pip install -r requirements.txt
```

To run the `SFTTrainer` integration script (`run_gemma_sft_qlora.py`):
```sh
pip install -r requirements-qlora.txt
```
Note that you may have to tweak `requirements-qlora.txt` depending on your setup. The current requirements file assumes that you are using CUDA 11, and installs `nvidia-cusparse-cu11` for `bitsandbytes`. Basically, you want to get a setup where training runs, and just add `pip install zeus-ml` on top of it.

## `ZeusMonitor` and `HFGlobalPowerLimitOptimizer`

- [`ZeusMonitor`](https://ml.energy/zeus/reference/monitor/energy/#zeus.monitor.energy.ZeusMonitor): Measures the GPU time and energy consumption of arbitrary code blocks.
Expand All @@ -23,7 +29,7 @@ For easy use with [HuggingFace 🤗 Transformers](https://huggingface.co/docs/tr
monitor = ZeusMonitor()
optimizer = HFGlobalPowerLimitOptimizer(monitor)

# Initialize HuggingFace 🤗 Trainer
# Also works for SFTTrainer.
trainer = Trainer(
...,
callbacks=[optimizer], # Add the `HFGlobalPowerLimitOptimizer` callback
Expand All @@ -32,9 +38,10 @@ For easy use with [HuggingFace 🤗 Transformers](https://huggingface.co/docs/tr

## Running the Example

By default, `Trainer` will make use of all available GPUs. If you would like to use only a subset of the GPUs, specify the `CUDA_VISIBLE_DEVICES` environment variable, which Zeus will also automatically respect.
By default, `Trainer`/`SFTTrainer` will make use of all available GPUs. If you would like to use only a subset of the GPUs, specify the `CUDA_VISIBLE_DEVICES` environment variable, which Zeus will also automatically respect.

```bash
# For Trainer.
python run_clm.py \
--model_name_or_path gpt2 \
--dataset_name wikitext \
Expand All @@ -44,4 +51,8 @@ python run_clm.py \
--do_train \
--do_eval \
--output_dir /tmp/test-clm

# For SFTTrainer.
python run_gemma_sft_qlora.py \
--dataset_name stingning/ultrachat
```
9 changes: 9 additions & 0 deletions examples/huggingface/requirements-qlora.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
zeus-ml
accelerate >= 0.12.0
torch >= 1.3
datasets >= 1.8.0
transformers>=4.37.2
peft
trl
bitsandbytes
nvidia-cusparse-cu11
1 change: 1 addition & 0 deletions examples/huggingface/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
zeus-ml
accelerate >= 0.12.0
torch >= 1.3
datasets >= 1.8.0
Expand Down
2 changes: 2 additions & 0 deletions examples/huggingface/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation

Forked from https://github.com/huggingface/transformers/tree/f3aa7db439a2a3942f76c115197fe953984ac334/examples/pytorch/language-modeling
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

Expand Down
156 changes: 156 additions & 0 deletions examples/huggingface/run_gemma_sft_qlora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Forked from https://huggingface.co/google/gemma-7b/blob/359f554e6b7dcfe8026a8705a20ff166141a5adf/examples/example_sft_qlora.py
"""

from dataclasses import dataclass, field
from typing import Optional

import torch

from transformers import AutoTokenizer, HfArgumentParser, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer
from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import HFGlobalPowerLimitOptimizer

@dataclass
class ScriptArguments:
"""
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
per_device_train_batch_size: Optional[int] = field(default=1)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=2)
learning_rate: Optional[float] = field(default=2e-4)
max_grad_norm: Optional[float] = field(default=0.3)
weight_decay: Optional[int] = field(default=0.001)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=8)
max_seq_length: Optional[int] = field(default=2048)
model_name: Optional[str] = field(
default=None,
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
}
)
dataset_name: Optional[str] = field(
default="stingning/ultrachat",
metadata={"help": "The preference dataset to use."},
)
fp16: Optional[bool] = field(
default=False,
metadata={"help": "Enables fp16 training."},
)
bf16: Optional[bool] = field(
default=False,
metadata={"help": "Enables bf16 training."},
)
packing: Optional[bool] = field(
default=True,
metadata={"help": "Use packing dataset creating."},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
metadata={"help": "Enables gradient checkpointing."},
)
use_flash_attention_2: Optional[bool] = field(
default=False,
metadata={"help": "Enables Flash Attention 2."},
)
optim: Optional[str] = field(
default="paged_adamw_32bit",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: str = field(
default="constant",
metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
)
max_steps: int = field(default=1000, metadata={"help": "How many optimizer update steps to take"})
warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
save_steps: int = field(default=10, metadata={"help": "Save checkpoint every X updates steps."})
logging_steps: int = field(default=10, metadata={"help": "Log every X updates steps."})
output_dir: str = field(
default="./results",
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)

parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]


def formatting_func(example):
text = f"### USER: {example['data'][0]}\n### ASSISTANT: {example['data'][1]}"
return text

# Load the GG model - this is the local one, update it to the one on the Hub
model_id = "google/gemma-7b"

quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4"
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
torch_dtype=torch.float32,
# attn_implementation="sdpa" if not script_args.use_flash_attention_2 else "flash_attention_2"
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

lora_config = LoraConfig(
r=script_args.lora_r,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout
)

train_dataset = load_dataset(script_args.dataset_name, split="train[:5%]")

output_dir = "gemma-qlora-ultrachat"

training_arguments = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
optim=script_args.optim,
save_steps=script_args.save_steps,
logging_steps=script_args.logging_steps,
learning_rate=script_args.learning_rate,
max_grad_norm=script_args.max_grad_norm,
max_steps=script_args.max_steps,
warmup_ratio=script_args.warmup_ratio,
lr_scheduler_type=script_args.lr_scheduler_type,
gradient_checkpointing=script_args.gradient_checkpointing,
fp16=script_args.fp16,
bf16=script_args.bf16,
)

monitor = ZeusMonitor()
# Since one iteration is in the order of tens of seconds or even minutes,
# no need to profile for too many iterations.
optimizer = HFGlobalPowerLimitOptimizer(monitor, warmup_steps=1, profile_steps=3)

trainer = SFTTrainer(
model=model,
args=training_arguments,
train_dataset=train_dataset,
peft_config=lora_config,
packing=script_args.packing,
dataset_text_field="id",
tokenizer=tokenizer,
max_seq_length=script_args.max_seq_length,
formatting_func=formatting_func,
callbacks=[optimizer],
)

trainer.train()
4 changes: 2 additions & 2 deletions scripts/lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
set -ev

if [[ -z $GITHUB_ACTION ]]; then
black --exclude examples/ZeusDataLoader/cifar100/models zeus capriccio examples
black zeus capriccio
else
black --check --exclude examples/ZeusDataLoader/cifar100/models zeus capriccio examples
black zeus capriccio
fi

ruff zeus
12 changes: 6 additions & 6 deletions zeus/run/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,17 +1011,17 @@ def __next__(self):
self.train_epoch_time.append(time_consumed)
# Record the energy consumption for each GPU.
for index in range(self.world_size):
self.train_epoch_energy[index][
self.epoch_num - 1
] = energy_consumed[index]
self.train_epoch_energy[index][self.epoch_num - 1] = (
energy_consumed[index]
)
else:
# Integrate the last time_consumed seconds.
self.eval_epoch_time.append(time_consumed)
# Record the energy consumption for each GPU.
for index in range(self.world_size):
self.eval_epoch_energy[index][
self.epoch_num - 1
] = energy_consumed[index]
self.eval_epoch_energy[index][self.epoch_num - 1] = (
energy_consumed[index]
)
# For the eval dataloader, we want to record the throughput and power
# for the current power limit. Since the train dataloader sets the power limit
# to the optimal power limit right after profiling is done, this will naturally
Expand Down
Loading