Skip to content

Commit

Permalink
Better?
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdAlazezAhmed committed Aug 19, 2024
1 parent bad3da1 commit b5b4195
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 97 deletions.
6 changes: 2 additions & 4 deletions examples/conduction-velocity-benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ steady_state_initializer!(u₀, odeform)

# io = ParaViewWriter("spiral-wave-test")

timestepper = Thunderbolt.AdaptiveOperatorSplittingAlgorithm(
timestepper = Thunderbolt.ReactionTangentController(
OS.LieTrotterGodunov((
BackwardEulerSolver(
solution_vector_type=Vector{Float32},
Expand All @@ -75,9 +75,7 @@ timestepper = Thunderbolt.AdaptiveOperatorSplittingAlgorithm(
reaction_threshold=0.1f0,
)
)),
Thunderbolt.ReactionTangentController(
0.5, 1.0, (0.01, 0.3)
)
0.5, 1.0, (0.01, 0.3)
)

problem = OS.OperatorSplittingProblem(odeform, u₀, tspan)
Expand Down
86 changes: 40 additions & 46 deletions src/solver/adaptivity.jl
Original file line number Diff line number Diff line change
@@ -1,42 +1,36 @@
abstract type AbstractTimeAdaptionAlgorithm end

"""
ReactionTangentController{T <: Real} <: AbstractTimeAdaptionAlgorithm
ReactionTangentController{T <: Real} <: OS.AbstractOperatorSplittingAlgorithm
A timestep length controller for [`LieTrotterGodunov`](@ref) [Lie:1880:tti,Tro:1959:psg,God:1959:dmn](@cite)
operator splitting using the reaction tangent as proposed in [OgiBalPer:2023:seats](@cite)
# Fields
- `σ_s::T`: steepness
- `σ_c::T`: offset in R axis
- `Δt_bounds::NTuple{2,T}`: lower and upper timestep length bounds
- `Rₙ₊₁::T`: updated maximal reaction magnitude
- `Rₙ::T`: previous reaction magnitude
"""
mutable struct ReactionTangentController{T <: Real} <: AbstractTimeAdaptionAlgorithm
const σ_s::T
const σ_c::T
const Δt_bounds::NTuple{2,T}
Rₙ₊₁::T
Rₙ::T
struct ReactionTangentController{LTG <: OS.LieTrotterGodunov, T <: Real} <: OS.AbstractOperatorSplittingAlgorithm
ltg::LTG
σ_s::T
σ_c::T
Δt_bounds::NTuple{2,T}
end

function ReactionTangentController(σ_s::T, σ_c::T, Δt_bounds::NTuple{2,T}) where {T <: Real}
return ReactionTangentController(σ_s, σ_c, Δt_bounds, 0.0, 0.0)
mutable struct ReactionTangentControllerCache{T <: Real, LTGCache <: OS.LieTrotterGodunovCache} <: OS.AbstractOperatorSplittingCache
const ltg_cache::LTGCache #It has Arrays so it can be const?
Rₙ₊₁::T
Rₙ::T
end

@inline OS.get_u(cache::ReactionTangentControllerCache) = OS.get_u(cache.ltg_cache)
@inline OS.get_uprev(cache::ReactionTangentControllerCache) = OS.get_uprev(cache.ltg_cache)
@inline DiffEqBase.get_tmp_cache(integrator::OS.OperatorSplittingIntegrator, alg::OS.AbstractOperatorSplittingAlgorithm, cache::ReactionTangentControllerCache) = DiffEqBase.get_tmp_cache(integrator, alg, cache.ltg_cache)

"""
AdaptiveOperatorSplittingAlgorithm{TOperatorSplittingAlg <: OS.AbstractOperatorSplittingAlgorithm, TTimeAdaptionAlgorithm <: AbstractTimeAdaptionAlgorithm} <: OS.AbstractOperatorSplittingAlgorithm
A generic operator splitting algorithm `operator_splitting_algorithm` with adaptive timestepping using the controller `controller`.
# Fields
- `operator_splitting_algorithm::TOperatorSplittingAlg`: steepness
- `controller::TTimeAdaptionAlgorithm`: offset in R axis
"""
struct AdaptiveOperatorSplittingAlgorithm{TOperatorSplittingAlg <: OS.AbstractOperatorSplittingAlgorithm, TTimeAdaptionAlgorithm <: AbstractTimeAdaptionAlgorithm} <: OS.AbstractOperatorSplittingAlgorithm
operator_splitting_algorithm::TOperatorSplittingAlg
controller::TTimeAdaptionAlgorithm
@inline function OS.advance_solution_to!(subintegrators::Tuple, cache::ReactionTangentControllerCache, tnext)
OS.advance_solution_to!(subintegrators, cache.ltg_cache, tnext)
end

@inline DiffEqBase.isadaptive(::AdaptiveOperatorSplittingAlgorithm) = true
@inline DiffEqBase.isadaptive(::ReactionTangentController) = true

"""
get_reaction_tangent(integrator::OS.OperatorSplittingIntegrator)
Expand All @@ -58,42 +52,42 @@ end
# end
end

@inline function OS.stepsize_controller!(integrator::OS.OperatorSplittingIntegrator, alg::AdaptiveOperatorSplittingAlgorithm)
OS.stepsize_controller!(integrator, alg.controller, alg)
end

@inline function OS.step_accept_controller!(integrator::OS.OperatorSplittingIntegrator, alg::AdaptiveOperatorSplittingAlgorithm, q)
OS.step_accept_controller!(integrator, alg.controller, alg, q)
end

@inline function OS.step_reject_controller!(integrator::OS.OperatorSplittingIntegrator, alg::AdaptiveOperatorSplittingAlgorithm, q)
OS.step_reject_controller!(integrator, alg.controller, alg, q)
end

@inline function OS.stepsize_controller!(integrator::OS.OperatorSplittingIntegrator, controller::ReactionTangentController, alg::AdaptiveOperatorSplittingAlgorithm{<:OS.LieTrotterGodunov})
@unpack σ_s, σ_c, Δt_bounds, Rₙ₊₁, Rₙ = controller
controller.Rₙ = controller.Rₙ₊₁
controller.Rₙ₊₁ = get_reaction_tangent(integrator)
@inline function OS.stepsize_controller!(integrator::OS.OperatorSplittingIntegrator, alg::ReactionTangentController)
integrator.cache.Rₙ = integrator.cache.Rₙ₊₁
integrator.cache.Rₙ₊₁ = get_reaction_tangent(integrator)
return nothing
end

@inline function OS.step_accept_controller!(integrator::OS.OperatorSplittingIntegrator, controller::ReactionTangentController, alg::AdaptiveOperatorSplittingAlgorithm{<:OS.LieTrotterGodunov}, q)
@unpack σ_s, σ_c, Δt_bounds, Rₙ₊₁, Rₙ = controller
@inline function OS.step_accept_controller!(integrator::OS.OperatorSplittingIntegrator, alg::ReactionTangentController, q)
@unpack Rₙ₊₁, Rₙ = integrator.cache
@unpack σ_s, σ_c, Δt_bounds = alg
R = max(Rₙ, Rₙ₊₁)
integrator.dt = (1 - 1/(1+exp((σ_c - R)*σ_s)))*(Δt_bounds[2] - Δt_bounds[1]) + Δt_bounds[1]
integrator._dt = (1 - 1/(1+exp((σ_c - R)*σ_s)))*(Δt_bounds[2] - Δt_bounds[1]) + Δt_bounds[1]
return nothing
end

@inline function OS.step_reject_controller!(integrator::OS.OperatorSplittingIntegrator, controller::ReactionTangentController, alg::AdaptiveOperatorSplittingAlgorithm{<:OS.LieTrotterGodunov}, q)
@inline function OS.step_reject_controller!(integrator::OS.OperatorSplittingIntegrator, alg::ReactionTangentController, q)
return nothing # Do nothing
end

# Dispatch for outer construction
function OS.init_cache(prob::OS.OperatorSplittingProblem, alg::AdaptiveOperatorSplittingAlgorithm; dt, kwargs...) # TODO
OS.init_cache(prob, alg.operator_splitting_algorithm;dt = dt, kwargs...)
function OS.init_cache(prob::OS.OperatorSplittingProblem, alg::ReactionTangentController; dt, kwargs...) # TODO
@unpack f = prob
@assert f isa GenericSplitFunction

u = copy(prob.u0)
uprev = copy(prob.u0)

# Build inner integrator
return OS.construct_inner_cache(f, alg, u, uprev)
end

# Dispatch for recursive construction
function OS.construct_inner_cache(f::OS.AbstractOperatorSplitFunction, alg::AdaptiveOperatorSplittingAlgorithm, u::AbstractArray, uprev::AbstractArray)
OS.construct_inner_cache(f, alg.operator_splitting_algorithm, u, uprev)
function OS.construct_inner_cache(f::OS.AbstractOperatorSplitFunction, alg::ReactionTangentController, u::AbstractArray{T}, uprev::AbstractArray) where T <: Number
ltg_cache = OS.construct_inner_cache(f, alg.ltg, u, uprev)
return ReactionTangentControllerCache(ltg_cache, zero(T), zero(T))
end

function OS.build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::Tuple, p::Tuple, cache::ReactionTangentControllerCache, u::AbstractArray, uprev::AbstractArray, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat)
OS.build_subintegrators_recursive(f, synchronizers, p, cache.ltg_cache, u, uprev, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat)
end
11 changes: 7 additions & 4 deletions src/solver/operator_splitting/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,17 @@ function DiffEqBase.__init(
callback = DiffEqBase.CallbackSet(callback)

cache = init_cache(prob, alg; dt, kwargs...)

u = get_u(cache)
uprev = get_uprev(cache)

subintegrators = build_subintegrators_recursive(prob.f, prob.f.synchronizers, p, cache, cache.u, cache.uprev, t0, dt, 1:length(u0), cache.u, tstops, _tstops, saveat, _saveat)
subintegrators = build_subintegrators_recursive(prob.f, prob.f.synchronizers, p, cache, u, uprev, t0, dt, 1:length(u0), u, tstops, _tstops, saveat, _saveat)

integrator = OperatorSplittingIntegrator(
prob.f,
alg,
cache.u,
cache.uprev,
u,
uprev,
p,
t0,
copy(t0),
Expand Down Expand Up @@ -250,7 +253,7 @@ end

function __step!(integrator)
(; dtchangeable, tstops) = integrator
_dt = DiffEqBase.isadaptive(integrator.alg) ? DiffEqBase.get_dt(integrator) : integrator.dt
_dt = DiffEqBase.get_dt(integrator)

# update dt before incrementing u; if dt is changeable and there is
# a tstop within dt, reduce dt to tstop - t
Expand Down
3 changes: 3 additions & 0 deletions src/solver/operator_splitting/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ struct LieTrotterGodunovCache{uType, tmpType, iiType} <: AbstractOperatorSplitti
inner_caches::iiType
end

get_u(cache#=::LieTrotterGodunovCache=#) = cache.u
get_uprev(cache#=::LieTrotterGodunovCache=#) = cache.uprev

# Dispatch for outer construction
function init_cache(prob::OperatorSplittingProblem, alg::LieTrotterGodunov; dt, kwargs...) # TODO
@unpack f = prob
Expand Down
72 changes: 34 additions & 38 deletions test/integration/test_electrophysiology.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using Thunderbolt
end
end

function solve_waveprop(mesh, coeff, subdomains, isadaptive = false)
function solve_waveprop(mesh, coeff, subdomains, timestepper)
cs = CoordinateSystemCoefficient(CartesianCoordinateSystem(mesh))
model = MonodomainModel(
ConstantCoefficient(1.0),
Expand All @@ -45,19 +45,6 @@ using Thunderbolt
mesh
)

_timestepper = LieTrotterGodunov((
BackwardEulerSolver(),
ForwardEulerCellSolver()
))
if isadaptive
timestepper = Thunderbolt.AdaptiveOperatorSplittingAlgorithm(
_timestepper,
Thunderbolt.ReactionTangentController(0.5, 1.0, (0.01, 0.3))
)
else
timestepper = _timestepper
end

u₀ = zeros(Float64, OS.function_size(odeform))
simple_initializer!(u₀, odeform)

Expand All @@ -71,40 +58,49 @@ using Thunderbolt
return integrator.u
end

timestepper = LieTrotterGodunov((
BackwardEulerSolver(),
ForwardEulerCellSolver()
))
timestepper_adaptive = Thunderbolt.ReactionTangentController(
timestepper,
0.5, 1.0, (0.98, 1.02)
)

mesh = generate_mesh(Hexahedron, (4, 4, 4), Vec{3}((0.0,0.0,0.0)), Vec{3}((1.0,1.0,1.0)))
coeff = ConstantCoefficient(SymmetricTensor{2,3,Float64}((4.5e-5, 0, 0, 2.0e-5, 0, 1.0e-5)))
u = solve_waveprop(mesh, coeff, [""])
u_adaptive = solve_waveprop(mesh, coeff, [""], true)
@test u u_adaptive
u = solve_waveprop(mesh, coeff, [""], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, [""], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4

mesh = generate_ideal_lv_mesh(4,1,1)
coeff = ConstantCoefficient(SymmetricTensor{2,3,Float64}((4.5e-5, 0, 0, 2.0e-5, 0, 1.0e-5)))
u = solve_waveprop(mesh, coeff, [""])
u_adaptive = solve_waveprop(mesh, coeff, [""], true)
@test u u_adaptive
u = solve_waveprop(mesh, coeff, [""], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, [""], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4

mesh = to_mesh(generate_mixed_grid_2D())
coeff = ConstantCoefficient(SymmetricTensor{2,2,Float64}((4.5e-5, 0, 2.0e-5)))
u = solve_waveprop(mesh, coeff, ["Pacemaker", "Myocardium"])
u_adaptive = solve_waveprop(mesh, coeff, ["Pacemaker", "Myocardium"], true)
@test u u_adaptive
u = solve_waveprop(mesh, coeff, ["Pacemaker"])
u_adaptive = solve_waveprop(mesh, coeff, ["Pacemaker"], true)
@test u u_adaptive
u = solve_waveprop(mesh, coeff, ["Myocardium"])
u_adaptive = solve_waveprop(mesh, coeff, ["Myocardium"], true)
@test u u_adaptive
u = solve_waveprop(mesh, coeff, ["Pacemaker", "Myocardium"], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, ["Pacemaker", "Myocardium"], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4
u = solve_waveprop(mesh, coeff, ["Pacemaker"], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, ["Pacemaker"], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4
u = solve_waveprop(mesh, coeff, ["Myocardium"], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, ["Myocardium"], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4

mesh = to_mesh(generate_mixed_dimensional_grid_3D())
coeff = ConstantCoefficient(SymmetricTensor{2,3,Float64}((4.5e-5, 0, 0, 2.0e-5, 0, 1.0e-5)))
u = solve_waveprop(mesh, coeff, ["Ventricle"])
u_adaptive = solve_waveprop(mesh, coeff, ["Ventricle"])
@test u u_adaptive
u = solve_waveprop(mesh, coeff, ["Ventricle"], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, ["Ventricle"], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4
coeff = ConstantCoefficient(SymmetricTensor{2,3,Float64}((5e-5, 0, 0, 5e-5, 0, 5e-5)))
u = solve_waveprop(mesh, coeff, ["Purkinje"])
u_adaptive = solve_waveprop(mesh, coeff, ["Purkinje"])
@test u u_adaptive
u = solve_waveprop(mesh, coeff, ["Ventricle", "Purkinje"])
u_adaptive = solve_waveprop(mesh, coeff, ["Ventricle", "Purkinje"])
@test u u_adaptive
u = solve_waveprop(mesh, coeff, ["Purkinje"], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, ["Purkinje"], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4
u = solve_waveprop(mesh, coeff, ["Ventricle", "Purkinje"], timestepper)
u_adaptive = solve_waveprop(mesh, coeff, ["Ventricle", "Purkinje"], timestepper_adaptive)
@test u u_adaptive rtol = 1e-4
end
8 changes: 3 additions & 5 deletions test/test_integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,18 @@ fsplit_NaN = GenericSplitFunction((f1,f_NaN), (f1dofs, f_NaN_dofs))

@testset "OperatorSplitting" begin
for TimeStepperType in (LieTrotterGodunov,)
for controller in (Thunderbolt.ReactionTangentController(0.5, 1.0, (0.01, 0.3)),)
timestepper = TimeStepperType(
(DummyForwardEuler(), DummyForwardEuler())
)
timestepper_adaptive = Thunderbolt.AdaptiveOperatorSplittingAlgorithm(timestepper, controller)
timestepper_adaptive = Thunderbolt.ReactionTangentController(timestepper, 0.5, 1.0, (0.01, 0.3))
timestepper_inner = TimeStepperType(
(DummyForwardEuler(), DummyForwardEuler())
)
timestepper_inner_adaptive = Thunderbolt.AdaptiveOperatorSplittingAlgorithm(timestepper_inner, controller) #TODO: Copy the controller instead
timestepper_inner_adaptive = Thunderbolt.ReactionTangentController(timestepper_inner, 0.5, 1.0, (0.01, 0.3)) #TODO: Copy the controller instead
timestepper2 = TimeStepperType(
(DummyForwardEuler(), timestepper_inner)
)
timestepper2_adaptive = Thunderbolt.AdaptiveOperatorSplittingAlgorithm(timestepper2, controller)
timestepper2_adaptive = Thunderbolt.ReactionTangentController(timestepper2, 0.5, 1.0, (0.01, 0.3))

for (tstepper1, tstepper_inner, tstepper2) in (
(timestepper, timestepper_inner, timestepper2),
Expand Down Expand Up @@ -162,7 +161,6 @@ fsplit_NaN = GenericSplitFunction((f1,f_NaN), (f1dofs, f_NaN_dofs))
# DiffEqBase.solve!(integrator_NaN)
# @test integrator_NaN.sol.retcode == DiffEqBase.ReturnCode.Failure
end
end
# integrator = DiffEqBase.init(prob, timestepper, dt=0.01, verbose=true)
# for (u, t) in DiffEqBase.TimeChoiceIterator(integrator, 0.0:5.0:100.0) end
# integrator_adaptive = DiffEqBase.init(prob, timestepper_adaptive, dt=0.01, verbose=true)
Expand Down

0 comments on commit b5b4195

Please sign in to comment.