From c7770fd63036b810849ec7807fa20f63e2821f1e Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Fri, 2 Feb 2024 14:21:14 +0100 Subject: [PATCH] Update dpnp.linalg.svd() function (#1604) * Draft commit of dpnp.linalg.svd impl * Pass empty arrays if compute_uv=False * Add logic for the input array n < m * Add a new cupy test_decomposition * Rename gesvd input parameters * Correspondence of passed parameters to gesvd signature * Correct initialization of result variables in dpnp_svd * Update test_decomposition * Add implementation of _dpnp_svd_batch * Add test_decomposition to the scope of public CI * Improve error handling for mkl_lapack::gesvd function * Declate detail variable * Use a_usm_type and a_sycl_queue variables * Add additional checks for gesvd function * Remove old dpnp_svd backend * Refresh test_svd in test_linalg * Add detailed comments for gesvd arguments * gesvd returns pair of events and uses dpctl.utils.keep_args_alive * Keep a lexicographical order * Update docstrings for svd * Add test_svd to test_usm_type * Add a new impl to get s_type * Add a description for _stacked_identity * Simplify dpnp_svd_batch * Update tests for dpnp.linalg.svd * Add hermitian argument support * Add test_svd_hermitian * Update svd docstrings * Tune tolerance * Update test_svd_errors * Update _common_type and _common_inexact_type * Remove passing n and m parameteres to _gesvd * Simplify results return logic for dpnp_svd_batch * Update condition and random files in cupy/testing to use fix_random and repeat decorators * Rename cupy/testing/condition.py to .../_condition.py * Use self._tol in TestSvd * Update gesvd error handler * dpnp_svd works with F contiguous arrays * Add additional checks for output arrays * Impl parallel calculation in dpnp_svd_batch * Skip using @_condition.repeat in cupy tests * Add additional checks for output arrays * Update docstrings for svd * Use dpctl.SyclEvent.wait_for in dpnp_svd_batch * Add TODO : matching the order of returned arrays * Skip cupy tests on windows * Rename condition to _condition * Set setUpClass to skip cupy tests on cpu --- dpnp/backend/extensions/lapack/CMakeLists.txt | 1 + dpnp/backend/extensions/lapack/gesvd.cpp | 359 +++++++++++++ dpnp/backend/extensions/lapack/gesvd.hpp | 55 ++ dpnp/backend/extensions/lapack/lapack_py.cpp | 9 + .../extensions/lapack/types_matrix.hpp | 22 + dpnp/backend/include/dpnp_iface_fptr.hpp | 2 - dpnp/backend/kernels/dpnp_krnl_linalg.cpp | 44 -- dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/linalg/dpnp_algo_linalg.pyx | 55 -- dpnp/linalg/dpnp_iface_linalg.py | 74 ++- dpnp/linalg/dpnp_utils_linalg.py | 481 +++++++++++++++--- tests/test_linalg.py | 205 +++++--- tests/test_sycl_queue.py | 91 ++-- tests/test_usm_type.py | 50 ++ .../cupy/linalg_tests/test_decomposition.py | 250 +++++++++ .../cupy/linalg_tests/test_solve.py | 4 +- .../cupy/random_tests/test_sample.py | 24 +- tests/third_party/cupy/testing/__init__.py | 4 +- .../testing/{condition.py => _condition.py} | 2 +- tests/third_party/cupy/testing/random.py | 17 +- 20 files changed, 1425 insertions(+), 326 deletions(-) create mode 100644 dpnp/backend/extensions/lapack/gesvd.cpp create mode 100644 dpnp/backend/extensions/lapack/gesvd.hpp rename tests/third_party/cupy/testing/{condition.py => _condition.py} (98%) diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index 626615e3e53..28fa2072d7d 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -28,6 +28,7 @@ set(python_module_name _lapack_impl) set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/lapack_py.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gesv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gesvd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp diff --git a/dpnp/backend/extensions/lapack/gesvd.cpp b/dpnp/backend/extensions/lapack/gesvd.cpp new file mode 100644 index 00000000000..27734f4492b --- /dev/null +++ b/dpnp/backend/extensions/lapack/gesvd.cpp @@ -0,0 +1,359 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "gesvd.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*gesvd_impl_fn_ptr_t)(sycl::queue, + const oneapi::mkl::jobsvd, + const oneapi::mkl::jobsvd, + const std::int64_t, + const std::int64_t, + char *, + const std::int64_t, + char *, + char *, + const std::int64_t, + char *, + const std::int64_t, + std::vector &, + const std::vector &); + +static gesvd_impl_fn_ptr_t gesvd_dispatch_table[dpctl_td_ns::num_types] + [dpctl_td_ns::num_types]; + +// Converts a given character code (ord) to the corresponding +// oneapi::mkl::jobsvd enumeration value +static oneapi::mkl::jobsvd process_job(std::int8_t job_val) +{ + switch (job_val) { + case 'A': + return oneapi::mkl::jobsvd::vectors; + case 'S': + return oneapi::mkl::jobsvd::somevec; + case 'O': + return oneapi::mkl::jobsvd::vectorsina; + case 'N': + return oneapi::mkl::jobsvd::novec; + default: + throw std::invalid_argument("Unknown value for job"); + } +} + +template +static sycl::event gesvd_impl(sycl::queue exec_q, + const oneapi::mkl::jobsvd jobu, + const oneapi::mkl::jobsvd jobvt, + const std::int64_t m, + const std::int64_t n, + char *in_a, + const std::int64_t lda, + char *out_s, + char *out_u, + const std::int64_t ldu, + char *out_vt, + const std::int64_t ldvt, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + RealT *s = reinterpret_cast(out_s); + T *u = reinterpret_cast(out_u); + T *vt = reinterpret_cast(out_vt); + + const std::int64_t scratchpad_size = mkl_lapack::gesvd_scratchpad_size( + exec_q, jobu, jobvt, m, n, lda, ldu, ldvt); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event gesvd_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + gesvd_event = mkl_lapack::gesvd( + exec_q, + jobu, // Character specifying how to compute the matrix U: + // 'A' computes all columns of U, + // 'S' computes the first min(m,n) columns of U, + // 'O' overwrites A with the columns of U, + // 'N' does not compute U. + jobvt, // Character specifying how to compute the matrix VT: + // 'A' computes all rows of VT, + // 'S' computes the first min(m,n) rows of VT, + // 'O' overwrites A with the rows of VT, + // 'N' does not compute VT. + m, // The number of rows in the input matrix A (0 <= m). + n, // The number of columns in the input matrix A (0 <= n). + a, // Pointer to the input matrix A of size (m x n). + lda, // The leading dimension of A, must be at least max(1, m). + s, // Pointer to the array containing the singular values. + u, // Pointer to the matrix U in the singular value decomposition. + ldu, // The leading dimension of U, must be at least max(1, m). + vt, // Pointer to the matrix VT in the singular value decomposition. + ldvt, // The leading dimension of VT, must be at least max(1, n). + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else if (info > 0) { + error_msg << "The algorithm computing SVD failed to converge; " + << info << " off-diagonal elements of an intermediate " + << "bidiagonal form did not converge to zero.\n"; + } + else { + error_msg << "Unexpected MKL exception caught during gesvd() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during gesvd() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(gesvd_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + return gesvd_event; +} + +std::pair + gesvd(sycl::queue exec_q, + const std::int8_t jobu_val, + const std::int8_t jobvt_val, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray out_s, + dpctl::tensor::usm_ndarray out_u, + dpctl::tensor::usm_ndarray out_vt, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int out_u_array_nd = out_u.get_ndim(); + const int out_s_array_nd = out_s.get_ndim(); + const int out_vt_array_nd = out_vt.get_ndim(); + + if (a_array_nd != 2) { + throw py::value_error( + "The input array has ndim=" + std::to_string(a_array_nd) + + ", but a 2-dimensional array is expected."); + } + + if (out_s_array_nd != 1) { + throw py::value_error("The output array of singular values has ndim=" + + std::to_string(out_s_array_nd) + + ", but a 1-dimensional array is expected."); + } + + if (jobu_val == 'N' && jobvt_val == 'N') { + if (out_u_array_nd != 0) { + throw py::value_error( + "The output array of the left singular vectors has ndim=" + + std::to_string(out_u_array_nd) + + ", but it is not used and should have ndim=0."); + } + if (out_vt_array_nd != 0) { + throw py::value_error( + "The output array of the right singular vectors has ndim=" + + std::to_string(out_vt_array_nd) + + ", but it is not used and should have ndim=0."); + } + } + else { + if (out_u_array_nd != 2) { + throw py::value_error( + "The output array of the left singular vectors has ndim=" + + std::to_string(out_u_array_nd) + + ", but a 2-dimensional array is expected."); + } + if (out_vt_array_nd != 2) { + throw py::value_error( + "The output array of the right singular vectors has ndim=" + + std::to_string(out_vt_array_nd) + + ", but a 2-dimensional array is expected."); + } + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible( + exec_q, {a_array.get_queue(), out_s.get_queue(), out_u.get_queue(), + out_vt.get_queue()})) + { + throw std::runtime_error( + "USM allocations are not compatible with the execution queue."); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, out_s) || overlap(a_array, out_u) || + overlap(a_array, out_vt) || overlap(out_s, out_u) || + overlap(out_s, out_vt) || overlap(out_u, out_vt)) + { + throw py::value_error("Arrays have overlapping segments of memory"); + } + + bool is_a_array_f_contig = a_array.is_f_contiguous(); + if (!is_a_array_f_contig) { + throw py::value_error("The input array must be F-contiguous"); + } + + bool is_out_u_array_f_contig = out_u.is_f_contiguous(); + bool is_out_vt_array_f_contig = out_vt.is_f_contiguous(); + + if (!is_out_u_array_f_contig || !is_out_vt_array_f_contig) { + throw py::value_error("The output arrays of the left and right " + "singular vectors must be F-contiguous"); + } + + bool is_out_s_array_c_contig = out_s.is_c_contiguous(); + bool is_out_s_array_f_contig = out_s.is_f_contiguous(); + + if (!is_out_s_array_c_contig || !is_out_s_array_f_contig) { + throw py::value_error("The output array of singular values " + "must be contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int out_u_type_id = array_types.typenum_to_lookup_id(out_u.get_typenum()); + int out_s_type_id = array_types.typenum_to_lookup_id(out_s.get_typenum()); + int out_vt_type_id = array_types.typenum_to_lookup_id(out_vt.get_typenum()); + + if (a_array_type_id != out_u_type_id || a_array_type_id != out_vt_type_id) { + throw py::type_error( + "Input array, output left singular vectors array, " + "and outpuy right singular vectors array must have " + "the same data type"); + } + + gesvd_impl_fn_ptr_t gesvd_fn = + gesvd_dispatch_table[a_array_type_id][out_s_type_id]; + if (gesvd_fn == nullptr) { + throw py::value_error( + "No gesvd implementation is defined for the given pair " + "of array type and output singular values type."); + } + + char *a_array_data = a_array.get_data(); + char *out_s_data = out_s.get_data(); + char *out_u_data = out_u.get_data(); + char *out_vt_data = out_vt.get_data(); + + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + const std::int64_t m = a_array_shape[0]; + const std::int64_t n = a_array_shape[1]; + + const std::int64_t lda = std::max(1UL, m); + const std::int64_t ldu = std::max(1UL, m); + const std::int64_t ldvt = + std::max(1UL, jobvt_val == 'S' ? (m > n ? n : m) : n); + + const oneapi::mkl::jobsvd jobu = process_job(jobu_val); + const oneapi::mkl::jobsvd jobvt = process_job(jobvt_val); + + std::vector host_task_events; + sycl::event gesvd_ev = + gesvd_fn(exec_q, jobu, jobvt, m, n, a_array_data, lda, out_s_data, + out_u_data, ldu, out_vt_data, ldvt, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {a_array, out_s, out_u, out_vt}, host_task_events); + + return std::make_pair(args_ev, gesvd_ev); +} + +template +struct GesvdContigFactory +{ + fnT get() + { + if constexpr (types::GesvdTypePairSupportFactory::is_defined) + { + return gesvd_impl; + } + else { + return nullptr; + } + } +}; + +void init_gesvd_dispatch_table(void) +{ + dpctl_td_ns::DispatchTableBuilder + contig; + contig.populate_dispatch_table(gesvd_dispatch_table); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/gesvd.hpp b/dpnp/backend/extensions/lapack/gesvd.hpp new file mode 100644 index 00000000000..17ebd0edbe7 --- /dev/null +++ b/dpnp/backend/extensions/lapack/gesvd.hpp @@ -0,0 +1,55 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +extern std::pair + gesvd(sycl::queue exec_q, + const std::int8_t jobu_val, + const std::int8_t jobvt_val, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray out_s, + dpctl::tensor::usm_ndarray out_u, + dpctl::tensor::usm_ndarray out_vt, + const std::vector &depends); + +extern void init_gesvd_dispatch_table(void); +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 71991be3652..0c76d0fc096 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -31,6 +31,7 @@ #include #include "gesv.hpp" +#include "gesvd.hpp" #include "getrf.hpp" #include "getri.hpp" #include "heevd.hpp" @@ -56,6 +57,7 @@ void init_dispatch_vectors(void) // populate dispatch tables void init_dispatch_tables(void) { + lapack_ext::init_gesvd_dispatch_table(); lapack_ext::init_heevd_dispatch_table(); } @@ -76,6 +78,13 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("sycl_queue"), py::arg("coeff_matrix"), py::arg("dependent_vals"), py::arg("depends") = py::list()); + m.def("_gesvd", &lapack_ext::gesvd, + "Call `gesvd` from OneMKL LAPACK library to return " + "the singular value decomposition of a general rectangular matrix", + py::arg("sycl_queue"), py::arg("jobu_val"), py::arg("jobvt_val"), + py::arg("a_array"), py::arg("res_s"), py::arg("res_u"), + py::arg("res_vt"), py::arg("depends") = py::list()); + m.def("_getrf", &lapack_ext::getrf, "Call `getrf` from OneMKL LAPACK library to return " "the LU factorization of a general n x n matrix", diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index 7e5413b84c8..893619e6afb 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -70,6 +70,28 @@ struct GesvTypePairSupportFactory dpctl_td_ns::NotDefinedEntry>::is_defined; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::gesvd + * function. + * + * @tparam T Type of array containing input matrix A and output matrices U and + * VT of singular vectors. + * @tparam RealT Type of output array containing singular values of A. + */ +template +struct GesvdTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, RealT, float>, + dpctl_td_ns:: + TypePairDefinedEntry, RealT, double>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::getrf diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 2e2ce5ab144..3061bb01f29 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -363,8 +363,6 @@ enum class DPNPFuncName : size_t parameters */ DPNP_FN_SUM, /**< Used in numpy.sum() impl */ DPNP_FN_SVD, /**< Used in numpy.linalg.svd() impl */ - DPNP_FN_SVD_EXT, /**< Used in numpy.linalg.svd() impl, requires extra - parameters */ DPNP_FN_TAKE, /**< Used in numpy.take() impl */ DPNP_FN_TAN, /**< Used in numpy.tan() impl */ DPNP_FN_TANH, /**< Used in numpy.tanh() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp index e0b6de5b1b6..610da8fda3c 100644 --- a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp @@ -824,17 +824,6 @@ template void (*dpnp_svd_default_c)(void *, void *, void *, void *, size_t, size_t) = dpnp_svd_c<_InputDT, _ComputeDT, _SVDT>; -template -DPCTLSyclEventRef (*dpnp_svd_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - void *, - void *, - size_t, - size_t, - const DPCTLEventVectorRef) = - dpnp_svd_c<_InputDT, _ComputeDT, _SVDT>; - void func_map_init_linalg_func(func_map_t &fmap) { fmap[DPNPFuncName::DPNP_FN_CHOLESKY][eft_FLT][eft_FLT] = { @@ -1046,38 +1035,5 @@ void func_map_init_linalg_func(func_map_t &fmap) eft_C128, (void *)dpnp_svd_default_c, std::complex, double>}; - fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_INT][eft_INT] = { - get_default_floating_type(), - (void *)dpnp_svd_ext_c< - int32_t, func_type_map_t::find_type, - func_type_map_t::find_type>, - get_default_floating_type(), - (void *) - dpnp_svd_ext_c()>, - func_type_map_t::find_type< - get_default_floating_type()>>}; - fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_LNG][eft_LNG] = { - get_default_floating_type(), - (void *)dpnp_svd_ext_c< - int64_t, func_type_map_t::find_type, - func_type_map_t::find_type>, - get_default_floating_type(), - (void *) - dpnp_svd_ext_c()>, - func_type_map_t::find_type< - get_default_floating_type()>>}; - fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_svd_ext_c}; - fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_svd_ext_c}; - fmap[DPNPFuncName::DPNP_FN_SVD_EXT][eft_C128][eft_C128] = { - eft_C128, - (void *) - dpnp_svd_ext_c, std::complex, double>}; - return; } diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 895b393aeff..28e21340647 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -171,8 +171,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_RNG_ZIPF_EXT DPNP_FN_SEARCHSORTED DPNP_FN_SEARCHSORTED_EXT - DPNP_FN_SVD - DPNP_FN_SVD_EXT DPNP_FN_TRACE DPNP_FN_TRACE_EXT DPNP_FN_TRANSPOSE diff --git a/dpnp/linalg/dpnp_algo_linalg.pyx b/dpnp/linalg/dpnp_algo_linalg.pyx index 1d94a893fff..3bf6dad3ee8 100644 --- a/dpnp/linalg/dpnp_algo_linalg.pyx +++ b/dpnp/linalg/dpnp_algo_linalg.pyx @@ -51,7 +51,6 @@ __all__ = [ "dpnp_matrix_rank", "dpnp_norm", "dpnp_qr", - "dpnp_svd", ] @@ -379,57 +378,3 @@ cpdef tuple dpnp_qr(utils.dpnp_descriptor x1, str mode): c_dpctl.DPCTLEvent_Delete(event_ref) return (res_q.get_pyobj(), res_r.get_pyobj()) - - -cpdef tuple dpnp_svd(utils.dpnp_descriptor x1, cpp_bool full_matrices, cpp_bool compute_uv, cpp_bool hermitian): - cdef size_t size_m = x1.shape[0] - cdef size_t size_n = x1.shape[1] - cdef size_t size_s = min(size_m, size_n) - - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SVD_EXT, param1_type, param1_type) - - x1_obj = x1.get_array() - - cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data, - x1_obj.sycl_device.has_aspect_fp64) - cdef DPNPFuncType return_type = ret_type_and_func[0] - cdef custom_linalg_1in_3out_shape_t func = < custom_linalg_1in_3out_shape_t > ret_type_and_func[1] - - cdef utils.dpnp_descriptor res_u = utils.create_output_descriptor((size_m, size_m), - return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - cdef utils.dpnp_descriptor res_s = utils.create_output_descriptor((size_s, ), - return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - cdef utils.dpnp_descriptor res_vt = utils.create_output_descriptor((size_n, size_n), - return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - - result_sycl_queue = res_u.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - x1.get_data(), - res_u.get_data(), - res_s.get_data(), - res_vt.get_data(), - size_m, - size_n, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return (res_u.get_pyobj(), res_s.get_pyobj(), res_vt.get_pyobj()) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 800aa8de1bb..2b8506130ad 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -53,6 +53,7 @@ dpnp_inv, dpnp_slogdet, dpnp_solve, + dpnp_svd, ) __all__ = [ @@ -611,12 +612,47 @@ def solve(a, b): return dpnp_solve(a, b) -def svd(x1, full_matrices=True, compute_uv=True, hermitian=False): +def svd(a, full_matrices=True, compute_uv=True, hermitian=False): """ Singular Value Decomposition. For full documentation refer to :obj:`numpy.linalg.svd`. + Parameters + ---------- + a : (..., M, N) {dpnp.ndarray, usm_ndarray} + Input array with ``a.ndim >= 2``. + full_matrices : bool, optional + If ``True``, it returns `u` and `Vh` with full-sized matrices. + If ``False``, the matrices are reduced in size. + Default: ``True``. + compute_uv : bool, optional + If ``False``, it only returns singular values. + Default: ``True``. + hermitian : bool, optional + If True, a is assumed to be Hermitian (symmetric if real-valued), + enabling a more efficient method for finding singular values. + Default: ``False``. + + Returns + ------- + u : { (…, M, M), (…, M, K) } dpnp.ndarray + Unitary matrix, where M is the number of rows of the input array `a`. + The shape of the matrix `u` depends on the value of `full_matrices`. + If `full_matrices` is ``True``, `u` has the shape (…, M, M). + If `full_matrices` is ``False``, `u` has the shape (…, M, K), + where K = min(M, N), and N is the number of columns of the input array `a`. + If `compute_uv` is ``False``, neither `u` or `Vh` are computed. + s : (…, K) dpnp.ndarray + Vector containing the singular values of `a`, sorted in descending order. + The length of `s` is min(M, N). + Vh : { (…, N, N), (…, K, N) } dpnp.ndarray + Unitary matrix, where N is the number of columns of the input array `a`. + The shape of the matrix `Vh` depends on the value of `full_matrices`. + If `full_matrices` is ``True``, `Vh` has the shape (…, N, N). + If `full_matrices` is ``False``, `Vh` has the shape (…, K, N). + If `compute_uv` is ``False``, neither `u` or `Vh` are computed. + Examples -------- >>> import dpnp as np @@ -629,11 +665,11 @@ def svd(x1, full_matrices=True, compute_uv=True, hermitian=False): >>> u.shape, s.shape, vh.shape ((9, 9), (6,), (6, 6)) >>> np.allclose(a, np.dot(u[:, :6] * s, vh)) - True + array([ True]) >>> smat = np.zeros((9, 6), dtype=complex) >>> smat[:6, :6] = np.diag(s) >>> np.allclose(a, np.dot(u, np.dot(smat, vh))) - True + array([ True]) Reconstruction based on reduced SVD, 2D case: @@ -641,10 +677,10 @@ def svd(x1, full_matrices=True, compute_uv=True, hermitian=False): >>> u.shape, s.shape, vh.shape ((9, 6), (6,), (6, 6)) >>> np.allclose(a, np.dot(u * s, vh)) - True + array([ True]) >>> smat = np.diag(s) >>> np.allclose(a, np.dot(u, np.dot(smat, vh))) - True + array([ True]) Reconstruction based on full SVD, 4D case: @@ -652,9 +688,9 @@ def svd(x1, full_matrices=True, compute_uv=True, hermitian=False): >>> u.shape, s.shape, vh.shape ((2, 7, 8, 8), (2, 7, 3), (2, 7, 3, 3)) >>> np.allclose(b, np.matmul(u[..., :3] * s[..., None, :], vh)) - True + array([ True]) >>> np.allclose(b, np.matmul(u[..., :3], s[..., None] * vh)) - True + array([ True]) Reconstruction based on reduced SVD, 4D case: @@ -662,30 +698,16 @@ def svd(x1, full_matrices=True, compute_uv=True, hermitian=False): >>> u.shape, s.shape, vh.shape ((2, 7, 8, 3), (2, 7, 3), (2, 7, 3, 3)) >>> np.allclose(b, np.matmul(u * s[..., None, :], vh)) - True + array([ True]) >>> np.allclose(b, np.matmul(u, s[..., None] * vh)) - True + array([ True]) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if not x1_desc.ndim == 2: - pass - elif full_matrices is not True: - pass - elif compute_uv is not True: - pass - elif hermitian is not False: - pass - else: - result_tup = dpnp_svd(x1_desc, full_matrices, compute_uv, hermitian) - - return result_tup + dpnp.check_supported_arrays_type(a) + check_stacked_2d(a) - return call_origin( - numpy.linalg.svd, x1, full_matrices, compute_uv, hermitian - ) + return dpnp_svd(a, full_matrices, compute_uv, hermitian) def slogdet(a): diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index f2632b5b6a4..93f41883133 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -24,8 +24,9 @@ # ***************************************************************************** +import dpctl import dpctl.tensor._tensor_impl as ti -from numpy import issubdtype +from numpy import prod import dpnp import dpnp.backend.extensions.lapack._lapack_impl as li @@ -40,6 +41,7 @@ "dpnp_inv", "dpnp_slogdet", "dpnp_solve", + "dpnp_svd", ] _jobz = {"N": 0, "V": 1} @@ -147,76 +149,6 @@ def _real_type(dtype, device=None): return dpnp.dtype(real_type) -def check_stacked_2d(*arrays): - """ - Return ``True`` if each array in `arrays` has at least two dimensions. - - If any array is less than two-dimensional, `dpnp.linalg.LinAlgError` will be raised. - - Parameters - ---------- - arrays : {dpnp.ndarray, usm_ndarray} - A sequence of input arrays to check for dimensionality. - - Returns - ------- - out : bool - ``True`` if each array in `arrays` is at least two-dimensional. - - Raises - ------ - dpnp.linalg.LinAlgError - If any array in `arrays` is less than two-dimensional. - - """ - - for a in arrays: - if a.ndim < 2: - raise dpnp.linalg.LinAlgError( - f"{a.ndim}-dimensional array given. The input " - "array must be at least two-dimensional" - ) - - -def check_stacked_square(*arrays): - """ - Return ``True`` if each array in `arrays` is a square matrix. - - If any array does not form a square matrix, `dpnp.linalg.LinAlgError` will be raised. - - Precondition: `arrays` are at least 2d. The caller should assert it - beforehand. For example, - - >>> def solve(a): - ... check_stacked_2d(a) - ... check_stacked_square(a) - ... ... - - Parameters - ---------- - arrays : {dpnp.ndarray, usm_ndarray} - A sequence of input arrays to check for square matrix shape. - - Returns - ------- - out : bool - ``True`` if each array in `arrays` forms a square matrix. - - Raises - ------ - dpnp.linalg.LinAlgError - If any array in `arrays` does not form a square matrix. - - """ - - for a in arrays: - m, n = a.shape[-2:] - if m != n: - raise dpnp.linalg.LinAlgError( - "Last 2 dimensions of the input array must be square" - ) - - def _common_type(*arrays): """ Common type for linear algebra operations. @@ -245,7 +177,8 @@ def _common_type(*arrays): dtypes = [arr.dtype for arr in arrays] - default = dpnp.default_float_type(device=arrays[0].device) + _, sycl_queue = get_usm_allocations(arrays) + default = dpnp.default_float_type(sycl_queue=sycl_queue) dtype_common = _common_inexact_type(default, *dtypes) return dtype_common @@ -275,7 +208,8 @@ def _common_inexact_type(default_dtype, *dtypes): """ inexact_dtypes = [ - dt if issubdtype(dt, dpnp.inexact) else default_dtype for dt in dtypes + dt if dpnp.issubdtype(dt, dpnp.inexact) else default_dtype + for dt in dtypes ] return dpnp.result_type(*inexact_dtypes) @@ -469,6 +403,120 @@ def _lu_factor(a, res_type): return (a_h, ipiv_h, dev_info_array) +def _stacked_identity( + batch_shape, n, dtype, usm_type="device", sycl_queue=None +): + """ + Create stacked identity matrices of size `n x n`. + + Forms multiple identity matrices based on `batch_shape`. + + Parameters + ---------- + batch_shape : tuple + Shape of the batch determining the stacking of identity matrices. + n : int + Dimension of each identity matrix. + dtype : dtype + Data type of the matrix element. + usm_type : {"device", "shared", "host"}, optional + The type of SYCL USM allocation for the output array. + sycl_queue : {None, SyclQueue}, optional + A SYCL queue to use for output array allocation and copying. + + Returns + ------- + out : dpnp.ndarray + Array of stacked `n x n` identity matrices as per `batch_shape`. + + Example + ------- + >>> _stacked_identity((2,), 2, dtype=dpnp.int64) + array([[[1, 0], + [0, 1]], + + [[1, 0], + [0, 1]]]) + + """ + + shape = batch_shape + (n, n) + idx = dpnp.arange(n, usm_type=usm_type, sycl_queue=sycl_queue) + x = dpnp.zeros(shape, dtype=dtype, usm_type=usm_type, sycl_queue=sycl_queue) + x[..., idx, idx] = 1 + return x + + +def check_stacked_2d(*arrays): + """ + Return ``True`` if each array in `arrays` has at least two dimensions. + + If any array is less than two-dimensional, `dpnp.linalg.LinAlgError` will be raised. + + Parameters + ---------- + arrays : {dpnp.ndarray, usm_ndarray} + A sequence of input arrays to check for dimensionality. + + Returns + ------- + out : bool + ``True`` if each array in `arrays` is at least two-dimensional. + + Raises + ------ + dpnp.linalg.LinAlgError + If any array in `arrays` is less than two-dimensional. + + """ + + for a in arrays: + if a.ndim < 2: + raise dpnp.linalg.LinAlgError( + f"{a.ndim}-dimensional array given. The input " + "array must be at least two-dimensional" + ) + + +def check_stacked_square(*arrays): + """ + Return ``True`` if each array in `arrays` is a square matrix. + + If any array does not form a square matrix, `dpnp.linalg.LinAlgError` will be raised. + + Precondition: `arrays` are at least 2d. The caller should assert it + beforehand. For example, + + >>> def solve(a): + ... check_stacked_2d(a) + ... check_stacked_square(a) + ... ... + + Parameters + ---------- + arrays : {dpnp.ndarray, usm_ndarray} + A sequence of input arrays to check for square matrix shape. + + Returns + ------- + out : bool + ``True`` if each array in `arrays` forms a square matrix. + + Raises + ------ + dpnp.linalg.LinAlgError + If any array in `arrays` does not form a square matrix. + + """ + + for a in arrays: + m, n = a.shape[-2:] + if m != n: + raise dpnp.linalg.LinAlgError( + "Last 2 dimensions of the input array must be square" + ) + + def dpnp_cholesky_batch(a, upper_lower, res_type): """ dpnp_cholesky_batch(a, upper_lower, res_type) @@ -1088,3 +1136,290 @@ def dpnp_slogdet(a): dpnp.where(singular, res_type.type(0), sign).reshape(shape), dpnp.where(singular, logdet_dtype.type("-inf"), logdet).reshape(shape), ) + + +def dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True): + """ + dpnp_svd_batch(a, uv_type, s_type, full_matrices=True, compute_uv=True) + + Return the batched singular value decomposition (SVD) of a stack of matrices. + + """ + + a_usm_type = a.usm_type + a_sycl_queue = a.sycl_queue + reshape = False + batch_shape_orig = a.shape[:-2] + + if a.ndim > 3: + # get 3d input arrays by reshape + a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1]) + reshape = True + + batch_size = a.shape[0] + m, n = a.shape[-2:] + + if batch_size == 0: + k = min(m, n) + s = dpnp.empty( + batch_shape_orig + (k,), + dtype=s_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + if compute_uv: + if full_matrices: + u_shape = batch_shape_orig + (m, m) + vt_shape = batch_shape_orig + (n, n) + else: + u_shape = batch_shape_orig + (m, k) + vt_shape = batch_shape_orig + (k, n) + + u = dpnp.empty( + u_shape, + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + vt = dpnp.empty( + vt_shape, + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return u, s, vt + else: + return s + elif m == 0 or n == 0: + s = dpnp.empty( + batch_shape_orig + (0,), + dtype=s_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + if compute_uv: + if full_matrices: + u = _stacked_identity( + batch_shape_orig, + m, + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + vt = _stacked_identity( + batch_shape_orig, + n, + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + else: + u = dpnp.empty( + batch_shape_orig + (m, 0), + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + vt = dpnp.empty( + batch_shape_orig + (0, n), + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return u, s, vt + else: + return s + + u_matrices = [None] * batch_size + s_matrices = [None] * batch_size + vt_matrices = [None] * batch_size + ht_list_ev = [None] * batch_size * 2 + for i in range(batch_size): + if compute_uv: + ( + u_matrices[i], + s_matrices[i], + vt_matrices[i], + ht_list_ev[2 * i], + ht_list_ev[2 * i + 1], + ) = dpnp_svd(a[i], full_matrices, compute_uv=True, batch_call=True) + else: + s_matrices[i], ht_list_ev[2 * i], ht_list_ev[2 * i + 1] = dpnp_svd( + a[i], full_matrices, compute_uv=False, batch_call=True + ) + + dpctl.SyclEvent.wait_for(ht_list_ev) + + # TODO: Need to return C-contiguous array to match the output of numpy.linalg.svd + # Allocate 'F' order memory for dpnp output arrays to be aligned with dpnp_svd + out_s = dpnp.array(s_matrices, order="F") + if reshape: + out_s = out_s.reshape(batch_shape_orig + out_s.shape[-1:]) + + if compute_uv: + out_u = dpnp.array(u_matrices, order="F") + out_vt = dpnp.array(vt_matrices, order="F") + if reshape: + return ( + out_u.reshape(batch_shape_orig + out_u.shape[-2:]), + out_s, + out_vt.reshape(batch_shape_orig + out_vt.shape[-2:]), + ) + else: + return out_u, out_s, out_vt + else: + return out_s + + +def dpnp_svd( + a, full_matrices=True, compute_uv=True, hermitian=False, batch_call=False +): + """ + dpnp_svd(a, full_matrices=True, compute_uv=True, hermitian=False, batch_call=False) + + Return the singular value decomposition (SVD). + + """ + + if hermitian: + check_stacked_square(a) + + # _gesvd returns eigenvalues with s ** 2 sorted descending, + # but dpnp.linalg.eigh returns s sorted ascending so we re-order the eigenvalues + # and related arrays to have the correct order + if compute_uv: + s, u = dpnp.linalg.eigh(a) + sgn = dpnp.sign(s) + s = dpnp.absolute(s) + sidx = dpnp.argsort(s)[..., ::-1] + # Rearrange the signs according to sorted indices + sgn = dpnp.take_along_axis(sgn, sidx, axis=-1) + # Sort the singular values in descending order + s = dpnp.take_along_axis(s, sidx, axis=-1) + # Rearrange the eigenvectors according to sorted indices + u = dpnp.take_along_axis(u, sidx[..., None, :], axis=-1) + # Singular values are unsigned, move the sign into v + # Compute V^T adjusting for the sign and conjugating + vt = dpnp.transpose(u * sgn[..., None, :]).conjugate() + return u, s, vt + else: + # TODO: use dpnp.linalg.eighvals when it is updated + s, _ = dpnp.linalg.eigh(a) + s = dpnp.abs(s) + return dpnp.sort(s)[..., ::-1] + + uv_type = _common_type(a) + s_type = _real_type(uv_type) + + if a.ndim > 2: + return dpnp_svd_batch(a, uv_type, s_type, full_matrices, compute_uv) + + a_usm_type = a.usm_type + a_sycl_queue = a.sycl_queue + m, n = a.shape + + if m == 0 or n == 0: + s = dpnp.empty( + (0,), + dtype=s_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + if compute_uv: + if full_matrices: + u_shape = (m,) + vt_shape = (n,) + else: + u_shape = (m, 0) + vt_shape = (0, n) + + u = dpnp.eye( + *u_shape, + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + vt = dpnp.eye( + *vt_shape, + dtype=uv_type, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return u, s, vt + else: + return s + + # oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input. + # Allocate 'F' order memory for dpnp arrays to comply with these requirements. + a_h = dpnp.empty_like(a, order="F", dtype=uv_type) + + a_usm_arr = dpnp.get_usm_ndarray(a) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a_sycl_queue + ) + + k = min(m, n) + if compute_uv: + if full_matrices: + u_shape = (m, m) + vt_shape = (n, n) + jobu = ord("A") + jobvt = ord("A") + else: + u_shape = (m, k) + vt_shape = (k, n) + jobu = ord("S") + jobvt = ord("S") + else: + u_shape = vt_shape = () + jobu = ord("N") + jobvt = ord("N") + + # oneMKL LAPACK assumes fortran-like array as input. + # Allocate 'F' order memory for dpnp output arrays to comply with these requirements. + u_h = dpnp.empty( + u_shape, + dtype=uv_type, + order="F", + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + vt_h = dpnp.empty( + vt_shape, + dtype=uv_type, + order="F", + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + s_h = dpnp.empty( + k, dtype=s_type, usm_type=a_usm_type, sycl_queue=a_sycl_queue + ) + + ht_lapack_ev, _ = li._gesvd( + a_sycl_queue, + jobu, + jobvt, + a_h.get_array(), + s_h.get_array(), + u_h.get_array(), + vt_h.get_array(), + [a_copy_ev], + ) + + if batch_call: + if compute_uv: + return u_h, s_h, vt_h, ht_lapack_ev, a_ht_copy_ev + else: + return s_h, ht_lapack_ev, a_ht_copy_ev + + ht_lapack_ev.wait() + a_ht_copy_ev.wait() + + # TODO: Need to return C-contiguous array to match the output of numpy.linalg.svd + if compute_uv: + return u_h, s_h, vt_h + else: + return s_h diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 5ea536c2887..85206bad5ba 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -9,6 +9,7 @@ from .helper import ( assert_dtype_allclose, get_all_dtypes, + get_complex_dtypes, has_support_aspect64, is_cpu_device, ) @@ -755,64 +756,6 @@ def test_qr_not_2D(): assert_allclose(ia, inp.matmul(dpnp_q, dpnp_r)) -@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True)) -@pytest.mark.parametrize( - "shape", - [(2, 2), (3, 4), (5, 3), (16, 16)], - ids=["(2,2)", "(3,4)", "(5,3)", "(16,16)"], -) -def test_svd(type, shape): - a = numpy.arange(shape[0] * shape[1], dtype=type).reshape(shape) - ia = inp.array(a) - - np_u, np_s, np_vt = numpy.linalg.svd(a) - dpnp_u, dpnp_s, dpnp_vt = inp.linalg.svd(ia) - - support_aspect64 = has_support_aspect64() - - if support_aspect64: - assert dpnp_u.dtype == np_u.dtype - assert dpnp_s.dtype == np_s.dtype - assert dpnp_vt.dtype == np_vt.dtype - assert dpnp_u.shape == np_u.shape - assert dpnp_s.shape == np_s.shape - assert dpnp_vt.shape == np_vt.shape - - tol = 1e-12 - if type == inp.float32: - tol = 1e-03 - elif not support_aspect64 and type in (inp.int32, inp.int64, None): - tol = 1e-03 - - # check decomposition - dpnp_diag_s = inp.zeros(shape, dtype=dpnp_s.dtype) - for i in range(dpnp_s.size): - dpnp_diag_s[i, i] = dpnp_s[i] - - # check decomposition - assert_allclose( - ia, inp.dot(dpnp_u, inp.dot(dpnp_diag_s, dpnp_vt)), rtol=tol, atol=tol - ) - - # compare singular values - # assert_allclose(dpnp_s, np_s, rtol=tol, atol=tol) - - # change sign of vectors - for i in range(min(shape[0], shape[1])): - if np_u[0, i] * dpnp_u[0, i] < 0: - np_u[:, i] = -np_u[:, i] - np_vt[i, :] = -np_vt[i, :] - - # compare vectors for non-zero values - for i in range(numpy.count_nonzero(np_s > tol)): - assert_allclose( - inp.asnumpy(dpnp_u)[:, i], np_u[:, i], rtol=tol, atol=tol - ) - assert_allclose( - inp.asnumpy(dpnp_vt)[i, :], np_vt[i, :], rtol=tol, atol=tol - ) - - class TestSolve: @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) def test_solve(self, dtype): @@ -1028,3 +971,149 @@ def test_slogdet_errors(self): # unsupported type a_np = inp.asnumpy(a_dp) assert_raises(TypeError, inp.linalg.slogdet, a_np) + + +class TestSvd: + def get_tol(self, dtype): + tol = 1e-06 + if dtype in (inp.float32, inp.complex64): + tol = 1e-04 + elif not has_support_aspect64() and dtype in ( + inp.int32, + inp.int64, + None, + ): + tol = 1e-04 + self._tol = tol + + def check_types_shapes( + self, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt=True + ): + if has_support_aspect64(): + if compute_vt: + assert dp_u.dtype == np_u.dtype + assert dp_vt.dtype == np_vt.dtype + assert dp_s.dtype == np_s.dtype + else: + if compute_vt: + assert dp_u.dtype.kind == np_u.dtype.kind + assert dp_vt.dtype.kind == np_vt.dtype.kind + assert dp_s.dtype.kind == np_s.dtype.kind + + if compute_vt: + assert dp_u.shape == np_u.shape + assert dp_vt.shape == np_vt.shape + assert dp_s.shape == np_s.shape + + # Checks the accuracy of singular value decomposition (SVD). + # Compares the reconstructed matrix from the decomposed components + # with the original matrix. + # Additionally checks for equality of singular values + # between dpnp and numpy decompositions + def check_decomposition( + self, dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt + ): + tol = self._tol + if compute_vt: + dpnp_diag_s = inp.zeros_like(dp_a, dtype=dp_s.dtype) + for i in range(min(dp_a.shape[-2], dp_a.shape[-1])): + dpnp_diag_s[..., i, i] = dp_s[..., i] + # TODO: remove it when dpnp.dot is updated + # dpnp.dot does not support complex type + if inp.issubdtype(dp_a.dtype, inp.complexfloating): + reconstructed = numpy.dot( + inp.asnumpy(dp_u), + numpy.dot(inp.asnumpy(dpnp_diag_s), inp.asnumpy(dp_vt)), + ) + else: + reconstructed = inp.dot(dp_u, inp.dot(dpnp_diag_s, dp_vt)) + # TODO: use assert dpnp.allclose() inside check_decomposition() + # when it will support complex dtypes + assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4) + + assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03) + + if compute_vt: + for i in range(min(dp_a.shape[-2], dp_a.shape[-1])): + if np_u[..., 0, i] * dp_u[..., 0, i] < 0: + np_u[..., :, i] = -np_u[..., :, i] + np_vt[..., i, :] = -np_vt[..., i, :] + for i in range(numpy.count_nonzero(np_s > tol)): + assert_allclose( + inp.asnumpy(dp_u[..., :, i]), + np_u[..., :, i], + rtol=tol, + atol=tol, + ) + assert_allclose( + inp.asnumpy(dp_vt[..., i, :]), + np_vt[..., i, :], + rtol=tol, + atol=tol, + ) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "shape", + [(2, 2), (3, 4), (5, 3), (16, 16)], + ids=["(2,2)", "(3,4)", "(5,3)", "(16,16)"], + ) + def test_svd(self, dtype, shape): + a = numpy.arange(shape[0] * shape[1], dtype=dtype).reshape(shape) + dp_a = inp.array(a) + + np_u, np_s, np_vt = numpy.linalg.svd(a) + dp_u, dp_s, dp_vt = inp.linalg.svd(dp_a) + + self.check_types_shapes(dp_u, dp_s, dp_vt, np_u, np_s, np_vt) + self.get_tol(dtype) + self.check_decomposition( + dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, True + ) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("compute_vt", [True, False], ids=["True", "False"]) + @pytest.mark.parametrize( + "shape", + [(2, 2), (16, 16)], + ids=["(2,2)", "(16, 16)"], + ) + def test_svd_hermitian(self, dtype, compute_vt, shape): + a = numpy.random.randn(*shape) + 1j * numpy.random.randn(*shape) + a = numpy.conj(a.T) @ a + + a = a.astype(dtype) + dp_a = inp.array(a) + + if compute_vt: + np_u, np_s, np_vt = numpy.linalg.svd( + a, compute_uv=compute_vt, hermitian=True + ) + dp_u, dp_s, dp_vt = inp.linalg.svd( + dp_a, compute_uv=compute_vt, hermitian=True + ) + else: + np_s = numpy.linalg.svd(a, compute_uv=compute_vt, hermitian=True) + dp_s = inp.linalg.svd(dp_a, compute_uv=compute_vt, hermitian=True) + np_u = np_vt = dp_u = dp_vt = None + + self.check_types_shapes( + dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt + ) + + self.get_tol(dtype) + + self.check_decomposition( + dp_a, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt + ) + + def test_svd_errors(self): + a_dp = inp.array([[1, 2], [3, 4]], dtype="float32") + + # unsupported type + a_np = inp.asnumpy(a_dp) + assert_raises(TypeError, inp.linalg.svd, a_np) + + # a.ndim < 2 + a_dp_ndim_1 = a_dp.flatten() + assert_raises(inp.linalg.LinAlgError, inp.linalg.svd, a_dp_ndim_1) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 7a7bcd53e0b..205d4efb572 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1230,53 +1230,62 @@ def test_qr(device): valid_devices, ids=[device.filter_string for device in valid_devices], ) -def test_svd(device): - shape = (2, 2) +@pytest.mark.parametrize("full_matrices", [True, False], ids=["True", "False"]) +@pytest.mark.parametrize("compute_uv", [True, False], ids=["True", "False"]) +@pytest.mark.parametrize( + "shape", + [ + (1, 4), + (3, 2), + (4, 4), + (2, 0), + (0, 2), + (2, 2, 3), + (3, 3, 0), + (0, 2, 3), + (1, 0, 3), + ], + ids=[ + "(1, 4)", + "(3, 2)", + "(4, 4)", + "(2, 0)", + "(0, 2)", + "(2, 2, 3)", + "(3, 3, 0)", + "(0, 2, 3)", + "(1, 0, 3)", + ], +) +def test_svd(shape, full_matrices, compute_uv, device): dtype = dpnp.default_float_type(device) - numpy_data = numpy.arange(shape[0] * shape[1], dtype=dtype).reshape(shape) - dpnp_data = dpnp.arange( - shape[0] * shape[1], dtype=dtype, device=device - ).reshape(shape) - - np_u, np_s, np_vt = numpy.linalg.svd(numpy_data) - dpnp_u, dpnp_s, dpnp_vt = dpnp.linalg.svd(dpnp_data) - - assert dpnp_u.dtype == np_u.dtype - assert dpnp_s.dtype == np_s.dtype - assert dpnp_vt.dtype == np_vt.dtype - assert dpnp_u.shape == np_u.shape - assert dpnp_s.shape == np_s.shape - assert dpnp_vt.shape == np_vt.shape - - # check decomposition - dpnp_diag_s = dpnp.zeros(shape, dtype=dpnp_s.dtype, device=device) - for i in range(dpnp_s.size): - dpnp_diag_s[i, i] = dpnp_s[i] - - # check decomposition - assert_dtype_allclose( - dpnp_data, dpnp.dot(dpnp_u, dpnp.dot(dpnp_diag_s, dpnp_vt)) + + count_elems = numpy.prod(shape) + dpnp_data = dpnp.arange(count_elems, dtype=dtype, device=device).reshape( + shape ) + expected_queue = dpnp_data.get_array().sycl_queue - for i in range(min(shape[0], shape[1])): - if np_u[0, i] * dpnp_u[0, i] < 0: - np_u[:, i] = -np_u[:, i] - np_vt[i, :] = -np_vt[i, :] + if compute_uv: + dpnp_u, dpnp_s, dpnp_vt = dpnp.linalg.svd( + dpnp_data, full_matrices=full_matrices, compute_uv=compute_uv + ) - # compare vectors for non-zero values - for i in range(numpy.count_nonzero(np_s)): - assert_dtype_allclose(dpnp_u[:, i], np_u[:, i]) - assert_dtype_allclose(dpnp_vt[i, :], np_vt[i, :]) + dpnp_u_queue = dpnp_u.get_array().sycl_queue + dpnp_vt_queue = dpnp_vt.get_array().sycl_queue + dpnp_s_queue = dpnp_s.get_array().sycl_queue - expected_queue = dpnp_data.get_array().sycl_queue - dpnp_u_queue = dpnp_u.get_array().sycl_queue - dpnp_s_queue = dpnp_s.get_array().sycl_queue - dpnp_vt_queue = dpnp_vt.get_array().sycl_queue + assert_sycl_queue_equal(dpnp_u_queue, expected_queue) + assert_sycl_queue_equal(dpnp_vt_queue, expected_queue) + assert_sycl_queue_equal(dpnp_s_queue, expected_queue) - # compare queue and device - assert_sycl_queue_equal(dpnp_u_queue, expected_queue) - assert_sycl_queue_equal(dpnp_s_queue, expected_queue) - assert_sycl_queue_equal(dpnp_vt_queue, expected_queue) + else: + dpnp_s = dpnp.linalg.svd( + dpnp_data, full_matrices=full_matrices, compute_uv=compute_uv + ) + dpnp_s_queue = dpnp_s.get_array().sycl_queue + + assert_sycl_queue_equal(dpnp_s_queue, expected_queue) @pytest.mark.parametrize( diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index ada68ebfa6c..bff548a90d0 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -740,3 +740,53 @@ def test_inv(shape, is_empty, usm_type): result = dp.linalg.inv(x) assert x.usm_type == result.usm_type + + +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +@pytest.mark.parametrize( + "full_matrices_param", [True, False], ids=["True", "False"] +) +@pytest.mark.parametrize( + "compute_uv_param", [True, False], ids=["True", "False"] +) +@pytest.mark.parametrize( + "shape", + [ + (1, 4), + (3, 2), + (4, 4), + (2, 0), + (0, 2), + (2, 2, 3), + (3, 3, 0), + (0, 2, 3), + (1, 0, 3), + ], + ids=[ + "(1, 4)", + "(3, 2)", + "(4, 4)", + "(2, 0)", + "(0, 2)", + "(2, 2, 3)", + "(3, 3, 0)", + "(0, 2, 3)", + "(1, 0, 3)", + ], +) +def test_svd(usm_type, shape, full_matrices_param, compute_uv_param): + x = dp.ones(shape, usm_type=usm_type) + + if compute_uv_param: + u, s, vt = dp.linalg.svd( + x, full_matrices=full_matrices_param, compute_uv=compute_uv_param + ) + + assert x.usm_type == u.usm_type + assert x.usm_type == vt.usm_type + else: + s = dp.linalg.svd( + x, full_matrices=full_matrices_param, compute_uv=compute_uv_param + ) + + assert x.usm_type == s.usm_type diff --git a/tests/third_party/cupy/linalg_tests/test_decomposition.py b/tests/third_party/cupy/linalg_tests/test_decomposition.py index 42bcf122ff4..fd887c16e6c 100644 --- a/tests/third_party/cupy/linalg_tests/test_decomposition.py +++ b/tests/third_party/cupy/linalg_tests/test_decomposition.py @@ -6,6 +6,7 @@ import dpnp as cupy from tests.helper import has_support_aspect64, is_cpu_device from tests.third_party.cupy import testing +from tests.third_party.cupy.testing import _condition def random_matrix(shape, dtype, scale, sym=False): @@ -44,6 +45,14 @@ def random_matrix(shape, dtype, scale, sym=False): return new_a.astype(dtype) +def stacked_identity(xp, batch_shape, n, dtype): + shape = batch_shape + (n, n) + idx = xp.arange(n) + x = xp.zeros(shape, dtype=dtype) + x[..., idx, idx] = 1 + return x + + class TestCholeskyDecomposition: @testing.numpy_cupy_allclose(atol=1e-3, type_check=has_support_aspect64()) def check_L(self, array, xp): @@ -135,3 +144,244 @@ def check_L(self, array): def test_decomposition(self, dtype): A = numpy.array([[1, -2], [-2, 1]]).astype(dtype) self.check_L(A) + + +@testing.parameterize( + *testing.product( + { + "full_matrices": [True, False], + } + ) +) +@testing.fix_random() +class TestSVD(unittest.TestCase): + # TODO: New packages that fix issue CMPLRLLVM-53771 are only available in internal CI. + # Skip the tests on cpu until these packages are available for the external CI. + # Specifically dpcpp_linux-64>=2024.1.0 + @classmethod + def setUpClass(cls): + if is_cpu_device(): + raise unittest.SkipTest("CMPLRLLVM-53771") + + def setUp(self): + self.seed = testing.generate_seed() + + @testing.for_dtypes( + [ + numpy.int32, + numpy.int64, + numpy.uint32, + numpy.uint64, + numpy.float32, + numpy.float64, + numpy.complex64, + numpy.complex128, + ] + ) + def check_usv(self, shape, dtype): + array = testing.shaped_random(shape, numpy, dtype=dtype, seed=self.seed) + a_cpu = numpy.asarray(array, dtype=dtype) + a_gpu = cupy.asarray(array, dtype=dtype) + result_cpu = numpy.linalg.svd(a_cpu, full_matrices=self.full_matrices) + result_gpu = cupy.linalg.svd(a_gpu, full_matrices=self.full_matrices) + # Check if the input matrix is not broken + testing.assert_allclose(a_gpu, a_cpu) + + assert len(result_gpu) == 3 + for i in range(3): + assert result_gpu[i].shape == result_cpu[i].shape + if has_support_aspect64(): + assert result_gpu[i].dtype == result_cpu[i].dtype + else: + assert result_gpu[i].dtype.kind == result_cpu[i].dtype.kind + u_cpu, s_cpu, vh_cpu = result_cpu + u_gpu, s_gpu, vh_gpu = result_gpu + testing.assert_allclose(s_gpu, s_cpu, rtol=1e-5, atol=1e-4) + + # reconstruct the matrix + k = s_cpu.shape[-1] + + # dpnp.dot/matmul does not support complex type and unstable on cpu + # TODO: remove it and use xp.dot/matmul when dpnp.dot/matmul is updated + u_gpu = u_gpu.asnumpy() + vh_gpu = vh_gpu.asnumpy() + s_gpu = s_gpu.asnumpy() + xp = numpy + + if len(shape) == 2: + if self.full_matrices: + a_gpu_usv = numpy.dot(u_gpu[:, :k] * s_gpu, vh_gpu[:k, :]) + else: + a_gpu_usv = numpy.dot(u_gpu * s_gpu, vh_gpu) + else: + if self.full_matrices: + a_gpu_usv = numpy.matmul( + u_gpu[..., :k] * s_gpu[..., None, :], vh_gpu[..., :k, :] + ) + else: + a_gpu_usv = numpy.matmul(u_gpu * s_gpu[..., None, :], vh_gpu) + testing.assert_allclose(a_gpu, a_gpu_usv, rtol=1e-4, atol=1e-4) + + # assert unitary + u_len = u_gpu.shape[-1] + vh_len = vh_gpu.shape[-2] + testing.assert_allclose( + xp.matmul(u_gpu.swapaxes(-1, -2).conj(), u_gpu), + stacked_identity(xp, shape[:-2], u_len, dtype), + atol=1e-4, + ) + testing.assert_allclose( + xp.matmul(vh_gpu, vh_gpu.swapaxes(-1, -2).conj()), + stacked_identity(xp, shape[:-2], vh_len, dtype), + atol=1e-4, + ) + + @testing.for_dtypes( + [ + numpy.int32, + numpy.int64, + numpy.uint32, + numpy.uint64, + numpy.float32, + numpy.float64, + numpy.complex64, + numpy.complex128, + ] + ) + # dpnp.linalg.svd() returns results as F-contiguous + # while numpy.linalg.svd() returns as C-contiguous + @testing.numpy_cupy_allclose( + rtol=1e-5, + atol=1e-4, + type_check=has_support_aspect64(), + contiguous_check=False, + ) + def check_singular(self, shape, xp, dtype): + array = testing.shaped_random(shape, xp, dtype=dtype, seed=self.seed) + a = xp.asarray(array, dtype=dtype) + a_copy = a.copy() + result = xp.linalg.svd( + a, full_matrices=self.full_matrices, compute_uv=False + ) + # Check if the input matrix is not broken + assert (a == a_copy).all() + return result + + @_condition.repeat(3, 10) + def test_svd_rank2(self): + self.check_usv((3, 7)) + self.check_usv((2, 2)) + self.check_usv((7, 3)) + + @_condition.repeat(3, 10) + def test_svd_rank2_no_uv(self): + self.check_singular((3, 7)) + self.check_singular((2, 2)) + self.check_singular((7, 3)) + + @testing.with_requires("numpy>=1.16") + def test_svd_rank2_empty_array(self): + self.check_usv((0, 3)) + self.check_usv((3, 0)) + self.check_usv((1, 0)) + + @testing.with_requires("numpy>=1.16") + @testing.numpy_cupy_array_equal(type_check=has_support_aspect64()) + def test_svd_rank2_empty_array_compute_uv_false(self, xp): + array = xp.empty((3, 0)) + return xp.linalg.svd( + array, full_matrices=self.full_matrices, compute_uv=False + ) + + @_condition.repeat(3, 10) + def test_svd_rank3(self): + self.check_usv((2, 3, 4)) + self.check_usv((2, 3, 7)) + self.check_usv((2, 4, 4)) + self.check_usv((2, 7, 3)) + self.check_usv((2, 4, 3)) + self.check_usv((2, 32, 32)) + + @_condition.repeat(3, 10) + def test_svd_rank3_loop(self): + # This tests the loop-based batched gesvd on CUDA (_gesvd_batched) + self.check_usv((2, 64, 64)) + self.check_usv((2, 64, 32)) + self.check_usv((2, 32, 64)) + + @_condition.repeat(3, 10) + def test_svd_rank3_no_uv(self): + self.check_singular((2, 3, 4)) + self.check_singular((2, 3, 7)) + self.check_singular((2, 4, 4)) + self.check_singular((2, 7, 3)) + self.check_singular((2, 4, 3)) + + @_condition.repeat(3, 10) + def test_svd_rank3_no_uv_loop(self): + # This tests the loop-based batched gesvd on CUDA (_gesvd_batched) + self.check_singular((2, 64, 64)) + self.check_singular((2, 64, 32)) + self.check_singular((2, 32, 64)) + + @testing.with_requires("numpy>=1.16") + def test_svd_rank3_empty_array(self): + self.check_usv((0, 3, 4)) + self.check_usv((3, 0, 4)) + self.check_usv((3, 4, 0)) + self.check_usv((3, 0, 0)) + self.check_usv((0, 3, 0)) + self.check_usv((0, 0, 3)) + + @testing.with_requires("numpy>=1.16") + @testing.numpy_cupy_array_equal(type_check=has_support_aspect64()) + def test_svd_rank3_empty_array_compute_uv_false1(self, xp): + array = xp.empty((3, 0, 4)) + return xp.linalg.svd( + array, full_matrices=self.full_matrices, compute_uv=False + ) + + @testing.with_requires("numpy>=1.16") + @testing.numpy_cupy_array_equal(type_check=has_support_aspect64()) + def test_svd_rank3_empty_array_compute_uv_false2(self, xp): + array = xp.empty((0, 3, 4)) + return xp.linalg.svd( + array, full_matrices=self.full_matrices, compute_uv=False + ) + + @_condition.repeat(3, 10) + def test_svd_rank4(self): + self.check_usv((2, 2, 3, 4)) + self.check_usv((2, 2, 3, 7)) + self.check_usv((2, 2, 4, 4)) + self.check_usv((2, 2, 7, 3)) + self.check_usv((2, 2, 4, 3)) + self.check_usv((2, 2, 32, 32)) + + @_condition.repeat(3, 10) + def test_svd_rank4_loop(self): + # This tests the loop-based batched gesvd on CUDA (_gesvd_batched) + self.check_usv((3, 2, 64, 64)) + self.check_usv((3, 2, 64, 32)) + self.check_usv((3, 2, 32, 64)) + + @_condition.repeat(3, 10) + def test_svd_rank4_no_uv(self): + self.check_singular((2, 2, 3, 4)) + self.check_singular((2, 2, 3, 7)) + self.check_singular((2, 2, 4, 4)) + self.check_singular((2, 2, 7, 3)) + self.check_singular((2, 2, 4, 3)) + + @_condition.repeat(3, 10) + def test_svd_rank4_no_uv_loop(self): + # This tests the loop-based batched gesvd on CUDA (_gesvd_batched) + self.check_singular((3, 2, 64, 64)) + self.check_singular((3, 2, 64, 32)) + self.check_singular((3, 2, 32, 64)) + + @testing.with_requires("numpy>=1.16") + def test_svd_rank4_empty_array(self): + self.check_usv((0, 2, 3, 4)) + self.check_usv((1, 2, 0, 4)) + self.check_usv((1, 2, 3, 0)) diff --git a/tests/third_party/cupy/linalg_tests/test_solve.py b/tests/third_party/cupy/linalg_tests/test_solve.py index b31082c8e84..cd397f6c9e1 100644 --- a/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/tests/third_party/cupy/linalg_tests/test_solve.py @@ -10,7 +10,7 @@ is_cpu_device, ) from tests.third_party.cupy import testing -from tests.third_party.cupy.testing import condition +from tests.third_party.cupy.testing import _condition @testing.parameterize( @@ -104,7 +104,7 @@ def test_invalid_shape(self): ) class TestInv(unittest.TestCase): @testing.for_dtypes("ifdFD") - @condition.retry(10) + @_condition.retry(10) def check_x(self, a_shape, dtype): a_cpu = numpy.random.randint(0, 10, size=a_shape) a_cpu = a_cpu.astype(dtype, order=self.order) diff --git a/tests/third_party/cupy/random_tests/test_sample.py b/tests/third_party/cupy/random_tests/test_sample.py index f95f3e42710..79e2370ad05 100644 --- a/tests/third_party/cupy/random_tests/test_sample.py +++ b/tests/third_party/cupy/random_tests/test_sample.py @@ -7,7 +7,7 @@ import dpnp as cupy from dpnp import random from tests.third_party.cupy import testing -from tests.third_party.cupy.testing import condition, hypothesis +from tests.third_party.cupy.testing import _condition, hypothesis @testing.gpu @@ -43,7 +43,7 @@ def test_zero_sizes(self): @testing.gpu class TestRandint2(unittest.TestCase): @pytest.mark.usefixtures("allow_fall_back_on_numpy") - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_bound_1(self): vals = [random.randint(0, 10, (2, 3)) for _ in range(10)] for val in vals: @@ -52,7 +52,7 @@ def test_bound_1(self): self.assertEqual(max(_.max() for _ in vals), 9) @pytest.mark.usefixtures("allow_fall_back_on_numpy") - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_bound_2(self): vals = [random.randint(0, 2) for _ in range(20)] for val in vals: @@ -61,7 +61,7 @@ def test_bound_2(self): self.assertEqual(max(_.max() for _ in vals), 1) @pytest.mark.usefixtures("allow_fall_back_on_numpy") - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_bound_overflow(self): # 100 - (-100) exceeds the range of int8 val = random.randint(numpy.int8(-100), numpy.int8(100), size=20) @@ -70,7 +70,7 @@ def test_bound_overflow(self): self.assertLess(val.max(), 100) @pytest.mark.usefixtures("allow_fall_back_on_numpy") - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_bound_float1(self): # generate floats s.t. int(low) < int(high) low, high = sorted(numpy.random.uniform(-5, 5, size=2)) @@ -90,7 +90,7 @@ def test_bound_float2(self): self.assertEqual(min(_.min() for _ in vals), -1) self.assertEqual(max(_.max() for _ in vals), 0) - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_goodness_of_fit(self): mx = 5 trial = 100 @@ -99,7 +99,7 @@ def test_goodness_of_fit(self): expected = numpy.array([float(trial) / mx] * mx) self.assertTrue(hypothesis.chi_square_test(counts, expected)) - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_goodness_of_fit_2(self): mx = 5 vals = random.randint(mx, size=(5, 20)) @@ -169,7 +169,7 @@ def test_size_is_not_none(self): @testing.fix_random() @testing.gpu class TestRandomIntegers2(unittest.TestCase): - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_bound_1(self): vals = [random.random_integers(0, 10, (2, 3)).get() for _ in range(10)] for val in vals: @@ -177,7 +177,7 @@ def test_bound_1(self): self.assertEqual(min(_.min() for _ in vals), 0) self.assertEqual(max(_.max() for _ in vals), 10) - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_bound_2(self): vals = [random.random_integers(0, 2).get() for _ in range(20)] for val in vals: @@ -185,7 +185,7 @@ def test_bound_2(self): self.assertEqual(min(vals), 0) self.assertEqual(max(vals), 2) - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_goodness_of_fit(self): mx = 5 trial = 100 @@ -194,7 +194,7 @@ def test_goodness_of_fit(self): expected = numpy.array([float(trial) / mx] * mx) self.assertTrue(hypothesis.chi_square_test(counts, expected)) - @condition.repeat(3, 10) + @_condition.repeat(3, 10) def test_goodness_of_fit_2(self): mx = 5 vals = random.randint(0, mx, (5, 20)).get() @@ -289,7 +289,7 @@ def test_randn_invalid_argument(self): @testing.fix_random() @testing.gpu class TestMultinomial(unittest.TestCase): - @condition.repeat(3, 10) + @_condition.repeat(3, 10) @testing.for_float_dtypes() @testing.numpy_cupy_allclose(rtol=0.05) def test_multinomial(self, xp, dtype): diff --git a/tests/third_party/cupy/testing/__init__.py b/tests/third_party/cupy/testing/__init__.py index 701c381e2f3..aa6c113706b 100644 --- a/tests/third_party/cupy/testing/__init__.py +++ b/tests/third_party/cupy/testing/__init__.py @@ -60,6 +60,4 @@ product, product_dict, ) -from tests.third_party.cupy.testing.random import fix_random - -# from tests.third_party.cupy.testing.random import generate_seed +from tests.third_party.cupy.testing.random import fix_random, generate_seed diff --git a/tests/third_party/cupy/testing/condition.py b/tests/third_party/cupy/testing/_condition.py similarity index 98% rename from tests/third_party/cupy/testing/condition.py rename to tests/third_party/cupy/testing/_condition.py index 4465dc3d0ee..3533ef8b84d 100644 --- a/tests/third_party/cupy/testing/condition.py +++ b/tests/third_party/cupy/testing/_condition.py @@ -106,7 +106,7 @@ def repeat(times, intensive_times=None): if intensive_times is None: return repeat_with_success_at_least(times, times) - casual_test = bool(int(os.environ.get("CUPY_TEST_CASUAL", "0"))) + casual_test = bool(int(os.environ.get("CUPY_TEST_CASUAL", "1"))) times_ = times if casual_test else intensive_times return repeat_with_success_at_least(times_, times_) diff --git a/tests/third_party/cupy/testing/random.py b/tests/third_party/cupy/testing/random.py index 444f2b3352c..ecc299737c0 100644 --- a/tests/third_party/cupy/testing/random.py +++ b/tests/third_party/cupy/testing/random.py @@ -20,12 +20,15 @@ def do_setup(deterministic=True): global _old_cupy_random_states _old_python_random_state = random.getstate() _old_numpy_random_state = numpy.random.get_state() - _old_cupy_random_states = cupy.random.generator._random_states - cupy.random.reset_states() + _old_cupy_random_states = cupy.random.dpnp_iface_random._dpnp_random_states + cupy.random.dpnp_iface_random._dpnp_random_states = {} # Check that _random_state has been recreated in # cupy.random.reset_states(). Otherwise the contents of # _old_cupy_random_states would be overwritten. - assert cupy.random.generator._random_states is not _old_cupy_random_states + assert ( + cupy.random.dpnp_iface_random._dpnp_random_states + is not _old_cupy_random_states + ) if not deterministic: random.seed() @@ -43,7 +46,7 @@ def do_teardown(): global _old_cupy_random_states random.setstate(_old_python_random_state) numpy.random.set_state(_old_numpy_random_state) - cupy.random.generator._random_states = _old_cupy_random_states + cupy.random.dpnp_iface_random._dpnp_random_states = _old_cupy_random_states _old_python_random_state = None _old_numpy_random_state = None _old_cupy_random_states = None @@ -91,12 +94,12 @@ def fix_random(): """Decorator that fixes random numbers in a test. This decorator can be applied to either a test case class or a test method. - It should not be applied within ``condition.retry`` or - ``condition.repeat``. + It should not be applied within ``_condition.retry`` or + ``_condition.repeat``. """ # TODO(niboshi): Prevent this decorator from being applied within - # condition.repeat or condition.retry decorators. That would repeat + # _condition.repeat or _condition.retry decorators. That would repeat # tests with the same random seeds. It's okay to apply this outside # these decorators.