Skip to content

Commit

Permalink
Move Kernel constructor and call/set_args trampolines to C++
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 7, 2024
1 parent 4c3dd3c commit 8fb6c4a
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 46 deletions.
44 changes: 3 additions & 41 deletions pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def __getattr__(self, attr):
# Nvidia does not raise errors even for invalid names,
# but this will give an error if the kernel is invalid.
knl.num_args # noqa: B018
knl._source = getattr(self, "_source", None)

if self._build_duration_info is not None:
build_descr, _was_cached, duration = self._build_duration_info
Expand Down Expand Up @@ -793,30 +792,9 @@ def __getattr__(self, name):

# {{{ Kernel

kernel_old_init = Kernel.__init__
kernel_old_get_info = Kernel.get_info
kernel_old_get_work_group_info = Kernel.get_work_group_info

def kernel_init(self, prg, name):
if not isinstance(prg, _cl._Program):
prg = prg._get_prg()

kernel_old_init(self, prg, name)

self._setup(prg)

def kernel__setup(self, prg):
self._source = getattr(prg, "_source", None)

from pyopencl.invoker import generate_enqueue_and_set_args
self._enqueue, self._set_args = generate_enqueue_and_set_args(
self.function_name, self.num_args, self.num_args,
None,
warn_about_arg_count_bug=None,
work_around_arg_count_bug=None, devs=self.context.devices)

return self

def kernel_set_arg_types(self, arg_types):
arg_types = tuple(arg_types)

Expand Down Expand Up @@ -845,14 +823,14 @@ def kernel_set_arg_types(self, arg_types):
# }}}

from pyopencl.invoker import generate_enqueue_and_set_args
self._enqueue, self.set_args = \
generate_enqueue_and_set_args(
self._set_enqueue_and_set_args(
*generate_enqueue_and_set_args(
self.function_name,
len(arg_types), self.num_args,
arg_types,
warn_about_arg_count_bug=warn_about_arg_count_bug,
work_around_arg_count_bug=work_around_arg_count_bug,
devs=self.context.devices)
devs=self.context.devices))

def kernel_get_work_group_info(self, param, device):
try:
Expand All @@ -870,18 +848,6 @@ def kernel_get_work_group_info(self, param, device):
wg_info_cache[cache_key] = result
return result

def kernel_set_args(self, *args, **kwargs):
# Need to duplicate the 'self' argument for dynamically generated method
return self._set_args(self, *args, **kwargs)

def kernel_call(self, queue, global_size, local_size, *args, **kwargs):
# __call__ can't be overridden directly, so we need this
# trampoline hack.

# Note: This is only used for the generic __call__, before
# kernel_set_scalar_arg_dtypes is called.
return self._enqueue(self, queue, global_size, local_size, *args, **kwargs)

def kernel_capture_call(self, output_file, queue, global_size, local_size,
*args, **kwargs):
from pyopencl.capture_call import capture_kernel_call
Expand All @@ -896,16 +862,12 @@ def kernel_get_info(self, param_name):
else:
return val

Kernel.__init__ = kernel_init
Kernel._setup = kernel__setup
Kernel.get_work_group_info = kernel_get_work_group_info

# FIXME: Possibly deprecate this version
Kernel.set_scalar_arg_dtypes = kernel_set_arg_types
Kernel.set_arg_types = kernel_set_arg_types

Kernel.set_args = kernel_set_args
Kernel.__call__ = kernel_call
Kernel.capture_call = kernel_capture_call
Kernel.get_info = kernel_get_info

Expand Down
83 changes: 79 additions & 4 deletions src/wrap_cl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,8 @@

