Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DIRK-IMEX schemes #106

Merged
merged 27 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
08d85c8
First elements of DIRK IMEX
ScottMacLachlan Dec 9, 2024
198cc8e
Merge branch 'master' into smaclachlan/add_dirk_imex
ScottMacLachlan Jan 8, 2025
0a8146d
Cleanup
ScottMacLachlan Jan 10, 2025
2f60244
Remove thingies
ScottMacLachlan Jan 10, 2025
b9e0a31
Cleanups
ScottMacLachlan Jan 10, 2025
e48cf44
Fix BCs / indent error
ScottMacLachlan Jan 10, 2025
016d16a
Adding first test
ScottMacLachlan Jan 10, 2025
e026c9d
Fix IMEX-Euler syntax and test it
ScottMacLachlan Jan 10, 2025
fc10fbf
Add convergence test
ScottMacLachlan Jan 10, 2025
27d5923
heat -> convection-diffusion
ScottMacLachlan Jan 10, 2025
2e66c63
Add monodomain demo
ScottMacLachlan Jan 10, 2025
90b6f60
Fix typos in demos
ScottMacLachlan Jan 10, 2025
4fa3d56
Lighter-weight mass solver
ScottMacLachlan Jan 10, 2025
e922460
Rename property
ScottMacLachlan Jan 10, 2025
ec7cc4d
Less cryptic comment
ScottMacLachlan Jan 10, 2025
e58c477
Better name
ScottMacLachlan Jan 10, 2025
8524b4e
Add feedback on failed assertion
ScottMacLachlan Jan 10, 2025
5addfb7
Reorganize loops to prep for cases
ScottMacLachlan Jan 10, 2025
6a9a568
Add general finalize method
ScottMacLachlan Jan 10, 2025
acebd07
Special case when last explicit stage is not needed
ScottMacLachlan Jan 10, 2025
05272b1
Update demo
ScottMacLachlan Jan 10, 2025
f671062
Merge branch 'master' into smaclachlan/add_dirk_imex
ScottMacLachlan Jan 10, 2025
02a8441
Adding stiffly accurate finalize method
ScottMacLachlan Jan 10, 2025
a2d1dd8
Introducing factory code
ScottMacLachlan Jan 11, 2025
0cbb1a7
Typos in demo
ScottMacLachlan Jan 11, 2025
dccce50
Tweak docstring
ScottMacLachlan Jan 11, 2025
36504a9
Sign convention
ScottMacLachlan Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions demos/monodomain/demo_monodomain_FHN.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The basic form of the equation is:

\chi \left( C_m u_t + I_{ion}(u) \right) = \nabla \cdot \sigma \nabla u

where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ration. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation:
where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ratio. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation:

.. math::

Expand Down Expand Up @@ -58,15 +58,15 @@ Specify the physical constants and initial conditions::
sigma = as_matrix([[sigma1, 0.0], [0.0, sigma2]])


InitialPotential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791))
InitialCell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)),
initial_potential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791))
initial_cell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)),
Constant(2.0), Constant(-0.5758))


uu = Function(Z)
vu, vc = TestFunctions(Z)
uu.sub(0).interpolate(InitialPotential)
uu.sub(1).interpolate(InitialCell)
uu.sub(0).interpolate(initial_potential)
uu.sub(1).interpolate(initial_cell)

(u, c) = split(uu)

Expand Down
154 changes: 154 additions & 0 deletions demos/monodomain/demo_monodomain_FHN_dirkimex.py.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
Solving monodomain equations with Fitzhugh-Nagumo reaction and a DIRK- IMEX method
==================================================================================

We're solving monodomain (reaction-diffusion) with a particular reaction term.
The basic form of the equation is:

.. math::

\chi \left( C_m u_t + I_{ion}(u) \right) = \nabla \cdot \sigma \nabla u

where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ratio. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation:

.. math::

c_t = \epsilon( u + \beta - \gamma c)

for certain positive parameters :math:`\beta` and :math:`\gamma`, and the current takes the form of:

.. math::

I_{ion}(u, c) = \tfrac{1}{\epsilon} \left( u - \tfrac{u^3}{3} - c \right)

so that we have an overall system of two equations. One of them is linear but stiff/diffusive, and the other is nonstiff but nonlinear. This combination makes the system a good candidate for IMEX-type methods.


