-
-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Handle Adjoints through Initialization #1168
base: master
Are you sure you want to change the base?
Changes from all commits
d6290bf
d3199c0
a4fa7c5
5a7dd26
94ec324
0c5564e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ jobs: | |
- Core5 | ||
- Core6 | ||
- Core7 | ||
- Core8 | ||
- QA | ||
- SDE1 | ||
- SDE2 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -425,6 +425,21 @@ function DiffEqBase._concrete_solve_adjoint( | |
save_end = true, kwargs_fwd...) | ||
end | ||
|
||
# Get gradients for the initialization problem if it exists | ||
igs = if _prob.f.initialization_data.initializeprob != nothing | ||
iprob = _prob.f.initialization_data.initializeprob | ||
ip = parameter_values(iprob) | ||
itunables, irepack, ialiases = canonicalize(Tunable(), ip) | ||
igs, = Zygote.gradient(ip) do ip | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This gradient isn't used? I think this would go into the backpass and if I'm thinking clearly, the resulting return is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not yet. These gradients are currently against the parameters of the initialization problem, not the system exactly. And the mapping between the two is ill defined, so we cannot simply I spoke with @AayushSabharwal about a way to map, it seems There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's another subtlety. I am not sure we haven't missed some part of the cfg by manually handling accumulation of gradients. Or any transforms we might need to calculate gradients for. The regular AD graph building typically took care of these details for us, but in this case we would need to worry about incorrect gradients manually There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes, you need to use the initializeprobmap https://github.com/SciML/SciMLBase.jl/blob/master/src/initialization.jl#L268 to map it back to the shape of the initial parameters.
p and dp just need the same ordering, so initializeprobmap should do the trick.
This is the only change to |
||
iprob2 = remake(iprob, p = ip) | ||
sol = solve(iprob2) | ||
sum(Array(sol)) | ||
end | ||
igs | ||
else | ||
nothing | ||
end | ||
|
||
# Force `save_start` and `save_end` in the forward pass This forces the | ||
# solver to do the backsolve all the way back to `u0` Since the start aliases | ||
# `_prob.u0`, this doesn't actually use more memory But it cleans up the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,13 +50,12 @@ end | |
sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, | ||
f, f.colorvec, needs_jac) | ||
(; diffcache, y, sol, λ, vjp, linsolve) = sense | ||
|
||
if needs_jac | ||
if SciMLBase.has_jac(f) | ||
f.jac(diffcache.J, y, p, nothing) | ||
else | ||
if DiffEqBase.isinplace(sol.prob) | ||
jacobian!(diffcache.J, diffcache.uf, y, diffcache.f_cache, | ||
jacobian!(diffcache.J.du, diffcache.uf, y, diffcache.f_cache, | ||
sensealg, diffcache.jac_config) | ||
else | ||
diffcache.J .= jacobian(diffcache.uf, y, sensealg) | ||
|
@@ -103,15 +102,18 @@ end | |
else | ||
if linsolve === nothing && isempty(sensealg.linsolve_kwargs) | ||
# For the default case use `\` to avoid any form of unnecessary cache allocation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I don't know about that comment. I think it's just old. (a) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So glad we can remove this branch altogether. |
||
vec(λ) .= diffcache.J' \ vec(dgdu_val) | ||
linear_problem = LinearProblem(diffcache.J.du', vec(dgdu_val'); u0 = vec(λ)) | ||
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ) | ||
else | ||
linear_problem = LinearProblem(diffcache.J', vec(dgdu_val'); u0 = vec(λ)) | ||
solve(linear_problem, linsolve; alias = LinearAliasSpecifier(alias_A = true), sensealg.linsolve_kwargs...) # u is vec(λ) | ||
end | ||
end | ||
|
||
try | ||
vecjacobian!(vec(dgdu_val), y, λ, p, nothing, sense; dgrad = vjp, dy = nothing) | ||
tunables, repack, aliases = canonicalize(Tunable(), p) | ||
vjp_tunables, vjp_repack, vjp_aliases = canonicalize(Tunable(), vjp) | ||
vecjacobian!(vec(dgdu_val), y, λ, tunables, nothing, sense; dgrad = vjp_tunables, dy = nothing) | ||
catch e | ||
if sense.sensealg.autojacvec === nothing | ||
@warn "Automatic AD choice of autojacvec failed in nonlinear solve adjoint, failing back to ODE adjoint + numerical vjp" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
using ModelingToolkit, OrdinaryDiffEq | ||
using ModelingToolkitStandardLibrary.Electrical | ||
using ModelingToolkitStandardLibrary.Blocks: Sine | ||
using NonlinearSolve | ||
import SciMLStructures as SS | ||
import SciMLSensitivity | ||
using Zygote | ||
|
||
function create_model(; C₁ = 3e-5, C₂ = 1e-6) | ||
@variables t | ||
@named resistor1 = Resistor(R = 5.0) | ||
@named resistor2 = Resistor(R = 2.0) | ||
@named capacitor1 = Capacitor(C = C₁) | ||
@named capacitor2 = Capacitor(C = C₂) | ||
@named source = Voltage() | ||
@named input_signal = Sine(frequency = 100.0) | ||
@named ground = Ground() | ||
@named ampermeter = CurrentSensor() | ||
|
||
eqs = [connect(input_signal.output, source.V) | ||
connect(source.p, capacitor1.n, capacitor2.n) | ||
connect(source.n, resistor1.p, resistor2.p, ground.g) | ||
connect(resistor1.n, capacitor1.p, ampermeter.n) | ||
connect(resistor2.n, capacitor2.p, ampermeter.p)] | ||
|
||
@named circuit_model = ODESystem(eqs, t, | ||
systems = [ | ||
resistor1, resistor2, capacitor1, capacitor2, | ||
source, input_signal, ground, ampermeter, | ||
]) | ||
end | ||
|
||
desauty_model = create_model() | ||
sys = structural_simplify(desauty_model) | ||
|
||
|
||
prob = ODEProblem(sys, [], (0.0, 0.1), guesses = [sys.resistor1.v => 1.]) | ||
iprob = prob.f.initialization_data.initializeprob | ||
isys = iprob.f.sys | ||
|
||
tunables, repack, aliases = SS.canonicalize(SS.Tunable(), parameter_values(iprob)) | ||
|
||
linsolve = LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization) | ||
sensealg = SciMLSensitivity.SteadyStateAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP(), linsolve = linsolve) | ||
igs, = Zygote.gradient(tunables) do p | ||
iprob2 = remake(iprob, p = repack(p)) | ||
sol = solve(iprob2, | ||
sensealg = sensealg | ||
) | ||
sum(Array(sol)) | ||
end | ||
|
||
@test !iszero(sum(igs)) | ||
|
||
|
||
# tunable_parameters(isys) .=> gs | ||
|
||
# gradient_unk1_idx = only(findfirst(x -> isequal(x, Initial(sys.capacitor1.v)), tunable_parameters(isys))) | ||
|
||
# gs[gradient_unk1_idx] | ||
|
||
# prob.f.initialization_data.update_initializeprob!(iprob, prob) | ||
# prob.f.initialization_data.update_initializeprob!(iprob, ::Vector) | ||
# prob.f.initialization_data.update_initializeprob!(iprob, gs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be before the solve, since you can use the initialization solution from here in the
remake
s of 397-405 in order to set newu0
andp
and thus skip running the initialization a second time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How can I indicate to
solve
to avoid running initialization?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initializealg = NoInit()
. Should probably just doCheckInit()
for safety but either is fine.