diff --git a/benchmarks/rabi-dormand-prince.jl b/benchmarks/rabi-dormand-prince.jl index c6b0760..d21d11f 100644 --- a/benchmarks/rabi-dormand-prince.jl +++ b/benchmarks/rabi-dormand-prince.jl @@ -1,5 +1,5 @@ using BenchmarkTools -using DormandPrince:DP5Solver, DP8Solver, integrate +using DormandPrince function fcn(x, y, f) g(x) = 2.2*2π*sin(2π*x) @@ -14,6 +14,6 @@ solver = DP8Solver( ComplexF64[1.0, 0.0] ) -integrate(solver, 2π) +integrate!(solver, 2π) -@benchmark integrate(clean_solver, 2π) setup=(clean_solver = DP8Solver(fcn, 0.0, ComplexF64[1.0, 0.0])) samples=10000 evals=5 seconds=500 \ No newline at end of file +@benchmark integrate!(clean_solver, 2π) setup=(clean_solver = DP8Solver(fcn, 0.0, ComplexF64[1.0, 0.0])) samples=10000 evals=5 seconds=500 \ No newline at end of file diff --git a/benchmarks/type_stab.jl b/benchmarks/type_stab.jl index 0b29921..83d8999 100644 --- a/benchmarks/type_stab.jl +++ b/benchmarks/type_stab.jl @@ -1,4 +1,4 @@ -using DormandPrince: DP5Solver, DP8Solver, integrate +using DormandPrince using DormandPrince.DP8: dop853, error_estimation using JET: @report_opt @@ -21,7 +21,7 @@ h = 1e-6 # @code_warntype error_estimation(solver, 1e-6) # @report_opt error_estimation(solver, 1e-6) -@code_warntype integrate(solver, 2π) -@report_opt integrate(solver, 2π) +@code_warntype integrate!(solver, 2π) +@report_opt integrate!(solver, 2π) diff --git a/src/DormandPrince.jl b/src/DormandPrince.jl index a66a041..5fc9909 100644 --- a/src/DormandPrince.jl +++ b/src/DormandPrince.jl @@ -11,7 +11,12 @@ include("dp5/mod.jl") include("dp8/mod.jl") # export Interface -export DP5Solver, DP8Solver, integrate +export AbstractDPSolver, + DP5Solver, + DP8Solver, + integrate!, + SolverIterator, + get_current_state end # DormandPrince diff --git a/src/dp5/solver.jl b/src/dp5/solver.jl index 6628fee..2c1273f 100644 --- a/src/dp5/solver.jl +++ b/src/dp5/solver.jl @@ -3,7 +3,7 @@ #include("checks.jl") #include("helpers.jl") -function DormandPrince.integrate( +function DormandPrince.integrate!( solver::DP5Solver{T}, xend::T ) where T diff --git a/src/dp8/solver.jl b/src/dp8/solver.jl index 9377eea..7cef125 100644 --- a/src/dp8/solver.jl +++ b/src/dp8/solver.jl @@ -3,7 +3,7 @@ #include("checks.jl") #include("helpers.jl") -function DormandPrince.integrate( +function DormandPrince.integrate!( solver::DP8Solver{T}, xend::T ) where T diff --git a/src/interface.jl b/src/interface.jl index 69a7f28..562a65a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,6 +1,6 @@ struct SolverIterator{T <: Real} - solver::AbstractDPSolver + solver::AbstractDPSolver{T} times::AbstractVector{T} end @@ -10,7 +10,7 @@ end function Base.iterate(solver_iter::SolverIterator) length(solver_iter.times) == 0 && return nothing # empty iterator # integrate to first time - integrate(solver_iter.solver, first(solver_iter.times)) + integrate!(solver_iter.solver, first(solver_iter.times)) # return value and index which is the state return (solver_iter.times[1], get_current_state(solver_iter.solver)), 2 end @@ -19,7 +19,7 @@ end function Base.iterate(solver_iter::SolverIterator, index::Int) index > length(solver_iter.times) && return nothing # end of iterator # integrate to next time - integrate(solver_iter.solver, solver_iter.times[index]) + integrate!(solver_iter.solver, solver_iter.times[index]) # return time and state return (solver_iter.times[index], get_current_state(solver_iter.solver)), index+1 end @@ -30,14 +30,13 @@ end # 3. integrate(callback, solver, times) -> vector of states with callback applied get_current_state(::AbstractDPSolver) = error("not implemented") -integrate(solver::AbstractDPSolver{T}, times::AbstractVector{T}) where {T <: Real} = SolverIterator(solver, times) - -function integrate(callback, solver::AbstractDPSolver{T}, times::AbstractVector{T}; sort_times::Bool = true) where {T <: Real} +integrate!(::AbstractDPSolver{T}, ::T) where T = error("not implemented") +function integrate!(callback, solver::AbstractDPSolver{T}, times::AbstractVector{T}; sort_times::Bool = true) where {T <: Real} times = sort_times ? sort(collect(times)) : times result = [] for time in times - integrate(solver, time) + integrate!(solver, time) push!(result, callback(time, get_current_state(solver))) end diff --git a/test/interface.jl b/test/interface.jl index 1500138..d4f57cb 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,6 +1,6 @@ using Test using LinearAlgebra -using DormandPrince: DP5Solver, DP8Solver, integrate +using DormandPrince function evolution_operator(t::Float64) ϕ = 2.2 * sin(π * t)^2 @@ -36,9 +36,9 @@ end ComplexF64[1.0, 0.0] ) - integrate(solver, 2π) + integrate!(solver, 2π) - @test solver.y ≈ solution(2π) + @test get_current_state(solver) ≈ solution(2π) end end @@ -66,7 +66,7 @@ end ComplexF64[1.0, 0.0] ) - iter = integrate(solver, times) + iter = SolverIterator(solver, times) for (t,y) in iter push!(values, copy(y)) @@ -99,7 +99,7 @@ end ComplexF64[1.0, 0.0] ) - integrate(solver, times) do t, y + integrate!(solver, times) do t, y push!(callback_times, t) push!(values, copy(y)) end diff --git a/test/stiff.jl b/test/stiff.jl index b22c5b2..62c9b75 100644 --- a/test/stiff.jl +++ b/test/stiff.jl @@ -1,5 +1,5 @@ using Test -using DormandPrince: DP5Solver, DP8Solver, integrate +using DormandPrince function stiff_fcn(x, y, f) f[1] = y[1]^2 - y[1]^3 @@ -14,6 +14,6 @@ end [0.0001] ) - integrate(solver, 2/0.0001) + integrate!(solver, 2/0.0001) end end