diff --git a/Project.toml b/Project.toml index 5f3aba73..b5ba13fb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.7.1" +version = "7.8.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -14,6 +14,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -22,16 +23,17 @@ ArrayInterfaceBandedMatricesExt = "BandedMatrices" ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" ArrayInterfaceCUDAExt = "CUDA" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" +ArrayInterfaceReverseDiffExt = "ReverseDiff" ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" ArrayInterfaceTrackerExt = "Tracker" [compat] -Adapt = "3, 4" -LinearAlgebra = "1.9" +Adapt = "4" +LinearAlgebra = "1.10" Requires = "1" -SparseArrays = "1.9" -SuiteSparse = "1.9" -julia = "1.9" +SparseArrays = "1.10" +SuiteSparse = "1.10" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" @@ -41,6 +43,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -50,4 +53,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [targets] -test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker"] +test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker", "ReverseDiff"] diff --git a/ext/ArrayInterfaceReverseDiffExt.jl b/ext/ArrayInterfaceReverseDiffExt.jl new file mode 100644 index 00000000..13bfd23f --- /dev/null +++ b/ext/ArrayInterfaceReverseDiffExt.jl @@ -0,0 +1,24 @@ +module ArrayInterfaceReverseDiffExt + +if isdefined(Base, :get_extension) + using ArrayInterface + import ReverseDiff +else + using ..ArrayInterface + import ..ReverseDiff +end + +ArrayInterface.ismutable(::Type{<:ReverseDiff.TrackedArray}) = false +ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false +ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false +ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false +function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal,N}) where {N} + if length(x) > 1 + reduce(vcat,x) + else + @show "here?" + reduce(vcat,[x[1],x[1]])[1:1] + end +end + +end # module diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 00000000..7c29c8dd --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,20 @@ +using ArrayInterface, ReverseDiff, Tracker, Test +x = ReverseDiff.track([4.0]) +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +x = reduce(vcat, ReverseDiff.track([4.0,4.0])) +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +x = [ReverseDiff.track([4.0])[1]] +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +x = reduce(vcat, ReverseDiff.track([4.0,4.0])) +x = [x[1],x[2]] +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray + +x = Tracker.TrackedArray([4.0]) +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray +x = [Tracker.TrackedArray([4.0])[1]] +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray +x = Tracker.TrackedArray([4.0,4.0]) +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray +x = reduce(vcat, Tracker.TrackedArray([4.0,4.0])) +x = [x[1],x[2]] +@test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray diff --git a/test/runtests.jl b/test/runtests.jl index ec3493fd..bedb7693 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,7 @@ end @time @safetestset "BandedMatrices" begin include("bandedmatrices.jl") end @time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end @time @safetestset "Core" begin include("core.jl") end + @time @safetestset "AD Integration" begin include("ad.jl") end @time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end end @@ -20,4 +21,4 @@ end activate_gpu_env() @time @safetestset "CUDA" begin include("gpu/cuda.jl") end end -end \ No newline at end of file +end