Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Sep 28, 2024
1 parent 5852794 commit 2f31389
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqBase"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
authors = ["Chris Rackauckas <[email protected]>"]
version = "6.155.4"
version = "6.156.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
2 changes: 1 addition & 1 deletion ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ end

Enzyme.Compiler.known_ops[typeof(DiffEqBase.fastpow)] = (:pow, 2, nothing)

end
end
9 changes: 7 additions & 2 deletions ext/DiffEqBaseReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import ReverseDiff
import DiffEqBase.ArrayInterface
import DiffEqBase.ForwardDiff

function DiffEqBase.anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}}
function DiffEqBase.anyeltypedual(::Type{T},
::Type{Val{counter}} = Val{0}) where {counter} where {
V, D, N, VA, DA, T <: ReverseDiff.TrackedArray{V, D, N, VA, DA}}
DiffEqBase.anyeltypedual(V, Val{counter})
end

Expand Down Expand Up @@ -38,7 +40,10 @@ function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal},
u0
end
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0)
DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual} = ReverseDiff.track(T.(u0))
function DiffEqBase.promote_u0(
u0, p::ReverseDiff.TrackedArray{T}, t0) where {T <: ForwardDiff.Dual}
ReverseDiff.track(T.(u0))
end
DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = eltype(p).(u0)

# Support adaptive with non-tracked time
Expand Down
12 changes: 8 additions & 4 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,10 @@ function apply_callback!(integrator,
end

if integrator.u_modified
if hasmethod(reeval_internals_due_to_modification!, Tuple{typeof(integrator)}, (:callback_initializealg,))
reeval_internals_due_to_modification!(integrator, callback_initializealg = callback.initializealg)
if hasmethod(reeval_internals_due_to_modification!,
Tuple{typeof(integrator)}, (:callback_initializealg,))
reeval_internals_due_to_modification!(
integrator, callback_initializealg = callback.initializealg)
else # handle legacy dispatch without kwarg
reeval_internals_due_to_modification!(integrator)
end
Expand Down Expand Up @@ -617,8 +619,10 @@ end
integrator.u_modified = true
callback.affect!(integrator)
if integrator.u_modified
if hasmethod(reeval_internals_due_to_modification!, Tuple{typeof(integrator), Bool}, (:callback_initializealg,))
reeval_internals_due_to_modification!(integrator, false, callback_initializealg = callback.initializealg)
if hasmethod(reeval_internals_due_to_modification!,
Tuple{typeof(integrator), Bool}, (:callback_initializealg,))
reeval_internals_due_to_modification!(
integrator, false, callback_initializealg = callback.initializealg)
else # handle legacy dispatch without kwarg
reeval_internals_due_to_modification!(integrator, false)
end
Expand Down
5 changes: 3 additions & 2 deletions src/fastpow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ const EXP2FT = (Float32(0x1.6a09e667f3bcdp-1),
if iszero(x)
return zero(outT)
elseif isinf(x) && isinf(y)
return convert(outT,Inf)
return convert(outT, Inf)
else
return convert(outT,@fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x))))
return convert(
outT, @fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x))))
end
end

Expand Down
5 changes: 4 additions & 1 deletion src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,10 @@ function anyeltypedual(x::NamedTuple, ::Type{Val{counter}} = Val{0}) where {coun
anyeltypedual(values(x))
end

DiffEqBase.anyeltypedual(f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter} = Any
function DiffEqBase.anyeltypedual(
f::SciMLBase.AbstractSciMLFunction, ::Type{Val{counter}}) where {counter}
Any
end

@inline promote_u0(::Nothing, p, t0) = nothing

Expand Down
9 changes: 6 additions & 3 deletions src/integrator_accessors.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# the following are setup per how integrators are implemented in OrdinaryDiffEq and
# StochasticDiffEq and provide dispatch points that JumpProcesses and others can use.

get_tstops(integ::DEIntegrator) =
function get_tstops(integ::DEIntegrator)
error("get_tstops not implemented for integrators of type $(nameof(typeof(integ)))")
get_tstops_array(integ::DEIntegrator) =
end
function get_tstops_array(integ::DEIntegrator)
error("get_tstops_array not implemented for integrators of type $(nameof(typeof(integ)))")
get_tstops_max(integ::DEIntegrator) =
end
function get_tstops_max(integ::DEIntegrator)
error("get_tstops_max not implemented for integrators of type $(nameof(typeof(integ)))")
end
6 changes: 4 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,8 @@ function __solve(
kwargs...)
if second_time
throw(NoDefaultAlgorithmError())
elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
elseif length(args) > 0 && !(first(args) isa
Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
throw(NonSolverError())
else
__solve(prob, nothing, args...; default_set = false, second_time = true, kwargs...)
Expand All @@ -1403,7 +1404,8 @@ function __init(prob::AbstractDEProblem, args...; default_set = false, second_ti
kwargs...)
if second_time
throw(NoDefaultAlgorithmError())
elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
elseif length(args) > 0 && !(first(args) isa
Union{Nothing, AbstractDEAlgorithm, AbstractNonlinearAlgorithm})
throw(NonSolverError())
else
__init(prob, nothing, args...; default_set = false, second_time = true, kwargs...)
Expand Down
8 changes: 5 additions & 3 deletions test/downstream/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@ using Test
@testset for RT in (Duplicated, DuplicatedNoNeed),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)

x = 3.0
y = 2.0
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005)
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol = 0.005, rtol = 0.005)
end
end

