Skip to content

Commit

Permalink
feat(cell): Added pretty printer for BaseTile/SharedTile/GlobalTile. (#…
Browse files Browse the repository at this point in the history
…48)

Added pretty printers for `BaseTile`, `SharedTile`, and `GlobalTile` to
display static shape and layout-related information.

For example, to use `BaseTile`, you can do the following:

```cpp
using BaseShape = WarpBaseTileShape<...>;
std::cout << BaseShape{} << std::endl;
```
  • Loading branch information
lcy-seso authored Jan 28, 2025
1 parent 29ca160 commit 1b922c1
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 42 deletions.
25 changes: 24 additions & 1 deletion include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace tilefusion::cell::copy::warp {
namespace tl = tile_layout;
using namespace cute;

namespace {
namespace { // functions/class/structs that are not exposed to a larger scope

// FIXME(ying): This hotfix addresses the current implementation's inability
// to explicitly distinguish between shared memory's row-major or
// column-major layout and global memory's layouts. However, this should be
Expand Down Expand Up @@ -54,6 +55,18 @@ struct WarpOffsetHelper<WarpReuse::kRowReuseCont, kRowStride_, kColStride_> {

DEVICE int operator()(int i, int j) const { return i * kRowStride; }
};

/// @brief Helper for pretty printing a BaseTile's static shape-related
/// information. This printer works ONLY on the host.
struct BaseTilePrettyPrinter {
template <typename BaseShape>
static HOST void print(std::ostream& out, const BaseShape& tile) {
// parameter `tile` here is not used
out << "BaseShape = (" << BaseShape::kRows << ", " << BaseShape::kCols
<< "), Numel = " << BaseShape::kNumel << ", ThreadLayout = ("
<< BaseShape::kRowThreads << ", " << BaseShape::kColThreads << ")";
}
};
} // namespace

// @brief In a thread block, warps are organized as 2-D matrices, each with
Expand Down Expand Up @@ -261,6 +274,16 @@ struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kColMajor> {
using WarpThreadLayout = tl::ColMajor<kRowThreads, kColThreads>;
};

/// @brief Pretty printer for the static shape information of a
/// `WarpBaseTileShape`. Note: This printer function works ONLY on the
/// host.
template <typename DType, typename TileShape, const tl::Layout kType>
static HOST std::ostream& operator<<(
std::ostream& out, const WarpBaseTileShape<DType, TileShape, kType>& tile) {
BaseTilePrettyPrinter::print(out, tile);
return out;
}

template <typename WarpLayout_, const WarpReuse kMode_>
struct GlobalOffsetHelper {
static constexpr WarpReuse kMode = kMode_;
Expand Down
30 changes: 30 additions & 0 deletions include/types/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace {

/// @brief Helper for pretty printing a GlobalTile's static shape-related
/// information. This printer works ONLY on the host.
struct GlobalTilePrettyPrinter {
template <typename Global>
static HOST void print(std::ostream& out, const Global& tile) {
// parameter `tile` here is not used
out << layout_type_to_str(Global::kType) << "(" << Global::kRows << ", "
<< Global::kCols << ", " << Global::kRowStride << ", "
<< Global::kColStride << "), numel = " << Global::kNumel;
}
};

} // namespace

template <typename Element_, typename Layout_>
struct GlobalTile {
using DType = Element_;
Expand All @@ -24,6 +40,10 @@ struct GlobalTile {

static constexpr tl::Layout kType = tl::layout_type<Layout>;

// This Ctor is to enable the use of the pretty printer of SharedTile in the
// host code.
HOST GlobalTile() : data_(nullptr), layout_(Layout{}) {}

DEVICE GlobalTile(DType* data) : data_(data), layout_(Layout{}) {}

DEVICE GlobalTile(const DType* data)
Expand All @@ -48,4 +68,14 @@ struct GlobalTile {
DType* data_;
Layout layout_;
};

/// @brief Pretty printer for the static shape information of a SharedTile.
/// Note: This printer function works ONLY on the host.
template <typename Element, typename Layout>
static HOST std::ostream& operator<<(std::ostream& out,
const GlobalTile<Element, Layout>& tile) {
GlobalTilePrettyPrinter::print(out, tile);
return out;
}

} // namespace tilefusion::cell
16 changes: 9 additions & 7 deletions include/types/global_tile_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace detail {
namespace {
/// @brief Helper for pretty printing a tile iterator's static shape-related
/// information. This printer works ONLY on the host.
struct GTileIteratorPrettyPrinter {
template <typename TileIterator>
static HOST void print(std::ostream& out, const TileIterator& itr) {
out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape["
<< dim_size<0, typename TileIterator::ChunkShape> << ", "
<< dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = "
<< TileIterator::sc0 << ", sc1 = " << TileIterator::sc1;
size_t size1 = dim_size<0, typename TileIterator::ChunkShape>;
size_t size2 = dim_size<1, typename TileIterator::ChunkShape>;

out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = ("
<< size1 << ", " << size2 << "), stripe count = ("
<< TileIterator::sc0 << ", " << TileIterator::sc1 << ")";
}
};
} // namespace detail
} // namespace

/// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles
/// and iterates over these smaller sub-tiles.
Expand Down Expand Up @@ -158,7 +160,7 @@ class GTileIterator {
template <typename TileShape, typename ChunkShape>
static HOST std::ostream& operator<<(
std::ostream& out, const GTileIterator<TileShape, ChunkShape>& itr) {
detail::GTileIteratorPrettyPrinter::print(out, itr);
GTileIteratorPrettyPrinter::print(out, itr);
return out;
}

Expand Down
15 changes: 6 additions & 9 deletions include/types/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace detail {

namespace {
template <typename DType>
constexpr int get_rows = DType::kRows;
Expand All @@ -36,16 +34,15 @@ constexpr int get_cols<__half> = 1;

template <>
constexpr int get_cols<cutlass::half_t> = 1;
} // namespace

/// @brief Helper for pretty printing a register tile's static shape
/// information. This printer works ONLY on the host.
struct RegTilePrettyPrinter {
template <typename Tile>
static HOST void print(std::ostream& out, const Tile& tile) {
out << layout_type_to_str(Tile::kType) << "["
out << layout_type_to_str(Tile::kType) << "("
<< Tile::kRows * get_rows<typename Tile::DType> << ", "
<< Tile::kCols * get_cols<typename Tile::DType> << "]";
<< Tile::kCols * get_cols<typename Tile::DType> << ")";
}
};

Expand All @@ -58,12 +55,12 @@ DEVICE void clear(__half* data, int numel) {
}

template <typename DType>
DEVICE void clear(DType* data, int numel) {
DEVICE void clear_impl(DType* data, int numel) {
for (int i = 0; i < numel; ++i) {
clear(data[i].mutable_data(), 8);
}
}
} // namespace detail
} // namespace

template <typename Element_, typename Layout_>
class RegTile {
Expand Down Expand Up @@ -101,7 +98,7 @@ class RegTile {
print_tile(const_cast<DType*>(data_), layout_);
}

DEVICE void clear() { detail::clear<DType>(data_, kNumel); }
DEVICE void clear() { clear_impl<DType>(data_, kNumel); }

private:
DType data_[kNumel];
Expand All @@ -124,7 +121,7 @@ using BaseTileColMajor = RegTile<Element, tl::ColMajor<4, 2>>;
template <typename T, typename Layout>
static HOST std::ostream& operator<<(std::ostream& out,
const RegTile<T, Layout>& tile) {
detail::RegTilePrettyPrinter::print(out, tile);
RegTilePrettyPrinter::print(out, tile);
return out;
}

Expand Down
33 changes: 33 additions & 0 deletions include/types/shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace {

/// @brief Helper for pretty printing a SharedTile's static shape-related
/// information. This printer works ONLY on the host.
struct SharedTilePrettyPrinter {
template <typename Shared>
static HOST void print(std::ostream& out, const Shared& tile) {
// parameter `tile` here is not used

auto swizzled = Shared::kSwizzled ? "swizzled" : "non-swizzled";

out << layout_type_to_str(Shared::kType) << "(" << Shared::kRows << ", "
<< Shared::kCols << ", " << Shared::kRowStride << ", "
<< Shared::kColStride << "), numel = " << Shared::kNumel
<< ", swizzled = " << swizzled;
}
};

} // namespace

template <typename Element_, typename Layout_, const bool kSwizzled_ = false>
class SharedTile {
public:
Expand All @@ -26,6 +46,10 @@ class SharedTile {
static constexpr tl::Layout kType = tl::layout_type<Layout>;
static constexpr bool kSwizzled = kSwizzled_;

// This Ctor is to enable the use of the pretty printer of SharedTile in the
// host code.
HOST SharedTile() : data_(nullptr), layout_(Layout{}) {}

DEVICE SharedTile(DType* data) : data_(data), layout_(Layout{}) {}

DEVICE DType* mutable_data() { return data_; }
Expand All @@ -48,4 +72,13 @@ class SharedTile {
Layout layout_;
};

/// @brief Pretty printer for the static shape information of a SharedTile.
/// Note: This printer function works ONLY on the host.
template <typename Element, typename Layout, const bool kSwizzled>
static HOST std::ostream& operator<<(
std::ostream& out, const SharedTile<Element, Layout, kSwizzled>& tile) {
SharedTilePrettyPrinter::print(out, tile);
return out;
}

} // namespace tilefusion::cell
18 changes: 9 additions & 9 deletions include/types/shared_tile_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

using namespace cute;

namespace detail {
namespace {
/// @brief Helper for pretty printing a tile iterator's static shape-related
/// information. This printer works ONLY on the host.
struct STileIteratorPrettyPrinter {
template <typename TileIterator>
static HOST void print(std::ostream& out, const TileIterator& itr) {
out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape["
<< dim_size<0, typename TileIterator::ChunkShape> << ", "
<< dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = "
<< TileIterator::sc0 << ", sc1 = " << TileIterator::sc1;
size_t size1 = dim_size<0, typename TileIterator::ChunkShape>;
size_t size2 = dim_size<1, typename TileIterator::ChunkShape>;

out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = ("
<< size1 << ", " << size2 << "), stripe count = ("
<< TileIterator::sc0 << ", " << TileIterator::sc1 << ")";
}
};
} // namespace detail
} // namespace

/// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles
/// and iterates over these smaller sub-tiles.
Expand Down Expand Up @@ -136,7 +136,7 @@ class STileIterator {
template <typename TileShape, typename ChunkShape>
static HOST std::ostream& operator<<(
std::ostream& out, const STileIterator<TileShape, ChunkShape>& itr) {
detail::STileIteratorPrettyPrinter::print(out, itr);
STileIteratorPrettyPrinter::print(out, itr);
return out;
}

Expand Down
27 changes: 11 additions & 16 deletions tests/cpp/cell/test_swizzled_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include <sstream>

#define DEBUG

namespace tilefusion::testing {
using namespace cell;
using namespace copy;
Expand Down Expand Up @@ -98,8 +100,8 @@ void run_test_rowmajor() {
static_assert(kShmRows == kRows, "kShmRows must be equal to kRows");

using Element = __half;
const int kThreads = tl::get_numel<WarpLayout> * 32;
static constexpr int kWarpPerRow = tl::num_rows<WarpLayout>;
const int kThreads = WarpLayout::kNumel * 32;
static constexpr int kWarpPerRow = WarpLayout::kRows;

using Global = GlobalTile<Element, tl::RowMajor<kRows, kCols>>;
using GIterator = GTileIterator<Global, TileShape<kRows, kShmCols>>;
Expand All @@ -124,12 +126,9 @@ void run_test_rowmajor() {
LOG(INFO) << "GIterator: " << GIterator{} << std::endl
<< "SIterator1: " << SIterator1{} << std::endl
<< "SIterator2: " << SIterator2{} << std::endl
<< "GlobalTile Shape: [" << kRows << ", " << kCols << "]"
<< std::endl
<< "SharedTile Shape: [" << kShmRows << ", " << kShmCols << "]"
<< std::endl
<< "sc0: " << kSc0 << ", sc1: " << kSc1 << std::endl
<< "RegTile Shape: " << Reg{} << std::endl;
<< "GlobalTile: " << Global{} << std::endl
<< "SharedTile: " << Shared1{} << std::endl
<< "RegTile: " << Reg{} << std::endl;
#endif

using G2S1 = GlobalToSharedLoader<Shared1, WarpLayout>;
Expand Down Expand Up @@ -205,16 +204,12 @@ void run_test_colmajor() {
using Reg = RegTile<BaseTileColMajor<Element>, tl::ColMajor<kSc0, kSc1>>;

#ifdef DEBUG
LOG(INFO) << std::endl
<< "GIterator: " << GIterator{} << std::endl
LOG(INFO) << "GIterator: " << GIterator{} << std::endl
<< "SIterator1: " << SIterator1{} << std::endl
<< "SIterator2: " << SIterator2{} << std::endl
<< "GlobalTile Shape: [" << kRows << ", " << kCols << "]"
<< std::endl
<< "SharedTile Shape: [" << kShmRows << ", " << kShmCols << "]"
<< std::endl
<< "sc0: " << kSc0 << ", sc1: " << kSc1 << std::endl
<< "RegTile Shape: " << Reg{} << std::endl;
<< "GlobalTile: " << Global{} << std::endl
<< "SharedTile: " << Shared1{} << std::endl
<< "RegTile: " << Reg{} << std::endl;
#endif

using G2S1 = GlobalToSharedLoader<Shared1, WarpLayout>;
Expand Down

0 comments on commit 1b922c1

Please sign in to comment.