diff --git a/gusto/linear_solvers.py b/gusto/linear_solvers.py index 4804010c6..5388112b7 100644 --- a/gusto/linear_solvers.py +++ b/gusto/linear_solvers.py @@ -8,7 +8,7 @@ from firedrake import ( split, LinearVariationalProblem, Constant, LinearVariationalSolver, TestFunctions, TrialFunctions, TestFunction, TrialFunction, lhs, - rhs, FacetNormal, div, dx, jump, avg, dS_v, dS_h, ds_v, ds_t, ds_b, + rhs, FacetNormal, div, dx, jump, avg, dS, dS_v, dS_h, ds_v, ds_t, ds_b, ds_tb, inner, action, dot, grad, Function, VectorSpaceBasis, BrokenElement, FunctionSpace, MixedFunctionSpace, DirichletBC ) @@ -586,6 +586,10 @@ def _setup_solver(self): VD = equation.domain.spaces("DG") Vb = equation.domain.spaces("DG") + # Check that the third field is buoyancy + if not equation.field_names[2] == 'b': + raise NotImplementedError("Field 'b' must exist to use the thermal linear solver in the SIQN scheme") + # Store time-stepping coefficients as UFL Constants beta = Constant(beta_) @@ -601,18 +605,22 @@ def _setup_solver(self): u, D = TrialFunctions(M) # Get background buoyancy and depth - bbar = split(equation.X_ref)[2] Dbar = split(equation.X_ref)[1] + bbar = split(equation.X_ref)[2] # Approximate elimination of b b = -dot(u, grad(bbar))*beta + b_in + n = FacetNormal(equation.domain.mesh) + eqn = ( inner(w, (u - u_in)) * dx - beta * (D - Dbar) * div(w*bbar) * dx - + beta * 0.5 * Dbar * inner(w, grad(bbar)) * dx + + beta * jump(w*bbar, n) * avg(D-Dbar) * dS + - beta * 0.5 * Dbar * bbar * div(w) * dx - beta * 0.5 * Dbar * b * div(w) * dx - + beta * 0.5 * (D - Dbar) * inner(w, grad(bbar)) * dx + - beta * 0.5 * bbar * div(w*(D-Dbar)) * dx + + beta * 0.5 * jump((D-Dbar)*w, n) * avg(bbar) * dS + inner(phi, (D - D_in)) * dx + beta * phi * Dbar * div(u) * dx ) @@ -667,6 +675,13 @@ def solve(self, xrhs, dy): """ self.xrhs.assign(xrhs) + # Check that the b reference profile has been set + bbar = split(self.equations.X_ref)[2] + b = dy.subfunctions[2] + bbar_func = Function(b.function_space()).interpolate(bbar) + if bbar_func.dat.data.max() == 0 and bbar_func.dat.data.min() == 0: + logger.warning("The reference profile for b in the linear solver is zero. To set a non-zero profile add b to the set_reference_profiles argument.") + with timed_region("Gusto:VelocityDepthSolve"): logger.info('Thermal linear solver: mixed solve') self.uD_solver.solve()