We start with standard Firedrake/Irksome imports::

import copy

from firedrake import (And, Constant, File, Function, FunctionSpace,
RectangleMesh, SpatialCoordinate, TestFunctions,
as_matrix, conditional, dx, grad, inner, split)
from irksome import Dt, MeshConstant, IMEX4, TimeStepper

And we set up the mesh and function space.::

mesh = RectangleMesh(20, 20, 70, 70, quadrilateral=True)
polyOrder = 2

V = FunctionSpace(mesh, "CG", 2)
Z = V * V

x, y = SpatialCoordinate(mesh)
MC = MeshConstant(mesh)
dt = MC.Constant(0.05)
t = MC.Constant(0.0)

Specify the physical constants and initial conditions::

eps = Constant(0.1)
beta = Constant(1.0)
gamma = Constant(0.5)

chi = Constant(1.0)
capacitance = Constant(1.0)

sigma1 = sigma2 = 1.0
sigma = as_matrix([[sigma1, 0.0], [0.0, sigma2]])


initial_potential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791))
initial_cell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)),
Constant(2.0), Constant(-0.5758))


uu = Function(Z)
vu, vc = TestFunctions(Z)
uu.sub(0).interpolate(initial_potential)
uu.sub(1).interpolate(initial_cell)

(u, c) = split(uu)


This sets up the Butcher tableau. Here, we use the DIRK-IMEX methods proposed
by Ascher, Ruuth, and Spiteri in their 1997 Applied Numerical Mathematics paper.
For this case, We use a four-stage method.::

butcher_tableau = IMEX4()
ScottMacLachlan marked this conversation as resolved.
Show resolved Hide resolved
ns = butcher_tableau.num_stages

To access an IMEX method, we need to separately specify the implicit and explicit parts of the operator.
The part to be handled implicitly is taken to contain the time derivatives as well::

F1 = (inner(chi * capacitance * Dt(u), vu)*dx
+ inner(grad(u), sigma * grad(vu))*dx
+ inner(Dt(c), vc)*dx - inner(eps * u, vc)*dx
- inner(beta * eps, vc)*dx + inner(gamma * eps * c, vc)*dx)

This is the part to be handled explicitly.::

F2 = - inner((chi/eps) * (-u + (u**3 / 3) + c), vu)*dx
rckirby marked this conversation as resolved.
Show resolved Hide resolved

If we wanted to use a fully implicit method, we would just take
F = F1 - F2. Note the minus sign, since DIRK-IMEX takes forms as F1 = F2.

Now, set up solver parameters. Since we're using a DIRK-IMEX scheme, we can
specify only parameters for each stage. We use an additive Schwarz (fieldsplit) method that applies AMG to the potential block and incomplete Cholesky to the cell block independently for each stage::

params = {"snes_type": "ksponly",
"ksp_monitor": None,
"mat_type": "aij",
"ksp_type": "fgmres",
"pc_type": "fieldsplit",
"pc_fieldsplit_type": "additive",
ScottMacLachlan marked this conversation as resolved.
Show resolved Hide resolved
"fieldsplit_0": {
"ksp_type": "preonly",
"pc_type": "gamg",
},
"fieldsplit_1": {
"ksp_type": "preonly",
"pc_type": "icc",
}}


The DIRK-IMEX schemes also require a mass-matrix solver. Here, we just use AMG on the coupled system, which works fine.::

mass_params = {"snes_type": "ksponly",
"ksp_rtol": 1.e-8,
"ksp_monitor": None,
"mat_type": "aij",
"ksp_type": "cg",
"pc_type": "icc",
}

Now, we access the IMEX method via the `TimeStepper` as with other methods. Note that we specify somewhat different kwargs, needing to specify the implicit and explicit parts separately as well as separate solver options for the implicit and mass solvers.::

stepper = TimeStepper(F1, butcher_tableau, t, dt, uu,
stage_type="dirkimex",
solver_parameters=params,
mass_parameters=mass_params,
Fexp=F2)

uFinal, cFinal = uu.split()
outfile1 = File("FHN_results/FHN_2d_u.pvd")
outfile2 = File("FHN_results/FHN_2d_c.pvd")
outfile1.write(uFinal, time=0)
outfile2.write(cFinal, time=0)

for j in range(12):
print(f"{float(t)}")
stepper.advance()
t.assign(float(t) + float(dt))

