Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper of options #471

Merged
merged 26 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
573ad52
started MixedOptions and an integration test for this
ta440 Dec 20, 2023
e04c73a
MixedOptions dictionary added
ta440 Dec 21, 2023
4a13df3
improvements to MixedOptions
ta440 Dec 22, 2023
6165057
more mixed options changes
ta440 Jan 9, 2024
775d3b2
more changes to mixed options
ta440 Jan 10, 2024
793768e
working on new test functions with mixed options
ta440 Jan 10, 2024
5f05367
more changes to test functions for multiple wrappers
ta440 Jan 11, 2024
5c13db1
more changes to mixed options with replacing test functions
ta440 Jan 16, 2024
806a49e
mixed options works for embedded dg and recovery
ta440 Jan 17, 2024
2be1354
moved DG1-DG1 equispaced test from test_limiters to test_mixed_fs_opt…
ta440 Jan 22, 2024
6d2fc9b
finalising mixed options test script
ta440 Jan 22, 2024
25ed243
lint
ta440 Jan 22, 2024
f410d2c
lint
ta440 Jan 22, 2024
d166a22
lint
ta440 Jan 22, 2024
aa3d7c0
lint
ta440 Jan 22, 2024
f1e958a
lint
ta440 Jan 22, 2024
51f6f7b
remove dubugging statements
ta440 Jan 23, 2024
55e0856
separate MixedFSOptions and MixedFSWrapper to align with exising code
ta440 Feb 20, 2024
164041c
lint
ta440 Feb 20, 2024
edcc929
lint
ta440 Feb 20, 2024
b317f39
lint
ta440 Feb 20, 2024
0700381
set up original spaces for MixedFSWrapper within the subwrappers
ta440 Feb 28, 2024
2f517e1
make original_space a base wrapper property
ta440 Feb 28, 2024
5c2980d
make original_space a required argument for wrapper.setup()
ta440 Feb 29, 2024
222c396
lint
ta440 Feb 29, 2024
66ce750
Merge branch 'main' into wrapper_of_options
tommbendall Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion gusto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__all__ = [
"IntegrateByParts", "TransportEquationType", "OutputParameters",
"CompressibleParameters", "ShallowWaterParameters",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions", "MixedFSOptions",
"SpongeLayerParameters", "DiffusionParameters", "BoundaryLayerParameters"
]

Expand Down Expand Up @@ -172,6 +172,15 @@ class SUPGOptions(WrapperOptions):
ibp = IntegrateByParts.TWICE


class MixedFSOptions(WrapperOptions):
"""Specifies options for a mixed finite element formulation
where different suboptions are applied to different
prognostic variables."""

name = "mixed_options"
suboptions = {}


class SpongeLayerParameters(Configuration):
"""Specifies parameters describing a 'sponge' (damping) layer."""

Expand Down
77 changes: 61 additions & 16 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from firedrake import (
Function, TestFunction, NonlinearVariationalProblem,
Function, TestFunction, TestFunctions, NonlinearVariationalProblem,
NonlinearVariationalSolver, DirichletBC, split, Constant
)
from firedrake.fml import (
Expand Down Expand Up @@ -88,7 +88,21 @@ def __init__(self, domain, field_name=None, solver_parameters=None,

if options is not None:
self.wrapper_name = options.name
if self.wrapper_name == "embedded_dg":
if self.wrapper_name == "mixed_options":
self.wrapper = MixedFSWrapper()

for field, suboption in options.suboptions.items():
if suboption.name == 'embedded_dg':
self.wrapper.subwrappers.update({field: EmbeddedDGWrapper(self, suboption)})
elif suboption.name == "recovered":
self.wrapper.subwrappers.update({field: RecoveryWrapper(self, suboption)})
elif suboption.name == "supg":
raise RuntimeError(
'Time discretisation: suboption SUPG is currently not implemented within MixedOptions')
else:
raise RuntimeError(
f'Time discretisation: suboption wrapper {wrapper_name} not implemented')
elif self.wrapper_name == "embedded_dg":
self.wrapper = EmbeddedDGWrapper(self, options)
elif self.wrapper_name == "recovered":
self.wrapper = RecoveryWrapper(self, options)
Expand Down Expand Up @@ -159,21 +173,52 @@ def setup(self, equation, apply_bcs=True, *active_labels):
# -------------------------------------------------------------------- #

if self.wrapper is not None:
self.wrapper.setup()
self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))
if self.wrapper_name == "mixed_options":

self.wrapper.wrapper_spaces = equation.spaces
self.wrapper.field_names = equation.field_names

for field, subwrapper in self.wrapper.subwrappers.items():

if field not in equation.field_names:
raise ValueError(f"The option defined for {field} is for a field that does not exist in the equation set")

field_idx = equation.field_names.index(field)

# Store the original space of the tracer
subwrapper.tracer_fs = self.equation.spaces[field_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
subwrapper.tracer_fs = self.equation.spaces[field_idx]


self.residual = self.wrapper.label_terms(self.residual)
subwrapper.setup()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
subwrapper.setup()
subwrapper.setup(self.equation.spaces[field_idx])


# Update the function space to that needed by the wrapper
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space

self.wrapper.setup()
self.fs = self.wrapper.function_space
new_test_mixed = TestFunctions(self.fs)

# Replace the original test function with one from the new
# function space defined by the subwrappers
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test_mixed))

