Skip to content

Commit

Permalink
fixed incorrect order of fp8 metadata initialization
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <adener@nvidia.com>
  • Loading branch information
denera committed Jan 21, 2024
1 parent ce2b738 commit a868a5c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,12 +860,12 @@ def __init__(
del self.weight_tensor
del self.bias_tensor

self.reset_parameters(defer_init=(device == 'meta'))

if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True

self.reset_parameters(defer_init=(device == 'meta'))

self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

# For RPL, bias has to be added after TP collectives
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,12 +1226,12 @@ def __init__(
else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)

self.reset_parameters(defer_init=(device == 'meta'))

if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=2)
self.fp8_meta["update_amax_and_scale_fwd"] = True

self.reset_parameters(defer_init=(device == 'meta'))

# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.set_parallel_mode and self.apply_bias:
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,12 +754,12 @@ def __init__(
del self.weight_tensor
del self.bias_tensor

self.reset_parameters(defer_init=(device == 'meta'))

if self.primary_weights_in_fp8:
self.init_fp8_metadata()
self.fp8_meta["update_amax_and_scale_fwd"] = True

self.reset_parameters(defer_init=(device == 'meta'))

self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

# For RPL, bias has to be added after TP collectives
Expand Down

0 comments on commit a868a5c

Please sign in to comment.