Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Stop storing fused weight tensor in linear modules #719

Merged
merged 11 commits into from
Apr 19, 2024

Conversation

timmoon10
Copy link
Collaborator

We've encountered excessive memory usage with FSDP-like workflows because Linear and LayerNormLinear sometimes store a tensor for fused weights. The weight params may be manipulated externally, e.g. to deallocate them or make them views into an all-gather buffer, but the modules hold on to the original buffers and prevent any memory savings. This PR updates the _noop_cat utility function so that it no longer requires the full output tensor, but can figure things out by inspecting the pointers and strides. There is one functional difference: Linear and LayerNormLinear no longer make any attempt to readjust the params if they are not contiguous, but will just concatenate as needed. Besides, if FSDP is misconfigured it will just wipe out the existing buffers and create new misaligned weights in every forward pass, so there's no hope that TE can repair the situation.

This is related to #570, which removed the fused weight tensors in cases without split weight params.

Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the enhancement New feature or request label Mar 13, 2024
@timmoon10 timmoon10 requested a review from ksivaman March 13, 2024 22:32
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 marked this pull request as draft March 14, 2024 20:46
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

timmoon10 commented Mar 14, 2024

This PR has exposed a bug in our numeric tests. Previously, Linear and LayerNormLinear would hold onto the fused weight tensor and would "repair" the params by making sure they were views of that tensor. However, most of the tests in test_numerics.py initialize TE modules like:

block = te.Linear(...).to(dtype=dtype).cuda()

Since the module was initialized in FP32, it would "repair" FP16/BF16 params by turning them back to FP32. No wonder we were able to achieve such tight numerical tolerances.

I see the existing implementation reallocates the fused weight tensor correctly:

full_tensor.data = torch.cat(tensors)

It is strange to me that the tests were previously passing with bit-wise exact accuracy though. I wonder if initializing the weights directly in FP16/BF16 instead of initializing in FP32 and casting resulted in tensors that were not bit-wise identical.

@timmoon10 timmoon10 marked this pull request as ready for review March 15, 2024 03:42
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 added testing Improvements to tests or testing infrastructure bug Something isn't working labels Mar 15, 2024
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@ksivaman
Copy link
Member

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

The failed test is unrelated (test_layernorm_accuracy) and it passed after rerunning.

@deepakn94
Copy link
Contributor

Anything holding up this being merged?

@ksivaman
Copy link
Member

/te-ci pytorch

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 merged commit 2a0fe78 into NVIDIA:main Apr 19, 2024
19 of 20 checks passed
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
)

* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug noop cat func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Construct TE modules in tests with correct dtypes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tolerances to numerical tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use plain PyTorch concat when exporting to ONNX

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 16, 2024
)

* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug noop cat func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Construct TE modules in tests with correct dtypes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tolerances to numerical tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use plain PyTorch concat when exporting to ONNX

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 23, 2024
)

* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Debug noop cat func

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Construct TE modules in tests with correct dtypes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tolerances to numerical tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Use plain PyTorch concat when exporting to ONNX

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request testing Improvements to tests or testing infrastructure
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants