From b84a312543fcf09cf38dd1d372eae85808f5c03b Mon Sep 17 00:00:00 2001 From: dekken Date: Sat, 23 Mar 2024 19:52:35 +0100 Subject: [PATCH] betterment --- .github/workflows/build.yml | 2 +- inc/mkn/gpu.hpp | 20 ++++------ inc/mkn/gpu/alloc.hpp | 8 ++-- inc/mkn/gpu/cpu.hpp | 23 ++++++++--- inc/mkn/gpu/cuda.hpp | 70 +++++++++++++++++++++----------- inc/mkn/gpu/cuda/def.hpp | 24 ----------- inc/mkn/gpu/def.hpp | 10 ----- inc/mkn/gpu/defines.hpp | 52 ++++++++++++++++++++++++ inc/mkn/gpu/device.hpp | 3 ++ inc/mkn/gpu/launchers.hpp | 16 ++++++-- inc/mkn/gpu/rocm.hpp | 79 +++++++++++++++++++++++++++++++------ inc/mkn/gpu/rocm/def.hpp | 24 ----------- mkn.yaml | 6 +-- res/mkn/hipcc.yaml | 6 +-- test/any/add.cpp | 7 ++++ test/any/managed.cpp | 2 + 16 files changed, 227 insertions(+), 125 deletions(-) delete mode 100644 inc/mkn/gpu/cuda/def.hpp create mode 100644 inc/mkn/gpu/defines.hpp delete mode 100644 inc/mkn/gpu/rocm/def.hpp diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c5912d3..5749f9d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,7 +10,7 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: test run: | diff --git a/inc/mkn/gpu.hpp b/inc/mkn/gpu.hpp index 148c7f6..d798fe9 100644 --- a/inc/mkn/gpu.hpp +++ b/inc/mkn/gpu.hpp @@ -31,27 +31,23 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef _MKN_GPU_HPP_ #define _MKN_GPU_HPP_ -#if defined(MKN_GPU_ROCM) -#include "mkn/gpu/rocm.hpp" -#elif defined(MKN_GPU_CUDA) -#include "mkn/gpu/cuda.hpp" -#elif defined(MKN_GPU_CPU) -#include "mkn/gpu/cpu.hpp" -#elif !defined(MKN_GPU_FN_PER_NS) || MKN_GPU_FN_PER_NS == 0 -#error "UNKNOWN GPU / define MKN_GPU_ROCM or MKN_GPU_CUDA" -#endif +#include "mkn/gpu/defines.hpp" namespace mkn::gpu { __device__ uint32_t idx() { -#if defined(MKN_GPU_ROCM) +#if MKN_GPU_ROCM return mkn::gpu::hip::idx(); -#elif defined(MKN_GPU_CUDA) + +#elif MKN_GPU_CUDA return mkn::gpu::cuda::idx(); -#elif defined(MKN_GPU_CPU) + +#elif MKN_GPU_CPU return mkn::gpu::cpu::idx(); + #else #error "UNKNOWN GPU / define MKN_GPU_ROCM or MKN_GPU_CUDA" + #endif } diff --git a/inc/mkn/gpu/alloc.hpp b/inc/mkn/gpu/alloc.hpp index b442240..4926970 100644 --- a/inc/mkn/gpu/alloc.hpp +++ b/inc/mkn/gpu/alloc.hpp @@ -72,8 +72,10 @@ class ManagedAllocator { template void copy(T* const dst, T const* const src, Size size) { - auto dst_p = Pointer{dst}; - auto src_p = Pointer{src}; + assert(dst and src); + + Pointer src_p{src}; + Pointer dst_p{dst}; bool to_send = dst_p.is_device_ptr() && src_p.is_host_ptr(); bool to_take = dst_p.is_host_ptr() && src_p.is_device_ptr(); @@ -81,7 +83,7 @@ void copy(T* const dst, T const* const src, Size size) { if (to_send) send(dst, src, size); else if (to_take) - take(dst, src, size); + take(src, dst, size); else throw std::runtime_error("Unsupported operation (PR welcome)"); } diff --git a/inc/mkn/gpu/cpu.hpp b/inc/mkn/gpu/cpu.hpp index 8b34013..5f1b961 100644 --- a/inc/mkn/gpu/cpu.hpp +++ b/inc/mkn/gpu/cpu.hpp @@ -93,14 +93,27 @@ struct Stream { std::size_t stream = 0; }; +struct StreamEvent { + StreamEvent(Stream&) {} + ~StreamEvent() {} + + auto& operator()() { return event; }; + void record() { ; } + bool finished() const { return true; } + void reset() {} + + Stream stream; + std::size_t event = 0; +}; + template struct Pointer { Pointer(T* _t) : t{_t} {} bool is_unregistered_ptr() const { return t == nullptr; } bool is_host_ptr() const { return true; } - bool is_device_ptr() const { return false; } - bool is_managed_ptr() const { return false; } + bool is_device_ptr() const { return true; } + bool is_managed_ptr() const { return true; } T* t; }; @@ -129,7 +142,7 @@ void alloc_managed(T*& p, Size size) { MKN_GPU_ASSERT(p = reinterpret_cast(std::malloc(size * sizeof(T)))); } -void destroy(void* p) { +void inline destroy(void* p) { KLOG(TRC); std::free(p); } @@ -177,7 +190,7 @@ void take_async(T* p, Span& span, Stream& /*stream*/, std::size_t start) { take(p, span.data(), span.size(), start); } -void sync() {} +void inline sync() {} #include "mkn/gpu/alloc.hpp" #include "mkn/gpu/device.hpp" @@ -186,7 +199,7 @@ namespace detail { static thread_local std::size_t idx = 0; } -template +template void launch(F f, dim3 g, dim3 b, std::size_t /*ds*/, std::size_t /*stream*/, Args&&... args) { std::size_t N = (g.x * g.y * g.z) * (b.x * b.y * b.z); KLOG(TRC) << N; diff --git a/inc/mkn/gpu/cuda.hpp b/inc/mkn/gpu/cuda.hpp index edd7794..e78a0a3 100644 --- a/inc/mkn/gpu/cuda.hpp +++ b/inc/mkn/gpu/cuda.hpp @@ -34,13 +34,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include -#include - #include "mkn/kul/log.hpp" #include "mkn/kul/span.hpp" #include "mkn/kul/tuple.hpp" #include "mkn/kul/assert.hpp" +#include #include "mkn/gpu/def.hpp" // #define MKN_GPU_ASSERT(x) (KASSERT((x) == cudaSuccess)) @@ -54,6 +53,22 @@ inline void gpuAssert(cudaError_t code, const char* file, int line, bool abort = } } +namespace mkn::gpu::cuda { + +template +__device__ SIZE idx() { + SIZE width = gridDim.x * blockDim.x; + SIZE height = gridDim.y * blockDim.y; + SIZE x = blockDim.x * blockIdx.x + threadIdx.x; + SIZE y = blockDim.y * blockIdx.y + threadIdx.y; + SIZE z = blockDim.z * blockIdx.z + threadIdx.z; + return x + (y * width) + (z * width * height); +} + +} // namespace mkn::gpu::cuda + +// + #if defined(MKN_GPU_FN_PER_NS) && MKN_GPU_FN_PER_NS #define MKN_GPU_NS mkn::gpu::cuda #else @@ -74,6 +89,24 @@ struct Stream { cudaStream_t stream; }; +struct StreamEvent { + StreamEvent(Stream& stream_) : stream{stream_} { reset(); } + ~StreamEvent() { /*MKN_GPU_ASSERT(result = cudaEventDestroy(event));*/ + } + + auto& operator()() { return event; }; + void record() { MKN_GPU_ASSERT(result = cudaEventRecord(event, stream())); } + bool finished() const { return cudaEventQuery(event) == cudaSuccess; } + void reset() { + if (event) MKN_GPU_ASSERT(result = cudaEventDestroy(event)); + MKN_GPU_ASSERT(result = cudaEventCreate(&event)); + } + + Stream& stream; + cudaError_t result; + cudaEvent_t event = nullptr; +}; + template struct Pointer { Pointer(T* _t) : t{_t} { MKN_GPU_ASSERT(cudaPointerGetAttributes(&attributes, t)); } @@ -112,7 +145,7 @@ void alloc_managed(T*& p, Size size) { MKN_GPU_ASSERT(cudaMallocManaged((void**)&p, size * sizeof(T))); } -void destroy(void* p) { +void inline destroy(void* p) { KLOG(TRC); MKN_GPU_ASSERT(cudaFree(p)); } @@ -147,23 +180,6 @@ void take(T const* p, T* t, Size size = 1, Size start = 0) { MKN_GPU_ASSERT(cudaMemcpy(t, p + start, size * sizeof(T), cudaMemcpyDeviceToHost)); } -template -void copy(T* dst, T const* src, Size size = 1, Size start = 0) { - KLOG(TRC); - Pointer p{dst}; - if (p.is_host_ptr()) - take(src, dst, size, start); - else - send(dst, src, size, start); -} - -template -void copy(std::vector& dst, std::vector const& src) { - KLOG(TRC); - assert(dst.size() >= src.size()); - copy(dst.data(), src.data(), dst.size()); -} - template void send_async(T* p, T const* t, Stream& stream, Size size = 1, Size start = 0) { KLOG(TRC); @@ -185,19 +201,25 @@ void take_async(T* p, Span& span, Stream& stream, std::size_t start) { stream())); } -void sync() { MKN_GPU_ASSERT(cudaDeviceSynchronize()); } +void inline sync() { MKN_GPU_ASSERT(cudaDeviceSynchronize()); } +void inline sync(cudaStream_t stream) { MKN_GPU_ASSERT(cudaStreamSynchronize(stream)); } #include "mkn/gpu/alloc.hpp" #include "mkn/gpu/device.hpp" -template +template void launch(F&& f, dim3 g, dim3 b, std::size_t ds, cudaStream_t& s, Args&&... args) { std::size_t N = (g.x * g.y * g.z) * (b.x * b.y * b.z); KLOG(TRC) << N; std::apply( [&](auto&&... params) { f<<>>(params...); }, devmem_replace(std::forward_as_tuple(args...), std::make_index_sequence())); - sync(); + if constexpr (_sync) { + if (s) + sync(s); + else + sync(); + } } // @@ -254,7 +276,7 @@ void fill(Container& c, T val) { } // -void prinfo(size_t dev = 0) { +void inline prinfo(size_t dev = 0) { cudaDeviceProp devProp; [[maybe_unused]] auto ret = cudaGetDeviceProperties(&devProp, dev); KOUT(NON) << " System version " << devProp.major << "." << devProp.minor; diff --git a/inc/mkn/gpu/cuda/def.hpp b/inc/mkn/gpu/cuda/def.hpp deleted file mode 100644 index 37abb46..0000000 --- a/inc/mkn/gpu/cuda/def.hpp +++ /dev/null @@ -1,24 +0,0 @@ - -// IWYU pragma: private, include "mkn/gpu/def.hpp" - -#ifndef _MKN_GPU_CUDA_DEF_HPP_ -#define _MKN_GPU_CUDA_DEF_HPP_ - -#include - -namespace mkn::gpu::cuda { - -template -__device__ SIZE idx() { - SIZE width = gridDim.x * blockDim.x; - SIZE height = gridDim.y * blockDim.y; - - SIZE x = blockDim.x * blockIdx.x + threadIdx.x; - SIZE y = blockDim.y * blockIdx.y + threadIdx.y; - SIZE z = blockDim.z * blockIdx.z + threadIdx.z; - return x + (y * width) + (z * width * height); -} - -} // namespace mkn::gpu::cuda - -#endif /*_MKN_GPU_CUDA_DEF_HPP_*/ diff --git a/inc/mkn/gpu/def.hpp b/inc/mkn/gpu/def.hpp index da3a26c..911ddd3 100644 --- a/inc/mkn/gpu/def.hpp +++ b/inc/mkn/gpu/def.hpp @@ -5,16 +5,6 @@ #include -#if defined(MKN_GPU_ROCM) -#include "mkn/gpu/rocm/def.hpp" -#elif defined(MKN_GPU_CUDA) -#include "mkn/gpu/cuda/def.hpp" -#elif defined(MKN_GPU_CPU) - -#elif !defined(MKN_GPU_FN_PER_NS) || MKN_GPU_FN_PER_NS == 0 -#error "UNKNOWN GPU / define MKN_GPU_ROCM or MKN_GPU_CUDA" -#endif - namespace mkn::gpu { #if defined(MKN_GPU_CPU) diff --git a/inc/mkn/gpu/defines.hpp b/inc/mkn/gpu/defines.hpp new file mode 100644 index 0000000..a0e1b26 --- /dev/null +++ b/inc/mkn/gpu/defines.hpp @@ -0,0 +1,52 @@ + + +#ifndef _MKN_GPU_DEFINES_HPP_ +#define _MKN_GPU_DEFINES_HPP_ + +#include + +#if !defined(MKN_GPU_FN_PER_NS) +#define MKN_GPU_FN_PER_NS 0 +#endif + +#if !defined(MKN_GPU_ROCM) and __has_include("hip/hip_runtime.h") +#define MKN_GPU_ROCM 1 +#endif +#if !defined(MKN_GPU_ROCM) +#define MKN_GPU_ROCM 0 +#endif + +#if !defined(MKN_GPU_CUDA) and __has_include() +#define MKN_GPU_CUDA 1 +#endif +#if !defined(MKN_GPU_CUDA) +#define MKN_GPU_CUDA 0 +#endif + +#if MKN_GPU_CUDA == 1 && MKN_GPU_ROCM == 1 && MKN_GPU_FN_PER_NS == 0 +#define MKN_GPU_FN_PER_NS 1 +#endif + +#if MKN_GPU_ROCM == 1 +#include "mkn/gpu/rocm.hpp" +#endif + +#if MKN_GPU_CUDA +#include "mkn/gpu/cuda.hpp" +#endif + +#if MKN_GPU_FN_PER_NS == 1 || MKN_GPU_CPU == 1 +#include "mkn/gpu/cpu.hpp" +#endif + +namespace mkn::gpu { + +struct CompileFlags { + bool constexpr static withCUDA = MKN_GPU_CUDA; + bool constexpr static withROCM = MKN_GPU_ROCM; + bool constexpr static perNamespace = MKN_GPU_FN_PER_NS; +}; + +} /* namespace mkn::gpu */ + +#endif /*_MKN_GPU_DEFINES_HPP_*/ diff --git a/inc/mkn/gpu/device.hpp b/inc/mkn/gpu/device.hpp index 7580180..975fc85 100644 --- a/inc/mkn/gpu/device.hpp +++ b/inc/mkn/gpu/device.hpp @@ -100,6 +100,9 @@ struct DeviceMem { auto& size() const { return s; } + auto* data() { return p; } + auto* data() const { return p; } + std::size_t s = 0; T* p = nullptr; bool owned = false; diff --git a/inc/mkn/gpu/launchers.hpp b/inc/mkn/gpu/launchers.hpp index adbd213..8bfd8e1 100644 --- a/inc/mkn/gpu/launchers.hpp +++ b/inc/mkn/gpu/launchers.hpp @@ -31,12 +31,20 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef _MKN_GPU_LAUNCHERS_HPP_ #define _MKN_GPU_LAUNCHERS_HPP_ +template struct GDLauncher : public GLauncher { GDLauncher(std::size_t s, size_t dev = 0) : GLauncher{s, dev} {} template auto operator()(F&& f, Args&&... args) { - _launch(std::forward(f), + _launch(s, std::forward(f), + as_values(std::forward_as_tuple(args...), std::make_index_sequence()), + count, args...); + } + + template + auto stream(Stream& s, F&& f, Args&&... args) { + _launch(s.stream, std::forward(f), as_values(std::forward_as_tuple(args...), std::make_index_sequence()), count, args...); } @@ -48,9 +56,9 @@ struct GDLauncher : public GLauncher { return T{nullptr}; } - template - void _launch(F&& f, std::tuple*, Args&&... args) { - MKN_GPU_NS::launch(&global_gd_kernel, g, b, ds, s, f, args...); + template + void _launch(S& _s, F&& f, std::tuple*, Args&&... args) { + MKN_GPU_NS::launch<_sync>(&global_gd_kernel, g, b, ds, _s, f, args...); } }; diff --git a/inc/mkn/gpu/rocm.hpp b/inc/mkn/gpu/rocm.hpp index a09b5a6..bd1060d 100644 --- a/inc/mkn/gpu/rocm.hpp +++ b/inc/mkn/gpu/rocm.hpp @@ -32,13 +32,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef _MKN_GPU_ROCM_HPP_ #define _MKN_GPU_ROCM_HPP_ -#include "hip/hip_runtime.h" - #include "mkn/kul/log.hpp" #include "mkn/kul/span.hpp" #include "mkn/kul/tuple.hpp" #include "mkn/kul/assert.hpp" +#include "hip/hip_runtime.h" #include "mkn/gpu/def.hpp" // #define MKN_GPU_ASSERT(x) (KASSERT((x) == hipSuccess)) @@ -48,10 +47,28 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. inline void gpuAssert(hipError_t code, const char* file, int line, bool abort = true) { if (code != hipSuccess) { fprintf(stderr, "GPUassert: %s %s %d\n", hipGetErrorString(code), file, line); + std::abort(); if (abort) exit(code); } } +namespace mkn::gpu::hip { + +template +__device__ SIZE idx() { + SIZE width = hipGridDim_x * hipBlockDim_x; + SIZE height = hipGridDim_y * hipBlockDim_y; + + SIZE x = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; + SIZE y = hipBlockDim_y * hipBlockIdx_y + hipThreadIdx_y; + SIZE z = hipBlockDim_z * hipBlockIdx_z + hipThreadIdx_z; + return x + (y * width) + (z * width * height); // max 4294967296 +} + +} // namespace mkn::gpu::hip + +// + #if defined(MKN_GPU_FN_PER_NS) && MKN_GPU_FN_PER_NS #define MKN_GPU_NS mkn::gpu::hip #else @@ -72,17 +89,49 @@ struct Stream { hipStream_t stream; }; +struct StreamEvent { + StreamEvent(Stream& stream_) : stream{stream_} { reset(); } + ~StreamEvent() { /*MKN_GPU_ASSERT(result = hipEventDestroy(event));*/ + } + + auto& operator()() { return event; }; + void record() { MKN_GPU_ASSERT(result = hipEventRecord(event, stream())); } + bool finished() const { return hipEventQuery(event) == hipSuccess; } + void reset() { + if (event) MKN_GPU_ASSERT(result = hipEventDestroy(event)); + MKN_GPU_ASSERT(result = hipEventCreate(&event)); + } + + Stream& stream; + hipError_t result; + hipEvent_t event = nullptr; +}; + +// https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___global_defs.html#gaea86e91d3cd65992d787b39b218435a3 template struct Pointer { - Pointer(T* _t) : t{_t} { MKN_GPU_ASSERT(hipPointerGetAttributes(&attributes, t)); } + Pointer(T* _t) : t{_t} { + assert(t); + MKN_GPU_ASSERT(hipPointerGetAttributes(&attributes, t)); + type = attributes.type; + } - // bool is_unregistered_ptr() const { return attributes.type == 0; } - bool is_host_ptr() const { return attributes.hostPointer != nullptr; } - bool is_device_ptr() const { return attributes.devicePointer != nullptr; } - bool is_managed_ptr() const { return attributes.isManaged; } + bool is_unregistered_ptr() const { + return attributes.type == hipMemoryType::hipMemoryTypeUnregistered; + } + bool is_host_ptr() const { + return is_unregistered_ptr() || type == hipMemoryType::hipMemoryTypeHost; + } + bool is_device_ptr() const { + return type == hipMemoryType::hipMemoryTypeDevice || attributes.isManaged; + } + bool is_managed_ptr() const { + return attributes.isManaged || type == hipMemoryType::hipMemoryTypeUnified; + } T* t; hipPointerAttribute_t attributes; + hipMemoryType type = hipMemoryType::hipMemoryTypeUnregistered; }; template @@ -109,7 +158,7 @@ void alloc_managed(T*& p, Size size) { MKN_GPU_ASSERT(hipMallocManaged((void**)&p, size * sizeof(T))); } -void destroy(void* p) { +void inline destroy(void* p) { KLOG(TRC); MKN_GPU_ASSERT(hipFree(p)); } @@ -165,19 +214,25 @@ void take_async(T* p, Span& span, Stream& stream, std::size_t start) { stream())); } -void sync() { MKN_GPU_ASSERT(hipDeviceSynchronize()); } +void inline sync() { MKN_GPU_ASSERT(hipDeviceSynchronize()); } +void inline sync(hipStream_t stream) { MKN_GPU_ASSERT(hipStreamSynchronize(stream)); } #include "mkn/gpu/alloc.hpp" #include "mkn/gpu/device.hpp" -template +template void launch(F&& f, dim3 g, dim3 b, std::size_t ds, hipStream_t& s, Args&&... args) { std::size_t N = (g.x * g.y * g.z) * (b.x * b.y * b.z); KLOG(TRC) << N; std::apply( [&](auto&&... params) { hipLaunchKernelGGL(f, g, b, ds, s, params...); }, devmem_replace(std::forward_as_tuple(args...), std::make_index_sequence())); - sync(); + if constexpr (_sync) { + if (s) + sync(s); + else + sync(); + } } // https://rocm-documentation.readthedocs.io/en/latest/Programming_Guides/HIP-GUIDE.html#calling-global-functions @@ -234,7 +289,7 @@ void fill(Container& c, T val) { } // https://rocm-developer-tools.github.io/HIP/group__Device.html -void prinfo(size_t dev = 0) { +void inline prinfo(size_t dev = 0) { hipDeviceProp_t devProp; [[maybe_unused]] auto ret = hipGetDeviceProperties(&devProp, dev); KOUT(NON) << " System version " << devProp.major << "." << devProp.minor; diff --git a/inc/mkn/gpu/rocm/def.hpp b/inc/mkn/gpu/rocm/def.hpp deleted file mode 100644 index d556972..0000000 --- a/inc/mkn/gpu/rocm/def.hpp +++ /dev/null @@ -1,24 +0,0 @@ - -// IWYU pragma: private, include "mkn/gpu/def.hpp" - -#ifndef _MKN_GPU_ROCM_DEF_HPP_ -#define _MKN_GPU_ROCM_DEF_HPP_ - -#include "hip/hip_runtime.h" - -namespace mkn::gpu::hip { - -template -__device__ SIZE idx() { - SIZE width = hipGridDim_x * hipBlockDim_x; - SIZE height = hipGridDim_y * hipBlockDim_y; - - SIZE x = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; - SIZE y = hipBlockDim_y * hipBlockIdx_y + hipThreadIdx_y; - SIZE z = hipBlockDim_z * hipBlockIdx_z + hipThreadIdx_z; - return x + (y * width) + (z * width * height); // max 4294967296 -} - -} // namespace mkn::gpu::hip - -#endif /*_MKN_GPU_ROCM_DEF_HPP_*/ diff --git a/mkn.yaml b/mkn.yaml index 2bf616c..2d7b461 100644 --- a/mkn.yaml +++ b/mkn.yaml @@ -1,4 +1,4 @@ -#! clean build test run -tOp rocm -x res/mkn/hipcc +#! clean build test run -Op rocm -x res/mkn/hipcc -W name: mkn.gpu parent: headers @@ -10,13 +10,13 @@ profile: - name: rocm parent: headers - arg: -DMKN_GPU_ROCM + # arg: -DMKN_GPU_ROCM=1 test: test/any/(\w).cpp test/hip/(\w).cpp - name: cuda parent: headers - arg: -DMKN_GPU_CUDA + # arg: -DMKN_GPU_CUDA test: test/any/(\w).cpp test/cuda/(\w).cpp diff --git a/res/mkn/hipcc.yaml b/res/mkn/hipcc.yaml index d0dc3f7..3cb9a0a 100644 --- a/res/mkn/hipcc.yaml +++ b/res/mkn/hipcc.yaml @@ -1,8 +1,8 @@ ## Recommended settings commented out. -# local: -# repo: /mkn/r -# mod-repo: /mkn/m +local: + repo: /mkn/r + mod-repo: /mkn/m # remote: # repo: git@github.com:mkn/ diff --git a/test/any/add.cpp b/test/any/add.cpp index 860d7f0..64794e5 100644 --- a/test/any/add.cpp +++ b/test/any/add.cpp @@ -30,10 +30,17 @@ __global__ void vectoradd1(T* a, T* b) { template uint32_t test_add1() { std::vector b(NUM); + + assert(mkn::gpu::Pointer{b.data()}.is_host_ptr()); + for (uint32_t i = 0; i < NUM; i++) b[i] = i; mkn::gpu::DeviceMem devA(NUM), devB(b); + assert(mkn::gpu::Pointer{devA.p}.is_device_ptr()); + mkn::gpu::Launcher{WIDTH, HEIGHT, TPB_X, TPB_Y}(vectoradd1, devA, devB); auto a = devA(); + + // assert(mkn::gpu::Pointer{a.data()}.is_device_ptr()); for (uint32_t i = 0; i < NUM; i++) if (a[i] != b[i] + 1) return 1; return 0; diff --git a/test/any/managed.cpp b/test/any/managed.cpp index 158e960..23689d9 100644 --- a/test/any/managed.cpp +++ b/test/any/managed.cpp @@ -21,6 +21,8 @@ __global__ void kernel(S* structs) { template std::uint32_t _test(L&& launcher) { ManagedVector mem{NUM}; + assert(mkn::gpu::Pointer{mem.data()}.is_managed_ptr()); + for (std::uint32_t i = 0; i < NUM; ++i) mem[i].d0 = i; launcher(kernel, mem);