diff --git a/gusto/spatial_methods/augmentation.py b/gusto/spatial_methods/augmentation.py index c0d59098..740afe31 100644 --- a/gusto/spatial_methods/augmentation.py +++ b/gusto/spatial_methods/augmentation.py @@ -9,7 +9,8 @@ LinearVariationalProblem, LinearVariationalSolver, lhs, rhs, dot, ds_b, ds_v, ds_t, ds, FacetNormal, TestFunction, TrialFunction, transpose, nabla_grad, outer, dS, dS_h, dS_v, sign, jump, div, - Constant, sqrt, cross, curl, FunctionSpace, assemble, DirichletBC + Constant, sqrt, cross, curl, FunctionSpace, assemble, DirichletBC, + Projector ) from firedrake.fml import ( subject, all_terms, replace_subject, keep, replace_test_function, @@ -19,6 +20,7 @@ time_derivative, transport, transporting_velocity, TransportEquationType, logger, prognostic, mass_weighted ) +from gusto.spatial_methods.transport_methods import DGUpwind import copy @@ -269,17 +271,26 @@ def __init__( self.mean_name = 'mean_'+mX_name print(self.mean_name) - self.eqns = eqns + self.eqn_orig = eqns + self.domain=domain exist_spaces = eqns.spaces self.idx_orig = len(exist_spaces) mean_idx = self.idx_orig + self.mean_idx = mean_idx # Extract the mixing ratio in question: mX_idx = eqns.field_names.index(mX_name) + self.mX_idx = mX_idx # Define the mean mixing ratio on the DG0 space DG0 = FunctionSpace(domain.mesh, "DG", 0) + self.DG0 = Function(DG0) + mX_space = eqns.spaces[mean_idx] + + self.mX_in = Function(mX_space) + self.mean_in = Function(mean_space) + self.compute_mean = Projector(mX_space, DG0) # Set up the scheme for the mean mixing ratio @@ -293,48 +304,57 @@ def __init__( print(self.X) - self.bcs = [] + self.bcs = None self.x_in = Function(self.fs) self.x_out = Function(self.fs) # Compute the new mean mass weighted term, # IF this is conservatively transported. - mX_idx = eqns.field_names.index(mX_name) - old_residual = eqns.residual + #old_residual = eqns.residual - mean_residual = old_residual.label_map( - lambda t: t.get(prognostic) == mX_name, - map_if_false=drop - ) - mean_residual = prognostic.update_value(mean_residual, self.mean_name) + #mean_residual = old_residual.label_map( + # lambda t: t.get(prognostic) == mX_name, + # map_if_false=drop + #) + #mean_residual = prognostic.update_value(mean_residual, self.mean_name) # Replace trial functions with those in the new mixed function space - for term in eqns.residual: - print('\n') - print(term.form) + #for term in eqns.residual: + # print('\n') + # print(term.form) - for idx in range(self.idx_orig): - field = eqns.field_names[idx] - # Seperate logic if mass-weighted or not? - print(idx) - print(field) - - prog = split(self.X)[idx] + # If I do this later on with the transport terms, maybe I don't need to do this here? - print('\n residual before change') - print(old_residual.form) - old_residual = old_residual.label_map( - lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), - map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx) - ) - old_residual = old_residual.label_map( - lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), - map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx) - ) - print('\n residual after change') - print(old_residual.form) + #for idx in range(self.idx_orig): + # field = eqns.field_names[idx] + # Seperate logic if mass-weighted or not? + # print(idx) + # print(field) + + # prog = split(self.X)[idx] + + # print('\n residual term before change') + # print(old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_false=drop + # ).form) + + # old_residual = old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx) + # ) + # old_residual = old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx) + # ) + + # print('\n residual term after change') + # print(old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_false=drop + # ).form) #old_residual = old_residual.label_map( # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), @@ -352,15 +372,87 @@ def __init__( # map_if_false=replace_subject(self.X[mean_idx], old_idx = mX_idx) #) - print('\n mean mX residual before change') + #print('\n mean mX residual before change') + #print(mean_residual.form) + #mean_residual = mean_residual.label_map( + # all_terms, + # replace_subject(self.X, old_idx=mX_idx, new_idx = mean_idx) + #) + #mean_residual = mean_residual.label_map( + # all_terms, + # replace_test_function(self.tests, old_idx=mX_idx, new_idx = mean_idx) + #) + #print('\n mean mX residual after change') + #print(mean_residual.form) + + # Form the new residual + #self.residual = old_residual + mean_residual + + #print('\n full residual') + #print(self.residual.form) + + #print('yoyoyoy') + + #for term in self.residual: + # print(term.get(prognostic)) + + def setup_transport(self, spatial_methods, equation): + # Copy spatial method for the mixing ratio onto the + # mean mixing ratio. + + old_residual = equation.residual + + # Copy the mean mixing ratio residual terms: + mean_residual = old_residual.label_map( + lambda t: t.get(prognostic) == self.mX_name, + map_if_false=drop + ) + mean_residual = prognostic.update_value(mean_residual, self.mean_name) + + print('\n in setup_transport') + + # Replace the tests and trial functions for all terms + # of the fields in the original equation + for idx in range(self.idx_orig): + field = self.eqn_orig.field_names[idx] + # Seperate logic if mass-weighted or not? + print('\n', idx) + print(field) + + prog = split(self.X)[idx] + + print('\n residual term before change') + print(old_residual.label_map( + lambda t: t.get(prognostic) == field, + map_if_false=drop + ).form) + + old_residual = old_residual.label_map( + lambda t: t.get(prognostic) == field, + map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx) + ) + old_residual = old_residual.label_map( + lambda t: t.get(prognostic) == field, + map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx) + ) + print('\n residual term after change') + print(old_residual.label_map( + lambda t: t.get(prognostic) == field, + map_if_false=drop + ).form) + + print('\n now setting up mean mixing ratio residual terms') + + + print('\n mean mX residual after change') print(mean_residual.form) mean_residual = mean_residual.label_map( all_terms, - replace_subject(self.X, old_idx=mX_idx, new_idx = mean_idx) + replace_subject(self.X, old_idx=self.mX_idx, new_idx=self.mean_idx) ) mean_residual = mean_residual.label_map( all_terms, - replace_test_function(self.tests, old_idx=mX_idx, new_idx = mean_idx) + replace_test_function(self.tests, old_idx=self.mX_idx, new_idx=self.mean_idx) ) print('\n mean mX residual after change') print(mean_residual.form) @@ -368,17 +460,13 @@ def __init__( # Form the new residual self.residual = old_residual + mean_residual - print('\n full residual') - print(self.residual.form) + #Check these two forms + print('\n Original equation with residual of length, ', len(equation.residual)) + print('\n Augmented equation with residual of length, ', len(self.residual)) - print('yoyoyoy') - for term in self.residual: - print(term.get(prognostic)) - def setup_transport(self, spatial_methods): - # Copy spatial method for the mixing ratio onto the - # mean mixing ratio. + def setup_transport_old(self, spatial_methods): mX_spatial_method = next(method for method in spatial_methods if method.variable == self.mX_name) mean_spatial_method = copy.copy(mX_spatial_method) @@ -387,9 +475,19 @@ def setup_transport(self, spatial_methods): self.spatial_methods.append(mean_spatial_method) for method in self.spatial_methods: print(method.variable) - #method.equation.residual = self.residual + method.equation.residual = self.residual + print(method.form.form) print(len(method.equation.residual)) + # Alternatively, redo all the spatial methods + # using the new mixed function space. + # So, want to make a new list of spatial methods + new_spatial_methods = [] + for method in self.spatial_methods: + # Determine the tye of transport method: + new_method = DGUpwind(self, method.variable) + new_spatial_methods.append(new_method) + def pre_apply(self, x_in): """ Sets the original fields, i.e. not the mean mixing ratios @@ -403,6 +501,12 @@ def pre_apply(self, x_in): for idx in range(self.idx_orig): self.x_in.subfunctions[idx].assign(x_in.subfunctions[idx]) + # Set the mean mixing ratio to be zero, just because + #DG0 = FunctionSpace(self.domain.mesh, "DG", 0) + #mean_mX = Function(DG0, name=self.mean_name) + + #self.x_in.subfunctions[self.mean_idx].assign(mean_mX) + def post_apply(self, x_out): """ @@ -412,7 +516,7 @@ def post_apply(self, x_out): x_out (:class:`Function`): The output fields """ - for idx, field in enumerate(self.eqn.field_names): + for idx in range(self.idx_orig): x_out.subfunctions[idx].assign(self.x_out.subfunctions[idx]) def update(self, x_in_mixed): @@ -422,5 +526,13 @@ def update(self, x_in_mixed): Args: x_in_mixed (:class:`Function`): The mixed function to update. """ - + + # Compute the mean mixing ratio + # How do I do this? + self.mX_in.assign(x_in_mixed[self.mX_idx]) + + # Project this into the lowest order space: + self.compute_mean.project() + + self.x_in.subfunctions[self.mean_idx].assign(self.mean_out) pass \ No newline at end of file diff --git a/gusto/timestepping/split_timestepper.py b/gusto/timestepping/split_timestepper.py index 5d988643..98b9cb84 100644 --- a/gusto/timestepping/split_timestepper.py +++ b/gusto/timestepping/split_timestepper.py @@ -300,11 +300,11 @@ def setup_scheme(self): self.setup_equation(self.equation) # If there is an augmentation, set up these transport terms + # Or, perhaps set up the whole residual now ... ? if self.scheme.augmentation is not None: if self.scheme.augmentation.name == 'mean_mixing_ratio': - self.scheme.augmentation.setup_transport(self.spatial_methods) + self.scheme.augmentation.setup_transport(self.spatial_methods, self.equation) print('Setting up augmented equation') - self.setup_equation(self.scheme.augmentation) # Go through and label all non-physics terms with a "dynamics" label