Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IMEX Multistage #453

Merged
merged 26 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1a56e1b
add unit test
alexbrown1995 Sep 7, 2023
fae09a9
Needs tidying, but IMEX_Euler3 is a seemingly working version of IMEX…
alexbrown1995 Sep 14, 2023
97ff98f
Tidy and adding new schemes
alexbrown1995 Sep 18, 2023
1a1fdc0
Merge branch 'main' into IMEX_multistage
alexbrown1995 Sep 18, 2023
aa2d599
Add back in courant logging, tidy time disc
alexbrown1995 Sep 19, 2023
b1b4cf5
change to split advective and div term in SWE
alexbrown1995 Sep 29, 2023
00f9966
SSP3 correction
alexbrown1995 Oct 5, 2023
a6b28a1
Tidy of code, lint corrections
alexbrown1995 Oct 19, 2023
9fef4b2
Merge branch 'main' into IMEX_multistage
alexbrown1995 Oct 19, 2023
c8207bc
revert explicitmultistage changes
alexbrown1995 Oct 19, 2023
78f5557
revert test changes
alexbrown1995 Oct 19, 2023
177c308
adressed most review comments
alexbrown1995 Oct 25, 2023
3a34def
Warning for explicit implicit splitting added
alexbrown1995 Oct 25, 2023
564cafb
Comments added and equations refactor moved into time_discretisation.…
alexbrown1995 Oct 26, 2023
31ddf80
Splitting on conservative term added. Need to make it work in case of…
alexbrown1995 Oct 27, 2023
0c6fba4
Merge to main
alexbrown1995 Oct 27, 2023
85e57a7
Lint fix
alexbrown1995 Oct 27, 2023
10ed951
Labelling of implicit and explicit terms moved outside of time_disc, …
alexbrown1995 Nov 1, 2023
e866193
change error message
alexbrown1995 Nov 1, 2023
cccc433
Merge to main
alexbrown1995 Nov 1, 2023
5878f23
lint fix + firedrake fml move
alexbrown1995 Nov 1, 2023
958e002
Error switched to Runtime
alexbrown1995 Nov 1, 2023
6daae50
linearisation terms added
alexbrown1995 Nov 1, 2023
9191901
bug fix to test func / q
alexbrown1995 Nov 1, 2023
d29d137
Update gusto/common_forms.py
atb1995 Nov 7, 2023
7a14129
Change in description
alexbrown1995 Nov 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions gusto/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None,
space_names=None, linearisation_map='default',
u_transport_option='vector_invariant_form',
no_normal_flow_bc_ids=None, active_tracers=None,
thermal=False, conservative_depth=True):
thermal=False):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand Down Expand Up @@ -676,10 +676,7 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None,
raise ValueError("Invalid u_transport_option: %s" % u_transport_option)

# Depth transport term
if (conservative_depth):
D_adv = prognostic(continuity_form(phi, D, u), 'D')
else:
D_adv = prognostic(advection_form(phi, D, u), 'D')
D_adv = prognostic(continuity_form(phi, D, u), 'D')

# Transport term needs special linearisation
if self.linearisation_map(D_adv.terms[0]):
Expand Down Expand Up @@ -711,11 +708,6 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None,

residual = (mass_form + adv_form + pressure_gradient_form)

# Add divergence term
if (not conservative_depth):
geo_grad_form = subject(prognostic(phi*D*div(u)*dx), self.X)
residual += geo_grad_form

# -------------------------------------------------------------------- #
# Extra Terms (Coriolis, Topography and Thermal)
# -------------------------------------------------------------------- #
Expand Down
67 changes: 53 additions & 14 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

from abc import ABCMeta, abstractmethod, abstractproperty
from firedrake import (Function, TestFunction, NonlinearVariationalProblem,
NonlinearVariationalSolver, DirichletBC, Constant)
NonlinearVariationalSolver, DirichletBC, Constant, split,
div, dx)
from firedrake.formmanipulation import split_form
from firedrake.utils import cached_property