if (j % 5 == 0):
outfile1.write(uFinal, time=j * float(dt))
outfile2.write(cFinal, time=j * float(dt))

nsteps, n_nonlin, n_lin, n_nonlin_mass, n_lin_mass = stepper.solver_stats()
print(f"Time steps taken: {nsteps}")
print(f" {n_nonlin} nonlinear steps in implicit stage solves (should be {nsteps*ns})")
print(f" {n_lin} linear steps in implicit stage solves")
print(f" {n_nonlin_mass} nonlinear steps in mass solves (should be {nsteps*(ns+1)})")
print(f" {n_lin_mass} linear steps in mass solves")

3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ and for adaptive IRK methods:
demos/demo_heat_adapt.py


Or check out an IMEX-type method for the monodomain equations:
Or check out two IMEX-type methods for the monodomain equations:

.. toctree::
:maxdepth: 1

demos/demo_monodomain_FHN.py
demos/demo_monodomain_FHN_dirkimex.py

Advanced demos
--------------
Expand Down
5 changes: 5 additions & 0 deletions irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
from .ButcherTableaux import RadauIIA # noqa: F401
from .pep_explicit_rk import PEPRK # noqa: F401
from .deriv import Dt # noqa: F401
from .dirk_imex_tableaux import IMEXEuler # noqa: F401
from .dirk_imex_tableaux import IMEX2 # noqa: F401
from .dirk_imex_tableaux import IMEX3 # noqa: F401
from .dirk_imex_tableaux import IMEX4 # noqa: F401
from .dirk_stepper import DIRKTimeStepper # noqa: F401
from .getForm import getForm # noqa: F401
from .imex import RadauIIAIMEXMethod # noqa: F401
from .imex import DIRKIMEXMethod # noqa: F401
from .pc import RanaBase, RanaDU, RanaLD # noqa: F401
from .pc import IRKAuxiliaryOperatorPC # noqa: F401
from .stage import StageValueTimeStepper # noqa: F401
Expand Down
100 changes: 100 additions & 0 deletions irksome/dirk_imex_tableaux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from .ButcherTableaux import ButcherTableau
import numpy as np


# IMEX Butcher tableau using numpy arrays
class IMEXEuler(ButcherTableau):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Euler instead of 1, while other methods have numbers? Maybe we want these in a kind of factory (like PEPRK and WSODIRK) with a more descriptive name like AscherRuuthSpiteriDIRKIMEX(order)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ascher-Ruuth-Spiteri have a numbering system that obfuscates the schemes to me, writing (s, sigma, p), where s is the number of stages, sigma is the number of explicit stages needed to be calculated (sometimes s, sometimes s+1 -- we currently assume it is always s+1, which means we sometimes do an unnecessary mass solve), and p is order of the scheme. I'd still prefer to avoid that, so perhaps just the factory you describe is sufficient naming?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the right thing to do is write out the loop "both ways" (the general case and optimized to avoid the extra mass solve). Maybe there is a clever way to fuse the loops, or maybe we just want an "if/else" branching on whether that optimization is possible. We could create advance_normal and advance_optimized methods and set self.advance to be one of these in the constructor. It would look like code duplication, but avoiding the extra mass solve is worthwhile if you're trying to use IMEX as a performance win.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a small refactor to loop over the solves as each stages' mass solve then implicit solve for the first ns stages, then having finalize_normal (compute last explicit stage and general-case update), finalize_no_last_explicit (general-case update where last explicit stage is not used), and finalize_stiffly_accurate (update where last explicit stage is not used, and last implicit stage is new u0)? That would get us out of code duplication altogether, and we could sniff which case in __init__

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a great plan. Does it get the optimal pattern for each case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay - 6a9a568 passes tests locally with the last mass solve and the final update separated into _final_general. Now to implement the special cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And 02a8441 implements the special cases. So, seems like this is good to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which brings us back to naming... Shall I just implement a factory that uses the Ascher-Ruuth-Spiteri indexing convention? Seems the easiest for users, even if it is a bit opaque

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the factory approach is probably cleaner for the user interface.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Factory is now implemented, in a2d1dd8

