Skip to content

Commit

Permalink
Merge pull request #1006 from LilithHafner/lh/format
Browse files Browse the repository at this point in the history
Run JuliaFormatter.format()
  • Loading branch information
ChrisRackauckas authored Feb 14, 2024
2 parents a5b65e1 + 3d58032 commit 7ac94ad
Show file tree
Hide file tree
Showing 37 changed files with 603 additions and 495 deletions.
20 changes: 11 additions & 9 deletions ext/DiffEqBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@ ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), No
ChainRulesCore.@non_differentiable DiffEqBase.checkkwargs(kwargshandle)

function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
DiffEqBase._solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
DiffEqBase._solve_forward(
prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
kwargs...)
end

function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem,
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
DiffEqBase._solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
sensealg::Union{Nothing, AbstractSensitivityAlgorithm},
u0, p, args...;
kwargs...)
DiffEqBase._solve_adjoint(
prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...;
kwargs...)
end

end
end
29 changes: 19 additions & 10 deletions ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ using Enzyme
import Enzyme: Const
using ChainRulesCore

function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1},
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob,
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
u0, p, args...; kwargs...) where {RT}
@inline function copy_or_reuse(val, idx)
if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val)
return deepcopy(val)
Expand All @@ -16,24 +19,30 @@ function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWi
end

@inline function arg_copy(i)
copy_or_reuse(args[i].val, i+5)
copy_or_reuse(args[i].val, i + 5)
end

res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...;

res = DiffEqBase._solve_adjoint(
copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3),
copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5),
SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...;
kwargs...)

dres = deepcopy(res[1])::RT
for v in dres.u
v.= 0
v .= 0
end
tup = (dres, res[2])
return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any)
end

function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT
dres, clos = tape
function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1},
func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob,
sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}},
u0, p, args...; kwargs...) where {RT}
dres, clos = tape
dres = dres::RT
dargs = clos(dres)
dargs = clos(dres)
for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...))
if ptr isa Enzyme.Const
continue
Expand All @@ -44,9 +53,9 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, f
ptr.dval .+= darg
end
for v in dres.u
v.= 0
v .= 0
end
return ntuple(_ -> nothing, Val(length(args)+4))
return ntuple(_ -> nothing, Val(length(args) + 4))
end

end
2 changes: 1 addition & 1 deletion ext/DiffEqBaseMPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end

if isdefined(MPI, :AbstractMultiRequest)
function DiffEqBase.anyeltypedual(::Type{T},
counter = 0) where {T <: MPI.AbstractMultiRequest}
counter = 0) where {T <: MPI.AbstractMultiRequest}
Any
end
end
Expand Down
15 changes: 8 additions & 7 deletions ext/DiffEqBaseMeasurementsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ else
end

function DiffEqBase.promote_u0(u0::AbstractArray{<:Measurements.Measurement},
p::AbstractArray{<:Measurements.Measurement}, t0)
p::AbstractArray{<:Measurements.Measurement}, t0)
u0
end
DiffEqBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0)
Expand All @@ -22,16 +22,17 @@ value(x::Measurements.Measurement) = Measurements.value(x)
@inline DiffEqBase.fastpow(x::Measurements.Measurement, y::Measurements.Measurement) = x^y

# Support adaptive steps should be errorless
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{
<:Measurements.Measurement,
N,
},
t) where {N}
@inline function DiffEqBase.ODE_DEFAULT_NORM(
u::AbstractArray{
<:Measurements.Measurement,
N
},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Measurements.Measurement, N},
t) where {N}
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((value(x) for x in u), Iterators.repeated(t))) / length(u))
end
Expand Down
37 changes: 20 additions & 17 deletions ext/DiffEqBaseMonteCarloMeasurementsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,43 @@ else
using ..MonteCarloMeasurements
end

function DiffEqBase.promote_u0(u0::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
},
p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles},
t0)
function DiffEqBase.promote_u0(
u0::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
},
p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles},
t0)
u0
end
function DiffEqBase.promote_u0(u0,
p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles},
t0)
p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles},
t0)
eltype(p).(u0)
end

DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T
DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles)

@inline function DiffEqBase.fastpow(x::MonteCarloMeasurements.AbstractParticles,
y::MonteCarloMeasurements.AbstractParticles)
y::MonteCarloMeasurements.AbstractParticles)
x^y
end

# Support adaptive steps should be errorless
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
N}, t) where {N}
@inline function DiffEqBase.ODE_DEFAULT_NORM(
u::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
N}, t) where {N}
sqrt(mean(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((value(x) for x in u), Iterators.repeated(t))))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
N},
t::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
N}) where {N}
@inline function DiffEqBase.ODE_DEFAULT_NORM(
u::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
N},
t::AbstractArray{
<:MonteCarloMeasurements.AbstractParticles,
N}) where {N}
sqrt(mean(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((value(x) for x in u), Iterators.repeated(value.(t)))))
end
Expand Down
106 changes: 54 additions & 52 deletions ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ end

DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V
function DiffEqBase.value(x::Type{
ReverseDiff.TrackedArray{V, D, N, VA, DA},
ReverseDiff.TrackedArray{V, D, N, VA, DA},
}) where {V, D,
N, VA,
DA}
N, VA,
DA}
Array{V, N}
end
DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value
Expand All @@ -26,15 +26,15 @@ DiffEqBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce(

DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::ReverseDiff.TrackedArray, t0)
p::ReverseDiff.TrackedArray, t0)
u0
end
function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray,
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
u0
end
function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
p::AbstractArray{<:ReverseDiff.TrackedReal}, t0)
u0
end
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
Expand All @@ -44,13 +44,14 @@ DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = elt
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t)
sqrt(sum(abs2, DiffEqBase.value(u)) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t) where {N}
@inline function DiffEqBase.ODE_DEFAULT_NORM(
u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
t) where {N}
t) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]),
zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u))
end
Expand All @@ -60,94 +61,95 @@ end

# Support TrackedReal time, don't drop tracking on the adaptivity there
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray,
t::ReverseDiff.TrackedReal)
t::ReverseDiff.TrackedReal)
sqrt(sum(abs2, u) / length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t::ReverseDiff.TrackedReal) where {N}
@inline function DiffEqBase.ODE_DEFAULT_NORM(
u::AbstractArray{<:ReverseDiff.TrackedReal, N},
t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N},
t::ReverseDiff.TrackedReal) where {N}
t::ReverseDiff.TrackedReal) where {N}
sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) /
length(u))
end
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal,
t::ReverseDiff.TrackedReal)
t::ReverseDiff.TrackedReal)
abs(u)
end

# `ReverseDiff.TrackedArray`
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray,
p::ReverseDiff.TrackedArray, args...; kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray,
p::ReverseDiff.TrackedArray, args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::ReverseDiff.TrackedArray,
args...; kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0, p::ReverseDiff.TrackedArray,
args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray, p,
args...; kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray, p,
args...; kwargs...)
ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...)
end

# `AbstractArray{<:ReverseDiff.TrackedReal}`
function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal},
p::AbstractArray{<:ReverseDiff.TrackedReal}, args...;
kwargs...)
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...;
kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0,
p::AbstractArray{<:ReverseDiff.TrackedReal},
args...; kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0,
p::AbstractArray{<:ReverseDiff.TrackedReal},
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray,
p::AbstractArray{<:ReverseDiff.TrackedReal},
args...; kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing}, u0::ReverseDiff.TrackedArray,
p::AbstractArray{<:ReverseDiff.TrackedReal},
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
args...; kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p,
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
end

function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem,
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray,
args...; kwargs...)
sensealg::Union{
SciMLBase.AbstractOverloadingSensitivityAlgorithm,
Nothing},
u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray,
args...; kwargs...)
DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...)
end

Expand Down
Loading

0 comments on commit 7ac94ad

Please sign in to comment.