Skip to content

Commit

Permalink
Handle wrapper types over GPU arrays correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 18, 2023
1 parent ada5366 commit 65d76be
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.21.0"
version = "2.21.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
6 changes: 3 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ end
end
end

function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool})
function defaultalg(A::GPUArraysCore.AnyGPUArray, b, assump::OperatorAssumptions{Bool})

Check warning on line 104 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L104

Added line #L104 was not covered by tests
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -110,7 +110,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump
end

# A === nothing case
function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool})
function defaultalg(A::Nothing, b::GPUArraysCore.AnyGPUArray, assump::OperatorAssumptions{Bool})

Check warning on line 113 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L113

Added line #L113 was not covered by tests
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -119,7 +119,7 @@ function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::Opera
end

# Ambiguity handling
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
function defaultalg(A::GPUArraysCore.AnyGPUArray, b::GPUArraysCore.AbstractGPUArray,

Check warning on line 122 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L122

Added line #L122 was not covered by tests
assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
Expand Down
35 changes: 30 additions & 5 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
if A isa AbstractSparseMatrixCSC
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
elseif A isa GPUArraysCore.AnyGPUArray
fact = lu(A; check = false)

Check warning on line 82 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L82

Added line #L82 was not covered by tests
elseif !ArrayInterface.can_setindex(typeof(A))
fact = lu(A, alg.pivot, check = false)
else
Expand All @@ -98,6 +100,17 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, A, b
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
end

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},

Check warning on line 103 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L103

Added line #L103 was not covered by tests
A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
if alg isa LUFactorization
return lu(A; check=false)

Check warning on line 107 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L106-L107

Added lines #L106 - L107 were not covered by tests
else
A isa GPUArraysCore.AnyGPUArray && return nothing
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check=false)

Check warning on line 110 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L109-L110

Added lines #L109 - L110 were not covered by tests
end
end

const PREALLOCATED_LU = ArrayInterface.lu_instance(rand(1, 1))

function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
Expand Down Expand Up @@ -143,7 +156,7 @@ end
function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AnyGPUArray)

Check warning on line 159 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L159

Added line #L159 was not covered by tests
fact = qr!(A, alg.pivot)
else
fact = qr(A) # CUDA.jl does not allow other args!
Expand All @@ -160,6 +173,12 @@ function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr,
ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
end

function init_cacheval(alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr,

Check warning on line 176 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L176

Added line #L176 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
A isa GPUArraysCore.AnyGPUArray && return qr(A)
return qr(A, alg.pivot)

Check warning on line 179 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L178-L179

Added lines #L178 - L179 were not covered by tests
end

const PREALLOCATED_QR = ArrayInterface.qr_instance(rand(1, 1))

function init_cacheval(alg::QRFactorization{NoPivot}, A::Matrix{Float64}, b, u, Pl, Pr,
Expand Down Expand Up @@ -204,6 +223,8 @@ function do_factorization(alg::CholeskyFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if A isa SparseMatrixCSC
fact = cholesky(A; shift = alg.shift, check = false, perm = alg.perm)
elseif A isa GPUArraysCore.AnyGPUArray
fact = cholesky(A; check = false)

Check warning on line 227 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L226-L227

Added lines #L226 - L227 were not covered by tests
elseif alg.pivot === Val(false) || alg.pivot === NoPivot()
fact = cholesky!(A, alg.pivot; check = false)
else
Expand All @@ -218,9 +239,13 @@ function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl,
cholesky(A)
end

function init_cacheval(alg::CholeskyFactorization, A::GPUArraysCore.AnyGPUArray, b, u, Pl,

Check warning on line 242 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L242

Added line #L242 was not covered by tests
Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
cholesky(A; check=false)

Check warning on line 244 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L244

Added line #L244 was not covered by tests
end

function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot)
end

Expand Down Expand Up @@ -968,7 +993,7 @@ default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true
const PREALLOCATED_NORMALCHOLESKY = ArrayInterface.cholesky_instance(rand(1, 1), NoPivot())

function init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray, GPUArraysCore.AbstractGPUArray,
A::Union{AbstractSparseArray, GPUArraysCore.AnyGPUArray,
Symmetric{<:Number, <:AbstractSparseArray}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -999,7 +1024,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
if A isa SparseMatrixCSC || A isa GPUArraysCore.AnyGPUArray || A isa SMatrix

Check warning on line 1027 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L1027

Added line #L1027 was not covered by tests
fact = cholesky(Symmetric((A)' * A); check = false)
else
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
Expand Down
3 changes: 2 additions & 1 deletion test/gpu/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
19 changes: 18 additions & 1 deletion test/gpu/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using LinearSolve, CUDA, LinearAlgebra, SparseArrays
using LinearSolve, CUDA, LinearAlgebra, SparseArrays, StableRNGs
using Test

CUDA.allowscalar(false)
Expand Down Expand Up @@ -73,3 +73,20 @@ using BlockDiagonals

@test solve(prob1, SimpleGMRES(; blocksize = 2)).u solve(prob2, SimpleGMRES()).u
end

# Test Dispatches for Adjoint/Transpose Types
rng = StableRNG(0)

A = Matrix(Hermitian(rand(rng, 5, 5) + I)) |> cu
b = rand(rng, 5) |> cu
prob1 = LinearProblem(A', b)
prob2 = LinearProblem(transpose(A), b)

@testset "Adjoint/Transpose Type: $(alg)" for alg in (NormalCholeskyFactorization(),
CholeskyFactorization(), LUFactorization(), QRFactorization(), nothing)
sol = solve(prob1, alg; alias_A = false)
@test norm(A' * sol.u .- b) < 1e-5

sol = solve(prob2, alg; alias_A = false)
@test norm(transpose(A) * sol.u .- b) < 1e-5
end

0 comments on commit 65d76be

Please sign in to comment.