Skip to content

Commit

Permalink
first version of mean mixing ratio working
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 committed Jan 9, 2025
1 parent 97b8917 commit bb5cbe1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 123 deletions.
153 changes: 35 additions & 118 deletions gusto/spatial_methods/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class MeanMixingRatio(Augmentation):
Args:
domain (:class:`Domain`): The domain object.
eqns (:class:`PrognosticEquationSet`): The overarching equation set.
mixing_ratio (:class: list): List of mixing ratios that
mixing_ratio (:class: list): A list of mixing ratios that

Check failure on line 258 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

W291

gusto/spatial_methods/augmentation.py:258:66: W291 trailing whitespace
are to have augmented mean mixing ratio fields.
OR, keep as a single mixing ratio, but define

Check failure on line 260 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

W291

gusto/spatial_methods/augmentation.py:260:54: W291 trailing whitespace
multiple augmentations?
Expand All @@ -267,15 +267,16 @@ def __init__(
):

self.name = 'mean_mixing_ratio'

Check failure on line 270 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

W293

gusto/spatial_methods/augmentation.py:270:1: W293 blank line contains whitespace
self.mX_name = mX_name
self.mean_name = 'mean_'+mX_name
print(self.mean_name)

self.eqn_orig = eqns
self.domain=domain

Check failure on line 275 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

E225

gusto/spatial_methods/augmentation.py:275:20: E225 missing whitespace around operator
exist_spaces = eqns.spaces

self.idx_orig = len(exist_spaces)

# Change this to adjust for a list
mean_idx = self.idx_orig
self.mean_idx = mean_idx

Expand All @@ -285,12 +286,14 @@ def __init__(

# Define the mean mixing ratio on the DG0 space
DG0 = FunctionSpace(domain.mesh, "DG", 0)
self.DG0 = Function(DG0)
mX_space = eqns.spaces[mean_idx]
mX_space = eqns.spaces[mX_idx]

# But, what if the mixing ratios are in

Check failure on line 291 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

W291

gusto/spatial_methods/augmentation.py:291:48: W291 trailing whitespace
# different function spaces???

self.mX_in = Function(mX_space)
self.mean_in = Function(mean_space)
self.compute_mean = Projector(mX_space, DG0)
self.mean = Function(DG0)
self.compute_mean = Projector(self.mX_in, self.mean)

# Set up the scheme for the mean mixing ratio

Check failure on line 298 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

E114

gusto/spatial_methods/augmentation.py:298:10: E114 indentation is not a multiple of 4 (comment)

Expand All @@ -302,101 +305,13 @@ def __init__(
self.X = Function(self.fs)
self.tests = TestFunctions(self.fs)

print(self.X)

self.bcs = None

self.x_in = Function(self.fs)
self.x_out = Function(self.fs)

# Compute the new mean mass weighted term,
# IF this is conservatively transported.

#old_residual = eqns.residual

#mean_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == mX_name,
# map_if_false=drop
#)
#mean_residual = prognostic.update_value(mean_residual, self.mean_name)

# Replace trial functions with those in the new mixed function space
#for term in eqns.residual:
# print('\n')
# print(term.form)

# If I do this later on with the transport terms, maybe I don't need to do this here?

#for idx in range(self.idx_orig):
# field = eqns.field_names[idx]
# Seperate logic if mass-weighted or not?
# print(idx)
# print(field)

# prog = split(self.X)[idx]

# print('\n residual term before change')
# print(old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_false=drop
# ).form)

# old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx)
# )
# old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx)
# )

# print('\n residual term after change')
# print(old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_false=drop
# ).form)

#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_subject(self.X[idx], old_idx=idx)
#)
#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_test_function(self.tests[idx], old_idx=idx)
#)

# Define the mean mixing ratio residual
#mean_residual = mX_residual.label_map(
# lambda t: t.has_label(mass_weighted),
#map_if_true=replace_subject(mean_mass, old_idx=mX_idx),
# map_if_false=replace_subject(self.X[mean_idx], old_idx = mX_idx)
#)

#print('\n mean mX residual before change')
#print(mean_residual.form)
#mean_residual = mean_residual.label_map(
# all_terms,
# replace_subject(self.X, old_idx=mX_idx, new_idx = mean_idx)
#)
#mean_residual = mean_residual.label_map(
# all_terms,
# replace_test_function(self.tests, old_idx=mX_idx, new_idx = mean_idx)
#)
#print('\n mean mX residual after change')
#print(mean_residual.form)

# Form the new residual
#self.residual = old_residual + mean_residual

#print('\n full residual')
#print(self.residual.form)

#print('yoyoyoy')

#for term in self.residual:
# print(term.get(prognostic))

def setup_transport(self, spatial_methods, equation):
def setup_residual(self, spatial_methods, equation):
# Copy spatial method for the mixing ratio onto the
# mean mixing ratio.

Expand All @@ -416,16 +331,16 @@ def setup_transport(self, spatial_methods, equation):
for idx in range(self.idx_orig):
field = self.eqn_orig.field_names[idx]
# Seperate logic if mass-weighted or not?
print('\n', idx)
print(field)
#print('\n', idx)
#print(field)

prog = split(self.X)[idx]

print('\n residual term before change')
print(old_residual.label_map(
lambda t: t.get(prognostic) == field,
map_if_false=drop
).form)
#print('\n residual term before change')
#print(old_residual.label_map(
# lambda t: t.get(prognostic) == field,
# map_if_false=drop
#).form)

old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field,
Expand All @@ -435,17 +350,17 @@ def setup_transport(self, spatial_methods, equation):
lambda t: t.get(prognostic) == field,
map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx)
)
print('\n residual term after change')
print(old_residual.label_map(
lambda t: t.get(prognostic) == field,
map_if_false=drop
).form)
#print('\n residual term after change')
#print(old_residual.label_map(
# lambda t: t.get(prognostic) == field,
# map_if_false=drop
#).form)

