Skip to content

Commit

Permalink
Return named tuple for eig, eigh, qr, slogdet, svd functions
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy committed Jan 23, 2025
1 parent d522480 commit 30a65a7
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 38 deletions.
19 changes: 18 additions & 1 deletion dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@
# pylint: disable=invalid-name
# pylint: disable=no-member

from typing import NamedTuple

import numpy
from dpctl.tensor._numpy_helper import normalize_axis_tuple

import dpnp

from .dpnp_utils_linalg import (
EighResult,
QRResult,
SlogdetResult,
SVDResult,
assert_2d,
assert_stacked_2d,
assert_stacked_square,
Expand All @@ -66,6 +72,11 @@
)

__all__ = [
"EigResult",
"EighResult",
"QRResult",
"SlogdetResult",
"SVDResult",
"cholesky",
"cond",
"cross",
Expand Down Expand Up @@ -100,6 +111,12 @@
]


# pylint:disable=missing-class-docstring
class EigResult(NamedTuple):
eigenvalues: dpnp.ndarray
eigenvectors: dpnp.ndarray


def cholesky(a, /, *, upper=False):
"""
Cholesky decomposition.
Expand Down Expand Up @@ -532,7 +549,7 @@ def eig(a):
# Since geev function from OneMKL LAPACK is not implemented yet,
# use NumPy for this calculation.
w_np, v_np = numpy.linalg.eig(dpnp.asnumpy(a))
return (
return EigResult(
dpnp.array(w_np, sycl_queue=a_sycl_queue, usm_type=a_usm_type),
dpnp.array(v_np, sycl_queue=a_sycl_queue, usm_type=a_usm_type),
)
Expand Down
108 changes: 71 additions & 37 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
# pylint: disable=protected-access
# pylint: disable=useless-import-alias

from typing import NamedTuple

import dpctl.tensor._tensor_impl as ti
import dpctl.utils as dpu
import numpy
Expand All @@ -50,6 +52,10 @@
from dpnp.linalg import LinAlgError as LinAlgError

__all__ = [
"EighResult",
"QRResult",
"SlogdetResult",
"SVDResult",
"assert_2d",
"assert_stacked_2d",
"assert_stacked_square",
Expand All @@ -70,6 +76,29 @@
"dpnp_svd",
]


# pylint:disable=missing-class-docstring
class EighResult(NamedTuple):
eigenvalues: dpnp.ndarray
eigenvectors: dpnp.ndarray


class QRResult(NamedTuple):
Q: dpnp.ndarray
R: dpnp.ndarray


class SlogdetResult(NamedTuple):
sign: dpnp.ndarray
logabsdet: dpnp.ndarray


class SVDResult(NamedTuple):
U: dpnp.ndarray
S: dpnp.ndarray
Vh: dpnp.ndarray


_jobz = {"N": 0, "V": 1}
_upper_lower = {"U": 0, "L": 1}

Expand Down Expand Up @@ -162,7 +191,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
# Convert to contiguous to align with NumPy
if a_orig_order == "C":
v = dpnp.ascontiguousarray(v)
return w, v
return EighResult(w, v)
return w


Expand Down Expand Up @@ -476,7 +505,7 @@ def _batched_qr(a, mode="reduced"):

r = _triu_inplace(r)

return (
return QRResult(
q.reshape(batch_shape + q.shape[-2:]),
r.reshape(batch_shape + r.shape[-2:]),
)
Expand Down Expand Up @@ -632,7 +661,7 @@ def _batched_svd(
u = dpnp.ascontiguousarray(u)
vt = dpnp.ascontiguousarray(vt)
# Swap `u` and `vt` for transposed input to restore correct order
return (vt, s, u) if trans_flag else (u, s, vt)
return SVDResult(vt, s, u) if trans_flag else SVDResult(u, s, vt)
return s


Expand Down Expand Up @@ -819,9 +848,9 @@ def _hermitian_svd(a, compute_uv):
# but dpnp.linalg.eigh returns s sorted ascending so we re-order
# the eigenvalues and related arrays to have the correct order
if compute_uv:
s, u = dpnp.linalg.eigh(a)
sgn = dpnp.sign(s)
s = dpnp.absolute(s)
s, u = s = dpnp_eigh(a, eigen_mode="V")
sgn = dpnp.sign(s, out=s)
s = dpnp.abs(s, out=s)
sidx = dpnp.argsort(s)[..., ::-1]
# Rearrange the signs according to sorted indices
sgn = dpnp.take_along_axis(sgn, sidx, axis=-1)
Expand All @@ -832,11 +861,10 @@ def _hermitian_svd(a, compute_uv):
# Singular values are unsigned, move the sign into v
# Compute V^T adjusting for the sign and conjugating
vt = dpnp.transpose(u * sgn[..., None, :]).conjugate()
return u, s, vt
return SVDResult(u, s, vt)

# TODO: use dpnp.linalg.eighvals when it is updated
s, _ = dpnp.linalg.eigh(a)
s = dpnp.abs(s)
s = dpnp_eigh(a, eigen_mode="N")
s = dpnp.abs(s, out=s)
return dpnp.sort(s)[..., ::-1]


Expand Down Expand Up @@ -1423,7 +1451,7 @@ def _zero_batched_qr(a, mode, m, n, k, res_type):
batch_shape = a.shape[:-2]

if mode == "reduced":
return (
return QRResult(
dpnp.empty_like(
a,
shape=batch_shape + (m, k),
Expand All @@ -1443,7 +1471,7 @@ def _zero_batched_qr(a, mode, m, n, k, res_type):
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)
return (
return QRResult(
q,
dpnp.empty_like(
a,
Expand Down Expand Up @@ -1530,7 +1558,7 @@ def _zero_batched_svd(
usm_type=usm_type,
sycl_queue=exec_q,
)
return u, s, vt
return SVDResult(u, s, vt)
return s


Expand All @@ -1548,22 +1576,28 @@ def _zero_k_qr(a, mode, m, n, res_type):
m, n = a.shape

if mode == "reduced":
return dpnp.empty_like(
a,
shape=(m, 0),
dtype=res_type,
), dpnp.empty_like(
a,
shape=(0, n),
dtype=res_type,
return QRResult(
dpnp.empty_like(
a,
shape=(m, 0),
dtype=res_type,
),
dpnp.empty_like(
a,
shape=(0, n),
dtype=res_type,
),
)
if mode == "complete":
return dpnp.identity(
m, dtype=res_type, sycl_queue=a_sycl_queue, usm_type=a_usm_type
), dpnp.empty_like(
a,
shape=(m, n),
dtype=res_type,
return QRResult(
dpnp.identity(
m, dtype=res_type, sycl_queue=a_sycl_queue, usm_type=a_usm_type
),
dpnp.empty_like(
a,
shape=(m, n),
dtype=res_type,
),
)
if mode == "r":
return dpnp.empty_like(
Expand Down Expand Up @@ -1648,7 +1682,7 @@ def _zero_m_n_batched_svd(
usm_type=usm_type,
sycl_queue=exec_q,
)
return u, s, vt
return SVDResult(u, s, vt)
return s


Expand Down Expand Up @@ -1692,7 +1726,7 @@ def _zero_m_n_svd(
usm_type=usm_type,
sycl_queue=exec_q,
)
return u, s, vt
return SVDResult(u, s, vt)
return s


Expand Down Expand Up @@ -1993,7 +2027,7 @@ def dpnp_det(a):
return det.reshape(shape)


def dpnp_eigh(a, UPLO, eigen_mode="V"):
def dpnp_eigh(a, UPLO="L", eigen_mode="V"):
"""
dpnp_eigh(a, UPLO, eigen_mode="V")
Expand All @@ -2016,7 +2050,7 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
w = dpnp.empty_like(a, shape=a.shape[:-1], dtype=w_type)
if eigen_mode == "V":
v = dpnp.empty_like(a, dtype=v_type)
return w, v
return EighResult(w, v)
return w

