diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index be01f2c011..96acb699ad 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -184,6 +184,7 @@ def gen_precision_aware_test( grad_dtype, exp_avg_dtype, exp_avg_sq_dtype, + store_param_remainders=False, model_rtol=None, model_atol=None, master_rtol=None, @@ -220,6 +221,7 @@ def gen_precision_aware_test( "weight_decay": 0, "amsgrad": False, } + ref_optim = torch.optim.Adam(ref_params, **options) tst_optim = te.optimizers.FusedAdam( model_params, @@ -228,6 +230,7 @@ def gen_precision_aware_test( exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype, use_decoupled_grad=True, + store_param_remainders=store_param_remainders, **options, ) @@ -237,7 +240,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer): p.decoupled_grad = p_ref.grad.clone().to(grad_dtype) ref_optimizer.step() tst_optimizer.step() - if use_master_weights: + if use_master_weights and not store_param_remainders: master_weights_to_fp32 = [ tst_optim.get_unscaled_state(p, "master_param") for p in model_params ] @@ -270,6 +273,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer): exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype, use_decoupled_grad=True, + store_param_remainders=store_param_remainders, **options, ) tst_optim.load_state_dict(state_dict) @@ -300,6 +304,19 @@ def test_fp32_master(self): exp_avg_sq_dtype=torch.float32, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp32_master_store_param_remainders(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + store_param_remainders=True, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") def test_fp16_master(self): self.gen_precision_aware_test( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 67fd1caf5b..58527ef6d5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -479,6 +479,12 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const int step, const int mode, const int bias_correction, const float weight_decay); +void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay); + void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index cb5e878fb2..548dd5a267 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -179,6 +179,122 @@ struct AdamFunctorMaster { } }; +template +struct AdamFunctorMasterParamRemainder { + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, + TensorListMetadata<5> &tl, // NOLINT(*) + const float beta1, const float beta2, + const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + int16_t *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + int16_t *p_remainder = reinterpret_cast(tl.addresses[4][tensor_loc]); + p_remainder += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + union fp32_or_int162 { + float fp32; + int16_t int16[2]; + }; + fp32_or_int162 local_master_param[ILP]; + int16_t local_p[ILP]; + int16_t local_p_rem[ILP]; + MATH_T r_g[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + + local_p[ii] = static_cast(p[i]); + local_p_rem[ii] = static_cast(p_remainder[i]); + } else { + r_g[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + + local_p[ii] = int16_t(0); + local_p_rem[ii] = int16_t(0); + } + } +// Reconstruct FP32 params +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (local_p_rem[ii] < 0) local_p[ii]--; // Undo rounding + local_master_param[ii].int16[1] = local_p[ii]; + local_master_param[ii].int16[0] = local_p_rem[ii]; + } + + MATH_T *r_p = reinterpret_cast(local_master_param); + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } + +// Split into BF16 params (rounded-to-nearest) and remainders +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + local_p[ii] = local_master_param[ii].int16[1]; + local_p_rem[ii] = local_master_param[ii].int16[0]; + if (local_p_rem[ii] < 0) local_p[ii]++; // Round up + } + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p_remainder[i] = static_cast(local_p_rem[ii]); + p[i] = static_cast(local_p[ii]); + + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, @@ -548,6 +664,42 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, AT_CUDA_CHECK(cudaGetLastError()); } +void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + const auto g_in_type = tensor_lists[0][0].scalar_type(); + const auto p_in_type = tensor_lists[1][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 5: g, p, m, v, p_master + TORCH_CHECK(tl_size == 5, "tensor list must contain 5"); + TORCH_CHECK(p_in_type == at::ScalarType::BFloat16, + "Adam with BF16 param remainders requires BF16 params"); + + // g, p, m, v, p_master + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMasterParamRemainder(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + + AT_CUDA_CHECK(cudaGetLastError()); +} + void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 165855d430..e5d8744eef 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -213,6 +213,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_adam", &multi_tensor_adam_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda, + "Compute and apply gradient update to parameters for Adam optimizer" + "where the master parameters only store the remainder bits", + py::call_guard()); m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 170c95442f..b86c973304 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -94,6 +94,13 @@ class FusedAdam(torch.optim.Optimizer): instead of ".grad" for reading gradients. It's useful when the dtypes of grad and param are different. (default: False) + store_param_remainders (bool, optional): Whether to store entire FP32 master + params or just store the trailing 16 remainder bits. Whole FP32 master can be + reconstructed from BF16 params plus the trailing remainder bits. Works only + when param type is BF16 and master weight type is FP32, no effect otherwise. + Useful memory saving optimization. + (default: False) + .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -118,6 +125,7 @@ def __init__( exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32, use_decoupled_grad=False, + store_param_remainders=False, ): if amsgrad: @@ -142,6 +150,8 @@ def __init__( raise RuntimeError("Capturable mode only supports fp32 exp_avg.") if capturable and exp_avg_sq_dtype != torch.float32: raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq") + if capturable and store_param_remainders: + raise RuntimeError("Capturable mode doesn't support storing param remainders") # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr @@ -172,6 +182,7 @@ def __init__( # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") self.multi_tensor_adam = tex.multi_tensor_adam + self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master @@ -192,6 +203,10 @@ def __init__( } self._scales = {} self.use_decoupled_grad = use_decoupled_grad + # Works only when master params is in FP32 + self.store_param_remainders = ( + store_param_remainders and master_weights and master_weight_dtype == torch.float32 + ) def zero_grad(self): # pylint: disable=missing-function-docstring @@ -261,7 +276,14 @@ def get_unscaled_state(self, param, state_name): unscaled = state[state_name].float() unscaled.mul_(self._scales[param][state_name]) elif dtype == torch.float32: - assert state[state_name].dtype == torch.float32 + if ( + self.store_param_remainders + and state_name == "master_param" + and param.dtype == torch.bfloat16 + ): + assert state[state_name].dtype == torch.int16 + else: + assert state[state_name].dtype == torch.float32 unscaled = state[state_name] else: raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") @@ -279,10 +301,19 @@ def set_scaled_state(self, param, state_name, unscaled_state): and 'master_param`. unscaled_state (torch.Tensor): The original high-precision(FP32) state. """ - assert unscaled_state.dtype == torch.float32 + store_param_remainders = ( + self.store_param_remainders + and state_name == "master_param" + and param.dtype == torch.bfloat16 + ) + + if store_param_remainders: + assert unscaled_state.dtype == torch.int16 + else: + assert unscaled_state.dtype == torch.float32 state = self.state[param] if state_name not in state: - self._initialize_state(param, state_name, False) + self._initialize_state(param, state_name, False, store_param_remainders) dtype = self.name_to_dtype_map[state_name] if dtype != torch.float32: @@ -291,7 +322,9 @@ def set_scaled_state(self, param, state_name, unscaled_state): else: state[state_name].copy_(unscaled_state) - def _initialize_state(self, param, state_name, zero_buffer: bool): + def _initialize_state( + self, param, state_name, zero_buffer: bool, store_param_remainders: bool = False + ): """Initialize one of the optimizer states according to `state_name`. Arguments: @@ -299,9 +332,13 @@ def _initialize_state(self, param, state_name, zero_buffer: bool): state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', and 'master_param`. zero_buffer (bool): Whether to initialize the optimizer state with zeros. + store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] - data = torch.empty_like(param, dtype=dtype) + if store_param_remainders: + data = torch.zeros_like(param, dtype=torch.int16) + else: + data = torch.empty_like(param, dtype=dtype) if zero_buffer: data.zero_() @@ -322,17 +359,24 @@ def _initialize_state(self, param, state_name, zero_buffer: bool): [1], dtype=torch.float32, device=param.device ) - def initialize_state(self, param): + def initialize_state(self, param, store_param_remainders): """Initialize optimizer states. Arguments: param (torch.nn.Parameter): One of parameters in this optimizer. + store_param_remainders (bool): Store trailing remainder bits. """ self._initialize_state(param, "exp_avg", zero_buffer=True) self._initialize_state(param, "exp_avg_sq", zero_buffer=True) if self.master_weights: - self._initialize_state(param, "master_param", zero_buffer=False) - self.set_scaled_state(param, "master_param", param.clone().detach().float()) + self._initialize_state( + param, + "master_param", + zero_buffer=False, + store_param_remainders=store_param_remainders, + ) + if not store_param_remainders: + self.set_scaled_state(param, "master_param", param.clone().detach().float()) def state_dict(self): """Override the state_dict() of pytorch. Before returning the state_dict, cast all @@ -377,7 +421,15 @@ def load_state_dict(self, state_dict): param = id_map[k] self.state[param] = {} for name in v: - self.set_scaled_state(param, name, v[name].float()) + if ( + self.store_param_remainders + and name == "master_param" + and param.dtype == torch.bfloat16 + ): + self.set_scaled_state(param, name, v[name]) + assert v[name].dtype == torch.int16 + else: + self.set_scaled_state(param, name, v[name].float()) def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. @@ -444,9 +496,11 @@ def step(self, closure=None, grad_scaler=None): for p in group["params"]: state = self.state[p] + store_param_remainders = self.store_param_remainders and p.dtype == torch.bfloat16 + # State initialization if len(state) == 0: - self.initialize_state(p) + self.initialize_state(p, store_param_remainders) if self.use_decoupled_grad: p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None @@ -462,8 +516,12 @@ def step(self, closure=None, grad_scaler=None): unscaled_state = {} for name in ["exp_avg", "exp_avg_sq", "master_param"]: if name in state: - unscaled = self.get_unscaled_state(p, name) - unscaled_state[name] = unscaled + if name == "master_param" and store_param_remainders: + unscaled_state[name] = self.state[p][name] + assert unscaled_state[name].dtype == torch.int16 + else: + unscaled = self.get_unscaled_state(p, name) + unscaled_state[name] = unscaled if self.name_to_dtype_map[name] != torch.float32: unscaled_lists[name].append(unscaled) scaled_lists[name].append(state[name]) @@ -506,6 +564,12 @@ def step(self, closure=None, grad_scaler=None): ) if has_fp16 and has_bf16: + if self.store_param_remainders: + raise RuntimeError( + "FusedAdam doesn't support a mix of FP16/BF16 weights + Store param" + " remainder." + ) + # simple to add support for this, but not needed for now raise RuntimeError( "FusedAdam does not support a mix of float16 and bfloat16 model weights." @@ -599,7 +663,14 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N v_of_f16_model, p_main_of_f16_model, ] - apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + if self.store_param_remainders and has_bf16 and not has_fp16: + # When you have BF16 params and need FP32 master params, you can reconstruct + # the FP32 master params with BF16 params + int16 remainders + apply_multi_tensor_adam( + self.multi_tensor_adam_param_remainder, tensor_lists + ) + else: + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) if len(p_fp8_model) > 0: tensor_lists = [ g_of_fp8_model,