Skip to content

Commit

Permalink
Fix boundary conditions (#98)
Browse files Browse the repository at this point in the history
* Fix shapeless zero bc

* clean up gblah/gdat monstrosity

* Add BoundsConstrainedBC (used internally)



---------

Co-authored-by: Mingdong He <[email protected]>
Co-authored-by: Robert Kirby <[email protected]>
Co-authored-by: Pablo Brubeck <[email protected]>
  • Loading branch information
4 people authored Oct 21, 2024
1 parent 0a0fdc5 commit dc9f91c
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 287 deletions.
15 changes: 2 additions & 13 deletions demos/lowlevel/demo_lowlevel_homogbc.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Continuing::

Now, we use the :func:`.getForm` function, which processes the semidiscrete problem::

Fnew, k, bcnew, nspnew, bcdata = getForm(F, butcher_tableau, t, dt, u, bcs=bc)
Fnew, k, bcnew, nspnew = getForm(F, butcher_tableau, t, dt, u, bcs=bc)

This returns several things:

Expand All @@ -68,18 +68,7 @@ This returns several things:
be enforced on the variational problem for the stages
* ``nspnew`` is a new :class:`~firedrake.MixedVectorSpaceBasis` that
can be used to express the nullspace of `Fnew`
* ``bcdata`` contains information needed to update the boundary
conditions. It is a list of pairs of the form (``f``, ``expr``), where
``f`` is a :class:`~firedrake.function.Function` and ``expr`` is an
:class:`~ufl.core.expr.Expr` for each of the Dirichlet boundary conditions.
Because Firedrake isn't smart enough to detect that `t` changes in
the expression for the boundary condition, we need to manually
interpolate or project each :py:class:`~ufl.core.expr.Expr` onto the corresponding ``f`` at the
beginning of each time step. Firedrake will notice this change and
re-apply the boundary conditions. This hassle is easy to overlook
(not needed in this demo with homogeneous BC) and part of the reason
we recommend using the :class:`.TimeStepper` interface that does this
for you.


Solver parameters are just blunt-force LU. Other options are surely possible::

Expand Down
13 changes: 1 addition & 12 deletions demos/lowlevel/demo_lowlevel_inhomogbc.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,14 @@ Imports::
As with the homogeneous BC case, we use the `getForm` method to
process the semidiscrete problem::

Fnew, k, bcnew, nspnew, bcdata = getForm(F, butcher_tableau, t, dt, u, bcs=bc)
Fnew, k, bcnew, nspnew = getForm(F, butcher_tableau, t, dt, u, bcs=bc)

Recall that `getForm` produces:

* ``Fnew`` is the UFL variational form for the fully discrete method.
* ``k`` is a new :class:`~firedrake.function.Function` of stages on the s-way product of the space on which the problem was originally posed
* ``bcnew`` is a list of new :class:`~firedrake.bcs.DirichletBC` that need to
be enforced on the variational problem for the stages
* ``bcdata`` contains information needed to update the boundary
conditions. It is a list of triples of the form
(``f``,``expr``,``method``), where ``f`` is a
:class:`~firedrake.function.Function`, ``expr`` is an
:class:`~ufl.core.expr.Expr`, and ``method`` is either a project or
interpolate operation for each of the Dirichlet boundary conditions.
You're using the low-level interface and have to force Firedrake to
reapply the boundary conditions.


We just use basic solver parameters and set up the variational problem
Expand All @@ -77,9 +69,6 @@ boundary conditions at each time step::
if float(t) + float(dt) > 1.0:
dt.assign(1.0 - float(t))

for (gdat, gcur, gmethod) in bcdata:
gmethod(gcur, u)

solver.solve()

for i in range(butcher_tableau.num_stages):
Expand Down
2 changes: 1 addition & 1 deletion demos/lowlevel/demo_lowlevel_mixed_heat.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Build the mesh and approximating spaces::

Because we aren't concerned with any strongly-enforced boundary conditions, we drop that information in calling `get_form`::

Fnew, k, _, _, _ = getForm(F, butcher_tableau, t, dt, sigu)
Fnew, k, _, _ = getForm(F, butcher_tableau, t, dt, sigu)

We set up the variational problem and solver using a sparse direct method::

Expand Down
106 changes: 106 additions & 0 deletions irksome/bcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from functools import partial
from firedrake import (DirichletBC, Function, TestFunction,
NonlinearVariationalProblem,
NonlinearVariationalSolver,
replace, inner, dx)


def get_sub(u, indices):
for i in indices:
u = u.sub(i)
return u


def bc2space(bc, V):
return get_sub(V, bc._indices)


def stage2spaces4bc(bc, V, Vbig, i):
"""used to figure out how to apply Dirichlet BC to each stage"""
num_fields = len(V)
sub = 0 if num_fields == 1 else bc.function_space_index()
comp = bc.function_space().component

Vbigi = Vbig[sub+num_fields*i]
if comp is not None: # check for sub-piece of vector-valued
Vbigi = Vbigi.sub(comp)

return Vbigi


def BCStageData(V, gcur, u0, u0_mult, i, t, dt):
if V.component is None: # V is not a bit of a VFS
if V.index is None: # not part of MFS, either
indices = ()
else: # part of MFS
indices = (V.index,)
else: # bottommost space is bit of VFS
if V.parent.index is None: # but not part of a MFS
indices = (V.component,)
else: # V is a bit of a VFS inside an MFS
indices = (V.parent.index, V.component)

if gcur == 0: # special case DirichletBC(V, 0, ...), do nothing
gdat = gcur
else:
gdat = gcur - u0_mult[i] * get_sub(u0, indices)
return gdat


def EmbeddedBCData(bc, t, dt, num_fields, butcher_tableau, ws, u0):
gorig = bc._original_arg
if gorig == 0: # special case DirichletBC(V, 0, ...), do nothing
gdat = gorig
else:
gcur = replace(gorig, {t: t+dt})
sub = 0 if num_fields == 1 else bc.function_space_index()
comp = bc.function_space().component
num_stages = butcher_tableau.num_stages
btilde = butcher_tableau.btilde
if comp is None: # check for sub-piece of vector-valued
for j in range(num_stages):
gcur -= dt*btilde[j]*ws[num_fields*j+sub]
else:
for j in range(num_stages):
gcur -= dt*btilde[j]*ws[num_fields*j+sub].sub(comp)

gdat = gcur - bc2space(bc, u0)
return gdat


class BoundsConstrainedBC(DirichletBC):
"""A DirichletBC with bounds-constrained data."""
def __init__(self, V, g, sub_domain, bounds, solver_parameters=None):
super().__init__(V, g, sub_domain)
if solver_parameters is None:
solver_parameters = {
"snes_type": "vinewtonssls",
}
self.solver_parameters = solver_parameters
self.bounds = bounds

@property
def function_arg(self):
'''The value of this boundary condition.'''
if hasattr(self, "_function_arg_update"):
self._function_arg_update()
return self._function_arg

@function_arg.setter
def function_arg(self, g):
'''Set the value of this boundary condition.'''
V = self.function_space()
gnew = Function(V)
try:
# Use the interpolant as initial guess
gnew.interpolate(g)
except (NotImplementedError, AttributeError):
pass
F = inner(TestFunction(V), gnew - g) * dx
problem = NonlinearVariationalProblem(F, gnew)
solver = NonlinearVariationalSolver(problem,
solver_parameters=self.solver_parameters)

self._function_arg = gnew
self.function_arg_update = partial(solver.solve, bounds=self.bounds)
self.function_arg_update()
69 changes: 8 additions & 61 deletions irksome/getForm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,14 @@
from operator import mul

import numpy
from firedrake import (DirichletBC, Function, TestFunction,
assemble, project, split)
from firedrake.__future__ import interpolate
from firedrake import Function, TestFunction, split
from ufl import diff
from ufl.algorithms import expand_derivatives
from ufl.classes import Zero
from ufl.constantvalue import as_ufl
from .tools import ConstantOrZero, MeshConstant, replace, getNullspace, AI, stage2spaces4bc
from .tools import ConstantOrZero, MeshConstant, replace, getNullspace, AI
from .deriv import TimeDerivative # , apply_time_derivatives


class BCStageData(object):
def __init__(self, V, gcur, u0, u0_mult, i, t, dt):
if V.component is not None: # bottommost space is bit of VFS
if V.parent.index is None: # but not part of a MFS
sub = V.component
try:
gdat = assemble(interpolate(gcur-u0_mult[i]*u0.sub(sub), V))
gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u.sub(sub))
except: # noqa: E722
gdat = project(gcur-u0_mult[i]*u0.sub(sub), V)
gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u.sub(sub))
else: # V is a bit of a VFS inside an MFS
sub0 = V.parent.index
sub1 = V.component
try:
gdat = assemble(interpolate(gcur-u0_mult[i]*u0.sub(sub0).sub(sub1), V))
gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u.sub(sub0).sub(sub1))
except: # noqa: E722
gdat = project(gcur-u0_mult[i]*u0.sub(sub0).sub(sub1), V)
gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u.sub(sub0).sub(sub1))
else: # V is not a bit of a VFS
if V.index is None: # not part of MFS, either
try:
gdat = assemble(interpolate(gcur-u0_mult[i]*u0, V))
gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u)
except: # noqa: E722
gdat = project(gcur-u0_mult[i]*u0, V)
gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u)
else: # part of MFS
sub = V.index
try:
gdat = assemble(interpolate(gcur-u0_mult[i]*u0.sub(sub), V))
gmethod = lambda g, u: gdat.interpolate(g-u0_mult[i]*u.sub(sub))
except: # noqa: E722
gdat = project(gcur-u0_mult[i]*u0.sub(sub), V)
gmethod = lambda g, u: gdat.project(g-u0_mult[i]*u.sub(sub))

self.gstuff = (gdat, gcur, gmethod)
from .bcs import BCStageData, bc2space, stage2spaces4bc


def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI,
Expand Down Expand Up @@ -99,16 +58,6 @@ def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI,
on the stages,
- 'nspnew', the :class:`firedrake.MixedVectorSpaceBasis` object
that represents the nullspace of the coupled system
- `gblah`, a list of tuples of the form (f, expr, method),
where f is a :class:`firedrake.Function` and expr is a
:class:`ufl.Expr`. At each time step, each expr needs to be
re-interpolated/projected onto the corresponding f in order
for Firedrake to pick up that time-dependent boundary
conditions need to be re-applied. The
interpolation/projection is encoded in method, which is
either `f.interpolate(expr-c*u0)` or `f.project(expr-c*u0)`, depending
on whether the function space for f supports interpolation or
not.
"""
if bc_type is None:
bc_type = "DAE"
Expand Down Expand Up @@ -191,7 +140,6 @@ def getForm(F, butch, t, dt, u0, bcs=None, bc_type=None, splitting=AI,
Fnew += replace(F, repl)

bcnew = []
gblah = []

if bcs is None:
bcs = []
Expand Down Expand Up @@ -227,13 +175,12 @@ def bc2gcur(bc, i):
# set up the new BCs for either method
for bc in bcs:
for i in range(num_stages):
Vsp, Vbigi = stage2spaces4bc(bc, V, Vbig, i)
Vsp = bc2space(bc, V)
Vbigi = stage2spaces4bc(bc, V, Vbig, i)
gcur = bc2gcur(bc, i)
blah = BCStageData(Vsp, gcur, u0, u0_mult, i, t, dt)
gdat, gcr, gmethod = blah.gstuff
gblah.append((gdat, gcr, gmethod))
bcnew.append(DirichletBC(Vbigi, gdat, bc.sub_domain))
gdat = BCStageData(Vsp, gcur, u0, u0_mult, i, t, dt)
bcnew.append(bc.reconstruct(V=Vbigi, g=gdat))

nspnew = getNullspace(V, Vbig, butch, nullspace)

return Fnew, w, bcnew, nspnew, gblah
return Fnew, w, bcnew, nspnew
5 changes: 1 addition & 4 deletions irksome/imex.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,14 @@ def __init__(self, F, Fexp, butcher_tableau,

# Since this assumes stiff accuracy, we drop
# the update information on the floor.
Fbig, _, UU, bigBCs, gblah, nsp = getFormStage(
Fbig, _, UU, bigBCs, nsp = getFormStage(
F, butcher_tableau, u0, t, dt, bcs,
splitting=splitting, nullspace=nullspace)

self.UU = UU
self.UU_old = UU_old = Function(UU.function_space())
self.UU_old_split = UU_old.subfunctions
self.bigBCs = bigBCs
self.bcdat = gblah

Fit, Fprop = getFormExplicit(
Fexp, butcher_tableau, u0, UU_old, t, dt, splitting)
Expand Down Expand Up @@ -283,8 +282,6 @@ def propagate(self):
for i, u0bit in enumerate(u0split):
u0bit.assign(self.UU_old_split[(ns-1)*nf + i])

for gdat, gcur, gmethod in self.bcdat:
gmethod(gdat, gcur)
push_parent(self.u0.function_space().dm, self.UU.function_space().dm)

ps = self.prop_solver
Expand Down
8 changes: 4 additions & 4 deletions irksome/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def form(self, pc, test, trial):
# which getForm do I need to get?

if stage_type in ("deriv", None):
Fnew, w, bcnew, bignsp, _ = \
Fnew, w, bcnew, bignsp = \
getForm(F, butcher_new, t, dt, u0, bcs,
bc_type, splitting, nullspace)
elif stage_type == "value":
Fnew, _, w, bcnew, _, bignsp = \
Fnew, _, w, bcnew, bignsp = \
getFormStage(F, butcher_new, u0, t, dt, bcs,
splitting, nullspace)
# Now we get the Jacobian for the modified system,
Expand Down Expand Up @@ -125,11 +125,11 @@ def form(self, pc, test, trial):
# which getForm do I need to get?

if stage_type in ("deriv", None):
Fnew, w, bcnew, bignsp, _ = \
Fnew, w, bcnew, bignsp = \
getForm(F, butcher_tableau, t, dt, u0, bcs,
bc_type, splitting, nullspace)
elif stage_type == "value":
Fnew, _, w, bcnew, _, bignsp = \
Fnew, _, w, bcnew, bignsp = \
getFormStage(F, butcher_tableau, u0, t, dt, bcs,
splitting, nullspace)
# Now we get the Jacobian for the modified system,
Expand Down
Loading

0 comments on commit dc9f91c

Please sign in to comment.