Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cell): Added pretty printer for BaseTile/SharedTile/GlobalTile. #48

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ namespace tilefusion::cell::copy::warp {
namespace tl = tile_layout;
using namespace cute;

namespace {
namespace { // functions/class/structs that are not exposed to a larger scope

// FIXME(ying): This hotfix addresses the current implementation's inability
// to explicitly distinguish between shared memory's row-major or
// column-major layout and global memory's layouts. However, this should be
Expand Down Expand Up @@ -54,6 +55,18 @@ struct WarpOffsetHelper<WarpReuse::kRowReuseCont, kRowStride_, kColStride_> {

DEVICE int operator()(int i, int j) const { return i * kRowStride; }
};

/// @brief Helper for pretty printing a BaseTile's static shape-related
/// information. This printer works ONLY on the host.
struct BaseTilePrettyPrinter {
template <typename BaseShape>
static HOST void print(std::ostream& out, const BaseShape& tile) {
// parameter `tile` here is not used
out << "BaseShape = (" << BaseShape::kRows << ", " << BaseShape::kCols
<< "), Numel = " << BaseShape::kNumel << ", ThreadLayout = ("
<< BaseShape::kRowThreads << ", " << BaseShape::kColThreads << ")";
}
};
} // namespace

// @brief In a thread block, warps are organized as 2-D matrices, each with
Expand Down Expand Up @@ -261,6 +274,16 @@ struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kColMajor> {
using WarpThreadLayout = tl::ColMajor<kRowThreads, kColThreads>;
};

/// @brief Pretty printer for the static shape information of a
/// `WarpBaseTileShape`. Note: This printer function works ONLY on the
/// host.
template <typename DType, typename TileShape, const tl::Layout kType>
static HOST std::ostream& operator<<(
std::ostream& out, const WarpBaseTileShape<DType, TileShape, kType>& tile) {
BaseTilePrettyPrinter::print(out, tile);
return out;
}

template <typename WarpLayout_, const WarpReuse kMode_>
struct GlobalOffsetHelper {
static constexpr WarpReuse kMode = kMode_;
Expand Down
30 changes: 30 additions & 0 deletions include/types/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace {

/// @brief Helper for pretty printing a GlobalTile's static shape-related
/// information. This printer works ONLY on the host.
struct GlobalTilePrettyPrinter {
template <typename Global>
static HOST void print(std::ostream& out, const Global& tile) {
// parameter `tile` here is not used
out << layout_type_to_str(Global::kType) << "(" << Global::kRows << ", "
<< Global::kCols << ", " << Global::kRowStride << ", "
<< Global::kColStride << "), numel = " << Global::kNumel;
}
};

} // namespace

template <typename Element_, typename Layout_>
struct GlobalTile {
using DType = Element_;
Expand All @@ -24,6 +40,10 @@ struct GlobalTile {

static constexpr tl::Layout kType = tl::layout_type<Layout>;

// This Ctor is to enable the use of the pretty printer of SharedTile in the
// host code.
HOST GlobalTile() : data_(nullptr), layout_(Layout{}) {}

DEVICE GlobalTile(DType* data) : data_(data), layout_(Layout{}) {}

DEVICE GlobalTile(const DType* data)
Expand All @@ -48,4 +68,14 @@ struct GlobalTile {
DType* data_;
Layout layout_;
};

/// @brief Pretty printer for the static shape information of a SharedTile.
/// Note: This printer function works ONLY on the host.
template <typename Element, typename Layout>
static HOST std::ostream& operator<<(std::ostream& out,
const GlobalTile<Element, Layout>& tile) {
GlobalTilePrettyPrinter::print(out, tile);
return out;
}

} // namespace tilefusion::cell
16 changes: 9 additions & 7 deletions include/types/global_tile_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace detail {
namespace {
/// @brief Helper for pretty printing a tile iterator's static shape-related
/// information. This printer works ONLY on the host.
struct GTileIteratorPrettyPrinter {
template <typename TileIterator>
static HOST void print(std::ostream& out, const TileIterator& itr) {
out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape["
<< dim_size<0, typename TileIterator::ChunkShape> << ", "
<< dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = "
<< TileIterator::sc0 << ", sc1 = " << TileIterator::sc1;
size_t size1 = dim_size<0, typename TileIterator::ChunkShape>;
size_t size2 = dim_size<1, typename TileIterator::ChunkShape>;

out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = ("
<< size1 << ", " << size2 << "), stripe count = ("
<< TileIterator::sc0 << ", " << TileIterator::sc1 << ")";
}
};
} // namespace detail
} // namespace

/// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles
/// and iterates over these smaller sub-tiles.
Expand Down Expand Up @@ -158,7 +160,7 @@ class GTileIterator {
template <typename TileShape, typename ChunkShape>
static HOST std::ostream& operator<<(
std::ostream& out, const GTileIterator<TileShape, ChunkShape>& itr) {
detail::GTileIteratorPrettyPrinter::print(out, itr);
GTileIteratorPrettyPrinter::print(out, itr);
return out;
}

Expand Down
15 changes: 6 additions & 9 deletions include/types/register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace detail {

namespace {
template <typename DType>
constexpr int get_rows = DType::kRows;
Expand All @@ -36,16 +34,15 @@ constexpr int get_cols<__half> = 1;

template <>
constexpr int get_cols<cutlass::half_t> = 1;
} // namespace

/// @brief Helper for pretty printing a register tile's static shape
/// information. This printer works ONLY on the host.
struct RegTilePrettyPrinter {
template <typename Tile>
static HOST void print(std::ostream& out, const Tile& tile) {
out << layout_type_to_str(Tile::kType) << "["
out << layout_type_to_str(Tile::kType) << "("
<< Tile::kRows * get_rows<typename Tile::DType> << ", "
<< Tile::kCols * get_cols<typename Tile::DType> << "]";
<< Tile::kCols * get_cols<typename Tile::DType> << ")";
}
};

Expand All @@ -58,12 +55,12 @@ DEVICE void clear(__half* data, int numel) {
}

template <typename DType>
DEVICE void clear(DType* data, int numel) {
DEVICE void clear_impl(DType* data, int numel) {
for (int i = 0; i < numel; ++i) {
clear(data[i].mutable_data(), 8);
}
}
} // namespace detail
} // namespace

template <typename Element_, typename Layout_>
class RegTile {
Expand Down Expand Up @@ -101,7 +98,7 @@ class RegTile {
print_tile(const_cast<DType*>(data_), layout_);
}

DEVICE void clear() { detail::clear<DType>(data_, kNumel); }
DEVICE void clear() { clear_impl<DType>(data_, kNumel); }

private:
DType data_[kNumel];
Expand All @@ -124,7 +121,7 @@ using BaseTileColMajor = RegTile<Element, tl::ColMajor<4, 2>>;
template <typename T, typename Layout>
static HOST std::ostream& operator<<(std::ostream& out,
const RegTile<T, Layout>& tile) {
detail::RegTilePrettyPrinter::print(out, tile);
RegTilePrettyPrinter::print(out, tile);
return out;
}

Expand Down
33 changes: 33 additions & 0 deletions include/types/shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

namespace {

/// @brief Helper for pretty printing a SharedTile's static shape-related
/// information. This printer works ONLY on the host.
struct SharedTilePrettyPrinter {
template <typename Shared>
static HOST void print(std::ostream& out, const Shared& tile) {
// parameter `tile` here is not used

auto swizzled = Shared::kSwizzled ? "swizzled" : "non-swizzled";

out << layout_type_to_str(Shared::kType) << "(" << Shared::kRows << ", "
<< Shared::kCols << ", " << Shared::kRowStride << ", "
<< Shared::kColStride << "), numel = " << Shared::kNumel
<< ", swizzled = " << swizzled;
}
};

} // namespace

template <typename Element_, typename Layout_, const bool kSwizzled_ = false>
class SharedTile {
public:
Expand All @@ -26,6 +46,10 @@ class SharedTile {
static constexpr tl::Layout kType = tl::layout_type<Layout>;
static constexpr bool kSwizzled = kSwizzled_;

// This Ctor is to enable the use of the pretty printer of SharedTile in the
// host code.
HOST SharedTile() : data_(nullptr), layout_(Layout{}) {}

DEVICE SharedTile(DType* data) : data_(data), layout_(Layout{}) {}

DEVICE DType* mutable_data() { return data_; }
Expand All @@ -48,4 +72,13 @@ class SharedTile {
Layout layout_;
};

/// @brief Pretty printer for the static shape information of a SharedTile.
/// Note: This printer function works ONLY on the host.
template <typename Element, typename Layout, const bool kSwizzled>
static HOST std::ostream& operator<<(
std::ostream& out, const SharedTile<Element, Layout, kSwizzled>& tile) {
SharedTilePrettyPrinter::print(out, tile);
return out;
}

} // namespace tilefusion::cell
18 changes: 9 additions & 9 deletions include/types/shared_tile_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@
namespace tilefusion::cell {
namespace tl = tile_layout;

using namespace cute;

namespace detail {
namespace {
/// @brief Helper for pretty printing a tile iterator's static shape-related
/// information. This printer works ONLY on the host.
struct STileIteratorPrettyPrinter {
template <typename TileIterator>
static HOST void print(std::ostream& out, const TileIterator& itr) {
out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape["
<< dim_size<0, typename TileIterator::ChunkShape> << ", "
<< dim_size<1, typename TileIterator::ChunkShape> << "], sc0 = "
<< TileIterator::sc0 << ", sc1 = " << TileIterator::sc1;
size_t size1 = dim_size<0, typename TileIterator::ChunkShape>;
size_t size2 = dim_size<1, typename TileIterator::ChunkShape>;

out << "numel = " << TileIterator::Tile::kNumel << ", ChunkShape = ("
<< size1 << ", " << size2 << "), stripe count = ("
<< TileIterator::sc0 << ", " << TileIterator::sc1 << ")";
}
};
} // namespace detail
} // namespace

/// @brief `SharedTileIterator` chunks a shared memory tile into smaller tiles
/// and iterates over these smaller sub-tiles.
Expand Down Expand Up @@ -136,7 +136,7 @@ class STileIterator {
template <typename TileShape, typename ChunkShape>
static HOST std::ostream& operator<<(
std::ostream& out, const STileIterator<TileShape, ChunkShape>& itr) {
detail::STileIteratorPrettyPrinter::print(out, itr);
STileIteratorPrettyPrinter::print(out, itr);
return out;
}

Expand Down
27 changes: 11 additions & 16 deletions tests/cpp/cell/test_swizzled_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include <sstream>

#define DEBUG

namespace tilefusion::testing {
using namespace cell;
using namespace copy;
Expand Down Expand Up @@ -98,8 +100,8 @@ void run_test_rowmajor() {
static_assert(kShmRows == kRows, "kShmRows must be equal to kRows");

using Element = __half;
const int kThreads = tl::get_numel<WarpLayout> * 32;
static constexpr int kWarpPerRow = tl::num_rows<WarpLayout>;
const int kThreads = WarpLayout::kNumel * 32;
static constexpr int kWarpPerRow = WarpLayout::kRows;

using Global = GlobalTile<Element, tl::RowMajor<kRows, kCols>>;
using GIterator = GTileIterator<Global, TileShape<kRows, kShmCols>>;
Expand All @@ -124,12 +126,9 @@ void run_test_rowmajor() {
LOG(INFO) << "GIterator: " << GIterator{} << std::endl
<< "SIterator1: " << SIterator1{} << std::endl
<< "SIterator2: " << SIterator2{} << std::endl
<< "GlobalTile Shape: [" << kRows << ", " << kCols << "]"
<< std::endl
<< "SharedTile Shape: [" << kShmRows << ", " << kShmCols << "]"
<< std::endl
<< "sc0: " << kSc0 << ", sc1: " << kSc1 << std::endl
<< "RegTile Shape: " << Reg{} << std::endl;
<< "GlobalTile: " << Global{} << std::endl
<< "SharedTile: " << Shared1{} << std::endl
<< "RegTile: " << Reg{} << std::endl;
#endif

using G2S1 = GlobalToSharedLoader<Shared1, WarpLayout>;
Expand Down Expand Up @@ -205,16 +204,12 @@ void run_test_colmajor() {
using Reg = RegTile<BaseTileColMajor<Element>, tl::ColMajor<kSc0, kSc1>>;

#ifdef DEBUG
LOG(INFO) << std::endl
<< "GIterator: " << GIterator{} << std::endl
LOG(INFO) << "GIterator: " << GIterator{} << std::endl
<< "SIterator1: " << SIterator1{} << std::endl
<< "SIterator2: " << SIterator2{} << std::endl
<< "GlobalTile Shape: [" << kRows << ", " << kCols << "]"
<< std::endl
<< "SharedTile Shape: [" << kShmRows << ", " << kShmCols << "]"
<< std::endl
<< "sc0: " << kSc0 << ", sc1: " << kSc1 << std::endl
<< "RegTile Shape: " << Reg{} << std::endl;
<< "GlobalTile: " << Global{} << std::endl
<< "SharedTile: " << Shared1{} << std::endl
<< "RegTile: " << Reg{} << std::endl;
#endif

using G2S1 = GlobalToSharedLoader<Shared1, WarpLayout>;
Expand Down
Loading