Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation offloading to CPU's for the Linear, Layernorm Linear and the Layernorm MLP modules #571

Merged
merged 31 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7cbc9c6
Added support activation offloading to CPU's
Dec 18, 2023
991d54a
Merge branch 'NVIDIA:main' into main
sanandaraj5597 Dec 19, 2023
54e0206
Moving CPU offloading library to TE
Dec 21, 2023
3210f84
Merge branch 'NVIDIA:main' into main
sanandaraj5597 Jan 10, 2024
cb831f5
Restructured code, added switch to choose between weight/activation o…
Jan 10, 2024
ca1808c
Merge remote-tracking branch 'refs/remotes/origin/main'
Jan 10, 2024
c8ad95b
Merge branch 'main' into main
ptrendx Jan 11, 2024
15293fa
Removed arg during constructor
Jan 12, 2024
4942cc8
Merge remote-tracking branch 'refs/remotes/origin/main'
Jan 12, 2024
6a6c1e6
Fix nit-pick errors
Jan 12, 2024
c94961c
Merge branch 'main' into main
ptrendx Jan 12, 2024
73b93fe
Documentation fixes
ptrendx Jan 12, 2024
d890e2a
Fix to the code block in docs
ptrendx Jan 12, 2024
cd92bf6
Merge branch 'main' into main
ksivaman Jan 16, 2024
d45ad8c
Added offloading unit test
Jan 18, 2024
bcdc562
Merge branch 'NVIDIA:main' into main
sanandaraj5597 Jan 18, 2024
3e4b3d5
Fixed formatting
Jan 18, 2024
7b70947
Fixed merge conflict
Jan 18, 2024
4ef146a
Merge branch 'main' into main
ksivaman Jan 19, 2024
6041649
Merge branch 'NVIDIA:main' into main
sanandaraj5597 Jan 19, 2024
dedf650
wgrad fusion fix, minor errors and lint
ksivaman Jan 20, 2024
06c4068
Merge branch 'main' into main
ksivaman Jan 20, 2024
ca99df2
Merge branch 'main' into main
ksivaman Jan 20, 2024
43dd5a6
Errors, test, lint
ksivaman Jan 20, 2024
5671662
RM test file
ksivaman Jan 20, 2024
ab1160d
Fixed stray PyT tensors in LayernormMLP getting offloaded
Jan 21, 2024
f37dcf2
Fixed typi
Jan 21, 2024
411f62e
Fix offloading for rmsnorm, rm test
ksivaman Jan 21, 2024
7db3b60
Fix errors
ksivaman Jan 21, 2024
da8fcf0
Float8Tensor compatible offloading
ksivaman Jan 21, 2024
60d8e83
Cleanup
ksivaman Jan 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ pyTorch
.. autoapifunction:: transformer_engine.pytorch.checkpoint

.. autoapifunction:: transformer_engine.pytorch.onnx_export

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
24 changes: 18 additions & 6 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dataclasses import dataclass
from typing import Optional
from contextlib import nullcontext

import torch
import pytest
Expand All @@ -20,6 +21,7 @@
TransformerLayer,
RMSNorm,
LayerNorm,
get_cpu_offload_context,
)
from transformer_engine.common import recipe

Expand Down Expand Up @@ -215,17 +217,24 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated."


def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload):
te_inp_hidden_states = torch.randn(
config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()

if skip_wgrad:
_disable_wgrads(block)

if cpu_offload:
offload_context, sync_function = get_cpu_offload_context(enabled=True)
else:
offload_context = nullcontext()
sync_function = lambda x: x

use_fp8 = fp8_recipe is not None
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context:
te_out = block(te_inp_hidden_states)
te_out = sync_function(te_out)
loss = te_out.sum()
loss.backward()
torch.cuda.synchronize()
Expand Down Expand Up @@ -449,9 +458,11 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
zero_centered_gamma, bias, activation,
normalization, parallel_attention_mlp):
normalization, parallel_attention_mlp,
cpu_offload):
config = model_configs[model]

if fp8_recipe is not None:
Expand Down Expand Up @@ -489,7 +500,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
.cuda()
)

_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload)


def test_sanity_gpt_126m():
Expand All @@ -512,6 +523,7 @@ def test_sanity_gpt_126m():
activation="gelu",
normalization="LayerNorm",
parallel_attention_mlp=False,
cpu_offload=False,
)


Expand Down Expand Up @@ -713,7 +725,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)

_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)


@pytest.mark.parametrize("dtype", param_types)
Expand Down Expand Up @@ -751,7 +763,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
.cuda()
)

_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad)
_test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False)


@pytest.mark.parametrize("dtype", param_types)
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
from .cpu_offload import get_cpu_offload_context
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
onnx_cast_to_fp8,
Expand Down
Loading