Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhacements for TimeDerivative #128

Merged
merged 6 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .ButcherTableaux import QinZhang # noqa: F401
from .ButcherTableaux import RadauIIA # noqa: F401
from .pep_explicit_rk import PEPRK # noqa: F401
from .deriv import Dt # noqa: F401
from .deriv import Dt, expand_time_derivatives # noqa: F401
from .dirk_imex_tableaux import DIRK_IMEX # noqa: F401
from .ars_dirk_imex_tableaux import ARS_DIRK_IMEX # noqa: F401
from .sspk_tableau import SSPK_DIRK_IMEX # noqa: F401
Expand Down
95 changes: 57 additions & 38 deletions irksome/deriv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from ufl.constantvalue import as_ufl
from ufl.differentiation import Derivative
from ufl.core.ufl_type import ufl_type
from ufl.corealg.multifunction import MultiFunction
from ufl.algorithms.map_integrands import map_integrand_dags, map_expr_dag
from ufl.algorithms.apply_derivatives import GenericDerivativeRuleset
from ufl.algorithms.apply_algebra_lowering import apply_algebra_lowering
from ufl.tensors import ListTensor
from ufl.indexed import Indexed
from ufl.core.multiindex import FixedIndex


@ufl_type(num_ops=1,
Expand All @@ -18,6 +22,7 @@ class TimeDerivative(Derivative):

def __new__(cls, f):
if isinstance(f, ListTensor):
# Push TimeDerivative inside ListTensor
return ListTensor(*map(TimeDerivative, f.ufl_operands))
return Derivative.__new__(cls)

Expand All @@ -27,73 +32,87 @@ def __init__(self, f):
def __str__(self):
return "d{%s}/dt" % (self.ufl_operands[0],)

def _simplify_indexed(self, multiindex):
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
# Push Indexed inside TimeDerivative
if all(isinstance(i, FixedIndex) for i in multiindex):
f, = self.ufl_operands
return TimeDerivative(Indexed(f, multiindex))
return Derivative._simplify_indexed(self, multiindex)

def Dt(f):
"""Short-hand function to produce a :class:`TimeDerivative` of the
input."""
return TimeDerivative(f)

def Dt(f, order=1):
"""Short-hand function to produce a :class:`TimeDerivative` of a given order."""
for k in range(order):
f = TimeDerivative(f)
return f


class TimeDerivativeRuleset(GenericDerivativeRuleset):
"""Apply AD rules to time derivative expressions. WIP"""
def __init__(self, t, timedep_coeffs):
"""Apply AD rules to time derivative expressions."""
def __init__(self, t=None, timedep_coeffs=None):
GenericDerivativeRuleset.__init__(self, ())
self.t = t
self.timedep_coeffs = timedep_coeffs

def coefficient(self, o):
if o in self.timedep_coeffs:
if self.t is not None and o is self.t:
return as_ufl(1.0)
elif self.timedep_coeffs is None or o in self.timedep_coeffs:
return TimeDerivative(o)
else:
return self.independent_terminal(o)

# def indexed(self, o, Ap, ii):
# print(o, type(o))
# print(Ap, type(Ap))
# print(ii, type(ii))
# 1/0
def spatial_coordinate(self, o):
return self.independent_terminal(o)

def time_derivative(self, o, f):
return TimeDerivative(f)

def _linear_op(self, o):
return TimeDerivative(o)

grad = _linear_op
curl = _linear_op
div = _linear_op


# mapping rules to splat out time derivatives so that replacement should
# work on more complex problems.
class TimeDerivativeRuleDispatcher(MultiFunction):
def __init__(self, t, timedep_coeffs):
def __init__(self, t=None, timedep_coeffs=None):
MultiFunction.__init__(self)
self.t = t
self.timedep_coeffs = timedep_coeffs

def terminal(self, o):
return o

def derivative(self, o):
raise NotImplementedError("Missing derivative handler for {0}.".format(type(o).__name__))

expr = MultiFunction.reuse_if_untouched

def grad(self, o):
from firedrake import grad
if isinstance(o, TimeDerivative):
return TimeDerivative(grad(*o.ufl_operands))
return o

def div(self, o):
def time_derivative(self, o):
nderivs = 0
while isinstance(o, TimeDerivative):
o, = o.ufl_operands
nderivs += 1
rules = TimeDerivativeRuleset(t=self.t, timedep_coeffs=self.timedep_coeffs)
for k in range(nderivs):
o = map_expr_dag(rules, o)
return o

def reference_grad(self, o):
def _linear_op(self, o):
return o

def coefficient_derivative(self, o):
return o
terminal = _linear_op
derivative = _linear_op
grad = _linear_op
curl = _linear_op
div = _linear_op

def coordinate_derivative(self, o):
return o

def time_derivative(self, o):
f, = o.ufl_operands
rules = TimeDerivativeRuleset(self.t, self.timedep_coeffs)
return map_expr_dag(rules, f)
def apply_time_derivatives(expression, t=None, timedep_coeffs=None):
rules = TimeDerivativeRuleDispatcher(t=t, timedep_coeffs=timedep_coeffs)
return map_integrand_dags(rules, expression)


def apply_time_derivatives(expression, t, timedep_coeffs=[]):
rules = TimeDerivativeRuleDispatcher(t, timedep_coeffs)
return map_integrand_dags(rules, expression)
def expand_time_derivatives(expression, t=None, timedep_coeffs=None):
expression = apply_algebra_lowering(expression)
expression = apply_time_derivatives(expression, t=t, timedep_coeffs=timedep_coeffs)
return expression
5 changes: 4 additions & 1 deletion irksome/dirk_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from firedrake import NonlinearVariationalSolver as NLVS
from ufl.constantvalue import as_ufl

from .deriv import TimeDerivative
from .deriv import TimeDerivative, expand_time_derivatives
from .tools import component_replace, replace, MeshConstant, vecconst
from .bcs import bc2space

Expand All @@ -30,6 +30,9 @@ def getFormDIRK(F, ks, butch, t, dt, u0, bcs=None):
c = MC.Constant(1.0)
a = MC.Constant(1.0)

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))

