Skip to content

Commit

Permalink
implementation for a single mixing ratio
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 committed Jan 7, 2025
1 parent 9adf64e commit 97b8917
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 48 deletions.
204 changes: 158 additions & 46 deletions gusto/spatial_methods/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
LinearVariationalProblem, LinearVariationalSolver, lhs, rhs, dot,
ds_b, ds_v, ds_t, ds, FacetNormal, TestFunction, TrialFunction,
transpose, nabla_grad, outer, dS, dS_h, dS_v, sign, jump, div,
Constant, sqrt, cross, curl, FunctionSpace, assemble, DirichletBC
Constant, sqrt, cross, curl, FunctionSpace, assemble, DirichletBC,
Projector
)
from firedrake.fml import (
subject, all_terms, replace_subject, keep, replace_test_function,
Expand All @@ -19,6 +20,7 @@
time_derivative, transport, transporting_velocity, TransportEquationType,
logger, prognostic, mass_weighted
)
from gusto.spatial_methods.transport_methods import DGUpwind
import copy


Expand Down Expand Up @@ -269,17 +271,26 @@ def __init__(
self.mean_name = 'mean_'+mX_name
print(self.mean_name)

self.eqns = eqns
self.eqn_orig = eqns
self.domain=domain
exist_spaces = eqns.spaces

self.idx_orig = len(exist_spaces)
mean_idx = self.idx_orig
self.mean_idx = mean_idx

# Extract the mixing ratio in question:
mX_idx = eqns.field_names.index(mX_name)
self.mX_idx = mX_idx

# Define the mean mixing ratio on the DG0 space
DG0 = FunctionSpace(domain.mesh, "DG", 0)
self.DG0 = Function(DG0)
mX_space = eqns.spaces[mean_idx]

self.mX_in = Function(mX_space)
self.mean_in = Function(mean_space)
self.compute_mean = Projector(mX_space, DG0)

# Set up the scheme for the mean mixing ratio

Expand All @@ -293,48 +304,57 @@ def __init__(

print(self.X)

self.bcs = []
self.bcs = None

self.x_in = Function(self.fs)
self.x_out = Function(self.fs)

# Compute the new mean mass weighted term,
# IF this is conservatively transported.
mX_idx = eqns.field_names.index(mX_name)

old_residual = eqns.residual
#old_residual = eqns.residual

mean_residual = old_residual.label_map(
lambda t: t.get(prognostic) == mX_name,
map_if_false=drop
)
mean_residual = prognostic.update_value(mean_residual, self.mean_name)
#mean_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == mX_name,
# map_if_false=drop
#)
#mean_residual = prognostic.update_value(mean_residual, self.mean_name)

# Replace trial functions with those in the new mixed function space
for term in eqns.residual:
print('\n')
print(term.form)
#for term in eqns.residual:
# print('\n')
# print(term.form)

for idx in range(self.idx_orig):
field = eqns.field_names[idx]
# Seperate logic if mass-weighted or not?
print(idx)
print(field)

prog = split(self.X)[idx]
# If I do this later on with the transport terms, maybe I don't need to do this here?

print('\n residual before change')
print(old_residual.form)
old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx)
)
old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx)
)
print('\n residual after change')
print(old_residual.form)
#for idx in range(self.idx_orig):
# field = eqns.field_names[idx]
# Seperate logic if mass-weighted or not?
# print(idx)
# print(field)

# prog = split(self.X)[idx]

# print('\n residual term before change')
# print(old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_false=drop
# ).form)

# old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx)
# )
# old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx)
# )

# print('\n residual term after change')
# print(old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_false=drop
# ).form)

#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
Expand All @@ -352,33 +372,101 @@ def __init__(
# map_if_false=replace_subject(self.X[mean_idx], old_idx = mX_idx)
#)

print('\n mean mX residual before change')
#print('\n mean mX residual before change')
#print(mean_residual.form)
#mean_residual = mean_residual.label_map(
# all_terms,
# replace_subject(self.X, old_idx=mX_idx, new_idx = mean_idx)
#)
#mean_residual = mean_residual.label_map(
# all_terms,
# replace_test_function(self.tests, old_idx=mX_idx, new_idx = mean_idx)
#)
#print('\n mean mX residual after change')
#print(mean_residual.form)

# Form the new residual
#self.residual = old_residual + mean_residual

#print('\n full residual')
#print(self.residual.form)

#print('yoyoyoy')

#for term in self.residual:
# print(term.get(prognostic))

def setup_transport(self, spatial_methods, equation):
# Copy spatial method for the mixing ratio onto the
# mean mixing ratio.

old_residual = equation.residual

# Copy the mean mixing ratio residual terms:
mean_residual = old_residual.label_map(
lambda t: t.get(prognostic) == self.mX_name,
map_if_false=drop
)
mean_residual = prognostic.update_value(mean_residual, self.mean_name)

print('\n in setup_transport')

# Replace the tests and trial functions for all terms
# of the fields in the original equation
for idx in range(self.idx_orig):
field = self.eqn_orig.field_names[idx]
# Seperate logic if mass-weighted or not?
print('\n', idx)
print(field)

prog = split(self.X)[idx]

print('\n residual term before change')
print(old_residual.label_map(
lambda t: t.get(prognostic) == field,
map_if_false=drop
).form)

old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field,
map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx)
)
old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field,
map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx)
)
print('\n residual term after change')
print(old_residual.label_map(
lambda t: t.get(prognostic) == field,
map_if_false=drop
).form)

