Skip to content

Commit

Permalink
more mixed options changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 committed Jan 9, 2024
1 parent 4a13df3 commit 6165057
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 29 deletions.
44 changes: 42 additions & 2 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def get_apply(self, x_out, x_in):
def new_apply(self, x_out, x_in):

self.wrapper.pre_apply(x_in)
original_apply(self, self.wrapper.x_out, self.wrapper.x_in)
if type(self.wrapper) == MixedOptions:
original_apply(self, x_out, x_in)
else:
original_apply(self, self.wrapper.x_out, self.wrapper.x_in)
self.wrapper.post_apply(x_out)

return new_apply(self, x_out, x_in)
Expand Down Expand Up @@ -103,9 +106,39 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
if options is not None:
if type(options) == MixedOptions:
print('jahjah')
print(options)
self.wrapper = options
#self.subwrappers = {}
# Or do I need to initialise everything here
# like is done for a single wrapper?
for field, suboption in self.wrapper.suboptions.items():
print(field)
print(suboption)

#Replace options with a wrapper?

#if suboption.name == 'embedded_dg':
# self.suboptions[field].wrapper = EmbeddedDGWrapper(self, suboption)
#elif suboption.name == "recovered":
# self.suboptions[field].wrapper = RecoveryWrapper(self, suboption)
#elif suboption.name == "supg":
# self.suboptions[field].wrapper = SUPGWrapper(self, suboption)
#else:
# raise RuntimeError(
# f'Time discretisation: suboption wrapper {wrapper_name} not implemented')
if suboption.name == 'embedded_dg':
self.wrapper.subwrappers.update({field:EmbeddedDGWrapper(self, suboption)})
elif suboption.name == "recovered":
self.wrapper.subwrappers.update({field:RecoveryWrapper(self, suboption)})
elif suboption.name == "supg":
self.wrapper.subwrappers.update({field:SUPGWrapper(self, suboption)})
else:
raise RuntimeError(
f'Time discretisation: suboption wrapper {wrapper_name} not implemented')
print(self.wrapper)
print(self.wrapper.suboptions)
print(self.wrapper.subwrappers)
#self.wrapper_name = 'mixed'
else:
self.wrapper_name = options.name
if self.wrapper_name == "embedded_dg":
Expand Down Expand Up @@ -180,10 +213,17 @@ def setup(self, equation, apply_bcs=True, *active_labels):

if self.wrapper is not None:
if type(self.wrapper) == MixedOptions:
#if self.wrapper_name == 'mixed':
# Subwrappers are defined.
# Set these up with ?
for _, subwrapper in self.wrapper.suboptions.items():
for field, subwrapper in self.wrapper.subwrappers.items():
#Set up field idxs here.
print(field)
print(subwrapper)
self.wrapper.subwrappers[field].idx = equation.field_names.index(field)
self.wrapper.subwrappers[field].mixed_options = True
self.wrapper.subwrappers[field].setup()
self.fs = self.wrapper.subwrappers[field].function_space
pass
else:
self.wrapper.setup()
Expand Down
69 changes: 45 additions & 24 deletions gusto/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def __init__(self, time_discretisation, wrapper_options):
self.time_discretisation = time_discretisation
self.options = wrapper_options
self.solver_parameters = None
self.field_name = None
#self.field_name = None
self.idx = None
self.mixed_options = False

@abstractmethod
def setup(self):
Expand Down Expand Up @@ -162,8 +164,10 @@ class RecoveryWrapper(Wrapper):
def setup(self):
"""Sets up function spaces and fields needed for this wrapper."""

print(self.options)

assert isinstance(self.options, RecoveryOptions), \
'Embedded DG wrapper can only be used with Recovery Options'
'Recovery wrapper can only be used with Recovery Options'

original_space = self.time_discretisation.fs
domain = self.time_discretisation.domain
Expand All @@ -177,6 +181,7 @@ def setup(self):
V_elt = BrokenElement(original_space.ufl_element())
self.function_space = FunctionSpace(domain.mesh, V_elt)
else:
print('using embedded space')
self.function_space = self.options.embedding_space

self.test_space = self.function_space
Expand All @@ -185,11 +190,23 @@ def setup(self):
# Internal variables to be used
# -------------------------------------------------------------------- #

self.x_in_tmp = Function(self.time_discretisation.fs)
if self.mixed_options == True:
self.x_in_tmp = Function(self.function_space)
else:
self.x_in_tmp = Function(self.time_discretisation.fs)

print(self.function_space)
print(self.time_discretisation.fs)

self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

