Skip to content

Commit

Permalink
Fix numerics
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman committed Apr 10, 2024
1 parent 31dc133 commit 19d0bd4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph:
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False)

@classmethod
Expand Down

0 comments on commit 19d0bd4

Please sign in to comment.