-
Notifications
You must be signed in to change notification settings - Fork 355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Stop storing fused weight tensor in linear modules #719
Conversation
Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon <tmoon@nvidia.com>
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
block = te.Linear(...).to(dtype=dtype).cuda()
I see the existing implementation reallocates the fused weight tensor correctly:
It is strange to me that the tests were previously passing with bit-wise exact accuracy though. I wonder if initializing the weights directly in FP16/BF16 instead of initializing in FP32 and casting resulted in tensors that were not bit-wise identical. |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
/te-ci pytorch |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
/te-ci pytorch |
/te-ci pytorch |
The failed test is unrelated ( |
Anything holding up this being merged? |
/te-ci pytorch |
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
/te-ci pytorch |
/te-ci pytorch |
) * Support noop concat without providing full tensor Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug noop cat func Signed-off-by: Tim Moon <tmoon@nvidia.com> * Construct TE modules in tests with correct dtypes Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add tolerances to numerical tests Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use plain PyTorch concat when exporting to ONNX Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
) * Support noop concat without providing full tensor Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug noop cat func Signed-off-by: Tim Moon <tmoon@nvidia.com> * Construct TE modules in tests with correct dtypes Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add tolerances to numerical tests Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use plain PyTorch concat when exporting to ONNX Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
) * Support noop concat without providing full tensor Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon <tmoon@nvidia.com> * Debug noop cat func Signed-off-by: Tim Moon <tmoon@nvidia.com> * Construct TE modules in tests with correct dtypes Signed-off-by: Tim Moon <tmoon@nvidia.com> * Add tolerances to numerical tests Signed-off-by: Tim Moon <tmoon@nvidia.com> * Use plain PyTorch concat when exporting to ONNX Signed-off-by: Tim Moon <tmoon@nvidia.com> --------- Signed-off-by: Tim Moon <tmoon@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
We've encountered excessive memory usage with FSDP-like workflows because
Linear
andLayerNormLinear
sometimes store a tensor for fused weights. The weight params may be manipulated externally, e.g. to deallocate them or make them views into an all-gather buffer, but the modules hold on to the original buffers and prevent any memory savings. This PR updates the_noop_cat
utility function so that it no longer requires the full output tensor, but can figure things out by inspecting the pointers and strides. There is one functional difference:Linear
andLayerNormLinear
no longer make any attempt to readjust the params if they are not contiguous, but will just concatenate as needed. Besides, if FSDP is misconfigured it will just wipe out the existing buffers and create new misaligned weights in every forward pass, so there's no hope that TE can repair the situation.This is related to #570, which removed the fused weight tensors in cases without split weight params.