Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support store_param_remainders feature from Apex in TE Fused Adam #1408

Merged
merged 29 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9072c5f
Initial commit
Jan 13, 2025
7d5d0dc
Fixed compilation errors
Jan 13, 2025
bcaf16d
Merge branch 'main' into param_remainder
sanandaraj5597 Jan 13, 2025
979a4c1
Fixed syntax errors
Jan 16, 2025
75b737c
Merge branch 'param_remainder' of https://github.com/sanandaraj5597/T…
Jan 16, 2025
887432d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2025
1f51592
Fixed NaN issue when initial param value is zero
Jan 30, 2025
bf1a07b
Merge branch 'param_remainder' of https://github.com/sanandaraj5597/T…
Jan 30, 2025
d662d78
Removed 64 bit indexing instantiation
Jan 30, 2025
fa28b75
Made this feature an opt-in
Jan 30, 2025
0dd4b79
Removed arg from unscaled state
Jan 30, 2025
0a257f3
Fixed compilation error
Jan 30, 2025
98cf4fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2025
334e8cf
Cleaned up errors
Jan 30, 2025
a4e46c9
Cleaned up errors
Jan 30, 2025
0feb6f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2025
07f1479
Added support for checkpointing
Jan 30, 2025
5a4a91d
Merge branch 'param_remainder' of https://github.com/sanandaraj5597/T…
Jan 30, 2025
251b5ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2025
1d917cd
Fixed checkpointing logic
Jan 30, 2025
e10f6c3
Added tests
Jan 30, 2025
30499d5
Fixed merge conflicts
Jan 30, 2025
5c202ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2025
123ea3f
Added assert failure for capturable mode
Jan 30, 2025
4d90bcc
Merge branch 'param_remainder' of https://github.com/sanandaraj5597/T…
Jan 30, 2025
ec76525
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2025
7a144ec
Merge branch 'main' into param_remainder
timmoon10 Jan 30, 2025
d08b814
Fixed pylint errors
Jan 31, 2025
f6cc32b
Merge branch 'param_remainder' of https://github.com/sanandaraj5597/T…
Jan 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,12 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const int step, const int mode, const int bias_correction,
const float weight_decay);

void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay);

void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,122 @@ struct AdamFunctorMaster {
}
};

template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
sanandaraj5597 marked this conversation as resolved.
Show resolved Hide resolved
struct AdamFunctorMasterParamRemainder {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<5> &tl, // NOLINT(*)
const float beta1, const float beta2,
const float beta1_correction,
const float beta2_correction, const float epsilon,
const float lr, adamMode_t mode, const float decay) {
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;

int16_t *p = reinterpret_cast<int16_t *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;

FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;

FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;

int16_t *p_remainder = reinterpret_cast<int16_t *>(tl.addresses[4][tensor_loc]);
p_remainder += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
union fp32_or_int162 {
float fp32;
int16_t int16[2];
};
fp32_or_int162 local_master_param[ILP];
int16_t local_p[ILP];
int16_t local_p_rem[ILP];
MATH_T r_g[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);

local_p[ii] = static_cast<int16_t>(p[i]);
local_p_rem[ii] = static_cast<int16_t>(p_remainder[i]);
} else {
r_g[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);

local_p[ii] = int16_t(0);
local_p_rem[ii] = int16_t(0);
}
}
// Reconstruct FP32 params
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (local_p_rem[ii] < 0) local_p[ii]--; // Undo rounding
local_master_param[ii].int16[1] = local_p[ii];
local_master_param[ii].int16[0] = local_p_rem[ii];
}

MATH_T *r_p = reinterpret_cast<MATH_T *>(local_master_param);

#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}

// Split into BF16 params (rounded-to-nearest) and remainders
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
local_p[ii] = local_master_param[ii].int16[1];
local_p_rem[ii] = local_master_param[ii].int16[0];
if (local_p_rem[ii] < 0) local_p[ii]++; // Round up
}

#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_remainder[i] = static_cast<int16_t>(local_p_rem[ii]);
p[i] = static_cast<int16_t>(local_p[ii]);

m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
}
};

template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
Expand Down Expand Up @@ -548,6 +664,69 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
AT_CUDA_CHECK(cudaGetLastError());
}

void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace at;

// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}

size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
}
}
}
if (requires_64bit_indexing) {
break;
}
}

const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();

// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5");

