diff --git a/ext/LinearSolvePardisoExt.jl b/ext/LinearSolvePardisoExt.jl index 0b4cfbb1..0318bb8a 100644 --- a/ext/LinearSolvePardisoExt.jl +++ b/ext/LinearSolvePardisoExt.jl @@ -134,12 +134,11 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs if cache.isfresh phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT Pardiso.set_phase!(cache.cacheval, phase) - Pardiso.pardiso(cache.cacheval, A, eltype(A)[]) + Pardiso.pardiso(cache.cacheval, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), eltype(A)[]) cache.isfresh = false end Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE) - Pardiso.pardiso(cache.cacheval, u, A, b) - + Pardiso.pardiso(cache.cacheval, u, SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), b) return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 7534d2fa..2559a210 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -217,7 +217,7 @@ All values default to `nothing` and the solver internally determines the values given the input types, and these keyword arguments are only for overriding the default handling process. This should not be required by most users. """ -struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm +struct PardisoJL{T1, T2} <: AbstractSparseFactorization nprocs::Union{Int, Nothing} solver_type::T1 matrix_type::T2 diff --git a/test/pardiso/pardiso.jl b/test/pardiso/pardiso.jl index a961a53d..c6af3cf4 100644 --- a/test/pardiso/pardiso.jl +++ b/test/pardiso/pardiso.jl @@ -177,3 +177,40 @@ for solver in solvers @test Pardiso.get_iparm(solver, i) == iparm[i][2] end end + +@testset "AbstractSparseMatrixCSC" begin + struct MySparseMatrixCSC2{Tv, Ti} <: SparseArrays.AbstractSparseMatrixCSC{Tv, Ti} + csc::SparseMatrixCSC{Tv, Ti} + end + + Base.size(m::MySparseMatrixCSC2) = size(m.csc) + SparseArrays.getcolptr(m::MySparseMatrixCSC2) = SparseArrays.getcolptr(m.csc) + SparseArrays.rowvals(m::MySparseMatrixCSC2) = SparseArrays.rowvals(m.csc) + SparseArrays.nonzeros(m::MySparseMatrixCSC2) = SparseArrays.nonzeros(m.csc) + + for alg in algs + N = 100 + u0 = ones(N) + A0 = spdiagm(1 => -ones(N - 1), 0 => fill(10.0, N), -1 => -ones(N - 1)) + b0 = A0 * u0 + B0 = MySparseMatrixCSC2(A0) + A1 = spdiagm(1 => -ones(N - 1), 0 => fill(100.0, N), -1 => -ones(N - 1)) + b1=A1*u0 + B1= MySparseMatrixCSC2(A1) + + + pr = LinearProblem(B0, b0) + # test default algorithn + u=solve(pr,alg) + @test norm(u - u0, Inf) < 1.0e-13 + + # test factorization with reinit! + pr = LinearProblem(B0, b0) + cache=init(pr,alg) + u=solve!(cache) + @test norm(u - u0, Inf) < 1.0e-13 + reinit!(cache; A=B1, b=b1) + u=solve!(cache) + @test norm(u - u0, Inf) < 1.0e-13 + end +end