From f504898661182a775364fffc0ffc82ca9f1e4750 Mon Sep 17 00:00:00 2001 From: Oriol Colomes Date: Fri, 20 Oct 2023 23:12:02 +0200 Subject: [PATCH] Optimizations in RK --- src/ODEs/ODETools/IMEXRungeKutta.jl | 6 ++++ src/ODEs/ODETools/RungeKutta.jl | 53 ++++++++++++++++++----------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/ODEs/ODETools/IMEXRungeKutta.jl b/src/ODEs/ODETools/IMEXRungeKutta.jl index a89f3da78..bbb0faa93 100644 --- a/src/ODEs/ODETools/IMEXRungeKutta.jl +++ b/src/ODEs/ODETools/IMEXRungeKutta.jl @@ -290,3 +290,9 @@ function explicit_rhs!(op::IMEXRungeKuttaUpdateNonlinearOperator, x::AbstractVec @. v = (x-op.u0)/(op.dt) explicit_rhs!(op.ex_rhs,op.odeop,op.ti,(u,v),op.ode_cache) end + +function update!(op::IMEXRungeKuttaStageNonlinearOperator,ti::Float64,ui::AbstractVector,i::Int) + op.ti = ti + op.ui = ui + op.i = i +end diff --git a/src/ODEs/ODETools/RungeKutta.jl b/src/ODEs/ODETools/RungeKutta.jl index 1b8993d2b..333f15a60 100644 --- a/src/ODEs/ODETools/RungeKutta.jl +++ b/src/ODEs/ODETools/RungeKutta.jl @@ -33,17 +33,22 @@ function solve_step!(uf::AbstractVector, ode_cache = allocate_cache(op) vi = similar(u0) ui = Vector{typeof(u0)}() - sizehint!(ui,s) - [push!(ui,similar(u0)) for i in 1:s] + sizehint!(ui,s-1) + [push!(ui,similar(u0)) for i in 1:s-1] rhs = similar(u0) + if (s>1) + u_aux = ui[1] # auxiliar variable to store the sum of stages + else + u_aux = nothing # This is needed for the case s=1 + end nls_stage_cache = nothing nls_update_cache = nothing else - ode_cache, vi, ui, rhs, nls_stage_cache, nls_update_cache = cache + ode_cache, vi, ui, u_aux, rhs, nls_stage_cache, nls_update_cache = cache end # Create RKNL stage operator - nlop_stage = RungeKuttaStageNonlinearOperator(op,t0,dt,u0,ode_cache,vi,ui,rhs,0,a) + nlop_stage = RungeKuttaStageNonlinearOperator(op,t0,dt,u0,ode_cache,vi,u_aux,rhs,0,a) # Compute intermediate stages for i in 1:s @@ -61,8 +66,10 @@ function solve_step!(uf::AbstractVector, nls_stage_cache = solve!(uf,solver.nls_stage,nlop_stage,nls_stage_cache) end - # Update stage unknown - @. nlop_stage.ui[i] = uf + # Store stage unknown + if (i1) + @. u += op.ui end rhs!(op.rhs,op.odeop,op.ti,(u,v),op.ode_cache) end @@ -223,11 +237,7 @@ end function rhs!(op::RungeKuttaUpdateNonlinearOperator, x::AbstractVector) v = op.vi @. v = (x-op.u0)/(op.dt) - u = op.b[op.s] * op.ui[op.s] - for i in 1:op.s-1 - @. u = u + op.b[i] * op.ui[i] - end - rhs!(op.rhs,op.odeop,op.ti,(u,v),op.ode_cache) + rhs!(op.rhs,op.odeop,op.ti,(op.ui,v),op.ode_cache) end function lhs!(b::AbstractVector, op::RungeKuttaNonlinearOperator, x::AbstractVector) @@ -237,9 +247,12 @@ function lhs!(b::AbstractVector, op::RungeKuttaNonlinearOperator, x::AbstractVec lhs!(b,op.odeop,op.ti,(u,v),op.ode_cache) end -function update!(op::RungeKuttaNonlinearOperator,ti::Float64,ui::AbstractVector,i::Int) +function update!(op::RungeKuttaStageNonlinearOperator,ti::Float64,ui::AbstractVector,i::Int) op.ti = ti - op.ui = ui + if (i>1) + op.ui = op.a[i,i-1]*ui[i-1] + [@. op.ui += op.a[i,j]*ui[j] for j in 1:i-2] + end op.i = i end