Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thermal sw linear solver #466

Merged
merged 10 commits into from
Nov 30, 2023
22 changes: 19 additions & 3 deletions examples/shallow_water/thermal_williamson2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Test case parameters
# ----------------------------------------------------------------- #

dt = 100
dt = 4000

if '--running-tests' in sys.argv:
tmax = dt
Expand Down Expand Up @@ -54,14 +54,25 @@
ShallowWaterPotentialEnergy(params),
ShallowWaterPotentialEnstrophy(),
SteadyStateError('u'), SteadyStateError('D'),
MeridionalComponent('u'), ZonalComponent('u')]
SteadyStateError('b'), MeridionalComponent('u'),
ZonalComponent('u')]
io = IO(domain, output, diagnostic_fields=diagnostic_fields)

# Transport schemes
transported_fields = [TrapeziumRule(domain, "u"),
SSPRK3(domain, "D", fixed_subcycles=2),
SSPRK3(domain, "b", fixed_subcycles=2)]
transport_methods = [DGUpwind(eqns, "u"),
DGUpwind(eqns, "D"),
DGUpwind(eqns, "b")]

# Linear solver
linear_solver = ThermalSWSolver(eqns)

# Time stepper
stepper = Timestepper(eqns, RK4(domain), io, spatial_methods=transport_methods)
stepper = SemiImplicitQuasiNewton(eqns, io, transported_fields,
transport_methods,
linear_solver=linear_solver)

# ----------------------------------------------------------------- #
# Initial conditions
Expand Down Expand Up @@ -92,6 +103,11 @@
D0.interpolate(Dexpr)
b0.interpolate(bexpr)

# Set reference profiles
Dbar = Function(D0.function_space()).assign(H)
bbar = Function(b0.function_space()).interpolate(bexpr)
stepper.set_reference_profiles([('D', Dbar), ('b', bbar)])

# ----------------------------------------------------------------- #
# Run
# ----------------------------------------------------------------- #
Expand Down
137 changes: 136 additions & 1 deletion gusto/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from abc import ABCMeta, abstractmethod, abstractproperty


__all__ = ["IncompressibleSolver", "LinearTimesteppingSolver", "CompressibleSolver"]
__all__ = ["IncompressibleSolver", "LinearTimesteppingSolver", "CompressibleSolver", "ThermalSWSolver"]


class TimesteppingSolver(object, metaclass=ABCMeta):
Expand Down Expand Up @@ -550,6 +550,141 @@ def solve(self, xrhs, dy):
b.assign(self.b)


class ThermalSWSolver(TimesteppingSolver):
"""
Linear solver object for the thermal shallow water equations.

This solves a linear problem for the thermal shallow water equations with
prognostic variables u (velocity), D (depth) and b (buoyancy). It follows
the following strategy:

(1) Eliminate b
(2) Solve the resulting system for (u, D) using a hybrid-mixed method
(3) Reconstruct b
"""

solver_parameters = {
'ksp_type': 'preonly',
'mat_type': 'matfree',
'pc_type': 'python',
'pc_python_type': 'firedrake.HybridizationPC',
'hybridization': {'ksp_type': 'cg',
'pc_type': 'gamg',
'ksp_rtol': 1e-8,
'mg_levels': {'ksp_type': 'chebyshev',
'ksp_max_it': 2,
'pc_type': 'bjacobi',
'sub_pc_type': 'ilu'}}
}

@timed_function("Gusto:SolverSetup")
def _setup_solver(self):
equation = self.equations # just cutting down line length a bit
dt = self.dt
beta_ = dt*self.alpha
Vu = equation.domain.spaces("HDiv")
VD = equation.domain.spaces("DG")
Vb = equation.domain.spaces("DG")

# Store time-stepping coefficients as UFL Constants
beta = Constant(beta_)

# Split up the rhs vector
self.xrhs = Function(self.equations.function_space)
u_in = split(self.xrhs)[0]
D_in = split(self.xrhs)[1]
b_in = split(self.xrhs)[2]

# Build the reduced function space for u, D
M = MixedFunctionSpace((Vu, VD))
w, phi = TestFunctions(M)
u, D = TrialFunctions(M)

# Get background buoyancy and depth
bbar = split(equation.X_ref)[2]
Dbar = split(equation.X_ref)[1]

# Approximate elimination of b
b = -dot(u, grad(bbar))*beta + b_in

eqn = (
inner(w, (u - u_in)) * dx
- beta * (D - Dbar) * div(w*bbar) * dx
+ beta * 0.5 * Dbar * inner(w, grad(bbar)) * dx
- beta * 0.5 * Dbar * b * div(w) * dx
+ beta * 0.5 * (D - Dbar) * inner(w, grad(bbar)) * dx
+ inner(phi, (D - D_in)) * dx
+ beta * phi * Dbar * div(u) * dx
)

aeqn = lhs(eqn)
Leqn = rhs(eqn)

# Place to put results of (u,D) solver
self.uD = Function(M)

# Boundary conditions
bcs = [DirichletBC(M.sub(0), bc.function_arg, bc.sub_domain) for bc in self.equations.bcs['u']]

# Solver for u, D
uD_problem = LinearVariationalProblem(aeqn, Leqn, self.uD, bcs=bcs)

# Provide callback for the nullspace of the trace system
def trace_nullsp(T):
return VectorSpaceBasis(constant=True)

appctx = {"trace_nullspace": trace_nullsp}
self.uD_solver = LinearVariationalSolver(uD_problem,
solver_parameters=self.solver_parameters,
appctx=appctx)
# Reconstruction of b
b = TrialFunction(Vb)
gamma = TestFunction(Vb)

u, D = self.uD.subfunctions
self.b = Function(Vb)

b_eqn = gamma*(b - b_in + inner(u, grad(bbar))*beta) * dx

b_problem = LinearVariationalProblem(lhs(b_eqn),
rhs(b_eqn),
self.b)
self.b_solver = LinearVariationalSolver(b_problem)

# Log residuals on hybridized solver
self.log_ksp_residuals(self.uD_solver.snes.ksp)

@timed_function("Gusto:LinearSolve")
def solve(self, xrhs, dy):
"""
Solve the linear problem.

Args:
xrhs (:class:`Function`): the right-hand side field in the
appropriate :class:`MixedFunctionSpace`.
dy (:class:`Function`): the resulting field in the appropriate
:class:`MixedFunctionSpace`.
"""
self.xrhs.assign(xrhs)

with timed_region("Gusto:VelocityDepthSolve"):
logger.info('Thermal linear solver: mixed solve')
self.uD_solver.solve()

u1, D1 = self.uD.subfunctions
u = dy.subfunctions[0]
D = dy.subfunctions[1]
b = dy.subfunctions[2]
u.assign(u1)
D.assign(D1)

with timed_region("Gusto:BuoyancyRecon"):
logger.info('Thermal linear solver: buoyancy reconstruction')
self.b_solver.solve()

b.assign(self.b)


class LinearTimesteppingSolver(object):
"""
A general object for solving mixed finite element linear problems.
Expand Down
Loading