From d307849bcf2c875b73763e8a772c4aad7bd246f4 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 17 Jan 2024 10:18:38 -0800 Subject: [PATCH] Avoid using torch.compile for roll and fill_ Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 7bec34c861..d4d82cf0be 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -583,7 +583,7 @@ def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: return amax_history -@jit_fuser +@torch.jit.script def _default_get_amax( amax_history: torch.Tensor, amax_compute_algo: str, @@ -625,7 +625,7 @@ def _compute_scaling_factor_inverse( return torch.where(non_weight_mask, 1.0 / scale, scale_inv) -@jit_fuser +@torch.jit.script def _fused_amax_and_scale_update( amax_history: torch.Tensor, scale: torch.Tensor,