Skip to content

Commit

Permalink
unify layernorm api
Browse files Browse the repository at this point in the history
  • Loading branch information
rocking5566 committed Oct 16, 2024
1 parent 0223658 commit e50e331
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 270 deletions.
2 changes: 1 addition & 1 deletion example/ck_tile/02_layernorm2d/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL example_layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_fp16.cpp layernorm2d_fwd_fp32.cpp ${INSTANCE_SRCS})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE layernorm2d_fwd_api.cpp ${INSTANCE_SRCS})

set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS)

Expand Down
17 changes: 2 additions & 15 deletions example/ck_tile/02_layernorm2d/example_layernorm2d_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
#include "layernorm2d_fwd.hpp"
#include <cstring>

extern float layernorm2d_fwd_fp16(layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern float layernorm2d_fwd_fp32(layernorm2d_fwd_args& param, ck_tile::stream_config stream);

auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
Expand Down Expand Up @@ -95,18 +92,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
M,
N};

float ave_time = .0;

if constexpr(std::is_same<DataType, ck_tile::fp16_t>::value)
{
ave_time =
layernorm2d_fwd_fp16(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
else if constexpr(std::is_same<DataType, float>::value)
{
ave_time =
layernorm2d_fwd_fp32(args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});
}
float ave_time =
layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true, 0, warmup, repeat});

std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
sizeof(BetaDataType) * N + sizeof(YDataType) * M * N;
Expand Down
258 changes: 258 additions & 0 deletions example/ck_tile/02_layernorm2d/layernorm2d_fwd_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#include <ck_tile/core.hpp>
#include "layernorm_dispatch.hpp"

// clang-format off
// fp16
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// extern template float run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);

// fp32
extern template float run_layernorm<float, 1, 32, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, false>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 32, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 1, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 2, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 4, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 8, 64, 4, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
extern template float run_layernorm<float, 16, 64, 2, true>(const layernorm2d_fwd_args& param, ck_tile::stream_config stream);
// clang-format on

float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
#if 0
if(a.N % 8 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<ck_tile::fp16_t, 1, 16, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 16, 8, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<ck_tile::fp16_t, 1, 32, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 8, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<ck_tile::fp16_t, 1, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 8, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<ck_tile::fp16_t, 2, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 8, true>(a, s);
}
else
{
return a.N == 2048 ? run_layernorm<ck_tile::fp16_t, 4, 64, 8, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 8, true>(a, s);
}
}
else if(a.N % 4 == 0)
#endif
if(a.N % 4 == 0)
{
if(a.N <= 128)
{
return a.N == 128
? run_layernorm<ck_tile::fp16_t, 1, 32, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 32, 4, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256
? run_layernorm<ck_tile::fp16_t, 1, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 4, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512
? run_layernorm<ck_tile::fp16_t, 2, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 4, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024
? run_layernorm<ck_tile::fp16_t, 4, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 4, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048
? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true>(a, s);
}
else
{
return a.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 8, 64, 4, false, true>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 4, true, true>(a, s);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return a.N == 128
? run_layernorm<ck_tile::fp16_t, 1, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 1, 64, 2, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256
? run_layernorm<ck_tile::fp16_t, 2, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 2, 64, 2, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512
? run_layernorm<ck_tile::fp16_t, 4, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 4, 64, 2, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024
? run_layernorm<ck_tile::fp16_t, 8, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 8, 64, 2, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048
? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false>(a, s)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true>(a, s);
}
else
{
return a.N % 2048 == 0
? run_layernorm<ck_tile::fp16_t, 16, 64, 2, false, true>(a, s)
: run_layernorm<ck_tile::fp16_t, 16, 64, 2, true, true>(a, s);
}
}
}
else if(t.data_type.compare("fp32") == 0)
{
if(a.N % 4 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<float, 1, 32, 4, false>(a, s)
: run_layernorm<float, 1, 32, 4, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<float, 1, 64, 4, false>(a, s)
: run_layernorm<float, 1, 64, 4, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<float, 2, 64, 4, false>(a, s)
: run_layernorm<float, 2, 64, 4, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<float, 4, 64, 4, false>(a, s)
: run_layernorm<float, 4, 64, 4, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<float, 8, 64, 4, false>(a, s)
: run_layernorm<float, 8, 64, 4, true>(a, s);
}
else
{
return a.N % 2048 == 0
? run_layernorm<float, 8, 64, 4, false, true>(a, s)
: run_layernorm<float, 8, 64, 4, true, true>(a, s);
}
}
else if(a.N % 2 == 0)
{
if(a.N <= 128)
{
return a.N == 128 ? run_layernorm<float, 1, 64, 2, false>(a, s)
: run_layernorm<float, 1, 64, 2, true>(a, s);
}
else if(a.N <= 256)
{
return a.N == 256 ? run_layernorm<float, 2, 64, 2, false>(a, s)
: run_layernorm<float, 2, 64, 2, true>(a, s);
}
else if(a.N <= 512)
{
return a.N == 512 ? run_layernorm<float, 4, 64, 2, false>(a, s)
: run_layernorm<float, 4, 64, 2, true>(a, s);
}
else if(a.N <= 1024)
{
return a.N == 1024 ? run_layernorm<float, 8, 64, 2, false>(a, s)
: run_layernorm<float, 8, 64, 2, true>(a, s);
}
else if(a.N <= 2048)
{
return a.N == 2048 ? run_layernorm<float, 16, 64, 2, false>(a, s)
: run_layernorm<float, 16, 64, 2, true>(a, s);
}
else
{
return a.N % 2048 == 0
? run_layernorm<float, 16, 64, 2, false, true>(a, s)
: run_layernorm<float, 16, 64, 2, true, true>(a, s);
}
}
}

return r;
}

Loading

0 comments on commit e50e331

Please sign in to comment.