Skip to content

Commit

Permalink
Use dpctl.tensor.matmul in the backend of dpnp.matmul when inputs…
Browse files Browse the repository at this point in the history
… are integer (#2296)

resolves #2270 

OneMath (OneMKL) routines (`gemm`, `gemv`, `gemm_batch`) for matrix
multiplication only support floating point data types. If inputs are
integer, to use OneMath we need to upcasting them to floating point
dtypes, perform the calculation and then cast back the result to integer
dtypes which is unsafe and we may loose some information for large
integers.
In this PR, the logic for `dpnp.matmul` is updated to use
`dpctl.tensor.matmul` when result has a integer dtypes.
  • Loading branch information
vtavana authored Feb 7, 2025
1 parent 0c455a6 commit db97d59
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 269 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,10 @@ array_api_tests/test_linalg.py::test_svd
array_api_tests/test_linalg.py::test_qr
array_api_tests/test_operators_and_elementwise_functions.py::test_clip

# unexpected result is returned
# unexpected result is returned - unmute when dpctl-1986 is resolved
array_api_tests/test_operators_and_elementwise_functions.py::test_asin
array_api_tests/test_operators_and_elementwise_functions.py::test_asinh

# missing 'correction' keyword argument
array_api_tests/test_signatures.py::test_func_signature[std]
array_api_tests/test_signatures.py::test_func_signature[var]

# arrays have different values
array_api_tests/test_linalg.py::test_linalg_tensordot
2 changes: 1 addition & 1 deletion .github/workflows/check-mkl-interfaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ jobs:
id: run_tests
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 12
timeout_minutes: 15
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ jobs:
id: run_tests_linux
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 12
timeout_minutes: 15
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
Expand Down Expand Up @@ -355,7 +355,7 @@ jobs:
id: run_tests_win
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 15
timeout_minutes: 17
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/cron-run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:
id: run_tests_linux
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 12
timeout_minutes: 15
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
Expand All @@ -143,7 +143,7 @@ jobs:
id: run_tests_win
uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
with:
timeout_minutes: 15
timeout_minutes: 17
max_attempts: ${{ env.RUN_TESTS_MAX_ATTEMPTS }}
retry_on: any
command: |
Expand Down
17 changes: 10 additions & 7 deletions dpnp/backend/extensions/blas/blas_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,18 @@ PYBIND11_MODULE(_blas_impl, m)
py::arg("sycl_queue"), py::arg("matrixA"), py::arg("vectorX"),
py::arg("vectorY"), py::arg("transpose"),
py::arg("depends") = py::list());
}

{
m.def(
"_row_major_is_available",
[](void) {
#if defined(USE_ONEMKL_CUBLAS)
return false;
#else
"_using_onemkl_interfaces",
[]() {
#ifdef USE_ONEMKL_INTERFACES
return true;
#endif // USE_ONEMKL_CUBLAS
#else
return false;
#endif
},
"Check if the onemkl::blas::row_major can be used.");
"Check if the OneMKL interfaces are being used.");
}
}
209 changes: 121 additions & 88 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,23 @@
]


def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
def _compute_res_dtype(*arrays, sycl_queue, dtype=None, out=None, casting="no"):
"""
Determines the output array data type and an intermediate data type
used in performing calculations related to a specific math function.
If dtype is ``None``, the output array data type of the operation is
determined based on the Promotion Type Rule and device capabilities.
Otherwise, `dtype` is used as output array dtype, if input arrays
can cast to it according to the casting rule determined. If casting
cannot be done, a ``TypeError`` is raised.
The intermediate data type is the data type used for performing the math
function calculations. If output array dtype is a floating-point data type,
it is also used for the intermediate data type. If output array dtype is an
integral data type, the default floating point data type of the device where
input arrays are allocated on are used for intermediate data type.
Determines the output array data type.
If `dtype` and `out` are ``None``, the output array data type of the
operation is determined based on the Promotion Type Rule and device
capabilities. if `out` is given, its data type is used as the output
array dtypes. Otherwise, `dtype` is used as output array dtype.
If input arrays cannot be cast to the determined output array dtype,
a ``TypeError`` is raised.
Parameters
----------
arrays : {dpnp.ndarray, usm_ndarray}
Input arrays.
dtype : dtype
If not ``None`` and `out` is not defined, data type of the output array.
out : {dpnp.ndarray, usm_ndarray}
If not ``None``, data type of the output array.
casting : {"no", "equiv", "safe", "same_kind", "unsafe"}, optional
Controls what kind of data casting may occur.
Expand All @@ -78,17 +75,23 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
Returns
-------
compute_dtype, res_dtype :
`compute_dtype` is the data type used in performing math function calculations.
The input arrays of the math function are cast to `compute_dtype` and then
the calculations are performed.
`res_dtype` is the output data type. When the result is obtained, it is cast
to `res_dtype`.
res_dtype :
`res_dtype` is the output data type. When the result is obtained,
it is cast to `res_dtype`.
"""

