Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman committed Apr 9, 2024
1 parent 13574a7 commit f77bde2
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,18 +517,24 @@ def forward_func(*args, **kwargs):
forward_funcs = tuple(forward_funcs)

# Save RNG state.
generators = [torch.cuda.default_generators[torch.cuda.current_device()],
*get_all_rng_states().values()]
original_rng_states = [state.get_state() for state in generators]
if graph_safe_rng_available():
generators = [torch.cuda.default_generators[torch.cuda.current_device()],
*get_all_rng_states().values()]
original_rng_states = [state.get_state() for state in generators]
else:
original_rng_states = torch.cuda.get_rng_state()

graphed_callables = _make_graphed_callables(
forward_funcs, sample_args, num_warmup_iters=num_warmup_iters,
allow_unused_input=allow_unused_input,
fp8_weight_caching=fp8_weight_caching, _order=_order)

# Ensures warmup does not affect numerics for ops such as dropout.
for gen, state in zip(generators, original_rng_states):
gen.set_state(state)
if graph_safe_rng_available():
for gen, state in zip(generators, original_rng_states):
gen.set_state(state)
else:
torch.cuda.set_rng_state(original_rng_states)

# Reset FP8 gradients.
for module in modules:
Expand Down

0 comments on commit f77bde2

Please sign in to comment.