Skip to content

Commit

Permalink
Merge pull request #943 from AayushSabharwal/as/fix-tests
Browse files Browse the repository at this point in the history
fix: fix remake autodiff tests and Zygote adjoint
  • Loading branch information
ChrisRackauckas authored Mar 6, 2025
2 parents 6d1c9e6 + 5e2e25e commit 23f6936
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 215 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ jobs:
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core7}
- {user: SciML, repo: Catalyst.jl, group: All}

steps:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ SciMLBasePartialFunctionsExt = "PartialFunctions"
SciMLBasePyCallExt = "PyCall"
SciMLBasePythonCallExt = "PythonCall"
SciMLBaseRCallExt = "RCall"
SciMLBaseZygoteExt = "Zygote"
SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
ADTypes = "0.2.5,1.0.0"
Expand Down
55 changes: 16 additions & 39 deletions ext/SciMLBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module SciMLBaseChainRulesCoreExt

using SciMLBase
using SciMLBase: getobserved
import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable
import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad
using SymbolicIndexingInterface

function ChainRulesCore.rrule(
Expand All @@ -15,52 +16,28 @@ function ChainRulesCore.rrule(
j::Integer)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if i === nothing
du, dprob = if i === nothing
getter = getobserved(VA)
grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[4] # pullback for p
if dp == NoTangent()
dp = zero_tangent(parameter_values(VA.prob))
end
dprob = remake(VA.prob, p = dp)
T = eltype(eltype(VA.u))
N = length(VA.prob.p)
Δ′ = ODESolution{T, N, typeof(du), Nothing, Nothing, Nothing, Nothing,
typeof(dprob), Nothing, Nothing, Nothing, Nothing}(du, nothing,
nothing, nothing, nothing, dprob, nothing, nothing,
VA.dense, 0, nothing, nothing, VA.retcode)
(NoTangent(), Δ′, NoTangent(), NoTangent())
du, dprob
else
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
zero(VA.u[1]) for m in 1:length(VA.u)]
dp = zero(VA.prob.p)
dp = zero_tangent(VA.prob.p)
dprob = remake(VA.prob, p = dp)
Δ′ = ODESolution{
T,
N,
typeof(du),
Nothing,
Nothing,
typeof(VA.t),
typeof(VA.k),
typeof(dprob),
typeof(VA.alg),
typeof(VA.interp),
typeof(VA.alg_choice),
typeof(VA.stats)
}(du,
nothing,
nothing,
VA.t,
VA.k,
dprob,
VA.alg,
VA.interp,
VA.dense,
0,
VA.stats,
VA.alg_choice,
VA.retcode)
(NoTangent(), Δ′, NoTangent(), NoTangent())
du, dprob
end
T = eltype(eltype(du))
N = ndims(eltype(du)) + 1
Δ′ = ODESolution{T, N}(du, nothing, nothing, VA.t, VA.k, nothing, dprob,
VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode)
(NoTangent(), Δ′, NoTangent(), NoTangent())
end
VA[sym, j], ODESolution_getindex_pullback
end
Expand Down
29 changes: 4 additions & 25 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module SciMLBaseZygoteExt
using Zygote
using Zygote: @adjoint, pullback
import Zygote: literal_getproperty
import ChainRulesCore
using SciMLBase
using SciMLBase: ODESolution, remake,
getobserved, build_solution, EnsembleSolution,
Expand Down Expand Up @@ -40,31 +41,9 @@ import SciMLStructures
VA[i, j], ODESolution_getindex_pullback
end

