Skip to content

Commit

Permalink
works for fp8 with deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaobingSuper committed Feb 5, 2025
1 parent f076495 commit 076e86e
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 16 deletions.
12 changes: 6 additions & 6 deletions docs/source/usage_guides/low_precision_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ accelerator = Accelerator(mixed_precision="fp8", kwarg_handlers=kwargs)
```{yaml}
mixed_precision: fp8
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
amax_compute_algo: max
amax_history_len: 1024
backend: TE
fp8_format: HYBRID
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
use_autocast_during_eval: false
```

Expand Down Expand Up @@ -114,13 +114,13 @@ Similarly this can be set in your `config.yaml`:
```{yaml}
mixed_precision: fp8
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
amax_compute_algo: max
amax_history_len: 1024
backend: TE
fp8_format: HYBRID
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
use_autocast_during_eval: false
```

Expand Down
4 changes: 2 additions & 2 deletions examples/config_yaml_templates/fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ fp8_config:
backend: TE # Can be TE | MS-AMP
# The following are TE specific arguments.
# See https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#common-api for more details
amax_history_length: 1024
amax_history_len: 1024
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level
Expand Down
12 changes: 8 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,14 @@ def __init__(
**kwargs,
)

if self.state.mixed_precision == "fp8" and self.fp8_recipe_handler is None:
self._mixed_precision = mixed_precision
if mixed_precision == "fp8" and self.fp8_recipe_handler is None:
self.fp8_recipe_handler = FP8RecipeKwargs()

self.delayed_fp8_autocast = False
if self.fp8_recipe_handler is not None:
# We already check if FP8 is available during `self.state`
if self.state.mixed_precision != "fp8" and (
if mixed_precision != "fp8" and (
self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
):
raise ValueError("Passing in a `FP8RecipeKwargs` object requires setting `mixed_precision='fp8'`.")
Expand Down Expand Up @@ -536,7 +537,10 @@ def __init__(
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available():
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")

elif self.state.mixed_precision == "fp8":
# for DeepSpeed, self.state.mixed_precision is always "bf16",
# see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1263.
elif mixed_precision == "fp8" or self.state.mixed_precision == "fp8":
# We always enable `native_amp` for FP8
self.native_amp = True
if self.fp8_backend == "MSAMP":
Expand Down Expand Up @@ -3640,7 +3644,7 @@ def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
@property
def fp8_backend(self):
"Returns the configured backend for training in FP8"
if self.mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
if self._mixed_precision == "fp8" and self.fp8_recipe_handler is not None:
return self.fp8_recipe_handler.backend
elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
return "MSAMP"
Expand Down
5 changes: 5 additions & 0 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,11 @@ def _validate_launch_command(args):
setattr(args, k, defaults.ipex_config[k])
for k in defaults.mpirun_config:
setattr(args, k, defaults.mpirun_config[k])
for k in defaults.fp8_config:
arg_to_set = k
if "fp8" not in arg_to_set:
arg_to_set = "fp8_" + arg_to_set
setattr(args, arg_to_set, defaults.fp8_config[k])
continue

# Those args are handled separately
Expand Down
8 changes: 7 additions & 1 deletion src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ def setup_fp8_env(args: argparse.Namespace, current_env: Dict[str, str]):
if arg.startswith("fp8_"):
value = getattr(args, arg)
if value is not None:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
if arg == "fp8_override_linear_precision":
values = value.strip("()").split(",")
current_env[prefix + "FP8_OVERRIDE_FPROP"] = values[0].strip()
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = values[1].strip()
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = values[2].strip()
else:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
return current_env


Expand Down
6 changes: 3 additions & 3 deletions tests/test_configs/0_34_0_fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
fp8_config:
amax_compute_algorithm: max
amax_history_length: 1024
amax_compute_algo: max
amax_history_len: 1024
backend: TE
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: false
override_linear_precision: (false, false, false)
use_autocast_during_eval: false
gpu_ids: all
machine_rank: 0
Expand Down

0 comments on commit 076e86e

Please sign in to comment.