From 3b05baf34eedec0b82863d3230f190af6729ec2f Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Sep 2024 03:04:02 -0500 Subject: [PATCH] Enzyme: Adapt to pending version breaking update (#2490) [only downstream] --- .buildkite/pipeline.yml | 11 +- Project.toml | 4 +- ext/EnzymeCoreExt.jl | 283 +++++++++++++++++++++----------------- test/extensions/enzyme.jl | 6 +- 4 files changed, 172 insertions(+), 132 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2b82534b31..472e35fb98 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -221,11 +221,16 @@ steps: # to check compatibility, also add Enzyme to the main environment # (or Pkg.test, which merges both environments, could fail) Pkg.activate(".") + # Try to co-develop Enzyme and KA, if that fails, try just to dev Enzyme try - Pkg.develop("Enzyme") + Pkg.develop([PackageSpec("Enzyme"), PackageSpec("KernelAbstractions")]) catch err - @error "Could not install Enzyme" exception=(err,catch_backtrace()) - exit(3) + try + Pkg.develop([PackageSpec("Enzyme")]) + catch err + @error "Could not install Enzyme" exception=(err,catch_backtrace()) + exit(3) + end end end diff --git a/Project.toml b/Project.toml index 21c476cbca..c4c834f77f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "CUDA" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "5.5.0" +version = "5.5.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -56,7 +56,7 @@ CUDA_Runtime_jll = "0.15" ChainRulesCore = "1" Crayons = "4" DataFrames = "1" -EnzymeCore = "0.7.3" +EnzymeCore = "0.8.2" ExprTools = "0.1" GPUArrays = "10.0.1" GPUCompiler = "0.24, 0.25, 0.26, 0.27" diff --git a/ext/EnzymeCoreExt.jl b/ext/EnzymeCoreExt.jl index f8c8fe2c7d..51b6b20e38 100644 --- a/ext/EnzymeCoreExt.jl +++ b/ext/EnzymeCoreExt.jl @@ -1,5 +1,4 @@ # compatibility with EnzymeCore - module EnzymeCoreExt using CUDA @@ -32,19 +31,19 @@ function EnzymeCore.compiler_job_from_backend(::CUDABackend, @nospecialize(F::Ty return GPUCompiler.CompilerJob(mi, CUDA.compiler_config(CUDA.device())) end -function metaf(fn, args::Vararg{Any, N}) where N - EnzymeCore.autodiff_deferred(Forward, fn, Const, args...) +function metaf(config, fn, args::Vararg{Any, N}) where N + EnzymeCore.autodiff_deferred(EnzymeCore.set_runtime_activity(Forward, config), Const(fn), Const, args...) nothing end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)}, ::Type{<:Duplicated}, f::Const{F}, tt::Const{TT}; kwargs...) where {F,TT} res = ofn.val(f.val, tt.val; kwargs...) return Duplicated(res, res) end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cufunction)}, ::Type{BatchDuplicated{T,N}}, f::Const{F}, tt::Const{TT}; kwargs...) where {F,TT,T,N} res = ofn.val(f.val, tt.val; kwargs...) @@ -53,24 +52,32 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cufunction)}, end) end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(cudaconvert)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, x::IT) where {RT, IT} - if RT <: Duplicated - Duplicated(ofn.val(x.val), ofn.val(x.dval)) - elseif RT <: Const - ofn.val(x.val)::eltype(RT) - elseif RT <: DuplicatedNoNeed - ofn.val(x.val)::eltype(RT) - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta - ofn.val(x.dval[i])::eltype(RT) + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + Duplicated(ofn.val(x.val), ofn.val(x.dval)) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(x.dval[i])::eltype(RT) + end + BatchDuplicated(ofn.val(x.val), tup) end - if RT <: BatchDuplicated - BatchDuplicated(ofv.val(x.val), tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + ofn.val(x.dval)::EnzymeCore.shadow_type(config, RT) else - tup + (ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(x.dval[i])::eltype(RT) + end)::EnzymeCore.shadow_type(config, RT) end + elseif EnzymeRules.needs_primal(config) + ofn.val(uval.val)::eltype(RT) + else + nothing end end @@ -93,7 +100,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cudac else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, ::Type{RT}, tape, x::IT) where {RT, IT} @@ -101,64 +108,85 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(cudaconvert)}, end -function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}}, ::Type{RT}, uval::EnzymeCore.Annotation{UndefInitializer}, args...) where {CT <: CuArray, RT} primargs = ntuple(Val(length(args))) do i Base.@_inline_meta args[i].val end - if RT <: Duplicated - shadow = ofn.val(uval.val, primargs...)::CT - fill!(shadow, 0) - Duplicated(ofn.val(uval.val, primargs...), shadow) - elseif RT <: Const - ofn.val(uval.val, primargs...) - elseif RT <: DuplicatedNoNeed - shadow = ofn.val(uval.val, primargs...)::CT - fill!(shadow, 0) - shadow::CT - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 shadow = ofn.val(uval.val, primargs...)::CT fill!(shadow, 0) - shadow::CT + Duplicated(ofn.val(uval.val, primargs...), shadow) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + shadow = ofn.val(uval.val, primargs...)::CT + fill!(shadow, 0) + shadow::CT + end + BatchDuplicated(ofn.val(uval.val, primargs...), tup) end - if RT <: BatchDuplicated - BatchDuplicated(ofv.val(uval.val), tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(uval.val, primargs...)::CT + fill!(shadow, 0) + shadow::shadow_type(config, RT) else - tup + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + shadow = ofn.val(uval.val, primargs...)::CT + fill!(shadow, 0) + shadow::CT + end + tup::shadow_type(config, RT) end + elseif EnzymeRules.needs_primal(config) + ofn.val(uval.val, primargs...) + else + nothing end end -function EnzymeCore.EnzymeRules.forward(ofn::Const{Type{CT}}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{Type{CT}}, ::Type{RT}, uval::EnzymeCore.Annotation{DR}, args...; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT} primargs = ntuple(Val(length(args))) do i Base.@_inline_meta args[i].val end - if RT <: Duplicated - shadow = ofn.val(uval.val, primargs...; kwargs...) - Duplicated(ofn.val(uval.dval, primargs...; kwargs...), shadow) - elseif RT <: Const - ofn.val(uval.val, primargs...; kwargs...) - elseif RT <: DuplicatedNoNeed - ofn.val(uval.dval, primargs...; kwargs...) - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta - shadow = ofn.val(uval.dval[i], primargs...; kwargs...) + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(uval.val, primargs...; kwargs...) + Duplicated(ofn.val(uval.val, primargs...; kwargs...), shadow) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(uval.val, primargs...; kwargs...) + end + BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup) end - if RT <: BatchDuplicated - BatchDuplicated(ofv.val(uval.val), tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(uval.val, primargs...; kwargs...) + shadow else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + ofn.val(uval.val, primargs...; kwargs...) + end tup end + elseif EnzymeRules.needs_primal(config) + ofn.val(uval.val, primargs...; kwargs...) + else + nothing end end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(synchronize)}, ::Type{RT}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {RT, N} pargs = ntuple(Val(N)) do i Base.@_inline_meta @@ -166,35 +194,42 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(synchronize)}, end res = ofn.val(pargs...; kwargs...) - if RT <: Duplicated - return Duplicated(res, res) - elseif RT <: Const - return res - elseif RT <: DuplicatedNoNeed - return res - else - tup = ntuple(Val(EnzymeCore.batch_size(RT))) do i - Base.@_inline_meta - res + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + Duplicated(res, res) + else + tup = ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + res + end + BatchDuplicated(ofn.val(uval.val, primargs...; kwargs...), tup) end - if RT <: BatchDuplicated - return BatchDuplicated(res, tup) + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + res else - return tup + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + res + end end + elseif EnzymeRules.needs_primal(config) + res + else + nothing end end -function EnzymeCore.EnzymeRules.forward(ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}}, +function EnzymeCore.EnzymeRules.forward(config, ofn::EnzymeCore.Annotation{CUDA.HostKernel{F,TT}}, ::Type{Const{Nothing}}, args...; kwargs...) where {F,TT} GC.@preserve args begin args = ((cudaconvert(a) for a in args)...,) - T2 = (F, (typeof(a) for a in args)...) + T2 = (typeof(config), F, (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(metaf, TT2) - res = cuf(ofn.val.f, args...; kwargs...) + res = cuf(config, ofn.val.f, args...; kwargs...) end return nothing @@ -223,16 +258,17 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{typeof(cufun else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? CT : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? CT : NTuple{EnzymeRules.width(config), CT}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Const{typeof(cufunction)},::Type{RT}, subtape, f, tt; kwargs...) where RT return (nothing, nothing) end -function meta_augf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType} +function meta_augf(config, f, tape::CuDeviceArray{TapeType}, args::Vararg{Any, N}) where {N, TapeType} + ModifiedBetween = overwritten(config) forward, _ = EnzymeCore.autodiff_deferred_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, @@ -270,7 +306,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::EnzymeCore.Annotat ModifiedBetween = overwritten(config) TapeType = EnzymeCore.tape_type( EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(Base.identity), Tuple{Float64}), - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), Const{F}, Const{Nothing}, map(typeof, args)..., @@ -281,18 +317,19 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::EnzymeCore.Annotat GC.@preserve args subtape, begin subtape2 = cudaconvert(subtape) - T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...) + T2 = (typeof(config), F, typeof(subtape2), (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(meta_augf, TT2) - res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) + res = cuf(config, ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) end return AugmentedReturn{Nothing,Nothing,CuArray}(nothing, nothing, subtape) end -function meta_revf(f, tape::CuDeviceArray{TapeType}, ::Val{ModifiedBetween}, args::Vararg{Any, N}) where {N, ModifiedBetween, TapeType} +function meta_revf(config, f, tape::CuDeviceArray{TapeType}, args::Vararg{Any, N}) where {N, TapeType} + ModifiedBetween = overwritten(config) _, reverse = EnzymeCore.autodiff_deferred_thunk( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), TapeType, Const{Core.Typeof(f)}, Const{Nothing}, @@ -328,7 +365,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA. args = ((cudaconvert(arg) for arg in args0)...,) ModifiedBetween = overwritten(config) TapeType = EnzymeCore.tape_type( - ReverseSplitModified(ReverseSplitWithPrimal, Val(ModifiedBetween)), + ReverseSplitModified(EnzymeCore.set_runtime_activity(ReverseSplitWithPrimal, config), Val(ModifiedBetween)), Const{F}, Const{Nothing}, map(typeof, args)..., @@ -338,10 +375,10 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA. GC.@preserve args0 subtape, begin subtape2 = cudaconvert(subtape) - T2 = (F, typeof(subtape2), Val{ModifiedBetween}, (typeof(a) for a in args)...) + T2 = (typeof(config), F, typeof(subtape2), (typeof(a) for a in args)...) TT2 = Tuple{T2...} cuf = cufunction(meta_revf, TT2) - res = cuf(ofn.val.f, subtape2, Val(ModifiedBetween), args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) + res = cuf(config, ofn.val.f, subtape2, args...; threads=(threads.x, threads.y, threads.z), blocks=(blocks.x, blocks.y, blocks.z), kwargs...) end return ntuple(Val(length(args0))) do i @@ -350,7 +387,7 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::EnzymeCore.Annotation{CUDA. end end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes} +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(Base.fill!)}, ::Type{RT}, A::EnzymeCore.Annotation{<:DenseCuArray{T}}, x) where {RT, T <: CUDA.MemsetCompatTypes} if A isa Const || A isa Duplicated || A isa BatchDuplicated ofn.val(A.val, x.val) end @@ -365,16 +402,14 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(Base.fill!)}, ::Type{R end end - if RT <: Duplicated - return A - elseif RT <: Const - return A.val - elseif RT <: DuplicatedNoNeed - return A.dval - elseif RT <: BatchDuplicated - return A + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + A + elseif EnzymeRules.needs_shadow(config) + A.dval + elseif EnzymeRules.needs_primal(config) + A.val else - return A.dval + nothing end end @@ -469,7 +504,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, : else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{UndefInitializer}, args::Vararg{EnzymeCore.Annotation, N}) where {CT <: CuArray, RT, N} @@ -503,7 +538,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, ofn::Const{Type{CT}}, : else nothing end - return EnzymeRules.AugmentedReturn{(EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing), (EnzymeRules.needs_shadow(config) ? (EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{EnzymeRules.width(config), eltype(RT)}) : Nothing), Nothing}(primal, shadow, nothing) + return EnzymeRules.AugmentedReturn{EnzymeRules.primal_type(config, RT), EnzymeRules.shadow_type(config, RT), Nothing}(primal, shadow, nothing) end function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{Type{CT}}, ::Type{RT}, tape, A::EnzymeCore.Annotation{DR}, args::Vararg{EnzymeCore.Annotation, N}; kwargs...) where {CT <: CuArray, DR <: CUDA.DataRef, RT, N} @@ -517,7 +552,7 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...) return nothing end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim!)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays.mapreducedim!)}, ::Type{RT}, f::EnzymeCore.Const{typeof(Base.identity)}, op::EnzymeCore.Const{typeof(Base.add_sum)}, @@ -544,16 +579,14 @@ function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays.mapreducedim end end - if RT <: Duplicated - return R - elseif RT <: Const - return R.val - elseif RT <: DuplicatedNoNeed - return R.dval - elseif RT <: BatchDuplicated - return R + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + R + elseif EnzymeRules.needs_shadow(config) + R.dval + elseif EnzymeRules.needs_primal(config) + R.val else - return R.dval + nothing end end @@ -605,34 +638,36 @@ function EnzymeCore.EnzymeRules.reverse(config, ofn::Const{typeof(GPUArrays.mapr return (nothing, nothing, nothing, nothing) end -function EnzymeCore.EnzymeRules.forward(ofn::Const{typeof(GPUArrays._mapreduce)}, +function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays._mapreduce)}, ::Type{RT}, f::EnzymeCore.Const{typeof(Base.identity)}, op::EnzymeCore.Const{typeof(Base.add_sum)}, A::EnzymeCore.Annotation{<:AnyCuArray{T}}; dims::D, init) where {RT, T, D} - if RT <: Const + + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + shadow = ofn.val(f.val, op.val, A.dval; dims, init) + Duplicated(ofn.val(f.val, op.val, A.val; dims, init), shadow) + else + tup = ntuple(Val(EnzymeRules.batch_width(RT))) do i + Base.@_inline_meta + ofn.val(f.val, op.val, A.dval[i]; dims, init) + end + BatchDuplicated(ofn.val(f.val, op.val, A.val; dims, init), tup) + end + elseif EnzymeRules.needs_shadow(config) + if EnzymeRules.width(config) == 1 + ofn.val(f.val, op.val, A.dval; dims, init) + else + ntuple(Val(EnzymeRules.batch_width(RT))) do i + Base.@_inline_meta + ofn.val(f.val, op.val, A.dval[i]; dims, init) + end + end + elseif EnzymeRules.needs_primal(config) ofn.val(f.val, op.val, A.val; dims, init) - elseif RT <: Duplicated - ( - ofn.val(f.val, op.val, A.val; dims, init), - ofn.val(f.val, op.val, A.dval; dims, init) - ) - elseif RT <: DuplicatedNoNeed - ofn.val(f.val, op.val, A.dval; dims, init) - elseif RT <: BatchDuplicated - ( - ofn.val(f.val, op.val, A.val; dims, init), - ntuple(Val(EnzymeRules.batch_width(RT))) do i - Base.@_inline_meta - ofn.val(f.val, op.val, A.dval[i]; dims, init) - end - ) else - @assert RT <: BatchDuplicatedNoNeed - ntuple(Val(EnzymeRules.batch_width(RT))) do i - Base.@_inline_meta - ofn.val(f.val, op.val, A.dval[i]; dims, init) - end + nothing end end diff --git a/test/extensions/enzyme.jl b/test/extensions/enzyme.jl index 0ad29c91eb..093472ffd3 100644 --- a/test/extensions/enzyme.jl +++ b/test/extensions/enzyme.jl @@ -78,10 +78,10 @@ end alloc(x) = CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}(undef, (x,)) @testset "Forward allocate" begin - dup = Enzyme.autodiff(Forward, alloc, Duplicated, Const(10)) - @test all(dup[2] .≈ 0.0) + dup = Enzyme.autodiff(ForwardWithPrimal, alloc, Duplicated, Const(10)) + @test all(dup[1] .≈ 0.0) - dup = Enzyme.autodiff(Forward, alloc, DuplicatedNoNeed, Const(10)) + dup = Enzyme.autodiff(Forward, alloc, Duplicated, Const(10)) @test all(dup[1] .≈ 0.0) end