diff --git a/include/cell/copy/warp.hpp b/include/cell/copy/warp.hpp index 8a21c87b..b75991ce 100644 --- a/include/cell/copy/warp.hpp +++ b/include/cell/copy/warp.hpp @@ -14,7 +14,8 @@ namespace tilefusion::cell::copy::warp { namespace tl = tile_layout; using namespace cute; -namespace { +namespace { // functions/class/structs that are not exposed to a larger scope + // FIXME(ying): This hotfix addresses the current implementation's inability // to explicitly distinguish between shared memory's row-major or // column-major layout and global memory's layouts. However, this should be @@ -54,6 +55,18 @@ struct WarpOffsetHelper { DEVICE int operator()(int i, int j) const { return i * kRowStride; } }; + +/// @brief Helper for pretty printing a BaseTile's static shape-related +/// information. This printer works ONLY on the host. +struct BaseTilePrettyPrinter { + template + static HOST void print(std::ostream& out, const BaseShape& tile) { + // parameter `tile` here is not used + out << "BaseShape = (" << BaseShape::kRows << ", " << BaseShape::kCols + << "), Numel = " << BaseShape::kNumel << ", ThreadLayout = (" + << BaseShape::kRowThreads << ", " << BaseShape::kColThreads << ")"; + } +}; } // namespace // @brief In a thread block, warps are organized as 2-D matrices, each with @@ -261,6 +274,16 @@ struct WarpBaseTileShape { using WarpThreadLayout = tl::ColMajor; }; +/// @brief Pretty printer for the static shape information of a +/// `WarpBaseTileShape`. Note: This printer function works ONLY on the +/// host. +template +static HOST std::ostream& operator<<( + std::ostream& out, const WarpBaseTileShape& tile) { + BaseTilePrettyPrinter::print(out, tile); + return out; +} + template struct GlobalOffsetHelper { static constexpr WarpReuse kMode = kMode_; diff --git a/include/types/global.hpp b/include/types/global.hpp index 29635787..39ba84c8 100644 --- a/include/types/global.hpp +++ b/include/types/global.hpp @@ -9,6 +9,22 @@ namespace tilefusion::cell { namespace tl = tile_layout; +namespace { + +/// @brief Helper for pretty printing a GlobalTile's static shape-related +/// information. This printer works ONLY on the host. +struct GlobalTilePrettyPrinter { + template + static HOST void print(std::ostream& out, const Global& tile) { + // parameter `tile` here is not used + out << layout_type_to_str(Global::kType) << "(" << Global::kRows << ", " + << Global::kCols << ", " << Global::kRowStride << ", " + << Global::kColStride << "), numel = " << Global::kNumel; + } +}; + +} // namespace + template struct GlobalTile { using DType = Element_; @@ -24,6 +40,10 @@ struct GlobalTile { static constexpr tl::Layout kType = tl::layout_type; + // This Ctor is to enable the use of the pretty printer of SharedTile in the + // host code. + HOST GlobalTile() : data_(nullptr), layout_(Layout{}) {} + DEVICE GlobalTile(DType* data) : data_(data), layout_(Layout{}) {} DEVICE GlobalTile(const DType* data) @@ -48,4 +68,14 @@ struct GlobalTile { DType* data_; Layout layout_; }; + +/// @brief Pretty printer for the static shape information of a SharedTile. +/// Note: This printer function works ONLY on the host. +template +static HOST std::ostream& operator<<(std::ostream& out, + const GlobalTile& tile) { + GlobalTilePrettyPrinter::print(out, tile); + return out; +} + } // namespace tilefusion::cell diff --git a/include/types/global_tile_iterator.hpp b/include/types/global_tile_iterator.hpp index 87e01033..3400c76a 100644 --- a/include/types/global_tile_iterator.hpp +++ b/include/types/global_tile_iterator.hpp @@ -9,19 +9,21 @@ namespace tilefusion::cell { namespace tl = tile_layout; -namespace detail { +namespace { /// @brief Helper for pretty printing a tile iterator's static shape-related /// information. This printer works ONLY on the host. struct GTileIteratorPrettyPrinter { template static HOST void print(std::ostream& out, const TileIterator& itr) { - out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape[" - << dim_size<0, typename TileIterator::ChunkShape> << ", " - << dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = " - << TileIterator::sc0 << ", sc1 = " << TileIterator::sc1; + size_t size1 = dim_size<0, typename TileIterator::ChunkShape>; + size_t size2 = dim_size<1, typename TileIterator::ChunkShape>; + + out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = (" + << size1 << ", " << size2 << "), stripe count = (" + << TileIterator::sc0 << ", " << TileIterator::sc1 << ")"; } }; -} // namespace detail +} // namespace /// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles /// and iterates over these smaller sub-tiles. @@ -158,7 +160,7 @@ class GTileIterator { template static HOST std::ostream& operator<<( std::ostream& out, const GTileIterator& itr) { - detail::GTileIteratorPrettyPrinter::print(out, itr); + GTileIteratorPrettyPrinter::print(out, itr); return out; } diff --git a/include/types/register.hpp b/include/types/register.hpp index b91643dd..ab85e808 100644 --- a/include/types/register.hpp +++ b/include/types/register.hpp @@ -10,8 +10,6 @@ namespace tilefusion::cell { namespace tl = tile_layout; -namespace detail { - namespace { template constexpr int get_rows = DType::kRows; @@ -36,16 +34,15 @@ constexpr int get_cols<__half> = 1; template <> constexpr int get_cols = 1; -} // namespace /// @brief Helper for pretty printing a register tile's static shape /// information. This printer works ONLY on the host. struct RegTilePrettyPrinter { template static HOST void print(std::ostream& out, const Tile& tile) { - out << layout_type_to_str(Tile::kType) << "[" + out << layout_type_to_str(Tile::kType) << "(" << Tile::kRows * get_rows << ", " - << Tile::kCols * get_cols << "]"; + << Tile::kCols * get_cols << ")"; } }; @@ -58,12 +55,12 @@ DEVICE void clear(__half* data, int numel) { } template -DEVICE void clear(DType* data, int numel) { +DEVICE void clear_impl(DType* data, int numel) { for (int i = 0; i < numel; ++i) { clear(data[i].mutable_data(), 8); } } -} // namespace detail +} // namespace template class RegTile { @@ -101,7 +98,7 @@ class RegTile { print_tile(const_cast(data_), layout_); } - DEVICE void clear() { detail::clear(data_, kNumel); } + DEVICE void clear() { clear_impl(data_, kNumel); } private: DType data_[kNumel]; @@ -124,7 +121,7 @@ using BaseTileColMajor = RegTile>; template static HOST std::ostream& operator<<(std::ostream& out, const RegTile& tile) { - detail::RegTilePrettyPrinter::print(out, tile); + RegTilePrettyPrinter::print(out, tile); return out; } diff --git a/include/types/shared.hpp b/include/types/shared.hpp index a86064b7..5791cd11 100644 --- a/include/types/shared.hpp +++ b/include/types/shared.hpp @@ -9,6 +9,26 @@ namespace tilefusion::cell { namespace tl = tile_layout; +namespace { + +/// @brief Helper for pretty printing a SharedTile's static shape-related +/// information. This printer works ONLY on the host. +struct SharedTilePrettyPrinter { + template + static HOST void print(std::ostream& out, const Shared& tile) { + // parameter `tile` here is not used + + auto swizzled = Shared::kSwizzled ? "swizzled" : "non-swizzled"; + + out << layout_type_to_str(Shared::kType) << "(" << Shared::kRows << ", " + << Shared::kCols << ", " << Shared::kRowStride << ", " + << Shared::kColStride << "), numel = " << Shared::kNumel + << ", swizzled = " << swizzled; + } +}; + +} // namespace + template class SharedTile { public: @@ -26,6 +46,10 @@ class SharedTile { static constexpr tl::Layout kType = tl::layout_type; static constexpr bool kSwizzled = kSwizzled_; + // This Ctor is to enable the use of the pretty printer of SharedTile in the + // host code. + HOST SharedTile() : data_(nullptr), layout_(Layout{}) {} + DEVICE SharedTile(DType* data) : data_(data), layout_(Layout{}) {} DEVICE DType* mutable_data() { return data_; } @@ -48,4 +72,13 @@ class SharedTile { Layout layout_; }; +/// @brief Pretty printer for the static shape information of a SharedTile. +/// Note: This printer function works ONLY on the host. +template +static HOST std::ostream& operator<<( + std::ostream& out, const SharedTile& tile) { + SharedTilePrettyPrinter::print(out, tile); + return out; +} + } // namespace tilefusion::cell diff --git a/include/types/shared_tile_iterator.hpp b/include/types/shared_tile_iterator.hpp index 80f8f38b..46687305 100644 --- a/include/types/shared_tile_iterator.hpp +++ b/include/types/shared_tile_iterator.hpp @@ -10,21 +10,21 @@ namespace tilefusion::cell { namespace tl = tile_layout; -using namespace cute; - -namespace detail { +namespace { /// @brief Helper for pretty printing a tile iterator's static shape-related /// information. This printer works ONLY on the host. struct STileIteratorPrettyPrinter { template static HOST void print(std::ostream& out, const TileIterator& itr) { - out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape[" - << dim_size<0, typename TileIterator::ChunkShape> << ", " - << dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = " - << TileIterator::sc0 << ", sc1 = " << TileIterator::sc1; + size_t size1 = dim_size<0, typename TileIterator::ChunkShape>; + size_t size2 = dim_size<1, typename TileIterator::ChunkShape>; + + out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = (" + << size1 << ", " << size2 << "), stripe count = (" + << TileIterator::sc0 << ", " << TileIterator::sc1 << ")"; } }; -} // namespace detail +} // namespace /// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles /// and iterates over these smaller sub-tiles. @@ -136,7 +136,7 @@ class STileIterator { template static HOST std::ostream& operator<<( std::ostream& out, const STileIterator& itr) { - detail::STileIteratorPrettyPrinter::print(out, itr); + STileIteratorPrettyPrinter::print(out, itr); return out; } diff --git a/tests/cpp/cell/test_swizzled_copy.cu b/tests/cpp/cell/test_swizzled_copy.cu index 1db67030..8a6d9e97 100644 --- a/tests/cpp/cell/test_swizzled_copy.cu +++ b/tests/cpp/cell/test_swizzled_copy.cu @@ -12,6 +12,8 @@ #include +#define DEBUG + namespace tilefusion::testing { using namespace cell; using namespace copy; @@ -98,8 +100,8 @@ void run_test_rowmajor() { static_assert(kShmRows == kRows, "kShmRows must be equal to kRows"); using Element = __half; - const int kThreads = tl::get_numel * 32; - static constexpr int kWarpPerRow = tl::num_rows; + const int kThreads = WarpLayout::kNumel * 32; + static constexpr int kWarpPerRow = WarpLayout::kRows; using Global = GlobalTile>; using GIterator = GTileIterator>; @@ -124,12 +126,9 @@ void run_test_rowmajor() { LOG(INFO) << "GIterator: " << GIterator{} << std::endl << "SIterator1: " << SIterator1{} << std::endl << "SIterator2: " << SIterator2{} << std::endl - << "GlobalTile Shape: [" << kRows << ", " << kCols << "]" - << std::endl - << "SharedTile Shape: [" << kShmRows << ", " << kShmCols << "]" - << std::endl - << "sc0: " << kSc0 << ", sc1: " << kSc1 << std::endl - << "RegTile Shape: " << Reg{} << std::endl; + << "GlobalTile: " << Global{} << std::endl + << "SharedTile: " << Shared1{} << std::endl + << "RegTile: " << Reg{} << std::endl; #endif using G2S1 = GlobalToSharedLoader; @@ -205,16 +204,12 @@ void run_test_colmajor() { using Reg = RegTile, tl::ColMajor>; #ifdef DEBUG - LOG(INFO) << std::endl - << "GIterator: " << GIterator{} << std::endl + LOG(INFO) << "GIterator: " << GIterator{} << std::endl << "SIterator1: " << SIterator1{} << std::endl << "SIterator2: " << SIterator2{} << std::endl - << "GlobalTile Shape: [" << kRows << ", " << kCols << "]" - << std::endl - << "SharedTile Shape: [" << kShmRows << ", " << kShmCols << "]" - << std::endl - << "sc0: " << kSc0 << ", sc1: " << kSc1 << std::endl - << "RegTile Shape: " << Reg{} << std::endl; + << "GlobalTile: " << Global{} << std::endl + << "SharedTile: " << Shared1{} << std::endl + << "RegTile: " << Reg{} << std::endl; #endif using G2S1 = GlobalToSharedLoader;