Skip to content

Commit

Permalink
Merge pull request #431 from avik-pal/ap/preserve_dims
Browse files Browse the repository at this point in the history
Preserve the dimensions of ReverseDiff TrackedArray
  • Loading branch information
ChrisRackauckas authored Mar 8, 2024
2 parents bbf06e2 + f3f07b6 commit 36ae9e5
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 89 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.8.0"
version = "7.8.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

Expand All @@ -30,7 +29,6 @@ ArrayInterfaceTrackerExt = "Tracker"
[compat]
Adapt = "4"
LinearAlgebra = "1.10"
Requires = "1"
SparseArrays = "1.10"
SuiteSparse = "1.10"
julia = "1.10"
Expand Down
15 changes: 4 additions & 11 deletions ext/ArrayInterfaceBandedMatricesExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
module ArrayInterfaceBandedMatricesExt

if isdefined(Base, :get_extension)
using ArrayInterface
using ArrayInterface: BandedMatrixIndex
using BandedMatrices
using LinearAlgebra
else
using ..ArrayInterface
using ..ArrayInterface: BandedMatrixIndex
using ..BandedMatrices
using ..LinearAlgebra
end
using ArrayInterface
using ArrayInterface: BandedMatrixIndex
using BandedMatrices
using LinearAlgebra

const TransOrAdjBandedMatrix = Union{
Adjoint{T, <:BandedMatrix{T}},
Expand Down
17 changes: 4 additions & 13 deletions ext/ArrayInterfaceBlockBandedMatricesExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
module ArrayInterfaceBlockBandedMatricesExt



if isdefined(Base, :get_extension)
using ArrayInterface
using ArrayInterface: BandedMatrixIndex
using BlockBandedMatrices
using BlockBandedMatrices.BlockArrays
else
using ..ArrayInterface
using ..ArrayInterface: BandedMatrixIndex
using ..BlockBandedMatrices
using ..BlockBandedMatrices.BlockArrays
end
using ArrayInterface
using ArrayInterface: BandedMatrixIndex
using BlockBandedMatrices
using BlockBandedMatrices.BlockArrays

struct BlockBandedMatrixIndex <: ArrayInterface.MatrixIndex
count::Int
Expand Down
13 changes: 3 additions & 10 deletions ext/ArrayInterfaceCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
module ArrayInterfaceCUDAExt

using ArrayInterface

if isdefined(Base, :get_extension)
using CUDA
using CUDA.CUSOLVER
using LinearAlgebra
else
using ..CUDA
using ..CUDA.CUSOLVER
using ..LinearAlgebra
end
using CUDA
using CUDA.CUSOLVER
using LinearAlgebra

function ArrayInterface.lu_instance(A::CuMatrix{T}) where {T}
if VERSION >= v"1.8-"
Expand Down
16 changes: 4 additions & 12 deletions ext/ArrayInterfaceGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
module ArrayInterfaceGPUArraysCoreExt


if isdefined(Base, :get_extension)
using Adapt
using ArrayInterface
using LinearAlgebra: lu
import GPUArraysCore
else
using Adapt # Will cause problems for relocatability.
using ..ArrayInterface
using ..LinearAlgebra: lu
import ..GPUArraysCore
end
using Adapt
using ArrayInterface
using LinearAlgebra: lu
import GPUArraysCore

ArrayInterface.fast_scalar_indexing(::Type{<:GPUArraysCore.AbstractGPUArray}) = false
@inline ArrayInterface.allowed_getindex(x::GPUArraysCore.AbstractGPUArray, i...) = GPUArraysCore.@allowscalar(x[i...])
Expand Down
15 changes: 5 additions & 10 deletions ext/ArrayInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
module ArrayInterfaceReverseDiffExt

if isdefined(Base, :get_extension)
using ArrayInterface
import ReverseDiff
else
using ..ArrayInterface
import ..ReverseDiff
end
using ArrayInterface
import ReverseDiff

ArrayInterface.ismutable(::Type{<:ReverseDiff.TrackedArray}) = false
ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false
ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false
function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal,N}) where {N}
function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N}
if length(x) > 1
reduce(vcat,x)
return reshape(reduce(vcat, x), size(x))
else
reduce(vcat,[x[1],x[1]])[1:1]
return reduce(vcat,[x[1], x[1]])[1:1]
end
end

Expand Down
12 changes: 3 additions & 9 deletions ext/ArrayInterfaceStaticArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
module ArrayInterfaceStaticArraysCoreExt

if isdefined(Base, :get_extension)
import ArrayInterface
using LinearAlgebra
import StaticArraysCore
else
import ..ArrayInterface
using ..LinearAlgebra
import ..StaticArraysCore
end
import ArrayInterface
using LinearAlgebra
import StaticArraysCore

function ArrayInterface.undefmatrix(::StaticArraysCore.MArray{S, T, N, L}) where {S, T, N, L}
return StaticArraysCore.MMatrix{L, L, T, L*L}(undef)
Expand Down
9 changes: 2 additions & 7 deletions ext/ArrayInterfaceTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
module ArrayInterfaceTrackerExt

if isdefined(Base, :get_extension)
using ArrayInterface
import Tracker
else
using ..ArrayInterface
import ..Tracker
end
using ArrayInterface
import Tracker

ArrayInterface.ismutable(::Type{<:Tracker.TrackedArray}) = false
ArrayInterface.ismutable(T::Type{<:Tracker.TrackedReal}) = false
Expand Down
14 changes: 0 additions & 14 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1000,18 +1000,4 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true
ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false
ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x))

## Extensions

import Requires
@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require BandedMatrices = "aae01518-5342-5314-be14-df237901396f" begin include("../ext/ArrayInterfaceBandedMatricesExt.jl") end
Requires.@require BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" begin include("../ext/ArrayInterfaceBlockBandedMatricesExt.jl") end
Requires.@require GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" begin include("../ext/ArrayInterfaceGPUArraysCoreExt.jl") end
Requires.@require StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticArraysCoreExt.jl") end
Requires.@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin include("../ext/ArrayInterfaceCUDAExt.jl") end
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/ArrayInterfaceTrackerExt.jl") end
end
end

end # module

0 comments on commit 36ae9e5

Please sign in to comment.