Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GaussAdjoint with callbacks #1060

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/callback_tracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ function _setup_reverse_callbacks(
du = first(get_tmp_cache(integrator))
λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S)

if sensealg isa GaussAdjoint
dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache
recursive_copyto!(dgrad, 0)
end

# if save_positions[2] = false, then the right limit is not saved. Thus, for
# the QuadratureAdjoint we would need to lift y from the left to the right limit.
# However, one also needs to update dgrad later on.
Expand Down Expand Up @@ -339,7 +344,10 @@ function _setup_reverse_callbacks(
vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS;
dgrad = dgrad, dy = dy)

dgrad !== nothing && (dgrad .*= -1)
if dgrad !== nothing && !(sensealg isa QuadratureAdjoint)
dgrad .*= -1
end

if cb isa Union{ContinuousCallback, VectorContinuousCallback}
# second correction to correct for left limit
@unpack Lu_left = correction
Expand All @@ -358,8 +366,13 @@ function _setup_reverse_callbacks(

λ .= dλ

if !(sensealg isa QuadratureAdjoint)
grad .-= dgrad
if sensealg isa GaussAdjoint
@assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad

#recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad)
elseif !(sensealg isa QuadratureAdjoint)
grad .= dgrad
end
end

Expand Down
22 changes: 13 additions & 9 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
G}
G, SAlg <: GaussAdjoint}
sol::S
p::pType
y::uType
Expand All @@ -8,15 +8,17 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG
f_cache::rateType
pJ::PJT
paramjac_config::PJC
sensealg::GaussAdjoint
sensealg::SAlg
dgdp_cache::DGP
dgdp::G
end

struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
Alg <: GaussAdjoint,
uType, SType, CPS, pType,
fType <: DiffEqBase.AbstractDiffEqFunction} <: SensitivityFunction
fType <: DiffEqBase.AbstractDiffEqFunction,
GI <: GaussIntegrand,
ICB} <: SensitivityFunction
diffcache::C
sensealg::Alg
discrete::Bool
Expand All @@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
checkpoint_sol::CPS
prob::pType
f::fType
GaussInt::GaussIntegrand
GaussInt::GI
integrating_cb::ICB
end

TruncatedStacktraces.@truncate_stacktrace ODEGaussAdjointSensitivityFunction
Expand All @@ -41,7 +44,7 @@ end
function ODEGaussAdjointSensitivityFunction(
g, sensealg, gaussint, discrete, sol, dgdu, dgdp,
f, alg,
checkpoints, tols, tstops = nothing;
checkpoints, integrating_cb, tols, tstops = nothing;
tspan = reverse(sol.prob.tspan))
checkpointing = ischeckpointing(sensealg, sol)
(checkpointing && checkpoints === nothing) &&
Expand Down Expand Up @@ -84,7 +87,7 @@ function ODEGaussAdjointSensitivityFunction(
g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg;
quad = true)
return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete,
y, sol, checkpoint_sol, sol.prob, f, gaussint)
y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb)
end

function Gaussfindcursor(intervals, t)
Expand Down Expand Up @@ -202,7 +205,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true
end

@noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg,
GaussInt::GaussIntegrand,
GaussInt::GaussIntegrand, integrating_cb,
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
Expand Down Expand Up @@ -275,7 +278,7 @@ end
λ = zero(u0)
end
sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol,
dgdu_continuous, dgdp_continuous, f, alg, checkpoints,
dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb,
(reltol = reltol, abstol = abstol), tstops, tspan = tspan)

init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end]
Expand Down Expand Up @@ -565,7 +568,8 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,

if sol.prob isa ODEProblem
adj_prob, cb2, rcb = ODEAdjointProblem(
sol, sensealg, alg, integrand, t, dgdu_discrete,
sol, sensealg, alg, integrand, integrating_cb,
t, dgdu_discrete,
dgdp_discrete,
dgdu_continuous, dgdp_continuous, g, Val(true);
checkpoints = checkpoints,
Expand Down
4 changes: 4 additions & 0 deletions test/callbacks/continuous_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,5 +291,9 @@ println("Continuous Callbacks")
sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP())
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
@test gFD≈gZy rtol=1e-10