res_dtype = dpnp.result_type(*arrays)
default_dtype = dpnp.default_float_type(sycl_queue=sycl_queue)

# If inputs are boolean and `out` is given and it is not boolean, the
# calculation should be performed in boolean and at the end the result
# is cast to out dtype. It is different than general case where the inputs
# are cast to out dtype and then calculation is performed. Even when inputs
# are boolean and `dtype` is given, the casting is done first and then the
# calculation is performed.
if out is not None and res_dtype != dpnp.bool:
# out dtype is prioritized over a given dtype
dtype = out.dtype

if dtype is not None:
if dpnp.can_cast(res_dtype, dtype, casting=casting):
Expand All @@ -98,11 +101,7 @@ def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
f"Cannot cast from dtype({res_dtype}) to dtype({dtype}) with casting rule {casting}"
)

compute_dtype = (
res_dtype if dpnp.issubdtype(res_dtype, dpnp.inexact) else default_dtype
)

return compute_dtype, res_dtype
return res_dtype


def _copy_array(x, copy_flag=False, dtype=None, order="C"):
Expand Down Expand Up @@ -504,6 +503,23 @@ def _gemm_matmul(exec_q, x1, x2, res):
return res


def _gemm_special_case(x1, x2, res_dtype, call_flag):
"""
`gemm` and `gemm_batch` support these special cases of data types
while `gemv` does not.
"""
# TODO: replace with dpnp.int8 when it is added
is_int8 = x1.dtype == numpy.int8 and x2.dtype == numpy.int8
is_int32_or_f32 = res_dtype in [dpnp.int32, dpnp.float32]
flag = is_int8 and is_int32_or_f32 and call_flag in ["gemm", "gemm_batch"]

# onemkl_interfaces does not support these data types
onemkl_interfaces = bi._using_onemkl_interfaces()

return flag and not onemkl_interfaces


def _shape_error(shape1, shape2, func, err_msg):
"""Validate the shapes of input and output arrays."""

Expand Down Expand Up @@ -749,17 +765,19 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
_validate_out_array(out, exec_q)

# Determine the appropriate data types
dot_dtype, res_dtype = _compute_res_dtype(a, b, sycl_queue=exec_q)
res_dtype = _compute_res_dtype(
a, b, out=out, casting=casting, sycl_queue=exec_q
)

result = _create_result_array(
a, b, out, (), dot_dtype, res_usm_type, exec_q
a, b, out, (), res_dtype, res_usm_type, exec_q
)

# input arrays should have the proper data type
if dpnp.issubdtype(res_dtype, dpnp.inexact):
# copying is needed if dtypes of input arrays are different
a = _copy_array(a, dtype=dot_dtype)
b = _copy_array(b, dtype=dot_dtype)
a = _copy_array(a, dtype=res_dtype)
b = _copy_array(b, dtype=res_dtype)

_manager = dpu.SequentialOrderManager[exec_q]

Expand All @@ -777,14 +795,11 @@ def dpnp_dot(a, b, /, out=None, *, casting="same_kind", conjugate=False):
)
_manager.add_event_pair(ht_ev, dot_ev)
else:
# oneapi::mkl::blas::dot is slow for integer data type,
# oneapi::mkl::blas::dot does not support integer dtypes,
# so using dpctl.tensor.vecdot instead
dpt_a = dpnp.get_usm_ndarray(a)
dpt_b = dpnp.get_usm_ndarray(b)
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(dpt_a, dpt_b))

if dot_dtype != res_dtype:
result = result.astype(res_dtype, copy=False)
a_usm = dpnp.get_usm_ndarray(a)
b_usm = dpnp.get_usm_ndarray(b)
result = dpnp_array._create_from_usm_ndarray(dpt.vecdot(a_usm, b_usm))

return dpnp.get_result_array(result, out, casting=casting)