if a.ndim > 2:
Expand Down Expand Up @@ -2097,7 +2131,7 @@ def dpnp_eigh(a, UPLO, eigen_mode="V"):
else:
out_v = v

return (w, out_v) if eigen_mode == "V" else w
return EighResult(w, out_v) if eigen_mode == "V" else w


def dpnp_inv(a):
Expand Down Expand Up @@ -2546,7 +2580,7 @@ def dpnp_qr(a, mode="reduced"):
r = a_t[:, :mc].transpose()

r = _triu_inplace(r)
return (q, r)
return QRResult(q, r)


def dpnp_solve(a, b):
Expand Down Expand Up @@ -2675,7 +2709,7 @@ def dpnp_slogdet(a):
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)
return sign, logdet
return SlogdetResult(sign, logdet)

lu, ipiv, dev_info = _lu_factor(a, res_type)

Expand All @@ -2687,7 +2721,7 @@ def dpnp_slogdet(a):

logdet = logdet.astype(logdet_dtype, copy=False)
singular = dev_info > 0
return (
return SlogdetResult(
dpnp.where(singular, res_type.type(0), sign).reshape(shape),
dpnp.where(singular, logdet_dtype.type("-inf"), logdet).reshape(shape),
)
Expand Down Expand Up @@ -2815,10 +2849,10 @@ def dpnp_svd(
# For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
# Transpose and swap them back to restore correct order for A.
if trans_flag:
return vt_h.T, s_h, u_h.T
return SVDResult(vt_h.T, s_h, u_h.T)
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# Convert to contiguous to align with NumPy
u_h = dpnp.ascontiguousarray(u_h)
vt_h = dpnp.ascontiguousarray(vt_h)
return u_h, s_h, vt_h
return SVDResult(u_h, s_h, vt_h)
return s_h

0 comments on commit 30a65a7

Please sign in to comment.