From 6f287b8c1e4e66e2ab88210b58420a98a6430a6f Mon Sep 17 00:00:00 2001 From: Ben Barsdell Date: Mon, 23 Oct 2023 11:39:49 +1100 Subject: [PATCH] Fix ambiguous overloads of launch() - Adds launch_raw() methods to replace overloads that take array of arg pointers, which were dangerously ambiguous with the variadic overload. - Adds explicit no-argument overload of launch() to avoid forming zero-sized array. --- jitify2.hpp | 40 ++++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/jitify2.hpp b/jitify2.hpp index 8867300..facb521 100644 --- a/jitify2.hpp +++ b/jitify2.hpp @@ -1796,13 +1796,30 @@ class ConfiguredKernelData { /*! Get the configured CUDA stream. */ CUstream stream() const { return stream_; } - // TODO: Taking void** here is dangerous due to ambiguity with the variadic // overload below. E.g., passing void*const* silently fails. /*! Launch the configured kernel. * \param arg_ptrs Array of pointers to kernel arguments. * \return An empty string on success, otherwise an error message. + * \deprecated Use \p launch_raw instead. */ - ErrorMsg launch(void** arg_ptrs) const { + JITIFY_DEPRECATED("Use launch_raw instead") + ErrorMsg launch(void** arg_ptrs) const { return launch_raw(arg_ptrs); } + + /*! Launch the configured kernel. + * \param arg_ptrs Vector of pointers to kernel arguments. + * \return An empty string on success, otherwise an error message. + * \deprecated Use \p launch_raw instead. + */ + JITIFY_DEPRECATED("Use launch_raw instead") + ErrorMsg launch(const std::vector& arg_ptrs) const { + return launch_raw(arg_ptrs); + } + + /*! Launch the configured kernel. + * \param arg_ptrs Array of pointers to kernel arguments. + * \return An empty string on success, otherwise an error message. + */ + ErrorMsg launch_raw(void** arg_ptrs) const { if (!cuda()) JITIFY_THROW_OR_RETURN(cuda().error()); JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(cuda().LaunchKernel()( kernel_.function(), grid_.x, grid_.y, grid_.z, block_.x, block_.y, @@ -1814,8 +1831,8 @@ class ConfiguredKernelData { * \param arg_ptrs Vector of pointers to kernel arguments. * \return An empty string on success, otherwise an error message. */ - ErrorMsg launch(const std::vector& arg_ptrs = {}) const { - return launch(const_cast(arg_ptrs.data())); + ErrorMsg launch_raw(const std::vector& arg_ptrs) const { + return launch_raw(const_cast(arg_ptrs.data())); } /*! Launch the configured kernel. @@ -1823,10 +1840,17 @@ class ConfiguredKernelData { * be passed as pointers. * \return An empty string on success, otherwise an error message. */ - template - ErrorMsg launch(const Args&... args) const { - void* arg_ptrs[] = {(void*)&args...}; - return this->launch(arg_ptrs); + template + ErrorMsg launch(const Arg& arg, const Args&... args) const { + void* arg_ptrs[] = {(void*)&arg, (void*)&args...}; + return this->launch_raw(arg_ptrs); + } + + /*! Launch the configured kernel. + * \return An empty string on success, otherwise an error message. + */ + ErrorMsg launch() const { + return this->launch_raw(nullptr); } };