From ab4243427e3317ae322ac9649b6abd498932bd37 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 3 Mar 2025 18:58:55 -0600 Subject: [PATCH 1/6] Indexed high-order TimeDerivatives --- irksome/deriv.py | 16 ++++++++++++---- irksome/tools.py | 12 ++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/irksome/deriv.py b/irksome/deriv.py index 694dd7e..7cf56f2 100644 --- a/irksome/deriv.py +++ b/irksome/deriv.py @@ -4,6 +4,7 @@ from ufl.algorithms.map_integrands import map_integrand_dags, map_expr_dag from ufl.algorithms.apply_derivatives import GenericDerivativeRuleset from ufl.tensors import ListTensor +from ufl.indexed import Indexed @ufl_type(num_ops=1, @@ -18,6 +19,7 @@ class TimeDerivative(Derivative): def __new__(cls, f): if isinstance(f, ListTensor): + # Push TimeDerivative inside ListTensor return ListTensor(*map(TimeDerivative, f.ufl_operands)) return Derivative.__new__(cls) @@ -27,11 +29,17 @@ def __init__(self, f): def __str__(self): return "d{%s}/dt" % (self.ufl_operands[0],) + def _simplify_indexed(self, multiindex): + """Return a simplified Expr used in the constructor of Indexed(self, multiindex).""" + # Push Indexed inside TimeDerivative + return TimeDerivative(Indexed(self.ufl_operands[0], multiindex)) -def Dt(f): - """Short-hand function to produce a :class:`TimeDerivative` of the - input.""" - return TimeDerivative(f) + +def Dt(f, order=1): + """Short-hand function to produce a :class:`TimeDerivative` of a given order.""" + for k in range(order): + f = TimeDerivative(f) + return f class TimeDerivativeRuleset(GenericDerivativeRuleset): diff --git a/irksome/tools.py b/irksome/tools.py index cd04701..4f18172 100644 --- a/irksome/tools.py +++ b/irksome/tools.py @@ -97,22 +97,14 @@ def replace(e, mapping): return map_integrand_dags(MyReplacer(mapping2), e) -def get_component(expr, index): - if isinstance(expr, TimeDerivative): - expr, = expr.ufl_operands - return TimeDerivative(expr[index]) - else: - return expr[index] - - def component_replace(e, mapping): - # Replace, reccurring on components + """Replace, recurring on components""" cmapping = {} for key, value in mapping.items(): cmapping[key] = as_tensor(value) if key.ufl_shape: for j in numpy.ndindex(key.ufl_shape): - cmapping[get_component(key, j)] = value[j] + cmapping[key[j]] = value[j] return replace(e, cmapping) From 93d75a2b6bbfc9db0b237317380782a5db5f2448 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Mon, 3 Mar 2025 22:40:38 -0600 Subject: [PATCH 2/6] TimeDerivativeRuleSet --- irksome/deriv.py | 46 ++++++++++++++-------------------------------- 1 file changed, 14 insertions(+), 32 deletions(-) diff --git a/irksome/deriv.py b/irksome/deriv.py index 7cf56f2..16f0ccd 100644 --- a/irksome/deriv.py +++ b/irksome/deriv.py @@ -3,6 +3,7 @@ from ufl.corealg.multifunction import MultiFunction from ufl.algorithms.map_integrands import map_integrand_dags, map_expr_dag from ufl.algorithms.apply_derivatives import GenericDerivativeRuleset +from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering from ufl.tensors import ListTensor from ufl.indexed import Indexed @@ -43,31 +44,29 @@ def Dt(f, order=1): class TimeDerivativeRuleset(GenericDerivativeRuleset): - """Apply AD rules to time derivative expressions. WIP""" - def __init__(self, t, timedep_coeffs): + """Apply AD rules to time derivative expressions.""" + def __init__(self, timedep_coeffs=None): GenericDerivativeRuleset.__init__(self, ()) - self.t = t self.timedep_coeffs = timedep_coeffs def coefficient(self, o): - if o in self.timedep_coeffs: + if self.timedep_coeffs is None or o in self.timedep_coeffs: return TimeDerivative(o) else: return self.independent_terminal(o) - # def indexed(self, o, Ap, ii): - # print(o, type(o)) - # print(Ap, type(Ap)) - # print(ii, type(ii)) - # 1/0 + def indexed(self, o, Ap, ii): + return TimeDerivative(o) + + def time_derivative(self, o): + return TimeDerivative(o) # mapping rules to splat out time derivatives so that replacement should # work on more complex problems. class TimeDerivativeRuleDispatcher(MultiFunction): - def __init__(self, t, timedep_coeffs): + def __init__(self, timedep_coeffs=None): MultiFunction.__init__(self) - self.t = t self.timedep_coeffs = timedep_coeffs def terminal(self, o): @@ -78,30 +77,13 @@ def derivative(self, o): expr = MultiFunction.reuse_if_untouched - def grad(self, o): - from firedrake import grad - if isinstance(o, TimeDerivative): - return TimeDerivative(grad(*o.ufl_operands)) - return o - - def div(self, o): - return o - - def reference_grad(self, o): - return o - - def coefficient_derivative(self, o): - return o - - def coordinate_derivative(self, o): - return o - def time_derivative(self, o): f, = o.ufl_operands - rules = TimeDerivativeRuleset(self.t, self.timedep_coeffs) + rules = TimeDerivativeRuleset(timedep_coeffs=self.timedep_coeffs) return map_expr_dag(rules, f) -def apply_time_derivatives(expression, t, timedep_coeffs=[]): - rules = TimeDerivativeRuleDispatcher(t, timedep_coeffs) +def apply_time_derivatives(expression, timedep_coeffs=None): + expression = apply_algebra_lowering(expression) + rules = TimeDerivativeRuleDispatcher(timedep_coeffs=timedep_coeffs) return map_integrand_dags(rules, expression) From af21c748be3b677b159bd52aacdf6a8cea535ee6 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 4 Mar 2025 08:00:04 -0600 Subject: [PATCH 3/6] High-order TimeDerivatives of expressions --- irksome/deriv.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/irksome/deriv.py b/irksome/deriv.py index 16f0ccd..6028a25 100644 --- a/irksome/deriv.py +++ b/irksome/deriv.py @@ -59,7 +59,8 @@ def indexed(self, o, Ap, ii): return TimeDerivative(o) def time_derivative(self, o): - return TimeDerivative(o) + f, = o.ufl_operands + return TimeDerivative(map_expr_dag(self, f)) # mapping rules to splat out time derivatives so that replacement should @@ -77,13 +78,23 @@ def derivative(self, o): expr = MultiFunction.reuse_if_untouched - def time_derivative(self, o): - f, = o.ufl_operands + def time_derivative(self, o, f): + nderivs = 0 + while isinstance(o, TimeDerivative): + o, = o.ufl_operands + nderivs += 1 rules = TimeDerivativeRuleset(timedep_coeffs=self.timedep_coeffs) - return map_expr_dag(rules, f) + for k in range(nderivs): + o = map_expr_dag(rules, o) + return o def apply_time_derivatives(expression, timedep_coeffs=None): - expression = apply_algebra_lowering(expression) rules = TimeDerivativeRuleDispatcher(timedep_coeffs=timedep_coeffs) return map_integrand_dags(rules, expression) + + +def expand_time_derivatives(expression, timedep_coeffs=None): + expression = apply_algebra_lowering(expression) + expression = apply_time_derivatives(expression) + return expression From cc18a77caef8be16cf7a258038cffb038ba2af35 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 4 Mar 2025 09:31:53 -0600 Subject: [PATCH 4/6] add tests --- irksome/__init__.py | 2 +- tests/test_differentiation.py | 57 +++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 tests/test_differentiation.py diff --git a/irksome/__init__.py b/irksome/__init__.py index bc7c773..dde3467 100644 --- a/irksome/__init__.py +++ b/irksome/__init__.py @@ -8,7 +8,7 @@ from .ButcherTableaux import QinZhang # noqa: F401 from .ButcherTableaux import RadauIIA # noqa: F401 from .pep_explicit_rk import PEPRK # noqa: F401 -from .deriv import Dt # noqa: F401 +from .deriv import Dt, expand_time_derivatives # noqa: F401 from .dirk_imex_tableaux import DIRK_IMEX # noqa: F401 from .ars_dirk_imex_tableaux import ARS_DIRK_IMEX # noqa: F401 from .sspk_tableau import SSPK_DIRK_IMEX # noqa: F401 diff --git a/tests/test_differentiation.py b/tests/test_differentiation.py new file mode 100644 index 0000000..4ab4962 --- /dev/null +++ b/tests/test_differentiation.py @@ -0,0 +1,57 @@ +import pytest +from irksome import Dt, expand_time_derivatives +from firedrake import Constant, dot, FunctionSpace, Function, UnitIntervalMesh, VectorFunctionSpace + + +@pytest.fixture +def mesh(): + return UnitIntervalMesh(1) + + +@pytest.fixture(params=("scalar",)) +def V(request, mesh): + if request.param == "scalar": + return FunctionSpace(mesh, "DG", 0) + elif request.param == "vector": + return VectorFunctionSpace(mesh, "DG", 0) + + +def test_second_derivative(V): + u = Function(V) + assert Dt(u, 2) == Dt(Dt(u)) + + +def test_expand_sum(V): + u = Function(V) + w = Function(V) + k1 = Constant(1) + k2 = Constant(2) + expr = Dt(k1*u + k2*w) + + expr = expand_time_derivatives(expr) + expected = k1*Dt(u) + k2*Dt(w) + assert expr == expand_time_derivatives(expected) + + +def test_expand_product_rule(V): + u = Function(V) + w = Function(V) + expr = Dt(dot(u, w)) + + expr = expand_time_derivatives(expr) + expected = dot(u, Dt(w)) + dot(Dt(u), w) + assert expr == expand_time_derivatives(expected) + + +def test_expand_second_derivative_product_rule(V): + u = Function(V) + w = Function(V) + expr = Dt(Dt(dot(u, w))) + + expr = expand_time_derivatives(expr) + expected = (dot(Dt(u, 2), w) + + dot(Dt(u), Dt(w)) + + dot(Dt(u), Dt(w)) + + dot(u, Dt(w, 2))) + # UFL equality is failing here due to different index numbers + assert str(expr) == str(expected) From 80417741c1a32ecb0d0a293a7875b665ed5cec55 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 4 Mar 2025 10:35:06 -0600 Subject: [PATCH 5/6] Splat out non-autonomous terms --- irksome/deriv.py | 20 ++++++++++++-------- tests/test_differentiation.py | 15 +++++++++++++-- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/irksome/deriv.py b/irksome/deriv.py index 6028a25..d7ca580 100644 --- a/irksome/deriv.py +++ b/irksome/deriv.py @@ -45,12 +45,15 @@ def Dt(f, order=1): class TimeDerivativeRuleset(GenericDerivativeRuleset): """Apply AD rules to time derivative expressions.""" - def __init__(self, timedep_coeffs=None): + def __init__(self, t=None, timedep_coeffs=None): GenericDerivativeRuleset.__init__(self, ()) + self.t = t self.timedep_coeffs = timedep_coeffs def coefficient(self, o): - if self.timedep_coeffs is None or o in self.timedep_coeffs: + if self.t is not None and o is self.t: + return 1.0 + elif self.timedep_coeffs is None or o in self.timedep_coeffs: return TimeDerivative(o) else: return self.independent_terminal(o) @@ -66,8 +69,9 @@ def time_derivative(self, o): # mapping rules to splat out time derivatives so that replacement should # work on more complex problems. class TimeDerivativeRuleDispatcher(MultiFunction): - def __init__(self, timedep_coeffs=None): + def __init__(self, t=None, timedep_coeffs=None): MultiFunction.__init__(self) + self.t = t self.timedep_coeffs = timedep_coeffs def terminal(self, o): @@ -83,18 +87,18 @@ def time_derivative(self, o, f): while isinstance(o, TimeDerivative): o, = o.ufl_operands nderivs += 1 - rules = TimeDerivativeRuleset(timedep_coeffs=self.timedep_coeffs) + rules = TimeDerivativeRuleset(t=self.t, timedep_coeffs=self.timedep_coeffs) for k in range(nderivs): o = map_expr_dag(rules, o) return o -def apply_time_derivatives(expression, timedep_coeffs=None): - rules = TimeDerivativeRuleDispatcher(timedep_coeffs=timedep_coeffs) +def apply_time_derivatives(expression, t=None, timedep_coeffs=None): + rules = TimeDerivativeRuleDispatcher(t=t, timedep_coeffs=timedep_coeffs) return map_integrand_dags(rules, expression) -def expand_time_derivatives(expression, timedep_coeffs=None): +def expand_time_derivatives(expression, t=None, timedep_coeffs=None): expression = apply_algebra_lowering(expression) - expression = apply_time_derivatives(expression) + expression = apply_time_derivatives(expression, t=t, timedep_coeffs=timedep_coeffs) return expression diff --git a/tests/test_differentiation.py b/tests/test_differentiation.py index 4ab4962..8ca4c50 100644 --- a/tests/test_differentiation.py +++ b/tests/test_differentiation.py @@ -1,6 +1,7 @@ import pytest -from irksome import Dt, expand_time_derivatives -from firedrake import Constant, dot, FunctionSpace, Function, UnitIntervalMesh, VectorFunctionSpace +from ufl.algorithms import expand_derivatives +from irksome import MeshConstant, Dt, expand_time_derivatives +from firedrake import Constant, diff, dot, FunctionSpace, Function, sin, UnitIntervalMesh, VectorFunctionSpace @pytest.fixture @@ -21,6 +22,16 @@ def test_second_derivative(V): assert Dt(u, 2) == Dt(Dt(u)) +def test_diff(mesh): + MC = MeshConstant(mesh) + t = MC.Constant(0.0) + q = sin(t**2) + expr = Dt(q) + expected = expand_derivatives(diff(q, t)) + expr = expand_time_derivatives(expr, t=t) + assert expr == expected + + def test_expand_sum(V): u = Function(V) w = Function(V) From 88d24b6610d30c1ad9e4f9f47687b09b828e54d1 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Tue, 4 Mar 2025 13:29:31 -0600 Subject: [PATCH 6/6] diff(x, t) -> Dt(x) --- irksome/deriv.py | 40 +++++++++++++++-------- irksome/dirk_stepper.py | 5 ++- irksome/discontinuous_galerkin_stepper.py | 4 +++ irksome/galerkin_stepper.py | 4 ++- irksome/imex.py | 9 ++++- irksome/stage_derivative.py | 11 ++++--- irksome/stage_value.py | 2 ++ tests/test_accuracy.py | 5 ++- tests/test_bern.py | 2 +- tests/test_curl.py | 3 +- tests/test_dirk.py | 11 +++---- tests/test_disc_galerkin.py | 7 ++-- tests/test_galerkin.py | 7 ++-- tests/test_imex.py | 2 +- 14 files changed, 70 insertions(+), 42 deletions(-) diff --git a/irksome/deriv.py b/irksome/deriv.py index d7ca580..6d876f6 100644 --- a/irksome/deriv.py +++ b/irksome/deriv.py @@ -1,3 +1,4 @@ +from ufl.constantvalue import as_ufl from ufl.differentiation import Derivative from ufl.core.ufl_type import ufl_type from ufl.corealg.multifunction import MultiFunction @@ -6,6 +7,7 @@ from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering from ufl.tensors import ListTensor from ufl.indexed import Indexed +from ufl.core.multiindex import FixedIndex @ufl_type(num_ops=1, @@ -33,7 +35,10 @@ def __str__(self): def _simplify_indexed(self, multiindex): """Return a simplified Expr used in the constructor of Indexed(self, multiindex).""" # Push Indexed inside TimeDerivative - return TimeDerivative(Indexed(self.ufl_operands[0], multiindex)) + if all(isinstance(i, FixedIndex) for i in multiindex): + f, = self.ufl_operands + return TimeDerivative(Indexed(f, multiindex)) + return Derivative._simplify_indexed(self, multiindex) def Dt(f, order=1): @@ -52,18 +57,24 @@ def __init__(self, t=None, timedep_coeffs=None): def coefficient(self, o): if self.t is not None and o is self.t: - return 1.0 + return as_ufl(1.0) elif self.timedep_coeffs is None or o in self.timedep_coeffs: return TimeDerivative(o) else: return self.independent_terminal(o) - def indexed(self, o, Ap, ii): + def spatial_coordinate(self, o): + return self.independent_terminal(o) + + def time_derivative(self, o, f): + return TimeDerivative(f) + + def _linear_op(self, o): return TimeDerivative(o) - def time_derivative(self, o): - f, = o.ufl_operands - return TimeDerivative(map_expr_dag(self, f)) + grad = _linear_op + curl = _linear_op + div = _linear_op # mapping rules to splat out time derivatives so that replacement should @@ -74,15 +85,9 @@ def __init__(self, t=None, timedep_coeffs=None): self.t = t self.timedep_coeffs = timedep_coeffs - def terminal(self, o): - return o - - def derivative(self, o): - raise NotImplementedError("Missing derivative handler for {0}.".format(type(o).__name__)) - expr = MultiFunction.reuse_if_untouched - def time_derivative(self, o, f): + def time_derivative(self, o): nderivs = 0 while isinstance(o, TimeDerivative): o, = o.ufl_operands @@ -92,6 +97,15 @@ def time_derivative(self, o, f): o = map_expr_dag(rules, o) return o + def _linear_op(self, o): + return o + + terminal = _linear_op + derivative = _linear_op + grad = _linear_op + curl = _linear_op + div = _linear_op + def apply_time_derivatives(expression, t=None, timedep_coeffs=None): rules = TimeDerivativeRuleDispatcher(t=t, timedep_coeffs=timedep_coeffs) diff --git a/irksome/dirk_stepper.py b/irksome/dirk_stepper.py index 4189f41..cab9f6f 100644 --- a/irksome/dirk_stepper.py +++ b/irksome/dirk_stepper.py @@ -4,7 +4,7 @@ from firedrake import NonlinearVariationalSolver as NLVS from ufl.constantvalue import as_ufl -from .deriv import TimeDerivative +from .deriv import TimeDerivative, expand_time_derivatives from .tools import component_replace, replace, MeshConstant, vecconst from .bcs import bc2space @@ -30,6 +30,9 @@ def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None): c = MC.Constant(1.0) a = MC.Constant(1.0) + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) + repl = {t: t + c * dt, u0: g + k * (a * dt), TimeDerivative(u0): k} diff --git a/irksome/discontinuous_galerkin_stepper.py b/irksome/discontinuous_galerkin_stepper.py index 75c54a2..931544a 100644 --- a/irksome/discontinuous_galerkin_stepper.py +++ b/irksome/discontinuous_galerkin_stepper.py @@ -5,6 +5,7 @@ from ufl.constantvalue import as_ufl from .base_time_stepper import StageCoupledTimeStepper from .bcs import stage2spaces4bc +from .deriv import expand_time_derivatives from .manipulation import extract_terms, strip_dt_form from .tools import component_replace, replace, vecconst import numpy as np @@ -79,6 +80,9 @@ def getFormDiscGalerkin(F, L, Q, t, dt, u0, stages, bcs=None): test_vals_w = vecconst(basis_vals_w) qpts = vecconst(qpts.reshape((-1,))) + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) + split_form = extract_terms(F) F_dtless = strip_dt_form(split_form.time) F_remainder = split_form.remainder diff --git a/irksome/galerkin_stepper.py b/irksome/galerkin_stepper.py index edfc56d..79ea42e 100644 --- a/irksome/galerkin_stepper.py +++ b/irksome/galerkin_stepper.py @@ -5,7 +5,7 @@ from ufl.constantvalue import as_ufl from .base_time_stepper import StageCoupledTimeStepper from .bcs import bc2space, stage2spaces4bc -from .deriv import TimeDerivative +from .deriv import TimeDerivative, expand_time_derivatives from .tools import component_replace, replace, vecconst import numpy as np from firedrake import TestFunction @@ -81,6 +81,8 @@ def getFormGalerkin(F, L_trial, L_test, Q, t, dt, u0, stages, bcs=None): dtu0sub = trial_dvals.T @ u_np dtu0 = TimeDerivative(u0) + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) # now loop over quadrature points Fnew = zero() diff --git a/irksome/imex.py b/irksome/imex.py index 1175987..37f121e 100644 --- a/irksome/imex.py +++ b/irksome/imex.py @@ -7,7 +7,7 @@ from ufl import zero from .ButcherTableaux import RadauIIA -from .deriv import TimeDerivative +from .deriv import TimeDerivative, expand_time_derivatives from .stage_value import getFormStage from .tools import AI, ConstantOrZero, IA, MeshConstant, replace, component_replace, getNullspace, get_stage_space from .bcs import bc2space @@ -59,6 +59,9 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None): Fit = zero() Fprop = zero() + # preprocess time derivatives + Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u0,)) + if splitting == AI: for i in range(num_stages): # replace test function @@ -292,6 +295,10 @@ def getFormsDIRKIMEX(F, Fexp, ks, khats, butch, t, dt, u0, bcs=None): if bcs is None: bcs = [] + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) + Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u0,)) + v = F.arguments()[0] V = v.function_space() msh = V.mesh() diff --git a/irksome/stage_derivative.py b/irksome/stage_derivative.py index b03f9ad..7c6adec 100644 --- a/irksome/stage_derivative.py +++ b/irksome/stage_derivative.py @@ -4,11 +4,9 @@ from firedrake import NonlinearVariationalSolver as NLVS from firedrake import assemble, dx, inner, norm -from ufl import diff, zero -from ufl.algorithms import expand_derivatives -from ufl.constantvalue import as_ufl +from ufl.constantvalue import as_ufl, zero from .tools import component_replace, replace, AI, vecconst -from .deriv import TimeDerivative # , apply_time_derivatives +from .deriv import Dt, TimeDerivative, expand_time_derivatives from .bcs import EmbeddedBCData, BCStageData, bc2space from .manipulation import extract_terms from .base_time_stepper import StageCoupledTimeStepper @@ -76,6 +74,8 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI): A1w = A1 @ w_np A2invw = A2inv @ w_np + # preprocess time derivatives + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) dtu = TimeDerivative(u0) Fnew = zero() for i in range(num_stages): @@ -92,7 +92,7 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI): def bc2gcur(bc, i): gorig = as_ufl(bc._original_arg) - gfoo = expand_derivatives(diff(gorig, t)) + gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u0,)) return replace(gfoo, {t: t + c[i] * dt}) elif bc_type == "DAE": @@ -291,6 +291,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0, self.err_old = 0.0 self.contreject = 0 + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) split_form = extract_terms(F) self.dtless_form = -split_form.remainder diff --git a/irksome/stage_value.py b/irksome/stage_value.py index 46665fd..be5f082 100644 --- a/irksome/stage_value.py +++ b/irksome/stage_value.py @@ -9,6 +9,7 @@ from .bcs import stage2spaces4bc from .ButcherTableaux import CollocationButcherTableau +from .deriv import expand_time_derivatives from .manipulation import extract_terms, strip_dt_form from .tools import AI, is_ode, replace, component_replace, vecconst from .base_time_stepper import StageCoupledTimeStepper @@ -103,6 +104,7 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo # assuming we have something of the form inner(Dt(g(u0)), v)*dx # For each stage i, this gets replaced with # inner((g(stages[i]) - g(u0))/dt, v)*dx + F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,)) split_form = extract_terms(F) F_dtless = strip_dt_form(split_form.time) F_remainder = split_form.remainder diff --git a/tests/test_accuracy.py b/tests/test_accuracy.py index f8a2e3e..d29a2b5 100644 --- a/tests/test_accuracy.py +++ b/tests/test_accuracy.py @@ -1,10 +1,9 @@ import numpy as np import pytest from firedrake import (DirichletBC, FunctionSpace, SpatialCoordinate, - TestFunction, UnitIntervalMesh, cos, diff, div, dx, + TestFunction, UnitIntervalMesh, cos, div, dx, errornorm, exp, grad, inner, norm, pi, project) from irksome import Dt, MeshConstant, RadauIIA, TimeStepper -from ufl.algorithms import expand_derivatives # test the accuracy of the 1d heat equation using CG elements @@ -26,7 +25,7 @@ def heat(n, deg, time_stages, **kwargs): dt = MC.Constant(2.0 / N) uexact = exp(-t) * cos(pi * x) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) butcher_tableau = RadauIIA(time_stages) diff --git a/tests/test_bern.py b/tests/test_bern.py index 907479c..0823833 100644 --- a/tests/test_bern.py +++ b/tests/test_bern.py @@ -46,7 +46,7 @@ def heat(n, deg, butcher_tableau, solver_parameters, bounds_type, **kwargs): dt = MC.Constant(2.0 / N) uexact = exp(-t) * cos(pi * x)**2 - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u = project(uexact, V) diff --git a/tests/test_curl.py b/tests/test_curl.py index bb79310..e2bf541 100644 --- a/tests/test_curl.py +++ b/tests/test_curl.py @@ -3,7 +3,6 @@ from firedrake.__future__ import interpolate from irksome import GaussLegendre, Dt, MeshConstant, TimeStepper from irksome.tools import AI, IA -from ufl.algorithms.ad import expand_derivatives def curlcross(a, b): @@ -25,7 +24,7 @@ def curltest(N, deg, butcher_tableau, splitting): x, y = SpatialCoordinate(msh) uexact = as_vector([t + 2*t*x + 4*t*y + 3*t*(y**2) + 2*t*x*y, 7*t + 5*t*x + 6*t*y - 3*t*x*y - 2*t*(x**2)]) - rhs = expand_derivatives(diff(uexact, t)) + curl(curl(uexact)) + rhs = Dt(uexact) + curl(curl(uexact)) u = assemble(interpolate(uexact, V)) diff --git a/tests/test_dirk.py b/tests/test_dirk.py index b045fd5..28b27b0 100644 --- a/tests/test_dirk.py +++ b/tests/test_dirk.py @@ -3,7 +3,6 @@ from firedrake import * from irksome import WSODIRK, Alexander, Dt, MeshConstant, TimeStepper from ufl import replace -from ufl.algorithms.ad import expand_derivatives wsodirks = [WSODIRK(*x) for x in ((4, 3, 2), (4, 3, 3))] @@ -35,7 +34,7 @@ def test_1d_heat_dirichletbc(butcher_tableau): + u_0 + ((x - x0) / x1) * (u_1 - u_0) ) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u = Function(V) u.interpolate(uexact) v = TestFunction(V) @@ -80,7 +79,7 @@ def test_1d_heat_neumannbc(butcher_tableau): (x,) = SpatialCoordinate(msh) uexact = cos(pi*x)*exp(-(pi**2)*t) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u_dirk = Function(V) u = Function(V) u_dirk.interpolate(uexact) @@ -125,7 +124,7 @@ def test_1d_heat_homogdbc(butcher_tableau): (x,) = SpatialCoordinate(msh) uexact = sin(pi*x)*exp(-(pi**2)*t) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u_dirk = Function(V) u = Function(V) u_dirk.interpolate(uexact) @@ -174,7 +173,7 @@ def test_1d_vectorheat_componentBC(butcher_tableau): uexact = as_vector([sin(pi*x/2)*exp(-(pi**2)*t/4), cos(pi*x/2)*exp(-(pi**2)*t/4)]) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u_dirk = Function(V) u = Function(V) u_dirk.interpolate(uexact) @@ -241,7 +240,7 @@ def test_stokes_bcs(butcher_tableau, bctype): uexact = as_vector([x*t + y**2, -y*t+t*(x**2)]) pexact = x + y * t - u_rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + grad(pexact) + u_rhs = Dt(uexact) - div(grad(uexact)) + grad(pexact) p_rhs = -div(uexact) z = Function(Z) diff --git a/tests/test_disc_galerkin.py b/tests/test_disc_galerkin.py index 80c45b3..9e2f586 100644 --- a/tests/test_disc_galerkin.py +++ b/tests/test_disc_galerkin.py @@ -4,7 +4,6 @@ from firedrake import * from irksome import Dt, MeshConstant, DiscontinuousGalerkinTimeStepper from irksome import TimeStepper, RadauIIA -from ufl.algorithms.ad import expand_derivatives import FIAT @@ -36,7 +35,7 @@ def test_1d_heat_dirichletbc(order, basis_type): + u_0 + ((x - x0) / x1) * (u_1 - u_0) ) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u = Function(V) u.interpolate(uexact) v = TestFunction(V) @@ -81,7 +80,7 @@ def test_1d_heat_neumannbc(order): butcher_tableau = RadauIIA(order+1) uexact = cos(pi*x)*exp(-(pi**2)*t) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u_Radau = Function(V) u = Function(V) u_Radau.interpolate(uexact) @@ -130,7 +129,7 @@ def test_1d_heat_homogeneous_dirichletbc(order): butcher_tableau = RadauIIA(order+1) uexact = sin(pi*x)*exp(-(pi**2)*t) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) bcs = DirichletBC(V, uexact, "on_boundary") u_Radau = Function(V) u = Function(V) diff --git a/tests/test_galerkin.py b/tests/test_galerkin.py index 0c4de9a..ba6d2c6 100644 --- a/tests/test_galerkin.py +++ b/tests/test_galerkin.py @@ -4,7 +4,6 @@ from firedrake import * from irksome import Dt, MeshConstant, GalerkinTimeStepper from irksome import TimeStepper, GaussLegendre -from ufl.algorithms.ad import expand_derivatives from FIAT import make_quadrature, ufc_simplex @@ -36,7 +35,7 @@ def test_1d_heat_dirichletbc(order, basis_type): + u_0 + ((x - x0) / x1) * (u_1 - u_0) ) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u = Function(V) u.interpolate(uexact) v = TestFunction(V) @@ -82,7 +81,7 @@ def test_1d_heat_neumannbc(order, num_quad_points): butcher_tableau = GaussLegendre(order) uexact = cos(pi*x)*exp(-(pi**2)*t) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) u_GL = Function(V) u = Function(V) u_GL.interpolate(uexact) @@ -131,7 +130,7 @@ def test_1d_heat_homogeneous_dirichletbc(order): butcher_tableau = GaussLegendre(order) uexact = sin(pi*x)*exp(-(pi**2)*t) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + rhs = Dt(uexact) - div(grad(uexact)) bcs = DirichletBC(V, uexact, "on_boundary") u_GL = Function(V) u = Function(V) diff --git a/tests/test_imex.py b/tests/test_imex.py index b749bed..5325d67 100644 --- a/tests/test_imex.py +++ b/tests/test_imex.py @@ -16,7 +16,7 @@ def convdiff_neumannbc(butcher_tableau, order, N): # Choose uexact so rhs is nonzero uexact = cos(pi*x)*exp(-t) - rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + uexact.dx(0) + rhs = Dt(uexact) - div(grad(uexact)) + uexact.dx(0) u = Function(V) u.interpolate(uexact)