diff --git a/gusto/time_discretisation.py b/gusto/time_discretisation.py index 593bbe830..c537b317f 100644 --- a/gusto/time_discretisation.py +++ b/gusto/time_discretisation.py @@ -56,7 +56,10 @@ def get_apply(self, x_out, x_in): def new_apply(self, x_out, x_in): self.wrapper.pre_apply(x_in) - original_apply(self, self.wrapper.x_out, self.wrapper.x_in) + if type(self.wrapper) == MixedOptions: + original_apply(self, x_out, x_in) + else: + original_apply(self, self.wrapper.x_out, self.wrapper.x_in) self.wrapper.post_apply(x_out) return new_apply(self, x_out, x_in) @@ -103,9 +106,39 @@ def __init__(self, domain, field_name=None, solver_parameters=None, if options is not None: if type(options) == MixedOptions: print('jahjah') + print(options) self.wrapper = options + #self.subwrappers = {} # Or do I need to initialise everything here # like is done for a single wrapper? + for field, suboption in self.wrapper.suboptions.items(): + print(field) + print(suboption) + + #Replace options with a wrapper? + + #if suboption.name == 'embedded_dg': + # self.suboptions[field].wrapper = EmbeddedDGWrapper(self, suboption) + #elif suboption.name == "recovered": + # self.suboptions[field].wrapper = RecoveryWrapper(self, suboption) + #elif suboption.name == "supg": + # self.suboptions[field].wrapper = SUPGWrapper(self, suboption) + #else: + # raise RuntimeError( + # f'Time discretisation: suboption wrapper {wrapper_name} not implemented') + if suboption.name == 'embedded_dg': + self.wrapper.subwrappers.update({field:EmbeddedDGWrapper(self, suboption)}) + elif suboption.name == "recovered": + self.wrapper.subwrappers.update({field:RecoveryWrapper(self, suboption)}) + elif suboption.name == "supg": + self.wrapper.subwrappers.update({field:SUPGWrapper(self, suboption)}) + else: + raise RuntimeError( + f'Time discretisation: suboption wrapper {wrapper_name} not implemented') + print(self.wrapper) + print(self.wrapper.suboptions) + print(self.wrapper.subwrappers) + #self.wrapper_name = 'mixed' else: self.wrapper_name = options.name if self.wrapper_name == "embedded_dg": @@ -180,10 +213,17 @@ def setup(self, equation, apply_bcs=True, *active_labels): if self.wrapper is not None: if type(self.wrapper) == MixedOptions: + #if self.wrapper_name == 'mixed': # Subwrappers are defined. # Set these up with ? - for _, subwrapper in self.wrapper.suboptions.items(): + for field, subwrapper in self.wrapper.subwrappers.items(): + #Set up field idxs here. + print(field) print(subwrapper) + self.wrapper.subwrappers[field].idx = equation.field_names.index(field) + self.wrapper.subwrappers[field].mixed_options = True + self.wrapper.subwrappers[field].setup() + self.fs = self.wrapper.subwrappers[field].function_space pass else: self.wrapper.setup() diff --git a/gusto/wrappers.py b/gusto/wrappers.py index b14a23b6b..52da12e4f 100644 --- a/gusto/wrappers.py +++ b/gusto/wrappers.py @@ -33,7 +33,9 @@ def __init__(self, time_discretisation, wrapper_options): self.time_discretisation = time_discretisation self.options = wrapper_options self.solver_parameters = None - self.field_name = None + #self.field_name = None + self.idx = None + self.mixed_options = False @abstractmethod def setup(self): @@ -162,8 +164,10 @@ class RecoveryWrapper(Wrapper): def setup(self): """Sets up function spaces and fields needed for this wrapper.""" + print(self.options) + assert isinstance(self.options, RecoveryOptions), \ - 'Embedded DG wrapper can only be used with Recovery Options' + 'Recovery wrapper can only be used with Recovery Options' original_space = self.time_discretisation.fs domain = self.time_discretisation.domain @@ -177,6 +181,7 @@ def setup(self): V_elt = BrokenElement(original_space.ufl_element()) self.function_space = FunctionSpace(domain.mesh, V_elt) else: + print('using embedded space') self.function_space = self.options.embedding_space self.test_space = self.function_space @@ -185,11 +190,23 @@ def setup(self): # Internal variables to be used # -------------------------------------------------------------------- # - self.x_in_tmp = Function(self.time_discretisation.fs) + if self.mixed_options == True: + self.x_in_tmp = Function(self.function_space) + else: + self.x_in_tmp = Function(self.time_discretisation.fs) + + print(self.function_space) + print(self.time_discretisation.fs) + self.x_in = Function(self.function_space) self.x_out = Function(self.function_space) + + #self.x_out = Function(equation.function_space) + #self.x_out = Function(self.time_discretisation.fs) + if self.time_discretisation.idx is None: - self.x_projected = Function(equation.function_space) + #self.x_projected = Function(equation.function_space) + self.x_projected = Function(self.function_space) else: self.x_projected = Function(equation.spaces[self.time_discretisation.idx]) @@ -381,18 +398,20 @@ def __init__(self, equation, suboptions): ValueError: If an option is defined for a field that is not in the prognostic variable set """ print(equation.function_space) - self.x_in = Function(equation.function_space) - self.x_out = Function(equation.function_space) + #self.x_in = Function(equation.function_space) + #self.x_out = Function(equation.function_space) print('Initialising MixedOptions') self.suboptions = suboptions - print(suboptions.items()) + #print(suboptions.items()) - self.wrapper = {} + self.subwrappers = {} for field, suboption in suboptions.items(): + print(field) + print(suboption) # Check that the field is in the prognostic variable set: if field not in equation.field_names: raise ValueError(f"The limiter defined for {field} is for a field that does not exist in the equation set") @@ -401,19 +420,19 @@ def __init__(self, equation, suboptions): wrapper_name = suboption.name print(wrapper_name) - if wrapper_name == "embedded_dg": - self.wrapper.update({field:EmbeddedDGWrapper(self, suboption)}) - elif wrapper_name == "recovered": - self.suboptions[field].subwrapper = RecoveryWrapper(self, suboption) - elif wrapper_name == "supg": - self.suboptions[field].subwrapper = SUPGWrapper(self, suboption) - else: - raise RuntimeError( - f'Time discretisation: suboption wrapper {wrapper_name} not implemented') + #if wrapper_name == "embedded_dg": + # self.wrapper.update({field:EmbeddedDGWrapper(self, suboption)}) + #elif wrapper_name == "recovered": + # self.suboptions[field].subwrapper = RecoveryWrapper(self, suboption) + #elif wrapper_name == "supg": + # self.suboptions[field].subwrapper = SUPGWrapper(self, suboption) + #else: + # raise RuntimeError( + # f'Time discretisation: suboption wrapper {wrapper_name} not implemented') #Initialise the wrapper and associate with a field: - - self.suboptions[field].idx = equation.field_names.index(field) + #self.suboptions[field].wrapper_name = suboption.name + #self.suboptions[field].idx = equation.field_names.index(field) #self.suboptions[field].fs = equation.field_names.function_space(field) # self.subwrapper.x_in = Function(self.suboptions[field].fs) @@ -426,8 +445,9 @@ def pre_apply(self, x_in): Perform the pre-applications from all subwrappers """ - for _, subwrapper in self.suboptions.items(): + for _, subwrapper in self.subwrappers.items(): print(subwrapper) + print('pre') field = x_in.subfunctions[subwrapper.idx] subwrapper.pre_apply(field) @@ -436,8 +456,9 @@ def post_apply(self, x_in): Perform the post-applications from all subwrappers """ - for _, suboption in self.suboptions.items(): - print(suboption) - field = x_in.subfunctions[suboption.idx] - suboption.subwrapper.post_apply(field) + for _, subwrapper in self.subwrappers.items(): + print(subwrapper) + print('post') + field = x_in.subfunctions[subwrapper.idx] + subwrapper.post_apply(field) diff --git a/integration-tests/transport/test_mixed_fs_options.py b/integration-tests/transport/test_mixed_fs_options.py index a82136923..9701e3dac 100644 --- a/integration-tests/transport/test_mixed_fs_options.py +++ b/integration-tests/transport/test_mixed_fs_options.py @@ -160,7 +160,7 @@ def setup_limiters(dirname, space_A, space_B): sublimiters.update({'tracerB': VertexBasedLimiter(VDG1_B)}) elif degree == 1: - opts = EmbeddedDGOptions() + suboptions.update({'tracerB': EmbeddedDGOptions()}) sublimiters.update({'tracerB': ThetaLimiter(VB)}) else: raise NotImplementedError @@ -276,8 +276,9 @@ def setup_limiters(dirname, space_A, space_B): return stepper, tmax, true_fieldA, true_fieldB -@pytest.mark.parametrize('space_A', ['Vtheta_degree_0', 'Vtheta_degree_1', 'DG0', - 'DG1', 'DG1_equispaced']) +#@pytest.mark.parametrize('space_A', ['Vtheta_degree_0', 'Vtheta_degree_1', 'DG0', +# 'DG1', 'DG1_equispaced']) +@pytest.mark.parametrize('space_A', ['DG0'])#, 'DG1_equispaced']) # It only makes sense to use the same degree for tracer B @pytest.mark.parametrize('space_B', ['Vtheta', 'DG'])