Skip to content

Commit

Permalink
Complex number support for nb::ndarray. (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifcruz authored and wjakob committed Oct 18, 2023
1 parent fea4849 commit 6cbd138
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 13 deletions.
13 changes: 9 additions & 4 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ The following constraints are available
- A scalar type (``float``, ``uint8_t``, etc.) constrains the representation
of the ndarray.

Complex arrays (i.e., ones based on ``std::complex<float>`` or
``std::complex<double>``) are supported but additionally require including
the header file ``<nanobind/stl/complex.h>``.

- This scalar type can be further annotated with ``const``, which is necessary
if you plan to call nanobind functions with arrays that do not permit write
access.
Expand Down Expand Up @@ -468,10 +472,11 @@ For example, the following snippet makes ``__fp16`` (half-precision type on
namespace nanobind {
template <> struct ndarray_traits<__fp16> {
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
static constexpr bool is_complex = false;
static constexpr bool is_float = true;
static constexpr bool is_bool = false;
static constexpr bool is_int = false;
static constexpr bool is_signed = true;
};
};
Expand Down
38 changes: 32 additions & 6 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ struct dltensor {

NAMESPACE_END(dlpack)

NAMESPACE_BEGIN(detail)

template <typename T>
struct is_complex : public std::false_type { };

NAMESPACE_END(detail)

constexpr size_t any = (size_t) -1;

template <size_t... Is> struct shape {
Expand All @@ -81,18 +88,19 @@ struct jax { };
struct ro { };

template <typename T> struct ndarray_traits {
static constexpr bool is_float = std::is_floating_point_v<T>;
static constexpr bool is_bool = std::is_same_v<std::remove_cv_t<T>, bool>;
static constexpr bool is_int = std::is_integral_v<T> && !is_bool;
static constexpr bool is_signed = std::is_signed_v<T>;
static constexpr bool is_complex = detail::is_complex<T>::value;
static constexpr bool is_float = std::is_floating_point_v<T>;
static constexpr bool is_bool = std::is_same_v<std::remove_cv_t<T>, bool>;
static constexpr bool is_int = std::is_integral_v<T> && !is_bool;
static constexpr bool is_signed = std::is_signed_v<T>;
};

NAMESPACE_BEGIN(detail)

template <typename T>
constexpr bool is_ndarray_scalar_v =
ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool;
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex;

template <typename> struct ndim_shape;
template <size_t... S> struct ndim_shape<std::index_sequence<S...>> {
Expand All @@ -115,6 +123,8 @@ template <typename T> constexpr dlpack::dtype dtype() {
result.code = (uint8_t) dlpack::dtype_code::Float;
else if constexpr (ndarray_traits<T>::is_signed)
result.code = (uint8_t) dlpack::dtype_code::Int;
else if constexpr (ndarray_traits<T>::is_complex)
result.code = (uint8_t) dlpack::dtype_code::Complex;
else if constexpr (std::is_same_v<std::remove_cv_t<T>, bool>)
result.code = (uint8_t) dlpack::dtype_code::Bool;
else
Expand Down Expand Up @@ -163,6 +173,21 @@ template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_fl
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_complex>> {
static constexpr size_t size = 0;

static constexpr auto name =
const_name("dtype=complex") +
const_name<sizeof(T) * 8>() +
const_name<std::is_const_v<T>>(", writable=False", "");

static void apply(ndarray_req &tr) {
tr.dtype = dtype<T>();
tr.req_dtype = true;
tr.req_ro = std::is_const_v<T>;
}
};

template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_int>> {
static constexpr size_t size = 0;

Expand Down Expand Up @@ -253,7 +278,8 @@ template <typename... Ts> struct ndarray_info {
template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
using scalar_type =
std::conditional_t<ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool, T, typename ndarray_info<Ts...>::scalar_type>;
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex,
T, typename ndarray_info<Ts...>::scalar_type>;
};

template <size_t... Is, typename... Ts> struct ndarray_info<shape<Is...>, Ts...> : ndarray_info<Ts...> {
Expand Down
9 changes: 9 additions & 0 deletions include/nanobind/stl/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
NAMESPACE_BEGIN(NB_NAMESPACE)
NAMESPACE_BEGIN(detail)

template <typename T>
struct is_complex;

template <typename T>
struct is_complex<std::complex<T>> : public std::true_type { };

template <typename T>
struct is_complex<const std::complex<T>> : public std::true_type { };

template <typename T> struct type_caster<std::complex<T>> {
NB_TYPE_CASTER(std::complex<T>, const_name("complex") )

Expand Down
7 changes: 7 additions & 0 deletions src/nb_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ static int nd_ndarray_tpbuffer(PyObject *exporter, Py_buffer *view, int) {
}
break;

case dlpack::dtype_code::Complex:
switch (t.dtype.bits) {
case 64: format = "Zf"; break;
case 128: format = "Zd"; break;
}
break;

case dlpack::dtype_code::Bool:
format = "?";
break;
Expand Down
14 changes: 12 additions & 2 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/complex.h>
#include <algorithm>
#include <vector>

Expand Down Expand Up @@ -68,6 +69,7 @@ NB_MODULE(test_ndarray_ext, m) {
});

m.def("pass_float32", [](const nb::ndarray<float> &) { }, "array"_a.noconvert());
m.def("pass_complex64", [](const nb::ndarray<std::complex<float>> &) { }, "array"_a.noconvert());
m.def("pass_uint32", [](const nb::ndarray<uint32_t> &) { }, "array"_a.noconvert());
m.def("pass_bool", [](const nb::ndarray<bool> &) { }, "array"_a.noconvert());
m.def("pass_float32_shaped",
Expand Down Expand Up @@ -119,10 +121,11 @@ NB_MODULE(test_ndarray_ext, m) {
}
printf("Tensor is on CPU? %i\n", ndarray.device_type() == nb::device::cpu::value);
printf("Device ID = %u\n", ndarray.device_id());
printf("Tensor dtype check: int16=%i, uint32=%i, float32=%i\n",
printf("Tensor dtype check: int16=%i, uint32=%i, float32=%i complex64=%i\n",
ndarray.dtype() == nb::dtype<int16_t>(),
ndarray.dtype() == nb::dtype<uint32_t>(),
ndarray.dtype() == nb::dtype<float>()
ndarray.dtype() == nb::dtype<float>(),
ndarray.dtype() == nb::dtype<std::complex<float>>()
);
});

Expand Down Expand Up @@ -261,6 +264,13 @@ NB_MODULE(test_ndarray_ext, m) {
v(i, j) = (float) (i * 10 + j);
}, "x"_a.noconvert());

m.def("fill_view_5", [](nb::ndarray<std::complex<float>, nb::shape<2, 2>, nb::c_contig, nb::device::cpu> x) {
auto v = x.view();
for (size_t i = 0; i < v.shape(0); ++i)
for (size_t j = 0; j < v.shape(1); ++j)
v(i, j) *= std::complex<float>(-1.0f, 2.0f);
}, "x"_a.noconvert());

#if defined(__aarch64__)
m.def("ret_numpy_half", []() {
__fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
Expand Down
25 changes: 24 additions & 1 deletion tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test02_docstr():
assert t.get_shape.__doc__ == "get_shape(array: ndarray[writable=False]) -> list"
assert t.pass_uint32.__doc__ == "pass_uint32(array: ndarray[dtype=uint32]) -> None"
assert t.pass_float32.__doc__ == "pass_float32(array: ndarray[dtype=float32]) -> None"
assert t.pass_complex64.__doc__ == "pass_complex64(array: ndarray[dtype=complex64]) -> None"
assert t.pass_bool.__doc__ == "pass_bool(array: ndarray[dtype=bool]) -> None"
assert t.pass_float32_shaped.__doc__ == "pass_float32_shaped(array: ndarray[dtype=float32, shape=(3, *, 4)]) -> None"
assert t.pass_float32_shaped_ordered.__doc__ == "pass_float32_shaped_ordered(array: ndarray[dtype=float32, order='C', shape=(*, *, 4)]) -> None"
Expand All @@ -82,10 +83,12 @@ def test02_docstr():
def test03_constrain_dtype():
a_u32 = np.array([1], dtype=np.uint32)
a_f32 = np.array([1], dtype=np.float32)
a_cf64 = np.array([1+1j], dtype=np.complex64)
a_bool = np.array([1], dtype=np.bool_)

t.pass_uint32(a_u32)
t.pass_float32(a_f32)
t.pass_complex64(a_cf64)
t.pass_bool(a_bool)

with pytest.raises(TypeError) as excinfo:
Expand All @@ -96,6 +99,10 @@ def test03_constrain_dtype():
t.pass_float32(a_u32)
assert 'incompatible function arguments' in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
t.pass_complex64(a_u32)
assert 'incompatible function arguments' in str(excinfo.value)

with pytest.raises(TypeError) as excinfo:
t.pass_bool(a_u32)
assert 'incompatible function arguments' in str(excinfo.value)
Expand Down Expand Up @@ -573,7 +580,7 @@ def test31_view():
t.fill_view_1(x2)
assert np.allclose(x1, x2*2)

#2
# 2
x1 = np.zeros((3, 4), dtype=np.float32, order='C')
x2 = np.zeros((3, 4), dtype=np.float32, order='F')
t.fill_view_2(x1)
Expand All @@ -585,6 +592,15 @@ def test31_view():

assert np.all(x1 == x2) and np.all(x2 == x3) and np.all(x3 == x4)

# 3
x1 = np.array([[1+2j, 3+4j], [5+6j, 7+8j]], dtype=np.complex64)
x2 = x1 * 2
t.fill_view_1(x1.view(np.float32))
assert np.allclose(x1, x2)
x2 = x1 * (-1+2j)
t.fill_view_5(x1)
assert np.allclose(x1, x2)

@needs_numpy
def test32_half():
if not hasattr(t, 'ret_numpy_half'):
Expand All @@ -601,3 +617,10 @@ def test33_cast():
assert a.ndim == 0 and b.ndim == 0
assert a.dtype == np.int32 and b.dtype == np.float32
assert a == 1 and b == 1

@needs_numpy
def test34_complex_decompose():
x1 = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64)

assert np.all(x1.real == np.array([1, 3, 5], dtype=np.float32))
assert np.all(x1.imag == np.array([2, 4, 6], dtype=np.float32))

0 comments on commit 6cbd138

Please sign in to comment.