Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch/C++] Comm+GEMM overlap compatibility with QuantizedTensor #1427

Merged

Conversation

denera
Copy link
Collaborator

@denera denera commented Jan 28, 2025

Description

This PR updates TE/common and TE/PyTorch API for comm+GEMM overlap to support the new QuantizedTensor abstraction.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • [x I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera added the 2.0.0 label Jan 28, 2025
@denera denera requested review from timmoon10 and ptrendx January 28, 2025 01:56
@denera denera self-assigned this Jan 28, 2025
…th cppqtensor

Signed-off-by: Alp Dener <adener@nvidia.com>

CommOverlap objects can now return overlap buffers to PyTorch as QuantizedTensors

Signed-off-by: Alp Dener <adener@nvidia.com>

updated comm+GEMM overlap test for pure GEMM, both BF16 and FP8 working with QuantizedTensor

Signed-off-by: Alp Dener <adener@nvidia.com>

te.Linear and te.LayerNormMLP updated for TP overlap w/ QuantizedTensor. All overlaps work in BF16. All ovrlaps except bulk WGRAD work in FP8.

Signed-off-by: Alp Dener <adener@nvidia.com>

completed TP overlap QuantizedTensor updates for LayerNormLinear, but issues with quantized normalization

Signed-off-by: Alp Dener <adener@nvidia.com>

all overlaps working with bf16, all but bulk WGRAD working with FP8

Signed-off-by: Alp Dener <adener@nvidia.com>

all overlaps work with Float8Tensor, except bulk wgrad in LayerNormMLP (works in other modules)

Signed-off-by: Alp Dener <adener@nvidia.com>

all overlaps working with QuantizedTensor in BF16 and FP8

Signed-off-by: Alp Dener <adener@nvidia.com>

cleaned up pytest formatting

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the blackwell-cppqtensor-tp-overlap-v2.0 branch from 9ba5009 to f1dcf35 Compare January 28, 2025 02:38
pre-commit-ci bot and others added 2 commits January 28, 2025 02:38
…and updated test sizing

Signed-off-by: Alp Dener <adener@nvidia.com>
@ksivaman ksivaman self-requested a review January 28, 2025 20:42
Comment on lines 142 to 150
# Configure quantizer for normalization output
if fp8 and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")

if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Putting UB logic here makes the comment incorrect
  • This won't generalize when we add more quantization schemes. Instead of assuming that all recipes except MXFP8 support UB, we should only assume FP8 delayed scaling supports UB.
Suggested change
# Configure quantizer for normalization output
if fp8 and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
# Check if overlapped communication is supported
if (
fp8
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop)
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")
# Configure quantizer for normalization output
if fp8:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")

transformer_engine/pytorch/module/layernorm_linear.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/layernorm_linear.py Outdated Show resolved Hide resolved
@timmoon10 timmoon10 self-requested a review January 28, 2025 22:16
@timmoon10 timmoon10 self-requested a review January 28, 2025 23:44
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks reasonable, although I have some stylistic suggestions. This is fine in our last-minute scramble to restore UB support with FP8. Next we will need to think about extending it to support MXFP8 and other quantization schemes.

transformer_engine/pytorch/cpp_extensions/gemm.py Outdated Show resolved Hide resolved
Comment on lines 112 to 125
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
inputmat = inp
inputmat_total = None
with_input_all_gather = parallel_mode == "column" and sequence_parallel
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as #1427 (comment):

Suggested change
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
inputmat = inp
inputmat_total = None
with_input_all_gather = parallel_mode == "column" and sequence_parallel
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance(
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling
):
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling")
# Check if overlapped communication is supported
if (
fp8
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop)
and not FP8GlobalStateManager.get_fp8_recipe().delayed()
):
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling")
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
inputmat = inp
inputmat_total = None
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
if fp8:

transformer_engine/pytorch/module/linear.py Outdated Show resolved Hide resolved
tests/pytorch/distributed/run_gemm_with_overlap.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/layernorm_mlp.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/layernorm_mlp.py Outdated Show resolved Hide resolved
transformer_engine/pytorch/module/layernorm_mlp.py Outdated Show resolved Hide resolved
denera and others added 4 commits January 30, 2025 07:06
…ests

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member

/te-ci pytorch

ksivaman and others added 9 commits February 3, 2025 04:08
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI; 23439789

@timmoon10 timmoon10 merged commit d715c83 into NVIDIA:release_v2.0 Feb 4, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants