Skip to content

Commit

Permalink
Removing Constant domains everywhere I can find them
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottMacLachlan committed Jan 26, 2024
1 parent 05e1220 commit 0e6650e
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 26 deletions.
4 changes: 2 additions & 2 deletions demos/demo_nitsche_heat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
V = FunctionSpace(msh, "CG", 1)
x, y = SpatialCoordinate(msh)

MC = MeshConstnt(msh)
MC = MeshConstant(msh)
dt = MC.Constant(10.0 / N)
t = MC.Constant(0.0)

Expand All @@ -28,7 +28,7 @@
# define the variational form once outside the loop
h = CellSize(msh)
n = FacetNormal(msh)
beta = Constant(100.0, domain=msh)
beta = Constant(100.0)
v = TestFunction(V)
F = (inner(Dt(u), v)*dx + inner(grad(u), grad(v))*dx - inner(rhs, v) * dx
- inner(dot(grad(u), n), v) * ds
Expand Down
9 changes: 5 additions & 4 deletions irksome/dirk_stepper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy
from firedrake import Constant, DirichletBC, Function
from firedrake import DirichletBC, Function
from firedrake import NonlinearVariationalProblem as NLVP
from firedrake import NonlinearVariationalSolver as NLVS
from firedrake import interpolate, split
from ufl.constantvalue import as_ufl

from .deriv import TimeDerivative
from .tools import replace
from .tools import replace, MeshConstant


class BCThingy:
Expand Down Expand Up @@ -89,8 +89,9 @@ def getFormDIRK(F, butch, t, dt, u0, bcs=None):
# variational form and BC's, and we update it for each stage in
# the loop over stages in the advance method. The Constant a is
# used similarly in the variational form
c = Constant(1.0, domain=msh)
a = Constant(1.0, domain=msh)
MC = MeshConstant(msh)
c = MC.Constant(1.0)
a = MC.Constant(1.0)

repl = {t: t+c*dt}
for u0bit, kbit, gbit in zip(u0bits, k_bits, gbits):
Expand Down
19 changes: 10 additions & 9 deletions irksome/getForm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ufl.algorithms import expand_derivatives
from ufl.classes import Zero
from ufl.constantvalue import as_ufl
from .tools import replace, getNullspace, AI
from .tools import MeshConstant, replace, getNullspace, AI
from .deriv import TimeDerivative # , apply_time_derivatives


Expand Down Expand Up @@ -52,8 +52,8 @@ def __init__(self, V, gcur, u0, u0_mult, i, t, dt):
self.gstuff = (gdat, gcur, gmethod)


def ConstantOrZero(x, msh):
return Zero() if abs(complex(x)) < 1.e-10 else Constant(x, domain=msh)
def ConstantOrZero(x, MC):
return Zero() if abs(complex(x)) < 1.e-10 else MC.Constant(x)


def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI,
Expand Down Expand Up @@ -124,21 +124,22 @@ def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI,

bA1, bA2 = splitting(butch.A)

MC = MeshConstant(msh)
try:
bA1inv = numpy.linalg.inv(bA1)
except numpy.linalg.LinAlgError:
bA1inv = None
try:
bA2inv = numpy.linalg.inv(bA2)
A2inv = numpy.array([[ConstantOrZero(aa, msh) for aa in arow] for arow in bA2inv],
A2inv = numpy.array([[ConstantOrZero(aa, MC) for aa in arow] for arow in bA2inv],
dtype=object)
except numpy.linalg.LinAlgError:
raise NotImplementedError("We require A = A1 A2 with A2 invertible")

A1 = numpy.array([[ConstantOrZero(aa, msh) for aa in arow] for arow in bA1],
A1 = numpy.array([[ConstantOrZero(aa, MC) for aa in arow] for arow in bA1],
dtype=object)
if bA1inv is not None:
A1inv = numpy.array([[ConstantOrZero(aa, msh) for aa in arow] for arow in bA1inv],
A1inv = numpy.array([[ConstantOrZero(aa, MC) for aa in arow] for arow in bA1inv],
dtype=object)
else:
A1inv = None
Expand Down Expand Up @@ -198,7 +199,7 @@ def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI,
if bc_type == "ODE":
assert splitting == AI, "ODE-type BC aren't implemented for this splitting strategy"
u0_mult_np = numpy.divide(1.0, butch.c, out=numpy.zeros_like(butch.c), where=butch.c != 0)
u0_mult = numpy.array([ConstantOrZero(mi, msh)/dt for mi in u0_mult_np],
u0_mult = numpy.array([ConstantOrZero(mi, MC)/dt for mi in u0_mult_np],
dtype=object)

def bc2gcur(bc, i):
Expand All @@ -211,14 +212,14 @@ def bc2gcur(bc, i):
raise NotImplementedError("Cannot have DAE BCs for this Butcher Tableau/splitting")

u0_mult_np = A1inv @ numpy.ones_like(butch.c)
u0_mult = numpy.array([ConstantOrZero(mi, msh)/dt for mi in u0_mult_np],
u0_mult = numpy.array([ConstantOrZero(mi, MC)/dt for mi in u0_mult_np],
dtype=object)

def bc2gcur(bc, i):
gorig = as_ufl(bc._original_arg)
gcur = 0
for j in range(num_stages):
gcur += ConstantOrZero(bA1inv[i, j], msh) / dt * replace(gorig, {t: t + c[j]*dt})
gcur += ConstantOrZero(bA1inv[i, j], MC) / dt * replace(gorig, {t: t + c[j]*dt})
return gcur
else:
raise ValueError("Unrecognised bc_type: %s", bc_type)
Expand Down
7 changes: 4 additions & 3 deletions irksome/imex.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import FIAT
import numpy as np
from firedrake import (Constant, Function, NonlinearVariationalProblem,
from firedrake import (Function, NonlinearVariationalProblem,
NonlinearVariationalSolver, TestFunction)
from firedrake.dmhooks import pop_parent, push_parent
from ufl.classes import Zero

from .ButcherTableaux import RadauIIA
from .stage import getBits, getFormStage
from .tools import AI, IA, replace
from .tools import AI, IA, MeshConstant, replace


def riia_explicit_coeffs(k):
Expand Down Expand Up @@ -45,7 +45,8 @@ def getFormExplicit(Fexp, butch, u0, UU, t, dt, splitting=None):

num_stages = butch.num_stages
num_fields = len(V)
vc = np.vectorize(lambda c: Constant(c, domain=msh))
MC = MeshConstant(msh)
vc = np.vectorize(lambda c: MC.Constant(c))
Aexp = riia_explicit_coeffs(num_stages)
Aprop = vc(Aexp)
Ait = vc(butch.A)
Expand Down
13 changes: 7 additions & 6 deletions irksome/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from operator import mul

import numpy as np
from firedrake import (Constant, DirichletBC, Function,
from firedrake import (DirichletBC, Function,
NonlinearVariationalProblem, NonlinearVariationalSolver,
TestFunction, dx, inner, interpolate, project, split)
from numpy import vectorize
from ufl.classes import Zero
from ufl.constantvalue import as_ufl

from .manipulation import extract_terms, strip_dt_form
from .tools import AI, IA, getNullspace, is_ode, replace
from .tools import AI, IA, getNullspace, MeshConstant, is_ode, replace


def getBits(num_stages, num_fields, u0, UU, v, VV):
Expand Down Expand Up @@ -116,7 +116,8 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None,
u0bits, vbits, VVbits, UUbits = getBits(num_stages, num_fields,
u0, UU, v, VV)

vecconst = np.vectorize(lambda c: Constant(c, domain=V.mesh()))
MC = MeshConstant(V.mesh())
vecconst = np.vectorize(lambda c: MC.Constant(c))
C = vecconst(butch.c)
A = vecconst(butch.A)

Expand Down Expand Up @@ -172,7 +173,7 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None,
Fnew += A[i, j] * dt * replace(Ftmp, repl)

elif splitting == IA:
Ainv = np.vectorize(lambda c: Constant(c, domain=V.mesh()))(np.linalg.inv(butch.A))
Ainv = np.vectorize(lambda c: MC.Constant(c))(np.linalg.inv(butch.A))

# time derivative part gets inverse of Butcher matrix.
for i in range(num_stages):
Expand Down Expand Up @@ -257,8 +258,8 @@ def getFormStage(F, butch, u0, t, dt, bcs=None, splitting=None,
unew = Function(V)

Fupdate = inner(unew - u0, v) * dx
B = vectorize(lambda c: Constant(c, domain=V.mesh()))(butch.b)
C = vectorize(lambda c: Constant(c, domain=V.mesh()))(butch.c)
B = vectorize(lambda c: MC.Constant(c))(butch.b)
C = vectorize(lambda c: MC.Constant(c))(butch.c)

for i in range(num_stages):
repl = {t: t + C[i] * dt}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from firedrake import (diff, div, dx, errornorm, exp, grad,
inner, norm, pi, project, sin,
Constant, DirichletBC, FunctionSpace,
DirichletBC, FunctionSpace,
SpatialCoordinate, TestFunction, UnitIntervalMesh)

from irksome import Dt, MeshConstant, TimeStepper, GaussLegendre
Expand Down Expand Up @@ -40,7 +40,7 @@ def heat(n, deg, time_stages, stage_type="deriv", splitting=IA):
F = (inner(Dt(u), v) * dx + inner(grad(u), grad(v)) * dx
- inner(rhs, v) * dx)

bc = DirichletBC(V, Constant(0, domain=msh), "on_boundary")
bc = DirichletBC(V, MC.Constant(0), "on_boundary")

stepper = TimeStepper(F, butcher_tableau, t, dt, u,
bcs=bc, solver_parameters=params,
Expand Down

0 comments on commit 0e6650e

Please sign in to comment.