Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IMEX Multistage #453

Merged
merged 26 commits into from
Nov 8, 2023
Merged

IMEX Multistage #453

merged 26 commits into from
Nov 8, 2023

Conversation

atb1995
Copy link
Collaborator

@atb1995 atb1995 commented Oct 19, 2023

This code change adds IMEX Multistage class which allows for general split diagonally implicit and explicit Runge-Kutta schemes by passing in a pair of Butcher Tableaus. It also adds IMEX Euler, ARK2, Trap2, SSP3 and ARS3. The butcher tableaus are of the form:

[a_00 0 . 0 ]
[a_10 a_11 . 0 ]
[ . . . . ]
[ b_0 b_1 . b_s]

and

[0 0 . 0 ]
[d_10 0 . 0 ]
[ . . . . ]
[ e_0 e_1 . e_s]

I ran convergence tests for Williamson 5 & 6 (see attached). This pull request is blocked by #440.

@atb1995
Copy link
Collaborator Author

atb1995 commented Oct 19, 2023

w6_imex_rk

w5_imex_rk

Copy link
Contributor

@tommbendall tommbendall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very exciting, thanks Alex. My suggestions are mainly about adding comments

@pytest.mark.parametrize("scheme", ["ssprk", "implicit_midpoint",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth", "Leapfrog", "AdamsMoulton"])
@pytest.mark.parametrize("scheme", ["ssprk", "TrapeziumRule",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add some of the IMEX schemes here to test them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue here is the test is a prescribed transport test, hence no option for implicit and explicit terms.. Could change or add another test?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed offline: adding a new test_imex.py which could test these on something simple like a wave equation with an analytic solution

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added test_IMEX.py. It uses a continuity form and splits the term to treat q*div(u) implicitly and u.grad(q) explicitly.

@@ -90,6 +90,8 @@ def __call__(self, target, value=None):
# ---------------------------------------------------------------------------- #

time_derivative = Label("time_derivative")
implicit = Label("implicit")
Copy link
Contributor

@tommbendall tommbendall Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see these labels added!

@@ -582,7 +582,7 @@ def __init__(self, domain, parameters, fexpr=None, bexpr=None,
space_names=None, linearisation_map='default',
u_transport_option='vector_invariant_form',
no_normal_flow_bc_ids=None, active_tracers=None,
thermal=False):
thermal=False, conservative_depth=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I don't think we want to add this argument to the equation

a butcher tableau defining a given explicit Runge Kutta time discretisation.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
subcycles (int, optional): the number of sub-steps to perform.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't seem to have included this argument

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed!

mesh and the compatible function spaces.
butcher_imp (numpy array): A matrix containing the coefficients of
a butcher tableau defining a given implicit Runge Kutta time discretisation.
butcher_exp (numpy array): A matrix containing the coefficients of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type should be :class:`numpy.ndarray`

super().setup(equation, apply_bcs, *active_labels)

self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, transport)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should transport be included here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is a map if false, so all non-time derivative and transport terms are implicit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've discussed this offline: we realised that we can label terms as explicit or implicit in any example file using the equation.label_map(...) routine.

But here we should have a check that throws an error if a term hasn't been labelled as implicit/explicit (or is the time_derivative)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error checker has been added. I raised a "NotImplementedError" - is this the correct choice..?

lambda t: any(t.has_label(time_derivative, transport)),
map_if_false=lambda t: implicit(t))

self.residual = self.residual.label_map(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to enforce transport terms being explicit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, it could be left up to the user to label terms as explicit or implicit at the script level.. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that in the long term we would like users to be able to specify which terms they would like to make implicit / explicit and this could be via a dictionary that is passed in. But for now this seems like the most sensible option. I suggested to @atb1995 that he should add a warning message using the logger to tell people that transport terms are treated explicitly and all others are treated implicitly.

map_if_true=replace_subject(self.x_out, old_idx=self.idx))
residual -= mass_form.label_map(all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
for i in range(stage):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can follow what is going on in this for loop but maybe you could add some comments to explain it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Comments added

return super(IMEXMultistage, self).rhs

def res(self, stage):
mass_form = self.residual.label_map(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a docstring here to explain this routine?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do! Will also add for final_res and the various solvers


@property
def final_res(self):
mass_form = self.residual.label_map(lambda t: t.has_label(time_derivative),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't spot what is different about this compared with the normal res routine!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The final residual loops across all stages, and also loops up to the final stage for explicit terms.

self.residual = self.residual.label_map(
lambda t: t.get(transport) == TransportEquationType.conservative,
map_if_true=drop)
self.residual += new_transport_term.form
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should just add the labelled form to the residual, not just the form part, i.e. self.residual += new_transport_term

Copy link
Contributor

@tommbendall tommbendall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be asking for more changes!

super().setup(equation, apply_bcs, *active_labels)

self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, transport)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've discussed this offline: we realised that we can label terms as explicit or implicit in any example file using the equation.label_map(...) routine.

But here we should have a check that throws an error if a term hasn't been labelled as implicit/explicit (or is the time_derivative)

@@ -323,13 +323,37 @@ def setup(self, equation, apply_bcs=True, *active_labels):

super().setup(equation, apply_bcs, *active_labels)

# Get continuity form transport term
for t in self.residual:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed offline: it might be best to move this code to break up the continuity form out of time discretistations to a new routine (we suggested that it would be in common_forms.py). Then a TimeDiscretisation can be agnostic to the terms that it works on, and this code than then be reused for other things.

@pytest.mark.parametrize("scheme", ["ssprk", "implicit_midpoint",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth", "Leapfrog", "AdamsMoulton"])
@pytest.mark.parametrize("scheme", ["ssprk", "TrapeziumRule",
"RK4", "Heun", "BDF2", "TR_BDF2", "AdamsBashforth",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed offline: adding a new test_imex.py which could test these on something simple like a wave equation with an analytic solution

Copy link
Contributor

@tommbendall tommbendall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test looks really good, just some questions about what the split_continuity_form routine does when we don't have a mixed function space

if hasattr(equation, "field_names"):
idx = equation.field_names.index(prognostic_field_name)
else:
idx = equation.field_name.index(prognostic_field_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this won't give what we expect! Is field_name a string and not a list?

In the case that we don't have field_names, we don't want an index and don't want to split anything up (because we won't have a mixed function space)

Copy link
Contributor

@tommbendall tommbendall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two tiny changes and then I think that this is ready!


def split_continuity_form(equation):
u"""
Loops through terms in a given equation, and splits continuity terms into
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you mention that this will split up all conservative terms in the equation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Thanks Tom

gusto/common_forms.py Outdated Show resolved Hide resolved
@tommbendall tommbendall merged commit c209deb into main Nov 8, 2023
4 checks passed
@tommbendall tommbendall deleted the IMEX_multistage branch November 8, 2023 09:12
@jshipton
Copy link
Contributor

jshipton commented Nov 8, 2023 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants