diff --git a/include/cell/copy/global_to_shared.hpp b/include/cell/copy/global_to_shared.hpp index b8b39ad..434b991 100644 --- a/include/cell/copy/global_to_shared.hpp +++ b/include/cell/copy/global_to_shared.hpp @@ -14,15 +14,14 @@ namespace tl = tile_layout; /** * @brief Load a warp tile from global memory to shared memory. * - * This function loads a warp tile whose shape is specified by `BaseShape` - * from global memory to shared memory. + * This function loads a data tile from global to shared memory. * - * @tparam Global_ The type of the global memory pointer. - * @tparam Shared_ The type of the shared memory pointer. - * @tparam BaseShape_ The shape of the warp tile. - * @tparam kRowExec_ The number of rows to execute. - * @tparam kColExec_ The number of columns to execute. - * @tparam kType The type of the elements to be loaded. + * @tparam Global The type of the global memory tile. + * @tparam Shared The type of the shared memory tile. + * @tparam BaseShape The shape of the base tile. + * @tparam kRowExec The number of rows to execute. + * @tparam kColExec The number of columns to execute. + * @tparam kType The type of Global and Shared memory layout. */ template ; + // NOTE: The WarpShape calculated here is for the warp reuse mode `kCont`. + // If you use a different mode, update the WarpShape accordingly. + static_assert((Shared::kRows % WarpLayout ::kRows == 0) && + (Shared::kCols % WarpLayout::kCols == 0), + "The shape of SharedTile must be divisible by the shape of " + "WarpLayout."); + + using WarpShape = TileShape; + using BaseShape = warp::WarpBaseTileShape; static_assert(Shared::kRows % BaseShape ::kRows == 0, "Shared::kRows must be divisible by BaseShape::kRows."); @@ -452,8 +459,9 @@ struct SharedToGlobalStorer { using DType = Shared::DType; using WarpLayout = WarpLayout_; - using BaseShape = - warp::WarpBaseTileShape; + using WarpShape = TileShape; + using BaseShape = warp::WarpBaseTileShape; static_assert(Shared::kRows % BaseShape::kRows == 0, "Shared::kRows must be divisible by BaseShape::kRows."); @@ -494,4 +502,5 @@ struct SharedToGlobalStorer { SharedOffset shared_offset_; GlobalOffset global_offset_; }; + } // namespace tilefusion::cell::copy diff --git a/include/cell/copy/warp.hpp b/include/cell/copy/warp.hpp index b75991c..baa09aa 100644 --- a/include/cell/copy/warp.hpp +++ b/include/cell/copy/warp.hpp @@ -9,10 +9,11 @@ #include "cell/copy/constants.hpp" #include "types/layout.hpp" +#include "types/tile_shape.hpp" namespace tilefusion::cell::copy::warp { +using namespace tilefusion::cell; namespace tl = tile_layout; -using namespace cute; namespace { // functions/class/structs that are not exposed to a larger scope @@ -191,13 +192,16 @@ struct ExecCounter { /// @brief Determine the automatic shape of a single warp based on the shape of /// the entire tile. The final warp tile shape is multiple of this atomic /// shape. -template +template struct WarpBaseTileShape; -template -struct WarpBaseTileShape { +template +struct WarpBaseTileShape { using AccessInfo = traits::AccessBase; + static constexpr int kTileRows = dim_size<0, TileShape>; + static constexpr int kTileCols = dim_size<1, TileShape>; + // In a row-major layout, columns are the contiguous dimension in memory. We // enforce the use of 128-bit vectorized instructions for data loading by a // single thread. This implies that the minimum number of columns should be @@ -205,18 +209,17 @@ struct WarpBaseTileShape { static constexpr int kMinCols = AccessInfo::kAccessInBits / (sizeof(DType) * 8); - static_assert(TileLayout::kCols >= kMinCols, - "The number of columns is too small."); + static_assert(kTileCols >= kMinCols, "The number of columns is too small."); - static_assert(TileLayout::kCols < AccessInfo::kExpectedSize || - (TileLayout::kCols >= AccessInfo::kExpectedSize && - TileLayout::kCols % AccessInfo::kExpectedSize == 0), + static_assert(kTileCols < AccessInfo::kExpectedSize || + (kTileCols >= AccessInfo::kExpectedSize && + kTileCols % AccessInfo::kExpectedSize == 0), "The current implementation requires that the number of " "columns of the tile be divisible by the cache line width."); - static constexpr int kCols = TileLayout::kCols >= AccessInfo::kExpectedSize + static constexpr int kCols = kTileCols >= AccessInfo::kExpectedSize ? AccessInfo::kExpectedSize - : TileLayout::kCols; + : kTileCols; // number of columns in a warp static constexpr int kColThreads = kCols / AccessInfo::kNumPerAccess; @@ -225,7 +228,7 @@ struct WarpBaseTileShape { static constexpr int kRowThreads = WARP_SIZE / kColThreads; static constexpr int kRows = kRowThreads; - static_assert(TileLayout::kRows % kRowThreads == 0, + static_assert(kTileRows % kRowThreads == 0, "The number of rows of the tile isn't evenly divisible by " "the number of threads in a column."); @@ -234,10 +237,13 @@ struct WarpBaseTileShape { using WarpThreadLayout = tl::RowMajor; }; -template -struct WarpBaseTileShape { +template +struct WarpBaseTileShape { using AccessInfo = traits::AccessBase; + static constexpr int kTileRows = dim_size<0, TileShape>; + static constexpr int kTileCols = dim_size<1, TileShape>; + // In a column-major layout, columns are the contiguous dimension in memory. // We enforce the use of 128-bit vectorized instructions for data loading by // a single thread. This implies that the minimum number of columns should @@ -245,18 +251,17 @@ struct WarpBaseTileShape { static constexpr int kMinRows = AccessInfo::kAccessInBits / (sizeof(DType) * 8); - static_assert(TileLayout::kRows >= kMinRows, - "The number of rows is too small."); + static_assert(kTileRows >= kMinRows, "The number of rows is too small."); - static_assert(TileLayout::kRows < AccessInfo::kExpectedSize || - (TileLayout::kRows >= AccessInfo::kExpectedSize && - TileLayout::kRows % AccessInfo::kExpectedSize == 0), + static_assert(kTileRows < AccessInfo::kExpectedSize || + (kTileRows >= AccessInfo::kExpectedSize && + kTileRows % AccessInfo::kExpectedSize == 0), "The current implementation requires that the number of " "rows of the tile be divisible by the cache line width."); - static constexpr int kRows = TileLayout::kRows >= AccessInfo::kExpectedSize + static constexpr int kRows = kTileRows >= AccessInfo::kExpectedSize ? AccessInfo::kExpectedSize - : TileLayout::kRows; + : kTileRows; // number of rows in a warp static constexpr int kRowThreads = kRows / AccessInfo::kNumPerAccess; @@ -265,7 +270,7 @@ struct WarpBaseTileShape { static constexpr int kColThreads = WARP_SIZE / kRowThreads; static constexpr int kCols = kColThreads; - static_assert(TileLayout::kCols % kColThreads == 0, + static_assert(kTileCols % kColThreads == 0, "The number of columns of the tile isn't evenly divisible by " "the number of threads in a row."); diff --git a/tests/cpp/cell/test_atomic_warp_tile_shape.cu b/tests/cpp/cell/test_atomic_warp_tile_shape.cu index 1736d8d..3639dac 100644 --- a/tests/cpp/cell/test_atomic_warp_tile_shape.cu +++ b/tests/cpp/cell/test_atomic_warp_tile_shape.cu @@ -5,18 +5,16 @@ #include "common/test_utils.hpp" namespace tilefusion::testing { + using namespace cell::copy::warp; namespace tl = tile_layout; -#define DEBUG_PRINT 1 - TEST(InferAtomicWarpTile, test1_half_row_major) { using DType = __half; + const tl::Layout kLayout = tl::Layout::kRowMajor; { // atomic warp shape: 32x8, thread layout: 32x1 - using Layout = tl::RowMajor<128, 8>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 32); EXPECT_EQ(WarpTile::kCols, 8); @@ -26,9 +24,7 @@ TEST(InferAtomicWarpTile, test1_half_row_major) { } { // atomic warp shape: 16x16, thread layout: 16x2 - using Layout = tl::RowMajor<64, 16>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 16); EXPECT_EQ(WarpTile::kCols, 16); @@ -38,9 +34,7 @@ TEST(InferAtomicWarpTile, test1_half_row_major) { } { // atomic warp shape: 8x32, thread layout: 8x4 - using Layout = tl::RowMajor<16, 32>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 8); EXPECT_EQ(WarpTile::kCols, 32); @@ -50,9 +44,7 @@ TEST(InferAtomicWarpTile, test1_half_row_major) { } { // atomic warp shape: 4x64, thread layout: 4x8 - using Layout = tl::RowMajor<128, 128>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 4); EXPECT_EQ(WarpTile::kCols, 64); @@ -64,11 +56,10 @@ TEST(InferAtomicWarpTile, test1_half_row_major) { TEST(InferAtomicWarpTile, test2_half_column_major) { using DType = __half; + const tl::Layout kLayout = tl::Layout::kColMajor; { // atomic warp shape: 8x32, thread layout: 1x32 - using Layout = tl::ColMajor<8, 128>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 8); EXPECT_EQ(WarpTile::kCols, 32); @@ -78,9 +69,7 @@ TEST(InferAtomicWarpTile, test2_half_column_major) { } { // atomic warp shape: 16x16, thread layout: 2x16 - using Layout = tl::ColMajor<16, 64>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 16); EXPECT_EQ(WarpTile::kCols, 16); @@ -90,9 +79,7 @@ TEST(InferAtomicWarpTile, test2_half_column_major) { } { // atomic warp shape: 32x8, thread layout: 4x8 - using Layout = tl::ColMajor<32, 16>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 32); EXPECT_EQ(WarpTile::kCols, 8); @@ -102,9 +89,7 @@ TEST(InferAtomicWarpTile, test2_half_column_major) { } { // atomic warp shape: 64x4, thread layout: 8x4 - using Layout = tl::ColMajor<128, 128>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 64); EXPECT_EQ(WarpTile::kCols, 4); @@ -116,11 +101,10 @@ TEST(InferAtomicWarpTile, test2_half_column_major) { TEST(InferAtomicWarpTile, test3_float_row_major) { using DType = float; + const tl::Layout kLayout = tl::Layout::kRowMajor; { // atomic warp shape: 32x4, thread layout: 32x1 - using Layout = tl::RowMajor<128, 4>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 32); EXPECT_EQ(WarpTile::kCols, 4); @@ -130,9 +114,7 @@ TEST(InferAtomicWarpTile, test3_float_row_major) { } { // atomic warp shape: 16x8, thread layout: 16x2 - using Layout = tl::RowMajor<64, 8>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 16); EXPECT_EQ(WarpTile::kCols, 8); @@ -142,9 +124,7 @@ TEST(InferAtomicWarpTile, test3_float_row_major) { } { // atomic warp shape: 8x16, thread layout: 8x4 - using Layout = tl::RowMajor<16, 16>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 8); EXPECT_EQ(WarpTile::kCols, 16); @@ -154,9 +134,7 @@ TEST(InferAtomicWarpTile, test3_float_row_major) { } { // atomic warp shape: 4x32, thread layout: 4x8 - using Layout = tl::RowMajor<128, 128>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 4); EXPECT_EQ(WarpTile::kCols, 32); @@ -168,11 +146,10 @@ TEST(InferAtomicWarpTile, test3_float_row_major) { TEST(InferAtomicWarpTile, test4_float_column_major) { using DType = float; + const tl::Layout kLayout = tl::Layout::kColMajor; { // atomic warp shape: 4x32, thread layout: 1x32 - using Layout = tl::ColMajor<4, 128>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 4); EXPECT_EQ(WarpTile::kCols, 32); @@ -182,9 +159,7 @@ TEST(InferAtomicWarpTile, test4_float_column_major) { } { // atomic warp shape: 8x16, thread layout: 2x16 - using Layout = tl::ColMajor<8, 64>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 8); EXPECT_EQ(WarpTile::kCols, 16); @@ -194,9 +169,7 @@ TEST(InferAtomicWarpTile, test4_float_column_major) { } { // atomic warp shape: 16x8, thread layout: 4x8 - using Layout = tl::ColMajor<16, 32>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 16); EXPECT_EQ(WarpTile::kCols, 8); @@ -206,9 +179,7 @@ TEST(InferAtomicWarpTile, test4_float_column_major) { } { // atomic warp shape: 4x32, thread layout: 8x4 - using Layout = tl::ColMajor<128, 128>; - using WarpTile = - WarpBaseTileShape; + using WarpTile = WarpBaseTileShape, kLayout>; EXPECT_EQ(WarpTile::kRows, 32); EXPECT_EQ(WarpTile::kCols, 4);