diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index ca9e432a4f..0581c45970 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -58,6 +58,7 @@ class layernorm_fwd_codegen: bool kPadN_, bool kSaveMeanInvStd_, bool kFastFDiv_, + bool kWelford_, bool kTwoPass_, ck_tile::index_t kFusedAdd_ = 0, ck_tile::index_t kFusedQuant_ = 0> @@ -120,6 +121,7 @@ class layernorm_fwd_codegen: static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kFastFDiv = kFastFDiv_; + static constexpr bool kWelford = kWelford_; static constexpr bool kTwoPass = kTwoPass_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; @@ -137,6 +139,7 @@ class layernorm_fwd_codegen: bool kPadN_, bool kSaveMeanInvStd_, bool kFastFDiv_, + bool kWelford_, bool kTwoPass_, int kFusedAdd_, int kFusedQuant_> @@ -152,6 +155,7 @@ class layernorm_fwd_codegen: kPadN_, kSaveMeanInvStd_, kFastFDiv_, + kWelford_, kTwoPass_, kFusedAdd_, kFusedQuant_>; @@ -184,6 +188,7 @@ class layernorm_fwd_codegen: using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant)>; @@ -204,12 +209,13 @@ class layernorm_fwd_codegen: using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; using Pipeline = std::conditional_t; - using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; using Default2DEpilogue = ck_tile::Default2DEpilogue; static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; + static constexpr bool UseRawStore = sizeof(YDataType) == 4; using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + ck_tile::DynamicQuantEpilogueTraits>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; @@ -274,7 +280,7 @@ class layernorm_fwd_codegen: #include "layernorm2d_fwd_api_common.hpp" // clang-format off -// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep +// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p add sweep {F_instance_def} // clang-format on @@ -362,6 +368,7 @@ class h_traits: F_kPadN : bool F_kSaveMeanInvStd_ : bool F_kFastFDiv_ : bool + F_kWelford_ : bool F_kTwoPass_ : bool F_kFusedAdd : int F_kFusedQuant : int @@ -369,7 +376,7 @@ class h_traits: @property def trait_name(self) ->str: t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' return t_ @@ -422,11 +429,10 @@ def name_api(self) -> str: def name_common_header(self) -> str: return 'layernorm2d_fwd_api_common' - @property - def content_api(self) -> str: + def content_api(self, args) -> str: # 1 sort based on dtype t_dtype_dict = dict() - blobs = self.get_blobs() + blobs = self.get_blobs(args) for blob in blobs: if blob.F_DataTypePair not in t_dtype_dict: t_dtype_dict[blob.F_DataTypePair] = {} @@ -462,8 +468,8 @@ def content_api(self) -> str: inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), F_VEC_COND = _cond, F_instance_func=ins.call_name) #inner_str = inner_str + vec_str - n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else '' - n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) prec_i, prec_o = dtype_.split(',') d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) @@ -474,7 +480,7 @@ def content_api(self) -> str: def content_common_header(self) -> str: return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) - def get_blobs(self): + def get_blobs(self, args): h_traits = layernorm_fwd_codegen.h_traits h_instance = layernorm_fwd_codegen.h_instance @@ -484,60 +490,61 @@ def get_blobs(self): scale_list = [('fp32,fp32')] dtype_list = [('fp16,fp16'), ('bf16,bf16'), ('fp16,int8'), ('bf16,int8')] # NOTE: only fused-dynamic-quant use int8 out + types_8bit = ('int8', 'fp8') + types_16bit = ('int16', 'fp16', 'bf16') #fused_add_list = [0, 1, 2] #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant fused_add_list = [0, 1] fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant - - # rm rn tm tn vn pd mv fdiv 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]} + # rm rn tm tn vn pd mv fdiv welford 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] @@ -558,16 +565,27 @@ def get_blobs(self): h_.F_YScaleDataType = scale_x h_.F_kFusedAdd = fused_add h_.F_kFusedQuant = fused_quant + # disable welford update for 8bit and 16 bit smallN + if not h_.F_kTwoPass_: + #disable 16 bit when set args disable_16b_welford + if args.disable_16b_welford and prec_i in types_16bit: + h_.F_kWelford_ = False + #disable 8bit by default + elif prec_i in types_8bit or prec_o in types_8bit: + h_.F_kWelford_ = False + #disable 16bit small N + elif prec_i in types_16bit and hs_key == '64': + h_.F_kWelford_ = False current_hs.append(h_) # + "\n" #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ current_n_str = 'big' if hs_key == 'big' else current_n total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, current_hs)) return total_blob - def list_blobs(self) -> None: + def list_blobs(self, args) -> None: w_p = Path(self.working_path) list_p = w_p / 'layernorm2d_fwd_blobs.txt' - blobs = self.get_blobs() + blobs = self.get_blobs(args) with list_p.open('w') as list_f: # api related file list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") @@ -576,11 +594,12 @@ def list_blobs(self) -> None: for b in blobs: list_f.write(str(w_p / (b.name + ".cpp")) + "\n") - def gen_blobs(self) -> None: + def gen_blobs(self, args) -> None: w_p = Path(self.working_path) - (w_p / (self.name_api + ".cpp")).write_text(self.content_api) + w_str = self.content_api(args) + (w_p / (self.name_api + ".cpp")).write_text(w_str) (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) - blobs = self.get_blobs() + blobs = self.get_blobs(args) for b in blobs: (w_p / (b.name + ".cpp")).write_text(b.content) @@ -588,14 +607,14 @@ def list_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': - layernorm_fwd_codegen(args.working_path, args.filter).list_blobs() + layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args) def gen_blobs(args): api_list = args.api.split(',') for api in api_list: if api == 'fwd': - layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs() + layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -663,6 +682,13 @@ def gen_blobs(args): help="codegen receipt." ) + parser.add_argument( + "--disable_16b_welford", + default=False, + required=False, + help="enable/disable welford for 16bit datatype n > 64" + ) + args = parser.parse_args() # print(f'{args.list_blobs}-{args.gen_blobs}') diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh index b7fd354bb8..3f5c3eb134 100755 --- a/example/ck_tile/02_layernorm2d/script/smoke_test.sh +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -27,7 +27,8 @@ $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 done done diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp index 724f6261d5..37f87b4fe0 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -4,8 +4,8 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/welford/block/block_welford_problem.hpp" -#include "ck_tile/ops/welford/block/block_welford.hpp" +#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" +#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" namespace ck_tile { @@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() + CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduce() { - using P_ = BlockWelfordProblem; - - return BlockWelford{}; + using P_ = BlockNormReduceProblem; + return BlockNormReduce{}; } template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() + CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceSync() { - using P_ = BlockWelfordProblem; + using P_ = BlockNormReduceProblem; - return BlockWelfordSync{}; + return BlockNormReduceSync{}; } template - CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() + CK_TILE_HOST_DEVICE static constexpr auto GetBlockNormReduceCrossWarpSync() { - using P_ = BlockWelfordProblem; + using P_ = BlockNormReduceProblem; - return BlockWelfordCrossWarpSync{}; + return BlockNormReduceCrossWarpSync{}; } template @@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy { if constexpr(Problem::kNeedCrossWarpSync) { - using P_ = BlockWelfordProblem; + using P_ = BlockNormReduceProblem; - using block_welford = BlockWelford; + using block_welford = BlockNormReduce; using x_block_tile = decltype(make_static_distributed_tensor( MakeXBlockTileDistribution())); using mean_var_block_tile = decltype(block_welford::template MakeMeanVarBlockTile()); - return GetBlockWelfordCrossWarpSync() + return GetBlockNormReduceCrossWarpSync() .template GetSmemSize(); } else diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index eefdaf9176..a30a9256ab 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; + static constexpr bool kWelford = Problem::Traits::kWelford; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; @@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass int cur_count = 0; int max_count = block_tile_welford_calculate_max_count(row_size); - auto block_welford = Policy::template GetBlockWelford(); - auto block_welford_sync = Policy::template GetBlockWelfordSync(); - auto block_welford_cross_warp_sync = - Policy::template GetBlockWelfordCrossWarpSync(); - + auto block_norm_reduce = Policy::template GetBlockNormReduce(); + auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync(); + auto block_norm_reduce_cross_warp_sync = + Policy::template GetBlockNormReduceCrossWarpSync(); + + using XTensorType = decltype(cast_tile(x)); + auto mean = block_norm_reduce.template MakeMeanVarBlockTile(); + auto var = block_norm_reduce.template MakeMeanVarBlockTile(); + clear_tile(mean); + clear_tile(var); // load gamma/beta (TODO: support no gamma/beta?) const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); @@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass store_tile(y_residual_window, cast_tile(acc)); } - // compute welford each-thread->cross-lane->cross-warp - auto [mean, var] = block_welford(acc, cur_count, max_count); - block_welford_sync(mean, var, cur_count); - block_welford_cross_warp_sync(mean, var, cur_count, smem); - block_tile_welford_post_scale_var(var, cur_count, constant{}); - + // compute reduce each-thread->cross-lane->cross-warp + block_norm_reduce(acc, mean, var, cur_count, max_count); + block_norm_reduce_sync(mean, var, cur_count); + block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem); + if(kWelford) + { + block_tile_welford_post_scale_var(var, cur_count, constant{}); + } + else + { + sweep_tile(mean, [&](auto idx) { + mean(idx) = mean(idx) / type_convert(row_size); + var(idx) = var(idx) / type_convert(row_size) - mean(idx) * mean(idx); + }); + } // compute inv-std auto inv_std = tile_elementwise_in( [&](const auto& v_) { @@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass const auto beta_ = type_convert(beta[j_idx]); auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; - - ln(idx) = ln_; + ln(idx) = ln_; }); if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index 6a86cc43c9..4a37be8776 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; + static constexpr bool kWelford = Problem::Traits::kWelford; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; @@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass void* smem, Epilogue) const { + static_assert(kWelford == true, "2 pass only supports welford merge"); auto x_window = make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); auto gamma_window = make_tile_window( @@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass int max_count = (num_n_tile_iteration - 1) * count_per_iter + block_tile_welford_calculate_max_count(last_iter_n); - auto block_welford = Policy::template GetBlockWelford(); - auto block_welford_sync = Policy::template GetBlockWelfordSync(); - auto block_welford_cross_warp_sync = - Policy::template GetBlockWelfordCrossWarpSync(); + auto block_norm_reduce = Policy::template GetBlockNormReduce(); + auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync(); + auto block_norm_reduce_cross_warp_sync = + Policy::template GetBlockNormReduceCrossWarpSync(); using XTensorType = decltype(cast_tile(load_tile(x_window))); - auto mean = block_welford.template MakeMeanVarBlockTile(); - auto var = block_welford.template MakeMeanVarBlockTile(); + auto mean = block_norm_reduce.template MakeMeanVarBlockTile(); + auto var = block_norm_reduce.template MakeMeanVarBlockTile(); for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { @@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass move_tile_window(y_residual_window, {0, Block_N}); } } - block_welford(acc, mean, var, cur_count, max_count); + block_norm_reduce(acc, mean, var, cur_count, max_count); } - block_welford_sync(mean, var, cur_count); - block_welford_cross_warp_sync(mean, var, cur_count, smem); + block_norm_reduce_sync(mean, var, cur_count); + block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem); block_tile_welford_post_scale_var(var, cur_count, constant{}); // compute inv-std diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp index e8c22f8ab5..045bd24e49 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp @@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName @@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kFastFDiv = kFastFDiv_; + static constexpr bool kWelford = kWelford_; static constexpr bool kTwoPass = kTwoPass_; static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; diff --git a/include/ck_tile/ops/welford.hpp b/include/ck_tile/ops/norm_reduce.hpp similarity index 54% rename from include/ck_tile/ops/welford.hpp rename to include/ck_tile/ops/norm_reduce.hpp index a4c479dd95..02d8eabd8a 100644 --- a/include/ck_tile/ops/welford.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -3,8 +3,8 @@ #pragma once -#include "ck_tile/ops/welford/block/block_welford.hpp" -#include "ck_tile/ops/welford/block/block_welford_problem.hpp" -#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" +#include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" +#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp similarity index 79% rename from include/ck_tile/ops/welford/block/block_welford.hpp rename to include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 56ca86d9df..15ac021631 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -4,22 +4,23 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/welford/thread/thread_welford.hpp" +#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" namespace ck_tile { template -struct BlockWelford +struct BlockNormReduce { using Problem = remove_cvref_t; using XDataType = typename Problem::XDataType; using ComputeDataType = typename Problem::ComputeDataType; static constexpr bool kFastFDiv = Problem::kFastFDiv; + static constexpr bool kWelford = Problem::kWelford; - CK_TILE_DEVICE constexpr BlockWelford() {} + CK_TILE_DEVICE constexpr BlockNormReduce() {} // [CAUSION] - max_count_ is to deal with the padding problem - // max_count_ is depend on caller, eg: naive and splitN welford will have different + // max_count_ is depend on caller, eg: naive and splitN norm_reduce will have different // calculation of max_count_ // -> use block_welford_calculate_max_count to compute template (x_tensor[in_dstr_idx]); - - welford_update(mean_tensor(out_dstr_idx), - var_tensor(out_dstr_idx), - x, - cur_count_, - constant{}); + if(kWelford) + { + welford_update(mean_tensor(out_dstr_idx), + var_tensor(out_dstr_idx), + x, + cur_count_, + constant{}); + } + else + { + mean_tensor(out_dstr_idx) += x; + var_tensor(out_dstr_idx) += x * x; + } }); } }); @@ -91,10 +98,11 @@ struct BlockWelford }; template -struct BlockWelfordSync +struct BlockNormReduceSync { using Problem = remove_cvref_t; static constexpr bool kFastFDiv = Problem::kFastFDiv; + static constexpr bool kWelford = Problem::kWelford; template CK_TILE_DEVICE void @@ -152,36 +160,48 @@ struct BlockWelfordSync (number{}.value); // pull data from remote lane - const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); - const auto v_remote_var = warp_shuffle(v_local_var, src_lane); - const auto v_remote_count = warp_shuffle(v_local_count, src_lane); - - // welford merge - welford_merge(v_local_mean, - v_local_var, - v_local_count, - v_remote_mean, - v_remote_var, - v_remote_count, - constant{}); + const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); + const auto v_remote_var = warp_shuffle(v_local_var, src_lane); + if(kWelford) + { + const auto v_remote_count = warp_shuffle(v_local_count, src_lane); + + // norm_reduce merge + welford_merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count, + constant{}); + } + else + { + v_local_mean += v_remote_mean; + v_local_var += v_remote_var; + } }); } }); mean_tensor.get_thread_buffer()(i) = v_local_mean; var_tensor.get_thread_buffer()(i) = v_local_var; - - count = v_local_count; + if(kWelford) + { + count = v_local_count; + } }); } }; template -struct BlockWelfordCrossWarpSync +struct BlockNormReduceCrossWarpSync { using Problem = remove_cvref_t; using BlockShape = typename Problem::BlockShape; static constexpr bool kFastFDiv = Problem::kFastFDiv; + static constexpr bool kWelford = Problem::kWelford; + using smem_dtype = std::conditional_t; template CK_TILE_DEVICE static constexpr index_t GetReduceWarps() @@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); // Note: we always pack everything into fp32x4 - fp32x4_t* smem_ptr = reinterpret_cast(smem); + smem_dtype* smem_ptr = reinterpret_cast(smem); const index_t lane_id = get_lane_id(); const index_t warp_id = get_warp_id(); constexpr auto num_reduce_warps = GetReduceWarps(); @@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync if(lane_id == 0) { static_for<0, thread_buf_size, 1>{}([&](auto i) { - fp32x4_t local_scratch_; + smem_dtype local_scratch_; local_scratch_[0] = bit_cast(mean_tensor.get_thread_buffer()[i]); local_scratch_[1] = bit_cast(var_tensor.get_thread_buffer()[i]); - local_scratch_[2] = bit_cast(count); - + if(kWelford) + { + local_scratch_[2] = bit_cast(count); + } smem_ptr[smem_offset + i * num_warps] = local_scratch_; }); } @@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync // load from smem. here we let everythread to do compute :) index_t local_warp_id = warp_id / num_reduce_warps; index_t local_smem_os = local_warp_id * num_reduce_warps; - fp32x4_t all_scratch[thread_buf_size * num_reduce_warps]; + smem_dtype all_scratch[thread_buf_size * num_reduce_warps]; static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { all_scratch[i_0 * num_reduce_warps + i_1] = @@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync static_for<0, thread_buf_size, 1>{}([&](auto i_0) { // TODO: use descriptor for this - auto v_local = all_scratch[i_0 * num_reduce_warps]; - auto v_local_mean = bit_cast(v_local[0]); - auto v_local_var = bit_cast(v_local[1]); - auto v_local_count = bit_cast(v_local[2]); + auto v_local = all_scratch[i_0 * num_reduce_warps]; + auto v_local_mean = bit_cast(v_local[0]); + auto v_local_var = bit_cast(v_local[1]); + int v_local_count = kWelford ? bit_cast(v_local[2]) : 0; // further reduce mean/var static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; - const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; + const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; const auto v_remote_mean = bit_cast(v_remote[0]); const auto v_remote_var = bit_cast(v_remote[1]); - const auto v_remote_count = bit_cast(v_remote[2]); - - welford_merge(v_local_mean, - v_local_var, - v_local_count, - v_remote_mean, - v_remote_var, - v_remote_count, - constant{}); + if(kWelford) + { + const auto v_remote_count = bit_cast(v_remote[2]); + + welford_merge(v_local_mean, + v_local_var, + v_local_count, + v_remote_mean, + v_remote_var, + v_remote_count, + constant{}); + } + else + { + v_local_mean += v_remote_mean; + v_local_var += v_remote_var; + } }); mean_tensor.get_thread_buffer()(i_0) = v_local_mean; var_tensor.get_thread_buffer()(i_0) = v_local_var; - - count = v_local_count; + if(kWelford) + count = v_local_count; }); } }; diff --git a/include/ck_tile/ops/welford/block/block_welford_problem.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp similarity index 66% rename from include/ck_tile/ops/welford/block/block_welford_problem.hpp rename to include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp index bcbfb7d76e..53f5bfc6ff 100644 --- a/include/ck_tile/ops/welford/block/block_welford_problem.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp @@ -7,13 +7,18 @@ namespace ck_tile { -template -struct BlockWelfordProblem +template +struct BlockNormReduceProblem { using XDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using BlockShape = remove_cvref_t; static constexpr bool kFastFDiv = kFastFDiv_; + static constexpr bool kWelford = kWelford_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp similarity index 100% rename from include/ck_tile/ops/welford/thread/thread_welford.hpp rename to include/ck_tile/ops/norm_reduce/thread/thread_welford.hpp