From 9ad1bb5a53087776a8a680e0dbc19cffb0a9015e Mon Sep 17 00:00:00 2001 From: Anton <100830759+antonwolfy@users.noreply.github.com> Date: Mon, 20 Jan 2025 12:29:26 +0100 Subject: [PATCH] Implement `__usm_ndarray_`_ protocol (#2261) The PR is intended to adopt to dpctl changes implemented in [dpctl#1959](https://github.com/IntelPython/dpctl/pull/1959). It implements support of `__usm_ndarray__` protocol for `dpnp.ndarray` and returns a property with `dpctl.tensor.usm_ndarray` instance corresponding to the content of the array object. This property is intended to speed-up conversion from `dpnp.ndarray` to `dpt.usm_ndarray` in `x=dpt.asarray(dpnp_array_obj)`. The input object that implements `__usm_ndarray__` is recognized as owner of USM allocation that is managed by a smart pointer, and asynchronous deallocation of `x` need not involve GIL. --- dpnp/dpnp_array.py | 19 +++++++++++++++++++ dpnp/tests/test_ndarray.py | 12 ++++++++++++ 2 files changed, 31 insertions(+) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index ff28b0c4256..8b5b12fa865 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -605,6 +605,25 @@ def __truediv__(self, other): """Return ``self/value``.""" return dpnp.true_divide(self, other) + @property + def __usm_ndarray__(self): + """ + Property to support `__usm_ndarray__` protocol. + + It assumes to return :class:`dpctl.tensor.usm_ndarray` instance + corresponding to the content of the object. + + This property is intended to speed-up conversion from + :class:`dpnp.ndarray` to :class:`dpctl.tensor.usm_ndarray` passed + into `dpctl.tensor.asarray` function. The input object that implements + `__usm_ndarray__` protocol is recognized as owner of USM allocation + that is managed by a smart pointer, and asynchronous deallocation + will not involve GIL. + + """ + + return self._array_obj + def __xor__(self, other): """Return ``self^value``.""" return dpnp.bitwise_xor(self, other) diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index a184af6ba22..0dfbf6c1e1c 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -176,6 +176,18 @@ def test_error(self): ia.item() +class TestUsmNdarrayProtocol: + def test_basic(self): + a = dpnp.arange(256, dtype=dpnp.int64) + usm_a = dpt.asarray(a) + + assert a.sycl_queue == usm_a.sycl_queue + assert a.usm_type == usm_a.usm_type + assert a.dtype == usm_a.dtype + assert usm_a.usm_data.reference_obj is None + assert (a == usm_a).all() + + def test_print_dpnp_int(): result = repr(dpnp.array([1, 0, 2, -3, -1, 2, 21, -9], dtype="i4")) expected = "array([ 1, 0, 2, -3, -1, 2, 21, -9], dtype=int32)"