Skip to content

Commit

Permalink
fixed Float8Tensor creation with deferred init, all tests passing loc…
Browse files Browse the repository at this point in the history
…ally
  • Loading branch information
denera committed Jan 12, 2024
1 parent 2f225cf commit cb055e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def init_as_weight(self, param: torch.Tensor, set_tp_attributes: bool = False) -
if FP8GlobalStateManager.with_fp8_parameters():
self.parent.init_fp8_metadata()
self.parent.fp8_meta["update_amax_and_scale_fwd"] = True
param = Float8Tensor(
param = Float8Tensor.to_float8(
param,
fp8_meta=self.parent.fp8_meta,
fp8_meta_index=self.fp8_meta_index
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
for name, param in self.named_parameters(recurse=False):
# Ensure parameter is on a real device
if param.device == torch.device('meta'):
param.to(device='cuda')
param = param.to(device='cuda')

if 'weight' in name:
# Initialize weight values on device
Expand All @@ -778,7 +778,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
param = self.param_init_meta[name].init_as_bias(param)

# Redo parameter wrap in case we broke it above
param = torch.nn.Parameter(param)
setattr(self, name, torch.nn.Parameter(param))

@abstractmethod
def forward(self):
Expand Down

0 comments on commit cb055e5

Please sign in to comment.