@adjoint function Base.getindex(VA::ODESolution, sym, j::Int)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
du, dprob = if i === nothing
getter = getobserved(VA)
grz = pullback(getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
du = [k == j ? grz[2] : zero(VA.u[1]) for k in 1:length(VA.u)]
dp = grz[3] # pullback for p
dprob = remake(VA.prob, p = dp)
du, dprob
else
du = [m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
zero(VA.u[1]) for m in 1:length(VA.u)]
dp = zero(VA.prob.p)
dprob = remake(VA.prob, p = dp)
du, dprob
end
T = eltype(eltype(VA.u))
N = ndims(VA)
Δ′ = ODESolution{T, N}(du, nothing, nothing,
VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats,
VA.alg_choice, VA.retcode)
(Δ′, nothing, nothing)
end
VA[sym, j], ODESolution_getindex_pullback
@adjoint function Base.getindex(VA::ODESolution, sym, j::Integer)
res, pullback = ChainRulesCore.rrule(Zygote.ZygoteRuleConfig(), getindex, VA, sym, j)
return res, Base.tail pullback
end

@adjoint function EnsembleSolution(sim, time, converged, stats)
Expand Down
6 changes: 5 additions & 1 deletion src/problems/sde_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,14 @@ function ConstructionBase.constructorof(::Type{P}) where {P <: SDEProblem}
function ctor(f, g, u0, tspan, p, noise, kw, noise_rate_prototype, seed)
if f isa AbstractSDEFunction
iip = isinplace(f)
if g !== f.g
f = remake(f; g)
end
return SDEProblem{iip}(f, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
else
iip = isinplace(f, 4)
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
end
return SDEProblem{iip}(f, g, u0, tspan, p; kw..., noise, noise_rate_prototype, seed)
end
end

Expand Down
8 changes: 2 additions & 6 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ end
function Base.showerror(io::IO, err::CyclicDependencyError)
println(io, "Detected cyclic dependency in initial values:")
for (k, v) in err.varmap
println(io, k, " => ", "v")
println(io, k, " => ", v)
end
println(io, "While trying to solve for variables: ", err.vars)
end
Expand Down Expand Up @@ -1085,10 +1085,6 @@ calling `SymbolicIndexingInterface.symbolic_container`, provided for dispatch. R
the updated `newu0` and `newp`.
"""
function late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
if hasmethod(symbolic_container, Tuple{typeof(root_indp)}) &&
(sc = symbolic_container(root_indp)) !== root_indp
return late_binding_update_u0_p(prob, sc, u0, p, t0, newu0, newp)
end
return newu0, newp
end

Expand All @@ -1099,7 +1095,7 @@ Calls `late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)` after
`root_indp`.
"""
function late_binding_update_u0_p(prob, u0, p, t0, newu0, newp)
root_indp = prob
root_indp = get_root_indp(prob)
return late_binding_update_u0_p(prob, root_indp, u0, p, t0, newu0, newp)
end

Expand Down
4 changes: 3 additions & 1 deletion src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4818,8 +4818,10 @@ for S in [:ODEFunction
end
end

const EMPTY_SYMBOLCACHE = SymbolCache()

function SymbolicIndexingInterface.symbolic_container(fn::AbstractSciMLFunction)
has_sys(fn) ? fn.sys : SymbolCache()
has_sys(fn) ? fn.sys : EMPTY_SYMBOLCACHE
end

function SymbolicIndexingInterface.is_observed(fn::AbstractSciMLFunction, sym)
Expand Down
72 changes: 43 additions & 29 deletions src/solutions/save_idxs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ function as_diffeq_array(vt::Vector{VectorTemplate}, t)
return DiffEqArray(typeof(TupleOfArraysWrapper(vt))[], t, (1, 1))
end

function is_empty_indp(indp)
isempty(variable_symbols(indp)) && isempty(parameter_symbols(indp)) &&
isempty(independent_variable_symbols(indp))
function get_root_indp(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) && (sc = symbolic_container(indp)) !== indp
return get_root_indp(sc)
end
return indp
end

# Everything from this point on is public API
Expand Down Expand Up @@ -105,17 +107,26 @@ struct SavedSubsystem{V, T, M, I, P, Q, C}
partition_count::C
end

function SavedSubsystem(indp, pobj, saved_idxs)
# nothing saved
if saved_idxs === nothing || isempty(saved_idxs)
SavedSubsystem(indp, pobj, ::Nothing) = nothing

function SavedSubsystem(indp, pobj, idx::Int)
_indp = get_root_indp(indp)
if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
return nothing
end
state_map = Dict(1 => idx)
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
end

# this is required because problems with no system have an empty `SymbolCache`
# as their symbolic container.
if is_empty_indp(indp)
function SavedSubsystem(indp, pobj, saved_idxs::Union{AbstractArray, Tuple})
_indp = get_root_indp(indp)
if _indp === EMPTY_SYMBOLCACHE || _indp === nothing
return nothing
end
if eltype(saved_idxs) == Int
state_map = Dict{Int, Int}(v => k for (k, v) in enumerate(saved_idxs))
return SavedSubsystem(state_map, nothing, nothing, nothing, nothing, nothing, nothing)
end

# array state symbolics must be scalarized
saved_idxs = collect(Iterators.flatten(map(saved_idxs) do sym
Expand Down Expand Up @@ -357,29 +368,32 @@ corresponding to the state variables and a `SavedSubsystem` to pass to `build_so
The second return value (corresponding to the `SavedSubsystem`) may be `nothing` in case
one is not required. `save_idxs` may be a scalar or `nothing`.
"""
get_save_idxs_and_saved_subsystem(prob, ::Nothing) = nothing, nothing
function get_save_idxs_and_saved_subsystem(prob, save_idxs::Vector{Int})
save_idxs, SavedSubsystem(prob, parameter_values(prob), save_idxs)
end
function get_save_idxs_and_saved_subsystem(prob, save_idx::Int)
save_idx, SavedSubsystem(prob, parameter_values(prob), save_idx)
end
function get_save_idxs_and_saved_subsystem(prob, save_idxs)
if save_idxs === nothing
saved_subsystem = nothing
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
_save_idxs = (save_idxs,)
else
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
_save_idxs = [save_idxs]
_save_idxs = save_idxs
end
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
if saved_subsystem !== nothing
_save_idxs = get_saved_state_idxs(saved_subsystem)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) ||
symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
else
_save_idxs = save_idxs
end
saved_subsystem = SavedSubsystem(prob, parameter_values(prob), _save_idxs)
if saved_subsystem !== nothing
_save_idxs = get_saved_state_idxs(saved_subsystem)
if isempty(_save_idxs)
# no states to save
save_idxs = Int[]
elseif !(save_idxs isa AbstractArray) ||
symbolic_type(save_idxs) != NotSymbolic()
# only a single state to save, and save it as a scalar timeseries instead of
# single-element array
save_idxs = only(_save_idxs)
else
save_idxs = _save_idxs
end
save_idxs = _save_idxs
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ DelayDiffEq = "5"
DiffEqCallbacks = "3, 4"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "9.64.1"
ModelingToolkit = "9.64.3"
ModelingToolkitStandardLibrary = "2.7"
NonlinearSolve = "2, 3, 4"
Optimization = "4"
Expand Down
5 changes: 2 additions & 3 deletions test/downstream/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ u0 = [lorenz1.x => 1.0,
lorenz1.z => 0.0,
lorenz2.x => 0.0,
lorenz2.y => 1.0,
lorenz2.z => 0.0,
a => 2.0]
lorenz2.z => 0.0]

p = [lorenz1.σ => 10.0,
lorenz1.ρ => 28.0,
Expand Down Expand Up @@ -68,7 +67,7 @@ gs_ts, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
end

@test all(map(x -> x == true_grad_vecsym, gs_ts))
@test all(map(x -> x == true_grad_vecsym, gs_ts.u))

# BatchedInterface AD
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0
Expand Down
8 changes: 2 additions & 6 deletions test/downstream/comprehensive_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,7 @@ timeseries_systems = [osys, ssys, jsys]
set! = setsym(indp, sym)
@inferred get(valp)
@test get(valp) == val
if valp isa JumpProblem && sym isa Union{Tuple, AbstractArray}
@test_broken valp[sym]
else
@test valp[sym] == val
end
@test valp[sym] == val

if !(valp isa SciMLBase.AbstractNoTimeSolution)
@inferred set!(valp, newval)
Expand Down Expand Up @@ -872,7 +868,7 @@ end
ud2interp = ud2val[2:4]

c1 = SciMLBase.Clock(0.1)
c2 = SciMLBase.SolverStepClock
c2 = SciMLBase.SolverStepClock()
for (sym, t, val) in [
(x, c1[2], xinterp[1]),
(x, c1[2:4], xinterp),
Expand Down
Loading

0 comments on commit 23f6936

Please sign in to comment.