Skip to content

Commit

Permalink
Add kernel reduction example
Browse files Browse the repository at this point in the history
  • Loading branch information
johnbowen42 committed Jan 18, 2025
1 parent 579ff77 commit de346d9
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 1 deletion.
5 changes: 5 additions & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ raja_add_executable(
NAME kernel-dynamic-tile
SOURCES kernel-dynamic-tile.cpp)

raja_add_executable(
NAME kernel-reduction
SOURCES kernel-reduction.cpp)


raja_add_executable(
NAME resource-kernel
SOURCES resource-kernel.cpp)
Expand Down
73 changes: 73 additions & 0 deletions examples/kernel-reduction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include "RAJA/RAJA.hpp"
#include "RAJA/index/RangeSegment.hpp"
#include "memoryManager.hpp"


int main(int RAJA_UNUSED_ARG(argc), char **RAJA_UNUSED_ARG(argv[]))
{

// matrix min, really dumb example
using EXEC_POL8 =
RAJA::KernelPolicy<
RAJA::statement::CudaKernel<
RAJA::statement::For<1, RAJA::cuda_block_x_loop, // row
RAJA::statement::For<0, RAJA::cuda_thread_x_loop, // col
RAJA::statement::Lambda<0> // min addition do I need an extra , RAJA::Params<0> here?
>
>
>
>;
// _matmult_3lambdakernel_cuda_end

using VALOPLOC_INT_MIN = RAJA::expt::ValLocOp<int, RAJA::Index_type, RAJA::operators::minimum>;
using VALOP_INT_MIN = RAJA::expt::ValOp<int, RAJA::operators::minimum>;
// RAJA::expt::Reduce<RAJA::operators::minimum>(&cuda_min),
int cuda_min = 0;

int seq_sum = 0;
int N = 10000;

RAJA::TypedRangeSegment<int> row_range(0, N);
RAJA::TypedRangeSegment<int> col_range(0, N);

RAJA::resources::Cuda cuda_res;
int *A = memoryManager::allocate<int>(N * N);
for (int row = 0; row < N; ++row) {
for (int col = 0; col < N; ++col) {
A[col + row * N] = -row;
}
}

RAJA::View<int, RAJA::Layout<2>> Aview(A, N, N);

// doesn't compile:
// no known conversion from
// 'RAJA::expt::detail::Reducer<RAJA::operators::minimum<int>, int, RAJA::expt::ValOp<int, RAJA::operators::minimum>>'
// to 'VALOP_INT_MIN &'
RAJA::kernel_param<EXEC_POL8>(
// segments
RAJA::make_tuple(col_range, row_range),
// params
RAJA::make_tuple(RAJA::expt::Reduce<RAJA::operators::minimum>(&cuda_min)),
//RAJA::tuple<double>(0.0),
// lambda 1
[=] RAJA_DEVICE (int col, int row, VALOP_INT_MIN &_cuda_min) {
_cuda_min.min(Aview(row, col));
//double& a){
//a += Aview(row, col);
}

);

// compiles
RAJA::forall<RAJA::cuda_exec<256>>(cuda_res, RAJA::RangeSegment(0, N),
RAJA::expt::Reduce<RAJA::operators::minimum>(&cuda_min),
[=] RAJA_DEVICE (int i, VALOP_INT_MIN &_cuda_min) {
_cuda_min.min(Aview(i, 0));
}

);

std::cout << "MIN VAL = " << cuda_min << std::endl;
//checkResult<double>(Cview, N);
};
41 changes: 41 additions & 0 deletions include/RAJA/kernel-reduce.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "RAJA/RAJA.hpp"

using EXEC_POL8 =
RAJA::KernelPolicy<
RAJA::statement::CudaKernel<
RAJA::statement::For<1, RAJA::cuda_block_x_loop, // row
RAJA::statement::For<0, RAJA::cuda_thread_x_loop, // col
RAJA::statement::Lambda<0, RAJA::Params<0>>, // dot = 0.0
RAJA::statement::For<2, RAJA::seq_exec,
RAJA::statement::Lambda<1> // dot += ...
>,
RAJA::statement::Lambda<2, RAJA::Segs<0, 1>, RAJA::Params<0>> // set C = ...
>
>
>
>;
// _matmult_3lambdakernel_cuda_end

RAJA::kernel_param<EXEC_POL8>(
RAJA::make_tuple(col_range, row_range, dot_range),

RAJA::tuple<double>{0.0}, // thread local variable for 'dot'

// lambda 0
[=] RAJA_DEVICE (double& dot) {
dot = 0.0;
},

// lambda 1
[=] RAJA_DEVICE (int col, int row, int k, double& dot) {
dot += Aview(row, k) * Bview(k, col);
},

// lambda 2
[=] RAJA_DEVICE (int col, int row, double& dot) {
Cview(row, col) = dot;
}

);

checkResult<double>(Cview, N);
2 changes: 1 addition & 1 deletion include/RAJA/pattern/params/reducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct Reducer : public ForallParamBase
{
using op = Op;
using value_type = T; // This is a basic data type

//using VOp = ValOp<T, Op>;
Reducer() = default;

// Basic data type constructor
Expand Down

0 comments on commit de346d9

Please sign in to comment.