if (requires_64bit_indexing) {
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
sanandaraj5597 marked this conversation as resolved.
Show resolved Hide resolved
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} else {
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>(
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
AT_CUDA_CHECK(cudaGetLastError());
}

void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
Expand Down
59 changes: 49 additions & 10 deletions transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ class FusedAdam(torch.optim.Optimizer):
instead of ".grad" for reading gradients. It's useful when the dtypes
of grad and param are different.
(default: False)
store_param_remainders (bool, optional): Whether to store entire FP32 master
params or just store the trailing 16 remainder bits. Whole FP32 master can be
reconstructed from BF16 params plus the trailing remainder bits. Works only
when param type is BF16 and master weight type is FP32, no effect otherwise.
Useful memory saving optimization.
(default: True)


.. _Adam - A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
Expand All @@ -118,6 +125,7 @@ def __init__(
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
use_decoupled_grad=False,
store_param_remainders=True,
sanandaraj5597 marked this conversation as resolved.
Show resolved Hide resolved
):

if amsgrad:
Expand Down Expand Up @@ -172,6 +180,7 @@ def __init__(
# Skip buffer
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda")
self.multi_tensor_adam = tex.multi_tensor_adam
self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder
self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8
self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable
self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master
Expand All @@ -192,6 +201,10 @@ def __init__(
}
self._scales = {}
self.use_decoupled_grad = use_decoupled_grad
# Works only when master params is in FP32
self.store_param_remainders = (
store_param_remainders and master_weights and master_weight_dtype == torch.float32
)
timmoon10 marked this conversation as resolved.
Show resolved Hide resolved

def zero_grad(self):
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
unscaled_state.mul_(rscale)
scaled_state.copy_(unscaled_state)

def get_unscaled_state(self, param, state_name):
def get_unscaled_state(self, param, state_name, store_param_remainders=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of store_param_remainders is False here, but it's True by default in the constructor. I think it's misleading, why not just set it to True here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to store param remainders for state_name other than master_params, that's why it's defaulted to false.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if this function didn't expose this kwarg since it makes its behavior less obvious. get_unscaled_state implies that it produces an FP32 value that is ready to use, so it would be better if step called a different function to access the BF16 remainder. If we want to keep this overall logic, we should change the function name to something more accurate (although a vague name like get_state_for_adam_kernel is a code smell).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worked around it. Resolving conversation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better, although we still have the problem that state scaling and BF16 remainders are both using this function in different ways. It's troubling that get_unscaled_state might not get the unscaled state.

Copy link
Contributor Author

@sanandaraj5597 sanandaraj5597 Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree your point. But using this function both with/without feature makes it look very efficient. Writing a separate function needs new function usage across step function, checkpointing, etc.

I've also tried to add assert checks inside the function to tighten the understanding/correctness. Hope you are fine with it.

"""Return the unscaled state corresponding to the input `param` and `state_name`.

Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
store_param_remainders (bool): Option to store trailing parameter remainder bits only
"""
state = self.state[param]
dtype = self.name_to_dtype_map[state_name]
Expand All @@ -261,7 +275,10 @@ def get_unscaled_state(self, param, state_name):
unscaled = state[state_name].float()
unscaled.mul_(self._scales[param][state_name])
elif dtype == torch.float32:
assert state[state_name].dtype == torch.float32
if store_param_remainders and state_name == "master_param":
assert state[state_name].dtype == torch.int16
else:
assert state[state_name].dtype == torch.float32
unscaled = state[state_name]
else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.")
Expand Down Expand Up @@ -291,17 +308,23 @@ def set_scaled_state(self, param, state_name, unscaled_state):
else:
state[state_name].copy_(unscaled_state)

def _initialize_state(self, param, state_name, zero_buffer: bool):
def _initialize_state(
self, param, state_name, zero_buffer: bool, store_param_remainders: bool = False
):
"""Initialize one of the optimizer states according to `state_name`.

Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
and 'master_param`.
zero_buffer (bool): Whether to initialize the optimizer state with zeros.
store_param_remainders (bool): Store only trailing remainder bits.
"""
dtype = self.name_to_dtype_map[state_name]
data = torch.empty_like(param, dtype=dtype)
if store_param_remainders:
data = torch.empty_like(param, dtype=torch.int16)
else:
data = torch.empty_like(param, dtype=dtype)
if zero_buffer:
data.zero_()

Expand All @@ -322,17 +345,24 @@ def _initialize_state(self, param, state_name, zero_buffer: bool):
[1], dtype=torch.float32, device=param.device
)

def initialize_state(self, param):
def initialize_state(self, param, store_param_remainders):
"""Initialize optimizer states.

Arguments:
param (torch.nn.Parameter): One of parameters in this optimizer.
store_param_remainders (bool): Store trailing remainder bits.
"""
self._initialize_state(param, "exp_avg", zero_buffer=True)
self._initialize_state(param, "exp_avg_sq", zero_buffer=True)
if self.master_weights:
self._initialize_state(param, "master_param", zero_buffer=False)
self.set_scaled_state(param, "master_param", param.clone().detach().float())
self._initialize_state(
param,
"master_param",
zero_buffer=False,
store_param_remainders=store_param_remainders,
)
if not store_param_remainders:
self.set_scaled_state(param, "master_param", param.clone().detach().float())

def state_dict(self):
"""Override the state_dict() of pytorch. Before returning the state_dict, cast all
Expand Down Expand Up @@ -444,9 +474,11 @@ def step(self, closure=None, grad_scaler=None):
for p in group["params"]:
state = self.state[p]

store_param_remainders = self.store_param_remainders and p.dtype == torch.bfloat16

# State initialization
if len(state) == 0:
self.initialize_state(p)
self.initialize_state(p, store_param_remainders)

if self.use_decoupled_grad:
p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None
Expand All @@ -462,7 +494,7 @@ def step(self, closure=None, grad_scaler=None):
unscaled_state = {}
for name in ["exp_avg", "exp_avg_sq", "master_param"]:
if name in state:
unscaled = self.get_unscaled_state(p, name)
unscaled = self.get_unscaled_state(p, name, store_param_remainders)
unscaled_state[name] = unscaled
if self.name_to_dtype_map[name] != torch.float32:
unscaled_lists[name].append(unscaled)
Expand Down Expand Up @@ -599,7 +631,14 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N
v_of_f16_model,
p_main_of_f16_model,
]
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
if self.store_param_remainders and has_bf16 and not has_fp16:
# When you have BF16 params and need FP32 master params, you can reconstruct
# the FP32 master params with BF16 params + int16 remainders
apply_multi_tensor_adam(
self.multi_tensor_adam_param_remainder, tensor_lists
)
else:
apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists)
if len(p_fp8_model) > 0:
tensor_lists = [
g_of_fp8_model,
Expand Down