Expand Down Expand Up @@ -902,8 +917,8 @@ def dpnp_multiplication(
axes_res = normalize_axis_tuple(axes_res, len(result_shape), "axes")

# Determine the appropriate data types
compute_dtype, res_dtype = _compute_res_dtype(
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
res_dtype = _compute_res_dtype(
x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q
)

call_flag = None
Expand Down Expand Up @@ -998,7 +1013,7 @@ def dpnp_multiplication(
x2,
out,
res_shape,
compute_dtype,
res_dtype,
res_usm_type,
exec_q,
res_order,
Expand All @@ -1010,64 +1025,82 @@ def dpnp_multiplication(
elif x1.size == 0 or x2.size == 0:
result.fill(0)
else:
# input arrays should have the proper data type and
# their base (last 2-dimensions) to be c-contiguous or f-contiguous
x1 = _copy_array(
x1,
copy_flag=not x1_contig_flag,
dtype=compute_dtype,
order=res_order,
)
x2 = _copy_array(
x2,
copy_flag=not x2_contig_flag,
dtype=compute_dtype,
order=res_order,
)

if call_flag == "gemv":
if transpose:
a_usm = dpnp.get_usm_ndarray(x2)
x_usm = dpnp.get_usm_ndarray(x1)
else:
a_usm = dpnp.get_usm_ndarray(x1)
x_usm = dpnp.get_usm_ndarray(x2)

_manager = dpu.SequentialOrderManager[exec_q]

ht_ev, gemv_ev = bi._gemv(
exec_q,
a_usm,
x_usm,
dpnp.get_usm_ndarray(result),
transpose,
depends=_manager.submitted_events,
if _gemm_special_case(x1, x2, res_dtype, call_flag):
x1 = _copy_array(
x1, copy_flag=not x1_contig_flag, order=res_order
)
_manager.add_event_pair(ht_ev, gemv_ev)
elif call_flag == "gemm":
result = _gemm_matmul(
exec_q,
x1,
x2,
result,
x2 = _copy_array(
x2, copy_flag=not x2_contig_flag, order=res_order
)
else: # call_flag == "gemm_batch"
assert call_flag == "gemm_batch"
result = _gemm_batch_matmul(
exec_q,
if call_flag == "gemm":
result = _gemm_matmul(exec_q, x1, x2, result)
else:
assert call_flag == "gemm_batch"
result = _gemm_batch_matmul(exec_q, x1, x2, result)
elif dpnp.issubdtype(res_dtype, dpnp.inexact):
# copying is needed if dtypes of input arrays are different or
# their base (last 2-dimensions) is not c-contiguous or f-contiguous
x1 = _copy_array(
x1,
copy_flag=not x1_contig_flag,
dtype=res_dtype,
order=res_order,
)
x2 = _copy_array(
x2,
result,
copy_flag=not x2_contig_flag,
dtype=res_dtype,
order=res_order,
)

if call_flag == "gemv":
if transpose:
a_usm = dpnp.get_usm_ndarray(x2)
x_usm = dpnp.get_usm_ndarray(x1)
else:
a_usm = dpnp.get_usm_ndarray(x1)
x_usm = dpnp.get_usm_ndarray(x2)

_manager = dpu.SequentialOrderManager[exec_q]

ht_ev, gemv_ev = bi._gemv(
exec_q,
a_usm,
x_usm,
dpnp.get_usm_ndarray(result),
transpose,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, gemv_ev)
elif call_flag == "gemm":
result = _gemm_matmul(exec_q, x1, x2, result)
else:
assert call_flag == "gemm_batch"
result = _gemm_batch_matmul(exec_q, x1, x2, result)
else:
# oneapi::mkl::blas::gemm/gemv do not support integer dtypes,
# except for special cases determined in `_gemm_special_case`,
# use dpctl.tensor.matmul for unsupported cases

# `dpt.matmul` does not support `casting` kwarg.
# We may need to change input dtypes based on given `casting`.
# The possibility of casting is already validated in
# `_compute_res_dtype`.
x1 = _copy_array(x1, dtype=res_dtype, order=res_order)
x2 = _copy_array(x2, dtype=res_dtype, order=res_order)

x1_usm = dpnp.get_usm_ndarray(x1)
x2_usm = dpnp.get_usm_ndarray(x2)
out_usm = dpnp.get_usm_ndarray(result)
dpt.matmul(
x1_usm, x2_usm, out=out_usm, dtype=dtype, order=order
)

if NumPy_special_case:
result = dpnp.tile(result, out.shape)
elif res_shape != result_shape:
result = dpnp.reshape(result, result_shape)

if compute_dtype != res_dtype:
result = dpnp.astype(result, res_dtype, copy=False)

if out is None:
if axes is not None:
# Move the data back to the appropriate axes of the result array
Expand Down Expand Up @@ -1207,8 +1240,8 @@ def dpnp_vecdot(
)

# Determine the appropriate data types
_, res_dtype = _compute_res_dtype(
x1, x2, dtype=dtype, casting=casting, sycl_queue=exec_q
res_dtype = _compute_res_dtype(
x1, x2, dtype=dtype, out=out, casting=casting, sycl_queue=exec_q
)

_, x1_is_1D, _ = _define_dim_flags(x1, axis=-1)
Expand Down
Loading

0 comments on commit db97d59

Please sign in to comment.