diff --git a/demos/monodomain/demo_monodomain_FHN.py.rst b/demos/monodomain/demo_monodomain_FHN.py.rst index 16254519..595945c7 100644 --- a/demos/monodomain/demo_monodomain_FHN.py.rst +++ b/demos/monodomain/demo_monodomain_FHN.py.rst @@ -8,7 +8,7 @@ The basic form of the equation is: \chi \left( C_m u_t + I_{ion}(u) \right) = \nabla \cdot \sigma \nabla u -where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ration. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation: +where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ratio. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation: .. math:: @@ -58,15 +58,15 @@ Specify the physical constants and initial conditions:: sigma = as_matrix([[sigma1, 0.0], [0.0, sigma2]]) - InitialPotential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791)) - InitialCell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)), + initial_potential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791)) + initial_cell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)), Constant(2.0), Constant(-0.5758)) uu = Function(Z) vu, vc = TestFunctions(Z) - uu.sub(0).interpolate(InitialPotential) - uu.sub(1).interpolate(InitialCell) + uu.sub(0).interpolate(initial_potential) + uu.sub(1).interpolate(initial_cell) (u, c) = split(uu) diff --git a/demos/monodomain/demo_monodomain_FHN_dirkimex.py.rst b/demos/monodomain/demo_monodomain_FHN_dirkimex.py.rst new file mode 100644 index 00000000..ca3d3acd --- /dev/null +++ b/demos/monodomain/demo_monodomain_FHN_dirkimex.py.rst @@ -0,0 +1,157 @@ +Solving monodomain equations with Fitzhugh-Nagumo reaction and a DIRK-IMEX method +================================================================================= + +We're solving monodomain (reaction-diffusion) with a particular reaction term. +The basic form of the equation is: + +.. math:: + + \chi \left( C_m u_t + I_{ion}(u) \right) = \nabla \cdot \sigma \nabla u + +where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ratio. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation: + +.. math:: + + c_t = \epsilon( u + \beta - \gamma c) + +for certain positive parameters :math:`\beta` and :math:`\gamma`, and the current takes the form of: + +.. math:: + + I_{ion}(u, c) = \tfrac{1}{\epsilon} \left( u - \tfrac{u^3}{3} - c \right) + +so that we have an overall system of two equations. One of them is linear but stiff/diffusive, and the other is nonstiff but nonlinear. This combination makes the system a good candidate for IMEX-type methods. + + +We start with standard Firedrake/Irksome imports:: + + import copy + + from firedrake import (And, Constant, File, Function, FunctionSpace, + RectangleMesh, SpatialCoordinate, TestFunctions, + as_matrix, conditional, dx, grad, inner, split) + from irksome import Dt, MeshConstant, DIRK_IMEX, TimeStepper + +And we set up the mesh and function space.:: + + mesh = RectangleMesh(20, 20, 70, 70, quadrilateral=True) + polyOrder = 2 + + V = FunctionSpace(mesh, "CG", 2) + Z = V * V + + x, y = SpatialCoordinate(mesh) + MC = MeshConstant(mesh) + dt = MC.Constant(0.05) + t = MC.Constant(0.0) + +Specify the physical constants and initial conditions:: + + eps = Constant(0.1) + beta = Constant(1.0) + gamma = Constant(0.5) + + chi = Constant(1.0) + capacitance = Constant(1.0) + + sigma1 = sigma2 = 1.0 + sigma = as_matrix([[sigma1, 0.0], [0.0, sigma2]]) + + + initial_potential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791)) + initial_cell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)), + Constant(2.0), Constant(-0.5758)) + + + uu = Function(Z) + vu, vc = TestFunctions(Z) + uu.sub(0).interpolate(initial_potential) + uu.sub(1).interpolate(initial_cell) + + (u, c) = split(uu) + + +This sets up the Butcher tableau. Here, we use the DIRK-IMEX methods proposed +by Ascher, Ruuth, and Spiteri in their 1997 Applied Numerical Mathematics paper. +For this case, We use a four-stage method.:: + + butcher_tableau = DIRK_IMEX(4, 4, 3) + ns = butcher_tableau.num_stages + +To access an IMEX method, we need to separately specify the implicit and explicit parts of the operator. +The part to be handled implicitly is taken to contain the time derivatives as well:: + + F1 = (inner(chi * capacitance * Dt(u), vu)*dx + + inner(grad(u), sigma * grad(vu))*dx + + inner(Dt(c), vc)*dx - inner(eps * u, vc)*dx + - inner(beta * eps, vc)*dx + inner(gamma * eps * c, vc)*dx) + +This is the part to be handled explicitly.:: + + F2 = inner((chi/eps) * (-u + (u**3 / 3) + c), vu)*dx + +If we wanted to use a fully implicit method, we would just take +F = F1 + F2. + +Now, set up solver parameters. Since we're using a DIRK-IMEX scheme, we can +specify only parameters for each stage. We use an additive Schwarz (fieldsplit) method that applies AMG to the potential block and incomplete Cholesky to the cell block independently for each stage:: + + params = {"snes_type": "ksponly", + "ksp_monitor": None, + "mat_type": "aij", + "ksp_type": "fgmres", + "pc_type": "fieldsplit", + "pc_fieldsplit_type": "additive", + "fieldsplit_0": { + "ksp_type": "preonly", + "pc_type": "gamg", + }, + "fieldsplit_1": { + "ksp_type": "preonly", + "pc_type": "icc", + }} + + +The DIRK-IMEX schemes also require a mass-matrix solver. Here, we just use an incomplete Cholesky preconditioner for CG on the coupled system, which works fine.:: + + mass_params = {"snes_type": "ksponly", + "ksp_rtol": 1.e-8, + "ksp_monitor": None, + "mat_type": "aij", + "ksp_type": "cg", + "pc_type": "icc", + } + +Now, we access the IMEX method via the `TimeStepper` as with other methods. Note that we specify somewhat different kwargs, needing to specify the implicit and explicit parts separately as well as separate solver options for the implicit and mass solvers.:: + + stepper = TimeStepper(F1, butcher_tableau, t, dt, uu, + stage_type="dirkimex", + solver_parameters=params, + mass_parameters=mass_params, + Fexp=F2) + + uFinal, cFinal = uu.split() + outfile1 = File("FHN_results/FHN_2d_u.pvd") + outfile2 = File("FHN_results/FHN_2d_c.pvd") + outfile1.write(uFinal, time=0) + outfile2.write(cFinal, time=0) + + for j in range(12): + print(f"{float(t)}") + stepper.advance() + t.assign(float(t) + float(dt)) + + if (j % 5 == 0): + outfile1.write(uFinal, time=j * float(dt)) + outfile2.write(cFinal, time=j * float(dt)) + + +We can print out some solver statistics here. We expect one implicit solve per stage per timestep, and that's what we see with the four-stage method. For this Butcher Tableau, we can avoid computing the final explicit stage (since it's coefficient in the next stage reconstruction is zero), so we see the same number of mass solves.:: + + nsteps, n_nonlin, n_lin, n_nonlin_mass, n_lin_mass = stepper.solver_stats() + print(f"Time steps taken: {nsteps}") + print(f" {n_nonlin} nonlinear steps in implicit stage solves (should be {nsteps*ns})") + print(f" {n_lin} linear steps in implicit stage solves") + print(f" {n_nonlin_mass} nonlinear steps in mass solves (should be {nsteps*ns})") + print(f" {n_lin_mass} linear steps in mass solves") + diff --git a/docs/source/index.rst b/docs/source/index.rst index e6561a54..c46a02ff 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -106,12 +106,13 @@ and for adaptive IRK methods: demos/demo_heat_adapt.py -Or check out an IMEX-type method for the monodomain equations: +Or check out two IMEX-type methods for the monodomain equations: .. toctree:: :maxdepth: 1 demos/demo_monodomain_FHN.py + demos/demo_monodomain_FHN_dirkimex.py Advanced demos -------------- diff --git a/irksome/__init__.py b/irksome/__init__.py index d83bd5ac..03f2f20c 100644 --- a/irksome/__init__.py +++ b/irksome/__init__.py @@ -8,9 +8,11 @@ from .ButcherTableaux import RadauIIA # noqa: F401 from .pep_explicit_rk import PEPRK # noqa: F401 from .deriv import Dt # noqa: F401 +from .dirk_imex_tableaux import DIRK_IMEX # noqa: F401 from .dirk_stepper import DIRKTimeStepper # noqa: F401 from .getForm import getForm # noqa: F401 from .imex import RadauIIAIMEXMethod # noqa: F401 +from .imex import DIRKIMEXMethod # noqa: F401 from .pc import RanaBase, RanaDU, RanaLD # noqa: F401 from .pc import IRKAuxiliaryOperatorPC # noqa: F401 from .stage import StageValueTimeStepper # noqa: F401 diff --git a/irksome/dirk_imex_tableaux.py b/irksome/dirk_imex_tableaux.py new file mode 100644 index 00000000..ebd53337 --- /dev/null +++ b/irksome/dirk_imex_tableaux.py @@ -0,0 +1,71 @@ +from .ButcherTableaux import ButcherTableau +import numpy as np + +# For the implicit scheme, the full Butcher Table is given as A, b, c. + +# For the explicit scheme, the full b_hat and c_hat are given, but (to +# avoid a lot of offset-by-ones in the code we store only the +# lower-left ns x ns block of A_hat + +# IMEX Butcher tableau for 1 stage +imex111A = np.array([[1.0]]) +imex111A_hat = np.array([[1.0]]) +imex111b = np.array([1.0]) +imex111b_hat = np.array([1.0, 0.0]) +imex111c = np.array([1.0]) +imex111c_hat = np.array([0.0, 1.0]) + + +# IMEX Butcher tableau for s = 2 +gamma = (2 - np.sqrt(2)) / 2 +delta = -2 * np.sqrt(2) / 3 +imex232A = np.array([[gamma, 0], [1 - gamma, gamma]]) +imex232A_hat = np.array([[gamma, 0], [delta, 1 - delta]]) +imex232b = np.array([1 - gamma, gamma]) +imex232b_hat = np.array([0, 1 - gamma, gamma]) +imex232c = np.array([gamma, 1.0]) +imex232c_hat = np.array([0, gamma, 1.0]) + +# IMEX Butcher tableau for 3 stages +imex343A = np.array([[0.4358665215, 0, 0], [0.2820667392, 0.4358665215, 0], [1.208496649, -0.644363171, 0.4358665215]]) +imex343A_hat = np.array([[0.4358665215, 0, 0], [0.3212788860, 0.3966543747, 0], [-0.105858296, 0.5529291479, 0.5529291479]]) +imex343b = np.array([1.208496649, -0.644363171, 0.4358665215]) +imex343b_hat = np.array([0, 1.208496649, -0.644363171, 0.4358665215]) +imex343c = np.array([0.4358665215, 0.7179332608, 1]) +imex343c_hat = np.array([0, 0.4358665215, 0.7179332608, 1.0]) + + +# IMEX Butcher tableau for 4 stages +imex443A = np.array([[1/2, 0, 0, 0], + [1/6, 1/2, 0, 0], + [-1/2, 1/2, 1/2, 0], + [3/2, -3/2, 1/2, 1/2]]) +imex443A_hat = np.array([[1/2, 0, 0, 0], + [11/18, 1/18, 0, 0], + [5/6, -5/6, 1/2, 0], + [1/4, 7/4, 3/4, -7/4]]) +imex443b = np.array([3/2, -3/2, 1/2, 1/2]) +imex443b_hat = np.array([1/4, 7/4, 3/4, -7/4, 0]) +imex443c = np.array([1/2, 2/3, 1/2, 1]) +imex443c_hat = np.array([0, 1/2, 2/3, 1/2, 1]) + +dirk_imex_dict = { + (1, 1, 1): (imex111A, imex111b, imex111c, imex111A_hat, imex111b_hat, imex111c_hat), + (2, 3, 2): (imex232A, imex232b, imex232c, imex232A_hat, imex232b_hat, imex232c_hat), + (3, 4, 3): (imex343A, imex343b, imex343c, imex343A_hat, imex343b_hat, imex343c_hat), + (4, 4, 3): (imex443A, imex443b, imex443c, imex443A_hat, imex443b_hat, imex443c_hat) +} + + +class DIRK_IMEX(ButcherTableau): + def __init__(self, ns_imp, ns_exp, order): + try: + A, b, c, A_hat, b_hat, c_hat = dirk_imex_dict[ns_imp, ns_exp, order] + except KeyError: + raise NotImplementedError("No DIRK-IMEX method for that combination of implicit and explicit stages and order") + self.order = order + super(DIRK_IMEX, self).__init__(A, b, None, c, order, None, None) + self.A_hat = A_hat + self.b_hat = b_hat + self.c_hat = c_hat + self.is_dirk_imex = True # Mark this as a DIRK-IMEX scheme diff --git a/irksome/imex.py b/irksome/imex.py index 815aa2a6..74bce7dc 100644 --- a/irksome/imex.py +++ b/irksome/imex.py @@ -1,13 +1,16 @@ import FIAT import numpy as np from firedrake import (Constant, Function, NonlinearVariationalProblem, - NonlinearVariationalSolver, TestFunction) + NonlinearVariationalSolver, TestFunction, + as_ufl, dx, inner, split) from firedrake.dmhooks import pop_parent, push_parent from ufl.classes import Zero from .ButcherTableaux import RadauIIA +from .deriv import TimeDerivative from .stage import getBits, getFormStage -from .tools import AI, IA, replace +from .tools import AI, IA, MeshConstant, replace +from .bcs import bc2space def riia_explicit_coeffs(k): @@ -300,3 +303,284 @@ def solver_stats(self): self.num_linear_iterations_prop, self.num_nonlinear_iterations_it, self.num_linear_iterations_it) + + +def getFormsDIRKIMEX(F, Fexp, ks, khats, butch, t, dt, u0, bcs=None): + if bcs is None: + bcs = [] + + v = F.arguments()[0] + V = v.function_space() + msh = V.mesh() + assert V == u0.function_space() + + num_fields = len(V) + num_stages = butch.num_stages + k = Function(V) + g = Function(V) + + khat = Function(V) + ghat = Function(V) + vhat = TestFunction(V) + + # If we're on a mixed problem, we need to replace pieces of the + # solution. Stores array of the splittings of the functions for each stage. + if num_fields == 1: + k_bits = [k] + u0bits = [u0] + gbits = [g] + ghat_bits = [ghat] + else: + k_bits = np.array(split(k), dtype=object) + u0bits = split(u0) + gbits = split(g) + ghat_bits = split(g) + + # Note: the Constant c is used for substitution in both the + # implicit variational form and BC's, and we update it for each stage in + # the loop over stages in the advance method. The Constants a and chat are + # used similarly in the variational forms + MC = MeshConstant(msh) + c = MC.Constant(1.0) + chat = MC.Constant(1.0) + a = MC.Constant(1.0) + + # Implicit replacement, solve at time t + c * dt, for k + repl = {t: t + c * dt} + for u0bit, kbit, gbit in zip(u0bits, k_bits, gbits): + repl[u0bit] = gbit + dt * a * kbit + repl[TimeDerivative(u0bit)] = kbit + stage_F = replace(F, repl) + + # Explicit replacement, solve at time t + chat * dt, for khat + replhat = {t: t + chat * dt} + for u0bit, ghatbit in zip(u0bits, ghat_bits): + replhat[u0bit] = ghatbit + Fhat = inner(khat, vhat)*dx + replace(Fexp, replhat) + + bcnew = [] + + # For the DIRK-IMEX case, we need one new BC for each old one + # (rather than one per stage), but we need a `Function` inside of + # each BC and a rule for computing that function at each time for + # each stage. + + a_vals = np.array([MC.Constant(0) for i in range(num_stages)], + dtype=object) + ahat_vals = np.array([MC.Constant(0) for i in range(num_stages+1)], + dtype=object) + d_val = MC.Constant(1.0) + + for bc in bcs: + bcarg = as_ufl(bc._original_arg) + bcarg_stage = replace(bcarg, {t: t+c*dt}) + + gdat = bcarg_stage - bc2space(bc, u0) + for i in range(num_stages): + gdat -= dt*a_vals[i]*bc2space(bc, ks[i]) + for i in range(num_stages+1): + gdat -= dt*ahat_vals[i]*bc2space(bc, khats[i]) + + gdat /= dt*d_val + bcnew.append(bc.reconstruct(g=gdat)) + + return stage_F, (k, g, a, c), bcnew, Fhat, (khat, ghat, chat), (a_vals, ahat_vals, d_val) + + +class DIRKIMEXMethod: + """Front-end class for advancing a time-dependent PDE via a + diagonally-implicit Runge-Kutta IMEX method formulated in terms of + stage derivatives. This implementation assumes a weak form + written as F + F_explicit = 0, where both F and F_explicit are UFL + Forms, with terms in F to be handled implicitly and those in + F_explicit to be handled explicitly + """ + + def __init__(self, F, F_explicit, butcher_tableau, t, dt, u0, bcs=None, + solver_parameters=None, mass_parameters=None, appctx=None, nullspace=None): + assert butcher_tableau.is_dirk_imex + + self.num_steps = 0 + self.num_nonlinear_iterations = 0 + self.num_linear_iterations = 0 + self.num_mass_nonlinear_iterations = 0 + self.num_mass_linear_iterations = 0 + + self.butcher_tableau = butcher_tableau + self.num_stages = butcher_tableau.num_stages + + self.V = V = u0.function_space() + self.u0 = u0 + self.t = t + self.dt = dt + self.num_fields = len(u0.function_space()) + self.ks = [Function(V) for _ in range(self.num_stages)] + self.k_hat_s = [Function(V) for _ in range(self.num_stages+1)] + + stage_F, (k, g, a, c), bcnew, Fhat, (khat, ghat, chat), (a_vals, ahat_vals, d_val) = getFormsDIRKIMEX( + F, F_explicit, self.ks, self.k_hat_s, butcher_tableau, t, dt, u0, bcs=bcs) + + self.bcnew = bcnew + + appctx_irksome = {"F": F, + "F_explicit": F_explicit, + "butcher_tableau": butcher_tableau, + "t": t, + "dt": dt, + "u0": u0, + "bcs": bcs, + "bc_type": "DAE", + "nullspace": nullspace} + if appctx is None: + appctx = appctx_irksome + else: + appctx = {**appctx, **appctx_irksome} + + self.problem = NonlinearVariationalProblem(stage_F, k, bcnew) + self.solver = NonlinearVariationalSolver(self.problem, appctx=appctx, + solver_parameters=solver_parameters, + nullspace=nullspace) + + self.mass_problem = NonlinearVariationalProblem(Fhat, khat) + self.mass_solver = NonlinearVariationalSolver(self.mass_problem, + solver_parameters=mass_parameters) + + self.kgac = k, g, a, c + self.kgchat = khat, ghat, chat + self.bc_constants = a_vals, ahat_vals, d_val + + AA = butcher_tableau.A + A_hat = butcher_tableau.A_hat + BB = butcher_tableau.b + B_hat = butcher_tableau.b_hat + + if B_hat[-1] == 0: + if np.allclose(AA[-1, :], BB) and np.allclose(A_hat[-1, :], B_hat[:-1]): + self._finalize = self._finalize_stiffly_accurate + else: + self._finalize = self._finalize_no_last_explicit + else: + self._finalize = self._finalize_general + + def advance(self): + k, g, a, c = self.kgac + khat, ghat, chat = self.kgchat + ks = self.ks + k_hat_s = self.k_hat_s + u0 = self.u0 + dtc = float(self.dt) + bt = self.butcher_tableau + ns = self.num_stages + AA = bt.A + A_hat = bt.A_hat + CC = bt.c + C_hat = bt.c_hat + a_vals, ahat_vals, d_val = self.bc_constants + + # Calculate explicit term for the first stage + ghat.assign(u0) + + for i in range(ns): + + chat.assign(C_hat[i]) + self.mass_solver.solve() + self.num_mass_nonlinear_iterations += self.mass_solver.snes.getIterationNumber() + self.num_mass_linear_iterations += self.mass_solver.snes.getLinearSolveIterations() + k_hat_s[i].assign(khat) + + g.assign(u0) + # Update g with contributions from previous stages + for j in range(i): + ksplit = ks[j].subfunctions + for gbit, kbit in zip(g.subfunctions, ksplit): + gbit += dtc * AA[i, j] * kbit + for j in range(i+1): + k_hat_split = k_hat_s[j].subfunctions + for gbit, k_hat_bit in zip(g.subfunctions, k_hat_split): + gbit += dtc * A_hat[i, j] * k_hat_bit + + # Solve for current stage + for j in range(i): + a_vals[j].assign(AA[i, j]) + for j in range(i, ns): + a_vals[j].assign(0) + for j in range(i+1): + ahat_vals[j].assign(A_hat[i, j]) + for j in range(i+1, ns+1): + ahat_vals[j].assign(0) + d_val.assign(AA[i, i]) + + # Solve the nonlinear problem at stage i + a.assign(AA[i, i]) + c.assign(CC[i]) + self.solver.solve() + self.num_nonlinear_iterations += self.solver.snes.getIterationNumber() + self.num_linear_iterations += self.solver.snes.getLinearSolveIterations() + ks[i].assign(k) + + # Update the solution for next stage + for ghatbit, gbit in zip(ghat.subfunctions, g.subfunctions): + ghatbit.assign(gbit) + for ghatbit, kbit in zip(ghat.subfunctions, ks[i].subfunctions): + ghatbit += dtc * AA[i, i] * kbit + + self._finalize() + self.num_steps += 1 + + # Last part of advance for the general case, where last explicit stage is calculated and used + def _finalize_general(self): + khat, ghat, chat = self.kgchat + ks = self.ks + k_hat_s = self.k_hat_s + u0 = self.u0 + dtc = float(self.dt) + bt = self.butcher_tableau + ns = self.num_stages + C_hat = bt.c_hat + BB = bt.b + B_hat = bt.b_hat + + chat.assign(C_hat[ns]) + self.mass_solver.solve() + self.num_mass_nonlinear_iterations += self.mass_solver.snes.getIterationNumber() + self.num_mass_linear_iterations += self.mass_solver.snes.getLinearSolveIterations() + k_hat_s[ns].assign(khat) + + # Final solution update + for i in range(ns): + for u0bit, kbit in zip(u0.subfunctions, ks[i].subfunctions): + u0bit += dtc * BB[i] * kbit + + for i in range(ns+1): + for u0bit, k_hat_bit in zip(u0.subfunctions, k_hat_s[i].subfunctions): + u0bit += dtc * B_hat[i] * k_hat_bit + + # Last part of advance for the general case, where last explicit stage is not used + def _finalize_no_last_explicit(self): + ks = self.ks + k_hat_s = self.k_hat_s + u0 = self.u0 + dtc = float(self.dt) + bt = self.butcher_tableau + ns = self.num_stages + BB = bt.b + B_hat = bt.b_hat + + # Final solution update + for i in range(ns): + for u0bit, kbit in zip(u0.subfunctions, ks[i].subfunctions): + u0bit += dtc * BB[i] * kbit + + for i in range(ns): + for u0bit, k_hat_bit in zip(u0.subfunctions, k_hat_s[i].subfunctions): + u0bit += dtc * B_hat[i] * k_hat_bit + + # Last part of advance for the general case, where last implicit stage is new solution + def _finalize_stiffly_accurate(self): + khat, ghat, chat = self.kgchat + u0 = self.u0 + for u0bit, ghatbit in zip(u0.subfunctions, ghat.subfunctions): + u0bit.assign(ghatbit) + + def solver_stats(self): + return self.num_steps, self.num_nonlinear_iterations, self.num_linear_iterations, self.num_mass_nonlinear_iterations, self.num_mass_linear_iterations diff --git a/irksome/stepper.py b/irksome/stepper.py index 1117e9e3..74ffea0f 100644 --- a/irksome/stepper.py +++ b/irksome/stepper.py @@ -9,7 +9,7 @@ from .dirk_stepper import DIRKTimeStepper from .explicit_stepper import ExplicitTimeStepper from .getForm import AI, getForm -from .imex import RadauIIAIMEXMethod +from .imex import RadauIIAIMEXMethod, DIRKIMEXMethod from .manipulation import extract_terms from .stage import StageValueTimeStepper @@ -67,7 +67,8 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs): "imex": ["Fexp", "stage_type", "bcs", "nullspace", "it_solver_parameters", "prop_solver_parameters", "splitting", "appctx", - "num_its_initial", "num_its_per_step"]} + "num_its_initial", "num_its_per_step"], + "dirkimex": ["Fexp", "stage_type", "bcs", "nullspace", "solver_parameters", "mass_parameters", "appctx"]} valid_adapt_parameters = ["tol", "dtmin", "dtmax", "KI", "KP", "max_reject", "onscale_factor", @@ -146,7 +147,7 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs): solver_parameters, appctx) elif stage_type == "imex": Fexp = kwargs.get("Fexp") - assert Fexp is not None + assert Fexp is not None, "Calling an IMEX scheme with no explicit form. Did you really mean to do this?" bcs = kwargs.get("bcs") appctx = kwargs.get("appctx") splitting = kwargs.get("splitting", AI) @@ -161,6 +162,17 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs): it_solver_parameters, prop_solver_parameters, splitting, appctx, nullspace, num_its_initial, num_its_per_step) + elif stage_type == "dirkimex": + Fexp = kwargs.get("Fexp") + assert Fexp is not None, "Calling an IMEX scheme with no explicit form. Did you really mean to do this?" + bcs = kwargs.get("bcs") + appctx = kwargs.get("appctx") + solver_parameters = kwargs.get("solver_parameters") + mass_parameters = kwargs.get("mass_parameters") + nullspace = kwargs.get("nullspace") + return DIRKIMEXMethod( + F, Fexp, butcher_tableau, t, dt, u0, bcs, + solver_parameters, mass_parameters, appctx, nullspace) class StageDerivativeTimeStepper: diff --git a/tests/test_imex.py b/tests/test_imex.py new file mode 100644 index 00000000..3d3faa7a --- /dev/null +++ b/tests/test_imex.py @@ -0,0 +1,126 @@ +from math import isclose + +import pytest +from firedrake import * +from irksome import Dt, MeshConstant, TimeStepper, DIRK_IMEX +from ufl.algorithms.ad import expand_derivatives + + +def convdiff_neumannbc(butcher_tableau, order, N): + msh = UnitIntervalMesh(N) + V = FunctionSpace(msh, "CG", order) + MC = MeshConstant(msh) + dt = MC.Constant(0.1 / N) + t = MC.Constant(0.0) + (x,) = SpatialCoordinate(msh) + + # Choose uexact so rhs is nonzero + uexact = cos(pi*x)*exp(-t) + rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + uexact.dx(0) + u = Function(V) + u.interpolate(uexact) + + v = TestFunction(V) + F = ( + inner(Dt(u), v) * dx + + inner(grad(u), grad(v)) * dx + - inner(rhs, v) * dx + ) + Fexp = inner(u.dx(0), v)*dx + + luparams = {"mat_type": "aij", "ksp_type": "preonly", "pc_type": "lu"} + + stepper = TimeStepper( + F, butcher_tableau, t, dt, u, Fexp=Fexp, + solver_parameters=luparams, mass_parameters=luparams, + stage_type="dirkimex" + ) + + t_end = 0.1 + while float(t) < t_end: + if float(t) + float(dt) > t_end: + dt.assign(t_end - float(t)) + stepper.advance() + t.assign(float(t) + float(dt)) + + return (errornorm(uexact, u) / norm(uexact)) + + +@pytest.mark.parametrize("imp_stages, exp_stages, order", + [(1, 1, 1), (2, 3, 2), + (3, 4, 3), (4, 4, 3)]) +def test_1d_convdiff_neumannbc(imp_stages, exp_stages, order): + bt = DIRK_IMEX(imp_stages, exp_stages, order) + errs = np.array([convdiff_neumannbc(bt, order, 10*2**p) for p in [3, 4]]) + print(errs) + conv = np.log2(errs[0]/errs[1]) + print(conv) + assert conv > order-0.4 + + +# Note that DIRK_IMEX(1,1,1) and DIRK_IMEX(4,4,2) are stiffly +# accurate, so the DAE-style BC imposition leads to satisfying the BCs +# exactly at each timestep, which we check here. The 2- and 3-stage +# methods are not. +@pytest.mark.parametrize("imp_stages, exp_stages, order", + [(1, 1, 1), (4, 4, 3)]) +def test_1d_heat_dirichletbc(imp_stages, exp_stages, order): + # Boundary values + u_0 = Constant(2.0) + u_1 = Constant(3.0) + + N = 10 + x0 = 0.0 + x1 = 10.0 + msh = IntervalMesh(N, x1) + V = FunctionSpace(msh, "CG", 1) + MC = MeshConstant(msh) + dt = MC.Constant(1.0 / N) + t = MC.Constant(0.0) + (x,) = SpatialCoordinate(msh) + + # Method of manufactured solutions copied from Heat equation demo. + S = Constant(2.0) + C = Constant(1000.0) + B = (x - Constant(x0)) * (x - Constant(x1)) / C + R = (x * x) ** 0.5 + # Note end linear contribution + uexact = ( + B * atan(t) * (pi / 2.0 - atan(S * (R - t))) + + u_0 + + ((x - x0) / x1) * (u_1 - u_0) + ) + rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + u = Function(V) + u.interpolate(uexact) + v = TestFunction(V) + F = ( + inner(Dt(u), v) * dx + + inner(grad(u), grad(v)) * dx + ) + Fexp = -inner(rhs, v) * dx + + bc = [ + DirichletBC(V, u_1, 2), + DirichletBC(V, u_0, 1), + ] + + luparams = {"mat_type": "aij", "ksp_type": "preonly", "pc_type": "lu"} + + butcher_tableau = DIRK_IMEX(imp_stages, exp_stages, order) + stepper = TimeStepper( + F, butcher_tableau, t, dt, u, Fexp=Fexp, bcs=bc, + solver_parameters=luparams, mass_parameters=luparams, + stage_type="dirkimex" + ) + + t_end = 2.0 + while float(t) < t_end: + if float(t) + float(dt) > t_end: + dt.assign(t_end - float(t)) + stepper.advance() + t.assign(float(t) + float(dt)) + # Check solution and boundary values + assert errornorm(uexact, u) / norm(uexact) < 10.0 ** -3 + assert isclose(u.at(x0), u_0) + assert isclose(u.at(x1), u_1)