Skip to content

Commit

Permalink
more augmentation changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 committed Jan 3, 2025
1 parent e9b05a2 commit 9adf64e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 51 deletions.
77 changes: 26 additions & 51 deletions gusto/spatial_methods/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
time_derivative, transport, transporting_velocity, TransportEquationType,
logger, prognostic, mass_weighted
)
import copy


class Augmentation(object, metaclass=ABCMeta):
Expand Down Expand Up @@ -76,6 +77,8 @@ def __init__(
self, domain, eqns, transpose_commutator=True, supg=False
):

self.name = 'vorticity'

V_vel = domain.spaces('HDiv')
V_vort = domain.spaces('H1')

Expand Down Expand Up @@ -261,6 +264,11 @@ def __init__(
self, domain, eqns, mX_name
):

self.name = 'mean_mixing_ratio'
self.mX_name = mX_name
self.mean_name = 'mean_'+mX_name
print(self.mean_name)

self.eqns = eqns
exist_spaces = eqns.spaces

Expand All @@ -275,7 +283,7 @@ def __init__(

# Set up the scheme for the mean mixing ratio

mean_mX = Function(DG0, name='mean_mX')
mean_mX = Function(DG0, name=self.mean_name)
mean_space = DG0
exist_spaces.append(mean_space)

Expand All @@ -300,6 +308,7 @@ def __init__(
lambda t: t.get(prognostic) == mX_name,
map_if_false=drop
)
mean_residual = prognostic.update_value(mean_residual, self.mean_name)

# Replace trial functions with those in the new mixed function space
for term in eqns.residual:
Expand Down Expand Up @@ -362,58 +371,24 @@ def __init__(
print('\n full residual')
print(self.residual.form)



def setup_residual(self):
# Update the residual
# Now, extract the existing residual:
old_residual = eqns.residual

print(old_residual.form)

# Extract terms relating to the mixing ratio of interest
mX_residual = old_residual.label_map(
lambda t: t.get(prognostic) == mX_name,
map_if_false=drop
)

print(mX_residual.form)

# Replace trial and test functions with the new mixed
# function space
# Does this work is the subject is mass-weighted???
for idx in range(self.idx_orig):
field = eqns.field_names[idx]
# Seperate logic if mass-weighted or not?
print(idx)
print(field)

prog = split(self.X)[idx]

#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_subject(self.X[idx], old_idx=idx)
#)
#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_test_function(self.tests[idx], old_idx=idx)
#)

# Define the mean mixing ratio residual
mean_residual = mX_residual.label_map(
lambda t: t.has_label(mass_weighted),
#map_if_true=replace_subject(mean_mass, old_idx=mX_idx),
map_if_false=replace_subject(self.X[mean_idx], old_idx = mX_idx)
)
print('yoyoyoy')

mean_residual = mean_residual.label_map(
lambda t: t.has_label(mass_weighted),
map_if_false=replace_test_function(self.tests[mean_idx], old_idx=mX_idx)
)

# Form the new residual
self.residual = old_residual + mean_residual
for term in self.residual:
print(term.get(prognostic))

def setup_transport(self, spatial_methods):
# Copy spatial method for the mixing ratio onto the
# mean mixing ratio.
mX_spatial_method = next(method for method in spatial_methods if method.variable == self.mX_name)

mean_spatial_method = copy.copy(mX_spatial_method)
mean_spatial_method.variable = self.mean_name
self.spatial_methods = copy.copy(spatial_methods)
self.spatial_methods.append(mean_spatial_method)
for method in self.spatial_methods:
print(method.variable)
#method.equation.residual = self.residual
print(len(method.equation.residual))

def pre_apply(self, x_in):
"""
Expand Down
10 changes: 10 additions & 0 deletions gusto/timestepping/split_timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,17 @@ def transporting_velocity(self):
return self.fields('u')

def setup_scheme(self):
print('Setting up base equation')
self.setup_equation(self.equation)

# If there is an augmentation, set up these transport terms
if self.scheme.augmentation is not None:
if self.scheme.augmentation.name == 'mean_mixing_ratio':
self.scheme.augmentation.setup_transport(self.spatial_methods)
print('Setting up augmented equation')
self.setup_equation(self.scheme.augmentation)


# Go through and label all non-physics terms with a "dynamics" label
dynamics = Label('dynamics')
self.equation.label_terms(lambda t: not any(t.has_label(time_derivative, physics_label)), dynamics)
Expand Down

0 comments on commit 9adf64e

Please sign in to comment.