-
-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: fix remake autodiff tests and Zygote adjoint #943
Conversation
This should fix our Downstream, DiffEqBase and StochasticDiffEq CI. MTKStdlib needs some fixes which I'll get to, NonlinearSolve has "type piracy" between its own sub packages which is arguably fine but it would still be nice to make CI green. SciMLSensitivity has... problems. I've added Core6 and Core7 to CI here. I think Core6 should pass now, the type instability in |
Okay Downstream testset failure is because of hashconsing. Specifically, |
On the surface though, this is due to calling |
Yes it's non-numeric and can be skipped as a zerograd. |
I have the fixes ready locally. Needs SciML/ModelingToolkit.jl#3422. |
cca4c94
to
73e20dc
Compare
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) | ||
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] | ||
dp = grz[3] # pullback for p | ||
du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😅 oh no.
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int) | ||
function ODESolution_getindex_pullback(Δ) | ||
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym | ||
du, dprob = if i === nothing | ||
getter = getobserved(VA) | ||
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ) | ||
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)] | ||
dp = grz[3] # pullback for p | ||
dprob = remake(VA.prob, p = dp) | ||
du, dprob | ||
else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because the ChainRulesCore one is fixed? If so, just delete instead of comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, yeah.
function get_save_idxs_and_saved_subsystem(prob, save_idxs) | ||
if save_idxs === nothing | ||
saved_subsystem = nothing | ||
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this be an unnecessary allocation in the scalar case? You only want to do this when symbolic scalar.
if isempty(_save_idxs) | ||
# no states to save | ||
save_idxs = Int[] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just error early? This case is a bit odd. Saving nothing? Must be a user issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user could still be saving discrete variables
test/downstream/adjoints.jl
Outdated
@@ -35,7 +34,7 @@ p = [lorenz1.σ => 10.0, | |||
|
|||
tspan = (0.0, 100.0) | |||
prob = ODEProblem(sys, u0, tspan, p) | |||
sol = solve(prob, Rodas4()) | |||
sol = solve(prob, Tsit5()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sol = solve(prob, Tsit5()) | |
sol = solve(prob, Rodas4()) |
the point is to flex a non explicit method
1eebb96
to
949ae78
Compare
…ovements I have no idea why this works
949ae78
to
4853750
Compare
No description provided.