From d4e0f80c6a08a84ea66f4be2104297827ee4dff0 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:28:30 -0700 Subject: [PATCH] [PyTorch] Stop storing fused weight tensor in linear modules (#719) * Support noop concat without providing full tensor Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon * Debug noop cat func Signed-off-by: Tim Moon * Construct TE modules in tests with correct dtypes Signed-off-by: Tim Moon * Add tolerances to numerical tests Signed-off-by: Tim Moon * Use plain PyTorch concat when exporting to ONNX Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_numerics.py | 344 +++++++++------ tests/pytorch/test_sanity.py | 396 ++++++++++-------- transformer_engine/pytorch/module/_common.py | 146 ++++--- .../pytorch/module/layernorm_linear.py | 96 ++--- transformer_engine/pytorch/module/linear.py | 97 +++-- 5 files changed, 597 insertions(+), 482 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 0cda82e0c4..90cfce8a6f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -4,7 +4,7 @@ import math import os -from typing import List, Optional +from typing import Dict, List, Optional import pytest import copy @@ -79,19 +79,26 @@ def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() -def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) -> bool: - """Ensures two lists are equal.""" - assert len(l1) == len(l2), "Unequal number of outputs." - failed = False - failed_tensors = "" - for i, (t1, t2) in enumerate(zip(l1, l2)): - if not torch.equal(t1, t2): - failed = True - failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" - assert not failed, "Output mismatches in:\n" + failed_tensors +def dtype_tols(dtype: torch.dtype) -> Dict[str, float]: + """Estimated numerical error for a datatype + Based on tolerances for torch.testing.assert_close. -def assert_allclose(l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float) -> bool: + """ + if dtype == torch.float32: + return dict(rtol=1.3e-6, atol=1e-5) + if dtype == torch.float16: + return dict(rtol=1e-3, atol=1e-5) + if dtype == torch.bfloat16: + return dict(rtol=1.6e-2, atol=1e-5) + raise ValueError(f"Unsuppored dtype ({dtype})") + + +def assert_allclose( + l1: List[torch.Tensor], + l2: List[torch.Tensor], + atol: float, +) -> bool: """Ensures two lists are equal.""" assert len(l1) == len(l2), "Unequal number of outputs." for i, (t1, t2) in enumerate(zip(l1, l2)): @@ -424,13 +431,16 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False output_layernorm=False, params_dtype=dtype, fuse_qkv_params=True, + device="cuda", ) - .cuda() ) te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -464,7 +474,20 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) - assert_all_equal(outputs, outputs_recompute) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols["atol"] = 1e-4 + if fp8 or fp8_model_params: + tols.update(dict(rtol=0.125, atol=0.0675)) + for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + **tols, + ) def _test_e2e_full_recompute( @@ -481,8 +504,7 @@ def _test_e2e_full_recompute( output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) with fp8_model_init(enabled=fp8 and fp8_model_params): - block = ( - TransformerLayer( + block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, @@ -496,13 +518,15 @@ def _test_e2e_full_recompute( output_layernorm=False, params_dtype=dtype, fuse_qkv_params=True, - ) - .cuda() + device="cuda", ) te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=use_reentrant - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=use_reentrant, + ) if use_reentrant: te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -566,7 +590,19 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, # Reset bias+GELU fusion flag to avoid contaminating other tests del os.environ["NVTE_BIAS_GELU_NVFUSION"] - assert_all_equal(outputs, outputs_recompute, names=names) + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols["atol"] = 1e-3 + if fp8 or fp8_model_params: + tols.update(dict(rtol=0.125, atol=0.0675)) + for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + **tols, + ) def _test_e2e_checkpointing_get_model(config, dtype): @@ -574,22 +610,20 @@ def _test_e2e_checkpointing_get_model(config, dtype): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - return ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - params_dtype=dtype, - ) - .cuda() + return TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + params_dtype=dtype, + device="cuda", ) @@ -597,8 +631,11 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= reset_rng_states() te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() block = _test_e2e_checkpointing_get_model(config, dtype) @@ -666,15 +703,29 @@ def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) - assert_all_equal(outputs, outputs_checkpoint) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols.update(dict(rtol=2e-2, atol=2e-3)) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + **tols, + ) def _test_e2e_gpt_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -705,12 +756,12 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): layernorm_epsilon=config.eps, attention_dropout=0.1, hidden_dropout=0.1, + params_dtype=dtype, fuse_qkv_params=True, qkv_weight_interleaved=False, parallel_attention_mlp=parallel_attention_mlp, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -765,8 +816,11 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): reset_rng_states() inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None @@ -799,11 +853,11 @@ def test_mha_accuracy(dtype, bs, model, mask_type): config.hidden_size, config.num_attention_heads, fuse_qkv_params=True, + params_dtype=dtype, qkv_weight_interleaved=False, input_layernorm=False, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -838,8 +892,11 @@ def _test_granular_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) inp_hidden_states.retain_grad() out = block(inp_hidden_states) @@ -857,10 +914,16 @@ def _test_granular_accuracy(block, bs, dtype, config): def _test_dpa_accuracy(block, bs, dtype, config): reset_rng_states() - mask = torch.triu(torch.ones(config.seq_len, config.seq_len, device="cuda"), diagonal=1).bool() + mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1) query, key, value = [ - torch.randn(config.seq_len, bs, config.num_attention_heads, - config.embed, dtype=dtype, requires_grad=True).cuda() for _ in range(3)] + torch.randn( + (config.seq_len, bs, config.num_attention_heads, config.embed), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + for _ in range(3) + ] query.retain_grad() key.retain_grad() @@ -921,9 +984,9 @@ def test_linear_accuracy(dtype, bs, model): config.hidden_size, 4 * config.hidden_size, bias=True, + params_dtype=dtype, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -932,9 +995,9 @@ def test_linear_accuracy(dtype, bs, model): config.hidden_size, 4 * config.hidden_size, bias=True, + device="cuda", + dtype=dtype, ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -965,10 +1028,10 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): RMSNorm( config.hidden_size, eps=eps, - zero_centered_gamma=zero_centered_gamma + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1009,10 +1072,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma): LayerNorm( config.hidden_size, eps=eps, - zero_centered_gamma=zero_centered_gamma + params_dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1058,10 +1121,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere config.eps, bias=True, normalization=normalization, + params_dtype=dtype, zero_centered_gamma=zero_centered_gamma, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1112,9 +1175,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): 4 * config.hidden_size, activation=activation, normalization=normalization, + params_dtype=dtype, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) @@ -1229,11 +1292,11 @@ def test_gpt_cuda_graph(dtype, bs, model): hidden_dropout=0.1, attention_dropout=0.1, kv_channels=config.embed, + params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, + device="cuda", ) - .to(dtype=dtype) - .cuda() ) graphed_block = copy.deepcopy(block) @@ -1257,28 +1320,29 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) with fp8_model_init(enabled=fp8_model_params): - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - params_dtype=dtype, - fuse_qkv_params=True, - ) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + params_dtype=dtype, + fuse_qkv_params=True, + device="cuda", ) te_inp_hidden_states = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) @@ -1306,7 +1370,18 @@ def test_gpt_fp8_parameters(dtype, bs, model): outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) - assert_all_equal(outputs, outputs_fp8_params) + + # Check that results match + tols = dict(rtol=0.125, atol=0.0675) + for i, (ref, test) in enumerate(zip(outputs, outputs_fp8_params)): + torch.testing.assert_close( + test, + ref, + msg=f"Mismatch in tensor {i}", + rtol=0.125, + atol=0.0675, + ) + @pytest.mark.parametrize("dtype", param_types) @@ -1323,54 +1398,53 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): # other layer. Set `*dropout` values to 0 to make sure the forward pass # is identical to the other layer. torch.manual_seed(0) - block_sbhd = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0, - attention_dropout=0, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - attn_input_format="sbhd" - ) - .to(dtype=dtype) - .cuda() + block_sbhd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="sbhd", ) # Set `torch.manual_seed` to make sure the weights are identical to the # other layer. Set `*dropout` values to 0 to make sure the forward pass # is identical to the other layer. torch.manual_seed(0) - block_bshd = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0, - attention_dropout=0, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - attn_input_format="bshd" - ) - .to(dtype=dtype) - .cuda() + block_bshd = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + device="cuda", + attn_input_format="bshd", ) for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" x_sbhd = torch.randn( - config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True - ).to(dtype).cuda() + (config.seq_len, bs, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) x_bshd = x_sbhd.transpose(0,1).contiguous() @@ -1384,7 +1458,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): torch.manual_seed(0) y_bshd = block_bshd(x_bshd) - assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) + # Check that results match + torch.testing.assert_close( + y_bshd, + y_sbhd.transpose(0,1).contiguous(), + ) @pytest.mark.parametrize("dtype", param_types) @@ -1424,10 +1502,10 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, num_attention_heads=H, attn_input_format=input_format, layer_number=layer_number, - attention_dropout = 0.0 + attention_dropout = 0.0, + params_dtype=dtype, + device="cuda", ) - .to(dtype=dtype) - .cuda() .eval() ) else: @@ -1437,9 +1515,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, num_attention_heads=H, qkv_format=input_format, layer_number=layer_number, - attention_dropout = 0.0 + attention_dropout = 0.0, + params_dtype=dtype, ) - .to(dtype=dtype) .cuda() .eval() ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e91e464fa4..9f8c8f73cb 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -172,10 +172,18 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=torch.float32, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=torch.float32, + device="cuda", + requires_grad=True, + ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() + te_inp_attn_mask = torch.randint( + 2, + (1, 1, config.seq_len, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -198,9 +206,17 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() - te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + te_inp_attn_mask = torch.randint( + 2, + (1, 1, config.seq_len, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -227,8 +243,11 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) if skip_wgrad: _disable_wgrads(block) @@ -250,10 +269,18 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) - te_inp_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5 + te_inp_attn_mask = torch.randint( + 2, + (config.batch_size, 1, 1, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -268,10 +295,24 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True - ).cuda() - te_inp_attn_mask = torch.randint(2, (1, 1, config.seq_len, config.seq_len)).cuda().bool() - enc_dec_attn_mask = torch.rand(torch.Size([config.batch_size, 1, 1, config.seq_len])).cuda() > 0.5 + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + te_inp_attn_mask = torch.randint( + 2, + (1, 1, config.seq_len, config.seq_len), + dtype=torch.bool, + device="cuda", + ) + + enc_dec_attn_mask = torch.randint( + 2, + (config.batch_size, 1, 1, config.seq_len), + dtype=torch.bool, + device="cuda", + ) if skip_wgrad: _disable_wgrads(block) @@ -294,8 +335,11 @@ def _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=not skip_dgrad - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=not skip_dgrad, + ) if skip_wgrad: _disable_wgrads(block) @@ -315,8 +359,10 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, requires_grad=True - ).cuda() + (config.seq_len, config.batch_size, config.hidden_size), + device="cuda", + requires_grad=True, + ) te_inp.retain_grad() with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): @@ -371,16 +417,14 @@ def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad, sigma = 0.023 init_method = init_method_normal(sigma) - block = ( - LayerNormLinear( - config.hidden_size, - config.hidden_size * 3, - init_method=init_method, - zero_centered_gamma=zero_centered_gamma, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = LayerNormLinear( + config.hidden_size, + config.hidden_size * 3, + init_method=init_method, + zero_centered_gamma=zero_centered_gamma, + normalization=normalization, + params_dtype=dtype, + device="cuda", ) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) @@ -402,12 +446,12 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): sigma = 0.023 output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - Linear( - config.hidden_size, config.hidden_size, init_method=output_layer_init_method - ) - .to(dtype=dtype) - .cuda() + block = Linear( + config.hidden_size, + config.hidden_size, + init_method=output_layer_init_method, + params_dtype=dtype, + device="cuda", ) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) @@ -435,18 +479,16 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - LayerNormMLP( - config.hidden_size, - 4 * config.hidden_size, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = LayerNormMLP( + config.hidden_size, + 4 * config.hidden_size, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + zero_centered_gamma=zero_centered_gamma, + activation=activation, + normalization=normalization, + params_dtype=dtype, + device="cuda", ) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad) @@ -477,26 +519,24 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - bias=bias, - activation=activation, - normalization=normalization, - parallel_attention_mlp=parallel_attention_mlp, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + zero_centered_gamma=zero_centered_gamma, + bias=bias, + activation=activation, + normalization=normalization, + device="cuda", + parallel_attention_mlp=parallel_attention_mlp, ) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) @@ -546,24 +586,22 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=True, - output_layernorm=True, - zero_centered_gamma=zero_centered_gamma, - self_attn_mask_type="padding", - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=True, + output_layernorm=True, + zero_centered_gamma=zero_centered_gamma, + self_attn_mask_type="padding", + normalization=normalization, + device="cuda", ) _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad) @@ -607,24 +645,22 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - layer_type="decoder", - zero_centered_gamma=zero_centered_gamma, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="decoder", + zero_centered_gamma=zero_centered_gamma, + normalization=normalization, + device="cuda", ) _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad) @@ -665,19 +701,17 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - ) - .to(dtype=torch.float32) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=torch.float32, + device="cuda", ) _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad) @@ -700,22 +734,20 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - drop_path_rate=1.0, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + drop_path_rate=1.0, + device="cuda", ) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) @@ -738,22 +770,20 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - fuse_qkv_params=True, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + fuse_qkv_params=True, + device="cuda", ) _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) @@ -777,24 +807,22 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - fuse_qkv_params=True, - fuse_wgrad_accumulation=True, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + zero_centered_gamma=zero_centered_gamma, + fuse_qkv_params=True, + fuse_wgrad_accumulation=True, + device="cuda", ) _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad) @@ -820,30 +848,28 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - fuse_qkv_params=True, - normalization=normalization, - ) - .to(dtype=dtype) - .cuda() + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.kv_channels, + params_dtype=dtype, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + zero_centered_gamma=zero_centered_gamma, + fuse_qkv_params=True, + normalization=normalization, + device="cuda", ) _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad) def test_model_multiple_cast(): - a = torch.zeros((16,16)).cuda() + a = torch.zeros((16,16), device="cuda") m = Linear(16,32) y = m(a) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 79798d2ff0..ab6455649c 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -10,6 +10,7 @@ import torch from .. import cpp_extensions as tex +from ..export import is_in_onnx_export_mode from ..fp8 import get_fp8_te_dtype from ..utils import get_default_init_method @@ -99,32 +100,79 @@ def _apply_normalization(inputmat:torch.Tensor, class _NoopCatFunc(torch.autograd.Function): - """No-op concatenate tensors along dim 0 + """Concatenate tensors, doing a no-op if possible - `full_tensor` is assumed to already be the concatenation of - `tensors`, i.e. they occupy the same memory with the correct - offsets. + See _noop_cat. """ @staticmethod def forward( - ctx, - split_ranges: List[Tuple[int, int]], - full_tensor: torch.Tensor, + ctx: Any, + dim: int, *tensors: Tuple[torch.Tensor, ...], ) -> torch.Tensor: - # pylint: disable=unused-argument + + # Check first tensor + if not tensors: + raise ValueError("Attempted to concatenate 0 tensors") + num_dims = tensors[0].dim() + if not -num_dims <= dim < num_dims: + raise ValueError( + "Attempted to concatenate tensor " + f"with shape {list(tensors[0].size())} along dim {dim}" + ) + dim %= num_dims + + # Check remaining tensors + out_shape = list(tensors[0].size()) + split_ranges = [(0, tensors[0].size(dim))] + for tensor in tensors[1:]: + in_shape = list(tensor.size()) + if ( + len(in_shape) != num_dims + or in_shape[:dim] != out_shape[:dim] + or in_shape[dim+1:] != out_shape[dim+1:] + ): + raise ValueError( + "Attempted to concatenate tensors with shapes " + f"{[list(tensor.size()) for tensor in tensors]} " + f"along dim {dim}" + ) + split_start = out_shape[dim] + split_end = split_start + in_shape[dim] + out_shape[dim] = split_end + split_ranges.append((split_start, split_end)) + + # Save state for backward + ctx.dim = dim ctx.split_ranges = split_ranges - assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient" - out = full_tensor.new() + + # Out-of-place concatenation if needed + dtype = tensors[0].dtype + device = tensors[0].device + strides = tensors[0].stride() + data_ptr_stride = strides[dim] * tensors[0].element_size() + data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride + for tensor in tensors[1:]: + if ( + tensor.dtype != dtype + or tensor.device != device + or tensor.stride() != strides + or tensor.data_ptr() != data_ptr + ): + return torch.cat(tensors, dim=dim) + data_ptr += tensor.size(dim) * data_ptr_stride + + # No-op concatenation + out = tensors[0].new() out.set_( - full_tensor.untyped_storage(), - full_tensor.storage_offset(), - full_tensor.size(), - full_tensor.stride(), + tensors[0].untyped_storage(), + tensors[0].storage_offset(), + out_shape, + strides, ) - out.requires_grad = True + out.requires_grad = any(tensor.requires_grad for tensor in tensors) return out @staticmethod @@ -132,64 +180,32 @@ def backward( ctx, grad_output: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: - grads = [ - grad_output[split_start:split_end] - for split_start, split_end in ctx.split_ranges - ] - return None, None, *grads + grad_inputs = [] + for split_start, split_end in ctx.split_ranges: + slices = [slice(None)] * grad_output.dim() + slices[ctx.dim] = slice(split_start, split_end) + grad_inputs.append(grad_output[tuple(slices)]) + return None, *grad_inputs def _noop_cat( tensors: List[torch.Tensor], - full_tensor: torch.Tensor, + dim: int = 0, ) -> torch.Tensor: - """Concatenate tensors along dim 0, doing a no-op if possible - - If `full_tensor` is already the concatenation of `tensors`, i.e. - they occupy the same memory region with the correct offsets, then - no copies are performed. Otherwise the buffers in all the tensors - are reallocated so that another call would result in a no-op. + """Concatenate tensors, doing a no-op if possible - In the backward pass, gradients to `partial_tensors` will just be - tensor views. + If tensors are already concatenated in memory, a tensor view of + that memory region will be returned. Otherwise the tensors will be + concatenated out-of-place, as usual. """ - - # Determine split points - split_ranges = [] - full_tensor_shape = full_tensor.size() - offset = 0 - for tensor in tensors: - tensor_shape = tensor.size() - if tensor_shape[1:] != full_tensor_shape[1:]: - raise ValueError( - f"Attempting to concatenate tensor with shape={list(tensor_shape)} " - f"into a tensor with shape={list(full_tensor_shape)}" - ) - split_start = offset - offset += tensor_shape[0] - split_end = offset - split_ranges.append((split_start, split_end)) - if offset != full_tensor_shape[0]: - raise ValueError( - f"Attempting to concatenate tensors with total shape[0]={offset} " - f"into a tensor with shape[0]={full_tensor_shape[0]}" - ) - - # Reallocate buffers if no-op concat isn't possible - need_to_reallocate = False - for tensor, (split_start, _) in zip(tensors, split_ranges): - if tensor.data_ptr() != full_tensor[split_start].data_ptr(): - need_to_reallocate = True - break - if need_to_reallocate: - with torch.no_grad(): - full_tensor.data = torch.cat(tensors) - for tensor, (split_start, split_end) in zip(tensors, split_ranges): - tensor.data = full_tensor[split_start:split_end] - - # Perform no-op concat - return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors) + if not tensors: + raise ValueError("Attempted to concatenate 0 tensors") + if len(tensors) == 1: + return tensors[0] + if is_in_onnx_export_mode(): + return torch.cat(tensors, dim=dim) + return _NoopCatFunc.apply(dim, *tensors) @dataclass diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 7d7bb0bbd5..75a8ad857e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -926,17 +926,20 @@ def __init__( else: self.layer_norm_bias = None - self.weight_tensor = torch.empty( - self.out_features, self.in_features, - device=device, dtype=params_dtype) - + # Contiguous buffers for params + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=params_dtype, + ) + bias_tensor = None if self.use_bias: - self.bias_tensor = torch.empty( + bias_tensor = torch.empty( self.out_features, device=device, - dtype=params_dtype) - else: - self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) + dtype=params_dtype, + ) # Configure parameter splits self.weight_names = [] @@ -982,7 +985,11 @@ def __init__( ) self.parameter_split_sizes[i] = size // self.tp_size - # Construct parameters from weight and bias buffers + # Construct weight parameters + # Note: Register weights together so that they are adjacent to + # each other in LayerNormLinear.parameters(). This makes it + # more likely that they will stay contiguous if the weights + # are manipulated externally, e.g. by FSDP. offset = 0 for i, split_size in enumerate(self.parameter_split_sizes): split_start = offset @@ -998,32 +1005,30 @@ def __init__( ) # Construct weight parameter - weight = self.weight_tensor - if is_subview: - weight = weight[split_start:split_end] - weight = torch.nn.Parameter(weight) - self.register_parameter(self.weight_names[i], weight, - init_fn=init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) - - # Construct bias parameter if needed - if self.use_bias: - bias = self.bias_tensor - if is_subview: - bias = bias[split_start:split_end] - bias = torch.nn.Parameter(bias) - self.register_parameter(self.bias_names[i], bias, - init_fn=init_method_constant(0.0)) - else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) - setattr(self, self.bias_names[i], bias) + self.register_parameter( + self.weight_names[i], + torch.nn.Parameter(weight_tensor[split_start:split_end]), + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) - # Concatenated tensors are not needed if not splitting - # into multiple parameters - if not is_subview: - del self.weight_tensor - del self.bias_tensor + # Construct bias parameters if needed + if self.use_bias: + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + self.register_parameter( + self.bias_names[i], + torch.nn.Parameter(bias_tensor[split_start:split_end]), + init_fn=init_method_constant(0.0), + ) + else: + for name in self.bias_names: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, name, bias) if self.primary_weights_in_fp8: self.init_fp8_metadata() @@ -1150,24 +1155,15 @@ def forward( "Need to run inside fp8_autocast region when weights are stored in FP8." # Get concatenated weight and bias tensors - if len(self.parameter_split_sizes) == 1: - weight_tensor = getattr(self, self.weight_names[0]) - bias_tensor = getattr(self, self.bias_names[0]) - elif torch.is_grad_enabled(): - weight_tensor = _noop_cat( - [getattr(self, name) for name in self.weight_names], - self.weight_tensor, + weight_tensor = _noop_cat( + [getattr(self, name) for name in self.weight_names], + ) + if self.use_bias: + bias_tensor = _noop_cat( + [getattr(self, name) for name in self.bias_names], ) - if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - self.bias_tensor, - ) - else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused else: - weight_tensor = self.weight_tensor - bias_tensor = self.bias_tensor + bias_tensor = getattr(self, self.bias_names[0]) # Unused # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cb2f6871b3..b48987f34c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -777,14 +777,20 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - self.weight_tensor = torch.empty( - self.out_features, self.in_features, - device=device, dtype=params_dtype) - + # Contiguous buffers for params + weight_tensor = torch.empty( + self.out_features, + self.in_features, + device=device, + dtype=params_dtype, + ) + bias_tensor = None if self.use_bias: - self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) - else: - self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) + bias_tensor = torch.empty( + self.out_features, + device=device, + dtype=params_dtype, + ) # Configure parameter splits self.weight_names = [] @@ -830,7 +836,11 @@ def __init__( ) self.parameter_split_sizes[i] = size // self.tp_size - # Construct parameters from weight and bias buffers + # Construct weight parameters + # Note: Register weights together so that they are adjacent to + # each other in Linear.parameters(). This makes it more likely + # that they will stay contiguous if the weights are + # manipulated externally, e.g. by FSDP. offset = 0 for i, split_size in enumerate(self.parameter_split_sizes): split_start = offset @@ -846,32 +856,30 @@ def __init__( ) # Construct weight parameter - weight = self.weight_tensor - if is_subview: - weight = weight[split_start:split_end] - weight = torch.nn.Parameter(weight) - self.register_parameter(self.weight_names[i], weight, - init_fn=init_method, - get_rng_state_tracker=get_rng_state_tracker, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) - - # Construct bias parameter if needed - if self.use_bias: - bias = self.bias_tensor - if is_subview: - bias = bias[split_start:split_end] - bias = torch.nn.Parameter(bias) - self.register_parameter(self.bias_names[i], bias, - init_fn=init_method_constant(0.0)) - else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) - setattr(self, self.bias_names[i], bias) + self.register_parameter( + self.weight_names[i], + torch.nn.Parameter(weight_tensor[split_start:split_end]), + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) - # Concatenated tensors are not needed if not splitting - # into multiple parameters - if not is_subview: - del self.weight_tensor - del self.bias_tensor + # Construct bias parameters if needed + if self.use_bias: + offset = 0 + for i, split_size in enumerate(self.parameter_split_sizes): + split_start = offset + offset += split_size + split_end = offset + self.register_parameter( + self.bias_names[i], + torch.nn.Parameter(bias_tensor[split_start:split_end]), + init_fn=init_method_constant(0.0), + ) + else: + for name in self.bias_names: + bias = torch.Tensor().to(dtype=params_dtype, device=device) + setattr(self, name, bias) if self.primary_weights_in_fp8: self.init_fp8_metadata() @@ -974,24 +982,15 @@ def forward( is_first_module_in_mha = is_first_module_in_mha and self.fp8_meta["recipe"].fp8_mha # Get concatenated weight and bias tensors - if len(self.parameter_split_sizes) == 1: - weight_tensor = getattr(self, self.weight_names[0]) - bias_tensor = getattr(self, self.bias_names[0]) - elif torch.is_grad_enabled(): - weight_tensor = _noop_cat( - [getattr(self, name) for name in self.weight_names], - self.weight_tensor, + weight_tensor = _noop_cat( + [getattr(self, name) for name in self.weight_names], + ) + if self.use_bias: + bias_tensor = _noop_cat( + [getattr(self, name) for name in self.bias_names], ) - if self.use_bias: - bias_tensor = _noop_cat( - [getattr(self, name) for name in self.bias_names], - self.bias_tensor, - ) - else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused else: - weight_tensor = self.weight_tensor - bias_tensor = self.bias_tensor + bias_tensor = getattr(self, self.bias_names[0]) # Unused # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(