print('\n now setting up mean mixing ratio residual terms')


print('\n mean mX residual after change')
print(mean_residual.form)
mean_residual = mean_residual.label_map(
all_terms,
replace_subject(self.X, old_idx=mX_idx, new_idx = mean_idx)
replace_subject(self.X, old_idx=self.mX_idx, new_idx=self.mean_idx)
)
mean_residual = mean_residual.label_map(
all_terms,
replace_test_function(self.tests, old_idx=mX_idx, new_idx = mean_idx)
replace_test_function(self.tests, old_idx=self.mX_idx, new_idx=self.mean_idx)
)
print('\n mean mX residual after change')
print(mean_residual.form)

# Form the new residual
self.residual = old_residual + mean_residual

print('\n full residual')
print(self.residual.form)
#Check these two forms
print('\n Original equation with residual of length, ', len(equation.residual))
print('\n Augmented equation with residual of length, ', len(self.residual))

print('yoyoyoy')

for term in self.residual:
print(term.get(prognostic))

def setup_transport(self, spatial_methods):
# Copy spatial method for the mixing ratio onto the
# mean mixing ratio.
def setup_transport_old(self, spatial_methods):
mX_spatial_method = next(method for method in spatial_methods if method.variable == self.mX_name)

mean_spatial_method = copy.copy(mX_spatial_method)
Expand All @@ -387,9 +475,19 @@ def setup_transport(self, spatial_methods):
self.spatial_methods.append(mean_spatial_method)
for method in self.spatial_methods:
print(method.variable)
#method.equation.residual = self.residual
method.equation.residual = self.residual
print(method.form.form)
print(len(method.equation.residual))

# Alternatively, redo all the spatial methods
# using the new mixed function space.
# So, want to make a new list of spatial methods
new_spatial_methods = []
for method in self.spatial_methods:
# Determine the tye of transport method:
new_method = DGUpwind(self, method.variable)
new_spatial_methods.append(new_method)

def pre_apply(self, x_in):
"""
Sets the original fields, i.e. not the mean mixing ratios
Expand All @@ -403,6 +501,12 @@ def pre_apply(self, x_in):
for idx in range(self.idx_orig):
self.x_in.subfunctions[idx].assign(x_in.subfunctions[idx])

# Set the mean mixing ratio to be zero, just because
#DG0 = FunctionSpace(self.domain.mesh, "DG", 0)
#mean_mX = Function(DG0, name=self.mean_name)

#self.x_in.subfunctions[self.mean_idx].assign(mean_mX)


def post_apply(self, x_out):
"""
Expand All @@ -412,7 +516,7 @@ def post_apply(self, x_out):
x_out (:class:`Function`): The output fields
"""

for idx, field in enumerate(self.eqn.field_names):
for idx in range(self.idx_orig):
x_out.subfunctions[idx].assign(self.x_out.subfunctions[idx])

def update(self, x_in_mixed):
Expand All @@ -422,5 +526,13 @@ def update(self, x_in_mixed):
Args:
x_in_mixed (:class:`Function`): The mixed function to update.
"""


# Compute the mean mixing ratio
# How do I do this?
self.mX_in.assign(x_in_mixed[self.mX_idx])

# Project this into the lowest order space:
self.compute_mean.project()

self.x_in.subfunctions[self.mean_idx].assign(self.mean_out)
pass
4 changes: 2 additions & 2 deletions gusto/timestepping/split_timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,11 @@ def setup_scheme(self):
self.setup_equation(self.equation)

# If there is an augmentation, set up these transport terms
# Or, perhaps set up the whole residual now ... ?
if self.scheme.augmentation is not None:
if self.scheme.augmentation.name == 'mean_mixing_ratio':
self.scheme.augmentation.setup_transport(self.spatial_methods)
self.scheme.augmentation.setup_transport(self.spatial_methods, self.equation)
print('Setting up augmented equation')
self.setup_equation(self.scheme.augmentation)


# Go through and label all non-physics terms with a "dynamics" label
Expand Down

0 comments on commit 97b8917

Please sign in to comment.