diff --git a/examples/shallow_water/thermal_williamson2.py b/examples/shallow_water/thermal_williamson2.py index e9fcd94ac..342ee8b07 100644 --- a/examples/shallow_water/thermal_williamson2.py +++ b/examples/shallow_water/thermal_williamson2.py @@ -6,7 +6,7 @@ # Test case parameters # ----------------------------------------------------------------- # -dt = 100 +dt = 4000 if '--running-tests' in sys.argv: tmax = dt @@ -54,14 +54,25 @@ ShallowWaterPotentialEnergy(params), ShallowWaterPotentialEnstrophy(), SteadyStateError('u'), SteadyStateError('D'), - MeridionalComponent('u'), ZonalComponent('u')] + SteadyStateError('b'), MeridionalComponent('u'), + ZonalComponent('u')] io = IO(domain, output, diagnostic_fields=diagnostic_fields) + +# Transport schemes +transported_fields = [TrapeziumRule(domain, "u"), + SSPRK3(domain, "D", fixed_subcycles=2), + SSPRK3(domain, "b", fixed_subcycles=2)] transport_methods = [DGUpwind(eqns, "u"), DGUpwind(eqns, "D"), DGUpwind(eqns, "b")] +# Linear solver +linear_solver = ThermalSWSolver(eqns) + # Time stepper -stepper = Timestepper(eqns, RK4(domain), io, spatial_methods=transport_methods) +stepper = SemiImplicitQuasiNewton(eqns, io, transported_fields, + transport_methods, + linear_solver=linear_solver) # ----------------------------------------------------------------- # # Initial conditions @@ -92,6 +103,11 @@ D0.interpolate(Dexpr) b0.interpolate(bexpr) +# Set reference profiles +Dbar = Function(D0.function_space()).assign(H) +bbar = Function(b0.function_space()).interpolate(bexpr) +stepper.set_reference_profiles([('D', Dbar), ('b', bbar)]) + # ----------------------------------------------------------------- # # Run # ----------------------------------------------------------------- # diff --git a/gusto/linear_solvers.py b/gusto/linear_solvers.py index 55bee0616..d1017aa28 100644 --- a/gusto/linear_solvers.py +++ b/gusto/linear_solvers.py @@ -24,7 +24,7 @@ from abc import ABCMeta, abstractmethod, abstractproperty -__all__ = ["IncompressibleSolver", "LinearTimesteppingSolver", "CompressibleSolver"] +__all__ = ["IncompressibleSolver", "LinearTimesteppingSolver", "CompressibleSolver", "ThermalSWSolver"] class TimesteppingSolver(object, metaclass=ABCMeta): @@ -550,6 +550,141 @@ def solve(self, xrhs, dy): b.assign(self.b) +class ThermalSWSolver(TimesteppingSolver): + """ + Linear solver object for the thermal shallow water equations. + + This solves a linear problem for the thermal shallow water equations with + prognostic variables u (velocity), D (depth) and b (buoyancy). It follows + the following strategy: + + (1) Eliminate b + (2) Solve the resulting system for (u, D) using a hybrid-mixed method + (3) Reconstruct b + """ + + solver_parameters = { + 'ksp_type': 'preonly', + 'mat_type': 'matfree', + 'pc_type': 'python', + 'pc_python_type': 'firedrake.HybridizationPC', + 'hybridization': {'ksp_type': 'cg', + 'pc_type': 'gamg', + 'ksp_rtol': 1e-8, + 'mg_levels': {'ksp_type': 'chebyshev', + 'ksp_max_it': 2, + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu'}} + } + + @timed_function("Gusto:SolverSetup") + def _setup_solver(self): + equation = self.equations # just cutting down line length a bit + dt = self.dt + beta_ = dt*self.alpha + Vu = equation.domain.spaces("HDiv") + VD = equation.domain.spaces("DG") + Vb = equation.domain.spaces("DG") + + # Store time-stepping coefficients as UFL Constants + beta = Constant(beta_) + + # Split up the rhs vector + self.xrhs = Function(self.equations.function_space) + u_in = split(self.xrhs)[0] + D_in = split(self.xrhs)[1] + b_in = split(self.xrhs)[2] + + # Build the reduced function space for u, D + M = MixedFunctionSpace((Vu, VD)) + w, phi = TestFunctions(M) + u, D = TrialFunctions(M) + + # Get background buoyancy and depth + bbar = split(equation.X_ref)[2] + Dbar = split(equation.X_ref)[1] + + # Approximate elimination of b + b = -dot(u, grad(bbar))*beta + b_in + + eqn = ( + inner(w, (u - u_in)) * dx + - beta * (D - Dbar) * div(w*bbar) * dx + + beta * 0.5 * Dbar * inner(w, grad(bbar)) * dx + - beta * 0.5 * Dbar * b * div(w) * dx + + beta * 0.5 * (D - Dbar) * inner(w, grad(bbar)) * dx + + inner(phi, (D - D_in)) * dx + + beta * phi * Dbar * div(u) * dx + ) + + aeqn = lhs(eqn) + Leqn = rhs(eqn) + + # Place to put results of (u,D) solver + self.uD = Function(M) + + # Boundary conditions + bcs = [DirichletBC(M.sub(0), bc.function_arg, bc.sub_domain) for bc in self.equations.bcs['u']] + + # Solver for u, D + uD_problem = LinearVariationalProblem(aeqn, Leqn, self.uD, bcs=bcs) + + # Provide callback for the nullspace of the trace system + def trace_nullsp(T): + return VectorSpaceBasis(constant=True) + + appctx = {"trace_nullspace": trace_nullsp} + self.uD_solver = LinearVariationalSolver(uD_problem, + solver_parameters=self.solver_parameters, + appctx=appctx) + # Reconstruction of b + b = TrialFunction(Vb) + gamma = TestFunction(Vb) + + u, D = self.uD.subfunctions + self.b = Function(Vb) + + b_eqn = gamma*(b - b_in + inner(u, grad(bbar))*beta) * dx + + b_problem = LinearVariationalProblem(lhs(b_eqn), + rhs(b_eqn), + self.b) + self.b_solver = LinearVariationalSolver(b_problem) + + # Log residuals on hybridized solver + self.log_ksp_residuals(self.uD_solver.snes.ksp) + + @timed_function("Gusto:LinearSolve") + def solve(self, xrhs, dy): + """ + Solve the linear problem. + + Args: + xrhs (:class:`Function`): the right-hand side field in the + appropriate :class:`MixedFunctionSpace`. + dy (:class:`Function`): the resulting field in the appropriate + :class:`MixedFunctionSpace`. + """ + self.xrhs.assign(xrhs) + + with timed_region("Gusto:VelocityDepthSolve"): + logger.info('Thermal linear solver: mixed solve') + self.uD_solver.solve() + + u1, D1 = self.uD.subfunctions + u = dy.subfunctions[0] + D = dy.subfunctions[1] + b = dy.subfunctions[2] + u.assign(u1) + D.assign(D1) + + with timed_region("Gusto:BuoyancyRecon"): + logger.info('Thermal linear solver: buoyancy reconstruction') + self.b_solver.solve() + + b.assign(self.b) + + class LinearTimesteppingSolver(object): """ A general object for solving mixed finite element linear problems.