diff --git a/examples/compressible/mountain_hydrostatic.py b/examples/compressible/mountain_hydrostatic.py index 2fe88206d..640713627 100644 --- a/examples/compressible/mountain_hydrostatic.py +++ b/examples/compressible/mountain_hydrostatic.py @@ -6,7 +6,7 @@ from gusto import * from firedrake import (as_vector, VectorFunctionSpace, PeriodicIntervalMesh, ExtrudedMesh, SpatialCoordinate, - exp, pi, cos, Function, conditional, Mesh, op2, sqrt) + exp, pi, cos, Function, conditional, Mesh, sqrt) import sys # ---------------------------------------------------------------------------- # @@ -157,21 +157,13 @@ top=True, exner_boundary=0.5, params=exner_params) +# Use kernel as a parallel-safe method of computing minimum +min_kernel = kernels.MinKernel() -def minimum(f): - fmin = op2.Global(1, [1000], dtype=float) - op2.par_loop(op2.Kernel(""" -static void minify(double *a, double *b) { - a[0] = a[0] > fabs(b[0]) ? fabs(b[0]) : a[0]; -} - """, "minify"), f.dof_dset.set, fmin(op2.MIN), f.dat(op2.READ)) - return fmin.data[0] - - -p0 = minimum(exner) +p0 = min_kernel.apply(exner) compressible_hydrostatic_balance(eqns, theta_b, rho_b, exner, top=True, params=exner_params) -p1 = minimum(exner) +p1 = min_kernel.apply(exner) alpha = 2.*(p1-p0) beta = p1-alpha exner_top = (1.-beta)/alpha diff --git a/gusto/diagnostics.py b/gusto/diagnostics.py index b0969dfd8..0612255cb 100644 --- a/gusto/diagnostics.py +++ b/gusto/diagnostics.py @@ -1,6 +1,6 @@ """Common diagnostic fields.""" -from firedrake import op2, assemble, dot, dx, Function, sqrt, \ +from firedrake import assemble, dot, dx, Function, sqrt, \ TestFunction, TrialFunction, Constant, grad, inner, curl, \ LinearVariationalProblem, LinearVariationalSolver, FacetNormal, \ ds_b, ds_v, ds_t, dS_h, dS_v, ds, dS, div, avg, jump, pi, \ @@ -15,6 +15,7 @@ from gusto.equations import CompressibleEulerEquations from gusto.active_tracers import TracerVariableType, Phases from gusto.logging import logger +from gusto.kernels import MinKernel, MaxKernel import numpy as np __all__ = ["Diagnostics", "CourantNumber", "Gradient", "XComponent", "YComponent", @@ -62,39 +63,25 @@ def register(self, *fields): @staticmethod def min(f): - # TODO check that this is correct. Maybe move the kernel elsewhere? """ Finds the global minimum DoF value of a field. Args: f (:class:`Function`): field to compute diagnostic for. """ - - fmin = op2.Global(1, np.finfo(float).max, dtype=float, comm=f._comm) - op2.par_loop(op2.Kernel(""" -static void minify(double *a, double *b) { - a[0] = a[0] > fabs(b[0]) ? fabs(b[0]) : a[0]; -} -""", "minify"), f.dof_dset.set, fmin(op2.MIN), f.dat(op2.READ)) - return fmin.data[0] + min_kernel = MinKernel() + return min_kernel.apply(f) @staticmethod def max(f): - # TODO check that this is correct. Maybe move the kernel elsewhere? """ Finds the global maximum DoF value of a field. Args: f (:class:`Function`): field to compute diagnostic for. """ - - fmax = op2.Global(1, np.finfo(float).min, dtype=float, comm=f._comm) - op2.par_loop(op2.Kernel(""" -static void maxify(double *a, double *b) { - a[0] = a[0] < fabs(b[0]) ? fabs(b[0]) : a[0]; -} -""", "maxify"), f.dof_dset.set, fmax(op2.MAX), f.dat(op2.READ)) - return fmax.data[0] + max_kernel = MaxKernel() + return max_kernel.apply(f) @staticmethod def rms(f): diff --git a/gusto/kernels.py b/gusto/kernels.py index 8ca72b363..ec9e80bf4 100644 --- a/gusto/kernels.py +++ b/gusto/kernels.py @@ -10,7 +10,8 @@ """ from firedrake import dx -from firedrake.parloops import par_loop, READ, WRITE +from firedrake.parloops import par_loop, READ, WRITE, MIN, MAX, op2 +import numpy as np class LimitMidpoints(): @@ -112,3 +113,61 @@ def apply(self, field, field_in): {"field": (field, WRITE), "field_in": (field_in, READ)}, is_loopy_kernel=True) + + +class MinKernel(): + """Finds the minimum DoF value of a field.""" + + def __init__(self): + + self._kernel = op2.Kernel(""" + static void minify(double *a, double *b) { + a[0] = a[0] > b[0] ? b[0] : a[0]; + } + """, "minify") + + def apply(self, field): + """ + Performs the par loop. + + Args: + field (:class:`Function`): The field to take the minimum of. + + Returns: + The minimum DoF value of the field. + """ + + fmin = op2.Global(1, np.finfo(float).max, dtype=float, comm=field._comm) + + op2.par_loop(self._kernel, field.dof_dset.set, fmin(MIN), field.dat(READ)) + + return fmin.data[0] + + +class MaxKernel(): + """Finds the maximum DoF value of a field.""" + + def __init__(self): + + self._kernel = op2.Kernel(""" + static void maxify(double *a, double *b) { + a[0] = a[0] < b[0] ? b[0] : a[0]; + } + """, "maxify") + + def apply(self, field): + """ + Performs the par loop. + + Args: + field (:class:`Function`): The field to take the maximum of. + + Returns: + The maximum DoF value of the field. + """ + + fmax = op2.Global(1, np.finfo(float).min, dtype=float, comm=field._comm) + + op2.par_loop(self._kernel, field.dof_dset.set, fmax(MAX), field.dat(READ)) + + return fmax.data[0] diff --git a/unit-tests/kernel_tests/test_max_kernel.py b/unit-tests/kernel_tests/test_max_kernel.py new file mode 100644 index 000000000..d80a9c981 --- /dev/null +++ b/unit-tests/kernel_tests/test_max_kernel.py @@ -0,0 +1,47 @@ +""" +A test of the MaxKernel kernel, which finds the global maximum of a field. +""" + +from firedrake import UnitSquareMesh, Function, FunctionSpace, SpatialCoordinate +from gusto import kernels +import numpy as np + + +def test_max_kernel(): + + # ------------------------------------------------------------------------ # + # Set up meshes and spaces + # ------------------------------------------------------------------------ # + + mesh = UnitSquareMesh(3, 3) + + DG1 = FunctionSpace(mesh, "DG", 1) + + field = Function(DG1) + + # ------------------------------------------------------------------------ # + # Initial conditions + # ------------------------------------------------------------------------ # + + x, y = SpatialCoordinate(mesh) + + # Some random expression + init_expr = (20./3.)*x*y + 300. + field.interpolate(init_expr) + + # Set a maximum value + max_val = 40069.18 + field.dat.data[5] = max_val + + # ------------------------------------------------------------------------ # + # Apply kernel + # ------------------------------------------------------------------------ # + + kernel = kernels.MaxKernel() + new_max = kernel.apply(field) + + # ------------------------------------------------------------------------ # + # Check values + # ------------------------------------------------------------------------ # + + assert np.isclose(new_max, max_val), 'maximum kernel is not correct' diff --git a/unit-tests/kernel_tests/test_min_kernel.py b/unit-tests/kernel_tests/test_min_kernel.py new file mode 100644 index 000000000..010a5c97c --- /dev/null +++ b/unit-tests/kernel_tests/test_min_kernel.py @@ -0,0 +1,47 @@ +""" +A test of the MinKernel kernel, which finds the global minimum of a field. +""" + +from firedrake import UnitSquareMesh, Function, FunctionSpace, SpatialCoordinate +from gusto import kernels +import numpy as np + + +def test_min_kernel(): + + # ------------------------------------------------------------------------ # + # Set up meshes and spaces + # ------------------------------------------------------------------------ # + + mesh = UnitSquareMesh(3, 3) + + DG1 = FunctionSpace(mesh, "DG", 1) + + field = Function(DG1) + + # ------------------------------------------------------------------------ # + # Initial conditions + # ------------------------------------------------------------------------ # + + x, y = SpatialCoordinate(mesh) + + # Some random expression + init_expr = (20./3.)*x*y + 300. + field.interpolate(init_expr) + + # Set a minimum value + min_val = -400.18 + field.dat.data[5] = min_val + + # ------------------------------------------------------------------------ # + # Apply kernel + # ------------------------------------------------------------------------ # + + kernel = kernels.MinKernel() + new_min = kernel.apply(field) + + # ------------------------------------------------------------------------ # + # Check values + # ------------------------------------------------------------------------ # + + assert np.isclose(new_min, min_val), 'Minimum kernel is not correct'