Skip to content

Commit

Permalink
Support store_param_remainders feature from Apex in TE Fused Adam (#…
Browse files Browse the repository at this point in the history
…1408)

* Initial commit

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fixed compilation errors

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fixed syntax errors

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed NaN issue when initial param value is zero

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Removed 64 bit indexing instantiation

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Made this feature an opt-in

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Removed arg from unscaled state

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Fixed compilation error

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleaned up errors

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added support for checkpointing

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed checkpointing logic

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* Added tests

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added assert failure for capturable mode

Signed-off-by: Selvaraj Anandaraj <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed pylint errors

Signed-off-by: Selvaraj Anandaraj <[email protected]>

---------

Signed-off-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: Selvaraj Anandaraj <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
4 people authored Jan 31, 2025
1 parent 96534aa commit e536954
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 14 deletions.
19 changes: 18 additions & 1 deletion tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def gen_precision_aware_test(
grad_dtype,
exp_avg_dtype,
exp_avg_sq_dtype,
store_param_remainders=False,
model_rtol=None,
model_atol=None,
master_rtol=None,
Expand Down Expand Up @@ -220,6 +221,7 @@ def gen_precision_aware_test(
"weight_decay": 0,
"amsgrad": False,
}

ref_optim = torch.optim.Adam(ref_params, **options)
tst_optim = te.optimizers.FusedAdam(
model_params,
Expand All @@ -228,6 +230,7 @@ def gen_precision_aware_test(
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)

Expand All @@ -237,7 +240,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer):
p.decoupled_grad = p_ref.grad.clone().to(grad_dtype)
ref_optimizer.step()
tst_optimizer.step()
if use_master_weights:
if use_master_weights and not store_param_remainders:
master_weights_to_fp32 = [
tst_optim.get_unscaled_state(p, "master_param") for p in model_params
]
Expand Down Expand Up @@ -270,6 +273,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer):
exp_avg_dtype=exp_avg_dtype,
exp_avg_sq_dtype=exp_avg_sq_dtype,
use_decoupled_grad=True,
store_param_remainders=store_param_remainders,
**options,
)
tst_optim.load_state_dict(state_dict)
Expand Down Expand Up @@ -300,6 +304,19 @@ def test_fp32_master(self):
exp_avg_sq_dtype=torch.float32,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp32_master_store_param_remainders(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
store_param_remainders=True,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_fp16_master(self):
self.gen_precision_aware_test(
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,12 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
const int step, const int mode, const int bias_correction,
const float weight_decay);

void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay);

void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,122 @@ struct AdamFunctorMaster {
}
};

template <typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctorMasterParamRemainder {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<5> &tl, // NOLINT(*)
const float beta1, const float beta2,
const float beta1_correction,
const float beta2_correction, const float epsilon,
const float lr, adamMode_t mode, const float decay) {
index_t tensor_loc = tl.block_to_tensor[blockIdx.x];

index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
index_t n = tl.sizes[tensor_loc];

GRAD_T *g = reinterpret_cast<GRAD_T *>(tl.addresses[0][tensor_loc]);
g += chunk_idx * chunk_size;

int16_t *p = reinterpret_cast<int16_t *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;

FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;

FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;

int16_t *p_remainder = reinterpret_cast<int16_t *>(tl.addresses[4][tensor_loc]);
p_remainder += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

// see note in multi_tensor_scale_kernel.cu
for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
union fp32_or_int162 {
float fp32;
int16_t int16[2];
};
fp32_or_int162 local_master_param[ILP];
int16_t local_p[ILP];
int16_t local_p_rem[ILP];
MATH_T r_g[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = static_cast<MATH_T>(g[i]);
r_m[ii] = static_cast<MATH_T>(m[i]);
r_v[ii] = static_cast<MATH_T>(v[i]);

local_p[ii] = static_cast<int16_t>(p[i]);
local_p_rem[ii] = static_cast<int16_t>(p_remainder[i]);
} else {
r_g[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);

local_p[ii] = int16_t(0);
local_p_rem[ii] = int16_t(0);
}
}
// Reconstruct FP32 params
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (local_p_rem[ii] < 0) local_p[ii]--; // Undo rounding
local_master_param[ii].int16[1] = local_p[ii];
local_master_param[ii].int16[0] = local_p_rem[ii];
}

MATH_T *r_p = reinterpret_cast<MATH_T *>(local_master_param);

#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}

// Split into BF16 params (rounded-to-nearest) and remainders
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
local_p[ii] = local_master_param[ii].int16[1];
local_p_rem[ii] = local_master_param[ii].int16[0];
if (local_p_rem[ii] < 0) local_p[ii]++; // Round up
}

#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_remainder[i] = static_cast<int16_t>(local_p_rem[ii]);
p[i] = static_cast<int16_t>(local_p[ii]);

m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
}
}
}
}
};

template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
Expand Down Expand Up @@ -548,6 +664,42 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
AT_CUDA_CHECK(cudaGetLastError());
}

void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace at;

// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}

const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();

// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5");
TORCH_CHECK(p_in_type == at::ScalarType::BFloat16,
"Adam with BF16 param remainders requires BF16 params");

// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));

AT_CUDA_CHECK(cudaGetLastError());
}

void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_adam", &multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits",
py::call_guard<py::gil_scoped_release>());
m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda,
"Compute and apply gradient update to parameters for Adam optimizer",
py::call_guard<py::gil_scoped_release>());
Expand Down
Loading

0 comments on commit e536954

Please sign in to comment.