Skip to content

Commit

Permalink
Update packed_stride.hpp to add CUTLASS_HOST_DEVICE decorator to new …
Browse files Browse the repository at this point in the history
…functions (#1495)
  • Loading branch information
djns99 authored Apr 19, 2024
1 parent 7d49e6c commit 5c447dd
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tools/util/include/cutlass/util/packed_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape
// Strides with group mode

template <class StrideIntT>
CUTLASS_HOST_DEVICE
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
Expand All @@ -121,6 +122,7 @@ make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s,
}

template <class StrideIntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
Expand All @@ -140,6 +142,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s,
// right in KTRSC order and can be coalesced to just k.
// We enforce this condition here with asserts.
template <class IntT, size_t RankT_>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
Expand Down Expand Up @@ -169,6 +172,7 @@ make_cute_packed_stride(

// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
make_cute_packed_stride(
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
Expand All @@ -185,6 +189,7 @@ make_cute_packed_stride(

// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
make_cute_packed_stride(
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
Expand All @@ -202,6 +207,7 @@ make_cute_packed_stride(

// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
make_cute_packed_stride(
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
Expand All @@ -224,6 +230,7 @@ make_cute_packed_stride(

// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
Expand All @@ -241,6 +248,7 @@ make_cute_packed_stride(

// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
Expand All @@ -260,6 +268,7 @@ make_cute_packed_stride(

// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
Expand All @@ -286,6 +295,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
Expand All @@ -311,6 +321,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
Expand Down Expand Up @@ -339,6 +350,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
Expand Down Expand Up @@ -370,6 +382,7 @@ make_cute_packed_stride(

// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, IntT> s,
Expand All @@ -386,6 +399,7 @@ make_cute_packed_stride(

// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, IntT> s,
Expand All @@ -402,6 +416,7 @@ make_cute_packed_stride(

// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, IntT> s,
Expand All @@ -424,6 +439,7 @@ make_cute_packed_stride(

// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
Expand Down Expand Up @@ -462,6 +478,7 @@ make_cute_packed_stride(

// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,
Expand Down

0 comments on commit 5c447dd

Please sign in to comment.