Skip to content

Commit

Permalink
PR #471: from firedrakeproject/wrapper_of_options
Browse files Browse the repository at this point in the history
Implement a wrapper for a mixed function space, that can take subwrappers for individual components
  • Loading branch information
tommbendall authored Mar 1, 2024
2 parents 6ac3f4b + 66ce750 commit 6d76ca7
Show file tree
Hide file tree
Showing 5 changed files with 488 additions and 133 deletions.
11 changes: 10 additions & 1 deletion gusto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__all__ = [
"IntegrateByParts", "TransportEquationType", "OutputParameters",
"CompressibleParameters", "ShallowWaterParameters",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions", "MixedFSOptions",
"SpongeLayerParameters", "DiffusionParameters", "BoundaryLayerParameters"
]

Expand Down Expand Up @@ -172,6 +172,15 @@ class SUPGOptions(WrapperOptions):
ibp = IntegrateByParts.TWICE


class MixedFSOptions(WrapperOptions):
"""Specifies options for a mixed finite element formulation
where different suboptions are applied to different
prognostic variables."""

name = "mixed_options"
suboptions = {}


class SpongeLayerParameters(Configuration):
"""Specifies parameters describing a 'sponge' (damping) layer."""

Expand Down
76 changes: 60 additions & 16 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from firedrake import (
Function, TestFunction, NonlinearVariationalProblem,
Function, TestFunction, TestFunctions, NonlinearVariationalProblem,
NonlinearVariationalSolver, DirichletBC, split, Constant
)
from firedrake.fml import (
Expand Down Expand Up @@ -88,7 +88,21 @@ def __init__(self, domain, field_name=None, solver_parameters=None,

if options is not None:
self.wrapper_name = options.name
if self.wrapper_name == "embedded_dg":
if self.wrapper_name == "mixed_options":
self.wrapper = MixedFSWrapper()

for field, suboption in options.suboptions.items():
if suboption.name == 'embedded_dg':
self.wrapper.subwrappers.update({field: EmbeddedDGWrapper(self, suboption)})
elif suboption.name == "recovered":
self.wrapper.subwrappers.update({field: RecoveryWrapper(self, suboption)})
elif suboption.name == "supg":
raise RuntimeError(
'Time discretisation: suboption SUPG is currently not implemented within MixedOptions')
else:
raise RuntimeError(
f'Time discretisation: suboption wrapper {wrapper_name} not implemented')
elif self.wrapper_name == "embedded_dg":
self.wrapper = EmbeddedDGWrapper(self, options)
elif self.wrapper_name == "recovered":
self.wrapper = RecoveryWrapper(self, options)
Expand Down Expand Up @@ -159,21 +173,51 @@ def setup(self, equation, apply_bcs=True, *active_labels):
# -------------------------------------------------------------------- #

if self.wrapper is not None:
self.wrapper.setup()
self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))
if self.wrapper_name == "mixed_options":

self.wrapper.wrapper_spaces = equation.spaces
self.wrapper.field_names = equation.field_names

for field, subwrapper in self.wrapper.subwrappers.items():

if field not in equation.field_names:
raise ValueError(f"The option defined for {field} is for a field that does not exist in the equation set")

field_idx = equation.field_names.index(field)
subwrapper.setup(equation.spaces[field_idx])

self.residual = self.wrapper.label_terms(self.residual)
# Update the function space to that needed by the wrapper
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space

self.wrapper.setup()
self.fs = self.wrapper.function_space
new_test_mixed = TestFunctions(self.fs)

# Replace the original test function with one from the new
# function space defined by the subwrappers
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test_mixed))

else:
if self.wrapper_name == "supg":
self.wrapper.setup()
else:
self.wrapper.setup(self.fs)
self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))

self.residual = self.wrapper.label_terms(self.residual)

# -------------------------------------------------------------------- #
# Make boundary conditions
Expand Down
103 changes: 87 additions & 16 deletions gusto/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from abc import ABCMeta, abstractmethod
from firedrake import (
FunctionSpace, Function, BrokenElement, Projector, Interpolator,
VectorElement, Constant, as_ufl, dot, grad, TestFunction
VectorElement, Constant, as_ufl, dot, grad, TestFunction, MixedFunctionSpace
)
from firedrake.fml import Term
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions
from gusto.recovery import Recoverer, ReversibleRecoverer
from gusto.labels import transporting_velocity
import ufl

