Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix backward compatibility with checkpoint API #740

Merged

Conversation

ksivaman
Copy link
Member

@ksivaman ksivaman commented Mar 28, 2024

Makes the new API backward compatible with previous version. In addition, the requires_grad field is checked by PyTorch during the backward call. Checking this in the forward requires user intervention as they need to call vanilla forward during eval steps.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman added the 1.5.0 label Mar 28, 2024
@ksivaman ksivaman requested a review from denera March 28, 2024 04:18
@ksivaman ksivaman self-assigned this Mar 28, 2024
@ksivaman
Copy link
Member Author

/te-ci pytorch

Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! (except the typo)

transformer_engine/pytorch/distributed.py Outdated Show resolved Hide resolved
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman force-pushed the backward_compatible_activation_recompute branch from 4455c0a to 4e1270c Compare March 28, 2024 05:15
@ksivaman
Copy link
Member Author

/te-ci pytorch

@ksivaman ksivaman merged commit 12cbd86 into NVIDIA:main Mar 29, 2024
20 checks passed
ksivaman added a commit that referenced this pull request Apr 3, 2024
* Fix backward compatibility with checkpoint API

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fix lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
* Fix backward compatibility with checkpoint API

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* review comments and fix lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants