Skip to content

Commit

Permalink
WIP: Wrap BLIS
Browse files Browse the repository at this point in the history
Test case:

```julia
using LinearSolve, blis_jll

A = rand(4, 4)
b = rand(4)
prob = LinearProblem(A, b)
sol = solve(prob,LinearSolve.BLISLUFactorization())
sol.u
```

throws:

```julia
julia> sol = solve(prob,LinearSolve.BLISLUFactorization())
ERROR: TypeError: in ccall: first argument not a pointer or valid constant expression, expected Ptr, got a value of type Tuple{Symbol, Ptr{Nothing}}
Stacktrace:
 [1] getrf!(A::Matrix{Float64}; ipiv::Vector{Int64}, info::Base.RefValue{Int64}, check::Bool)
   @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:67
 [2] getrf!
   @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:55 [inlined]
 [3] #solve!#9
   @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:222 [inlined]
 [4] solve!
   @ LinearSolveBLISExt ~/.julia/dev/LinearSolve/ext/LinearSolveBLISExt.jl:216 [inlined]
 [5] #solve!#6
   @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:209 [inlined]
 [6] solve!
   @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:208 [inlined]
 [7] #solve#5
   @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:205 [inlined]
 [8] solve(::LinearProblem{…}, ::LinearSolve.BLISLUFactorization)
   @ LinearSolve ~/.julia/dev/LinearSolve/src/common.jl:202
 [9] top-level scope
   @ REPL[8]:1
Some type information was truncated. Use `show(err)` to see complete types.
```
  • Loading branch information
ChrisRackauckas committed Nov 12, 2023
1 parent a455e27 commit 2fea1c2
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
blis_jll = "6136c539-28a5-5bf0-87cc-b183200dce32"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771"
Expand All @@ -44,6 +45,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

[extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBLISExt = "blis_jll"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = "Enzyme"
Expand All @@ -58,6 +60,7 @@ LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
[compat]
ArrayInterface = "7.4.11"
BandedMatrices = "1"
blis_jll = "0.9.0"
BlockDiagonals = "0.1"
ConcreteStructs = "0.2"
DocStringExtensions = "0.9"
Expand Down
248 changes: 248 additions & 0 deletions ext/LinearSolveBLISExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
module LinearSolveBLISExt

using Libdl
using blis_jll
using LinearAlgebra
using LinearSolve

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1,
@blasfunc, chkargsok
using LinearSolve: ArrayInterface, BLISLUFactorization, @get_cacheval, LinearCache, SciMLBase

const global libblis = dlopen(blis_jll.blis_path)

function getrf!(A::AbstractMatrix{<:ComplexF64};

Check warning on line 15 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L15

Added line #L15 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 25 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L19-L25

Added lines #L19 - L25 were not covered by tests
end
ccall((@blasfunc(zgetrf_), libblis), Cvoid,

Check warning on line 27 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L27

Added line #L27 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 32 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L31-L32

Added lines #L31 - L32 were not covered by tests
end

function getrf!(A::AbstractMatrix{<:ComplexF32};

Check warning on line 35 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L35

Added line #L35 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 45 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L39-L45

Added lines #L39 - L45 were not covered by tests
end
ccall((@blasfunc(cgetrf_), libblis), Cvoid,

Check warning on line 47 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L47

Added line #L47 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 52 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L51-L52

Added lines #L51 - L52 were not covered by tests
end

function getrf!(A::AbstractMatrix{<:Float64};

Check warning on line 55 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L55

Added line #L55 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 65 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L59-L65

Added lines #L59 - L65 were not covered by tests
end
ccall((@blasfunc(dgetrf_), libblis), Cvoid,

Check warning on line 67 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L67

Added line #L67 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 72 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L71-L72

Added lines #L71 - L72 were not covered by tests
end

function getrf!(A::AbstractMatrix{<:Float32};

Check warning on line 75 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L75

Added line #L75 was not covered by tests
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2))),
info = Ref{BlasInt}(),
check = false)
require_one_based_indexing(A)
check && chkfinite(A)
chkstride1(A)
m, n = size(A)
lda = max(1, stride(A, 2))
if isempty(ipiv)
ipiv = similar(A, BlasInt, min(size(A, 1), size(A, 2)))

Check warning on line 85 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L79-L85

Added lines #L79 - L85 were not covered by tests
end
ccall((@blasfunc(sgetrf_), libblis), Cvoid,

Check warning on line 87 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L87

Added line #L87 was not covered by tests
(Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
m, n, A, lda, ipiv, info)
chkargsok(info[])
A, ipiv, info[], info #Error code is stored in LU factorization type

Check warning on line 92 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L91-L92

Added lines #L91 - L92 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 95 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L95

Added line #L95 was not covered by tests
A::AbstractMatrix{<:ComplexF64},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:ComplexF64};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 105 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L100-L105

Added lines #L100 - L105 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 108 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end
nrhs = size(B, 2)
ccall(("zgetrs_", libblis), Cvoid,

Check warning on line 111 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L110-L111

Added lines #L110 - L111 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 117 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 120 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L120

Added line #L120 was not covered by tests
A::AbstractMatrix{<:ComplexF32},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:ComplexF32};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 130 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L125-L130

Added lines #L125 - L130 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 133 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L132-L133

Added lines #L132 - L133 were not covered by tests
end
nrhs = size(B, 2)
ccall(("cgetrs_", libblis), Cvoid,

Check warning on line 136 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 142 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L141-L142

Added lines #L141 - L142 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 145 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L145

Added line #L145 was not covered by tests
A::AbstractMatrix{<:Float64},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:Float64};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 155 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L150-L155

Added lines #L150 - L155 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 158 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L157-L158

Added lines #L157 - L158 were not covered by tests
end
nrhs = size(B, 2)
ccall(("dgetrs_", libblis), Cvoid,

Check warning on line 161 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L160-L161

Added lines #L160 - L161 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 167 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L166-L167

Added lines #L166 - L167 were not covered by tests
end

function getrs!(trans::AbstractChar,

Check warning on line 170 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L170

Added line #L170 was not covered by tests
A::AbstractMatrix{<:Float32},
ipiv::AbstractVector{BlasInt},
B::AbstractVecOrMat{<:Float32};
info = Ref{BlasInt}())
require_one_based_indexing(A, ipiv, B)
LinearAlgebra.LAPACK.chktrans(trans)
chkstride1(A, B, ipiv)
n = LinearAlgebra.checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n"))

Check warning on line 180 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L175-L180

Added lines #L175 - L180 were not covered by tests
end
if n != length(ipiv)
throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n"))

Check warning on line 183 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
end
nrhs = size(B, 2)
ccall(("sgetrs_", libblis), Cvoid,

Check warning on line 186 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L185-L186

Added lines #L185 - L186 were not covered by tests
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, max(1, stride(B, 2)), info,
1)
LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[]))
B

Check warning on line 192 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L191-L192

Added lines #L191 - L192 were not covered by tests
end

default_alias_A(::BLISLUFactorization, ::Any, ::Any) = false
default_alias_b(::BLISLUFactorization, ::Any, ::Any) = false

Check warning on line 196 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L195-L196

Added lines #L195 - L196 were not covered by tests

const PREALLOCATED_BLIS_LU = begin
A = rand(0, 0)
luinst = ArrayInterface.lu_instance(A), Ref{BlasInt}()
end

function LinearSolve.init_cacheval(alg::BLISLUFactorization, A, b, u, Pl, Pr,

Check warning on line 203 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L203

Added line #L203 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
PREALLOCATED_BLIS_LU

Check warning on line 206 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L206

Added line #L206 was not covered by tests
end

function LinearSolve.init_cacheval(alg::BLISLUFactorization, A::AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}}, b, u, Pl, Pr,

Check warning on line 209 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L209

Added line #L209 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
A = rand(eltype(A), 0, 0)
ArrayInterface.lu_instance(A), Ref{BlasInt}()

Check warning on line 213 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L212-L213

Added lines #L212 - L213 were not covered by tests
end

function SciMLBase.solve!(cache::LinearCache, alg::BLISLUFactorization;

Check warning on line 216 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L216

Added line #L216 was not covered by tests
kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
cacheval = @get_cacheval(cache, :BLISLUFactorization)
res = getrf!(A; ipiv = cacheval[1].ipiv, info = cacheval[2])
fact = LU(res[1:3]...), res[4]
cache.cacheval = fact
cache.isfresh = false

Check warning on line 225 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L218-L225

Added lines #L218 - L225 were not covered by tests
end

y = ldiv!(cache.u, @get_cacheval(cache, :BLISLUFactorization)[1], cache.b)
SciMLBase.build_linear_solution(alg, y, nothing, cache)

Check warning on line 229 in ext/LinearSolveBLISExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LinearSolveBLISExt.jl#L228-L229

Added lines #L228 - L229 were not covered by tests

#=
A, info = @get_cacheval(cache, :BLISLUFactorization)
LinearAlgebra.require_one_based_indexing(cache.u, cache.b)
m, n = size(A, 1), size(A, 2)
if m > n
Bc = copy(cache.b)
getrs!('N', A.factors, A.ipiv, Bc; info)
return copyto!(cache.u, 1, Bc, 1, n)
else
copyto!(cache.u, cache.b)
getrs!('N', A.factors, A.ipiv, cache.u; info)
end
SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
=#
end

end
2 changes: 2 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,5 @@ A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pr
to avoid allocations and automatically offloads to the GPU.
"""
struct MetalLUFactorization <: AbstractFactorization end

struct BLISLUFactorization <: AbstractFactorization end

0 comments on commit 2fea1c2

Please sign in to comment.