__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper"]
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper", "MixedFSWrapper"]


class Wrapper(object, metaclass=ABCMeta):
Expand All @@ -33,14 +33,23 @@ def __init__(self, time_discretisation, wrapper_options):
self.time_discretisation = time_discretisation
self.options = wrapper_options
self.solver_parameters = None
self.original_space = None

@abstractmethod
def setup(self):
def setup(self, original_space):
"""
Performs standard set up routines, and is to be called by the setup
method of the underlying time discretisation.
Store the original function space of the prognostic variable.
Within each child wrapper, setup performs standard set up routines,
and is to be called by the setup method of the underlying
time discretisation.
Args:
original_space (:class:`FunctionSpace`): the space that the
prognostic variable is defined on. This is a subset space of
a mixed function space when using a MixedFSWrapper.
"""
pass
self.original_space = original_space

@abstractmethod
def pre_apply(self):
Expand Down Expand Up @@ -76,13 +85,14 @@ class EmbeddedDGWrapper(Wrapper):
the original space.
"""

def setup(self):
def setup(self, original_space):
"""Sets up function spaces and fields needed for this wrapper."""

assert isinstance(self.options, EmbeddedDGOptions), \
'Embedded DG wrapper can only be used with Embedded DG Options'

original_space = self.time_discretisation.fs
super().setup(original_space)

domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

Expand All @@ -91,7 +101,7 @@ def setup(self):
# -------------------------------------------------------------------- #

if self.options.embedding_space is None:
V_elt = BrokenElement(original_space.ufl_element())
V_elt = BrokenElement(self.original_space.ufl_element())
self.function_space = FunctionSpace(domain.mesh, V_elt)
else:
self.function_space = self.options.embedding_space
Expand All @@ -104,8 +114,9 @@ def setup(self):

self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

if self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
self.x_projected = Function(self.original_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])

Expand Down Expand Up @@ -158,13 +169,14 @@ class RecoveryWrapper(Wrapper):
field is then returned to the original space.
"""

def setup(self):
def setup(self, original_space):
"""Sets up function spaces and fields needed for this wrapper."""

assert isinstance(self.options, RecoveryOptions), \
'Embedded DG wrapper can only be used with Recovery Options'
'Recovery wrapper can only be used with Recovery Options'

super().setup(original_space)

original_space = self.time_discretisation.fs
domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

Expand All @@ -173,7 +185,7 @@ def setup(self):
# -------------------------------------------------------------------- #

if self.options.embedding_space is None:
V_elt = BrokenElement(original_space.ufl_element())
V_elt = BrokenElement(self.original_space.ufl_element())
self.function_space = FunctionSpace(domain.mesh, V_elt)
else:
self.function_space = self.options.embedding_space
Expand All @@ -184,11 +196,12 @@ def setup(self):
# Internal variables to be used
# -------------------------------------------------------------------- #

self.x_in_tmp = Function(self.time_discretisation.fs)
self.x_in_tmp = Function(self.original_space)
self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

if self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
self.x_projected = Function(self.original_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])

Expand Down Expand Up @@ -361,3 +374,61 @@ def label_terms(self, residual):
new_residual = transporting_velocity.update_value(new_residual, self.transporting_velocity)

return new_residual


class MixedFSWrapper(object):
"""
An object to hold a subwrapper dictionary with different wrappers for
different tracers. This means that different tracers can be solved
simultaneously using a CoupledTransportEquation, whilst being in
different spaces and needing different implementation options.
"""

def __init__(self):

self.wrapper_spaces = None
self.field_names = None
self.subwrappers = {}

def setup(self):
""" Compute the new mixed function space from the subwrappers """

self.function_space = MixedFunctionSpace(self.wrapper_spaces)
self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

def pre_apply(self, x_in):
"""
Perform the pre-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = x_in.subfunctions[field_idx]
x_in_sub = self.x_in.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.pre_apply(field)
x_in_sub.assign(subwrapper.x_in)
else:
x_in_sub.assign(field)

def post_apply(self, x_out):
"""
Perform the post-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = self.x_out.subfunctions[field_idx]
x_out_sub = x_out.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.x_out.assign(field)
subwrapper.post_apply(x_out_sub)
else:
x_out_sub.assign(field)
Loading

0 comments on commit 6d76ca7

Please sign in to comment.