From 5904a80246fcd8756e0a31a0c72ca3285b7ca2e6 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 29 Jan 2025 17:48:02 -0800 Subject: [PATCH] [PyTorch] Respect existing quantizer usages in functional linear API (#1440) Respect existing quantizer usages in functional linear API Signed-off-by: Tim Moon --- .../pytorch/ops/basic/basic_linear.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 3b4c9579c9..1747877996 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -449,7 +449,7 @@ def _functional_forward( if with_quantized_compute and not w_is_quantized: if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=False) + weight_quantizer.set_usage(rowwise=True) w = weight_quantizer(w) elif not with_quantized_compute and w_is_quantized: w = w.dequantize() @@ -666,7 +666,7 @@ def _functional_backward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=True) + input_quantizer.set_usage(columnwise=True) if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, @@ -705,7 +705,7 @@ def _functional_backward( if with_quantized_compute and not w_is_quantized: if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=True) + weight_quantizer.set_usage(columnwise=True) w = weight_quantizer(w) elif not with_quantized_compute and w_is_quantized: w = w.dequantize() @@ -833,6 +833,10 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: + # Check which grads are required + input_requires_grad = ctx.requires_grad and input_.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight.requires_grad + # FP8 metadata with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() input_quantizer = None @@ -841,6 +845,8 @@ def op_forward( grad_output_quantizer = None grad_input_quantizer = None if with_quantized_compute: + + # Get quantizers input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) if next_op is not None and next_op.num_quantizers("forward") > 0: @@ -849,6 +855,12 @@ def op_forward( if prev_op is not None and prev_op.num_quantizers("backward") > 0: grad_input_quantizer = prev_op.get_quantizer("backward", 0) + # Configure quantizers + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + input_quantizer.set_usage(columnwise=weight_requires_grad) + weight_quantizer.set_usage(columnwise=False) + # Get autocast dtype if needed dtype = None if torch.is_autocast_enabled(): @@ -876,8 +888,8 @@ def op_forward( ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.dtype = dtype - ctx.input_requires_grad = input_.requires_grad - ctx.weight_requires_grad = self.weight.requires_grad + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad ctx.has_prev_op = prev_op is not None return output