Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix remake autodiff tests and Zygote adjoint #943

Merged
merged 22 commits into from
Mar 6, 2025

Conversation

AayushSabharwal
Copy link
Member

No description provided.

@AayushSabharwal
Copy link
Member Author

This should fix our Downstream, DiffEqBase and StochasticDiffEq CI. MTKStdlib needs some fixes which I'll get to, NonlinearSolve has "type piracy" between its own sub packages which is arguably fine but it would still be nice to make CI green. SciMLSensitivity has... problems. I've added Core6 and Core7 to CI here. I think Core6 should pass now, the type instability in late_binding_update_u0_p was fixed.

@AayushSabharwal
Copy link
Member Author

Okay Downstream testset failure is because of hashconsing. Specifically, WeakValueDict has a lock call which Zygote can't AD through.

@AayushSabharwal
Copy link
Member Author

On the surface though, this is due to calling getproperty(::MTK.AbstractSystem, ::Symbol). We could just define a zero adjoint for that?

@ChrisRackauckas
Copy link
Member

Yes it's non-numeric and can be skipped as a zerograd.

@AayushSabharwal
Copy link
Member Author

I have the fixes ready locally. Needs SciML/ModelingToolkit.jl#3422.

Comment on lines 21 to +22
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅 oh no.

Comment on lines -43 to -53
@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
dprob = remake(VA.prob, p = dp)
du, dprob
else
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because the ChainRulesCore one is fixed? If so, just delete instead of comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, yeah.

function get_save_idxs_and_saved_subsystem(prob, save_idxs)
if save_idxs === nothing
saved_subsystem = nothing
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this be an unnecessary allocation in the scalar case? You only want to do this when symbolic scalar.

Comment on lines +390 to +385
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just error early? This case is a bit odd. Saving nothing? Must be a user issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user could still be saving discrete variables

@@ -35,7 +34,7 @@ p = [lorenz1.σ => 10.0,

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p)
sol = solve(prob, Rodas4())
sol = solve(prob, Tsit5())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sol = solve(prob, Tsit5())
sol = solve(prob, Rodas4())

the point is to flex a non explicit method

@ChrisRackauckas ChrisRackauckas merged commit 23f6936 into SciML:master Mar 6, 2025
35 of 53 checks passed
@AayushSabharwal AayushSabharwal deleted the as/fix-tests branch March 6, 2025 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants