Skip to content

Commit

Permalink
Merge pull request #444 from firedrakeproject/adaptive_subcycling
Browse files Browse the repository at this point in the history
Adaptive subcycling
  • Loading branch information
jshipton authored Oct 20, 2023
2 parents 7c8d413 + dfebaf7 commit cfd17a9
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 36 deletions.
4 changes: 2 additions & 2 deletions examples/compressible/dcmip_3_1_meanflow_quads.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@

# Transport schemes
transported_fields = [TrapeziumRule(domain, "u"),
SSPRK3(domain, "rho", subcycles=2),
SSPRK3(domain, "theta", options=SUPGOptions(), subcycles=2)]
SSPRK3(domain, "rho", fixed_subcycles=2),
SSPRK3(domain, "theta", options=SUPGOptions(), fixed_subcycles=2)]
transport_methods = [DGUpwind(eqns, field) for field in ["u", "rho", "theta"]]

# Linear solver
Expand Down
2 changes: 1 addition & 1 deletion examples/shallow_water/williamson_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@

# Transport schemes
transported_fields = [TrapeziumRule(domain, "u"),
SSPRK3(domain, "D", subcycles=2)]
SSPRK3(domain, "D", fixed_subcycles=2)]
transport_methods = [DGUpwind(eqns, "u"), DGUpwind(eqns, "D")]

# Time stepper
Expand Down
12 changes: 11 additions & 1 deletion gusto/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from gusto.diagnostics import Diagnostics, CourantNumber
from gusto.meshes import get_flat_latlon_mesh
from firedrake import (Function, functionspaceimpl, File,
from firedrake import (Function, functionspaceimpl, File, Constant,
DumbCheckpoint, FILE_CREATE, FILE_READ, CheckpointFile)
from pyop2.mpi import MPI
import numpy as np
Expand Down Expand Up @@ -233,6 +233,9 @@ def __init__(self, domain, output, diagnostics=None, diagnostic_fields=None):
self.dumpfile = None
self.to_pick_up = None

if output.log_courant:
self.courant_max = Constant(0.0)

def log_parameters(self, equation):
"""
Logs an equation's physical parameters that take non-default values.
Expand Down Expand Up @@ -309,6 +312,13 @@ def log_courant(self, state_fields, name='u', component="whole", message=None):
else:
logger.info(f'Max Courant {message}: {courant_max:.2e}')

if component == 'whole':
# TODO: this will update the Courant number more than we need to
# and possibly with the wrong Courant number
# we could make self.courant_max a dict with keys depending on
# the field to take the Courant number of
self.courant_max.assign(courant_max)

def setup_diagnostics(self, state_fields):
"""
Prepares the I/O for computing the model's global diagnostics and
Expand Down
90 changes: 65 additions & 25 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from abc import ABCMeta, abstractmethod, abstractproperty
from firedrake import (Function, TestFunction, NonlinearVariationalProblem,
NonlinearVariationalSolver, DirichletBC)
NonlinearVariationalSolver, DirichletBC, Constant)
from firedrake.formmanipulation import split_form
from firedrake.utils import cached_property

Expand All @@ -18,6 +18,7 @@
from gusto.labels import time_derivative, prognostic, physics_label
from gusto.logging import logger, DEBUG, logging_ksp_monitor_true_residual
from gusto.wrappers import *
import math
import numpy as np


Expand Down Expand Up @@ -71,9 +72,13 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
self.field_name = field_name
self.equation = None

self.dt = domain.dt
self.dt = Constant(0.0)
self.dt.assign(domain.dt)
self.original_dt = Constant(0.0)
self.original_dt.assign(self.dt)
self.options = options
self.limiter = limiter
self.courant_max = None

if options is not None:
self.wrapper_name = options.name
Expand Down Expand Up @@ -241,16 +246,24 @@ def apply(self, x_out, x_in):
class ExplicitTimeDiscretisation(TimeDiscretisation):
"""Base class for explicit time discretisations."""