print('\n now setting up mean mixing ratio residual terms')
#print('\n now setting up mean mixing ratio residual terms')


print('\n mean mX residual after change')
print(mean_residual.form)
#print('\n mean mX residual after change')
#print(mean_residual.form)
mean_residual = mean_residual.label_map(
all_terms,
replace_subject(self.X, old_idx=self.mX_idx, new_idx=self.mean_idx)
Expand All @@ -454,15 +369,17 @@ def setup_transport(self, spatial_methods, equation):
all_terms,
replace_test_function(self.tests, old_idx=self.mX_idx, new_idx=self.mean_idx)
)
print('\n mean mX residual after change')
print(mean_residual.form)
#print('\n mean mX residual after change')
#print(mean_residual.form)

# Form the new residual
self.residual = old_residual + mean_residual
residual = old_residual + mean_residual
self.residual = subject(residual, self.X)

#Check these two forms
print('\n Original equation with residual of length, ', len(equation.residual))
print('\n Augmented equation with residual of length, ', len(self.residual))
#print('\n Original equation with residual of length, ', len(equation.residual))
#print('\n Augmented equation with residual of length, ', len(self.residual))




Expand Down Expand Up @@ -529,10 +446,10 @@ def update(self, x_in_mixed):

# Compute the mean mixing ratio
# How do I do this?
self.mX_in.assign(x_in_mixed[self.mX_idx])
self.mX_in.assign(x_in_mixed.subfunctions[self.mX_idx])

# Project this into the lowest order space:
self.compute_mean.project()

self.x_in.subfunctions[self.mean_idx].assign(self.mean_out)
self.x_in.subfunctions[self.mean_idx].assign(self.mean)
pass
19 changes: 14 additions & 5 deletions gusto/timestepping/split_timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from firedrake.fml import Label, drop
from pyop2.profiling import timed_stage
from gusto.core import TimeLevelFields, StateFields
from gusto.core.labels import time_derivative, physics_label
from gusto.core.labels import time_derivative, physics_label, dynamics_label
from gusto.time_discretisation.time_discretisation import ExplicitTimeDiscretisation
from gusto.timestepping.timestepper import BaseTimestepper, Timestepper
from numpy import ones
Expand Down Expand Up @@ -299,13 +299,21 @@ def setup_scheme(self):
print('Setting up base equation')
self.setup_equation(self.equation)

# If there is an augmentation, set up these transport terms
# Or, perhaps set up the whole residual now ... ?
# If there is an augmentation, set up the residual now
if self.scheme.augmentation is not None:
if self.scheme.augmentation.name == 'mean_mixing_ratio':
self.scheme.augmentation.setup_transport(self.spatial_methods, self.equation)
self.scheme.augmentation.setup_residual(self.spatial_methods, self.equation)
print('Setting up augmented equation')

# Go through and label all non-physics terms with a "dynamics" label
dynamics = Label('dynamics')
self.scheme.augmentation.residual = self.scheme.augmentation.residual.label_map(
lambda t: not any(t.has_label(time_derivative, physics_label)),
map_if_true=lambda t: dynamics(t)
)
print(len(self.scheme.augmentation.residual.label_map(
lambda t: t.has_label(dynamics),
map_if_false=drop
)))

# Go through and label all non-physics terms with a "dynamics" label
dynamics = Label('dynamics')
Expand All @@ -316,6 +324,7 @@ def setup_scheme(self):
if self.io.output.log_courant:
self.scheme.courant_max = self.io.courant_max


def setup_prescribed_expr(self, expr_func):
"""
Sets up the prescribed transporting velocity, through a python function
Expand Down

0 comments on commit bb5cbe1

Please sign in to comment.