From 6b3484e77d2e2988e6168c6392c358555be50437 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Sep 2024 01:43:20 +0200 Subject: [PATCH 1/8] Reduce allocs in OS module. --- src/solver/operator_splitting/integrator.jl | 23 ++++++--------------- src/solver/operator_splitting/solver.jl | 19 +++++++---------- src/solver/time/euler.jl | 6 +++++- src/solver/time/partitioned_solver.jl | 22 +++++++++++++++----- src/solver/time/time_integrator.jl | 23 ++++++++------------- 5 files changed, 45 insertions(+), 48 deletions(-) diff --git a/src/solver/operator_splitting/integrator.jl b/src/solver/operator_splitting/integrator.jl index 7044467c..3dd6e12f 100644 --- a/src/solver/operator_splitting/integrator.jl +++ b/src/solver/operator_splitting/integrator.jl @@ -72,18 +72,15 @@ function DiffEqBase.__init( callback = DiffEqBase.CallbackSet(callback) - cache = init_cache(prob, alg; dt, kwargs...) - - u = cache.u - uprev = cache.uprev + cache = init_cache(prob, alg; u0, t0, dt, kwargs...) - subintegrators = build_subintegrators_recursive(prob.f, prob.f.synchronizers, p, cache, u, uprev, t0, dt, 1:length(u0), u, tstops, _tstops, saveat, _saveat) + subintegrators = build_subintegrators_recursive(prob.f, prob.f.synchronizers, p, cache, t0, dt, 1:length(u0), cache.u, tstops, _tstops, saveat, _saveat) integrator = OperatorSplittingIntegrator( prob.f, alg, - u, - uprev, + cache.u, + cache.uprev, p, t0, copy(t0), @@ -342,7 +339,7 @@ end advance_solution_to!(integrator::OperatorSplittingIntegrator, cache::AbstractOperatorSplittingCache, tnext::Number) = advance_solution_to!(integrator.subintegrators, cache, tnext) # Dispatch for tree node construction -function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::Tuple, p::Tuple, cache::AbstractOperatorSplittingCache, u::AbstractArray, uprev::AbstractArray, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) +function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::Tuple, p::Tuple, cache::AbstractOperatorSplittingCache, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) return ntuple(i -> build_subintegrators_recursive( get_operator(f, i), @@ -350,10 +347,6 @@ function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers:: p[i], cache.inner_caches[i], # TODO recover this - # cache.inner_caches[i].u, - # cache.inner_caches[i].uprev, - similar(u, length(f.dof_ranges[i])), - similar(uprev, length(f.dof_ranges[i])), t, dt, f.dof_ranges[i], # We pass the full solution, because some parameters might require # access to solution variables which are not part of the local solution range @@ -362,7 +355,7 @@ function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers:: ), length(f.functions) ) end -function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::NoExternalSynchronization, p::Tuple, cache::AbstractOperatorSplittingCache, u::AbstractArray, uprev::AbstractArray, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) +function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::NoExternalSynchronization, p::Tuple, cache::AbstractOperatorSplittingCache, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) return ntuple(i -> build_subintegrators_recursive( get_operator(f, i), @@ -370,10 +363,6 @@ function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers:: p[i], cache.inner_caches[i], # TODO recover this - # cache.inner_caches[i].u, - # cache.inner_caches[i].uprev, - similar(u, length(f.dof_ranges[i])), - similar(uprev, length(f.dof_ranges[i])), t, dt, f.dof_ranges[i], # We pass the full solution, because some parameters might require # access to solution variables which are not part of the local solution range diff --git a/src/solver/operator_splitting/solver.jl b/src/solver/operator_splitting/solver.jl index 69d9640e..45fc8623 100644 --- a/src/solver/operator_splitting/solver.jl +++ b/src/solver/operator_splitting/solver.jl @@ -13,36 +13,33 @@ end struct LieTrotterGodunovCache{uType, tmpType, iiType} <: AbstractOperatorSplittingCache u::uType uprev::uType # True previous solution - uprev2::tmpType # Previous solution used during time marching tmp::tmpType # Scratch inner_caches::iiType end # Dispatch for outer construction -function init_cache(prob::OperatorSplittingProblem, alg::LieTrotterGodunov; dt, kwargs...) # TODO +function init_cache(prob::OperatorSplittingProblem, alg::LieTrotterGodunov; u0, kwargs...) # TODO @unpack f = prob @assert f isa GenericSplitFunction - u = copy(prob.u0) - uprev = copy(prob.u0) - # Build inner integrator - return construct_inner_cache(f, alg, u, uprev) + return construct_inner_cache(f, alg; u0, kwargs...) end # Dispatch for recursive construction -function construct_inner_cache(f::AbstractOperatorSplitFunction, alg::LieTrotterGodunov, u::AbstractArray, uprev::AbstractArray) +function construct_inner_cache(f::AbstractOperatorSplitFunction, alg::LieTrotterGodunov; u0, kwargs...) dof_ranges = f.dof_ranges - uprev2 = similar(uprev) + u = copy(u0) + uprev = copy(u0) tmp = similar(u) - inner_caches = ntuple(i->construct_inner_cache(get_operator(f, i), alg.inner_algs[i], similar(u, length(dof_ranges[i])), similar(u, length(dof_ranges[i]))), length(f.functions)) - LieTrotterGodunovCache(u, uprev, uprev2, tmp, inner_caches) + inner_caches = ntuple(i->construct_inner_cache(get_operator(f, i), alg.inner_algs[i]; u0, kwargs...), length(f.functions)) + LieTrotterGodunovCache(u, uprev, tmp, inner_caches) end @inline @unroll function advance_solution_to!(subintegrators::Tuple, cache::LieTrotterGodunovCache, tnext) # We assume that the integrators are already synced - @unpack u, uprev2, uprev, inner_caches = cache + @unpack u, uprev, inner_caches = cache # Store current solution uprev .= u diff --git a/src/solver/time/euler.jl b/src/solver/time/euler.jl index 8b7101ca..a7eb2ea5 100644 --- a/src/solver/time/euler.jl +++ b/src/solver/time/euler.jl @@ -74,6 +74,10 @@ function perform_step!(f::TransientDiffusionFunction, cache::BackwardEulerSolver return !solve_failed end +function init_cache(prob, alg::BackwardEulerSolver; t0) + return setup_solver_cache(prob.f, alg, t0) +end + function setup_solver_cache(f::TransientDiffusionFunction, solver::BackwardEulerSolver, t₀) @unpack dh = f @unpack inner_solver = solver @@ -86,7 +90,7 @@ function setup_solver_cache(f::TransientDiffusionFunction, solver::BackwardEuler uprev = create_system_vector(solver.solution_vector_type, f) tmp = create_system_vector(solver.solution_vector_type, f) - T = eltype(A) + T = eltype(u0) qr = create_quadrature_rule(f, solver, field_name) diff --git a/src/solver/time/partitioned_solver.jl b/src/solver/time/partitioned_solver.jl index fdff04e3..8878051c 100644 --- a/src/solver/time/partitioned_solver.jl +++ b/src/solver/time/partitioned_solver.jl @@ -36,7 +36,6 @@ struct ForwardEulerCellSolverCache{duType, uType, dumType, umType, xType} <: Abs # These vectors hold the data uₙ::uType uₙ₋₁::uType - tmp::uType # These array view the data above to give easy indices of the form [ode index, local state index] dumat::dumType uₙmat::umType @@ -62,6 +61,12 @@ Adapt.@adapt_structure ForwardEulerCellSolverCache return true end +function init_cache(prob, alg::ForwardEulerCellSolver; t0) + cache = setup_solver_cache(prob.f, alg, t0) + resize(cache.uₙ₋₁, size(cache.uₙ)) + return cache +end + function setup_solver_cache(f::PointwiseODEFunction, solver::ForwardEulerCellSolver, t₀) @unpack npoints, ode = f ndofs_local = num_states(ode) @@ -69,12 +74,12 @@ function setup_solver_cache(f::PointwiseODEFunction, solver::ForwardEulerCellSol du = create_system_vector(solver.solution_vector_type, f) dumat = reshape(du, (npoints,ndofs_local)) uₙ = create_system_vector(solver.solution_vector_type, f) - uₙ₋₁ = create_system_vector(solver.solution_vector_type, f) - tmp = create_system_vector(solver.solution_vector_type, f) + # uₙ₋₁ = create_system_vector(solver.solution_vector_type, f) + uₙ₋₁ = similar(uₙ, 0) uₙmat = reshape(uₙ, (npoints,ndofs_local)) xs = f.x === nothing ? nothing : Adapt.adapt(solver.solution_vector_type, f.x) - return ForwardEulerCellSolverCache(du, uₙ, uₙ₋₁, tmp, dumat, uₙmat, solver.batch_size_hint, xs) + return ForwardEulerCellSolverCache(du, uₙ, uₙ₋₁, dumat, uₙmat, solver.batch_size_hint, xs) end Base.@kwdef struct AdaptiveForwardEulerSubstepper{T, SolutionVectorType <: AbstractVector{T}} <: AbstractPointwiseSolver @@ -135,6 +140,13 @@ Adapt.@adapt_structure AdaptiveForwardEulerSubstepperCache return true end + +function init_cache(prob, alg::AdaptiveForwardEulerSubstepper; t0) + cache = setup_solver_cache(prob.f, alg, t0) + resize(cache.uₙ₋₁, size(cache.uₙ)) + return cache +end + function setup_solver_cache(f::PointwiseODEFunction, solver::AdaptiveForwardEulerSubstepper, t₀) @unpack npoints, ode = f ndofs_local = num_states(ode) @@ -142,7 +154,7 @@ function setup_solver_cache(f::PointwiseODEFunction, solver::AdaptiveForwardEule du = create_system_vector(solver.solution_vector_type, f) dumat = reshape(du, (npoints,ndofs_local)) uₙ = create_system_vector(solver.solution_vector_type, f) - uₙ₋₁ = create_system_vector(solver.solution_vector_type, f) + uₙ₋₁ = similar(uₙ, 0) uₙmat = reshape(uₙ, (npoints,ndofs_local)) xs = if f.x === nothing nothing diff --git a/src/solver/time/time_integrator.jl b/src/solver/time/time_integrator.jl index c4909852..429a3922 100644 --- a/src/solver/time/time_integrator.jl +++ b/src/solver/time/time_integrator.jl @@ -84,11 +84,10 @@ end # Copy solution into subproblem uparentview = @view subintegrator.uparent[subintegrator.indexset] subintegrator.u .= uparentview - # for (i,imain) in enumerate(subintegrator.indexset) - # subintegrator.u[i] = subintegrator.uparent[imain] - # end - # Mark previous solution - subintegrator.uprev .= subintegrator.u + # Mark previous solution, if necessary + if subintegrator.uprev !== nothing && length(subintegrator.uprev) > 0 + subintegrator.uprev .= subintegrator.u + end syncronize_parameters!(subintegrator, subintegrator.f, subintegrator.synchronizer) end @inline function OS.finalize_local_step!(subintegrator::ThunderboltTimeIntegrator) @@ -96,12 +95,9 @@ end # uparentview = @view subintegrator.uparent[subintegrator.indexset] uparentview .= subintegrator.u - # for (i,imain) in enumerate(subintegrator.indexset) - # subintegrator.uparent[imain] = subintegrator.u[i] - # end end # Glue code -function OS.build_subintegrators_recursive(f, synchronizer, p::Any, cache::AbstractTimeSolverCache, u::AbstractArray, uprev::AbstractArray, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) +function OS.build_subintegrators_recursive(f, synchronizer, p::Any, cache::AbstractTimeSolverCache, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) integrator = Thunderbolt.ThunderboltTimeIntegrator( f, cache.uₙ, @@ -125,8 +121,8 @@ function OS.build_subintegrators_recursive(f, synchronizer, p::Any, cache::Abstr syncronize_parameters!(integrator, f, synchronizer) return integrator end -function OS.construct_inner_cache(f, alg::AbstractSolver, u::AbstractArray, uprev::AbstractArray) - return Thunderbolt.setup_solver_cache(f, alg, 0.0) +function OS.construct_inner_cache(f, alg::AbstractSolver; u0, t0, kwargs...) + return Thunderbolt.setup_solver_cache(f, alg, t0) end OS.recursive_null_parameters(stuff::Union{AbstractSemidiscreteProblem, AbstractSemidiscreteFunction}) = OS.DiffEqBase.NullParameters() syncronize_parameters!(integ, f, ::OS.NoExternalSynchronization) = nothing @@ -162,10 +158,9 @@ function DiffEqBase.__init( callback = DiffEqBase.CallbackSet(callback) - cache = setup_solver_cache(f, alg, t0) + cache = init_cache(prob, alg; t0) - cache.uₙ .= u0 - cache.uₙ₋₁ .= u0 + cache.uₙ .= u0 integrator = ThunderboltTimeIntegrator( f, From 0f32cb0221bb68e7adc0150e290c0c58eb2a8ff9 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Sep 2024 02:51:02 +0200 Subject: [PATCH 2/8] Remove uparent from ThunderboltTimeIntegrator --- src/solver/operator_splitting/integrator.jl | 20 +++++++------- src/solver/operator_splitting/solver.jl | 14 +++++----- src/solver/time/rtc.jl | 21 +++++++-------- src/solver/time/time_integrator.jl | 17 ++++-------- test/test_integrators.jl | 30 +++++++++------------ 5 files changed, 44 insertions(+), 58 deletions(-) diff --git a/src/solver/operator_splitting/integrator.jl b/src/solver/operator_splitting/integrator.jl index 3dd6e12f..928e99b9 100644 --- a/src/solver/operator_splitting/integrator.jl +++ b/src/solver/operator_splitting/integrator.jl @@ -277,8 +277,8 @@ function __step!(integrator) synchronize_subintegrators!(integrator) tnext = integrator.t + integrator.dt - # Solve inner problems - advance_solution_to!(integrator, tnext) + # Solve inner problems + advance_solution_to!(integrator, tnext; uparent=integrator.u) stepsize_controller!(integrator) # Update integrator @@ -299,8 +299,8 @@ function __step!(integrator) end # solvers need to define this interface -function advance_solution_to!(integrator, tnext) - advance_solution_to!(integrator, integrator.cache, tnext) +function advance_solution_to!(integrator, tnext; uparent) + advance_solution_to!(integrator, integrator.cache, tnext; uparent) end DiffEqBase.get_dt(integrator::OperatorSplittingIntegrator) = integrator._dt @@ -336,7 +336,9 @@ end end end -advance_solution_to!(integrator::OperatorSplittingIntegrator, cache::AbstractOperatorSplittingCache, tnext::Number) = advance_solution_to!(integrator.subintegrators, cache, tnext) +function advance_solution_to!(integrator::OperatorSplittingIntegrator, cache::AbstractOperatorSplittingCache, tnext::Number; uparent) + advance_solution_to!(integrator.subintegrators, cache, tnext; uparent) +end # Dispatch for tree node construction function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::Tuple, p::Tuple, cache::AbstractOperatorSplittingCache, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) @@ -372,14 +374,14 @@ function build_subintegrators_recursive(f::GenericSplitFunction, synchronizers:: ) end -@unroll function prepare_local_step!(subintegrators::Tuple) +@unroll function prepare_local_step!(uparent, subintegrators::Tuple) @unroll for subintegrator in subintegrators - prepare_local_step!(subintegrator) + prepare_local_step!(uparent, subintegrator) end end -@unroll function finalize_local_step!(subintegrators::Tuple) +@unroll function finalize_local_step!(uparent, subintegrators::Tuple) @unroll for subintegrator in subintegrators - finalize_local_step!(subintegrator) + finalize_local_step!(uparent, subintegrator) end end diff --git a/src/solver/operator_splitting/solver.jl b/src/solver/operator_splitting/solver.jl index 45fc8623..9d67af4c 100644 --- a/src/solver/operator_splitting/solver.jl +++ b/src/solver/operator_splitting/solver.jl @@ -23,21 +23,21 @@ function init_cache(prob::OperatorSplittingProblem, alg::LieTrotterGodunov; u0, @assert f isa GenericSplitFunction # Build inner integrator - return construct_inner_cache(f, alg; u0, kwargs...) + return construct_inner_cache(f, alg; uparent=u0, u0, kwargs...) end # Dispatch for recursive construction -function construct_inner_cache(f::AbstractOperatorSplitFunction, alg::LieTrotterGodunov; u0, kwargs...) +function construct_inner_cache(f::AbstractOperatorSplitFunction, alg::LieTrotterGodunov; uparent, u0, kwargs...) dof_ranges = f.dof_ranges u = copy(u0) uprev = copy(u0) tmp = similar(u) - inner_caches = ntuple(i->construct_inner_cache(get_operator(f, i), alg.inner_algs[i]; u0, kwargs...), length(f.functions)) + inner_caches = ntuple(i->construct_inner_cache(get_operator(f, i), alg.inner_algs[i]; uparent, u0=view(uparent,dof_ranges[i]), kwargs...), length(f.functions)) LieTrotterGodunovCache(u, uprev, tmp, inner_caches) end -@inline @unroll function advance_solution_to!(subintegrators::Tuple, cache::LieTrotterGodunovCache, tnext) +@inline @unroll function advance_solution_to!(subintegrators::Tuple, cache::LieTrotterGodunovCache, tnext; uparent) # We assume that the integrators are already synced @unpack u, uprev, inner_caches = cache @@ -48,8 +48,8 @@ end i = 0 @unroll for subinteg in subintegrators i += 1 - prepare_local_step!(subinteg) - advance_solution_to!(subinteg, inner_caches[i], tnext) - finalize_local_step!(subinteg) + prepare_local_step!(uparent, subinteg) + advance_solution_to!(subinteg, inner_caches[i], tnext; uparent) + finalize_local_step!(uparent, subinteg) end end diff --git a/src/solver/time/rtc.jl b/src/solver/time/rtc.jl index fdb37850..eb2e6752 100644 --- a/src/solver/time/rtc.jl +++ b/src/solver/time/rtc.jl @@ -32,8 +32,8 @@ end @inline DiffEqBase.get_tmp_cache(integrator::OS.OperatorSplittingIntegrator, alg::OS.AbstractOperatorSplittingAlgorithm, cache::ReactionTangentControllerCache) = DiffEqBase.get_tmp_cache(integrator, alg, cache.ltg_cache) -@inline function OS.advance_solution_to!(subintegrators::Tuple, cache::ReactionTangentControllerCache, tnext) - OS.advance_solution_to!(subintegrators, cache.ltg_cache, tnext) +@inline function OS.advance_solution_to!(subintegrators::Tuple, cache::ReactionTangentControllerCache, tnext; kwargs...) + OS.advance_solution_to!(subintegrators, cache.ltg_cache, tnext; kwargs...) end @inline DiffEqBase.isadaptive(::ReactionTangentController) = true @@ -85,23 +85,20 @@ end end # Dispatch for outer construction -function OS.init_cache(prob::OS.OperatorSplittingProblem, alg::ReactionTangentController; dt, kwargs...) +function OS.init_cache(prob::OS.OperatorSplittingProblem, alg::ReactionTangentController; u0, kwargs...) @unpack f = prob @assert f isa GenericSplitFunction - u = copy(prob.u0) - uprev = copy(prob.u0) - # Build inner integrator - return OS.construct_inner_cache(f, alg, u, uprev) + return OS.construct_inner_cache(f, alg; uparent=u0, u0, kwargs...) end # Dispatch for recursive construction -function OS.construct_inner_cache(f::OS.AbstractOperatorSplitFunction, alg::ReactionTangentController, u::AbstractArray{T}, uprev::AbstractArray) where T <: Number - ltg_cache = OS.construct_inner_cache(f, alg.ltg, u, uprev) - return ReactionTangentControllerCache(ltg_cache, zero(T)) +function OS.construct_inner_cache(f::OS.AbstractOperatorSplitFunction, alg::ReactionTangentController; u0, kwargs...) + ltg_cache = OS.construct_inner_cache(f, alg.ltg; u0, kwargs...) + return ReactionTangentControllerCache(ltg_cache, zero(eltype(u0))) end -function OS.build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::Tuple, p::Tuple, cache::ReactionTangentControllerCache, u::AbstractArray, uprev::AbstractArray, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) - OS.build_subintegrators_recursive(f, synchronizers, p, cache.ltg_cache, u, uprev, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) +function OS.build_subintegrators_recursive(f::GenericSplitFunction, synchronizers::Tuple, p::Tuple, cache::ReactionTangentControllerCache, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) + OS.build_subintegrators_recursive(f, synchronizers, p, cache.ltg_cache, t, dt, dof_range, uparent, tstops, _tstops, saveat, _saveat) end diff --git a/src/solver/time/time_integrator.jl b/src/solver/time/time_integrator.jl index 429a3922..d8f2dd65 100644 --- a/src/solver/time/time_integrator.jl +++ b/src/solver/time/time_integrator.jl @@ -5,7 +5,6 @@ over some time interval. mutable struct ThunderboltTimeIntegrator{ fType, uType, - uType2, uprevType, indexSetType, tType, @@ -19,7 +18,6 @@ mutable struct ThunderboltTimeIntegrator{ } <: DiffEqBase.SciMLBase.DEIntegrator{#=alg_type=#Nothing, true, uType, tType} # FIXME alg f::fType # Right hand side u::uType # Current local solution - uparent::uType2 # Real solution injected by OperatorSplittingIntegrator uprev::uprevType indexset::indexSetType p::pType @@ -75,14 +73,14 @@ end OS.tdir(::ThunderboltTimeIntegrator) = 1 # TODO Any -> cache supertype -function OS.advance_solution_to!(integrator::ThunderboltTimeIntegrator, cache::Any, tend) +function OS.advance_solution_to!(integrator::ThunderboltTimeIntegrator, cache::Any, tend; kwargs...) @unpack f, t = integrator dt = tend-t dt ≈ 0.0 || DiffEqBase.step!(integrator, dt, true) end -@inline function OS.prepare_local_step!(subintegrator::ThunderboltTimeIntegrator) +@inline function OS.prepare_local_step!(uparent, subintegrator::ThunderboltTimeIntegrator) # Copy solution into subproblem - uparentview = @view subintegrator.uparent[subintegrator.indexset] + uparentview = @view uparent[subintegrator.indexset] subintegrator.u .= uparentview # Mark previous solution, if necessary if subintegrator.uprev !== nothing && length(subintegrator.uprev) > 0 @@ -90,10 +88,10 @@ end end syncronize_parameters!(subintegrator, subintegrator.f, subintegrator.synchronizer) end -@inline function OS.finalize_local_step!(subintegrator::ThunderboltTimeIntegrator) +@inline function OS.finalize_local_step!(uparent, subintegrator::ThunderboltTimeIntegrator) # Copy solution out of subproblem # - uparentview = @view subintegrator.uparent[subintegrator.indexset] + uparentview = @view uparent[subintegrator.indexset] uparentview .= subintegrator.u end # Glue code @@ -101,7 +99,6 @@ function OS.build_subintegrators_recursive(f, synchronizer, p::Any, cache::Abstr integrator = Thunderbolt.ThunderboltTimeIntegrator( f, cache.uₙ, - uparent, cache.uₙ₋₁, dof_range, p, @@ -139,7 +136,6 @@ function DiffEqBase.__init( advance_to_tstop = false, save_func = (u, t) -> copy(u), # custom kwarg dtchangeable = true, # custom kwarg - uparent = nothing, # custom kwarg syncronizer = OS.NoExternalSynchronization(), # custom kwarg kwargs..., ) @@ -165,7 +161,6 @@ function DiffEqBase.__init( integrator = ThunderboltTimeIntegrator( f, cache.uₙ, - uparent, cache.uₙ₋₁, 1:length(u0), p, @@ -200,8 +195,6 @@ end @inline get_parent_index(integ::ThunderboltTimeIntegrator, local_idx::Int, range::AbstractUnitRange) = first(range) + local_idx - 1 @inline get_parent_index(integ::ThunderboltTimeIntegrator, local_idx::Int, range::StepRange) = first(range) + range.step*(local_idx - 1) -@inline get_parent_value(integ::ThunderboltTimeIntegrator, local_idx::Int) = integ.uparent[get_parent_index(integ, local_idx)] - # Compat with OrdinaryDiffEq function perform_step!(integ::ThunderboltTimeIntegrator, cache::AbstractTimeSolverCache) if !perform_step!(integ.f, cache, integ.t, integ.dt) diff --git a/test/test_integrators.jl b/test/test_integrators.jl index be21344c..30ee832b 100644 --- a/test/test_integrators.jl +++ b/test/test_integrators.jl @@ -2,7 +2,7 @@ import Thunderbolt: OS, ThunderboltTimeIntegrator # using BenchmarkTools using UnPack -@testset "Operator Splitting API" begin +# @testset "Operator Splitting API" begin ODEFunction = Thunderbolt.DiffEqBase.ODEFunction @@ -18,12 +18,15 @@ using UnPack end # Dispatch for leaf construction - function OS.construct_inner_cache(f::ODEFunction, alg::DummyForwardEuler, u::AbstractArray, uprev::AbstractArray) + function OS.construct_inner_cache(f::ODEFunction, alg::DummyForwardEuler; u0, kwargs...) + du = copy(u0) + u = copy(u0) + uprev = copy(u0) dumat = reshape(uprev, (:,1)) - DummyForwardEulerCache(copy(uprev), dumat, copy(uprev), copy(uprev)) + DummyForwardEulerCache(du, dumat, u, uprev) end - Thunderbolt.num_states(::Any) = 2 - Thunderbolt.transmembranepotential_index(::Any) = 1 + Thunderbolt.num_states(::ODEFunction) = 2 # FIXME + Thunderbolt.transmembranepotential_index(::ODEFunction) = 1 # FIXME function Thunderbolt.setup_solver_cache(f::PointwiseODEFunction, solver::DummyForwardEuler, t₀) @unpack npoints, ode = f @@ -37,13 +40,15 @@ using UnPack end # Dispatch innermost solve - function OS.advance_solution_to!(integ::ThunderboltTimeIntegrator, cache::DummyForwardEulerCache, tend) + function Thunderbolt.perform_step!(integ::ThunderboltTimeIntegrator, cache::DummyForwardEulerCache) @unpack f, dt, u, p, t = integ @unpack du = cache f isa Thunderbolt.PointwiseODEFunction ? f.ode(du, u, p, t) : f(du, u, p, t) @. u += dt * du cache.dumat[:,1] .= du + + return true end # Operator splitting @@ -228,15 +233,4 @@ using UnPack end end end - -# tnext = tspan[1]+0.01 -# @btime OS.advance_solution_to!($integrator, $tnext) setup=(DiffEqBase.reinit!(integrator, u0; tspan)) -# 326.743 ns (8 allocations: 416 bytes) for 1 (OUTDATED -# 89.949 ns (0 allocations: 0 bytes) for 2 (OUTDATED -# 31.418 ns (0 allocations: 0 bytes) for 3 -# @btime DiffEqBase.solve!($integrator) setup=(DiffEqBase.reinit!(integrator, u0; tspan)); -# 431.632 μs (10000 allocations: 507.81 KiB) for 1 (OUTDATED -# 105.712 μs (0 allocations: 0 bytes) for 2 (OUTDATED) -# 1.852 μs (0 allocations: 0 bytes) for 3 - -end +# end From a2f7c8c8a2df5d689b1455881a35cb79d70038e9 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Sep 2024 03:17:37 +0200 Subject: [PATCH 3/8] Fix CI --- src/Thunderbolt.jl | 1 - src/modeling/core/coordinate_systems.jl | 71 ++++++++----------------- src/solver/time/euler.jl | 4 -- src/solver/time/time_integrator.jl | 4 ++ 4 files changed, 27 insertions(+), 53 deletions(-) diff --git a/src/Thunderbolt.jl b/src/Thunderbolt.jl index b46e3e21..9bb0ad62 100644 --- a/src/Thunderbolt.jl +++ b/src/Thunderbolt.jl @@ -206,7 +206,6 @@ export AdaptiveForwardEulerSubstepper, # Integrator get_parent_index, - get_parent_value, # Utils calculate_volume_deformed_mesh, elementtypes, diff --git a/src/modeling/core/coordinate_systems.jl b/src/modeling/core/coordinate_systems.jl index 9a2a80fc..e901fe86 100644 --- a/src/modeling/core/coordinate_systems.jl +++ b/src/modeling/core/coordinate_systems.jl @@ -68,8 +68,8 @@ Requires a mesh with facetsets and a nodeset * Apex """ -function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{String} = [""]; up = Vec((0.0,0.0,1.0))) - @assert up ≈ Vec((0.0,0.0,1.0)) "Custom up vector not yet supported." +function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{String} = [""]; up = Vec((0.0,0.0,-1.0))) + @assert abs.(up) ≈ Vec((0.0,0.0,1.0)) "Custom up vector not yet supported." ip_collection = LagrangeCollection{1}() qr_collection = QuadratureRuleCollection(2) cv_collection = CellValueCollection(qr_collection, ip_collection) @@ -119,38 +119,18 @@ function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{St close!(ch) update!(ch, 0.0); - K_transmural = copy(K) + K_transmural = K f = zeros(ndofs(dh)) apply!(K_transmural, f, ch) - transmural = K_transmural \ f; + transmural = solve(LinearSolve.LinearProblem(K_transmural, f)).u # Apicobasal coordinate - #TODO refactor check for node set existence - if !haskey(mesh.grid.nodesets, "Apex") #TODO this is just a hotfix, assuming that z points towards the apex - apex_node_index = 1 - nodes = getnodes(mesh) - for (i,node) ∈ enumerate(nodes) - if nodes[i].x[3] > nodes[apex_node_index].x[3] - apex_node_index = i - end - end - addnodeset!(mesh, "Apex", OrderedSet{Int}((apex_node_index))) - end - - ch = ConstraintHandler(dh); - dbc = Dirichlet(:coordinates, getfacetset(mesh, "Base"), (x, t) -> 0) - Ferrite.add!(ch, dbc); - dbc = Dirichlet(:coordinates, getnodeset(mesh, "Apex"), (x, t) -> 1) - Ferrite.add!(ch, dbc); - close!(ch) - update!(ch, 0.0); - - K_apicobasal = copy(K) - f = zeros(ndofs(dh)) - - apply!(K_apicobasal, f, ch) - apicobasal = K_apicobasal \ f; + apicobasal = zeros(ndofs(dh)) + apply_analytical!(apicobasal, dh, :coordinates, x->x ⋅ up) + apicobasal .-= minimum(apicobasal) + apicobasal = abs.(apicobasal) + apicobasal ./= maximum(apicobasal) rotational = zeros(ndofs(dh)) rotational .= NaN @@ -175,7 +155,7 @@ function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{St rotational[dofs[qp.i]] = 0.0 else x = x_planar / xlen - rotational[dofs[qp.i]] = (π + atan(x[1], x[2]))/2 # TODO tilted coordinate system + rotational[dofs[qp.i]] = 1 + atan(x[1], x[2])/π # TODO tilted coordinate system end end end @@ -194,7 +174,7 @@ Requires a mesh with facetsets * Myocardium """ function compute_midmyocardial_section_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{String} = [""]; up = Vec((0.0,0.0,1.0))) - @assert up ≈ Vec((0.0,0.0,1.0)) "Custom up vector not yet supported." + @assert abs.(up) ≈ Vec((0.0,0.0,1.0)) "Custom up vector not yet supported." ip_collection = LagrangeCollection{1}() qr_collection = QuadratureRuleCollection(2) cv_collection = CellValueCollection(qr_collection, ip_collection) @@ -244,26 +224,21 @@ function compute_midmyocardial_section_coordinate_system(mesh::SimpleMesh{3}, su close!(ch) update!(ch, 0.0); - K_transmural = copy(K) + K_transmural = K f = zeros(ndofs(dh)) apply!(K_transmural, f, ch) - transmural = K_transmural \ f; - - ch = ConstraintHandler(dh); - dbc = Dirichlet(:coordinates, getfacetset(mesh, "Base"), (x, t) -> 0) - Ferrite.add!(ch, dbc); - dbc = Dirichlet(:coordinates, getfacetset(mesh, "Myocardium"), (x, t) -> 0.15) - Ferrite.add!(ch, dbc); - close!(ch) - update!(ch, 0.0); - - K_apicobasal = copy(K) - f = zeros(ndofs(dh)) - - apply!(K_apicobasal, f, ch) - apicobasal = K_apicobasal \ f; + transmural = solve(LinearSolve.LinearProblem(K_transmural, f)).u + # Apicobasal coordinate + apicobasal = zeros(ndofs(dh)) + apply_analytical!(apicobasal, dh, :coordinates, x->x ⋅ up) + apicobasal .-= minimum(apicobasal) + apicobasal = abs.(apicobasal) + apicobasal ./= maximum(apicobasal) + apicobasal .*= 0.15 + + # Rotational coordinate rotational = zeros(ndofs(dh)) rotational .= NaN @@ -283,7 +258,7 @@ function compute_midmyocardial_section_coordinate_system(mesh::SimpleMesh{3}, su x_planar = x_dof - (x_dof ⋅ up) * up # Project into plane x = x_planar / norm(x_planar) - rotational[dofs[qp.i]] = (π + atan(x[1], x[2]))/2 # TODO tilted coordinate system + rotational[dofs[qp.i]] = 1 + atan(x[1], x[2])/π # TODO tilted coordinate system end end end diff --git a/src/solver/time/euler.jl b/src/solver/time/euler.jl index a7eb2ea5..9a766ee9 100644 --- a/src/solver/time/euler.jl +++ b/src/solver/time/euler.jl @@ -74,10 +74,6 @@ function perform_step!(f::TransientDiffusionFunction, cache::BackwardEulerSolver return !solve_failed end -function init_cache(prob, alg::BackwardEulerSolver; t0) - return setup_solver_cache(prob.f, alg, t0) -end - function setup_solver_cache(f::TransientDiffusionFunction, solver::BackwardEulerSolver, t₀) @unpack dh = f @unpack inner_solver = solver diff --git a/src/solver/time/time_integrator.jl b/src/solver/time/time_integrator.jl index d8f2dd65..72194875 100644 --- a/src/solver/time/time_integrator.jl +++ b/src/solver/time/time_integrator.jl @@ -205,3 +205,7 @@ function perform_step!(integ::ThunderboltTimeIntegrator, cache::AbstractTimeSolv end return true end + +function init_cache(prob, alg; t0, kwargs...) + return setup_solver_cache(prob.f, alg, t0) +end From e0abe2fd87d921d577ec85ae4080b0e09e096dc8 Mon Sep 17 00:00:00 2001 From: termi-official Date: Sat, 21 Sep 2024 04:12:47 +0200 Subject: [PATCH 4/8] Scaling --- src/modeling/core/coordinate_systems.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modeling/core/coordinate_systems.jl b/src/modeling/core/coordinate_systems.jl index e901fe86..94cb1ff6 100644 --- a/src/modeling/core/coordinate_systems.jl +++ b/src/modeling/core/coordinate_systems.jl @@ -155,7 +155,7 @@ function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{St rotational[dofs[qp.i]] = 0.0 else x = x_planar / xlen - rotational[dofs[qp.i]] = 1 + atan(x[1], x[2])/π # TODO tilted coordinate system + rotational[dofs[qp.i]] = 1/2 + atan(x[1], x[2])/(2π) # TODO tilted coordinate system end end end @@ -258,7 +258,7 @@ function compute_midmyocardial_section_coordinate_system(mesh::SimpleMesh{3}, su x_planar = x_dof - (x_dof ⋅ up) * up # Project into plane x = x_planar / norm(x_planar) - rotational[dofs[qp.i]] = 1 + atan(x[1], x[2])/π # TODO tilted coordinate system + rotational[dofs[qp.i]] = 1/2 + atan(x[1], x[2])/(2π) # TODO tilted coordinate system end end end From b90318dc00612f79b0f6c264b3ffa728806da8da Mon Sep 17 00:00:00 2001 From: termi-official Date: Fri, 27 Sep 2024 18:26:45 +0200 Subject: [PATCH 5/8] Default to CG for coordinate system --- src/modeling/core/coordinate_systems.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/modeling/core/coordinate_systems.jl b/src/modeling/core/coordinate_systems.jl index 94cb1ff6..24829efd 100644 --- a/src/modeling/core/coordinate_systems.jl +++ b/src/modeling/core/coordinate_systems.jl @@ -83,7 +83,6 @@ function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{St # Assemble Laplacian # TODO use bilinear operator for performance K = allocate_matrix(dh) - assembler = start_assemble(K) for sdh in dh.subdofhandlers cellvalues = getcellvalues(cv_collection, getcells(mesh, first(sdh.cellset))) @@ -123,7 +122,8 @@ function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{St f = zeros(ndofs(dh)) apply!(K_transmural, f, ch) - transmural = solve(LinearSolve.LinearProblem(K_transmural, f)).u + sol = solve(LinearSolve.LinearProblem(K_transmural, f), KrylovJL_CG()) + transmural = sol.u # Apicobasal coordinate apicobasal = zeros(ndofs(dh)) @@ -228,7 +228,8 @@ function compute_midmyocardial_section_coordinate_system(mesh::SimpleMesh{3}, su f = zeros(ndofs(dh)) apply!(K_transmural, f, ch) - transmural = solve(LinearSolve.LinearProblem(K_transmural, f)).u + sol = solve(LinearSolve.LinearProblem(K_transmural, f), KrylovJL_CG()) + transmural = sol.u # Apicobasal coordinate apicobasal = zeros(ndofs(dh)) From f96e38503a32b35a89a63918496456020ee6614c Mon Sep 17 00:00:00 2001 From: termi-official Date: Fri, 27 Sep 2024 18:40:11 +0200 Subject: [PATCH 6/8] derp --- src/modeling/core/coordinate_systems.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modeling/core/coordinate_systems.jl b/src/modeling/core/coordinate_systems.jl index 24829efd..7d4de6dc 100644 --- a/src/modeling/core/coordinate_systems.jl +++ b/src/modeling/core/coordinate_systems.jl @@ -122,7 +122,7 @@ function compute_lv_coordinate_system(mesh::SimpleMesh{3}, subdomains::Vector{St f = zeros(ndofs(dh)) apply!(K_transmural, f, ch) - sol = solve(LinearSolve.LinearProblem(K_transmural, f), KrylovJL_CG()) + sol = solve(LinearSolve.LinearProblem(K_transmural, f), LinearSolve.KrylovJL_CG()) transmural = sol.u # Apicobasal coordinate @@ -228,7 +228,7 @@ function compute_midmyocardial_section_coordinate_system(mesh::SimpleMesh{3}, su f = zeros(ndofs(dh)) apply!(K_transmural, f, ch) - sol = solve(LinearSolve.LinearProblem(K_transmural, f), KrylovJL_CG()) + sol = solve(LinearSolve.LinearProblem(K_transmural, f), LinearSolve.KrylovJL_CG()) transmural = sol.u # Apicobasal coordinate From cb9a8f4a3b099680fbd2ae6fc313f379712d902b Mon Sep 17 00:00:00 2001 From: termi-official Date: Tue, 1 Oct 2024 19:42:20 +0200 Subject: [PATCH 7/8] Fix type issues --- src/modeling/core/coefficients.jl | 10 ++++++---- src/modeling/core/coordinate_systems.jl | 13 ++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/modeling/core/coefficients.jl b/src/modeling/core/coefficients.jl index 95f4880d..39113f66 100644 --- a/src/modeling/core/coefficients.jl +++ b/src/modeling/core/coefficients.jl @@ -115,13 +115,15 @@ struct CoordinateSystemCoefficient{CS} end function compute_nodal_values(csc::CoordinateSystemCoefficient, dh::DofHandler, field_name::Symbol) - nodal_values = Vector{value_type(csc.cs)}(UndefInitializer(), ndofs(dh)) + Tv = value_type(csc.cs) + nodal_values = Vector{Tv}(UndefInitializer(), ndofs(dh)) + T = eltype(Tv) for sdh in dh.subdofhandlers field_name ∈ sdh.field_names || continue - ip = Ferrite.getfieldinterpolation(sdh, field_name) - positions = Ferrite.reference_coordinates(ip) + ip = Ferrite.getfieldinterpolation(sdh, field_name) + rdim = Ferrite.getrefdim(ip) + positions = Vec{rdim,T}.(Ferrite.reference_coordinates(ip)) # This little trick uses the delta property of interpolations - T = eltype(first(positions)) qr = QuadratureRule{Ferrite.getrefshape(ip)}([T(1.0) for _ in 1:length(positions)], positions) cc = setup_coefficient_cache(csc, qr, sdh) _compute_nodal_values!(nodal_values, qr, cc, sdh) diff --git a/src/modeling/core/coordinate_systems.jl b/src/modeling/core/coordinate_systems.jl index b4598ceb..30cd1502 100644 --- a/src/modeling/core/coordinate_systems.jl +++ b/src/modeling/core/coordinate_systems.jl @@ -3,10 +3,13 @@ Standard cartesian coordinate system. """ -struct CartesianCoordinateSystem{sdim} +struct CartesianCoordinateSystem{sdim,T} + function CartesianCoordinateSystem{sdim}() where sdim + return new{sdim,Float32}() + end end -value_type(::CartesianCoordinateSystem{sdim}) where sdim = Vec{sdim, Float32} +value_type(::CartesianCoordinateSystem{sdim, T}) where {sdim, T} = Vec{sdim, T} CartesianCoordinateSystem(mesh::AbstractGrid{sdim}) where sdim = CartesianCoordinateSystem{sdim}() @@ -47,7 +50,9 @@ struct LVCoordinate{T} rotational::T end -value_type(::LVCoordinateSystem) = LVCoordinate{Float32} +Base.eltype(::Type{LVCoordinate{T}}) where T = T +Base.eltype(::LVCoordinate{T}) where T = T +value_type(::LVCoordinateSystem{T}) where T = LVCoordinate{T} """ @@ -309,6 +314,8 @@ struct BiVCoordinate{T} transventricular::T end +Base.eltype(::Type{BiVCoordinate{T}}) where T = T +Base.eltype(::BiVCoordinate{T}) where T = T value_type(::BiVCoordinateSystem) = BiVCoordinate getcoordinateinterpolation(cs::BiVCoordinateSystem, cell::Ferrite.AbstractCell) = Ferrite.getfieldinterpolation(cs.dh, (1,1)) From b842d35bf60da342a99b8766cae116e35ef7bfe0 Mon Sep 17 00:00:00 2001 From: termi-official Date: Tue, 1 Oct 2024 20:00:12 +0200 Subject: [PATCH 8/8] Typo --- src/modeling/core/coordinate_systems.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modeling/core/coordinate_systems.jl b/src/modeling/core/coordinate_systems.jl index 30cd1502..7f966823 100644 --- a/src/modeling/core/coordinate_systems.jl +++ b/src/modeling/core/coordinate_systems.jl @@ -46,7 +46,7 @@ LV only part of the universal ventricular coordinate, containing """ struct LVCoordinate{T} transmural::T - apicaobasal::T + apicobasal::T rotational::T end @@ -309,7 +309,7 @@ Biventricular universal coordinate, containing """ struct BiVCoordinate{T} transmural::T - apicaobasal::T + apicobasal::T rotational::T transventricular::T end