-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrapper of options #471
Wrapper of options #471
Changes from 21 commits
573ad52
e04c73a
4a13df3
6165057
775d3b2
793768e
5f05367
5c13db1
806a49e
2be1354
6d2fc9b
25ed243
f410d2c
d166a22
aa3d7c0
f1e958a
51f6f7b
55e0856
164041c
edcc929
b317f39
0700381
2f517e1
5c2980d
222c396
66ce750
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,7 +10,7 @@ | |||||
import numpy as np | ||||||
|
||||||
from firedrake import ( | ||||||
Function, TestFunction, NonlinearVariationalProblem, | ||||||
Function, TestFunction, TestFunctions, NonlinearVariationalProblem, | ||||||
NonlinearVariationalSolver, DirichletBC, split, Constant | ||||||
) | ||||||
from firedrake.fml import ( | ||||||
|
@@ -88,7 +88,21 @@ def __init__(self, domain, field_name=None, solver_parameters=None, | |||||
|
||||||
if options is not None: | ||||||
self.wrapper_name = options.name | ||||||
if self.wrapper_name == "embedded_dg": | ||||||
if self.wrapper_name == "mixed_options": | ||||||
self.wrapper = MixedFSWrapper() | ||||||
|
||||||
for field, suboption in options.suboptions.items(): | ||||||
if suboption.name == 'embedded_dg': | ||||||
self.wrapper.subwrappers.update({field: EmbeddedDGWrapper(self, suboption)}) | ||||||
elif suboption.name == "recovered": | ||||||
self.wrapper.subwrappers.update({field: RecoveryWrapper(self, suboption)}) | ||||||
elif suboption.name == "supg": | ||||||
raise RuntimeError( | ||||||
'Time discretisation: suboption SUPG is currently not implemented within MixedOptions') | ||||||
else: | ||||||
raise RuntimeError( | ||||||
f'Time discretisation: suboption wrapper {wrapper_name} not implemented') | ||||||
elif self.wrapper_name == "embedded_dg": | ||||||
self.wrapper = EmbeddedDGWrapper(self, options) | ||||||
elif self.wrapper_name == "recovered": | ||||||
self.wrapper = RecoveryWrapper(self, options) | ||||||
|
@@ -159,21 +173,52 @@ def setup(self, equation, apply_bcs=True, *active_labels): | |||||
# -------------------------------------------------------------------- # | ||||||
|
||||||
if self.wrapper is not None: | ||||||
self.wrapper.setup() | ||||||
self.fs = self.wrapper.function_space | ||||||
if self.solver_parameters is None: | ||||||
self.solver_parameters = self.wrapper.solver_parameters | ||||||
new_test = TestFunction(self.wrapper.test_space) | ||||||
# SUPG has a special wrapper | ||||||
if self.wrapper_name == "supg": | ||||||
new_test = self.wrapper.test | ||||||
|
||||||
# Replace the original test function with the one from the wrapper | ||||||
self.residual = self.residual.label_map( | ||||||
all_terms, | ||||||
map_if_true=replace_test_function(new_test)) | ||||||
if self.wrapper_name == "mixed_options": | ||||||
|
||||||
self.wrapper.wrapper_spaces = equation.spaces | ||||||
self.wrapper.field_names = equation.field_names | ||||||
|
||||||
for field, subwrapper in self.wrapper.subwrappers.items(): | ||||||
|
||||||
if field not in equation.field_names: | ||||||
raise ValueError(f"The option defined for {field} is for a field that does not exist in the equation set") | ||||||
|
||||||
field_idx = equation.field_names.index(field) | ||||||
|
||||||
# Store the original space of the tracer | ||||||
subwrapper.tracer_fs = self.equation.spaces[field_idx] | ||||||
|
||||||
self.residual = self.wrapper.label_terms(self.residual) | ||||||
subwrapper.setup() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# Update the function space to that needed by the wrapper | ||||||
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space | ||||||
|
||||||
self.wrapper.setup() | ||||||
self.fs = self.wrapper.function_space | ||||||
new_test_mixed = TestFunctions(self.fs) | ||||||
|
||||||
# Replace the original test function with one from the new | ||||||
# function space defined by the subwrappers | ||||||
self.residual = self.residual.label_map( | ||||||
all_terms, | ||||||
map_if_true=replace_test_function(new_test_mixed)) | ||||||
|
||||||
else: | ||||||
self.wrapper.setup() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we would want this to be |
||||||
self.fs = self.wrapper.function_space | ||||||
if self.solver_parameters is None: | ||||||
self.solver_parameters = self.wrapper.solver_parameters | ||||||
new_test = TestFunction(self.wrapper.test_space) | ||||||
# SUPG has a special wrapper | ||||||
if self.wrapper_name == "supg": | ||||||
new_test = self.wrapper.test | ||||||
|
||||||
# Replace the original test function with the one from the wrapper | ||||||
self.residual = self.residual.label_map( | ||||||
all_terms, | ||||||
map_if_true=replace_test_function(new_test)) | ||||||
|
||||||
self.residual = self.wrapper.label_terms(self.residual) | ||||||
|
||||||
# -------------------------------------------------------------------- # | ||||||
# Make boundary conditions | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7,15 +7,15 @@ | |||||
from abc import ABCMeta, abstractmethod | ||||||
from firedrake import ( | ||||||
FunctionSpace, Function, BrokenElement, Projector, Interpolator, | ||||||
VectorElement, Constant, as_ufl, dot, grad, TestFunction | ||||||
VectorElement, Constant, as_ufl, dot, grad, TestFunction, MixedFunctionSpace | ||||||
) | ||||||
from firedrake.fml import Term | ||||||
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions | ||||||
from gusto.recovery import Recoverer, ReversibleRecoverer | ||||||
from gusto.labels import transporting_velocity | ||||||
import ufl | ||||||
|
||||||
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper"] | ||||||
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper", "MixedFSWrapper"] | ||||||
|
||||||
|
||||||
class Wrapper(object, metaclass=ABCMeta): | ||||||
|
@@ -33,6 +33,7 @@ def __init__(self, time_discretisation, wrapper_options): | |||||
self.time_discretisation = time_discretisation | ||||||
self.options = wrapper_options | ||||||
self.solver_parameters = None | ||||||
self.tracer_fs = None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
@abstractmethod | ||||||
def setup(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||||||
|
@@ -82,10 +83,14 @@ def setup(self): | |||||
assert isinstance(self.options, EmbeddedDGOptions), \ | ||||||
'Embedded DG wrapper can only be used with Embedded DG Options' | ||||||
|
||||||
original_space = self.time_discretisation.fs | ||||||
domain = self.time_discretisation.domain | ||||||
equation = self.time_discretisation.equation | ||||||
|
||||||
if self.tracer_fs is not None: | ||||||
original_space = self.tracer_fs | ||||||
else: | ||||||
original_space = self.time_discretisation.fs | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this snippet changes to be: |
||||||
# -------------------------------------------------------------------- # | ||||||
# Set up spaces to be used with wrapper | ||||||
# -------------------------------------------------------------------- # | ||||||
|
@@ -104,7 +109,10 @@ def setup(self): | |||||
|
||||||
self.x_in = Function(self.function_space) | ||||||
self.x_out = Function(self.function_space) | ||||||
if self.time_discretisation.idx is None: | ||||||
|
||||||
if self.tracer_fs is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
self.x_projected = Function(self.tracer_fs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? |
||||||
elif self.time_discretisation.idx is None: | ||||||
self.x_projected = Function(equation.function_space) | ||||||
else: | ||||||
self.x_projected = Function(equation.spaces[self.time_discretisation.idx]) | ||||||
|
@@ -161,13 +169,19 @@ class RecoveryWrapper(Wrapper): | |||||
def setup(self): | ||||||
"""Sets up function spaces and fields needed for this wrapper.""" | ||||||
|
||||||
print(self.options) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just needs removing! |
||||||
|
||||||
assert isinstance(self.options, RecoveryOptions), \ | ||||||
'Embedded DG wrapper can only be used with Recovery Options' | ||||||
'Recovery wrapper can only be used with Recovery Options' | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for correcting this! |
||||||
|
||||||
original_space = self.time_discretisation.fs | ||||||
domain = self.time_discretisation.domain | ||||||
equation = self.time_discretisation.equation | ||||||
|
||||||
if self.tracer_fs is not None: | ||||||
original_space = self.tracer_fs | ||||||
else: | ||||||
original_space = self.time_discretisation.fs | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this snippet changes to
|
||||||
# -------------------------------------------------------------------- # | ||||||
# Set up spaces to be used with wrapper | ||||||
# -------------------------------------------------------------------- # | ||||||
|
@@ -184,10 +198,17 @@ def setup(self): | |||||
# Internal variables to be used | ||||||
# -------------------------------------------------------------------- # | ||||||
|
||||||
self.x_in_tmp = Function(self.time_discretisation.fs) | ||||||
if self.tracer_fs is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
self.x_in_tmp = Function(self.tracer_fs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
else: | ||||||
self.x_in_tmp = Function(self.time_discretisation.fs) | ||||||
|
||||||
self.x_in = Function(self.function_space) | ||||||
self.x_out = Function(self.function_space) | ||||||
if self.time_discretisation.idx is None: | ||||||
|
||||||
if self.tracer_fs is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
self.x_projected = Function(self.tracer_fs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
elif self.time_discretisation.idx is None: | ||||||
self.x_projected = Function(equation.function_space) | ||||||
else: | ||||||
self.x_projected = Function(equation.spaces[self.time_discretisation.idx]) | ||||||
|
@@ -361,3 +382,61 @@ def label_terms(self, residual): | |||||
new_residual = transporting_velocity.update_value(new_residual, self.transporting_velocity) | ||||||
|
||||||
return new_residual | ||||||
|
||||||
|
||||||
class MixedFSWrapper(object): | ||||||
""" | ||||||
An object to hold a subwrapper dictionary with different wrappers for | ||||||
different tracers. This means that different tracers can be solved | ||||||
simultaneously using a CoupledTransportEquation, whilst being in | ||||||
different spaces and needing different implementation options. | ||||||
""" | ||||||
|
||||||
def __init__(self): | ||||||
|
||||||
self.wrapper_spaces = None | ||||||
self.field_names = None | ||||||
self.subwrappers = {} | ||||||
|
||||||
def setup(self): | ||||||
""" Compute the new mixed function space from the subwrappers """ | ||||||
|
||||||
self.function_space = MixedFunctionSpace(self.wrapper_spaces) | ||||||
self.x_in = Function(self.function_space) | ||||||
self.x_out = Function(self.function_space) | ||||||
|
||||||
def pre_apply(self, x_in): | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method looks exactly right to me |
||||||
Perform the pre-applications for all fields | ||||||
with an associated subwrapper. | ||||||
""" | ||||||
|
||||||
for field_name in self.field_names: | ||||||
field_idx = self.field_names.index(field_name) | ||||||
field = x_in.subfunctions[field_idx] | ||||||
x_in_sub = self.x_in.subfunctions[field_idx] | ||||||
|
||||||
if field_name in self.subwrappers: | ||||||
subwrapper = self.subwrappers[field_name] | ||||||
subwrapper.pre_apply(field) | ||||||
x_in_sub.assign(subwrapper.x_in) | ||||||
else: | ||||||
x_in_sub.assign(field) | ||||||
|
||||||
def post_apply(self, x_out): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method looks exactly right to me |
||||||
""" | ||||||
Perform the post-applications for all fields | ||||||
with an associated subwrapper. | ||||||
""" | ||||||
|
||||||
for field_name in self.field_names: | ||||||
field_idx = self.field_names.index(field_name) | ||||||
field = self.x_out.subfunctions[field_idx] | ||||||
x_out_sub = x_out.subfunctions[field_idx] | ||||||
|
||||||
if field_name in self.subwrappers: | ||||||
subwrapper = self.subwrappers[field_name] | ||||||
subwrapper.x_out.assign(field) | ||||||
subwrapper.post_apply(x_out_sub) | ||||||
else: | ||||||
x_out_sub.assign(field) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.