Skip to content

Commit

Permalink
more changes to mixed options with replacing test functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 committed Jan 16, 2024
1 parent 5f05367 commit 5c13db1
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 106 deletions.
93 changes: 40 additions & 53 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from gusto.configuration import EmbeddedDGOptions, RecoveryOptions
from gusto.labels import (time_derivative, prognostic, physics_label,
implicit, explicit)
implicit, explicit, transport)
from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual
from gusto.wrappers import *

Expand All @@ -43,15 +43,6 @@ def new_apply(self, x_out, x_in):

self.wrapper.pre_apply(x_in)
original_apply(self, self.wrapper.x_out, self.wrapper.x_in)

#if type(self.wrapper) == MixedOptions:
#print('x_out', x_out)
#print('self.wrapper.x_out', self.wrapper.x_out)
#original_apply(self, x_out, x_in)
# original_apply(self, self.wrapper.x_out, self.wrapper.x_in)
#else:
# original_apply(self, self.wrapper.x_out, self.wrapper.x_in)

self.wrapper.post_apply(x_out)

return new_apply(self, x_out, x_in)
Expand Down Expand Up @@ -97,12 +88,8 @@ def __init__(self, domain, field_name=None, solver_parameters=None,

if options is not None:
if type(options) == MixedOptions:
#print('jahjah')
#print(options)
self.wrapper = options
#self.subwrappers = {}
# Or do I need to initialise everything here
# like is done for a single wrapper?

Check failure on line 92 in gusto/time_discretisation.py

View workflow job for this annotation

GitHub Actions / Run linter

W293

gusto/time_discretisation.py:92:1: W293 blank line contains whitespace
for field, suboption in self.wrapper.suboptions.items():
print(field)
print(suboption)
Expand All @@ -116,10 +103,7 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
else:
raise RuntimeError(
f'Time discretisation: suboption wrapper {wrapper_name} not implemented')

Check failure on line 105 in gusto/time_discretisation.py

View workflow job for this annotation

GitHub Actions / Run linter

E122

gusto/time_discretisation.py:105:25: E122 continuation line missing indentation or outdented
#print(self.wrapper)
#print(self.wrapper.suboptions)
#print(self.wrapper.subwrappers)
#self.wrapper_name = 'mixed'

Check failure on line 106 in gusto/time_discretisation.py

View workflow job for this annotation

GitHub Actions / Run linter

W293

gusto/time_discretisation.py:106:1: W293 blank line contains whitespace
else:
self.wrapper_name = options.name
if self.wrapper_name == "embedded_dg":
Expand Down Expand Up @@ -194,58 +178,61 @@ def setup(self, equation, apply_bcs=True, *active_labels):

if self.wrapper is not None:
if type(self.wrapper) == MixedOptions:
#if self.wrapper_name == 'mixed':
# Subwrappers are defined.
# Set these up with ?

# Give more than one fs?
#fields = []

print(self.wrapper.wrapper_spaces)

#print(self.wrapper.wrapper_spaces)

for field, subwrapper in self.wrapper.subwrappers.items():
field_idx = equation.field_names.index(field)
self.wrapper.subwrappers[field].idx = field_idx
self.wrapper.subwrappers[field].mixed_options = True

subwrapper.idx = field_idx
subwrapper.mixed_options = True

# Store the original space of the tracer
self.wrapper.subwrappers[field].tracer_fs = self.equation.spaces[equation.field_names.index(field)]
subwrapper.tracer_fs = self.equation.spaces[field_idx]

self.wrapper.subwrappers[field].setup()
subwrapper.setup()

# Update the function space to that needed by the wrapper
self.wrapper.wrapper_spaces[field_idx] = self.wrapper.subwrappers[field].function_space
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space

# Store test space?
#self.wrapper.test_spaces[field_idx] = subwrapper.function_space

# Replace the test function space
# This currently won't work: supg test
# is a function, not a space
if self.wrapper.suboptions[field].name == "supg":
self.wrapper.test_spaces[field_idx] = self.wrapper.subwrappers[field].test
new_test = subwrapper.test
else:
self.wrapper.test_spaces[field_idx] = self.wrapper.subwrappers[field].test_space
new_test = TestFunction(subwrapper.test_space)

self.residual = self.residual.label_map(
lambda t: t.has_label(transport) and t.get(prognostic) == field,
map_if_true=replace_test_function(new_test, old_idx=field_idx))

self.residual = subwrapper.label_terms(self.residual)

# Currently can only use one set of solver parameters ...
if self.solver_parameters is None:
self.solver_parameters = subwrapper.solver_parameters

self.wrapper.setup()

# Check if mixed function spcae has changed:
if self.wrapper.function_space == equation.function_space:
print('same')
else:
print('different')

self.fs = self.wrapper.function_space

# Or replace test functions here??
new_test = TestFunction(self.wrapper.test_space)

self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))

