Skip to content

Commit

Permalink
More numerics fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman committed Apr 10, 2024
1 parent 86e9162 commit 4cbda4a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 37 deletions.
2 changes: 1 addition & 1 deletion tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, hidden_size, nheads, kv, seq_len):
def reset_rng_states() -> None:
"""revert back to initial RNG state."""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
torch.cuda.set_rng_state(_cuda_rng_state)


def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool:
Expand Down
40 changes: 16 additions & 24 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,10 @@ def graph_safe_rng_available() -> bool:
def _get_cuda_rng_state(
device: Union[int, str, torch.device] = "cuda",
clone: bool = False,
graph_safe: bool = True,
) -> torch.Tensor:
r"""Return the random number generator state of the specified GPU as a ByteTensor.
"""Return the random number generator state of the specified GPU."""

Args:
device (torch.device or int, optional): The device to return the RNG state of.
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device).
.. warning::
This function eagerly initializes CUDA.
"""
_lazy_init()
if isinstance(device, str):
device = torch.device(device)
Expand All @@ -75,7 +69,7 @@ def _get_cuda_rng_state(
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available():
if graph_safe_rng_available() and graph_safe:
if clone:
# Reference to the cloned generator state
return default_generator.clone_state()
Expand All @@ -84,15 +78,13 @@ def _get_cuda_rng_state(
return default_generator.get_state()


def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None:
"""Sets the random number generator state of the current GPU.
def _set_cuda_rng_state(
new_state: torch.Tensor,
device: Union[int, str] = -1,
graph_safe = True,
) -> None:
"""Sets the random number generator state of the current GPU."""

Arguments:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if device == -1:
device = torch.device("cuda")
elif isinstance(device, str):
Expand All @@ -105,7 +97,7 @@ def cb() -> None:
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
if graph_safe_rng_available():
if graph_safe_rng_available() and graph_safe:
default_generator.graphsafe_set_state(new_state)
return
default_generator.set_state(new_state)
Expand Down Expand Up @@ -262,7 +254,7 @@ def forward(

# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state()
ctx.fwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()

Expand Down Expand Up @@ -327,13 +319,13 @@ def backward(

# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state()
bwd_cuda_rng_state = _get_cuda_rng_state(graph_safe=False)
if get_rng_state_tracker is not None:
bwd_cuda_rng_state_tracker = get_rng_state_tracker().get_states()

# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

Expand All @@ -347,7 +339,7 @@ def backward(

# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state, graph_safe=False)
if get_rng_state_tracker is not None:
get_rng_state_tracker().set_states(bwd_cuda_rng_state_tracker)

Expand Down Expand Up @@ -395,7 +387,7 @@ def cache_rng_states(self, forward=True):
"""Cache fwd/bwd RNG states in the frame to restore later."""
rng_states = (
torch.get_rng_state(),
_get_cuda_rng_state(),
_get_cuda_rng_state(graph_safe=False),
)
if self.get_rng_state_tracker is not None:
rng_states += (self.get_rng_state_tracker().get_states(), )
Expand All @@ -413,7 +405,7 @@ def restore_rng_states(self, forward=True):
rng_states = self.bwd_rng_states

torch.set_rng_state(rng_states[0])
_set_cuda_rng_state(rng_states[1])
_set_cuda_rng_state(rng_states[1], graph_safe=False)
if self.get_rng_state_tracker is not None:
self.get_rng_state_tracker().set_states(rng_states[2])

Expand Down
20 changes: 8 additions & 12 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
FP8GlobalStateManager,
get_default_fp8_recipe,
)
from .distributed import get_all_rng_states, graph_safe_rng_available
from .distributed import (
get_all_rng_states,
graph_safe_rng_available,
_get_cuda_rng_state,
_set_cuda_rng_state,
)
from .module.base import TransformerEngineBaseModule


Expand Down Expand Up @@ -517,24 +522,15 @@ def forward_func(*args, **kwargs):
forward_funcs = tuple(forward_funcs)

# Save RNG state.
if graph_safe_rng_available():
generators = [torch.cuda.default_generators[torch.cuda.current_device()],
*get_all_rng_states().values()]
original_rng_states = [state.get_state() for state in generators]
else:
original_rng_states = torch.cuda.get_rng_state()
cuda_rng_state = _get_cuda_rng_state(graph_safe=False)

graphed_callables = _make_graphed_callables(
forward_funcs, sample_args, num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching, _order=_order)

# Ensures warmup does not affect numerics for ops such as dropout.
if graph_safe_rng_available():
for gen, state in zip(generators, original_rng_states):
gen.set_state(state)
else:
torch.cuda.set_rng_state(original_rng_states)
_set_cuda_rng_state(cuda_rng_state, graph_safe=False)

# Reset FP8 gradients.
for module in modules:
Expand Down

0 comments on commit 4cbda4a

Please sign in to comment.