From d7eee90cb211c2e778b55f35767080acdd207f56 Mon Sep 17 00:00:00 2001 From: Hossein Pourbozorg Date: Thu, 3 Oct 2024 17:57:00 +0330 Subject: [PATCH] fix DI api change --- src/icnf.jl | 40 ++++++++++++++++++++++++++++------------ src/utils.jl | 5 +++-- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/src/icnf.jl b/src/icnf.jl index 906c6e73..00560f72 100644 --- a/src/icnf.jl +++ b/src/icnf.jl @@ -197,7 +197,7 @@ function augmented_f( snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) l̇ = -LinearAlgebra.dot(ϵJ, ϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -227,7 +227,7 @@ function augmented_f( snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -256,8 +256,12 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + ż, Jϵ = DifferentiationInterface.value_and_pushforward( + snn, + icnf.compute_mode.adback, + z, + (ϵ,), + ) l̇ = -LinearAlgebra.dot(ϵ, Jϵ) Ė = if NORM_Z LinearAlgebra.norm(ż) @@ -286,8 +290,12 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1)] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + ż, Jϵ = DifferentiationInterface.value_and_pushforward( + snn, + icnf.compute_mode.adback, + z, + (ϵ,), + ) du[begin:(end - n_aug - 1)] .= ż du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) du[(end - n_aug + 1)] = if NORM_Z @@ -317,7 +325,7 @@ function augmented_f( snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) l̇ = -sum(ϵJ .* ϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -351,7 +359,7 @@ function augmented_f( snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,)) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z @@ -380,8 +388,12 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + ż, Jϵ = DifferentiationInterface.value_and_pushforward( + snn, + icnf.compute_mode.adback, + z, + (ϵ,), + ) l̇ = -sum(ϵ .* Jϵ; dims = 1) Ė = transpose(if NORM_Z LinearAlgebra.norm.(eachcol(ż)) @@ -414,8 +426,12 @@ function augmented_f( n_aug = n_augment(icnf, mode) snn = Lux.StatefulLuxLayer{true}(nn, p, st) z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + ż, Jϵ = DifferentiationInterface.value_and_pushforward( + snn, + icnf.compute_mode.adback, + z, + (ϵ,), + ) du[begin:(end - n_aug - 1), :] .= ż du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) du[(end - n_aug + 1), :] .= if NORM_Z diff --git a/src/utils.jl b/src/utils.jl index 4c0c389c..6d1411ea 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,8 @@ res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) - res[i, :, :] = DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, z) + res[i, :, :] = + DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,)) ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) end y, eachslice(copy(res); dims = 3) @@ -27,7 +28,7 @@ end for i in axes(xs, 1) ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) res[:, i, :] = - DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, z) + DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (z,)) ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) end y, eachslice(copy(res); dims = 3)