diff --git a/demos/lowlevel/demo_lowlevel_homogbc.py.rst b/demos/lowlevel/demo_lowlevel_homogbc.py.rst index 11b4e70..cd6ed58 100644 --- a/demos/lowlevel/demo_lowlevel_homogbc.py.rst +++ b/demos/lowlevel/demo_lowlevel_homogbc.py.rst @@ -56,7 +56,7 @@ Continuing:: Now, we use the :func:`.getForm` function, which processes the semidiscrete problem:: - Fnew, k, bcnew, nspnew, bcdata = getForm(F, butcher_tableau, t, dt, u, bcs=bc) + Fnew, k, bcnew, nspnew = getForm(F, butcher_tableau, t, dt, u, bcs=bc) This returns several things: @@ -68,18 +68,7 @@ This returns several things: be enforced on the variational problem for the stages * ``nspnew`` is a new :class:`~firedrake.MixedVectorSpaceBasis` that can be used to express the nullspace of `Fnew` -* ``bcdata`` contains information needed to update the boundary - conditions. It is a list of pairs of the form (``f``, ``expr``), where - ``f`` is a :class:`~firedrake.function.Function` and ``expr`` is an - :class:`~ufl.core.expr.Expr` for each of the Dirichlet boundary conditions. - Because Firedrake isn't smart enough to detect that `t` changes in - the expression for the boundary condition, we need to manually - interpolate or project each :py:class:`~ufl.core.expr.Expr` onto the corresponding ``f`` at the - beginning of each time step. Firedrake will notice this change and - re-apply the boundary conditions. This hassle is easy to overlook - (not needed in this demo with homogeneous BC) and part of the reason - we recommend using the :class:`.TimeStepper` interface that does this - for you. + Solver parameters are just blunt-force LU. Other options are surely possible:: diff --git a/demos/lowlevel/demo_lowlevel_inhomogbc.py.rst b/demos/lowlevel/demo_lowlevel_inhomogbc.py.rst index 509e235..666fd8e 100644 --- a/demos/lowlevel/demo_lowlevel_inhomogbc.py.rst +++ b/demos/lowlevel/demo_lowlevel_inhomogbc.py.rst @@ -39,7 +39,7 @@ Imports:: As with the homogeneous BC case, we use the `getForm` method to process the semidiscrete problem:: - Fnew, k, bcnew, nspnew, bcdata = getForm(F, butcher_tableau, t, dt, u, bcs=bc) + Fnew, k, bcnew, nspnew = getForm(F, butcher_tableau, t, dt, u, bcs=bc) Recall that `getForm` produces: @@ -47,14 +47,6 @@ Recall that `getForm` produces: * ``k`` is a new :class:`~firedrake.function.Function` of stages on the s-way product of the space on which the problem was originally posed * ``bcnew`` is a list of new :class:`~firedrake.bcs.DirichletBC` that need to be enforced on the variational problem for the stages -* ``bcdata`` contains information needed to update the boundary - conditions. It is a list of triples of the form - (``f``,``expr``,``method``), where ``f`` is a - :class:`~firedrake.function.Function`, ``expr`` is an - :class:`~ufl.core.expr.Expr`, and ``method`` is either a project or - interpolate operation for each of the Dirichlet boundary conditions. - You're using the low-level interface and have to force Firedrake to - reapply the boundary conditions. We just use basic solver parameters and set up the variational problem @@ -77,9 +69,6 @@ boundary conditions at each time step:: if float(t) + float(dt) > 1.0: dt.assign(1.0 - float(t)) - for (gdat, gcur, gmethod) in bcdata: - gmethod(gcur, u) - solver.solve() for i in range(butcher_tableau.num_stages): diff --git a/demos/lowlevel/demo_lowlevel_mixed_heat.py.rst b/demos/lowlevel/demo_lowlevel_mixed_heat.py.rst index 7a61fed..01e62f3 100644 --- a/demos/lowlevel/demo_lowlevel_mixed_heat.py.rst +++ b/demos/lowlevel/demo_lowlevel_mixed_heat.py.rst @@ -52,7 +52,7 @@ Build the mesh and approximating spaces:: Because we aren't concerned with any strongly-enforced boundary conditions, we drop that information in calling `get_form`:: - Fnew, k, _, _, _ = getForm(F, butcher_tableau, t, dt, sigu) + Fnew, k, _, _ = getForm(F, butcher_tableau, t, dt, sigu) We set up the variational problem and solver using a sparse direct method:: diff --git a/irksome/bcs.py b/irksome/bcs.py new file mode 100644 index 0000000..7262a7c --- /dev/null +++ b/irksome/bcs.py @@ -0,0 +1,106 @@ +from functools import partial +from firedrake import (DirichletBC, Function, TestFunction, + NonlinearVariationalProblem, + NonlinearVariationalSolver, + replace, inner, dx) + + +def get_sub(u, indices): + for i in indices: + u = u.sub(i) + return u + + +def bc2space(bc, V): + return get_sub(V, bc._indices) + + +def stage2spaces4bc(bc, V, Vbig, i): + """used to figure out how to apply Dirichlet BC to each stage""" + num_fields = len(V) + sub = 0 if num_fields == 1 else bc.function_space_index() + comp = bc.function_space().component + + Vbigi = Vbig[sub+num_fields*i] + if comp is not None: # check for sub-piece of vector-valued + Vbigi = Vbigi.sub(comp) + + return Vbigi + + +def BCStageData(V, gcur, u0, u0_mult, i, t, dt): + if V.component is None: # V is not a bit of a VFS + if V.index is None: # not part of MFS, either + indices = () + else: # part of MFS + indices = (V.index,) + else: # bottommost space is bit of VFS + if V.parent.index is None: # but not part of a MFS + indices = (V.component,) + else: # V is a bit of a VFS inside an MFS + indices = (V.parent.index, V.component) + + if gcur == 0: # special case DirichletBC(V, 0, ...), do nothing + gdat = gcur + else: + gdat = gcur - u0_mult[i] * get_sub(u0, indices) + return gdat + + +def EmbeddedBCData(bc, t, dt, num_fields, butcher_tableau, ws, u0): + gorig = bc._original_arg + if gorig == 0: # special case DirichletBC(V, 0, ...), do nothing + gdat = gorig + else: + gcur = replace(gorig, {t: t+dt}) + sub = 0 if num_fields == 1 else bc.function_space_index() + comp = bc.function_space().component + num_stages = butcher_tableau.num_stages + btilde = butcher_tableau.btilde + if comp is None: # check for sub-piece of vector-valued + for j in range(num_stages): + gcur -= dt*btilde[j]*ws[num_fields*j+sub] + else: + for j in range(num_stages): + gcur -= dt*btilde[j]*ws[num_fields*j+sub].sub(comp) + + gdat = gcur - bc2space(bc, u0) + return gdat + + +class BoundsConstrainedBC(DirichletBC): + """A DirichletBC with bounds-constrained data.""" + def __init__(self, V, g, sub_domain, bounds, solver_parameters=None): + super().__init__(V, g, sub_domain) + if solver_parameters is None: + solver_parameters = { + "snes_type": "vinewtonssls", + } + self.solver_parameters = solver_parameters + self.bounds = bounds + + @property + def function_arg(self): + '''The value of this boundary condition.''' + if hasattr(self, "_function_arg_update"): + self._function_arg_update() + return self._function_arg + + @function_arg.setter + def function_arg(self, g): + '''Set the value of this boundary condition.''' + V = self.function_space() + gnew = Function(V) + try: + # Use the interpolant as initial guess + gnew.interpolate(g) + except (NotImplementedError, AttributeError): + pass + F = inner(TestFunction(V), gnew - g) * dx + problem = NonlinearVariationalProblem(F, gnew) + solver = NonlinearVariationalSolver(problem, + solver_parameters=self.solver_parameters) + + self._function_arg = gnew + self.function_arg_update = partial(solver.solve, bounds=self.bounds) + self.function_arg_update() diff --git a/irksome/getForm.py b/irksome/getForm.py index e014412..4a8e2b7 100644 --- a/irksome/getForm.py +++ b/irksome/getForm.py @@ -2,55 +2,14 @@ from operator import mul import numpy -from firedrake import (DirichletBC, Function, TestFunction, - assemble, project, split) -from firedrake.__future__ import interpolate +from firedrake import Function, TestFunction, split from ufl import diff from ufl.algorithms import expand_derivatives from ufl.classes import Zero from ufl.constantvalue import as_ufl -from .tools import ConstantOrZero, MeshConstant, replace, getNullspace, AI, stage2spaces4bc +from .tools import ConstantOrZero, MeshConstant, replace, getNullspace, AI from .deriv import TimeDerivative # , apply_time_derivatives - - -class BCStageData(object): - def __init__(self, V, gcur, u0, u0_mult, i, t, dt): - if V.component is not None: # bottommost space is bit of VFS - if V.parent.index is None: # but not part of a MFS - sub = V.component - try: - gdat = assemble(interpolate(gcur-u0_mult[i]*u0.sub(sub), V)) - gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u.sub(sub)) - except: # noqa: E722 - gdat = project(gcur-u0_mult[i]*u0.sub(sub), V) - gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u.sub(sub)) - else: # V is a bit of a VFS inside an MFS - sub0 = V.parent.index - sub1 = V.component - try: - gdat = assemble(interpolate(gcur-u0_mult[i]*u0.sub(sub0).sub(sub1), V)) - gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u.sub(sub0).sub(sub1)) - except: # noqa: E722 - gdat = project(gcur-u0_mult[i]*u0.sub(sub0).sub(sub1), V) - gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u.sub(sub0).sub(sub1)) - else: # V is not a bit of a VFS - if V.index is None: # not part of MFS, either - try: - gdat = assemble(interpolate(gcur-u0_mult[i]*u0, V)) - gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u) - except: # noqa: E722 - gdat = project(gcur-u0_mult[i]*u0, V) - gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u) - else: # part of MFS - sub = V.index - try: - gdat = assemble(interpolate(gcur-u0_mult[i]*u0.sub(sub), V)) - gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u.sub(sub)) - except: # noqa: E722 - gdat = project(gcur-u0_mult[i]*u0.sub(sub), V) - gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u.sub(sub)) - - self.gstuff = (gdat, gcur, gmethod) +from .bcs import BCStageData, bc2space, stage2spaces4bc def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI, @@ -99,16 +58,6 @@ def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI, on the stages, - 'nspnew', the :class:`firedrake.MixedVectorSpaceBasis` object that represents the nullspace of the coupled system - - `gblah`, a list of tuples of the form (f, expr, method), - where f is a :class:`firedrake.Function` and expr is a - :class:`ufl.Expr`. At each time step, each expr needs to be - re-interpolated/projected onto the corresponding f in order - for Firedrake to pick up that time-dependent boundary - conditions need to be re-applied. The - interpolation/projection is encoded in method, which is - either `f.interpolate(expr-c*u0)` or `f.project(expr-c*u0)`, depending - on whether the function space for f supports interpolation or - not. """ if bc_type is None: bc_type = "DAE" @@ -191,7 +140,6 @@ def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI, Fnew += replace(F, repl) bcnew = [] - gblah = [] if bcs is None: bcs = [] @@ -227,13 +175,12 @@ def bc2gcur(bc, i): # set up the new BCs for either method for bc in bcs: for i in range(num_stages): - Vsp, Vbigi = stage2spaces4bc(bc, V, Vbig, i) + Vsp = bc2space(bc, V) + Vbigi = stage2spaces4bc(bc, V, Vbig, i) gcur = bc2gcur(bc, i) - blah = BCStageData(Vsp, gcur, u0, u0_mult, i, t, dt) - gdat, gcr, gmethod = blah.gstuff - gblah.append((gdat, gcr, gmethod)) - bcnew.append(DirichletBC(Vbigi, gdat, bc.sub_domain)) + gdat = BCStageData(Vsp, gcur, u0, u0_mult, i, t, dt) + bcnew.append(bc.reconstruct(V=Vbigi, g=gdat)) nspnew = getNullspace(V, Vbig, butch, nullspace) - return Fnew, w, bcnew, nspnew, gblah + return Fnew, w, bcnew, nspnew diff --git a/irksome/imex.py b/irksome/imex.py index e1f3ec0..cf66f11 100644 --- a/irksome/imex.py +++ b/irksome/imex.py @@ -207,7 +207,7 @@ def __init__(self, F, Fexp, butcher_tableau, # Since this assumes stiff accuracy, we drop # the update information on the floor. - Fbig, _, UU, bigBCs, gblah, nsp = getFormStage( + Fbig, _, UU, bigBCs, nsp = getFormStage( F, butcher_tableau, u0, t, dt, bcs, splitting=splitting, nullspace=nullspace) @@ -215,7 +215,6 @@ def __init__(self, F, Fexp, butcher_tableau, self.UU_old = UU_old = Function(UU.function_space()) self.UU_old_split = UU_old.subfunctions self.bigBCs = bigBCs - self.bcdat = gblah Fit, Fprop = getFormExplicit( Fexp, butcher_tableau, u0, UU_old, t, dt, splitting) @@ -283,8 +282,6 @@ def propagate(self): for i, u0bit in enumerate(u0split): u0bit.assign(self.UU_old_split[(ns-1)*nf + i]) - for gdat, gcur, gmethod in self.bcdat: - gmethod(gdat, gcur) push_parent(self.u0.function_space().dm, self.UU.function_space().dm) ps = self.prop_solver diff --git a/irksome/pc.py b/irksome/pc.py index d100616..7c6143c 100644 --- a/irksome/pc.py +++ b/irksome/pc.py @@ -71,11 +71,11 @@ def form(self, pc, test, trial): # which getForm do I need to get? if stage_type in ("deriv", None): - Fnew, w, bcnew, bignsp, _ = \ + Fnew, w, bcnew, bignsp = \ getForm(F, butcher_new, t, dt, u0, bcs, bc_type, splitting, nullspace) elif stage_type == "value": - Fnew, _, w, bcnew, _, bignsp = \ + Fnew, _, w, bcnew, bignsp = \ getFormStage(F, butcher_new, u0, t, dt, bcs, splitting, nullspace) # Now we get the Jacobian for the modified system, @@ -125,11 +125,11 @@ def form(self, pc, test, trial): # which getForm do I need to get? if stage_type in ("deriv", None): - Fnew, w, bcnew, bignsp, _ = \ + Fnew, w, bcnew, bignsp = \ getForm(F, butcher_tableau, t, dt, u0, bcs, bc_type, splitting, nullspace) elif stage_type == "value": - Fnew, _, w, bcnew, _, bignsp = \ + Fnew, _, w, bcnew, bignsp = \ getFormStage(F, butcher_tableau, u0, t, dt, bcs, splitting, nullspace) # Now we get the Jacobian for the modified system, diff --git a/irksome/stage.py b/irksome/stage.py index 0e57e6d..68ebb51 100644 --- a/irksome/stage.py +++ b/irksome/stage.py @@ -5,19 +5,19 @@ import numpy as np from FIAT import Bernstein, ufc_simplex -from firedrake import (DirichletBC, Function, NonlinearVariationalProblem, - NonlinearVariationalSolver, TestFunction, assemble, dx, - inner, project, solve, split) -from firedrake.__future__ import interpolate +from firedrake import (Function, NonlinearVariationalProblem, + NonlinearVariationalSolver, TestFunction, dx, + inner, split) from firedrake.petsc import PETSc from numpy import vectorize from ufl.classes import Zero from ufl.constantvalue import as_ufl +from .bcs import stage2spaces4bc from .ButcherTableaux import CollocationButcherTableau from .manipulation import extract_terms, strip_dt_form from .tools import (AI, IA, ConstantOrZero, MeshConstant, getNullspace, is_ode, - replace, stage2spaces4bc) + replace) def isiterable(x): @@ -106,16 +106,6 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None, vandermonde=None on the stages, - 'nspnew', the :class:`firedrake.MixedVectorSpaceBasis` object that represents the nullspace of the coupled system - - `gblah`, a list of tuples of the form (f, expr, method), - where f is a :class:`firedrake.Function` and expr is a - :class:`ufl.Expr`. At each time step, each expr needs to be - re-interpolated/projected onto the corresponding f in order - for Firedrake to pick up that time-dependent boundary - conditions need to be re-applied. The - interpolation/projection is encoded in method, which is - either `f.interpolate(expr-c*u0)` or `f.project(expr-c*u0)`, depending - on whether the function space for f supports interpolation or - not. """ v = F.arguments()[0] V = v.function_space() @@ -261,7 +251,6 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None, vandermonde=None if bc_constraints is None: bc_constraints = {} bcsnew = [] - gblah = [] # For each BC, we need a new BC for each stage # so we need to figure out how the function is indexed (mixed + vec) @@ -271,44 +260,26 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None, vandermonde=None for bc in bcs: bcarg = as_ufl(bc._original_arg) - gblah_cur = [] - if bc in bc_constraints: bcparams, bclower, bcupper = bc_constraints[bc] - for i in range(num_stages): - Vsp, Vbigi = stage2spaces4bc(bc, V, Vbig, i) - gdat = Function(Vsp) - vbc = TestFunction(Vsp) - gmethod = lambda gd, gc: solve(inner(gd - gc, vbc) * dx == 0, - gd, solver_parameters=bcparams) gcur = replace(bcarg, {t: t+C[i] * dt}) - gblah_cur.append((gdat, gcur - Vander_col[i] * bcarg, gmethod)) + gcur = gcur - Vander_col[i] * bcarg else: - for i in range(num_stages): - Vsp, Vbigi = stage2spaces4bc(bc, V, Vbig, i) - try: - gdat = assemble(interpolate(bcarg, Vsp)) - gmethod = lambda gd, gc: gd.interpolate(gc) - except: # noqa: E722 - gdat = project(bcarg, Vsp) - gmethod = lambda gd, gc: gd.project(gc) - - gcur = replace(bcarg, {t: t+C[i]*dt}) - gblah_cur.append((gdat, gcur - Vander_col[i] * bcarg, gmethod)) - gdats_cur = np.zeros((num_stages,), dtype="O") for i in range(num_stages): - gdats_cur[i] = gblah_cur[i][0] + Vbigi = stage2spaces4bc(bc, V, Vbig, i) + gcur = replace(bcarg, {t: t+C[i]*dt}) + gcur = gcur - Vander_col[i] * bcarg + gdats_cur[i] = gcur zdats_cur = Vander_inv[1:, 1:] @ gdats_cur bcnew_cur = [] for i in range(num_stages): - Vsp, Vbigi = stage2spaces4bc(bc, V, Vbig, i) - bcnew_cur.append(DirichletBC(Vbigi, zdats_cur[i], bc.sub_domain)) + Vbigi = stage2spaces4bc(bc, V, Vbig, i) + bcnew_cur.append(bc.reconstruct(V=Vbigi, g=zdats_cur[i])) bcsnew.extend(bcnew_cur) - gblah.extend(gblah_cur) nspacenew = getNullspace(V, Vbig, butch, nullspace) @@ -323,7 +294,6 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None, vandermonde=None for i in range(num_stages): repl = {t: t + C[i] * dt} - for k in range(num_fields): repl[u0bits[k]] = UUbits[i][k] for ii in np.ndindex(u0bits[k].ufl_shape): @@ -335,39 +305,16 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None, vandermonde=None # And the BC's for the update -- just the original BC at t+dt update_bcs = [] - update_bcs_gblah = [] for bc in bcs: - if num_fields == 1: # not mixed space - comp = bc.function_space().component - if comp is not None: # check for sub-piece of vector-valued - Vsp = V.sub(comp) - else: - Vsp = V - else: # mixed space - sub = bc.function_space_index() - comp = bc.function_space().component - if comp is not None: # check for sub-piece of vector-valued - Vsp = V.sub(sub).sub(comp) - else: - Vsp = V.sub(sub) - bcarg = as_ufl(bc._original_arg) - try: - gdat = assemble(interpolate(bcarg, Vsp)) - gmethod = lambda gd, gc: gd.interpolate(gc) - except: # noqa: E722 - gdat = project(bcarg, Vsp) - gmethod = lambda gd, gc: gd.project(gc) - gcur = replace(bcarg, {t: t+dt}) - update_bcs.append(DirichletBC(Vsp, gdat, bc.sub_domain)) - update_bcs_gblah.append((gdat, gcur, gmethod)) + update_bcs.append(bc.reconstruct(g=gcur)) - update_stuff = (unew, Fupdate, update_bcs, update_bcs_gblah) + update_stuff = (unew, Fupdate, update_bcs) else: update_stuff = None - return (Fnew, update_stuff, ZZ, bcsnew, gblah, nspacenew) + return (Fnew, update_stuff, ZZ, bcsnew, nspacenew) class StageValueTimeStepper: @@ -396,17 +343,16 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, assert isinstance(butcher_tableau, CollocationButcherTableau), "Need collocation for Bernstein conversion" bern = Bernstein(ufc_simplex(1), num_stages) cc = np.reshape(np.append(0, butcher_tableau.c), (-1, 1)) - vandermonde = bern.tabulate(0, np.reshape(cc, (-1, 1)))[0,].T + vandermonde = bern.tabulate(0, np.reshape(cc, (-1, 1)))[(0, )].T else: raise ValueError("Unknown or unimplemented basis transformation type") - Fbig, update_stuff, UU, bigBCs, gblah, nsp = getFormStage( + Fbig, update_stuff, UU, bigBCs, nsp = getFormStage( F, butcher_tableau, u0, t, dt, bcs, vandermonde=vandermonde, splitting=splitting) self.UU = UU self.bigBCs = bigBCs - self.bcdat = gblah self.update_stuff = update_stuff self.prob = NonlinearVariationalProblem(Fbig, UU, bigBCs) @@ -429,7 +375,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, solver_parameters=solver_parameters) if (not butcher_tableau.is_stiffly_accurate) and (basis_type != "Bernstein"): - unew, Fupdate, update_bcs, update_bcs_gblah = self.update_stuff + unew, Fupdate, update_bcs = self.update_stuff self.update_problem = NonlinearVariationalProblem( Fupdate, unew, update_bcs) @@ -453,9 +399,7 @@ def _update_stiff_acc(self): u0bit.assign(UUs[self.num_fields*(self.num_stages-1)+i]) def _update_general(self): - (unew, Fupdate, update_bcs, update_bcs_gblah) = self.update_stuff - for gdat, gcur, gmethod in update_bcs_gblah: - gmethod(gdat, gcur) + unew, Fupdate, update_bcs = self.update_stuff self.update_solver.solve() unewbits = unew.subfunctions for u0bit, unewbit in zip(self.u0.subfunctions, unewbits): @@ -503,9 +447,6 @@ def advance(self, bounds=None): stage_bounds = (slb, sub) - for gdat, gcur, gmethod in self.bcdat: - gmethod(gdat, gcur) - self.solver.solve(bounds=stage_bounds) self.num_steps += 1 diff --git a/irksome/stepper.py b/irksome/stepper.py index 2837ef2..1117e9e 100644 --- a/irksome/stepper.py +++ b/irksome/stepper.py @@ -1,12 +1,11 @@ import numpy -from firedrake import DirichletBC, Function +from firedrake import Function from firedrake import NonlinearVariationalProblem as NLVP from firedrake import NonlinearVariationalSolver as NLVS -from firedrake import TestFunction, assemble, dx, inner, norm, project, replace -from firedrake.__future__ import interpolate +from firedrake import TestFunction, assemble, dx, inner, norm from firedrake.dmhooks import pop_parent, push_parent -from ufl.constantvalue import as_ufl +from .bcs import EmbeddedBCData, bc2space from .dirk_stepper import DIRKTimeStepper from .explicit_stepper import ExplicitTimeStepper from .getForm import AI, getForm @@ -221,12 +220,11 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, self.num_nonlinear_iterations = 0 self.num_linear_iterations = 0 - bigF, stages, bigBCs, bigNSP, bigBCdata = \ + bigF, stages, bigBCs, bigNSP = \ getForm(F, butcher_tableau, t, dt, u0, bcs, bc_type, splitting, nullspace) self.stages = stages self.bigBCs = bigBCs - self.bigBCdata = bigBCdata problem = NLVP(bigF, stages, bigBCs) appctx_irksome = {"F": F, "butcher_tableau": butcher_tableau, @@ -301,9 +299,6 @@ def _update_A2Tmb(self): def advance(self): """Advances the system from time `t` to time `t + dt`. Note: overwrites the value `u0`.""" - for gdat, gcur, gmethod in self.bigBCdata: - gmethod(gcur, self.u0) - push_parent(self.u0.function_space().dm, self.stages.function_space().dm) self.solver.solve() pop_parent(self.u0.function_space().dm, self.stages.function_space().dm) @@ -408,76 +403,19 @@ def __init__(self, F, butcher_tableau, t, dt, u0, self.dtless_form = -split_form.remainder # Set up and cache boundary conditions for error estimate + embbc = [] if self.gamma0 != 0: # Grab spaces for BCs v = F.arguments()[0] V = v.function_space() num_fields = len(V) - num_stages = butcher_tableau.num_stages - btilde = butcher_tableau.btilde ws = self.ws - class EmbeddedBCData(object): - def __init__(self, bc, t, dt, num_fields, num_stages, btilde, V, ws, u0): - gorig = as_ufl(bc._original_arg) - gcur = replace(gorig, {t: t+dt}) - if num_fields == 1: # not mixed space - comp = bc.function_space().component - if comp is not None: # check for sub-piece of vector-valued - Vsp = V.sub(comp) - for j in range(num_stages): - gcur -= dt*btilde[j]*ws[j].sub(comp) - try: - gdat = assemble(interpolate(gcur-u0.sub(comp), Vsp)) - gmethod = lambda g, u: gdat.interpolate(g-u.sub(comp)) - except: # noqa: E722 - gdat = project(gcur-u0.sub(comp), Vsp) - gmethod = lambda g, u: gdat.project(g-u.sub(comp)) - else: - Vsp = V - for j in range(num_stages): - gcur -= dt*btilde[j]*ws[j] - try: - gdat = assemble(interpolate(gcur-u0, Vsp)) - gmethod = lambda g, u: gdat.interpolate(g-u) - except: # noqa: E722 - gdat = project(gcur-u0, Vsp) - gmethod = lambda g, u: gdat.project(g-u) - - else: # mixed space - sub = bc.function_space_index() - comp = bc.function_space().component - if comp is not None: # check for sub-piece of vector-valued - Vsp = V.sub(sub).sub(comp) - for j in range(num_stages): - gcur -= dt*btilde[j]*ws[num_fields*j+sub].sub(comp) - try: - gdat = assemble(interpolate(gcur-u0.sub(sub).sub(comp), Vsp)) - gmethod = lambda g, u: gdat.interpolate(g-u.sub(sub).sub(comp)) - except: # noqa: E722 - gdat = project(gcur-u0.sub(sub).sub(comp), Vsp) - gmethod = lambda g, u: gdat.project(g-u.sub(sub).sub(comp)) - else: - Vsp = V.sub(sub) - for j in range(num_stages): - gcur -= dt*btilde[j]*ws[num_fields*j+sub] - try: - gdat = assemble(interpolate(gcur-u0.sub(sub), Vsp)) - gmethod = lambda g, u: gdat.interpolate(g-u.sub(sub)) - except: # noqa: E722 - gdat = project(gcur-u0.sub(sub), Vsp) - gmethod = lambda g, u: gdat.project(g-u.sub(sub)) - self.gstuff = (gdat, gcur, gmethod, Vsp) - - embbc = [] - gblah = [] for bc in bcs: - blah = EmbeddedBCData(bc, self.t, self.dt, num_fields, num_stages, btilde, V, ws, self.u0) - gdat, gcur, gmethod, gVsp = blah.gstuff - gblah.append((gdat, gcur, gmethod)) - embbc.append(DirichletBC(gVsp, gdat, bc.sub_domain)) - self.embbc = embbc - self.gblah = gblah + gVsp = bc2space(bc, V) + gdat = EmbeddedBCData(bc, self.t, self.dt, num_fields, butcher_tableau, ws, self.u0) + embbc.append(bc.reconstruct(V=gVsp, g=gdat)) + self.embbc = embbc def _estimate_error(self): """Assuming that the RK stages have been evaluated, estimates @@ -497,8 +435,6 @@ def _estimate_error(self): if self.gamma0 != 0.0: error_test = TestFunction(u0.function_space()) f_form = inner(error_func, error_test)*dx-self.gamma0*dtc*self.dtless_form - for gdat, gcur, gmethod in self.gblah: - gmethod(gcur, self.u0) f_problem = NLVP(f_form, error_func, bcs=self.embbc) f_solver = NLVS(f_problem, solver_parameters=self.gamma0_params) f_solver.solve() @@ -520,9 +456,6 @@ def advance(self): self.dt.assign(self.dt_max) self.print("\tTrying dt = %e" % (float(self.dt))) while 1: - for gdat, gcur, gmethod in self.bigBCdata: - gmethod(gcur, self.u0) - self.solver.solve() self.num_nonlinear_iterations += self.solver.snes.getIterationNumber() self.num_linear_iterations += self.solver.snes.getLinearSolveIterations() diff --git a/irksome/tools.py b/irksome/tools.py index 946dd42..e9bf6c1 100644 --- a/irksome/tools.py +++ b/irksome/tools.py @@ -122,39 +122,3 @@ def Constant(self, val=0.0): def ConstantOrZero(x, MC): return Zero() if abs(complex(x)) < 1.e-10 else MC.Constant(x) - - -def bc2space(bc, V): - num_fields = len(V) - if num_fields == 1: # not mixed space - comp = bc.function_space().component - Vsp = V if comp is None else V.sub(comp) - else: # mixed space - sub = bc.function_space_index() - comp = bc.function_space().component - Vsp = V.sub(sub) if comp is None else V.sub(sub).sub(comp) - return Vsp - - -# used to figure out how to apply Dirichlet BC to each stage -def stage2spaces4bc(bc, V, Vbig, i): - num_fields = len(V) - if num_fields == 1: # not mixed space - comp = bc.function_space().component - if comp is not None: # check for sub-piece of vector-valued - Vsp = V.sub(comp) - Vbigi = Vbig[i].sub(comp) - else: - Vsp = V - Vbigi = Vbig[i] - else: # mixed space - sub = bc.function_space_index() - comp = bc.function_space().component - if comp is not None: # check for sub-piece of vector-valued - Vsp = V.sub(sub).sub(comp) - Vbigi = Vbig[sub+num_fields*i].sub(comp) - else: - Vsp = V.sub(sub) - Vbigi = Vbig[sub+num_fields*i] - - return Vsp, Vbigi diff --git a/tests/test_split.py b/tests/test_split.py index 83e0714..7045a0f 100644 --- a/tests/test_split.py +++ b/tests/test_split.py @@ -137,7 +137,7 @@ def Ffull(z, test): return Fimp(z, test) + Fexp(z, test) bcs = [DirichletBC(Z.sub(0), as_vector([x*(1-x), 0]), (4,)), - DirichletBC(Z.sub(0), as_vector([0, 0]), (1, 2, 3))] + DirichletBC(Z.sub(0), 0, (1, 2, 3))] nsp = [(1, VectorSpaceBasis(constant=True, comm=COMM_WORLD))] diff --git a/tests/test_stokes.py b/tests/test_stokes.py index 4f3c5c5..b3f2515 100644 --- a/tests/test_stokes.py +++ b/tests/test_stokes.py @@ -128,7 +128,7 @@ def NSETest(butch, stage_type, splitting): ) bcs = [DirichletBC(Z.sub(0), Constant((1, 0)), (4,)), - DirichletBC(Z.sub(0), Constant((0, 0)), (1, 2, 3))] + DirichletBC(Z.sub(0), 0, (1, 2, 3))] nullspace = [(1, VectorSpaceBasis(constant=True))]