diff --git a/irksome/base_time_stepper.py b/irksome/base_time_stepper.py index dd572f84..3a69501f 100644 --- a/irksome/base_time_stepper.py +++ b/irksome/base_time_stepper.py @@ -40,10 +40,9 @@ def __init__(self, F, t, dt, u0, appctx=None, nullspace=None, splitting=None, bc_type="DAE"): - super().__init__(F, butcher_tableau, t, dt, u0, + super().__init__(F, t, dt, u0, bcs=bcs, appctx=appctx, nullspace=nullspace) - def advance(self): self.solver.solve() self._update() @@ -53,5 +52,3 @@ def advance(self): @abstractmethod def getForm(self, butch=None): pass - - diff --git a/irksome/discontinuous_galerkin_stepper.py b/irksome/discontinuous_galerkin_stepper.py index 6b311085..c24d2cfe 100644 --- a/irksome/discontinuous_galerkin_stepper.py +++ b/irksome/discontinuous_galerkin_stepper.py @@ -229,7 +229,7 @@ def __init__(self, F, order, t, dt, u0, bcs=None, basis_type=None, bigNSP = getNullspace(V, UU.function_space(), order+1, nullspace) - + self.UU = UU self.bigBCs = bigBCs problem = NLVP(bigF, UU, bigBCs) diff --git a/irksome/galerkin_stepper.py b/irksome/galerkin_stepper.py index 72cf43e4..5ed63b25 100644 --- a/irksome/galerkin_stepper.py +++ b/irksome/galerkin_stepper.py @@ -212,7 +212,7 @@ def __init__(self, F, order, t, dt, u0, bcs=None, basis_type=None, bigNSP = getNullspace(V, UU.function_space(), order, nullspace) - + self.UU = UU self.bigBCs = bigBCs problem = NLVP(bigF, UU, bigBCs) diff --git a/irksome/imex.py b/irksome/imex.py index 083a77a9..ce50a76d 100644 --- a/irksome/imex.py +++ b/irksome/imex.py @@ -188,7 +188,7 @@ def __init__(self, F, Fexp, butcher_tableau, nsp = getNullspace(u0.function_space(), UU.function_space(), self.num_stages, nullspace) - + self.UU = UU self.UU_old = UU_old = Function(UU.function_space()) self.UU_old_split = UU_old.subfunctions diff --git a/irksome/pc.py b/irksome/pc.py index 399796f5..e2fd382d 100644 --- a/irksome/pc.py +++ b/irksome/pc.py @@ -3,6 +3,7 @@ import numpy from firedrake import AuxiliaryOperatorPC, derivative +from firedrake.dmhooks import get_appctx from ufl import replace from irksome.stage_derivative import getForm @@ -68,11 +69,14 @@ def form(self, pc, test, trial): butcher_new = copy.deepcopy(butcher_tableau) butcher_new.A = Atilde - # which getForm do I need to get? + # get stages + ctx = get_appctx(pc.getDM()) + w = ctx._x + # which getForm do I need to get? if stage_type in ("deriv", None): - Fnew, w, bcnew = \ - getForm(F, butcher_new, t, dt, u0, bcs, + Fnew, bcnew = \ + getForm(F, butcher_new, t, dt, u0, w, bcs, bc_type, splitting) elif stage_type == "value": Fnew, _, w, bcnew = \ @@ -124,9 +128,13 @@ def form(self, pc, test, trial): F, bcs = self.getNewForm(pc, u0, v0) # which getForm do I need to get? + # get stages + ctx = get_appctx(pc.getDM()) + w = ctx._x + if stage_type in ("deriv", None): - Fnew, w, bcnew = \ - getForm(F, butcher_tableau, t, dt, u0, bcs, + Fnew, bcnew = \ + getForm(F, butcher_tableau, t, dt, u0, w, bcs, bc_type, splitting) elif stage_type == "value": Fnew, _, w, bcnew = \ diff --git a/irksome/stage_derivative.py b/irksome/stage_derivative.py index 71a7b259..46583588 100644 --- a/irksome/stage_derivative.py +++ b/irksome/stage_derivative.py @@ -84,7 +84,7 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI): num_stages = butch.num_stages w = stages Vbig = stages.function_space() - + vnew = TestFunction(Vbig) v_np = numpy.reshape(vnew, (num_stages, *u0.ufl_shape)) w_np = numpy.reshape(w, (num_stages, *u0.ufl_shape)) @@ -305,14 +305,13 @@ def get_stages(self): num_stages = self.butcher_tableau.num_stages Vbig = reduce(mul, (self.V for _ in range(num_stages))) return Function(Vbig) - + def get_form_and_bcs(self, stages, butcher_tableau=None): if butcher_tableau is None: butcher_tableau = self.butcher_tableau return getForm(self.F, butcher_tableau, self.t, self.dt, self.u0, stages, self.orig_bcs, self.bc_type, self.splitting) - class AdaptiveTimeStepper(StageDerivativeTimeStepper): diff --git a/irksome/stage_value.py b/irksome/stage_value.py index b039f549..8fc4cc42 100644 --- a/irksome/stage_value.py +++ b/irksome/stage_value.py @@ -274,7 +274,6 @@ def __init__(self, F, butcher_tableau, t, dt, u0, bcs=None, UU.function_space(), self.num_stages, nullspace) - self.UU = UU self.bigBCs = bigBCs self.update_stuff = update_stuff diff --git a/irksome/stepper.py b/irksome/stepper.py index 4a945328..8dd5b638 100644 --- a/irksome/stepper.py +++ b/irksome/stepper.py @@ -165,5 +165,3 @@ def TimeStepper(F, butcher_tableau, t, dt, u0, **kwargs): return DIRKIMEXMethod( F, Fexp, butcher_tableau, t, dt, u0, bcs, solver_parameters, mass_parameters, appctx, nullspace) - -