@testset "Fast pow - Enzyme reverse rule" begin
@testset for RT in (Active,),
Tx in (Active,),
Ty in (Active,)

x = 2.0
y = 3.0
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001)
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol = 0.001, rtol = 0.001)
end
end
end
6 changes: 5 additions & 1 deletion test/downstream/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ sol1 = solve(prob, Euler(); dt = 1 // 2^(4));
df = DataFrame(sol1)
@test names(df) == ["timestamp", "value1", "value2", "value3", "value4"]

prob = ODEProblem(ODEFunction(f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)), rand(2, 2), (0.0, 1.0));
prob = ODEProblem(
ODEFunction(
f_2dlinear, sys = SymbolicIndexingInterface.SymbolCache([:a, :b, :c, :d], [], :t)),
rand(2, 2),
(0.0, 1.0));
sol2 = solve(prob, Euler(); dt = 1 // 2^(4));
df = DataFrame(sol2)
@test names(df) == ["timestamp", "a", "b", "c", "d"]
2 changes: 1 addition & 1 deletion test/fastpow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ end
@test fastpow(1.0, 1.0) isa Float64
errors = [abs(^(x, y) - fastpow(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5]
@test maximum(errors) < 1e-4
end
end
25 changes: 18 additions & 7 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,19 +352,30 @@ DiffEqBase.anyeltypedual((; x = foo, y = prob.f))

@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(3))) == Any
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(3)))) == Any
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) == eltype(ones(ForwardDiff.Dual, 3))
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) == eltype(ones(ForwardDiff.Dual, 3))
@test DiffEqBase.anyeltypedual(ReverseDiff.track(ones(ForwardDiff.Dual, 3))) ==
eltype(ones(ForwardDiff.Dual, 3))
@test DiffEqBase.anyeltypedual(typeof(ReverseDiff.track(ones(ForwardDiff.Dual, 3)))) ==
eltype(ones(ForwardDiff.Dual, 3))

struct FakeParameterObject{T}
tunables::T
end

SciMLStructures.isscimlstructure(::FakeParameterObject) = true
SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::FakeParameterObject) = f.tunables, x -> FakeParameterObject(x), true
function SciMLStructures.canonicalize(::SciMLStructures.Tunable, f::FakeParameterObject)
f.tunables, x -> FakeParameterObject(x), true
end

@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedArray
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa ReverseDiff.TrackedReal
@test DiffEqBase.promote_u0(ones(3), FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedArray{<:ForwardDiff.Dual}
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa ReverseDiff.TrackedReal{<:ForwardDiff.Dual}
@test DiffEqBase.promote_u0(
ones(3), FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa
ReverseDiff.TrackedArray
@test DiffEqBase.promote_u0(1.0, FakeParameterObject(ReverseDiff.track(ones(3))), 0.0) isa
ReverseDiff.TrackedReal
@test DiffEqBase.promote_u0(
ones(3), FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa
ReverseDiff.TrackedArray{<:ForwardDiff.Dual}
@test DiffEqBase.promote_u0(
1.0, FakeParameterObject(ReverseDiff.track(ones(ForwardDiff.Dual, 3))), 0.0) isa
ReverseDiff.TrackedReal{<:ForwardDiff.Dual}
@test DiffEqBase.promote_u0(NaN, [NaN], 0.0) isa Float64
@test DiffEqBase.promote_u0([1.0], [NaN], 0.0) isa Vector{Float64}

0 comments on commit 2f31389

Please sign in to comment.