From e9b05a285f58ea1d8885083518a41100d384c822 Mon Sep 17 00:00:00 2001 From: Tim Andrews Date: Thu, 2 Jan 2025 16:25:44 +0000 Subject: [PATCH] set the mean mixing ratio augmentation residual --- gusto/spatial_methods/augmentation.py | 107 ++++++++++++++++++++++---- 1 file changed, 93 insertions(+), 14 deletions(-) diff --git a/gusto/spatial_methods/augmentation.py b/gusto/spatial_methods/augmentation.py index 5a5807d7..9734bf21 100644 --- a/gusto/spatial_methods/augmentation.py +++ b/gusto/spatial_methods/augmentation.py @@ -12,7 +12,8 @@ Constant, sqrt, cross, curl, FunctionSpace, assemble, DirichletBC ) from firedrake.fml import ( - subject, all_terms, replace_subject, keep, replace_test_function + subject, all_terms, replace_subject, keep, replace_test_function, + replace_trial_function, drop ) from gusto import ( time_derivative, transport, transporting_velocity, TransportEquationType, @@ -274,7 +275,7 @@ def __init__( # Set up the scheme for the mean mixing ratio - mean_mX = Function(DG0) + mean_mX = Function(DG0, name='mean_mX') mean_space = DG0 exist_spaces.append(mean_space) @@ -282,6 +283,10 @@ def __init__( self.X = Function(self.fs) self.tests = TestFunctions(self.fs) + print(self.X) + + self.bcs = [] + self.x_in = Function(self.fs) self.x_out = Function(self.fs) @@ -289,17 +294,91 @@ def __init__( # IF this is conservatively transported. mX_idx = eqns.field_names.index(mX_name) - #mean_mass = + old_residual = eqns.residual + + mean_residual = old_residual.label_map( + lambda t: t.get(prognostic) == mX_name, + map_if_false=drop + ) + + # Replace trial functions with those in the new mixed function space + for term in eqns.residual: + print('\n') + print(term.form) + for idx in range(self.idx_orig): + field = eqns.field_names[idx] + # Seperate logic if mass-weighted or not? + print(idx) + print(field) + + prog = split(self.X)[idx] + + print('\n residual before change') + print(old_residual.form) + old_residual = old_residual.label_map( + lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx) + ) + old_residual = old_residual.label_map( + lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx) + ) + print('\n residual after change') + print(old_residual.form) + + #old_residual = old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_true=replace_subject(self.X[idx], old_idx=idx) + #) + #old_residual = old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_true=replace_test_function(self.tests[idx], old_idx=idx) + #) + + # Define the mean mixing ratio residual + #mean_residual = mX_residual.label_map( + # lambda t: t.has_label(mass_weighted), + #map_if_true=replace_subject(mean_mass, old_idx=mX_idx), + # map_if_false=replace_subject(self.X[mean_idx], old_idx = mX_idx) + #) + + print('\n mean mX residual before change') + print(mean_residual.form) + mean_residual = mean_residual.label_map( + all_terms, + replace_subject(self.X, old_idx=mX_idx, new_idx = mean_idx) + ) + mean_residual = mean_residual.label_map( + all_terms, + replace_test_function(self.tests, old_idx=mX_idx, new_idx = mean_idx) + ) + print('\n mean mX residual after change') + print(mean_residual.form) + + # Form the new residual + self.residual = old_residual + mean_residual + + print('\n full residual') + print(self.residual.form) + + + + def setup_residual(self): + # Update the residual # Now, extract the existing residual: old_residual = eqns.residual + print(old_residual.form) + # Extract terms relating to the mixing ratio of interest mX_residual = old_residual.label_map( lambda t: t.get(prognostic) == mX_name, - map_if_true=keep + map_if_false=drop ) + print(mX_residual.form) + # Replace trial and test functions with the new mixed # function space # Does this work is the subject is mass-weighted??? @@ -309,15 +388,16 @@ def __init__( print(idx) print(field) + prog = split(self.X)[idx] - old_residual = old_residual.label_map( - lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), - map_if_true=replace_subject(self.X[idx], old_idx=idx) - ) - old_residual = old_residual.label_map( - lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), - map_if_true=replace_test_function(self.tests[idx], old_idx=idx) - ) + #old_residual = old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_true=replace_subject(self.X[idx], old_idx=idx) + #) + #old_residual = old_residual.label_map( + # lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted), + # map_if_true=replace_test_function(self.tests[idx], old_idx=idx) + #) # Define the mean mixing ratio residual mean_residual = mX_residual.label_map( @@ -331,10 +411,9 @@ def __init__( map_if_false=replace_test_function(self.tests[mean_idx], old_idx=mX_idx) ) + # Form the new residual self.residual = old_residual + mean_residual - self.bcs = [] - def pre_apply(self, x_in): """