Skip to content

Commit

Permalink
Add tensor partition and generic copy for ck wrapper (#1108)
Browse files Browse the repository at this point in the history
* Add tensor partition and generic copy for ck wrapper

* Update changelog

* Stylistic fixes

* Change shape/strides logic to descriptor transforms

* Fixes

* Fix client example

* Fix comments
  • Loading branch information
bartekxk authored Jan 3, 2024
1 parent b268f27 commit 4234b3a
Show file tree
Hide file tree
Showing 14 changed files with 940 additions and 306 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ None
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799)
- Support for Batched Gemm DL (#732)
- Introduce wrapper sublibrary (limited functionality). (#1071, #1098)
- Introduce wrapper sublibrary (limited functionality). (#1071, #1098, #1108)

### Changes
- Changed the grouped convolution API to maintain consistency with other convolution kernels (#817)
Expand Down
8 changes: 8 additions & 0 deletions docs/wrapper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,11 @@ Tensor helpers
-------------------------------------

.. doxygenfile:: tensor_utils.hpp

.. doxygenfile:: tensor_partition.hpp

-------------------------------------
Operations
-------------------------------------

.. doxygenfile:: copy.hpp
11 changes: 11 additions & 0 deletions include/ck/utility/tuple_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
return math::max(TupleDepth<depth + 1>(Ts{})...);
}

template <index_t from, index_t to, typename... Ts>
__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
{
return generate_tuple(
[&](auto i) {
using Idx = Number<from + i>;
return tuple.At(Idx{});
},
Number<to - from>{});
}

} // namespace ck
120 changes: 33 additions & 87 deletions include/ck/wrapper/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
* \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims.
*/
template <typename Shape, typename Strides>
template <typename Shape, typename UnnestedDescriptorType>
struct Layout
{
private:
Expand All @@ -31,7 +29,7 @@ struct Layout
{
return generate_tuple(
[&](auto) {
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
{
// runtime layout
return index_t(0);
Expand All @@ -45,27 +43,6 @@ struct Layout
Number<Tuple<Ts...>::Size()>{});
}

// Generate packed (column-major) strides if not passed
template <typename... Ts>
__host__ __device__ constexpr static auto
GenerateColumnMajorPackedStrides(const Tuple<Ts...>& shape)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
return generate_tuple(
[&](auto i) {
if constexpr(i.value == 0)
{
return I1;
}
else
{
return TupleReduce<I0.value, i.value>([](auto x, auto y) { return x * y; },
unrolled_shape);
}
},
Number<decltype(unrolled_shape)::Size()>{});
}

// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
Expand Down Expand Up @@ -207,33 +184,15 @@ struct Layout
return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
}

template <typename LayoutShape, typename LayoutStrides>
__host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape,
const LayoutStrides& strides)
{
const auto unrolled_shape = UnrollNestedTuple(shape);
const auto unrolled_strides = UnrollNestedTuple(strides);
static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
"Size of strides and shape are not consistent.");
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}

// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using DeducedStrides =
std::conditional_t<is_same_v<Strides, Tuple<>>,
remove_cvref_t<decltype(GenerateColumnMajorPackedStrides(Shape{}))>,
Strides>;
using FlattenDescriptorType =
remove_cvref_t<decltype(MakeFlattenDescriptor(Shape{}, DeducedStrides{}))>;
using Descriptor1dType =
remove_cvref_t<decltype(MakeMerge1d(Shape{}, FlattenDescriptorType{}))>;
remove_cvref_t<decltype(MakeMerge1d(Shape{}, UnnestedDescriptorType{}))>;
using DefaultIdxsTupleType = remove_cvref_t<decltype(GenerateDefaultIdxsTuple(Shape{}))>;

template <typename... ShapeDims, typename... IdxDims>
__host__ __device__ constexpr static auto
TransformDesc(const Tuple<ShapeDims...>& shape,
const Tuple<IdxDims...>& idx,
const FlattenDescriptorType& naive_descriptor)
const UnnestedDescriptorType& naive_descriptor)
{
if constexpr(Tuple<IdxDims...>::Size() == I1)
{
Expand All @@ -256,48 +215,33 @@ struct Layout
}

using MergedNestsDescriptorType = remove_cvref_t<decltype(TransformDesc(
Shape{}, DefaultIdxsTupleType{}, FlattenDescriptorType{}))>;
Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>;

public:
__host__ __device__ constexpr auto GetElementSpaceSize() const
{
return flatten_descriptor_.GetElementSpaceSize();
return unnested_descriptor_.GetElementSpaceSize();
}

__host__ __device__ Layout() = delete;

/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param strides Strides for layout (optional if tensor is packed).
* \param unnested_descriptor Descriptor
*/
__host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides)
: flatten_descriptor_{}, shape_(shape), strides_(strides)
__host__ __device__ constexpr Layout(const Shape& shape,
const UnnestedDescriptorType& unnested_descriptor)
: shape_(shape)
{
// Construct if runtime mode
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
unnested_descriptor_ = unnested_descriptor;
descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
}
}

