Skip to content

Commit

Permalink
Activation offloading to CPU's for the Linear, Layernorm Linear and t…
Browse files Browse the repository at this point in the history
…he Layernorm MLP modules (#571)

* Added support activation offloading to CPU's

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Moving CPU offloading library to TE

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Restructured code, added switch to choose between weight/activation offloading

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed arg during constructor

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix nit-pick errors

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Documentation fixes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Fix to the code block in docs

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Added offloading unit test

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed formatting

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* wgrad fusion fix, minor errors and lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Errors, test, lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* RM test file

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixed stray PyT tensors in LayernormMLP getting offloaded

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed typi

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fix offloading for rmsnorm, rm test

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix errors

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Float8Tensor compatible offloading

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Cleanup

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: Przemyslaw Tredak <ptredak@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
4 people committed Jan 22, 2024
1 parent b25611b commit c6f0a1f
Show file tree
Hide file tree
Showing 7 changed files with 615 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint

.. autoapifunction:: transformer_engine.pytorch.onnx_export

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
24 changes: 18 additions & 6 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dataclasses import dataclass
from typing import Optional
from contextlib import nullcontext

import torch
import pytest
Expand All @@ -20,6 +21,7 @@
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
)
from transformer_engine.common import recipe

Expand Down Expand Up @@ -215,17 +217,24 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."


def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()

if skip_wgrad:
_disable_wgrads(block)

if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x

use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
Expand Down Expand Up @@ -449,9 +458,11 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization, parallel_attention_mlp):
normalization, parallel_attention_mlp,
cpu_offload):
config = model_configs[model]

if fp8_recipe is not None:
Expand Down Expand Up @@ -489,7 +500,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
.cuda()
)

_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)


def test_sanity_gpt_126m():
Expand All @@ -512,6 +523,7 @@ def test_sanity_gpt_126m():
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
cpu_offload=False,
)


Expand Down Expand Up @@ -713,7 +725,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)

_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)


@pytest.mark.parametrize("dtype", param_types)
Expand Down Expand Up @@ -751,7 +763,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)

_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)


@pytest.mark.parametrize("dtype", param_types)
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
from .cpu_offload import get_cpu_offload_context
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
Expand Down
Loading

0 comments on commit c6f0a1f

Please sign in to comment.