Skip to content

Commit

Permalink
Bookmark
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Feb 10, 2025
1 parent 676d5ac commit 92b3d9b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
66 changes: 39 additions & 27 deletions benchmarks/fp8/torchao/distrib_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from accelerate.utils import AORecipeKwargs, set_seed
from torchao.float8 import convert_to_float8_training

from transformers.integrations import HfDeepSpeedConfig


MODEL_NAME = "bert-base-cased"
METRIC = evaluate.load("glue", "mrpc")
Expand All @@ -48,16 +50,20 @@ def filter_linear_layers(module, fqn, first_layer_name=None, last_layer_name=Non


def train_baseline(zero_stage: int = 1):
# This forces transformers to think Zero-3 Init should be used
with patch("transformers.integrations.deepspeed.is_deepspeed_zero3_enabled") as mock:
mock.return_value = zero_stage == 3
set_seed(42)

accelerator = Accelerator()
config = HfDeepSpeedConfig(
{
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 1,
"zero_optimization": {"stage": zero_stage},
}
)
plugin = DeepSpeedPlugin(hf_ds_config=config)
accelerator = Accelerator(deepspeed_plugin=plugin)
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)

first_linear = None
last_linear = None
for name, module in model.named_modules():
Expand All @@ -66,9 +72,8 @@ def train_baseline(zero_stage: int = 1):
first_linear = name
last_linear = name
func = partial(filter_linear_layers, first_layer_name=first_linear, last_layer_name=last_linear)
convert_to_float8_training(model, module_filter_fn=func)

accelerator = Accelerator()
convert_to_float8_training(model, module_filter_fn=func)

import numpy as np

Expand Down Expand Up @@ -125,27 +130,34 @@ def train_baseline(zero_stage: int = 1):
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'

del config
return base_model_results, trained_model_results, model_outputs, data


def train_integration(zero_stage: int = 1):
set_seed(42)
AcceleratorState()._reset_state(True)
config = HfDeepSpeedConfig(
{
"train_micro_batch_size_per_gpu": 16,
"gradient_accumulation_steps": 1,
"zero_optimization": {"stage": zero_stage},
}
)
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=zero_stage,
zero3_init_flag=zero_stage == 3,
gradient_clipping=1.0,
hf_ds_config=config,
)
accelerator = Accelerator(
mixed_precision="fp8", kwargs_handlers=[AORecipeKwargs()], deepspeed_plugin=deepspeed_plugin
)
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 16

model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = get_training_utilities(
MODEL_NAME, accelerator=accelerator
)

model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
model, optimizer, lr_scheduler, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, lr_scheduler, train_dataloader, eval_dataloader
)
base_model_results = evaluate_model(model, eval_dataloader, METRIC, accelerator=accelerator)
model.train()
model_outputs = []
Expand All @@ -169,26 +181,26 @@ def train_integration(zero_stage: int = 1):
trained_model_results["f1"] > base_model_results["f1"]
), f'F1 score should be higher for the trained model: {trained_model_results["f1"]} > {base_model_results["f1"]}'

del config
return base_model_results, trained_model_results, model_outputs, data


if __name__ == "__main__":
# for zero_stage in [1, 2, 3]:
for zero_stage in [3]:
# Expected baseline: ValueError: {'accuracy': 0.7916666666666666, 'f1': 0.8513011152416357}
baseline_not_trained, baseline_trained, baseline_outputs, baseline_data = train_baseline(zero_stage)
accelerator_not_trained, accelerator_trained, accelerator_outputs, accelerator_data = train_integration(zero_stage)
print(baseline_trained, accelerator_trained)
# assert (
# baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
# ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
# assert (
# baseline_not_trained["f1"] == accelerator_not_trained["f1"]
# ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
# assert (
# baseline_trained["accuracy"] == accelerator_trained["accuracy"]
# ), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
# assert (
# baseline_trained["f1"] == accelerator_trained["f1"]
# ), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'

assert (
baseline_not_trained["accuracy"] == accelerator_not_trained["accuracy"]
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_not_trained["accuracy"]} == {accelerator_not_trained["accuracy"]}'
assert (
baseline_not_trained["f1"] == accelerator_not_trained["f1"]
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_not_trained["f1"]} == {accelerator_not_trained["f1"]}'
assert (
baseline_trained["accuracy"] == accelerator_trained["accuracy"]
), f'ZERO stage {zero_stage}: Accuracy should be the same for the baseline and accelerator: {baseline_trained["accuracy"]} == {accelerator_trained["accuracy"]}'
assert (
baseline_trained["f1"] == accelerator_trained["f1"]
), f'ZERO stage {zero_stage}: F1 score should be the same for the baseline and accelerator: {baseline_trained["f1"]} == {accelerator_trained["f1"]}'
AcceleratorState()._reset_state(True)
torch.distributed.destroy_process_group()
2 changes: 1 addition & 1 deletion benchmarks/fp8/torchao/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def collate_fn(examples):
return train_dataloader, eval_dataloader


def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None):
def get_training_utilities(model_name: str, batch_size: int = 16, accelerator=None, prepare=True):
"""
Returns a tuple of:
- Model
Expand Down

0 comments on commit 92b3d9b

Please sign in to comment.