Skip to content

Commit

Permalink
Fix ambiguous overloads of launch()
Browse files Browse the repository at this point in the history
- Adds launch_raw() methods to replace overloads that take array of
  arg pointers, which were dangerously ambiguous with the variadic
  overload.
- Adds explicit no-argument overload of launch() to avoid forming
  zero-sized array.
  • Loading branch information
benbarsdell committed Oct 23, 2023
1 parent 88ab8fa commit 6f287b8
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions jitify2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1796,13 +1796,30 @@ class ConfiguredKernelData {
/*! Get the configured CUDA stream. */
CUstream stream() const { return stream_; }

// TODO: Taking void** here is dangerous due to ambiguity with the variadic
// overload below. E.g., passing void*const* silently fails.
/*! Launch the configured kernel.
* \param arg_ptrs Array of pointers to kernel arguments.
* \return An empty string on success, otherwise an error message.
* \deprecated Use \p launch_raw instead.
*/
ErrorMsg launch(void** arg_ptrs) const {
JITIFY_DEPRECATED("Use launch_raw instead")
ErrorMsg launch(void** arg_ptrs) const { return launch_raw(arg_ptrs); }

/*! Launch the configured kernel.
* \param arg_ptrs Vector of pointers to kernel arguments.
* \return An empty string on success, otherwise an error message.
* \deprecated Use \p launch_raw instead.
*/
JITIFY_DEPRECATED("Use launch_raw instead")
ErrorMsg launch(const std::vector<void*>& arg_ptrs) const {
return launch_raw(arg_ptrs);
}

/*! Launch the configured kernel.
* \param arg_ptrs Array of pointers to kernel arguments.
* \return An empty string on success, otherwise an error message.
*/
ErrorMsg launch_raw(void** arg_ptrs) const {
if (!cuda()) JITIFY_THROW_OR_RETURN(cuda().error());
JITIFY_THROW_OR_RETURN_IF_CUDA_ERROR(cuda().LaunchKernel()(
kernel_.function(), grid_.x, grid_.y, grid_.z, block_.x, block_.y,
Expand All @@ -1814,19 +1831,26 @@ class ConfiguredKernelData {
* \param arg_ptrs Vector of pointers to kernel arguments.
* \return An empty string on success, otherwise an error message.
*/
ErrorMsg launch(const std::vector<void*>& arg_ptrs = {}) const {
return launch(const_cast<void**>(arg_ptrs.data()));
ErrorMsg launch_raw(const std::vector<void*>& arg_ptrs) const {
return launch_raw(const_cast<void**>(arg_ptrs.data()));
}

/*! Launch the configured kernel.
* \param args Arguments for the kernel. Note that reference arguments must
* be passed as pointers.
* \return An empty string on success, otherwise an error message.
*/
template <typename... Args>
ErrorMsg launch(const Args&... args) const {
void* arg_ptrs[] = {(void*)&args...};
return this->launch(arg_ptrs);
template <typename Arg, typename... Args>
ErrorMsg launch(const Arg& arg, const Args&... args) const {
void* arg_ptrs[] = {(void*)&arg, (void*)&args...};
return this->launch_raw(arg_ptrs);
}

/*! Launch the configured kernel.
* \return An empty string on success, otherwise an error message.
*/
ErrorMsg launch() const {
return this->launch_raw(nullptr);
}
};

Expand Down

0 comments on commit 6f287b8

Please sign in to comment.