From 990507e487c30aafe1658cf95cb40dc540455964 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Fri, 10 Jan 2025 15:29:19 +0800 Subject: [PATCH 1/2] support adam bf16 state Signed-off-by: XiaobingSuper --- tests/pytorch/test_fused_optimizer.py | 28 +++++ .../multi_tensor/multi_tensor_adam.cu | 103 ++++++++++++------ .../pytorch/optimizers/fused_adam.py | 20 ++-- 3 files changed, 107 insertions(+), 44 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 507fd3f350..cec25803f2 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -360,6 +360,20 @@ def test_fp16_exp_avg(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_exp_avg(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.bfloat16, + exp_avg_sq_dtype=torch.float32, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): @@ -389,6 +403,20 @@ def test_fp16_exp_avg_sq(self): master_atol=2e-3, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_bf16_exp_avg_sq(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.bfloat16, + master_rtol=2e-3, + master_atol=2e-3, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg_sq(self): diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 548dd5a267..65c791f572 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -53,7 +53,7 @@ struct FP8Data { template <> struct FP8Data {}; -template +template struct AdamFunctorMaster { static constexpr bool is_fp8_type = is_fp8::value; @@ -83,10 +83,10 @@ struct AdamFunctorMaster { PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + M_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + V_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; FULL_T *p_master = reinterpret_cast(tl.addresses[4][tensor_loc]); @@ -151,8 +151,8 @@ struct AdamFunctorMaster { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p_master[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); if constexpr (is_fp8_type) { __builtin_assume(fp8_data.max >= 0); fp8_data.max = fmaxf(fabsf(r_p[ii]), fp8_data.max); @@ -295,7 +295,8 @@ struct AdamFunctorMasterParamRemainder { } }; -template + +template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -321,10 +322,10 @@ struct AdamFunctor { PARAM_T *p = reinterpret_cast(tl.addresses[1][tensor_loc]); p += chunk_idx * chunk_size; - FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + M_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); m += chunk_idx * chunk_size; - FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + V_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); v += chunk_idx * chunk_size; n -= chunk_idx * chunk_size; @@ -376,8 +377,8 @@ struct AdamFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { p[i] = static_cast(r_p[ii]); - m[i] = static_cast(r_m[ii]); - v[i] = static_cast(r_v[ii]); + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); } } } @@ -609,6 +610,9 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto p_in_type = tensor_lists[1][0].scalar_type(); + const auto m_in_type = tensor_lists[2][0].scalar_type(); + const auto v_in_type = tensor_lists[3][0].scalar_type(); + auto tl_size = tensor_lists.size(); // case 4: g, p, m, v @@ -622,22 +626,31 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type, 1, "adam", - multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, - tensor_lists, - AdamFunctor(), beta1, - beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + m_in_type, 2, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + v_in_type, 3, "adam", + multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay); + )))); } else { // g, p, m, v, p_master DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type, 1, "adam", - multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, - tensor_lists, - AdamFunctorMaster(), - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + m_in_type, 2, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + v_in_type, 3, "adam", + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)))); } } else { if (tl_size == 4) { @@ -646,19 +659,27 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type, 1, "adam", - multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, - beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + m_in_type, 2, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + v_in_type, 3, "adam", + multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctor(), beta1, + beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)))); } else { DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type, 1, "adam", - multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + m_in_type, 2, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + v_in_type, 3, "adam", + multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)))); } } AT_CUDA_CHECK(cudaGetLastError()); @@ -732,6 +753,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, } const auto g_in_type = tensor_lists[0][0].scalar_type(); + const auto m_in_type = tensor_lists[2][0].scalar_type(); + const auto v_in_type = tensor_lists[3][0].scalar_type(); auto tl_size = tensor_lists.size(); // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv @@ -742,19 +765,27 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, fp8_dtype, FP8_T, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type, 0, "adam", - multi_tensor_apply<5, true>( - (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + m_in_type, 1, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + v_in_type, 2, "adam", + multi_tensor_apply<5, true>( + (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)))); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type, 0, "adam", - multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - beta1, beta2, bias_correction1, bias_correction2, epsilon, - lr, (adamMode_t)mode, weight_decay);)); + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + m_in_type, 1, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + v_in_type, 2, "adam", + multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, + lr, (adamMode_t)mode, weight_decay);)))); } AT_CUDA_CHECK(cudaGetLastError()); } diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index d972fd96ab..8f65881701 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -130,10 +130,10 @@ def __init__( # Add constraints to dtypes of states. if master_weights and master_weight_dtype not in [torch.float32, torch.float16]: raise RuntimeError("FusedAdam only supports fp32/fp16 master weights.") - if exp_avg_dtype not in [torch.float32, torch.float16, torch.uint8]: - raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg.") - if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.uint8]: - raise RuntimeError("FusedAdam only supports fp32/fp16/fp8 exp_avg_sq.") + if exp_avg_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg.") + if exp_avg_sq_dtype not in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: + raise RuntimeError("FusedAdam only supports fp32/fp16/bf16/fp8 exp_avg_sq.") # Currently, capturable mode only supports fp32 master weights and optimizer states. # The reason is, if the master weights or optimizer states are not in fp32 dtype, @@ -284,8 +284,11 @@ def get_unscaled_state(self, param, state_name): else: assert state[state_name].dtype == torch.float32 unscaled = state[state_name] + elif dtype == torch.bfloat16: + assert state[state_name].dtype == torch.bfloat16 + unscaled = state[state_name] else: - raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") + raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/bf16/fp32.") return unscaled def set_scaled_state(self, param, state_name, unscaled_state): @@ -300,6 +303,7 @@ def set_scaled_state(self, param, state_name, unscaled_state): and 'master_param`. unscaled_state (torch.Tensor): The original high-precision(FP32) state. """ + store_param_remainders = ( self.store_param_remainders and state_name == "master_param" @@ -315,7 +319,7 @@ def set_scaled_state(self, param, state_name, unscaled_state): self._initialize_state(param, state_name, False, store_param_remainders) dtype = self.name_to_dtype_map[state_name] - if dtype != torch.float32: + if dtype not in [torch.float32, torch.bfloat16]: scale = self._scales[param] self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name]) else: @@ -354,7 +358,7 @@ def _initialize_state( self.state[param][state_name] = data # Create scale if necessary. - if dtype != torch.float32: + if dtype not in [torch.float32, torch.bfloat16]: if param not in self._scales: self._scales[param] = {} self._scales[param][state_name] = torch.ones( @@ -526,7 +530,7 @@ def step(self, closure=None, grad_scaler=None): else: unscaled = self.get_unscaled_state(p, name) unscaled_state[name] = unscaled - if self.name_to_dtype_map[name] != torch.float32: + if self.name_to_dtype_map[name] not in [torch.float32, torch.bfloat16]: unscaled_lists[name].append(unscaled) scaled_lists[name].append(state[name]) state_scales[name].append(self._scales[p][name]) From 364982ae5bcbdc853b075cb7dec679edf2232f35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Feb 2025 01:56:39 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../multi_tensor/multi_tensor_adam.cu | 53 +++++++++++-------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 65c791f572..0778aa3ecd 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -53,7 +53,8 @@ struct FP8Data { template <> struct FP8Data {}; -template +template struct AdamFunctorMaster { static constexpr bool is_fp8_type = is_fp8::value; @@ -295,8 +296,8 @@ struct AdamFunctorMasterParamRemainder { } }; - -template +template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, TensorListMetadata<4> &tl, // NOLINT(*) @@ -632,10 +633,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, v_in_type, 3, "adam", multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, - beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay); - )))); + AdamFunctor(), + beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)))); } else { // g, p, m, v, p_master DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( @@ -648,9 +649,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, v_in_type, 3, "adam", multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)))); + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)))); } } else { if (tl_size == 4) { @@ -664,9 +666,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( v_in_type, 3, "adam", multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), beta1, - beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)))); + AdamFunctor(), + beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)))); } else { DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( p_in_type, 0, "adam", @@ -677,9 +680,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( v_in_type, 3, "adam", multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, - (adamMode_t)mode, weight_decay);)))); + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)))); } } AT_CUDA_CHECK(cudaGetLastError()); @@ -769,10 +773,12 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, m_in_type, 1, "adam", DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( v_in_type, 2, "adam", - multi_tensor_apply<5, true>( - (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), beta1, beta2, - bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)))); + multi_tensor_apply<5, true>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)))); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, @@ -783,9 +789,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( v_in_type, 2, "adam", multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), - beta1, beta2, bias_correction1, bias_correction2, epsilon, - lr, (adamMode_t)mode, weight_decay);)))); + AdamFunctorMaster(), + beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)))); } AT_CUDA_CHECK(cudaGetLastError()); }