/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__ __device__ constexpr Layout(const Shape& shape)
: flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_))
{
if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime())
{
flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_);
descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_);
merged_nests_descriptor_ =
TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_);
TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_);
}
}

Expand All @@ -310,9 +254,9 @@ struct Layout
template <typename Idxs>
__host__ __device__ constexpr index_t operator()() const
{
static_assert(FlattenDescriptorType::IsKnownAtCompileTime(),
static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(),
"Compiletime operator used on runtime layout.");
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{}));
using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{}));
using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{}));
return TransformedDesc{}.CalculateOffset(UnrolledIdx{});
}
Expand All @@ -339,7 +283,7 @@ struct Layout
else
{
// Custom index, need to transform descriptor
const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_);
const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_);
return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx));
}
}
Expand All @@ -351,7 +295,7 @@ struct Layout
* \return Calculated size.
*/
template <index_t IDim>
__host__ __device__ constexpr index_t GetLength() const
__host__ __device__ constexpr auto GetLength() const
{
const auto elem = shape_.At(Number<IDim>{});
if constexpr(is_detected<is_tuple, tuple_element_t<IDim, Shape>>::value)
Expand All @@ -371,7 +315,7 @@ struct Layout
*
* \return Calculated size.
*/
__host__ __device__ constexpr index_t GetLengths() const
__host__ __device__ constexpr auto GetLengths() const
{
const auto unrolled_shape = UnrollNestedTuple(shape_);
return TupleReduce<I0.value, unrolled_shape.Size()>([](auto x, auto y) { return x * y; },
Expand All @@ -385,13 +329,6 @@ struct Layout
*/
__host__ __device__ constexpr const Shape& GetShape() const { return shape_; }

/**
* \brief Strides getter.
*
* \return Strides.
*/
__host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; }

/**
* \brief Get default lengths (tuple filled with Shape length elements).
*
Expand All @@ -417,17 +354,26 @@ struct Layout
*
* \return Default descriptor.
*/
__host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor()
__host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const
{
return merged_nests_descriptor_;
}

/**
* \brief Get unnested descriptor (with unrolled dims)
*
* \return Flatten descriptor.
*/
__host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const
{
return unnested_descriptor_;
}

private:
FlattenDescriptorType flatten_descriptor_;
UnnestedDescriptorType unnested_descriptor_;
Descriptor1dType descriptor_1d_;
MergedNestsDescriptorType merged_nests_descriptor_;
const Shape shape_;
const DeducedStrides strides_;
};

} // namespace wrapper
Expand Down
41 changes: 41 additions & 0 deletions include/ck/wrapper/operations/copy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "../utils/tensor_utils.hpp"

namespace ck {
namespace wrapper {

/**
* \brief Perform generic copy between two tensors. Tensors must have the
* same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template <typename SrcTensorType, typename DstTensorType>
__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
{
if constexpr(!SrcTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(src_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else if constexpr(!DstTensorType::IsDynamicBuffer)
{
using SizeType = decltype(size(dst_tensor));
static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); });
}
else
{
for(int i = 0; i < size(src_tensor); i++)
{
dst_tensor(i) = src_tensor(i);
}
}
}

} // namespace wrapper
} // namespace ck
Loading

0 comments on commit 4234b3a

Please sign in to comment.