else:
self.wrapper.setup()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.wrapper.setup()
self.wrapper.setup(self.fs)

I think this should be self.fs but I'm not sure...? Maybe it is just None

Copy link
Collaborator Author

@ta440 ta440 Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would want this to be None, as we want self.original_space=None when not using the mixed wrapper.

self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))

self.residual = self.wrapper.label_terms(self.residual)

# -------------------------------------------------------------------- #
# Make boundary conditions
Expand Down
95 changes: 87 additions & 8 deletions gusto/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from abc import ABCMeta, abstractmethod
from firedrake import (
FunctionSpace, Function, BrokenElement, Projector, Interpolator,
VectorElement, Constant, as_ufl, dot, grad, TestFunction
VectorElement, Constant, as_ufl, dot, grad, TestFunction, MixedFunctionSpace
)
from firedrake.fml import Term
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions
from gusto.recovery import Recoverer, ReversibleRecoverer
from gusto.labels import transporting_velocity
import ufl

__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper"]
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper", "MixedFSWrapper"]


class Wrapper(object, metaclass=ABCMeta):
Expand All @@ -33,6 +33,7 @@ def __init__(self, time_discretisation, wrapper_options):
self.time_discretisation = time_discretisation
self.options = wrapper_options
self.solver_parameters = None
self.tracer_fs = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.tracer_fs = None
self.original_space = None


@abstractmethod
def setup(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add original_space as an argument here

Expand Down Expand Up @@ -82,10 +83,14 @@ def setup(self):
assert isinstance(self.options, EmbeddedDGOptions), \
'Embedded DG wrapper can only be used with Embedded DG Options'

original_space = self.time_discretisation.fs
domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

if self.tracer_fs is not None:
original_space = self.tracer_fs
else:
original_space = self.time_discretisation.fs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this snippet changes to be:
self.original_space = original_space

# -------------------------------------------------------------------- #
# Set up spaces to be used with wrapper
# -------------------------------------------------------------------- #
Expand All @@ -104,7 +109,10 @@ def setup(self):

self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)
if self.time_discretisation.idx is None:

if self.tracer_fs is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.tracer_fs is not None:
if self.original_space is not None:

self.x_projected = Function(self.tracer_fs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.x_projected = Function(self.tracer_fs)
self.x_projected = Function(self.original_space)

?

elif self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])
Expand Down Expand Up @@ -161,13 +169,19 @@ class RecoveryWrapper(Wrapper):
def setup(self):
"""Sets up function spaces and fields needed for this wrapper."""

print(self.options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just needs removing!


assert isinstance(self.options, RecoveryOptions), \
'Embedded DG wrapper can only be used with Recovery Options'
'Recovery wrapper can only be used with Recovery Options'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for correcting this!


original_space = self.time_discretisation.fs
domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

if self.tracer_fs is not None:
original_space = self.tracer_fs
else:
original_space = self.time_discretisation.fs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this snippet changes to

self.original_space = original_space

# -------------------------------------------------------------------- #
# Set up spaces to be used with wrapper
# -------------------------------------------------------------------- #
Expand All @@ -184,10 +198,17 @@ def setup(self):
# Internal variables to be used
# -------------------------------------------------------------------- #

self.x_in_tmp = Function(self.time_discretisation.fs)
if self.tracer_fs is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.tracer_fs is not None:
if self.original_space is not None:

self.x_in_tmp = Function(self.tracer_fs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.x_in_tmp = Function(self.tracer_fs)
self.x_in_tmp = Function(self.original_space)

else:
self.x_in_tmp = Function(self.time_discretisation.fs)

self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)
if self.time_discretisation.idx is None:

if self.tracer_fs is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.tracer_fs is not None:
if self.original_space is not None:

self.x_projected = Function(self.tracer_fs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.x_projected = Function(self.tracer_fs)
self.x_projected = Function(self.original_space)

elif self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])
Expand Down Expand Up @@ -361,3 +382,61 @@ def label_terms(self, residual):
new_residual = transporting_velocity.update_value(new_residual, self.transporting_velocity)

return new_residual


class MixedFSWrapper(object):
"""
An object to hold a subwrapper dictionary with different wrappers for
different tracers. This means that different tracers can be solved
simultaneously using a CoupledTransportEquation, whilst being in
different spaces and needing different implementation options.
"""

def __init__(self):

self.wrapper_spaces = None
self.field_names = None
self.subwrappers = {}

def setup(self):
""" Compute the new mixed function space from the subwrappers """

self.function_space = MixedFunctionSpace(self.wrapper_spaces)
self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

def pre_apply(self, x_in):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks exactly right to me

Perform the pre-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = x_in.subfunctions[field_idx]
x_in_sub = self.x_in.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.pre_apply(field)
x_in_sub.assign(subwrapper.x_in)
else:
x_in_sub.assign(field)

def post_apply(self, x_out):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks exactly right to me

"""
Perform the post-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = self.x_out.subfunctions[field_idx]
x_out_sub = x_out.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.x_out.assign(field)
subwrapper.post_apply(x_out_sub)
else:
x_out_sub.assign(field)
Loading
Loading