#self.x_out = Function(equation.function_space)
#self.x_out = Function(self.time_discretisation.fs)

if self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
#self.x_projected = Function(equation.function_space)
self.x_projected = Function(self.function_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])

Expand Down Expand Up @@ -381,18 +398,20 @@ def __init__(self, equation, suboptions):
ValueError: If an option is defined for a field that is not in the prognostic variable set
"""
print(equation.function_space)
self.x_in = Function(equation.function_space)
self.x_out = Function(equation.function_space)
#self.x_in = Function(equation.function_space)
#self.x_out = Function(equation.function_space)

print('Initialising MixedOptions')

self.suboptions = suboptions

print(suboptions.items())
#print(suboptions.items())

self.wrapper = {}
self.subwrappers = {}

for field, suboption in suboptions.items():
print(field)
print(suboption)
# Check that the field is in the prognostic variable set:
if field not in equation.field_names:
raise ValueError(f"The limiter defined for {field} is for a field that does not exist in the equation set")
Expand All @@ -401,19 +420,19 @@ def __init__(self, equation, suboptions):
wrapper_name = suboption.name
print(wrapper_name)

if wrapper_name == "embedded_dg":
self.wrapper.update({field:EmbeddedDGWrapper(self, suboption)})
elif wrapper_name == "recovered":
self.suboptions[field].subwrapper = RecoveryWrapper(self, suboption)
elif wrapper_name == "supg":
self.suboptions[field].subwrapper = SUPGWrapper(self, suboption)
else:
raise RuntimeError(
f'Time discretisation: suboption wrapper {wrapper_name} not implemented')
#if wrapper_name == "embedded_dg":
# self.wrapper.update({field:EmbeddedDGWrapper(self, suboption)})
#elif wrapper_name == "recovered":
# self.suboptions[field].subwrapper = RecoveryWrapper(self, suboption)
#elif wrapper_name == "supg":
# self.suboptions[field].subwrapper = SUPGWrapper(self, suboption)
#else:
# raise RuntimeError(
# f'Time discretisation: suboption wrapper {wrapper_name} not implemented')

#Initialise the wrapper and associate with a field:

self.suboptions[field].idx = equation.field_names.index(field)
#self.suboptions[field].wrapper_name = suboption.name
#self.suboptions[field].idx = equation.field_names.index(field)
#self.suboptions[field].fs = equation.field_names.function_space(field)
# self.subwrapper.x_in = Function(self.suboptions[field].fs)

Expand All @@ -426,8 +445,9 @@ def pre_apply(self, x_in):
Perform the pre-applications from all subwrappers
"""

for _, subwrapper in self.suboptions.items():
for _, subwrapper in self.subwrappers.items():
print(subwrapper)
print('pre')
field = x_in.subfunctions[subwrapper.idx]
subwrapper.pre_apply(field)

Expand All @@ -436,8 +456,9 @@ def post_apply(self, x_in):
Perform the post-applications from all subwrappers
"""

for _, suboption in self.suboptions.items():
print(suboption)
field = x_in.subfunctions[suboption.idx]
suboption.subwrapper.post_apply(field)
for _, subwrapper in self.subwrappers.items():
print(subwrapper)
print('post')
field = x_in.subfunctions[subwrapper.idx]
subwrapper.post_apply(field)

7 changes: 4 additions & 3 deletions integration-tests/transport/test_mixed_fs_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def setup_limiters(dirname, space_A, space_B):

sublimiters.update({'tracerB': VertexBasedLimiter(VDG1_B)})
elif degree == 1:
opts = EmbeddedDGOptions()
suboptions.update({'tracerB': EmbeddedDGOptions()})
sublimiters.update({'tracerB': ThetaLimiter(VB)})
else:
raise NotImplementedError
Expand Down Expand Up @@ -276,8 +276,9 @@ def setup_limiters(dirname, space_A, space_B):
return stepper, tmax, true_fieldA, true_fieldB


@pytest.mark.parametrize('space_A', ['Vtheta_degree_0', 'Vtheta_degree_1', 'DG0',
'DG1', 'DG1_equispaced'])
#@pytest.mark.parametrize('space_A', ['Vtheta_degree_0', 'Vtheta_degree_1', 'DG0',
# 'DG1', 'DG1_equispaced'])
@pytest.mark.parametrize('space_A', ['DG0'])#, 'DG1_equispaced'])
# It only makes sense to use the same degree for tracer B
@pytest.mark.parametrize('space_B', ['Vtheta', 'DG'])

Expand Down

0 comments on commit 6165057

Please sign in to comment.