Skip to content

Commit

Permalink
Optimizations in RK
Browse files Browse the repository at this point in the history
  • Loading branch information
oriolcg committed Oct 20, 2023
1 parent 324f163 commit f504898
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
6 changes: 6 additions & 0 deletions src/ODEs/ODETools/IMEXRungeKutta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,9 @@ function explicit_rhs!(op::IMEXRungeKuttaUpdateNonlinearOperator, x::AbstractVec
@. v = (x-op.u0)/(op.dt)
explicit_rhs!(op.ex_rhs,op.odeop,op.ti,(u,v),op.ode_cache)
end

function update!(op::IMEXRungeKuttaStageNonlinearOperator,ti::Float64,ui::AbstractVector,i::Int)
op.ti = ti
op.ui = ui
op.i = i
end
53 changes: 33 additions & 20 deletions src/ODEs/ODETools/RungeKutta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,22 @@ function solve_step!(uf::AbstractVector,
ode_cache = allocate_cache(op)
vi = similar(u0)
ui = Vector{typeof(u0)}()
sizehint!(ui,s)
[push!(ui,similar(u0)) for i in 1:s]
sizehint!(ui,s-1)
[push!(ui,similar(u0)) for i in 1:s-1]
rhs = similar(u0)
if (s>1)
u_aux = ui[1] # auxiliar variable to store the sum of stages
else
u_aux = nothing # This is needed for the case s=1
end
nls_stage_cache = nothing
nls_update_cache = nothing
else
ode_cache, vi, ui, rhs, nls_stage_cache, nls_update_cache = cache
ode_cache, vi, ui, u_aux, rhs, nls_stage_cache, nls_update_cache = cache

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

@oriolcg if u_aux points to ui[1], why are we passing it separately?

end

# Create RKNL stage operator
nlop_stage = RungeKuttaStageNonlinearOperator(op,t0,dt,u0,ode_cache,vi,ui,rhs,0,a)
nlop_stage = RungeKuttaStageNonlinearOperator(op,t0,dt,u0,ode_cache,vi,u_aux,rhs,0,a)

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

idem


# Compute intermediate stages
for i in 1:s
Expand All @@ -61,8 +66,10 @@ function solve_step!(uf::AbstractVector,
nls_stage_cache = solve!(uf,solver.nls_stage,nlop_stage,nls_stage_cache)
end

# Update stage unknown
@. nlop_stage.ui[i] = uf
# Store stage unknown
if (i<s)
@. ui[i] = uf
end

end

Expand All @@ -72,17 +79,24 @@ function solve_step!(uf::AbstractVector,
# Skip final update if not necessary
if !(c[s]==1.0 && a[s,:] == b)

# Operate with stages solution
@. u_aux = b[1]*ui[1] # u_aux points to ui[1], this needs to be done first

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

it could be u[i] *= b[1], right?

@. u_aux += b[s]*uf
for i in 2:s-1
@. u_aux += b[i]*ui[i]
end

# Create RKNL final update operator
ode_cache = update_cache!(ode_cache,op,tf)
nlop_update = RungeKuttaUpdateNonlinearOperator(op,tf,dt,u0,ode_cache,vi,ui,rhs,s,b)
nlop_update = RungeKuttaUpdateNonlinearOperator(op,tf,dt,u0,ode_cache,vi,u_aux,rhs,s,b)

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

idem


# solve at final update
nls_update_cache = solve!(uf,solver.nls_update,nlop_update,nls_update_cache)

end

# Update final cache
cache = (ode_cache, vi, ui, rhs, nls_stage_cache, nls_update_cache)
cache = (ode_cache, vi, ui, u_aux, rhs, nls_stage_cache, nls_update_cache)

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

idem


return (uf,tf,cache)

Expand All @@ -101,7 +115,7 @@ mutable struct RungeKuttaStageNonlinearOperator <: RungeKuttaNonlinearOperator
u0::AbstractVector
ode_cache
vi::AbstractVector
ui::Vector{AbstractVector}
ui::Union{AbstractVector,Nothing}

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

do we really need this?

rhs::AbstractVector
i::Int
a::Matrix
Expand All @@ -118,7 +132,7 @@ mutable struct RungeKuttaUpdateNonlinearOperator <: RungeKuttaNonlinearOperator
u0::AbstractVector
ode_cache
vi::AbstractVector
ui::Vector{AbstractVector}
ui::AbstractVector
rhs::AbstractVector
s::Int
b::Vector{Number}
Expand Down Expand Up @@ -213,21 +227,17 @@ end
function rhs!(op::RungeKuttaStageNonlinearOperator, x::AbstractVector)
v = op.vi
@. v = (x-op.u0)/(op.dt)
u = op.a[op.i,op.i] * op.ui[op.i]
for j in 1:op.i-1
@. u += op.a[op.i,j] * op.ui[j]
u = op.a[op.i,op.i] * x

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

We should use the cache to store this value

if (op.i>1)
@. u += op.ui
end
rhs!(op.rhs,op.odeop,op.ti,(u,v),op.ode_cache)
end

function rhs!(op::RungeKuttaUpdateNonlinearOperator, x::AbstractVector)
v = op.vi
@. v = (x-op.u0)/(op.dt)
u = op.b[op.s] * op.ui[op.s]

This comment has been minimized.

Copy link
@santiagobadia

santiagobadia Oct 22, 2023

Member

idem

for i in 1:op.s-1
@. u = u + op.b[i] * op.ui[i]
end
rhs!(op.rhs,op.odeop,op.ti,(u,v),op.ode_cache)
rhs!(op.rhs,op.odeop,op.ti,(op.ui,v),op.ode_cache)
end

function lhs!(b::AbstractVector, op::RungeKuttaNonlinearOperator, x::AbstractVector)
Expand All @@ -237,9 +247,12 @@ function lhs!(b::AbstractVector, op::RungeKuttaNonlinearOperator, x::AbstractVec
lhs!(b,op.odeop,op.ti,(u,v),op.ode_cache)
end

function update!(op::RungeKuttaNonlinearOperator,ti::Float64,ui::AbstractVector,i::Int)
function update!(op::RungeKuttaStageNonlinearOperator,ti::Float64,ui::AbstractVector,i::Int)
op.ti = ti
op.ui = ui
if (i>1)
op.ui = op.a[i,i-1]*ui[i-1]
[@. op.ui += op.a[i,j]*ui[j] for j in 1:i-2]
end
op.i = i
end

Expand Down

0 comments on commit f504898

Please sign in to comment.