namespace pyopencl
{
using namespace py::literals;

class program;
class command_queue;

Expand Down Expand Up @@ -4773,24 +4775,51 @@ namespace pyopencl
cl_kernel m_kernel;
bool m_set_arg_prefer_svm;

// Source is a Python object so that we can hold a reference to the source object
// without a need to copy it.
//
// Not implementing GC traversals for this because (IMO) it's
// unlikely the source string is involved in a cycle with the
// kernel object.
py::object m_source;

// These are generated code, unlikely to hold a reference back to the
// kernel, therefore also not implementing GC traversal for this.
py::object m_enqueue_func;
py::object m_set_args_func;

public:
kernel(cl_kernel knl, bool retain)
: m_kernel(knl), m_set_arg_prefer_svm(false)
{
if (retain)
PYOPENCL_CALL_GUARDED(clRetainKernel, (knl));

set_up_basic_invokers();
}

kernel(program const &prg, std::string const &kernel_name)
kernel(py::object prg_py, std::string const &kernel_name)
: m_set_arg_prefer_svm(false)
{
program const *prg = nullptr;
try
{
prg = py::cast<program const *>(prg_py);
}
catch (py::cast_error) {
prg = py::cast<program const *>(prg_py.attr("_get_prg")());
}

cl_int status_code;

PYOPENCL_PRINT_CALL_TRACE("clCreateKernel");
m_kernel = clCreateKernel(prg.data(), kernel_name.c_str(),
&status_code);
m_kernel = clCreateKernel(prg->data(), kernel_name.c_str(), &status_code);
if (status_code != CL_SUCCESS)
throw pyopencl::error("clCreateKernel", status_code);

m_source = py::getattr(prg_py, "_source", py::object());

set_up_basic_invokers();
}

~kernel()
Expand All @@ -4803,6 +4832,11 @@ namespace pyopencl
return m_kernel;
}

py::object source() const
{
return m_source;
}

PYOPENCL_EQUALITY_TESTS(kernel);

#if PYOPENCL_CL_VERSION >= 0x2010
Expand Down Expand Up @@ -5167,8 +5201,49 @@ namespace pyopencl
default:
throw error("Kernel.get_sub_group_info", CL_INVALID_VALUE);
}
}
}
#endif

void set_up_basic_invokers()
{
py::module_ invoker = py::module_::import_("pyopencl.invoker");

py::tuple res = py::cast<py::tuple>(invoker.attr("generate_enqueue_and_set_args")(
get_info(CL_KERNEL_FUNCTION_NAME),
num_args(), num_args(),
py::none(),
"warn_about_arg_count_bug"_a=py::none(),
"work_around_arg_count_bug"_a=py::none(),
"devs"_a=get_info(CL_KERNEL_CONTEXT).attr("devices")
));

m_enqueue_func = res[0];
m_set_args_func = res[1];
}

void set_enqueue_and_set_args(py::object enqueue_func, py::object set_args_func)
{
m_enqueue_func = enqueue_func;
m_set_args_func = set_args_func;
}

py::object enqueue(py::args args, py::kwargs kwargs) const
{
return m_enqueue_func(py::cast(this), *args, **kwargs);
}

void set_args(py::args args, py::kwargs kwargs) const
{
m_set_args_func(py::cast(this), *args, **kwargs);
}

cl_uint num_args() const
{
cl_uint param_value;
PYOPENCL_CALL_GUARDED(clGetKernelInfo,
(m_kernel, CL_KERNEL_NUM_ARGS, sizeof(param_value), &param_value, 0));
return param_value;
}
};

#define PYOPENCL_KERNEL_SET_ARG_MULTI_ERROR_HANDLER \
Expand Down
6 changes: 5 additions & 1 deletion src/wrap_cl_part_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,8 @@ void pyopencl_expose_part_2(py::module_ &m)
{
typedef kernel cls;
py::class_<cls>(m, "Kernel", py::dynamic_attr())
.def(py::init<const program &, std::string const &>())
.def(py::init<py::object, std::string const &>())
.def_prop_ro("_source", &cls::source)
.DEF_SIMPLE_METHOD(get_info)
.DEF_SIMPLE_METHOD(get_work_group_info)
#if PYOPENCL_CL_VERSION >= 0x2010
Expand Down Expand Up @@ -585,6 +586,9 @@ void pyopencl_expose_part_2(py::module_ &m)
py::arg("input_value").none(true)=py::none()
)
#endif
.def("__call__", &cls::enqueue)
.def("set_args", &cls::set_args)
.def("_set_enqueue_and_set_args", &cls::set_enqueue_and_set_args)
;
}

Expand Down

0 comments on commit 8fb6c4a

Please sign in to comment.