-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch/C++] Comm+GEMM overlap compatibility with QuantizedTensor #1427
[PyTorch/C++] Comm+GEMM overlap compatibility with QuantizedTensor #1427
Conversation
…th cppqtensor Signed-off-by: Alp Dener <adener@nvidia.com> CommOverlap objects can now return overlap buffers to PyTorch as QuantizedTensors Signed-off-by: Alp Dener <adener@nvidia.com> updated comm+GEMM overlap test for pure GEMM, both BF16 and FP8 working with QuantizedTensor Signed-off-by: Alp Dener <adener@nvidia.com> te.Linear and te.LayerNormMLP updated for TP overlap w/ QuantizedTensor. All overlaps work in BF16. All ovrlaps except bulk WGRAD work in FP8. Signed-off-by: Alp Dener <adener@nvidia.com> completed TP overlap QuantizedTensor updates for LayerNormLinear, but issues with quantized normalization Signed-off-by: Alp Dener <adener@nvidia.com> all overlaps working with bf16, all but bulk WGRAD working with FP8 Signed-off-by: Alp Dener <adener@nvidia.com> all overlaps work with Float8Tensor, except bulk wgrad in LayerNormMLP (works in other modules) Signed-off-by: Alp Dener <adener@nvidia.com> all overlaps working with QuantizedTensor in BF16 and FP8 Signed-off-by: Alp Dener <adener@nvidia.com> cleaned up pytest formatting Signed-off-by: Alp Dener <adener@nvidia.com>
9ba5009
to
f1dcf35
Compare
for more information, see https://pre-commit.ci
…and updated test sizing Signed-off-by: Alp Dener <adener@nvidia.com>
# Configure quantizer for normalization output | ||
if fp8 and input_quantizer is None: | ||
raise ValueError("Missing quantizer for input tensor") | ||
if fp8: | ||
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | ||
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | ||
): | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | ||
|
||
if input_quantizer is None: | ||
raise ValueError("Missing quantizer for input tensor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Putting UB logic here makes the comment incorrect
- This won't generalize when we add more quantization schemes. Instead of assuming that all recipes except MXFP8 support UB, we should only assume FP8 delayed scaling supports UB.
# Configure quantizer for normalization output | |
if fp8 and input_quantizer is None: | |
raise ValueError("Missing quantizer for input tensor") | |
if fp8: | |
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | |
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | |
): | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
if input_quantizer is None: | |
raise ValueError("Missing quantizer for input tensor") | |
# Check if overlapped communication is supported | |
if ( | |
fp8 | |
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop) | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") | |
# Configure quantizer for normalization output | |
if fp8: | |
if input_quantizer is None: | |
raise ValueError("Missing quantizer for input tensor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks reasonable, although I have some stylistic suggestions. This is fine in our last-minute scramble to restore UB support with FP8. Next we will need to think about extending it to support MXFP8 and other quantization schemes.
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
Outdated
Show resolved
Hide resolved
# Prepare input tensor | ||
# Note: Cast to expected dtype and perform tensor-parallel communication | ||
inputmat = inp | ||
inputmat_total = None | ||
with_input_all_gather = parallel_mode == "column" and sequence_parallel | ||
with_input_all_gather_nccl = ( | ||
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop | ||
) | ||
own_quantized_input = False | ||
if fp8: | ||
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | ||
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | ||
): | ||
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #1427 (comment):
# Prepare input tensor | |
# Note: Cast to expected dtype and perform tensor-parallel communication | |
inputmat = inp | |
inputmat_total = None | |
with_input_all_gather = parallel_mode == "column" and sequence_parallel | |
with_input_all_gather_nccl = ( | |
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop | |
) | |
own_quantized_input = False | |
if fp8: | |
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and isinstance( | |
FP8GlobalStateManager.get_fp8_recipe(), BlockScaling | |
): | |
raise NotImplementedError("Comm+GEMM overlap does not support MXFP8 block scaling") | |
# Check if overlapped communication is supported | |
if ( | |
fp8 | |
and (ub_overlap_ag_fprop or ub_overlap_rs_fprop) | |
and not FP8GlobalStateManager.get_fp8_recipe().delayed() | |
): | |
raise NotImplementedError("Comm+GEMM overlap is only supported with FP8 delayed scaling") | |
# Prepare input tensor | |
# Note: Cast to expected dtype and perform tensor-parallel communication | |
inputmat = inp | |
inputmat_total = None | |
with_input_all_gather_nccl = ( | |
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop | |
) | |
own_quantized_input = False | |
if fp8: |
…ests Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
/te-ci pytorch |
…ppqtensor-tp-overlap-v2.0
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending CI; 23439789
Description
This PR updates TE/common and TE/PyTorch API for comm+GEMM overlap to support the new QuantizedTensor abstraction.
Type of change
Checklist: