Skip to content

Commit

Permalink
set the mean mixing ratio augmentation residual
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 committed Jan 2, 2025
1 parent fa1062d commit e9b05a2
Showing 1 changed file with 93 additions and 14 deletions.
107 changes: 93 additions & 14 deletions gusto/spatial_methods/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
Constant, sqrt, cross, curl, FunctionSpace, assemble, DirichletBC
)
from firedrake.fml import (
subject, all_terms, replace_subject, keep, replace_test_function
subject, all_terms, replace_subject, keep, replace_test_function,
replace_trial_function, drop
)
from gusto import (
time_derivative, transport, transporting_velocity, TransportEquationType,
Expand Down Expand Up @@ -274,32 +275,110 @@ def __init__(

# Set up the scheme for the mean mixing ratio

mean_mX = Function(DG0)
mean_mX = Function(DG0, name='mean_mX')
mean_space = DG0
exist_spaces.append(mean_space)

self.fs = MixedFunctionSpace(exist_spaces)
self.X = Function(self.fs)
self.tests = TestFunctions(self.fs)

print(self.X)

self.bcs = []

self.x_in = Function(self.fs)
self.x_out = Function(self.fs)

# Compute the new mean mass weighted term,
# IF this is conservatively transported.
mX_idx = eqns.field_names.index(mX_name)

#mean_mass =
old_residual = eqns.residual

mean_residual = old_residual.label_map(
lambda t: t.get(prognostic) == mX_name,
map_if_false=drop
)

# Replace trial functions with those in the new mixed function space
for term in eqns.residual:
print('\n')
print(term.form)

for idx in range(self.idx_orig):
field = eqns.field_names[idx]
# Seperate logic if mass-weighted or not?
print(idx)
print(field)

prog = split(self.X)[idx]

print('\n residual before change')
print(old_residual.form)
old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
map_if_true=replace_subject(self.X, old_idx=idx, new_idx = idx)
)
old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
map_if_true=replace_test_function(self.tests, old_idx=idx, new_idx=idx)
)
print('\n residual after change')
print(old_residual.form)

#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_subject(self.X[idx], old_idx=idx)
#)
#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_test_function(self.tests[idx], old_idx=idx)
#)

# Define the mean mixing ratio residual
#mean_residual = mX_residual.label_map(
# lambda t: t.has_label(mass_weighted),
#map_if_true=replace_subject(mean_mass, old_idx=mX_idx),
# map_if_false=replace_subject(self.X[mean_idx], old_idx = mX_idx)
#)

print('\n mean mX residual before change')
print(mean_residual.form)
mean_residual = mean_residual.label_map(
all_terms,
replace_subject(self.X, old_idx=mX_idx, new_idx = mean_idx)
)
mean_residual = mean_residual.label_map(
all_terms,
replace_test_function(self.tests, old_idx=mX_idx, new_idx = mean_idx)
)
print('\n mean mX residual after change')
print(mean_residual.form)

# Form the new residual
self.residual = old_residual + mean_residual

print('\n full residual')
print(self.residual.form)



def setup_residual(self):
# Update the residual
# Now, extract the existing residual:
old_residual = eqns.residual

print(old_residual.form)

# Extract terms relating to the mixing ratio of interest
mX_residual = old_residual.label_map(
lambda t: t.get(prognostic) == mX_name,
map_if_true=keep
map_if_false=drop
)

print(mX_residual.form)

# Replace trial and test functions with the new mixed
# function space
# Does this work is the subject is mass-weighted???
Expand All @@ -309,15 +388,16 @@ def __init__(
print(idx)
print(field)

prog = split(self.X)[idx]

old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
map_if_true=replace_subject(self.X[idx], old_idx=idx)
)
old_residual = old_residual.label_map(
lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
map_if_true=replace_test_function(self.tests[idx], old_idx=idx)
)
#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_subject(self.X[idx], old_idx=idx)
#)
#old_residual = old_residual.label_map(
# lambda t: t.get(prognostic) == field and not t.has_label(mass_weighted),
# map_if_true=replace_test_function(self.tests[idx], old_idx=idx)
#)

# Define the mean mixing ratio residual
mean_residual = mX_residual.label_map(
Expand All @@ -331,10 +411,9 @@ def __init__(
map_if_false=replace_test_function(self.tests[mean_idx], old_idx=mX_idx)
)

# Form the new residual
self.residual = old_residual + mean_residual

self.bcs = []


def pre_apply(self, x_in):
"""
Expand Down

0 comments on commit e9b05a2

Please sign in to comment.