Skip to content

Commit

Permalink
review comments and fix lint
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman committed Mar 28, 2024
1 parent c3b28bb commit 4e1270c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def checkpoint(
# Ensure backward compatibility.
if not only_tensor_args:
warnings.warn(
"Passing non-tensor non-keyword arguments in deprecated and support will be removed in "
"Passing non-tensor non-keyword arguments is deprecated and support will be removed in "
"future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and "
"`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.",
DeprecationWarning, stacklevel=2,
Expand All @@ -547,7 +547,7 @@ def checkpoint(
isinstance(arg, None | torch.Tensor)
), f"Expected tensor argument, found {type(arg)}."

distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3]
distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking
args = args[3:]

# Trigger the native PyTorch checkpoint if:
Expand Down

0 comments on commit 4e1270c

Please sign in to comment.