from gusto.configuration import EmbeddedDGOptions, RecoveryOptions
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, TransportEquationType
from gusto.fml import (
replace_subject, replace_test_function, Term, all_terms, drop
replace_subject, replace_test_function, Term, all_terms, drop, subject
)
from gusto.labels import time_derivative, prognostic, physics_label, transport, implicit, explicit
from gusto.labels import (time_derivative, prognostic, physics_label,
transport, implicit, explicit, transporting_velocity)
from gusto.common_forms import advection_form
from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual
from gusto.wrappers import *
import numpy as np
Expand Down Expand Up @@ -252,15 +255,15 @@ class IMEXMultistage(TimeDiscretisation):

For each i = 1, s in an s stage method
we compute the intermediate solutions: \n
y_i = y^n + dt*(a_i1*F(y_1) + a_i2*F(y_2)+ ... + a_ii*F(y_i)) \n
y_i = y^n + dt*(a_i1*F(y_1) + a_i2*F(y_2)+ ... + a_ii*F(y_i)) \n
+ dt*(d_i1*S(y_1) + d_i2*S(y_2)+ ... + d_{i,i-1}*S(y_{i-1}))

At the last stage, compute the new solution by: \n
y^{n+1} = y^n + dt*(b_1*F(y_1) + b_2*F(y_2) + .... + b_s*F(y_s)) \n
+ dt*(e_1*S(y_1) + e_2*S(y_2) + .... + e_s*S(y_s)) \n

"""
# ---------------------------------------------------------------------------
# --------------------------------------------------------------------------
# Butcher tableaus for a s-th order
# diagonally implicit scheme (left) and explicit scheme (right):
# c_0 | a_00 0 . 0 f_0 | 0 0 . 0
Expand All @@ -279,34 +282,31 @@ class IMEXMultistage(TimeDiscretisation):
# [a_20 a_21 a_22 . 0 ] [d_20 d_21 0 . 0 ]
# [ . . . . . ] [ . . . . . ]
# [ b_0 b_1 . b_s] [ e_0 e_1 . . e_s]
# ---------------------------------------------------------------------------
#
# --------------------------------------------------------------------------

def __init__(self, domain, butcher_imp, butcher_exp, field_name=None,
solver_parameters=None, limiter=None, options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
butcher_imp (numpy array): A matrix containing the coefficients of
butcher_imp (:class:`numpy.ndarray`): A matrix containing the coefficients of
a butcher tableau defining a given implicit Runge Kutta time discretisation.
butcher_exp (numpy array): A matrix containing the coefficients of
butcher_exp (:class:`numpy.ndarray`): A matrix containing the coefficients of
a butcher tableau defining a given explicit Runge Kutta time discretisation.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
subcycles (int, optional): the number of sub-steps to perform.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
the evolving field to enforce monotonicity. Defaults to None.
options (:class:`AdvectionOptions`, optional): an object containing
options to either be passed to the spatial discretisation, or
to control the "wrapper" methods, such as Embedded DG or a
recovery method. Defaults to None.
"""
super().__init__(domain, field_name=field_name,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
options=options)
self.butcher_imp = butcher_imp
self.butcher_exp = butcher_exp
self.nStages = int(np.shape(self.butcher_imp)[1])
Expand All @@ -323,13 +323,37 @@ def setup(self, equation, apply_bcs=True, *active_labels):

super().setup(equation, apply_bcs, *active_labels)

# Get continuity form transport term
for t in self.residual:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline: it might be best to move this code to break up the continuity form out of time discretistations to a new routine (we suggested that it would be in common_forms.py). Then a TimeDiscretisation can be agnostic to the terms that it works on, and this code than then be reused for other things.

if(t.get(transport) == TransportEquationType.conservative):
# Split continuity form term
test = t.form.arguments()[0]
subj = t.get(subject)
prognostic_field_name = t.get(prognostic)
idx = self.equation.field_names.index(prognostic_field_name)
transported_field = split(subj)[idx]
# u_idx = self.equation.field_names.index('u')
# uadv = split(self.equation.X)[u_idx]
# breakpoint()
uadv = t.get(transporting_velocity)
breakpoint()
new_transport_term = prognostic(subject(advection_form(test, transported_field, uadv) + test*transported_field*div(uadv)*dx, subj, prognostic_field_name))
# Add onto residual and drop old term
self.residual = self.residual.label_map(
lambda t: t.get(transport) == TransportEquationType.conservative,
map_if_true=drop)
self.residual += new_transport_term.form
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should just add the labelled form to the residual, not just the form part, i.e. self.residual += new_transport_term