sensealg = GaussAdjoint(autojacvec = EnzymeVJP())
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
@test gFD≈gZy rtol=1e-10
end
end
11 changes: 11 additions & 0 deletions test/callbacks/continuous_vs_discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ function test_continuous_wrt_discrete_callback()
saveat = tspan[2], save_start = false)),
u0, p)

du03, dp3 = Zygote.gradient(
(u0, p) -> sum(solve(prob, Tsit5(), u0 = u0, p = p,
callback = cb,
sensealg = GaussAdjoint(),
saveat = tspan[2], save_start = false)),
u0, p)

dstuff = ForwardDiff.gradient(
(θ) -> sum(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:4],
callback = cb, saveat = tspan[2],
Expand All @@ -173,8 +180,12 @@ function test_continuous_wrt_discrete_callback()
@test dp1 ≈ dstuff[3:4]
@test du02 ≈ dstuff[1:2]
@test dp2 ≈ dstuff[3:4]
@test du03 ≈ dstuff[1:2]
@test dp3 ≈ dstuff[3:4]
@test du01 ≈ du02
@test dp1 ≈ dp2
@test du01 ≈ du03
@test dp1 ≈ dp3
end

@testset "Compare continuous with discrete callbacks" begin
Expand Down
10 changes: 10 additions & 0 deletions test/callbacks/discrete_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ function test_discrete_callback(cb, tstops, g, dg!, cboop = nothing, tprev = fal
sensealg = QuadratureAdjoint())),
u0, p)

du05, dp5 = Zygote.gradient(
(u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p,
callback = cb, tstops = tstops,
abstol = abstol, reltol = reltol,
saveat = savingtimes,
sensealg = GaussAdjoint())),
u0, p)

dstuff = ForwardDiff.gradient(
(θ) -> g(solve(prob, Tsit5(), u0 = θ[1:2], p = θ[3:6],
callback = cb, tstops = tstops,
Expand Down Expand Up @@ -135,9 +143,11 @@ function test_discrete_callback(cb, tstops, g, dg!, cboop = nothing, tprev = fal
@test du01≈du03c rtol=1e-7
@test du03 ≈ du03c
@test du01 ≈ du04
@test du01 ≈ du05
@test dp1 ≈ dp3
@test dp1 ≈ dp3c
@test dp1≈dp4 rtol=1e-7
@test dp1≈dp5 rtol=1e-7

cb2 = SciMLSensitivity.track_callbacks(CallbackSet(cb), prob.tspan[1], prob.u0, prob.p,
BacksolveAdjoint(autojacvec = ReverseDiffVJP()))
Expand Down
10 changes: 10 additions & 0 deletions test/callbacks/vector_continuous_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ function test_vector_continuous_callback(cb, g)
sensealg = BacksolveAdjoint())),
u0, p)

du02, dp2 = @time Zygote.gradient(
(u0, p) -> g(solve(prob, Tsit5(), u0 = u0, p = p,
callback = cb, abstol = abstol,
reltol = reltol,
saveat = savingtimes,
sensealg = GaussAdjoint())),
u0, p)

dstuff = @time ForwardDiff.gradient(
(θ) -> g(solve(prob, Tsit5(), u0 = θ[1:4],
p = θ[5:6], callback = cb,
Expand All @@ -38,6 +46,8 @@ function test_vector_continuous_callback(cb, g)

@test du01 ≈ dstuff[1:4]
@test dp1 ≈ dstuff[5:6]
@test du02 ≈ dstuff[1:4]
@test dp2 ≈ dstuff[5:6]
end

@testset "VectorContinuous callbacks" begin
Expand Down
Loading