Skip to content

Commit

Permalink
making maring integrate as having side effect. removing iterator `i…
Browse files Browse the repository at this point in the history
…ntegrate`
  • Loading branch information
weinbe58 committed Jan 3, 2024
1 parent 68accef commit 2a16ad1
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 23 deletions.
6 changes: 3 additions & 3 deletions benchmarks/rabi-dormand-prince.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using BenchmarkTools
using DormandPrince:DP5Solver, DP8Solver, integrate
using DormandPrince

function fcn(x, y, f)
g(x) = 2.2*2π*sin(2π*x)
Expand All @@ -14,6 +14,6 @@ solver = DP8Solver(
ComplexF64[1.0, 0.0]
)

integrate(solver, 2π)
integrate!(solver, 2π)

@benchmark integrate(clean_solver, 2π) setup=(clean_solver = DP8Solver(fcn, 0.0, ComplexF64[1.0, 0.0])) samples=10000 evals=5 seconds=500
@benchmark integrate!(clean_solver, 2π) setup=(clean_solver = DP8Solver(fcn, 0.0, ComplexF64[1.0, 0.0])) samples=10000 evals=5 seconds=500
6 changes: 3 additions & 3 deletions benchmarks/type_stab.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DormandPrince: DP5Solver, DP8Solver, integrate
using DormandPrince
using DormandPrince.DP8: dop853, error_estimation
using JET: @report_opt

Expand All @@ -21,7 +21,7 @@ h = 1e-6
# @code_warntype error_estimation(solver, 1e-6)
# @report_opt error_estimation(solver, 1e-6)

@code_warntype integrate(solver, 2π)
@report_opt integrate(solver, 2π)
@code_warntype integrate!(solver, 2π)
@report_opt integrate!(solver, 2π)


7 changes: 6 additions & 1 deletion src/DormandPrince.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ include("dp5/mod.jl")
include("dp8/mod.jl")

# export Interface
export DP5Solver, DP8Solver, integrate
export AbstractDPSolver,
DP5Solver,
DP8Solver,
integrate!,
SolverIterator,
get_current_state


end # DormandPrince
2 changes: 1 addition & 1 deletion src/dp5/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include("checks.jl")
#include("helpers.jl")

function DormandPrince.integrate(
function DormandPrince.integrate!(
solver::DP5Solver{T},
xend::T
) where T
Expand Down
2 changes: 1 addition & 1 deletion src/dp8/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include("checks.jl")
#include("helpers.jl")

function DormandPrince.integrate(
function DormandPrince.integrate!(
solver::DP8Solver{T},
xend::T
) where T
Expand Down
13 changes: 6 additions & 7 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

struct SolverIterator{T <: Real}
solver::AbstractDPSolver
solver::AbstractDPSolver{T}
times::AbstractVector{T}
end

Expand All @@ -10,7 +10,7 @@ end
function Base.iterate(solver_iter::SolverIterator)
length(solver_iter.times) == 0 && return nothing # empty iterator
# integrate to first time
integrate(solver_iter.solver, first(solver_iter.times))
integrate!(solver_iter.solver, first(solver_iter.times))
# return value and index which is the state
return (solver_iter.times[1], get_current_state(solver_iter.solver)), 2
end
Expand All @@ -19,7 +19,7 @@ end
function Base.iterate(solver_iter::SolverIterator, index::Int)
index > length(solver_iter.times) && return nothing # end of iterator
# integrate to next time
integrate(solver_iter.solver, solver_iter.times[index])
integrate!(solver_iter.solver, solver_iter.times[index])
# return time and state
return (solver_iter.times[index], get_current_state(solver_iter.solver)), index+1
end
Expand All @@ -30,14 +30,13 @@ end
# 3. integrate(callback, solver, times) -> vector of states with callback applied

get_current_state(::AbstractDPSolver) = error("not implemented")
integrate(solver::AbstractDPSolver{T}, times::AbstractVector{T}) where {T <: Real} = SolverIterator(solver, times)

function integrate(callback, solver::AbstractDPSolver{T}, times::AbstractVector{T}; sort_times::Bool = true) where {T <: Real}
integrate!(::AbstractDPSolver{T}, ::T) where T = error("not implemented")
function integrate!(callback, solver::AbstractDPSolver{T}, times::AbstractVector{T}; sort_times::Bool = true) where {T <: Real}
times = sort_times ? sort(collect(times)) : times

result = []
for time in times
integrate(solver, time)
integrate!(solver, time)
push!(result, callback(time, get_current_state(solver)))
end

Expand Down
10 changes: 5 additions & 5 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test
using LinearAlgebra
using DormandPrince: DP5Solver, DP8Solver, integrate
using DormandPrince

function evolution_operator(t::Float64)
ϕ = 2.2 * sin* t)^2
Expand Down Expand Up @@ -36,9 +36,9 @@ end
ComplexF64[1.0, 0.0]
)

integrate(solver, 2π)
integrate!(solver, 2π)

@test solver.y solution(2π)
@test get_current_state(solver) solution(2π)
end

end
Expand Down Expand Up @@ -66,7 +66,7 @@ end
ComplexF64[1.0, 0.0]
)

iter = integrate(solver, times)
iter = SolverIterator(solver, times)

for (t,y) in iter
push!(values, copy(y))
Expand Down Expand Up @@ -99,7 +99,7 @@ end
ComplexF64[1.0, 0.0]
)

integrate(solver, times) do t, y
integrate!(solver, times) do t, y
push!(callback_times, t)
push!(values, copy(y))
end
Expand Down
4 changes: 2 additions & 2 deletions test/stiff.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test
using DormandPrince: DP5Solver, DP8Solver, integrate
using DormandPrince

function stiff_fcn(x, y, f)
f[1] = y[1]^2 - y[1]^3
Expand All @@ -14,6 +14,6 @@ end
[0.0001]
)

integrate(solver, 2/0.0001)
integrate!(solver, 2/0.0001)
end
end

0 comments on commit 2a16ad1

Please sign in to comment.