Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support adam bf16 state #1465

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,20 @@ def test_fp16_exp_avg(self):
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_exp_avg(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.bfloat16,
exp_avg_sq_dtype=torch.float32,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
Expand Down Expand Up @@ -389,6 +403,20 @@ def test_fp16_exp_avg_sq(self):
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
def test_bf16_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.bfloat16,
master_rtol=2e-3,
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ struct FP8Data {
template <>
struct FP8Data<false> {};

template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
template <typename PARAM_T, typename GRAD_T, typename M_T, typename V_T, typename FULL_T,
typename index_t>
struct AdamFunctorMaster {
static constexpr bool is_fp8_type = is_fp8<PARAM_T>::value;

Expand Down Expand Up @@ -83,10 +84,10 @@ struct AdamFunctorMaster {
PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;

FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
M_T *m = reinterpret_cast<M_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;

FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
V_T *v = reinterpret_cast<V_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;

FULL_T *p_master = reinterpret_cast<FULL_T *>(tl.addresses[4][tensor_loc]);
Expand Down Expand Up @@ -151,8 +152,8 @@ struct AdamFunctorMaster {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p_master[i] = static_cast<FULL_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
m[i] = static_cast<M_T>(r_m[ii]);
v[i] = static_cast<V_T>(r_v[ii]);
if constexpr (is_fp8_type) {
__builtin_assume(fp8_data.max >= 0);
fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max);
Expand Down Expand Up @@ -295,7 +296,8 @@ struct AdamFunctorMasterParamRemainder {
}
};

template <typename PARAM_T, typename GRAD_T, typename FULL_T, typename index_t>
template <typename PARAM_T, typename GRAD_T, typename M_T, typename V_T, typename FULL_T,
typename index_t>
struct AdamFunctor {
__device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem,
TensorListMetadata<4> &tl, // NOLINT(*)
Expand All @@ -321,10 +323,10 @@ struct AdamFunctor {
PARAM_T *p = reinterpret_cast<PARAM_T *>(tl.addresses[1][tensor_loc]);
p += chunk_idx * chunk_size;

FULL_T *m = reinterpret_cast<FULL_T *>(tl.addresses[2][tensor_loc]);
M_T *m = reinterpret_cast<M_T *>(tl.addresses[2][tensor_loc]);
m += chunk_idx * chunk_size;

FULL_T *v = reinterpret_cast<FULL_T *>(tl.addresses[3][tensor_loc]);
V_T *v = reinterpret_cast<V_T *>(tl.addresses[3][tensor_loc]);
v += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;
Expand Down Expand Up @@ -376,8 +378,8 @@ struct AdamFunctor {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = static_cast<PARAM_T>(r_p[ii]);
m[i] = static_cast<FULL_T>(r_m[ii]);
v[i] = static_cast<FULL_T>(r_v[ii]);
m[i] = static_cast<M_T>(r_m[ii]);
v[i] = static_cast<V_T>(r_v[ii]);
}
}
}
Expand Down Expand Up @@ -609,6 +611,9 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,

const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
const auto m_in_type = tensor_lists[2][0].scalar_type();
const auto v_in_type = tensor_lists[3][0].scalar_type();

auto tl_size = tensor_lists.size();

// case 4: g, p, m, v
Expand All @@ -622,22 +627,32 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
m_in_type, 2, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
v_in_type, 3, "adam",
multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, scalar_t_2,
scalar_t_3, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);))));
} else {
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
m_in_type, 2, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
v_in_type, 3, "adam",
multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, scalar_t_2,
scalar_t_3, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);))));
}
} else {
if (tl_size == 4) {
Expand All @@ -646,19 +661,29 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
m_in_type, 2, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
v_in_type, 3, "adam",
multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, scalar_t_2,
scalar_t_3, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);))));
} else {
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
m_in_type, 2, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
v_in_type, 3, "adam",
multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, scalar_t_2,
scalar_t_3, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);))));
}
}
AT_CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -732,6 +757,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
}

const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto m_in_type = tensor_lists[2][0].scalar_type();
const auto v_in_type = tensor_lists[3][0].scalar_type();
auto tl_size = tensor_lists.size();

// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
Expand All @@ -742,19 +769,30 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
multi_tensor_apply<5, true>(
(int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
m_in_type, 1, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
v_in_type, 2, "adam",
multi_tensor_apply<5, true>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, scalar_t_1,
scalar_t_2, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);))));
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);));
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
m_in_type, 1, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
v_in_type, 2, "adam",
multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, scalar_t_1,
scalar_t_2, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);))));
}
AT_CUDA_CHECK(cudaGetLastError());
}
Expand Down
20 changes: 12 additions & 8 deletions transformer_engine/pytorch/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def __init__(
# Add constraints to dtypes of states.
if master_weights and master_weight_dtype not in [torch.float32, torch.float16]:
raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.")
if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.")
if exp_avg_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg.")
if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.")

# Currently, capturable mode only supports fp32 master weights and optimizer states.
# The reason is, if the master weights or optimizer states are not in fp32 dtype,
Expand Down Expand Up @@ -284,8 +284,11 @@ def get_unscaled_state(self, param, state_name):
else:
assert state[state_name].dtype == torch.float32
unscaled = state[state_name]
elif dtype == torch.bfloat16:
assert state[state_name].dtype == torch.bfloat16
unscaled = state[state_name]
else:
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.")
raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.")
return unscaled

def set_scaled_state(self, param, state_name, unscaled_state):
Expand All @@ -300,6 +303,7 @@ def set_scaled_state(self, param, state_name, unscaled_state):
and 'master_param`.
unscaled_state (torch.Tensor): The original high-precision(FP32) state.
"""

store_param_remainders = (
self.store_param_remainders
and state_name == "master_param"
Expand All @@ -315,7 +319,7 @@ def set_scaled_state(self, param, state_name, unscaled_state):
self._initialize_state(param, state_name, False, store_param_remainders)

dtype = self.name_to_dtype_map[state_name]
if dtype != torch.float32:
if dtype not in [torch.float32, torch.bfloat16]:
scale = self._scales[param]
self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name])
else:
Expand Down Expand Up @@ -354,7 +358,7 @@ def _initialize_state(
self.state[param][state_name] = data

# Create scale if necessary.
if dtype != torch.float32:
if dtype not in [torch.float32, torch.bfloat16]:
if param not in self._scales:
self._scales[param] = {}
self._scales[param][state_name] = torch.ones(
Expand Down Expand Up @@ -526,7 +530,7 @@ def step(self, closure=None, grad_scaler=None):
else:
unscaled = self.get_unscaled_state(p, name)
unscaled_state[name] = unscaled
if self.name_to_dtype_map[name] != torch.float32:
if self.name_to_dtype_map[name] not in [torch.float32, torch.bfloat16]:
unscaled_lists[name].append(unscaled)
scaled_lists[name].append(state[name])
state_scales[name].append(self._scales[p][name])
Expand Down