Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fused kernel for FP8 scale update #593

Merged
merged 19 commits into from
Feb 8, 2024

Conversation

timmoon10
Copy link
Collaborator

Running GPT 175B, I've found that the forward pass is often bottlenecked by GPU kernel launch overheads, especially in the forward pass. Profiling the Python code finds that ~20% of the Transformer layer forward pass is spent in amax_and_scale_update (compare to 9% launching GEMM kernels). This function is called in every forward pass of Linear, LayerNormLinear, and LayerNormMLP, and each call involves ~10 small GPU operations. nvfuser and torch.compile do fuse some of the operations, leading to some improvement in GPU runtime, but the extra CPU overhead results in somewhat worse performance in the CPU-bound case.

This is an experiment with using a hand-written kernel to reduce these overheads. Alternative approaches:

  • Only update scaling factors once per training step instead of once per microbatch
  • Batch the scale updates for all layers together

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the enhancement New feature or request label Jan 6, 2024
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@yaox12
Copy link
Collaborator

yaox12 commented Jan 9, 2024

From my profiling, _default_get_amax is not fused well either. Can we follow paddle's

tex.amax_and_scale_update_inplace(_amax_history=fp8_meta[fp8_meta_tensor_key].amax_history,
to fused the whole amax_and_scale_update?

@ptrendx
Copy link
Member

ptrendx commented Jan 9, 2024

Right, instead of writing another kernel in the framework-aware portion, let's actually add the Paddle kernel to the common part and use in both places.

Add unit test.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

timmoon10 commented Jan 10, 2024

One thing that'll make this tricky is that the Paddle kernel is in-place. The PyTorch kernel includes the amax history roll, so making that in-place would make the kernel much more complicated than it's worth. Perhaps we should change the Paddle implementation to be out-of-place?

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 changed the title [PyTorch] Implement fused kernel for FP8 scale update Implement fused kernel for FP8 scale update Jan 26, 2024
@timmoon10 timmoon10 marked this pull request as ready for review January 26, 2024 23:36
@timmoon10
Copy link
Collaborator Author

timmoon10 commented Jan 26, 2024

I've moved the fused kernel to the core C++ library and modified the PyTorch and Paddle extensions so they both use it. It turns out that making the kernel work in-place was not difficult.

This includes the Paddle bugfix from #633.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the fused-fp8-scale-update branch from 0d4a3cd to 58b9a23 Compare January 29, 2024 18:45
@timmoon10
Copy link
Collaborator Author

/te-ci

@timmoon10 timmoon10 requested a review from ksivaman January 29, 2024 23:10
@timmoon10
Copy link
Collaborator Author

/te-ci

@ptrendx ptrendx added the 1.4.0 label Jan 30, 2024
@timmoon10
Copy link
Collaborator Author

The PyTorch and JAX tests as passing in pipeline 12465152 and the Paddle tests are failing due to a problem with cuDNN v9 in the upstream Paddle container. The Paddle tests pass when I use an older Paddle container in pipeline 12474691. This is ready to merge.

@ksivaman
Copy link
Member

ksivaman commented Feb 3, 2024

/te-ci

@ksivaman
Copy link
Member

ksivaman commented Feb 3, 2024

Have we tested e2e numerics for this change against the previous versions? @timmoon10

@timmoon10
Copy link
Collaborator Author

The tests pass, but I haven't tried full-scale training runs.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 force-pushed the fused-fp8-scale-update branch from d54bbbe to c597131 Compare February 8, 2024 01:30
@timmoon10
Copy link
Collaborator Author

@ksivaman The numerics are tested thoroughly in https://github.com/NVIDIA/TransformerEngine/blob/7fc00c0819b3249d1abbb935719a81202d097e9d/tests/pytorch/test_recipe.py, so I think full-scale convergence runs are excessively careful.

@timmoon10
Copy link
Collaborator Author

/te-ci

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Collaborator Author

After internal discussion with @ptrendx and @ksivaman, there is no objection to merging without review.

@timmoon10 timmoon10 merged commit a950061 into NVIDIA:main Feb 8, 2024
9 checks passed
@timmoon10 timmoon10 deleted the fused-fp8-scale-update branch April 25, 2024 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1.4.0 enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants