Skip to content

Commit

Permalink
dynamic initialized (again)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcapriot committed Mar 1, 2024
1 parent 9e72e6e commit 6906999
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions pydiso/mkl_solver.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,11 @@ cdef class MKLPardisoSolver:
cdef int_t mat_type
cdef int_t _factored
cdef size_t shape[2]
cdef int_t _initialized
cdef PyThread_type_lock lock
cdef void * a

cdef object _data_type
cdef object _Adata #a reference to make sure the pointer "a" doesn't get destroyed
cdef object _Adata # a reference to make sure the pointer "a" doesn't get destroyed

def __cinit__(self, *args, **kwargs):
self.lock = PyThread_allocate_lock()
Expand Down Expand Up @@ -259,7 +258,6 @@ cdef class MKLPardisoSolver:
>>> np.allclose(x, x_solved)
True
'''
self._initialized = False
n_row, n_col = A.shape
if n_row != n_col:
raise ValueError("Matrix is not square")
Expand Down Expand Up @@ -307,15 +305,14 @@ cdef class MKLPardisoSolver:
self._par64 = _PardisoParams64()
self._initialize(self._par64, A, matrix_type, verbose)

if(verbose):
if verbose:
#for reporting factorization progress via python's `print`
mkl_set_progress(mkl_progress)
else:
mkl_set_progress(mkl_no_progress)

self._set_A(A.data)
self._analyze()
self._initialized = True
self._factored = False
if factor:
self._factor()
Expand Down Expand Up @@ -345,6 +342,13 @@ cdef class MKLPardisoSolver:
self._set_A(A.data)
self._factor()

cdef _initialized(self):
cdef int i
for i in range(64):
if self.handle[i]:
return 1
return 0

def __call__(self, b):
return self.solve(b)

Expand Down Expand Up @@ -505,18 +509,18 @@ cdef class MKLPardisoSolver:
cdef int_t phase=-1, nrhs=0, error=0
cdef long_t phase64=-1, nrhs64=0, error64=0

if self._initialized:
if self._initialized():
PyThread_acquire_lock(self.lock, 1)
if self._is_32:
pardiso(
self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
&phase, &self._par.n, self.a, NULL, NULL, NULL, &nrhs, self._par.iparm,
&phase, &self._par.n, NULL, NULL, NULL, NULL, &nrhs, self._par.iparm,
&self._par.msglvl, NULL, NULL, &error
)
else:
pardiso_64(
self.handle, &self._par64.maxfct, &self._par64.mnum, &self._par64.mtype,
&phase64, &self._par64.n, self.a, NULL, NULL, NULL, &nrhs64,
&phase64, &self._par64.n, NULL, NULL, NULL, NULL, &nrhs64,
self._par64.iparm, &self._par64.msglvl, NULL, NULL, &error64
)
PyThread_release_lock(self.lock)
Expand Down

0 comments on commit 6906999

Please sign in to comment.