repl = {t: t + c * dt,
u0: g + k * (a * dt),
TimeDerivative(u0): k}
Expand Down
4 changes: 4 additions & 0 deletions irksome/discontinuous_galerkin_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ufl.constantvalue import as_ufl
from .base_time_stepper import StageCoupledTimeStepper
from .bcs import stage2spaces4bc
from .deriv import expand_time_derivatives
from .manipulation import extract_terms, strip_dt_form
from .tools import component_replace, replace, vecconst
import numpy as np
Expand Down Expand Up @@ -79,6 +80,9 @@ def getFormDiscGalerkin(F, L, Q, t, dt, u0, stages, bcs=None):
test_vals_w = vecconst(basis_vals_w)
qpts = vecconst(qpts.reshape((-1,)))

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))

split_form = extract_terms(F)
F_dtless = strip_dt_form(split_form.time)
F_remainder = split_form.remainder
Expand Down
4 changes: 3 additions & 1 deletion irksome/galerkin_stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ufl.constantvalue import as_ufl
from .base_time_stepper import StageCoupledTimeStepper
from .bcs import bc2space, stage2spaces4bc
from .deriv import TimeDerivative
from .deriv import TimeDerivative, expand_time_derivatives
from .tools import component_replace, replace, vecconst
import numpy as np
from firedrake import TestFunction
Expand Down Expand Up @@ -81,6 +81,8 @@ def getFormGalerkin(F, L_trial, L_test, Q, t, dt, u0, stages, bcs=None):
dtu0sub = trial_dvals.T @ u_np

dtu0 = TimeDerivative(u0)
# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))

# now loop over quadrature points
Fnew = zero()
Expand Down
9 changes: 8 additions & 1 deletion irksome/imex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ufl import zero

from .ButcherTableaux import RadauIIA
from .deriv import TimeDerivative
from .deriv import TimeDerivative, expand_time_derivatives
from .stage_value import getFormStage
from .tools import AI, ConstantOrZero, IA, MeshConstant, replace, component_replace, getNullspace, get_stage_space
from .bcs import bc2space
Expand Down Expand Up @@ -59,6 +59,9 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None):
Fit = zero()
Fprop = zero()

# preprocess time derivatives
Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u0,))

if splitting == AI:
for i in range(num_stages):
# replace test function
Expand Down Expand Up @@ -292,6 +295,10 @@ def getFormsDIRKIMEX(F, Fexp, ks, khats, butch, t, dt, u0, bcs=None):
if bcs is None:
bcs = []

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
Fexp = expand_time_derivatives(Fexp, t=t, timedep_coeffs=(u0,))

v = F.arguments()[0]
V = v.function_space()
msh = V.mesh()
Expand Down
11 changes: 6 additions & 5 deletions irksome/stage_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from firedrake import NonlinearVariationalSolver as NLVS
from firedrake import assemble, dx, inner, norm

from ufl import diff, zero
from ufl.algorithms import expand_derivatives
from ufl.constantvalue import as_ufl
from ufl.constantvalue import as_ufl, zero
from .tools import component_replace, replace, AI, vecconst
from .deriv import TimeDerivative # , apply_time_derivatives
from .deriv import Dt, TimeDerivative, expand_time_derivatives
from .bcs import EmbeddedBCData, BCStageData, bc2space
from .manipulation import extract_terms
from .base_time_stepper import StageCoupledTimeStepper
Expand Down Expand Up @@ -76,6 +74,8 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI):
A1w = A1 @ w_np
A2invw = A2inv @ w_np

# preprocess time derivatives
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
dtu = TimeDerivative(u0)
Fnew = zero()
for i in range(num_stages):
Expand All @@ -92,7 +92,7 @@ def getForm(F, butch, t, dt, u0, stages, bcs=None, bc_type=None, splitting=AI):

def bc2gcur(bc, i):
gorig = as_ufl(bc._original_arg)
gfoo = expand_derivatives(diff(gorig, t))
gfoo = expand_time_derivatives(Dt(gorig), t=t, timedep_coeffs=(u0,))
return replace(gfoo, {t: t + c[i] * dt})

elif bc_type == "DAE":
Expand Down Expand Up @@ -291,6 +291,7 @@ def __init__(self, F, butcher_tableau, t, dt, u0,
self.err_old = 0.0
self.contreject = 0

F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
split_form = extract_terms(F)
self.dtless_form = -split_form.remainder

Expand Down
2 changes: 2 additions & 0 deletions irksome/stage_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .bcs import stage2spaces4bc
from .ButcherTableaux import CollocationButcherTableau
from .deriv import expand_time_derivatives
from .manipulation import extract_terms, strip_dt_form
from .tools import AI, is_ode, replace, component_replace, vecconst
from .base_time_stepper import StageCoupledTimeStepper
Expand Down Expand Up @@ -103,6 +104,7 @@ def getFormStage(F, butch, t, dt, u0, stages, bcs=None, splitting=None, vandermo
# assuming we have something of the form inner(Dt(g(u0)), v)*dx
# For each stage i, this gets replaced with
# inner((g(stages[i]) - g(u0))/dt, v)*dx
F = expand_time_derivatives(F, t=t, timedep_coeffs=(u0,))
split_form = extract_terms(F)
F_dtless = strip_dt_form(split_form.time)
F_remainder = split_form.remainder
Expand Down
12 changes: 2 additions & 10 deletions irksome/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,22 +97,14 @@ def replace(e, mapping):
return map_integrand_dags(MyReplacer(mapping2), e)


def get_component(expr, index):
if isinstance(expr, TimeDerivative):
expr, = expr.ufl_operands
return TimeDerivative(expr[index])
else:
return expr[index]


def component_replace(e, mapping):
# Replace, reccurring on components
"""Replace, recurring on components"""
cmapping = {}
for key, value in mapping.items():
cmapping[key] = as_tensor(value)
if key.ufl_shape:
for j in numpy.ndindex(key.ufl_shape):
cmapping[get_component(key, j)] = value[j]
cmapping[key[j]] = value[j]
return replace(e, cmapping)


Expand Down
5 changes: 2 additions & 3 deletions tests/test_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import pytest
from firedrake import (DirichletBC, FunctionSpace, SpatialCoordinate,
TestFunction, UnitIntervalMesh, cos, diff, div, dx,
TestFunction, UnitIntervalMesh, cos, div, dx,
errornorm, exp, grad, inner, norm, pi, project)
from irksome import Dt, MeshConstant, RadauIIA, TimeStepper
from ufl.algorithms import expand_derivatives


# test the accuracy of the 1d heat equation using CG elements
Expand All @@ -26,7 +25,7 @@ def heat(n, deg, time_stages, **kwargs):
dt = MC.Constant(2.0 / N)

uexact = exp(-t) * cos(pi * x)
rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact))
rhs = Dt(uexact) - div(grad(uexact))

butcher_tableau = RadauIIA(time_stages)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_bern.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def heat(n, deg, butcher_tableau, solver_parameters, bounds_type, **kwargs):
dt = MC.Constant(2.0 / N)

uexact = exp(-t) * cos(pi * x)**2
rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact))
rhs = Dt(uexact) - div(grad(uexact))

u = project(uexact, V)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_curl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from firedrake.__future__ import interpolate
from irksome import GaussLegendre, Dt, MeshConstant, TimeStepper
from irksome.tools import AI, IA
from ufl.algorithms.ad import expand_derivatives


def curlcross(a, b):
Expand All @@ -25,7 +24,7 @@ def curltest(N, deg, butcher_tableau, splitting):
x, y = SpatialCoordinate(msh)

uexact = as_vector([t + 2*t*x + 4*t*y + 3*t*(y**2) + 2*t*x*y, 7*t + 5*t*x + 6*t*y - 3*t*x*y - 2*t*(x**2)])
rhs = expand_derivatives(diff(uexact, t)) + curl(curl(uexact))
rhs = Dt(uexact) + curl(curl(uexact))

u = assemble(interpolate(uexact, V))

Expand Down
Loading