def __init__(self):
A = np.array([[1.0]]) # Implicit matrix (for implicit part)
A_hat = np.array([[1.0]]) # Explicit matrix (for explicit part)
b = np.array([1.0]) # Implicit weights
b_hat = np.array([1.0, 0.0]) # Explicit weights
c = np.array([1.0]) # Time steps for implicit part
c_hat = np.array([0.0, 1.0]) # Time steps for explicit part
order = 1 # First-order method (Euler)
embedded_order = None # set to None
gamma0 = None # set to None
super().__init__(A, b, b_hat, c, order, embedded_order, gamma0)
self.A_hat = A_hat
self.b_hat = b_hat
self.c_hat = c_hat
self.is_dirk_imex = True # Mark this as a DIRK-IMEX scheme


# IMEX Butcher tableau for s = 2
class IMEX2(ButcherTableau):
def __init__(self):
# Parameters for the s = 2 method
gamma = (2 - np.sqrt(2)) / 2
delta = -2 * np.sqrt(2) / 3

# Implicit and explicit coefficients
A = np.array([[gamma, 0], [1 - gamma, gamma]]) # Implicit matrix (A)
A_hat = np.array([[gamma, 0], [delta, 1 - delta]]) # Explicit matrix (A_hat)
b = np.array([1 - gamma, gamma]) # Implicit weights
b_hat = np.array([0, 1 - gamma, gamma]) # Explicit weights (b_hat)
c = np.array([gamma, 1.0]) # Time steps for implicit part (c)
c_hat = np.array([0, gamma, 1.0]) # Time steps for explicit part (c_hat)

# The method order is 2
order = 2
embedded_order = None # set to None
gamma0 = None # set to None
btilde = None

super().__init__(A, b, btilde, c, order, embedded_order, gamma0)
self.A_hat = A_hat
self.b_hat = b_hat
self.c_hat = c_hat
self.is_dirk_imex = True # Mark this as a DIRK-IMEX scheme


# IMEX Butcher tableau for s = 3
class IMEX3(ButcherTableau):
def __init__(self):
A = np.array([[0.4358665215, 0, 0], [0.2820667392, 0.4358665215, 0], [1.208496649, -0.644363171, 0.4358665215]]) # Implicit matrix (A)
A_hat = np.array([[0.4358665215, 0, 0], [0.3212788860, 0.3966543747, 0], [-0.105858296, 0.5529291479, 0.5529291479]]) # Explicit matrix (A_hat)
b = np.array([1.208496649, -0.644363171, 0.4358665215]) # Implicit weights
b_hat = np.array([0, 1.208496649, -0.644363171, 0.4358665215]) # Explicit weights (b_hat)
c = np.array([0.4358665215, 0.7179332608, 1]) # Time steps for implicit part (c)
c_hat = np.array([0, 0.4358665215, 0.7179332608, 1.0]) # Time steps for explicit part (c_hat)

# The method order is 3
order = 3
embedded_order = None # set to None
gamma0 = None # set to None
btilde = None

super().__init__(A, b, btilde, c, order, embedded_order, gamma0)
self.A_hat = A_hat
self.b_hat = b_hat
self.c_hat = c_hat
self.is_dirk_imex = True # Mark this as a DIRK-IMEX scheme


# IMEX Butcher tableau for s = 4
class IMEX4(ButcherTableau):
def __init__(self):
A = np.array([[1/2, 0, 0, 0],
[1/6, 1/2, 0, 0],
[-1/2, 1/2, 1/2, 0],
[3/2, -3/2, 1/2, 1/2]]) # Corrected A matrix definition
A_hat = np.array([[1/2, 0, 0, 0],
[11/18, 1/18, 0, 0],
[5/6, -5/6, 1/2, 0],
[1/4, 7/4, 3/4, -7/4]]) # Explicit matrix (A_hat)
b = np.array([3/2, -3/2, 1/2, 1/2]) # Implicit weights
b_hat = np.array([1/4, 7/4, 3/4, -7/4, 0]) # Explicit weights (b_hat)
c = np.array([1/2, 2/3, 1/2, 1]) # Time steps for implicit part (c)
c_hat = np.array([0, 1/2, 2/3, 1/2, 1]) # Time steps for explicit part (c_hat)

order = 3
embedded_order = None
gamma0 = None
btilde = None

super().__init__(A, b, btilde, c, order, embedded_order, gamma0)
self.A_hat = A_hat
self.b_hat = b_hat
self.c_hat = c_hat
self.is_dirk_imex = True # Mark this as a DIRK-IMEX scheme
Loading