# Only call this if with SUPG, so should put this into the
# previous section
# self.residual = self.wrapper.label_terms(self.residual)

#new_test_mixed = TestFunction(self.fs)

#for field, subwrapper in self.wrapper.subwrappers.items():
# field_idx = equation.field_names.index(field)

# if self.wrapper.suboptions[field].name == "supg":
# new_test_mixed[field_idx] = subwrapper.test

# self.residual = self.residual.label_map(
# lambda t: t.has_label(transport) and t.get(prognostic) == field,
# map_if_true=replace_test_function(new_test_mixed[field_idx], old_idx=field_idx))

# self.residual = subwrapper.label_terms(self.residual)



else:
self.wrapper.setup()
Expand Down
105 changes: 57 additions & 48 deletions gusto/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,39 +414,21 @@ def __init__(self, equation, suboptions):
Raises:
ValueError: If an option is defined for a field that is not in the prognostic variable set
"""
#print(equation.function_space)
#print(equation.active_tracers)
#print(equation.active_tracers[0])
#print(equation.active_tracers[1])
#print(equation.space_names)
#print(equation.spaces)

#self.wrapper_spaces = equation.space_names
self.wrapper_spaces = equation.spaces
self.test_spaces = equation.spaces
print(len(equation.spaces))

#self.x_in = Function(equation.function_space)
#self.x_out = Function(equation.function_space)

print('Initialising MixedOptions')

self.field_names = equation.field_names
self.suboptions = suboptions

#print(suboptions.items())

self.subwrappers = {}

for field, suboption in suboptions.items():
#print(field)
#print(suboption)

# Check that the field is in the prognostic variable set:
if field not in equation.field_names:
raise ValueError(f"The limiter defined for {field} is for a field that does not exist in the equation set")
else:
#else:
# check that a valid wrapper has been given
wrapper_name = suboption.name

# wrapper_name = suboption.name
#print(wrapper_name)

# Extract the space?
Expand All @@ -469,46 +451,73 @@ def __init__(self, equation, suboptions):
# self.subwrapper.x_in = Function(self.suboptions[field].fs)

def setup(self):
# This is done in the suboption wrappers themselves
# Or, determine the new mixed function space

#Loop over all active variables?
# Compute the new mixed function space
self.function_space = MixedFunctionSpace(self.wrapper_spaces)
self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)
self.test_space = MixedFunctionSpace(self.test_spaces)

def pre_apply(self, x_in):
"""
Perform the pre-applications from all subwrappers
Perform the pre-applications for all fields
with an associated subwrapper.
"""
#self.x_in.assign(x_in)
self.x_in = x_in

for _, subwrapper in self.subwrappers.items():
print(subwrapper)
print('pre')
field = x_in.subfunctions[subwrapper.idx]
subwrapper.pre_apply(field)
#x_in.subfunctions[subwrapper.idx] = subwrapper.x_in

x_in_sub = self.x_in.subfunctions[subwrapper.idx]
#x_in_sub.assign(subwrapper.x_in)
x_in_sub = subwrapper.x_in
#self.x_in.subfunctions[subwrapper.idx] = subwrapper.x_in
for field_name in self.field_names:
#print(field_name)
#print(self.subwrappers)

field_idx = self.field_names.index(field_name)
#print(field_idx)

field = x_in.subfunctions[field_idx]
x_in_sub = self.x_in.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
print(subwrapper)
print('pre')
#field = x_in.subfunctions[subwrapper.idx]
subwrapper.pre_apply(field)
#x_in.subfunctions[subwrapper.idx] = subwrapper.x_in

#x_in_sub = self.x_in.subfunctions[subwrapper.idx]
x_in_sub.assign(subwrapper.x_in)
#x_in_sub = subwrapper.x_in
#self.x_in.subfunctions[subwrapper.idx] = subwrapper.x_in
else:
x_in_sub.assign(field)


def post_apply(self, x_out):
"""
Perform the post-applications from all subwrappers
Perform the post-applications for all fields
with an associated subwrapper.
"""
x_out.assign(self.x_out)
#x_out.assign(self.x_out)

for _, subwrapper in self.subwrappers.items():
print(subwrapper)
print('post')
field = x_out.subfunctions[subwrapper.idx]
subwrapper.post_apply(field)
#for _, subwrapper in self.subwrappers.items():
# print(subwrapper)
# print('post')
# field = x_out.subfunctions[subwrapper.idx]
# subwrapper.post_apply(field)

x_out_sub = x_out.subfunctions[subwrapper.idx]
x_out_sub.assign(subwrapper.x_out)
# x_out_sub = x_out.subfunctions[subwrapper.idx]
# x_out_sub.assign(subwrapper.x_out)

for field_name in self.field_names:

field_idx = self.field_names.index(field_name)

field = self.x_out.subfunctions[field_idx]
x_out_sub = x_out.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
print(subwrapper)
print('post')
subwrapper.post_apply(field)
x_out_sub.assign(subwrapper.x_out)
else:
x_out_sub.assign(field)
10 changes: 5 additions & 5 deletions integration-tests/transport/test_mixed_fs_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,23 @@ def setup_limiters(dirname, space_A, space_B):

# Tracer B spaces
if space_B == 'DG':
space_B_string = 'DG'
if degree == 0:
VB = domain.spaces('DG')
VCG1_B = FunctionSpace(mesh, 'CG', 1)
VDG1_B = domain.spaces('DG1_equispaced')
space_B_string = 'DG'
elif degree == 1:
VB = domain.spaces('DG')
space_B_string = 'DG'
else:
raise NotImplementedError
elif space_B == 'Vtheta':
space_B_string = 'theta'
if degree == 0:
VB = domain.spaces('theta')
VCG1_B = FunctionSpace(mesh, 'CG', 1)
VDG1_B = domain.spaces('DG1_equispaced')
space_B_string = 'theta'
elif degree == 1:
VB = domain.spaces('theta')
space_B_string = 'theta'
else:
raise NotImplementedError
else:
Expand Down Expand Up @@ -178,6 +176,9 @@ def setup_limiters(dirname, space_A, space_B):
# DG Upwind transport for both tracers:
transport_method = [DGUpwind(eqn, 'tracerA'), DGUpwind(eqn, 'tracerB')]

# Need to give SUPG options to the above, if using supg ...
# Need to test SUPG here!

# Build time stepper
stepper = PrescribedTransport(eqn, transport_schemes, io, transport_method)

Expand Down Expand Up @@ -279,7 +280,6 @@ def setup_limiters(dirname, space_A, space_B):

@pytest.mark.parametrize('space_A', ['Vtheta_degree_0', 'Vtheta_degree_1', 'DG0',
'DG1', 'DG1_equispaced'])
#@pytest.mark.parametrize('space_A', ['DG0'])#, 'DG1_equispaced'])
# It only makes sense to use the same degree for tracer B
@pytest.mark.parametrize('space_B', ['Vtheta', 'DG'])

Expand Down

0 comments on commit 5c13db1

Please sign in to comment.