-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IMEX Multistage #453
IMEX Multistage #453
Changes from 3 commits
1a56e1b
fae09a9
97ff98f
1a1fdc0
aa2d599
b1b4cf5
00f9966
a6b28a1
9fef4b2
c8207bc
78f5557
177c308
3a34def
564cafb
31ddf80
0c6fba4
85e57a7
10ed951
e866193
cccc433
5878f23
958e002
6daae50
9191901
d29d137
7a14129
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,15 +7,18 @@ | |
|
||
from abc import ABCMeta, abstractmethod, abstractproperty | ||
from firedrake import (Function, TestFunction, NonlinearVariationalProblem, | ||
NonlinearVariationalSolver, DirichletBC, Constant) | ||
NonlinearVariationalSolver, DirichletBC, Constant, split, | ||
div, dx) | ||
from firedrake.formmanipulation import split_form | ||
from firedrake.utils import cached_property | ||
|
||
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions | ||
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, TransportEquationType | ||
from gusto.fml import ( | ||
replace_subject, replace_test_function, Term, all_terms, drop | ||
replace_subject, replace_test_function, Term, all_terms, drop, subject | ||
) | ||
from gusto.labels import time_derivative, prognostic, physics_label, transport, implicit, explicit | ||
from gusto.labels import (time_derivative, prognostic, physics_label, | ||
transport, implicit, explicit, transporting_velocity) | ||
from gusto.common_forms import advection_form | ||
from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual | ||
from gusto.wrappers import * | ||
import numpy as np | ||
|
@@ -252,15 +255,15 @@ class IMEXMultistage(TimeDiscretisation): | |
|
||
For each i = 1, s in an s stage method | ||
we compute the intermediate solutions: \n | ||
y_i = y^n + dt*(a_i1*F(y_1) + a_i2*F(y_2)+ ... + a_ii*F(y_i)) \n | ||
y_i = y^n + dt*(a_i1*F(y_1) + a_i2*F(y_2)+ ... + a_ii*F(y_i)) \n | ||
+ dt*(d_i1*S(y_1) + d_i2*S(y_2)+ ... + d_{i,i-1}*S(y_{i-1})) | ||
|
||
At the last stage, compute the new solution by: \n | ||
y^{n+1} = y^n + dt*(b_1*F(y_1) + b_2*F(y_2) + .... + b_s*F(y_s)) \n | ||
+ dt*(e_1*S(y_1) + e_2*S(y_2) + .... + e_s*S(y_s)) \n | ||
|
||
""" | ||
# --------------------------------------------------------------------------- | ||
# -------------------------------------------------------------------------- | ||
# Butcher tableaus for a s-th order | ||
# diagonally implicit scheme (left) and explicit scheme (right): | ||
# c_0 | a_00 0 . 0 f_0 | 0 0 . 0 | ||
|
@@ -279,34 +282,31 @@ class IMEXMultistage(TimeDiscretisation): | |
# [a_20 a_21 a_22 . 0 ] [d_20 d_21 0 . 0 ] | ||
# [ . . . . . ] [ . . . . . ] | ||
# [ b_0 b_1 . b_s] [ e_0 e_1 . . e_s] | ||
# --------------------------------------------------------------------------- | ||
# | ||
# -------------------------------------------------------------------------- | ||
|
||
def __init__(self, domain, butcher_imp, butcher_exp, field_name=None, | ||
solver_parameters=None, limiter=None, options=None): | ||
""" | ||
Args: | ||
domain (:class:`Domain`): the model's domain object, containing the | ||
mesh and the compatible function spaces. | ||
butcher_imp (numpy array): A matrix containing the coefficients of | ||
butcher_imp (:class:`numpy.ndarray`): A matrix containing the coefficients of | ||
a butcher tableau defining a given implicit Runge Kutta time discretisation. | ||
butcher_exp (numpy array): A matrix containing the coefficients of | ||
butcher_exp (:class:`numpy.ndarray`): A matrix containing the coefficients of | ||
a butcher tableau defining a given explicit Runge Kutta time discretisation. | ||
field_name (str, optional): name of the field to be evolved. | ||
Defaults to None. | ||
subcycles (int, optional): the number of sub-steps to perform. | ||
Defaults to None. | ||
solver_parameters (dict, optional): dictionary of parameters to | ||
pass to the underlying solver. Defaults to None. | ||
limiter (:class:`Limiter` object, optional): a limiter to apply to | ||
the evolving field to enforce monotonicity. Defaults to None. | ||
options (:class:`AdvectionOptions`, optional): an object containing | ||
options to either be passed to the spatial discretisation, or | ||
to control the "wrapper" methods, such as Embedded DG or a | ||
recovery method. Defaults to None. | ||
""" | ||
super().__init__(domain, field_name=field_name, | ||
solver_parameters=solver_parameters, | ||
limiter=limiter, options=options) | ||
options=options) | ||
self.butcher_imp = butcher_imp | ||
self.butcher_exp = butcher_exp | ||
self.nStages = int(np.shape(self.butcher_imp)[1]) | ||
|
@@ -323,13 +323,37 @@ def setup(self, equation, apply_bcs=True, *active_labels): | |
|
||
super().setup(equation, apply_bcs, *active_labels) | ||
|
||
# Get continuity form transport term | ||
for t in self.residual: | ||
if(t.get(transport) == TransportEquationType.conservative): | ||
# Split continuity form term | ||
test = t.form.arguments()[0] | ||
subj = t.get(subject) | ||
prognostic_field_name = t.get(prognostic) | ||
idx = self.equation.field_names.index(prognostic_field_name) | ||
transported_field = split(subj)[idx] | ||
# u_idx = self.equation.field_names.index('u') | ||
# uadv = split(self.equation.X)[u_idx] | ||
# breakpoint() | ||
uadv = t.get(transporting_velocity) | ||
breakpoint() | ||
new_transport_term = prognostic(subject(advection_form(test, transported_field, uadv) + test*transported_field*div(uadv)*dx, subj, prognostic_field_name)) | ||
# Add onto residual and drop old term | ||
self.residual = self.residual.label_map( | ||
lambda t: t.get(transport) == TransportEquationType.conservative, | ||
map_if_true=drop) | ||
self.residual += new_transport_term.form | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should just add the labelled form to the residual, not just the form part, i.e. |
||
|
||
# Label transport terms as explicit, all other terms as implicit | ||
self.residual = self.residual.label_map( | ||
lambda t: any(t.has_label(time_derivative, transport)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should transport be included here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is a map if false, so all non-time derivative and transport terms are implicit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We've discussed this offline: we realised that we can label terms as But here we should have a check that throws an error if a term hasn't been labelled as implicit/explicit (or is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An error checker has been added. I raised a "NotImplementedError" - is this the correct choice..? |
||
map_if_false=lambda t: implicit(t)) | ||
|
||
self.residual = self.residual.label_map( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to enforce transport terms being explicit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a good point, it could be left up to the user to label terms as explicit or implicit at the script level.. What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that in the long term we would like users to be able to specify which terms they would like to make implicit / explicit and this could be via a dictionary that is passed in. But for now this seems like the most sensible option. I suggested to @atb1995 that he should add a warning message using the logger to tell people that transport terms are treated explicitly and all others are treated implicitly. |
||
lambda t: t.has_label(transport), | ||
map_if_true=lambda t: explicit(t)) | ||
|
||
logger.warning("Default IMEX Multistage treats transport terms explicitly, and all other terms implicitly") | ||
|
||
self.xs = [Function(self.fs) for i in range(self.nStages)] | ||
|
||
|
@@ -344,13 +368,19 @@ def rhs(self): | |
return super(IMEXMultistage, self).rhs | ||
|
||
def res(self, stage): | ||
"""Set up the discretisation's residual for a given stage.""" | ||
# Add time derivative terms y_s - y^n for stage s | ||
mass_form = self.residual.label_map( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a docstring here to explain this routine? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will do! Will also add for final_res and the various solvers |
||
lambda t: t.has_label(time_derivative), | ||
map_if_false=drop) | ||
residual = mass_form.label_map(all_terms, | ||
map_if_true=replace_subject(self.x_out, old_idx=self.idx)) | ||
residual -= mass_form.label_map(all_terms, | ||
map_if_true=replace_subject(self.x1, old_idx=self.idx)) | ||
# Loop through stages up to s-1 and calcualte/sum | ||
# dt*(a_s1*F(y_1) + a_s2*F(y_2)+ ... + a_{s,s-1}*F(y_{s-1})) | ||
# and | ||
# dt*(d_s1*S(y_1) + d_s2*S(y_2)+ ... + d_{s,s-1}*S(y_{s-1})) | ||
for i in range(stage): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can follow what is going on in this for loop but maybe you could add some comments to explain it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point! Comments added |
||
r_exp = self.residual.label_map( | ||
lambda t: t.has_label(explicit), | ||
|
@@ -368,6 +398,7 @@ def res(self, stage): | |
map_if_false=lambda t: Constant(self.butcher_imp[stage, i])*self.dt*t) | ||
residual += r_imp | ||
residual += r_exp | ||
# Calculate and add on dt*a_ss*F(y_s) | ||
r_imp = self.residual.label_map( | ||
lambda t: t.has_label(implicit), | ||
map_if_true=replace_subject(self.x_out, old_idx=self.idx), | ||
|
@@ -380,12 +411,18 @@ def res(self, stage): | |
|
||
@property | ||
def final_res(self): | ||
"""Set up the discretisation's final residual.""" | ||
# Add time derivative terms y^{n+1} - y^n | ||
mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't spot what is different about this compared with the normal There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The final residual loops across all stages, and also loops up to the final stage for explicit terms. |
||
map_if_false=drop) | ||
residual = mass_form.label_map(all_terms, | ||
map_if_true=replace_subject(self.x_out, old_idx=self.idx)) | ||
residual -= mass_form.label_map(all_terms, | ||
map_if_true=replace_subject(self.x1, old_idx=self.idx)) | ||
# Loop through stages up to s-1 and calcualte/sum | ||
# dt*(b_1*F(y_1) + b_2*F(y_2) + .... + b_s*F(y_s)) | ||
# and | ||
# dt*(e_1*S(y_1) + e_2*S(y_2) + .... + e_s*S(y_s)) | ||
for i in range(self.nStages): | ||
r_exp = self.residual.label_map( | ||
lambda t: t.has_label(explicit), | ||
|
@@ -407,6 +444,7 @@ def final_res(self): | |
|
||
@cached_property | ||
def solvers(self): | ||
"""Set up a list of solvers for each problem at a stage.""" | ||
solvers = [] | ||
for stage in range(self.nStages): | ||
# setup solver using residual defined in derived class | ||
|
@@ -417,6 +455,7 @@ def solvers(self): | |
|
||
@cached_property | ||
def final_solver(self): | ||
"""Set up a solver for the final solve to evaluate time level n+1.""" | ||
# setup solver using lhs and rhs defined in derived class | ||
problem = NonlinearVariationalProblem(self.final_res, self.x_out, bcs=self.bcs) | ||
solver_name = self.field_name+self.__class__.__name__ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed offline: it might be best to move this code to break up the continuity form out of time discretistations to a new routine (we suggested that it would be in
common_forms.py
). Then aTimeDiscretisation
can be agnostic to the terms that it works on, and this code than then be reused for other things.