From 8fb6c4a64eedbfb19d1f4bd310786d7d01eee84b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 7 Oct 2024 17:09:11 -0500 Subject: [PATCH] Move Kernel constructor and call/set_args trampolines to C++ --- pyopencl/__init__.py | 44 ++-------------------- src/wrap_cl.hpp | 83 ++++++++++++++++++++++++++++++++++++++++-- src/wrap_cl_part_2.cpp | 6 ++- 3 files changed, 87 insertions(+), 46 deletions(-) diff --git a/pyopencl/__init__.py b/pyopencl/__init__.py index fe1b4027d..fd36e6134 100644 --- a/pyopencl/__init__.py +++ b/pyopencl/__init__.py @@ -444,7 +444,6 @@ def __getattr__(self, attr): # Nvidia does not raise errors even for invalid names, # but this will give an error if the kernel is invalid. knl.num_args # noqa: B018 - knl._source = getattr(self, "_source", None) if self._build_duration_info is not None: build_descr, _was_cached, duration = self._build_duration_info @@ -793,30 +792,9 @@ def __getattr__(self, name): # {{{ Kernel - kernel_old_init = Kernel.__init__ kernel_old_get_info = Kernel.get_info kernel_old_get_work_group_info = Kernel.get_work_group_info - def kernel_init(self, prg, name): - if not isinstance(prg, _cl._Program): - prg = prg._get_prg() - - kernel_old_init(self, prg, name) - - self._setup(prg) - - def kernel__setup(self, prg): - self._source = getattr(prg, "_source", None) - - from pyopencl.invoker import generate_enqueue_and_set_args - self._enqueue, self._set_args = generate_enqueue_and_set_args( - self.function_name, self.num_args, self.num_args, - None, - warn_about_arg_count_bug=None, - work_around_arg_count_bug=None, devs=self.context.devices) - - return self - def kernel_set_arg_types(self, arg_types): arg_types = tuple(arg_types) @@ -845,14 +823,14 @@ def kernel_set_arg_types(self, arg_types): # }}} from pyopencl.invoker import generate_enqueue_and_set_args - self._enqueue, self.set_args = \ - generate_enqueue_and_set_args( + self._set_enqueue_and_set_args( + *generate_enqueue_and_set_args( self.function_name, len(arg_types), self.num_args, arg_types, warn_about_arg_count_bug=warn_about_arg_count_bug, work_around_arg_count_bug=work_around_arg_count_bug, - devs=self.context.devices) + devs=self.context.devices)) def kernel_get_work_group_info(self, param, device): try: @@ -870,18 +848,6 @@ def kernel_get_work_group_info(self, param, device): wg_info_cache[cache_key] = result return result - def kernel_set_args(self, *args, **kwargs): - # Need to duplicate the 'self' argument for dynamically generated method - return self._set_args(self, *args, **kwargs) - - def kernel_call(self, queue, global_size, local_size, *args, **kwargs): - # __call__ can't be overridden directly, so we need this - # trampoline hack. - - # Note: This is only used for the generic __call__, before - # kernel_set_scalar_arg_dtypes is called. - return self._enqueue(self, queue, global_size, local_size, *args, **kwargs) - def kernel_capture_call(self, output_file, queue, global_size, local_size, *args, **kwargs): from pyopencl.capture_call import capture_kernel_call @@ -896,16 +862,12 @@ def kernel_get_info(self, param_name): else: return val - Kernel.__init__ = kernel_init - Kernel._setup = kernel__setup Kernel.get_work_group_info = kernel_get_work_group_info # FIXME: Possibly deprecate this version Kernel.set_scalar_arg_dtypes = kernel_set_arg_types Kernel.set_arg_types = kernel_set_arg_types - Kernel.set_args = kernel_set_args - Kernel.__call__ = kernel_call Kernel.capture_call = kernel_capture_call Kernel.get_info = kernel_get_info diff --git a/src/wrap_cl.hpp b/src/wrap_cl.hpp index f127ef218..1a881b5ab 100644 --- a/src/wrap_cl.hpp +++ b/src/wrap_cl.hpp @@ -455,6 +455,8 @@ namespace pyopencl { + using namespace py::literals; + class program; class command_queue; @@ -4773,24 +4775,51 @@ namespace pyopencl cl_kernel m_kernel; bool m_set_arg_prefer_svm; + // Source is a Python object so that we can hold a reference to the source object + // without a need to copy it. + // + // Not implementing GC traversals for this because (IMO) it's + // unlikely the source string is involved in a cycle with the + // kernel object. + py::object m_source; + + // These are generated code, unlikely to hold a reference back to the + // kernel, therefore also not implementing GC traversal for this. + py::object m_enqueue_func; + py::object m_set_args_func; + public: kernel(cl_kernel knl, bool retain) : m_kernel(knl), m_set_arg_prefer_svm(false) { if (retain) PYOPENCL_CALL_GUARDED(clRetainKernel, (knl)); + + set_up_basic_invokers(); } - kernel(program const &prg, std::string const &kernel_name) + kernel(py::object prg_py, std::string const &kernel_name) : m_set_arg_prefer_svm(false) { + program const *prg = nullptr; + try + { + prg = py::cast(prg_py); + } + catch (py::cast_error) { + prg = py::cast(prg_py.attr("_get_prg")()); + } + cl_int status_code; PYOPENCL_PRINT_CALL_TRACE("clCreateKernel"); - m_kernel = clCreateKernel(prg.data(), kernel_name.c_str(), - &status_code); + m_kernel = clCreateKernel(prg->data(), kernel_name.c_str(), &status_code); if (status_code != CL_SUCCESS) throw pyopencl::error("clCreateKernel", status_code); + + m_source = py::getattr(prg_py, "_source", py::object()); + + set_up_basic_invokers(); } ~kernel() @@ -4803,6 +4832,11 @@ namespace pyopencl return m_kernel; } + py::object source() const + { + return m_source; + } + PYOPENCL_EQUALITY_TESTS(kernel); #if PYOPENCL_CL_VERSION >= 0x2010 @@ -5167,8 +5201,49 @@ namespace pyopencl default: throw error("Kernel.get_sub_group_info", CL_INVALID_VALUE); } - } + } #endif + + void set_up_basic_invokers() + { + py::module_ invoker = py::module_::import_("pyopencl.invoker"); + + py::tuple res = py::cast(invoker.attr("generate_enqueue_and_set_args")( + get_info(CL_KERNEL_FUNCTION_NAME), + num_args(), num_args(), + py::none(), + "warn_about_arg_count_bug"_a=py::none(), + "work_around_arg_count_bug"_a=py::none(), + "devs"_a=get_info(CL_KERNEL_CONTEXT).attr("devices") + )); + + m_enqueue_func = res[0]; + m_set_args_func = res[1]; + } + + void set_enqueue_and_set_args(py::object enqueue_func, py::object set_args_func) + { + m_enqueue_func = enqueue_func; + m_set_args_func = set_args_func; + } + + py::object enqueue(py::args args, py::kwargs kwargs) const + { + return m_enqueue_func(py::cast(this), *args, **kwargs); + } + + void set_args(py::args args, py::kwargs kwargs) const + { + m_set_args_func(py::cast(this), *args, **kwargs); + } + + cl_uint num_args() const + { + cl_uint param_value; + PYOPENCL_CALL_GUARDED(clGetKernelInfo, + (m_kernel, CL_KERNEL_NUM_ARGS, sizeof(param_value), ¶m_value, 0)); + return param_value; + } }; #define PYOPENCL_KERNEL_SET_ARG_MULTI_ERROR_HANDLER \ diff --git a/src/wrap_cl_part_2.cpp b/src/wrap_cl_part_2.cpp index 76cfed278..08b873f7f 100644 --- a/src/wrap_cl_part_2.cpp +++ b/src/wrap_cl_part_2.cpp @@ -538,7 +538,8 @@ void pyopencl_expose_part_2(py::module_ &m) { typedef kernel cls; py::class_(m, "Kernel", py::dynamic_attr()) - .def(py::init()) + .def(py::init()) + .def_prop_ro("_source", &cls::source) .DEF_SIMPLE_METHOD(get_info) .DEF_SIMPLE_METHOD(get_work_group_info) #if PYOPENCL_CL_VERSION >= 0x2010 @@ -585,6 +586,9 @@ void pyopencl_expose_part_2(py::module_ &m) py::arg("input_value").none(true)=py::none() ) #endif + .def("__call__", &cls::enqueue) + .def("set_args", &cls::set_args) + .def("_set_enqueue_and_set_args", &cls::set_enqueue_and_set_args) ; }