Skip to content

Commit

Permalink
Support bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
rocking5566 committed Oct 16, 2024
1 parent 4e14a89 commit 02b9a7d
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 13 deletions.
31 changes: 25 additions & 6 deletions example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@
#include "layernorm2d_fwd.hpp"
#include <cstring>

// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}

template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
Expand Down Expand Up @@ -52,7 +69,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<MeanDataType> mean_host_ref({M});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({M});


ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
Expand Down Expand Up @@ -105,14 +121,13 @@ bool run(const ck_tile::ArgParser& arg_parser)

y_buf.FromDevice(y_host_dev.data());

pass = ck_tile::check_err(y_host_dev, y_host_ref);
auto [rtol, atol] = get_elimit<DataType>();
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);

std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}

std::cout << std::endl << std::flush;
std::cout << "pass = " << pass << std::endl;

return pass;
}

Expand All @@ -127,6 +142,10 @@ int main(int argc, char* argv[])
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}

return -3;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"

template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a)
{
using DataType = typename Traits_::DataType;

using PipelineProblem =
ck_tile::BlockLayernorm2dFwdProblem<typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;

using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;

const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;

return ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{},
grids,
blocks,
0,
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
}

template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat,
NThread,
VectorAccessSize,
false,
false,
kTwoPass>;

using S = ck_tile::stream_config;
using A = layernorm2d_fwd_args;

// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);

template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp"

template <typename Traits_>
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a)
{
using DataType = typename Traits_::DataType;

using PipelineProblem =
ck_tile::BlockLayernorm2dFwdProblem<typename LayerNormTypeConfig<DataType>::XDataType,
typename LayerNormTypeConfig<DataType>::GammaDataType,
typename LayerNormTypeConfig<DataType>::BetaDataType,
typename LayerNormTypeConfig<DataType>::ComputeDataType,
typename LayerNormTypeConfig<DataType>::YDataType,
typename LayerNormTypeConfig<DataType>::MeanDataType,
typename LayerNormTypeConfig<DataType>::InvStdDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveMeanInvStd,
Traits_::kTwoPass>;

using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;

const dim3 grids = Kernel::GridSize(a.M);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;

return ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{},
grids,
blocks,
0,
a.p_x,
a.p_gamma,
a.p_beta,
a.p_y,
a.p_mean,
a.p_invStd,
a.epsilon,
a.M,
a.N));
}

template <ck_tile::index_t NRepeat,
ck_tile::index_t NThread,
ck_tile::index_t VectorAccessSize,
bool kTwoPass>
using t = layernorm2d_fwd_traits_<ck_tile::bf16_t,
NRepeat,
NThread,
VectorAccessSize,
true,
false,
kTwoPass>;

using S = const ck_tile::stream_config;
using A = layernorm2d_fwd_args;

// Disable all vector 8fp16 read/write instances as it has performance issue regarding compiler
// template float layernorm2d_fwd_<t<1, 16, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 32, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<1, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<2, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, false>>(const S&, A);
// template float layernorm2d_fwd_<t<4, 64, 8, true>>(const S&, A);

template float layernorm2d_fwd_<t<1, 32, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<1, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 4, true>>(const S&, A);

template float layernorm2d_fwd_<t<1, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<2, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<4, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<8, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, false>>(const S&, A);
template float layernorm2d_fwd_<t<16, 64, 2, true>>(const S&, A);

template float layernorm2d_fwd_<t<32, 64, 1, false>>(const S&, A);
template float layernorm2d_fwd_<t<32, 64, 1, true>>(const S&, A);
14 changes: 7 additions & 7 deletions example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ struct LayerNormTypeConfig<ck_tile::half_t>
};

template <>
struct LayerNormTypeConfig<float>
struct LayerNormTypeConfig<ck_tile::bf16_t>
{
using XDataType = float;
using YDataType = float;
using GammaDataType = float;
using BetaDataType = float;
using MeanDataType = float;
using InvStdDataType = float;
using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using BetaDataType = ck_tile::bf16_t;
using MeanDataType = ck_tile::bf16_t;
using InvStdDataType = ck_tile::bf16_t;
using ComputeDataType = float;
};

Expand Down
72 changes: 72 additions & 0 deletions example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,78 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
: layernorm2d_fwd_<trait_<ck_tile::fp16_t, 32, 64, 1, true, true>>(s, a);
}
}
else if(t.data_type.compare("bf16") == 0)
{
if(a.N % 4 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 32, 4, true>>(s, a);
}
else if(a.N <= 256)
{
return a.N == 256 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 4, true>>(s, a);
}
else if(a.N <= 512)
{
return a.N == 512 ? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 4, true>>(s, a);
}
else if(a.N <= 1024)
{
return a.N == 1024
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 4, true>>(s, a);
}
else if(a.N <= 2048)
{
return a.N == 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true>>(s, a);
}
else
{
return a.N % 2048 == 0
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, false, true>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 4, true, true>>(s, a);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 64, 2, true>>(s, a);
}
else if(a.N <= 256)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 2, 64, 2, true>>(s, a);
}
else if(a.N <= 512)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 4, 64, 2, true>>(s, a);
}
else if(a.N <= 1024)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 8, 64, 2, true>>(s, a);
}
else if(a.N <= 2048)
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true>>(s, a);
}
else
{
return layernorm2d_fwd_<trait_<ck_tile::bf16_t, 16, 64, 2, true, true>>(s, a);
}
}
else
{
return a.N <= 2048
? layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, false>>(s, a)
: layernorm2d_fwd_<trait_<ck_tile::bf16_t, 32, 64, 1, true, true>>(s, a);
}
}

if(r < 0)
throw std::runtime_error("Without supported instances!");
Expand Down
Empty file modified example/ck_tile/02_layernorm2d/perf_test.sh
100644 → 100755
Empty file.

0 comments on commit 02b9a7d

Please sign in to comment.