diff --git a/dpnp/backend/extensions/ufunc/CMakeLists.txt b/dpnp/backend/extensions/ufunc/CMakeLists.txt index d45bfa822e5..5f892506b81 100644 --- a/dpnp/backend/extensions/ufunc/CMakeLists.txt +++ b/dpnp/backend/extensions/ufunc/CMakeLists.txt @@ -38,6 +38,7 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/lcm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/ldexp.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/logaddexp2.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/nan_to_num.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/radians.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/sinc.cpp ${CMAKE_CURRENT_SOURCE_DIR}/elementwise_functions/spacing.cpp diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp index 43a68e487cf..f9d179d5ca4 100644 --- a/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/common.cpp @@ -38,6 +38,7 @@ #include "lcm.hpp" #include "ldexp.hpp" #include "logaddexp2.hpp" +#include "nan_to_num.hpp" #include "radians.hpp" #include "sinc.hpp" #include "spacing.hpp" @@ -64,6 +65,7 @@ void init_elementwise_functions(py::module_ m) init_lcm(m); init_ldexp(m); init_logaddexp2(m); + init_nan_to_num(m); init_radians(m); init_sinc(m); init_spacing(m); diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp new file mode 100644 index 00000000000..ec5dfd0a78b --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.cpp @@ -0,0 +1,408 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "dpctl4pybind11.hpp" +#include +#include +#include + +#include "kernels/elementwise_functions/nan_to_num.hpp" + +#include "../../elementwise_functions/simplify_iteration_space.hpp" + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" +#include "utils/type_utils.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +// declare pybind11 wrappers in py_internal namespace +namespace dpnp::extensions::ufunc +{ + +namespace impl +{ + +template +struct value_type_of +{ + using type = T; +}; + +template +struct value_type_of> +{ + using type = T; +}; + +template +using value_type_of_t = typename value_type_of::type; + +typedef sycl::event (*nan_to_num_fn_ptr_t)(sycl::queue &, + int, + std::size_t, + const py::ssize_t *, + const py::object &, + const py::object &, + const py::object &, + const char *, + py::ssize_t, + char *, + py::ssize_t, + const std::vector &); + +template +sycl::event nan_to_num_strided_call(sycl::queue &exec_q, + int nd, + std::size_t nelems, + const py::ssize_t *shape_strides, + const py::object &py_nan, + const py::object &py_posinf, + const py::object &py_neginf, + const char *arg_p, + py::ssize_t arg_offset, + char *dst_p, + py::ssize_t dst_offset, + const std::vector &depends) +{ + using dpctl::tensor::type_utils::is_complex_v; + using scT = std::conditional_t, value_type_of_t, T>; + + const scT nan_v = py::cast(py_nan); + const scT posinf_v = py::cast(py_posinf); + const scT neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_strided_impl; + sycl::event to_num_ev = nan_to_num_strided_impl( + exec_q, nd, nelems, shape_strides, nan_v, posinf_v, neginf_v, arg_p, + arg_offset, dst_p, dst_offset, depends); + + return to_num_ev; +} + +typedef sycl::event (*nan_to_num_contig_fn_ptr_t)( + sycl::queue &, + std::size_t, + const py::object &, + const py::object &, + const py::object &, + const char *, + char *, + const std::vector &); + +template +sycl::event nan_to_num_contig_call(sycl::queue &exec_q, + std::size_t nelems, + const py::object &py_nan, + const py::object &py_posinf, + const py::object &py_neginf, + const char *arg_p, + char *dst_p, + const std::vector &depends) +{ + using dpctl::tensor::type_utils::is_complex_v; + using scT = std::conditional_t, value_type_of_t, T>; + + const scT nan_v = py::cast(py_nan); + const scT posinf_v = py::cast(py_posinf); + const scT neginf_v = py::cast(py_neginf); + + using dpnp::kernels::nan_to_num::nan_to_num_contig_impl; + sycl::event to_num_contig_ev = nan_to_num_contig_impl( + exec_q, nelems, nan_v, posinf_v, neginf_v, arg_p, dst_p, depends); + + return to_num_contig_ev; +} + +namespace td_ns = dpctl::tensor::type_dispatch; +nan_to_num_fn_ptr_t nan_to_num_dispatch_vector[td_ns::num_types]; +nan_to_num_contig_fn_ptr_t nan_to_num_contig_dispatch_vector[td_ns::num_types]; + +std::pair + py_nan_to_num(const dpctl::tensor::usm_ndarray &src, + const py::object &py_nan, + const py::object &py_posinf, + const py::object &py_neginf, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &q, + const std::vector &depends) +{ + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error("Array data types are not the same."); + } + + if (!dpctl::utils::queues_are_compatible(q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + int src_nd = src.get_ndim(); + if (src_nd != dst.get_ndim()) { + throw py::value_error("Array dimensions are not the same."); + } + + const py::ssize_t *src_shape = src.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + + std::size_t nelems = src.get_size(); + bool shapes_equal = std::equal(src_shape, src_shape + src_nd, dst_shape); + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, nelems); + + // check memory overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if (overlap(src, dst) && !same_logical_tensors(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + // handle contiguous inputs + bool is_src_c_contig = src.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_dst_f_contig = dst.is_f_contiguous(); + + bool both_c_contig = (is_src_c_contig && is_dst_c_contig); + bool both_f_contig = (is_src_f_contig && is_dst_f_contig); + + if (both_c_contig || both_f_contig) { + auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid]; + + if (contig_fn == nullptr) { + throw std::runtime_error( + "Contiguous implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + auto comp_ev = contig_fn(q, nelems, py_nan, py_posinf, py_neginf, + src_data, dst_data, depends); + sycl::event ht_ev = + dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + + // simplify iteration space + // if 1d with strides 1 - input is contig + // dispatch to strided + + auto const &src_strides = src.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src_strides; + shT simplified_dst_strides; + py::ssize_t src_offset(0); + py::ssize_t dst_offset(0); + + int nd = src_nd; + const py::ssize_t *shape = src_shape; + + py_internal::simplify_iteration_space( + nd, shape, src_strides, dst_strides, + // output + simplified_shape, simplified_src_strides, simplified_dst_strides, + src_offset, dst_offset); + + if (nd == 1 && simplified_src_strides[0] == 1 && + simplified_dst_strides[0] == 1) { + // Special case of contiguous data + auto contig_fn = nan_to_num_contig_dispatch_vector[src_typeid]; + + if (contig_fn == nullptr) { + throw std::runtime_error( + "Contiguous implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + int src_elem_size = src.get_elemsize(); + int dst_elem_size = dst.get_elemsize(); + auto comp_ev = + contig_fn(q, nelems, py_nan, py_posinf, py_neginf, + src_data + src_elem_size * src_offset, + dst_data + dst_elem_size * dst_offset, depends); + + sycl::event ht_ev = + dpctl::utils::keep_args_alive(q, {src, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + + auto fn = nan_to_num_dispatch_vector[src_typeid]; + + if (fn == nullptr) { + throw std::runtime_error( + "nan_to_num implementation is missing for src_typeid=" + + std::to_string(src_typeid)); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + std::vector host_tasks{}; + host_tasks.reserve(2); + + auto ptr_size_event_triple_ = device_allocate_and_pack( + q, host_tasks, simplified_shape, simplified_src_strides, + simplified_dst_strides); + auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_)); + const sycl::event ©_shape_ev = std::get<2>(ptr_size_event_triple_); + const py::ssize_t *shape_strides = shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_ev); + + sycl::event comp_ev = + fn(q, nelems, nd, shape_strides, py_nan, py_posinf, py_neginf, src_data, + src_offset, dst_data, dst_offset, all_deps); + + // async free of shape_strides temporary + sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + q, {comp_ev}, shape_strides_owner); + + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(q, {src, dst}, host_tasks), comp_ev); +} + +/** + * @brief A factory to define pairs of supported types for which + * nan-to-num function is available. + * + * @tparam T Type of input vector `a` and of result vector `y`. + */ +template +struct NanToNumOutputType +{ + using value_type = typename std::disjunction< + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry, + td_ns::TypeMapResultEntry>, + td_ns::TypeMapResultEntry>, + td_ns::DefaultResultEntry>::result_type; +}; + +template +struct NanToNumFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + return nullptr; + } + else { + return nan_to_num_strided_call; + } + } +}; + +template +struct NanToNumContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v::value_type, + void>) { + return nullptr; + } + else { + return nan_to_num_contig_call; + } + } +}; + +void populate_nan_to_num_dispatch_vectors(void) +{ + using namespace td_ns; + + DispatchVectorBuilder dvb1; + dvb1.populate_dispatch_vector(nan_to_num_dispatch_vector); + + DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(nan_to_num_contig_dispatch_vector); +} + +} // namespace impl + +void init_nan_to_num(py::module_ m) +{ + { + impl::populate_nan_to_num_dispatch_vectors(); + + using impl::py_nan_to_num; + m.def("_nan_to_num", &py_nan_to_num, "", py::arg("src"), + py::arg("py_nan"), py::arg("py_posinf"), py::arg("py_neginf"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.hpp b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.hpp new file mode 100644 index 00000000000..26ac37bf1c4 --- /dev/null +++ b/dpnp/backend/extensions/ufunc/elementwise_functions/nan_to_num.hpp @@ -0,0 +1,35 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpnp::extensions::ufunc +{ +void init_nan_to_num(py::module_ m); +} // namespace dpnp::extensions::ufunc diff --git a/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp new file mode 100644 index 00000000000..c4219de63f4 --- /dev/null +++ b/dpnp/backend/kernels/elementwise_functions/nan_to_num.hpp @@ -0,0 +1,284 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include +#include + +#include +// dpctl tensor headers +#include "kernels/alignment.hpp" +#include "kernels/dpctl_tensor_types.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpnp::kernels::nan_to_num +{ + +template +inline T to_num(const T v, const T nan, const T posinf, const T neginf) +{ + return (sycl::isnan(v)) ? nan + : (sycl::isinf(v)) ? (v > 0) ? posinf : neginf + : v; +} + +template +struct NanToNumFunctor +{ +private: + const T *inp_ = nullptr; + T *out_ = nullptr; + const InOutIndexerT inp_out_indexer_; + const scT nan_; + const scT posinf_; + const scT neginf_; + +public: + NanToNumFunctor(const T *inp, + T *out, + const InOutIndexerT &inp_out_indexer, + const scT nan, + const scT posinf, + const scT neginf) + : inp_(inp), out_(out), inp_out_indexer_(inp_out_indexer), nan_(nan), + posinf_(posinf), neginf_(neginf) + { + } + + void operator()(sycl::id<1> wid) const + { + const auto &offsets_ = inp_out_indexer_(wid.get(0)); + const dpctl::tensor::ssize_t &inp_offset = offsets_.get_first_offset(); + const dpctl::tensor::ssize_t &out_offset = offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::is_complex_v; + if constexpr (is_complex_v) { + using realT = typename T::value_type; + static_assert(std::is_same_v); + T z = inp_[inp_offset]; + realT x = to_num(z.real(), nan_, posinf_, neginf_); + realT y = to_num(z.imag(), nan_, posinf_, neginf_); + out_[out_offset] = T{x, y}; + } + else { + out_[out_offset] = to_num(inp_[inp_offset], nan_, posinf_, neginf_); + } + } +}; + +template +struct NanToNumContigFunctor +{ +private: + const T *in_ = nullptr; + T *out_ = nullptr; + std::size_t nelems_; + const scT nan_; + const scT posinf_; + const scT neginf_; + +public: + NanToNumContigFunctor(const T *in, + T *out, + const std::size_t n_elems, + const scT nan, + const scT posinf, + const scT neginf) + : in_(in), out_(out), nelems_(n_elems), nan_(nan), posinf_(posinf), + neginf_(neginf) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; + /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NOTE: work-group size must be divisible by sub-group size */ + + using dpctl::tensor::type_utils::is_complex_v; + if constexpr (enable_sg_loadstore && !is_complex_v) { + auto sg = ndit.get_sub_group(); + const std::uint16_t sgSize = sg.get_max_local_range()[0]; + const std::size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { + using dpctl::tensor::sycl_utils::sub_group_load; + using dpctl::tensor::sycl_utils::sub_group_store; +#pragma unroll + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const std::size_t offset = base + it * sgSize; + auto in_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&in_[offset]); + auto out_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&out_[offset]); + + sycl::vec arg_vec = + sub_group_load(sg, in_multi_ptr); +#pragma unroll + for (std::uint32_t k = 0; k < vec_sz; ++k) { + arg_vec[k] = to_num(arg_vec[k], nan_, posinf_, neginf_); + } + sub_group_store(sg, arg_vec, out_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) { + out_[k] = to_num(in_[k], nan_, posinf_, neginf_); + } + } + } + else { + const std::uint16_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const std::size_t gid = ndit.get_global_linear_id(); + const std::uint16_t elems_per_sg = sgSize * elems_per_wi; + + const std::size_t start = + (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const std::size_t end = std::min(nelems_, start + elems_per_sg); + for (std::size_t offset = start; offset < end; offset += sgSize) { + if constexpr (is_complex_v) { + using realT = typename T::value_type; + static_assert(std::is_same_v); + + T z = in_[offset]; + realT x = to_num(z.real(), nan_, posinf_, neginf_); + realT y = to_num(z.imag(), nan_, posinf_, neginf_); + out_[offset] = T{x, y}; + } + else { + out_[offset] = to_num(in_[offset], nan_, posinf_, neginf_); + } + } + } + } +}; + +template +sycl::event nan_to_num_strided_impl(sycl::queue &q, + const size_t nelems, + const int nd, + const dpctl::tensor::ssize_t *shape_strides, + const scT nan, + const scT posinf, + const scT neginf, + const char *in_cp, + const dpctl::tensor::ssize_t in_offset, + char *out_cp, + const dpctl::tensor::ssize_t out_offset, + const std::vector &depends) +{ + dpctl::tensor::type_utils::validate_type_for_device(q); + + const T *in_tp = reinterpret_cast(in_cp); + T *out_tp = reinterpret_cast(out_cp); + + using InOutIndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + const InOutIndexerT indexer{nd, in_offset, out_offset, shape_strides}; + + sycl::event comp_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using NanToNumFunc = NanToNumFunctor; + cgh.parallel_for( + {nelems}, + NanToNumFunc(in_tp, out_tp, indexer, nan, posinf, neginf)); + }); + return comp_ev; +} + +template +sycl::event nan_to_num_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const scT nan, + const scT posinf, + const scT neginf, + const char *in_cp, + char *out_cp, + const std::vector &depends = {}) +{ + constexpr std::uint8_t elems_per_wi = n_vecs * vec_sz; + const std::size_t n_work_items_needed = nelems / elems_per_wi; + const std::size_t empirical_threshold = std::size_t(1) << 21; + const std::size_t lws = (n_work_items_needed <= empirical_threshold) + ? std::size_t(128) + : std::size_t(256); + + const std::size_t n_groups = + ((nelems + lws * elems_per_wi - 1) / (lws * elems_per_wi)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + const T *in_tp = reinterpret_cast(in_cp); + T *out_tp = reinterpret_cast(out_cp); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using dpctl::tensor::kernels::alignment_utils::is_aligned; + using dpctl::tensor::kernels::alignment_utils::required_alignment; + if (is_aligned(in_tp) && + is_aligned(out_tp)) + { + constexpr bool enable_sg_loadstore = true; + using NanToNumFunc = NanToNumContigFunctor; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf)); + } + else { + constexpr bool disable_sg_loadstore = false; + using NanToNumFunc = NanToNumContigFunctor; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + NanToNumFunc(in_tp, out_tp, nelems, nan, posinf, neginf)); + } + }); + + return comp_ev; +} + +} // namespace dpnp::kernels::nan_to_num diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index cf3d14b98de..8be5d6e28ce 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -3125,21 +3125,11 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): "nan must be a scalar of an integer, float, bool, " f"but got {type(nan)}" ) - - out = dpnp.empty_like(x) if copy else x x_type = x.dtype.type if not issubclass(x_type, dpnp.inexact): - return x + return dpnp.copy(x) if copy else dpnp.get_result_array(x) - parts = ( - (x.real, x.imag) if issubclass(x_type, dpnp.complexfloating) else (x,) - ) - parts_out = ( - (out.real, out.imag) - if issubclass(x_type, dpnp.complexfloating) - else (out,) - ) max_f, min_f = _get_max_min(x.real.dtype) if posinf is not None: if not isinstance(posinf, (int, float)): @@ -3156,16 +3146,26 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): ) min_f = neginf - for part, part_out in zip(parts, parts_out): - nan_mask = dpnp.isnan(part) - posinf_mask = dpnp.isposinf(part) - neginf_mask = dpnp.isneginf(part) + if copy: + out = dpnp.empty_like(x) + else: + if not x.flags.writable: + raise ValueError("copy is required for read-only array `x`") + out = x + + x_ary = dpnp.get_usm_ndarray(x) + out_ary = dpnp.get_usm_ndarray(out) + + q = x.sycl_queue + _manager = dpu.SequentialOrderManager[q] + + h_ev, comp_ev = ufi._nan_to_num( + x_ary, nan, max_f, min_f, out_ary, q, depends=_manager.submitted_events + ) - part = dpnp.where(nan_mask, nan, part, out=part_out) - part = dpnp.where(posinf_mask, max_f, part, out=part_out) - part = dpnp.where(neginf_mask, min_f, part, out=part_out) + _manager.add_event_pair(h_ev, comp_ev) - return out + return dpnp.get_result_array(out) _NEGATIVE_DOCSTRING = """ diff --git a/dpnp/tests/test_mathematical.py b/dpnp/tests/test_mathematical.py index 059cce931d2..75c18075d3c 100644 --- a/dpnp/tests/test_mathematical.py +++ b/dpnp/tests/test_mathematical.py @@ -1558,6 +1558,27 @@ def test_errors_diff_types(self, kwarg, value): with pytest.raises(TypeError): dpnp.nan_to_num(ia, **{kwarg: value}) + def test_error_readonly(self): + a = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf]) + a.flags.writable = False + with pytest.raises(ValueError): + dpnp.nan_to_num(a, copy=False) + + @pytest.mark.parametrize("copy", [True, False]) + @pytest.mark.parametrize("dt", get_all_dtypes(no_bool=True, no_none=True)) + def test_nan_to_num_strided(self, copy, dt): + n = 10 + dt = numpy.dtype(dt) + np_a = numpy.arange(2 * n, dtype=dt) + dp_a = dpnp.arange(2 * n, dtype=dt) + if dt.kind in "fc": + np_a[::4] = numpy.nan + dp_a[::4] = dpnp.nan + dp_r = dpnp.nan_to_num(dp_a[::-2], copy=copy, nan=57.0) + np_r = numpy.nan_to_num(np_a[::-2], copy=copy, nan=57.0) + + assert_dtype_allclose(dp_r, np_r) + class TestProd: @pytest.mark.parametrize("axis", [None, 0, 1, -1, 2, -2, (1, 2), (0, -2)])