Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Fix backward compatibility with checkpoint API #740

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,13 +516,40 @@ def checkpoint(
kwargs : dict
dictionary of string keys for keyword arguments to :attr:`function`.
"""
only_tensor_args = True
for arg in args:
if not isinstance(arg, torch.Tensor):
only_tensor_args = False
break

# Pop out te.distributed.checkpoint() arguments
global _USE_REENTRANT_ACTIVATION_RECOMPUTE
_USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True)
distribute_saved_activations = kwargs.pop("distribute_saved_activations", False)
tp_group = kwargs.pop("tp_group", None)
get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None)

# Ensure backward compatibility.
if not only_tensor_args:
warnings.warn(
"Passing non-tensor non-keyword arguments in deprecated and support will be removed in "
ksivaman marked this conversation as resolved.
Show resolved Hide resolved
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
DeprecationWarning, stacklevel=2,
)
assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API."
assert (
isinstance(args[0], bool) and callable(args[1])
and isinstance(args[2], None | dist_group_type)
), "Incorrect arguments for deprecated `checkpoint` API."
for arg in args[3:]:
assert (
isinstance(arg, None | torch.Tensor)
), f"Expected tensor argument, found {type(arg)}."

distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3]
args = args[3:]

# Trigger the native PyTorch checkpoint if:
# 1. `function` is a `torch.nn.Module`
# AND
Expand Down Expand Up @@ -555,16 +582,6 @@ def checkpoint(
assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group

# Make sure at least one tensor input has `requires_grad=True`
input_requires_grad = False
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
input_requires_grad = True
break
assert input_requires_grad, (
"`use_reentrant=True` requires at least one input tensor with `requires_grad=True`."
)

return _CheckpointFunction.apply(
function,
distribute_saved_activations,
Expand Down
Loading