From 6cbd1387753ea8f519ac0fe2242f0a54dd670ede Mon Sep 17 00:00:00 2001 From: Luigi Cruz Date: Wed, 18 Oct 2023 17:16:02 -0300 Subject: [PATCH] Complex number support for ``nb::ndarray``. (#319) --- docs/ndarray.rst | 13 ++++++++---- include/nanobind/ndarray.h | 38 ++++++++++++++++++++++++++++------ include/nanobind/stl/complex.h | 9 ++++++++ src/nb_ndarray.cpp | 7 +++++++ tests/test_ndarray.cpp | 14 +++++++++++-- tests/test_ndarray.py | 25 +++++++++++++++++++++- 6 files changed, 93 insertions(+), 13 deletions(-) diff --git a/docs/ndarray.rst b/docs/ndarray.rst index 8d2c0508..8673d078 100644 --- a/docs/ndarray.rst +++ b/docs/ndarray.rst @@ -115,6 +115,10 @@ The following constraints are available - A scalar type (``float``, ``uint8_t``, etc.) constrains the representation of the ndarray. + Complex arrays (i.e., ones based on ``std::complex`` or + ``std::complex``) are supported but additionally require including + the header file ````. + - This scalar type can be further annotated with ``const``, which is necessary if you plan to call nanobind functions with arrays that do not permit write access. @@ -468,10 +472,11 @@ For example, the following snippet makes ``__fp16`` (half-precision type on namespace nanobind { template <> struct ndarray_traits<__fp16> { - static constexpr bool is_float = true; - static constexpr bool is_bool = false; - static constexpr bool is_int = false; - static constexpr bool is_signed = true; + static constexpr bool is_complex = false; + static constexpr bool is_float = true; + static constexpr bool is_bool = false; + static constexpr bool is_int = false; + static constexpr bool is_signed = true; }; }; diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h index 0b360219..7b9e6551 100644 --- a/include/nanobind/ndarray.h +++ b/include/nanobind/ndarray.h @@ -65,6 +65,13 @@ struct dltensor { NAMESPACE_END(dlpack) +NAMESPACE_BEGIN(detail) + +template +struct is_complex : public std::false_type { }; + +NAMESPACE_END(detail) + constexpr size_t any = (size_t) -1; template struct shape { @@ -81,10 +88,11 @@ struct jax { }; struct ro { }; template struct ndarray_traits { - static constexpr bool is_float = std::is_floating_point_v; - static constexpr bool is_bool = std::is_same_v, bool>; - static constexpr bool is_int = std::is_integral_v && !is_bool; - static constexpr bool is_signed = std::is_signed_v; + static constexpr bool is_complex = detail::is_complex::value; + static constexpr bool is_float = std::is_floating_point_v; + static constexpr bool is_bool = std::is_same_v, bool>; + static constexpr bool is_int = std::is_integral_v && !is_bool; + static constexpr bool is_signed = std::is_signed_v; }; NAMESPACE_BEGIN(detail) @@ -92,7 +100,7 @@ NAMESPACE_BEGIN(detail) template constexpr bool is_ndarray_scalar_v = ndarray_traits::is_float || ndarray_traits::is_int || - ndarray_traits::is_bool; + ndarray_traits::is_bool || ndarray_traits::is_complex; template struct ndim_shape; template struct ndim_shape> { @@ -115,6 +123,8 @@ template constexpr dlpack::dtype dtype() { result.code = (uint8_t) dlpack::dtype_code::Float; else if constexpr (ndarray_traits::is_signed) result.code = (uint8_t) dlpack::dtype_code::Int; + else if constexpr (ndarray_traits::is_complex) + result.code = (uint8_t) dlpack::dtype_code::Complex; else if constexpr (std::is_same_v, bool>) result.code = (uint8_t) dlpack::dtype_code::Bool; else @@ -163,6 +173,21 @@ template struct ndarray_arg::is_fl } }; +template struct ndarray_arg::is_complex>> { + static constexpr size_t size = 0; + + static constexpr auto name = + const_name("dtype=complex") + + const_name() + + const_name>(", writable=False", ""); + + static void apply(ndarray_req &tr) { + tr.dtype = dtype(); + tr.req_dtype = true; + tr.req_ro = std::is_const_v; + } +}; + template struct ndarray_arg::is_int>> { static constexpr size_t size = 0; @@ -253,7 +278,8 @@ template struct ndarray_info { template struct ndarray_info : ndarray_info { using scalar_type = std::conditional_t::is_float || ndarray_traits::is_int || - ndarray_traits::is_bool, T, typename ndarray_info::scalar_type>; + ndarray_traits::is_bool || ndarray_traits::is_complex, + T, typename ndarray_info::scalar_type>; }; template struct ndarray_info, Ts...> : ndarray_info { diff --git a/include/nanobind/stl/complex.h b/include/nanobind/stl/complex.h index d69b8467..a86e0f22 100644 --- a/include/nanobind/stl/complex.h +++ b/include/nanobind/stl/complex.h @@ -15,6 +15,15 @@ NAMESPACE_BEGIN(NB_NAMESPACE) NAMESPACE_BEGIN(detail) +template +struct is_complex; + +template +struct is_complex> : public std::true_type { }; + +template +struct is_complex> : public std::true_type { }; + template struct type_caster> { NB_TYPE_CASTER(std::complex, const_name("complex") ) diff --git a/src/nb_ndarray.cpp b/src/nb_ndarray.cpp index 5e4ff1d0..a2917b54 100644 --- a/src/nb_ndarray.cpp +++ b/src/nb_ndarray.cpp @@ -69,6 +69,13 @@ static int nd_ndarray_tpbuffer(PyObject *exporter, Py_buffer *view, int) { } break; + case dlpack::dtype_code::Complex: + switch (t.dtype.bits) { + case 64: format = "Zf"; break; + case 128: format = "Zd"; break; + } + break; + case dlpack::dtype_code::Bool: format = "?"; break; diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index 092f6401..30b0f6cd 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -68,6 +69,7 @@ NB_MODULE(test_ndarray_ext, m) { }); m.def("pass_float32", [](const nb::ndarray &) { }, "array"_a.noconvert()); + m.def("pass_complex64", [](const nb::ndarray> &) { }, "array"_a.noconvert()); m.def("pass_uint32", [](const nb::ndarray &) { }, "array"_a.noconvert()); m.def("pass_bool", [](const nb::ndarray &) { }, "array"_a.noconvert()); m.def("pass_float32_shaped", @@ -119,10 +121,11 @@ NB_MODULE(test_ndarray_ext, m) { } printf("Tensor is on CPU? %i\n", ndarray.device_type() == nb::device::cpu::value); printf("Device ID = %u\n", ndarray.device_id()); - printf("Tensor dtype check: int16=%i, uint32=%i, float32=%i\n", + printf("Tensor dtype check: int16=%i, uint32=%i, float32=%i complex64=%i\n", ndarray.dtype() == nb::dtype(), ndarray.dtype() == nb::dtype(), - ndarray.dtype() == nb::dtype() + ndarray.dtype() == nb::dtype(), + ndarray.dtype() == nb::dtype>() ); }); @@ -261,6 +264,13 @@ NB_MODULE(test_ndarray_ext, m) { v(i, j) = (float) (i * 10 + j); }, "x"_a.noconvert()); + m.def("fill_view_5", [](nb::ndarray, nb::shape<2, 2>, nb::c_contig, nb::device::cpu> x) { + auto v = x.view(); + for (size_t i = 0; i < v.shape(0); ++i) + for (size_t j = 0; j < v.shape(1); ++j) + v(i, j) *= std::complex(-1.0f, 2.0f); + }, "x"_a.noconvert()); + #if defined(__aarch64__) m.def("ret_numpy_half", []() { __fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 2554787e..05c9fabc 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -71,6 +71,7 @@ def test02_docstr(): assert t.get_shape.__doc__ == "get_shape(array: ndarray[writable=False]) -> list" assert t.pass_uint32.__doc__ == "pass_uint32(array: ndarray[dtype=uint32]) -> None" assert t.pass_float32.__doc__ == "pass_float32(array: ndarray[dtype=float32]) -> None" + assert t.pass_complex64.__doc__ == "pass_complex64(array: ndarray[dtype=complex64]) -> None" assert t.pass_bool.__doc__ == "pass_bool(array: ndarray[dtype=bool]) -> None" assert t.pass_float32_shaped.__doc__ == "pass_float32_shaped(array: ndarray[dtype=float32, shape=(3, *, 4)]) -> None" assert t.pass_float32_shaped_ordered.__doc__ == "pass_float32_shaped_ordered(array: ndarray[dtype=float32, order='C', shape=(*, *, 4)]) -> None" @@ -82,10 +83,12 @@ def test02_docstr(): def test03_constrain_dtype(): a_u32 = np.array([1], dtype=np.uint32) a_f32 = np.array([1], dtype=np.float32) + a_cf64 = np.array([1+1j], dtype=np.complex64) a_bool = np.array([1], dtype=np.bool_) t.pass_uint32(a_u32) t.pass_float32(a_f32) + t.pass_complex64(a_cf64) t.pass_bool(a_bool) with pytest.raises(TypeError) as excinfo: @@ -96,6 +99,10 @@ def test03_constrain_dtype(): t.pass_float32(a_u32) assert 'incompatible function arguments' in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: + t.pass_complex64(a_u32) + assert 'incompatible function arguments' in str(excinfo.value) + with pytest.raises(TypeError) as excinfo: t.pass_bool(a_u32) assert 'incompatible function arguments' in str(excinfo.value) @@ -573,7 +580,7 @@ def test31_view(): t.fill_view_1(x2) assert np.allclose(x1, x2*2) - #2 + # 2 x1 = np.zeros((3, 4), dtype=np.float32, order='C') x2 = np.zeros((3, 4), dtype=np.float32, order='F') t.fill_view_2(x1) @@ -585,6 +592,15 @@ def test31_view(): assert np.all(x1 == x2) and np.all(x2 == x3) and np.all(x3 == x4) + # 3 + x1 = np.array([[1+2j, 3+4j], [5+6j, 7+8j]], dtype=np.complex64) + x2 = x1 * 2 + t.fill_view_1(x1.view(np.float32)) + assert np.allclose(x1, x2) + x2 = x1 * (-1+2j) + t.fill_view_5(x1) + assert np.allclose(x1, x2) + @needs_numpy def test32_half(): if not hasattr(t, 'ret_numpy_half'): @@ -601,3 +617,10 @@ def test33_cast(): assert a.ndim == 0 and b.ndim == 0 assert a.dtype == np.int32 and b.dtype == np.float32 assert a == 1 and b == 1 + +@needs_numpy +def test34_complex_decompose(): + x1 = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) + + assert np.all(x1.real == np.array([1, 3, 5], dtype=np.float32)) + assert np.all(x1.imag == np.array([2, 4, 6], dtype=np.float32))