Skip to content

Commit

Permalink
fix DI api change
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Oct 3, 2024
1 parent 2323970 commit d7eee90
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
40 changes: 28 additions & 12 deletions src/icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ function augmented_f(
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1)]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ)
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
= -LinearAlgebra.dot(ϵJ, ϵ)
= if NORM_Z
LinearAlgebra.norm(ż)
Expand Down Expand Up @@ -227,7 +227,7 @@ function augmented_f(
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1)]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ)
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
du[begin:(end - n_aug - 1)] .=
du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ)
du[(end - n_aug + 1)] = if NORM_Z
Expand Down Expand Up @@ -256,8 +256,12 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1)]
ż, Jϵ =
DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ)
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
snn,
icnf.compute_mode.adback,
z,
(ϵ,),
)
= -LinearAlgebra.dot(ϵ, Jϵ)
= if NORM_Z
LinearAlgebra.norm(ż)
Expand Down Expand Up @@ -286,8 +290,12 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1)]
ż, Jϵ =
DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ)
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
snn,
icnf.compute_mode.adback,
z,
(ϵ,),
)
du[begin:(end - n_aug - 1)] .=
du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ)
du[(end - n_aug + 1)] = if NORM_Z
Expand Down Expand Up @@ -317,7 +325,7 @@ function augmented_f(
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1), :]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ)
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
= -sum(ϵJ .* ϵ; dims = 1)
= transpose(if NORM_Z
LinearAlgebra.norm.(eachcol(ż))
Expand Down Expand Up @@ -351,7 +359,7 @@ function augmented_f(
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1), :]
ż, ϵJ =
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ)
DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, (ϵ,))
du[begin:(end - n_aug - 1), :] .=
du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1))
du[(end - n_aug + 1), :] .= if NORM_Z
Expand Down Expand Up @@ -380,8 +388,12 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1), :]
ż, Jϵ =
DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ)
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
snn,
icnf.compute_mode.adback,
z,
(ϵ,),
)
= -sum.* Jϵ; dims = 1)
= transpose(if NORM_Z
LinearAlgebra.norm.(eachcol(ż))
Expand Down Expand Up @@ -414,8 +426,12 @@ function augmented_f(
n_aug = n_augment(icnf, mode)
snn = Lux.StatefulLuxLayer{true}(nn, p, st)
z = u[begin:(end - n_aug - 1), :]
ż, Jϵ =
DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ)
ż, Jϵ = DifferentiationInterface.value_and_pushforward(
snn,
icnf.compute_mode.adback,
z,
(ϵ,),
)
du[begin:(end - n_aug - 1), :] .=
du[(end - n_aug), :] .= -vec(sum.* Jϵ; dims = 1))
du[(end - n_aug + 1), :] .= if NORM_Z
Expand Down
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[i, :, :] = DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, z)
res[i, :, :] =
DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,))
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
end
y, eachslice(copy(res); dims = 3)
Expand All @@ -27,7 +28,7 @@ end
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[:, i, :] =
DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, z)
DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (z,))
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
end
y, eachslice(copy(res); dims = 3)
Expand Down

0 comments on commit d7eee90

Please sign in to comment.