# Label transport terms as explicit, all other terms as implicit
self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, transport)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should transport be included here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is a map if false, so all non-time derivative and transport terms are implicit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've discussed this offline: we realised that we can label terms as explicit or implicit in any example file using the equation.label_map(...) routine.

But here we should have a check that throws an error if a term hasn't been labelled as implicit/explicit (or is the time_derivative)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error checker has been added. I raised a "NotImplementedError" - is this the correct choice..?

map_if_false=lambda t: implicit(t))

self.residual = self.residual.label_map(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to enforce transport terms being explicit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, it could be left up to the user to label terms as explicit or implicit at the script level.. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that in the long term we would like users to be able to specify which terms they would like to make implicit / explicit and this could be via a dictionary that is passed in. But for now this seems like the most sensible option. I suggested to @atb1995 that he should add a warning message using the logger to tell people that transport terms are treated explicitly and all others are treated implicitly.

lambda t: t.has_label(transport),
map_if_true=lambda t: explicit(t))

logger.warning("Default IMEX Multistage treats transport terms explicitly, and all other terms implicitly")

self.xs = [Function(self.fs) for i in range(self.nStages)]

Expand All @@ -344,13 +368,19 @@ def rhs(self):
return super(IMEXMultistage, self).rhs

def res(self, stage):
"""Set up the discretisation's residual for a given stage."""
# Add time derivative terms y_s - y^n for stage s
mass_form = self.residual.label_map(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring here to explain this routine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do! Will also add for final_res and the various solvers

lambda t: t.has_label(time_derivative),
map_if_false=drop)
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calcualte/sum
# dt*(a_s1*F(y_1) + a_s2*F(y_2)+ ... + a_{s,s-1}*F(y_{s-1}))
# and
# dt*(d_s1*S(y_1) + d_s2*S(y_2)+ ... + d_{s,s-1}*S(y_{s-1}))
for i in range(stage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can follow what is going on in this for loop but maybe you could add some comments to explain it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Comments added

r_exp = self.residual.label_map(
lambda t: t.has_label(explicit),
Expand All @@ -368,6 +398,7 @@ def res(self, stage):
map_if_false=lambda t: Constant(self.butcher_imp[stage, i])*self.dt*t)
residual += r_imp
residual += r_exp
# Calculate and add on dt*a_ss*F(y_s)
r_imp = self.residual.label_map(
lambda t: t.has_label(implicit),
map_if_true=replace_subject(self.x_out, old_idx=self.idx),
Expand All @@ -380,12 +411,18 @@ def res(self, stage):

@property
def final_res(self):
"""Set up the discretisation's final residual."""
# Add time derivative terms y^{n+1} - y^n
mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't spot what is different about this compared with the normal res routine!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final residual loops across all stages, and also loops up to the final stage for explicit terms.

map_if_false=drop)
residual = mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
# Loop through stages up to s-1 and calcualte/sum
# dt*(b_1*F(y_1) + b_2*F(y_2) + .... + b_s*F(y_s))
# and
# dt*(e_1*S(y_1) + e_2*S(y_2) + .... + e_s*S(y_s))
for i in range(self.nStages):
r_exp = self.residual.label_map(
lambda t: t.has_label(explicit),
Expand All @@ -407,6 +444,7 @@ def final_res(self):

@cached_property
def solvers(self):
"""Set up a list of solvers for each problem at a stage."""
solvers = []
for stage in range(self.nStages):
# setup solver using residual defined in derived class
Expand All @@ -417,6 +455,7 @@ def solvers(self):

@cached_property
def final_solver(self):
"""Set up a solver for the final solve to evaluate time level n+1."""
# setup solver using lhs and rhs defined in derived class
problem = NonlinearVariationalProblem(self.final_res, self.x_out, bcs=self.bcs)
solver_name = self.field_name+self.__class__.__name__
Expand Down