Skip to content

Commit

Permalink
fix(cell): bug fix for inferring the BaseTile in g2s data tile tran…
Browse files Browse the repository at this point in the history
…sfer. (#50)

In the process of transferring data tiles between global and shared
memory, `BaseTile` should be inferred from a warp tile instead of a
shared memory tile. This pull request fixes that bug.
  • Loading branch information
lcy-seso authored Jan 28, 2025
1 parent e976316 commit 0a1a287
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 84 deletions.
33 changes: 21 additions & 12 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ namespace tl = tile_layout;
/**
* @brief Load a warp tile from global memory to shared memory.
*
* This function loads a warp tile whose shape is specified by `BaseShape`
* from global memory to shared memory.
* This function loads a data tile from global to shared memory.
*
* @tparam Global_ The type of the global memory pointer.
* @tparam Shared_ The type of the shared memory pointer.
* @tparam BaseShape_ The shape of the warp tile.
* @tparam kRowExec_ The number of rows to execute.
* @tparam kColExec_ The number of columns to execute.
* @tparam kType The type of the elements to be loaded.
* @tparam Global The type of the global memory tile.
* @tparam Shared The type of the shared memory tile.
* @tparam BaseShape The shape of the base tile.
* @tparam kRowExec The number of rows to execute.
* @tparam kColExec The number of columns to execute.
* @tparam kType The type of Global and Shared memory layout.
*/
template <typename Global, typename Shared, typename BaseShape,
const int kRowExec, const int kColExec,
Expand Down Expand Up @@ -399,8 +398,16 @@ struct GlobalToSharedLoader {
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

using BaseShape =
warp::WarpBaseTileShape<DType, typename Shared::Layout, Shared::kType>;
// NOTE: The WarpShape calculated here is for the warp reuse mode `kCont`.
// If you use a different mode, update the WarpShape accordingly.
static_assert((Shared::kRows % WarpLayout ::kRows == 0) &&
(Shared::kCols % WarpLayout::kCols == 0),
"The shape of SharedTile must be divisible by the shape of "
"WarpLayout.");

using WarpShape = TileShape<Shared::kRows / WarpLayout::kRows,
Shared::kCols / WarpLayout::kCols>;
using BaseShape = warp::WarpBaseTileShape<DType, WarpShape, Shared::kType>;

static_assert(Shared::kRows % BaseShape ::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
Expand Down Expand Up @@ -452,8 +459,9 @@ struct SharedToGlobalStorer {
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

using BaseShape =
warp::WarpBaseTileShape<DType, typename Shared::Layout, Shared::kType>;
using WarpShape = TileShape<Shared::kRows / WarpLayout::kRows,
Shared::kCols / WarpLayout::kCols>;
using BaseShape = warp::WarpBaseTileShape<DType, WarpShape, Shared::kType>;

static_assert(Shared::kRows % BaseShape::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
Expand Down Expand Up @@ -494,4 +502,5 @@ struct SharedToGlobalStorer {
SharedOffset shared_offset_;
GlobalOffset global_offset_;
};

} // namespace tilefusion::cell::copy
49 changes: 27 additions & 22 deletions include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@

#include "cell/copy/constants.hpp"
#include "types/layout.hpp"
#include "types/tile_shape.hpp"

namespace tilefusion::cell::copy::warp {
using namespace tilefusion::cell;
namespace tl = tile_layout;
using namespace cute;

namespace { // functions/class/structs that are not exposed to a larger scope

Expand Down Expand Up @@ -191,32 +192,34 @@ struct ExecCounter {
/// @brief Determine the automatic shape of a single warp based on the shape of
/// the entire tile. The final warp tile shape is multiple of this atomic
/// shape.
template <typename DType, typename TileLayout, const tl::Layout kType>
template <typename DType, typename TileShape, const tl::Layout kType>
struct WarpBaseTileShape;

template <typename DType, typename TileLayout>
struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kRowMajor> {
template <typename DType, typename TileShape>
struct WarpBaseTileShape<DType, TileShape, tl::Layout::kRowMajor> {
using AccessInfo = traits::AccessBase<DType>;

static constexpr int kTileRows = dim_size<0, TileShape>;
static constexpr int kTileCols = dim_size<1, TileShape>;

// In a row-major layout, columns are the contiguous dimension in memory. We
// enforce the use of 128-bit vectorized instructions for data loading by a
// single thread. This implies that the minimum number of columns should be
// at least 128 bits.
static constexpr int kMinCols =
AccessInfo::kAccessInBits / (sizeof(DType) * 8);

static_assert(TileLayout::kCols >= kMinCols,
"The number of columns is too small.");
static_assert(kTileCols >= kMinCols, "The number of columns is too small.");

static_assert(TileLayout::kCols < AccessInfo::kExpectedSize ||
(TileLayout::kCols >= AccessInfo::kExpectedSize &&
TileLayout::kCols % AccessInfo::kExpectedSize == 0),
static_assert(kTileCols < AccessInfo::kExpectedSize ||
(kTileCols >= AccessInfo::kExpectedSize &&
kTileCols % AccessInfo::kExpectedSize == 0),
"The current implementation requires that the number of "
"columns of the tile be divisible by the cache line width.");

static constexpr int kCols = TileLayout::kCols >= AccessInfo::kExpectedSize
static constexpr int kCols = kTileCols >= AccessInfo::kExpectedSize
? AccessInfo::kExpectedSize
: TileLayout::kCols;
: kTileCols;

// number of columns in a warp
static constexpr int kColThreads = kCols / AccessInfo::kNumPerAccess;
Expand All @@ -225,7 +228,7 @@ struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kRowMajor> {
static constexpr int kRowThreads = WARP_SIZE / kColThreads;

static constexpr int kRows = kRowThreads;
static_assert(TileLayout::kRows % kRowThreads == 0,
static_assert(kTileRows % kRowThreads == 0,
"The number of rows of the tile isn't evenly divisible by "
"the number of threads in a column.");

Expand All @@ -234,29 +237,31 @@ struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kRowMajor> {
using WarpThreadLayout = tl::RowMajor<kRowThreads, kColThreads>;
};

template <typename DType, typename TileLayout>
struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kColMajor> {
template <typename DType, typename TileShape>
struct WarpBaseTileShape<DType, TileShape, tl::Layout::kColMajor> {
using AccessInfo = traits::AccessBase<DType>;

static constexpr int kTileRows = dim_size<0, TileShape>;
static constexpr int kTileCols = dim_size<1, TileShape>;

// In a column-major layout, columns are the contiguous dimension in memory.
// We enforce the use of 128-bit vectorized instructions for data loading by
// a single thread. This implies that the minimum number of columns should
// be at least 128 bits.
static constexpr int kMinRows =
AccessInfo::kAccessInBits / (sizeof(DType) * 8);

static_assert(TileLayout::kRows >= kMinRows,
"The number of rows is too small.");
static_assert(kTileRows >= kMinRows, "The number of rows is too small.");

static_assert(TileLayout::kRows < AccessInfo::kExpectedSize ||
(TileLayout::kRows >= AccessInfo::kExpectedSize &&
TileLayout::kRows % AccessInfo::kExpectedSize == 0),
static_assert(kTileRows < AccessInfo::kExpectedSize ||
(kTileRows >= AccessInfo::kExpectedSize &&
kTileRows % AccessInfo::kExpectedSize == 0),
"The current implementation requires that the number of "
"rows of the tile be divisible by the cache line width.");

static constexpr int kRows = TileLayout::kRows >= AccessInfo::kExpectedSize
static constexpr int kRows = kTileRows >= AccessInfo::kExpectedSize
? AccessInfo::kExpectedSize
: TileLayout::kRows;
: kTileRows;

// number of rows in a warp
static constexpr int kRowThreads = kRows / AccessInfo::kNumPerAccess;
Expand All @@ -265,7 +270,7 @@ struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kColMajor> {
static constexpr int kColThreads = WARP_SIZE / kRowThreads;

static constexpr int kCols = kColThreads;
static_assert(TileLayout::kCols % kColThreads == 0,
static_assert(kTileCols % kColThreads == 0,
"The number of columns of the tile isn't evenly divisible by "
"the number of threads in a row.");

Expand Down
71 changes: 21 additions & 50 deletions tests/cpp/cell/test_atomic_warp_tile_shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
#include "common/test_utils.hpp"

namespace tilefusion::testing {

using namespace cell::copy::warp;
namespace tl = tile_layout;

#define DEBUG_PRINT 1

TEST(InferAtomicWarpTile, test1_half_row_major) {
using DType = __half;
const tl::Layout kLayout = tl::Layout::kRowMajor;

{ // atomic warp shape: 32x8, thread layout: 32x1
using Layout = tl::RowMajor<128, 8>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<128, 8>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -26,9 +24,7 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {
}

{ // atomic warp shape: 16x16, thread layout: 16x2
using Layout = tl::RowMajor<64, 16>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<64, 16>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -38,9 +34,7 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {
}

{ // atomic warp shape: 8x32, thread layout: 8x4
using Layout = tl::RowMajor<16, 32>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<16, 32>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -50,9 +44,7 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {
}

{ // atomic warp shape: 4x64, thread layout: 4x8
using Layout = tl::RowMajor<128, 128>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<128, 128>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 4);
EXPECT_EQ(WarpTile::kCols, 64);
Expand All @@ -64,11 +56,10 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {

TEST(InferAtomicWarpTile, test2_half_column_major) {
using DType = __half;
const tl::Layout kLayout = tl::Layout::kColMajor;

{ // atomic warp shape: 8x32, thread layout: 1x32
using Layout = tl::ColMajor<8, 128>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<8, 128>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -78,9 +69,7 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {
}

{ // atomic warp shape: 16x16, thread layout: 2x16
using Layout = tl::ColMajor<16, 64>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<16, 64>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -90,9 +79,7 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {
}

{ // atomic warp shape: 32x8, thread layout: 4x8
using Layout = tl::ColMajor<32, 16>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<32, 16>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -102,9 +89,7 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {
}

{ // atomic warp shape: 64x4, thread layout: 8x4
using Layout = tl::ColMajor<128, 128>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<128, 128>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 64);
EXPECT_EQ(WarpTile::kCols, 4);
Expand All @@ -116,11 +101,10 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {

TEST(InferAtomicWarpTile, test3_float_row_major) {
using DType = float;
const tl::Layout kLayout = tl::Layout::kRowMajor;

{ // atomic warp shape: 32x4, thread layout: 32x1
using Layout = tl::RowMajor<128, 4>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<128, 4>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 4);
Expand All @@ -130,9 +114,7 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {
}

{ // atomic warp shape: 16x8, thread layout: 16x2
using Layout = tl::RowMajor<64, 8>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<64, 8>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -142,9 +124,7 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {
}

{ // atomic warp shape: 8x16, thread layout: 8x4
using Layout = tl::RowMajor<16, 16>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<16, 16>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -154,9 +134,7 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {
}

{ // atomic warp shape: 4x32, thread layout: 4x8
using Layout = tl::RowMajor<128, 128>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<128, 128>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 4);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -168,11 +146,10 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {

TEST(InferAtomicWarpTile, test4_float_column_major) {
using DType = float;
const tl::Layout kLayout = tl::Layout::kColMajor;

{ // atomic warp shape: 4x32, thread layout: 1x32
using Layout = tl::ColMajor<4, 128>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<4, 128>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 4);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -182,9 +159,7 @@ TEST(InferAtomicWarpTile, test4_float_column_major) {
}

{ // atomic warp shape: 8x16, thread layout: 2x16
using Layout = tl::ColMajor<8, 64>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<8, 64>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -194,9 +169,7 @@ TEST(InferAtomicWarpTile, test4_float_column_major) {
}

{ // atomic warp shape: 16x8, thread layout: 4x8
using Layout = tl::ColMajor<16, 32>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<16, 32>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -206,9 +179,7 @@ TEST(InferAtomicWarpTile, test4_float_column_major) {
}

{ // atomic warp shape: 4x32, thread layout: 8x4
using Layout = tl::ColMajor<128, 128>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile = WarpBaseTileShape<DType, TileShape<128, 128>, kLayout>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 4);
Expand Down

0 comments on commit 0a1a287

Please sign in to comment.