def __init__(self, domain, field_name=None, subcycles=None,
solver_parameters=None, limiter=None, options=None):
def __init__(self, domain, field_name=None, fixed_subcycles=None,
subcycle_by_courant=None, solver_parameters=None, limiter=None,
options=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
subcycles (int, optional): the number of sub-steps to perform.
Defaults to None.
fixed_subcycles (int, optional): the fixed number of sub-steps to
perform. This option cannot be specified with the
`subcycle_by_courant` argument. Defaults to None.
subcycle_by_courant (float, optional): specifying this option will
make the scheme perform adaptive sub-cycling based on the
Courant number. The specified argument is the maximum Courant
for one sub-cycle. Defaults to None, in which case adaptive
sub-cycling is not used. This option cannot be specified with the
`fixed_subcycles` argument.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
limiter (:class:`Limiter` object, optional): a limiter to apply to
Expand All @@ -264,7 +277,11 @@ def __init__(self, domain, field_name=None, subcycles=None,
solver_parameters=solver_parameters,
limiter=limiter, options=options)

self.subcycles = subcycles
if fixed_subcycles is not None and subcycle_by_courant is not None:
raise ValueError('Cannot specify both subcycle and subcycle_by '
+ 'arguments to a time discretisation')
self.fixed_subcycles = fixed_subcycles
self.subcycle_by_courant = subcycle_by_courant

def setup(self, equation, apply_bcs=True, *active_labels):
"""
Expand All @@ -279,11 +296,11 @@ def setup(self, equation, apply_bcs=True, *active_labels):
"""
super().setup(equation, apply_bcs, *active_labels)

# if user has specified a number of subcycles, then save this
# if user has specified a number of fixed subcycles, then save this
# and rescale dt accordingly; else perform just one cycle using dt
if self.subcycles is not None:
self.dt = self.dt/self.subcycles
self.ncycles = self.subcycles
if self.fixed_subcycles is not None:
self.dt.assign(self.dt/self.fixed_subcycles)
self.ncycles = self.fixed_subcycles
else:
self.dt = self.dt
self.ncycles = 1
Expand Down Expand Up @@ -331,6 +348,11 @@ def apply(self, x_out, x_in):
x_out (:class:`Function`): the output field to be computed.
x_in (:class:`Function`): the input field.
"""
# If doing adaptive subcycles, update dt and ncycles here
if self.subcycle_by_courant is not None:
self.ncycles = math.ceil(float(self.courant_max)/self.subcycle_by_courant)
self.dt.assign(self.original_dt/self.ncycles)

self.x0.assign(x_in)
for i in range(self.ncycles):
self.apply_cycle(self.x1, self.x0)
Expand Down Expand Up @@ -375,9 +397,13 @@ class ExplicitMultistage(ExplicitTimeDiscretisation):
"""

def __init__(self, domain, field_name=None, subcycles=None, solver_parameters=None,
def __init__(self, domain, field_name=None, fixed_subcycles=None,
subcycle_by_courant=None, solver_parameters=None,
limiter=None, options=None, butcher_matrix=None):
super().__init__(domain, field_name=field_name, subcycles=subcycles,

super().__init__(domain, field_name=field_name,
fixed_subcycles=fixed_subcycles,
subcycle_by_courant=subcycle_by_courant,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
if butcher_matrix is not None:
Expand Down Expand Up @@ -466,11 +492,15 @@ class ForwardEuler(ExplicitMultistage):
k0 = F[y^n]
y^(n+1) = y^n + dt*k0
"""
def __init__(self, domain, field_name=None, subcycles=None, solver_parameters=None,
def __init__(self, domain, field_name=None, fixed_subcycles=None,
subcycle_by_courant=None, solver_parameters=None,
limiter=None, options=None, butcher_matrix=None):
super().__init__(domain, field_name=field_name, subcycles=subcycles,
super().__init__(domain, field_name=field_name,
fixed_subcycles=fixed_subcycles,
subcycle_by_courant=subcycle_by_courant,
solver_parameters=solver_parameters,
limiter=limiter, options=options, butcher_matrix=butcher_matrix)
limiter=limiter, options=options,
butcher_matrix=butcher_matrix)
self.butcher_matrix = np.array([1.]).reshape(1, 1)
self.nbutcher = int(np.shape(self.butcher_matrix)[0])

Expand All @@ -485,11 +515,15 @@ class SSPRK3(ExplicitMultistage):
k2 = F[y^n + (1/4)*dt*(k0+k1)]
y^(n+1) = y^n + (1/6)*dt*(k0 + k1 + 4*k2)
"""
def __init__(self, domain, field_name=None, subcycles=None, solver_parameters=None,
def __init__(self, domain, field_name=None, fixed_subcycles=None,
subcycle_by_courant=None, solver_parameters=None,
limiter=None, options=None, butcher_matrix=None):
super().__init__(domain, field_name=field_name, subcycles=subcycles,
super().__init__(domain, field_name=field_name,
fixed_subcycles=fixed_subcycles,
subcycle_by_courant=subcycle_by_courant,
solver_parameters=solver_parameters,
limiter=limiter, options=options, butcher_matrix=butcher_matrix)
limiter=limiter, options=options,
butcher_matrix=butcher_matrix)
self.butcher_matrix = np.array([[1., 0., 0.], [1./4., 1./4., 0.], [1./6., 1./6., 2./3.]])
self.nbutcher = int(np.shape(self.butcher_matrix)[0])

Expand All @@ -509,11 +543,15 @@ class RK4(ExplicitMultistage):
where superscripts indicate the time-level.
"""
def __init__(self, domain, field_name=None, subcycles=None, solver_parameters=None,
def __init__(self, domain, field_name=None, fixed_subcycles=None,
subcycle_by_courant=None, solver_parameters=None,
limiter=None, options=None, butcher_matrix=None):
super().__init__(domain, field_name=field_name, subcycles=subcycles,
super().__init__(domain, field_name=field_name,
fixed_subcycles=fixed_subcycles,
subcycle_by_courant=subcycle_by_courant,
solver_parameters=solver_parameters,
limiter=limiter, options=options, butcher_matrix=butcher_matrix)
limiter=limiter, options=options,
butcher_matrix=butcher_matrix)
self.butcher_matrix = np.array([[0.5, 0., 0., 0.], [0., 0.5, 0., 0.], [0., 0., 1., 0.], [1./6., 1./3., 1./3., 1./6.]])
self.nbutcher = int(np.shape(self.butcher_matrix)[0])

Expand All @@ -531,9 +569,11 @@ class Heun(ExplicitMultistage):
where superscripts indicate the time-level and subscripts indicate the stage
number.
"""
def __init__(self, domain, field_name=None, subcycles=None, solver_parameters=None,
def __init__(self, domain, field_name=None, fixed_subcycles=None,
subcycle_by_courant=None, solver_parameters=None,
limiter=None, options=None, butcher_matrix=None):
super().__init__(domain, field_name,
super().__init__(domain, field_name, fixed_subcycles=fixed_subcycles,
subcycle_by_courant=subcycle_by_courant,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
self.butcher_matrix = np.array([[1., 0.], [0.5, 0.5]])
Expand All @@ -555,7 +595,7 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
mesh and the compatible function spaces.
field_name (str, optional): name of the field to be evolved.
Defaults to None.
subcycles (int, optional): the number of sub-steps to perform.
fixed_subcycles (int, optional): the number of sub-steps to perform.
Defaults to None.
solver_parameters (dict, optional): dictionary of parameters to
pass to the underlying solver. Defaults to None.
Expand Down
3 changes: 3 additions & 0 deletions gusto/timeloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def setup_scheme(self):
self.setup_equation(self.equation)
self.scheme.setup(self.equation)
self.setup_transporting_velocity(self.scheme)
self.scheme.courant_max = self.io.courant_max

def timestep(self):
"""
Expand Down Expand Up @@ -364,6 +365,7 @@ def setup_scheme(self):
apply_bcs = True
self.scheme.setup(self.equation, apply_bcs, dynamics)
self.setup_transporting_velocity(self.scheme)
self.scheme.courant_max = self.io.courant_max

def timestep(self):

Expand Down Expand Up @@ -567,6 +569,7 @@ def setup_scheme(self):
for _, scheme in self.active_transport:
scheme.setup(self.equation, apply_bcs, transport)
self.setup_transporting_velocity(scheme)
scheme.courant_max = self.io.courant_max

apply_bcs = True
for _, scheme in self.diffusion_schemes:
Expand Down
14 changes: 7 additions & 7 deletions integration-tests/transport/test_subcycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ def run(timestepper, tmax, f_end):
return norm(timestepper.fields("f") - f_end) / norm(f_end)


@pytest.mark.parametrize("equation_form", ["advective", "continuity"])
def test_subcyling(tmpdir, equation_form, tracer_setup):
@pytest.mark.parametrize("subcycling", ["fixed", "adaptive"])
def test_subcyling(tmpdir, subcycling, tracer_setup):
geometry = "slice"
setup = tracer_setup(tmpdir, geometry)
domain = setup.domain
V = domain.spaces("DG")
if equation_form == "advective":
eqn = AdvectionEquation(domain, V, "f")
else:
eqn = ContinuityEquation(domain, V, "f")
eqn = AdvectionEquation(domain, V, "f")

transport_scheme = SSPRK3(domain, subcycles=2)
if subcycling == "fixed":
transport_scheme = SSPRK3(domain, fixed_subcycles=2)
elif subcycling == "adaptive":
transport_scheme = SSPRK3(domain, subcycle_by_courant=0.25)
transport_method = DGUpwind(eqn, "f")

timestepper = PrescribedTransport(eqn, transport_scheme, setup.io, transport_method)
Expand Down

0 comments on commit cfd17a9

Please sign in to comment.