Skip to content

Commit

Permalink
lazy import cuBLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
kmaehashi committed Oct 10, 2023
1 parent 9f819ae commit 904d1c0
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 13 deletions.
6 changes: 5 additions & 1 deletion cupy/_core/_routines_linalg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ from cupy._core cimport _routines_manipulation as _manipulation
from cupy._core cimport _routines_math as _math
from cupy.cuda cimport device
from cupy_backends.cuda.api cimport runtime
from cupy_backends.cuda.libs cimport cublas


cdef extern from '../../cupy_backends/cupy_complex.h':
Expand Down Expand Up @@ -543,6 +542,8 @@ cpdef _ndarray_base dot(
cpdef _ndarray_base tensordot_core(
_ndarray_base a, _ndarray_base b, _ndarray_base out, Py_ssize_t n,
Py_ssize_t m, Py_ssize_t k, const shape_t& ret_shape):
from cupy_backends.cuda.libs import cublas

# out, if specified, must be C-contiguous and have correct shape.
cdef shape_t shape
cdef Py_ssize_t transa, transb, lda, ldb
Expand Down Expand Up @@ -704,6 +705,8 @@ cpdef _ndarray_base tensordot_core_v11(
Py_ssize_t transa, Py_ssize_t transb, Py_ssize_t m, Py_ssize_t n,
Py_ssize_t k, _ndarray_base a, Py_ssize_t lda, _ndarray_base b,
Py_ssize_t ldb, _ndarray_base c, Py_ssize_t ldc):
from cupy_backends.cuda.libs import cublas

cdef float one_f, zero_f
cdef double one_d, zero_d
cdef cuComplex one_F, zero_F
Expand Down Expand Up @@ -833,6 +836,7 @@ cpdef _ndarray_base matmul(
.. seealso:: :func:`numpy.matmul`
"""
from cupy_backends.cuda.libs import cublas

cdef Py_ssize_t i, n, m, ka, kb, a_sh, b_sh, c_sh, ldc
cdef Py_ssize_t batchCount, a_part_outshape, b_part_outshape
Expand Down
3 changes: 2 additions & 1 deletion cupy/_core/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ from cupy.cuda cimport memory
from cupy.cuda cimport stream as stream_module
from cupy_backends.cuda cimport stream as _stream_module
from cupy_backends.cuda.api cimport runtime
from cupy_backends.cuda.libs cimport cublas


# If rop of cupy.ndarray is called, cupy's op is the last chance.
Expand Down Expand Up @@ -2674,6 +2673,8 @@ cpdef _ndarray_base _internal_ascontiguousarray(_ndarray_base a):


cpdef _ndarray_base _internal_asfortranarray(_ndarray_base a):
from cupy_backends.cuda.libs import cublas

cdef _ndarray_base newarray
cdef int m, n
cdef intptr_t handle
Expand Down
5 changes: 4 additions & 1 deletion cupy/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from cupy.cuda import texture # NOQA
from cupy_backends.cuda.api import driver # NOQA
from cupy_backends.cuda.api import runtime # NOQA
from cupy_backends.cuda.libs import cublas # NOQA
from cupy_backends.cuda.libs import nvrtc # NOQA
from cupy_backends.cuda.libs import profiler # NOQA

Expand Down Expand Up @@ -63,6 +62,10 @@ def __getattr__(key):
from cupy_backends.cuda.libs import curand
_cupy.cuda.curand = curand
return curand
elif key == 'cublas':
from cupy_backends.cuda.libs import cublas
_cupy.cuda.cublas = cublas
return cublas

# `nvtx_enabled` flags are kept for backward compatibility with Chainer.
# Note: module-level getattr only runs on Python 3.7+.
Expand Down
2 changes: 1 addition & 1 deletion cupy/cuda/device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import threading
from cupy._core import syncdetect
from cupy_backends.cuda.api cimport runtime
from cupy_backends.cuda.api import runtime as runtime_module
from cupy_backends.cuda.libs import cublas
from cupy import _util


Expand Down Expand Up @@ -251,6 +250,7 @@ cdef class Device:
itself is different.
"""
from cupy_backends.cuda.libs import cublas
return self._get_handle(
'cublas_handles', cublas.create, cublas.destroy)

Expand Down
13 changes: 8 additions & 5 deletions cupy/linalg/_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import cupy
from cupy_backends.cuda.api import runtime
from cupy_backends.cuda.libs import cublas
from cupy._core import internal
from cupy.cuda import device
from cupy.linalg import _util
Expand Down Expand Up @@ -33,6 +32,7 @@ def _lu_factor(a_t, dtype):
.. seealso:: :func:`scipy.linalg.lu_factor`
"""
from cupy_backends.cuda.libs import cublas
from cupy_backends.cuda.libs import cusolver

orig_shape = a_t.shape
Expand All @@ -57,13 +57,13 @@ def _lu_factor(a_t, dtype):
a_array = cupy.arange(start, stop, step, dtype=cupy.uintp)

if dtype == numpy.float32:
getrfBatched = cupy.cuda.cublas.sgetrfBatched
getrfBatched = cublas.sgetrfBatched
elif dtype == numpy.float64:
getrfBatched = cupy.cuda.cublas.dgetrfBatched
getrfBatched = cublas.dgetrfBatched
elif dtype == numpy.complex64:
getrfBatched = cupy.cuda.cublas.cgetrfBatched
getrfBatched = cublas.cgetrfBatched
elif dtype == numpy.complex128:
getrfBatched = cupy.cuda.cublas.zgetrfBatched
getrfBatched = cublas.zgetrfBatched
else:
assert False

Expand Down Expand Up @@ -117,8 +117,10 @@ def _potrf_batched(a):
Returns:
cupy.ndarray: The lower-triangular matrix.
"""
from cupy_backends.cuda.libs import cublas
from cupy_backends.cuda.libs import cusolver
from cupyx.cusolver import check_availability

if not check_availability('potrfBatched'):
raise RuntimeError('potrfBatched is not available')

Expand Down Expand Up @@ -175,6 +177,7 @@ def cholesky(a):
.. seealso:: :func:`numpy.linalg.cholesky`
"""
from cupy_backends.cuda.libs import cublas
from cupy_backends.cuda.libs import cusolver

_util._assert_cupy_array(a)
Expand Down
2 changes: 1 addition & 1 deletion cupy/linalg/_eigenvalue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy

import cupy
from cupy_backends.cuda.libs import cublas
from cupy.cuda import device
from cupy.cuda import runtime
from cupy.linalg import _util
Expand All @@ -12,6 +11,7 @@


def _syevd(a, UPLO, with_eigen_vector, overwrite_a=False):
from cupy_backends.cuda.libs import cublas
from cupy_backends.cuda.libs import cusolver

if UPLO not in ('L', 'U'):
Expand Down
3 changes: 2 additions & 1 deletion cupy/linalg/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from cupy.cuda import device
from cupy.linalg import _decomposition
from cupy.linalg import _util
from cupy.cublas import batched_gesv, get_batched_gesv_limit
import cupyx


Expand Down Expand Up @@ -36,6 +35,8 @@ def solve(a, b):
.. seealso:: :func:`numpy.linalg.solve`
"""
from cupy.cublas import batched_gesv, get_batched_gesv_limit

if a.ndim > 2 and a.shape[-1] <= get_batched_gesv_limit():
# Note: There is a low performance issue in batched_gesv when matrix is
# large, so it is not used in such cases.
Expand Down
1 change: 0 additions & 1 deletion cupyx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from cupyx import time # NOQA
from cupyx import scipy # NOQA
from cupyx import optimizing # NOQA
from cupyx import lapack # NOQA

from cupyx._ufunc_config import errstate # NOQA
from cupyx._ufunc_config import geterr # NOQA
Expand Down
2 changes: 1 addition & 1 deletion cupyx/linalg/_solve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from cupy.linalg import _util
from cupyx import lapack


def invh(a):
Expand All @@ -15,6 +14,7 @@ def invh(a):
Returns:
cupy.ndarray: The inverse of matrix ``a``.
"""
from cupyx import lapack

_util._assert_cupy_array(a)
# TODO: Use `_assert_stacked_2d` instead, once cusolver supports nrhs > 1
Expand Down

0 comments on commit 904d1c0

Please sign in to comment.