From 912273c23ebe3a8c7b2633a912bbb8f1229cdb1f Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Fri, 31 Jan 2025 14:45:23 +0000 Subject: [PATCH 1/3] slate: fix kernel argument ordering --- firedrake/slate/slac/kernel_builder.py | 32 ++++++++++++-------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/firedrake/slate/slac/kernel_builder.py b/firedrake/slate/slac/kernel_builder.py index 419931232f..8cf27b5298 100644 --- a/firedrake/slate/slac/kernel_builder.py +++ b/firedrake/slate/slac/kernel_builder.py @@ -144,26 +144,24 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrap # Pick the coefficients associated with a Tensor()/TSFC kernel tsfc_coefficients = {tsfc_coefficients[i]: indices for i, indices in kinfo.coefficient_numbers} - for c, cinfo in wrapper_coefficients.items(): - if c in tsfc_coefficients: - if isinstance(cinfo, tuple): - if tsfc_coefficients[c]: - ind, = tsfc_coefficients[c] - if ind != 0: - raise ValueError(f"Active indices of non-mixed function must be (0, ), not {tsfc_coefficients[c]}") - kernel_data.append((c, cinfo[0])) - else: - for ind, (c_, info) in enumerate(cinfo.items()): - if ind in tsfc_coefficients[c]: - kernel_data.append((c_, info[0])) + for c in tsfc_coefficients: + cinfo = wrapper_coefficients[c] + if isinstance(cinfo, tuple): + if tsfc_coefficients[c]: + ind, = tsfc_coefficients[c] + if ind != 0: + raise ValueError(f"Active indices of non-mixed function must be (0, ), not {tsfc_coefficients[c]}") + kernel_data.append((c, cinfo[0])) + else: + for ind, (c_, info) in enumerate(cinfo.items()): + if ind in tsfc_coefficients[c]: + kernel_data.append((c_, info[0])) # Pick the constants associated with a Tensor()/TSFC kernel tsfc_constants = tuple(tsfc_constants[i] for i in kinfo.constant_numbers) - kernel_data.extend([ - (constant, constant_name) - for constant, constant_name in wrapper_constants - if constant in tsfc_constants - ]) + wrapper_constants = dict(wrapper_constants) + for c in tsfc_constants: + kernel_data.append((c, wrapper_constants[c])) return kernel_data def loopify_tsfc_kernel_data(self, kernel_data): From 31dde189b555817b6256fc0b53f10c5893ab98d3 Mon Sep 17 00:00:00 2001 From: Thomas Bendall Date: Fri, 31 Jan 2025 16:33:53 +0000 Subject: [PATCH 2/3] slate: add gusto test Co-authored-by: Josh Hope-Collins --- .../slate/test_slate_hybridization.py | 135 ++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/tests/firedrake/slate/test_slate_hybridization.py b/tests/firedrake/slate/test_slate_hybridization.py index ab7d4d4415..52b67f2bf6 100644 --- a/tests/firedrake/slate/test_slate_hybridization.py +++ b/tests/firedrake/slate/test_slate_hybridization.py @@ -462,3 +462,138 @@ def test_mixed_poisson_approximated_schur_jacobi_prec(setup_poisson): assert sigma_err < 1e-8 assert u_err < 1e-8 + + +def test_slate_hybridization_gusto(): + + # ---------------------------------------------------------------------------- # + # Set up core objects + # ---------------------------------------------------------------------------- # + + degree = 1 + radius = Constant(6371220.) + dt = Constant(3000.0) + H = Constant(3e4/9.80616) + g = Constant(9.80616) + Omega = Constant(7.292e-5) + alpha = Constant(0.5) + beta_u = alpha*dt + beta_d = alpha*dt + + mesh = IcosahedralSphereMesh(float(radius), refinement_level=0, degree=2) + x = SpatialCoordinate(mesh) + + # Function spaces + VDG = FunctionSpace(mesh, "DG", degree) + VHDiv = FunctionSpace(mesh, "BDM", degree+1) + mesh.init_cell_orientations(x) + + VCG1 = FunctionSpace(mesh, 'CG', 1) + Vmixed = MixedFunctionSpace((VHDiv, VDG, VDG)) + Vreduced = MixedFunctionSpace((VHDiv, VDG)) + + # Mixed RHS and LHS functions + x_rhs = Function(Vmixed) + + # Components of various mixed functions + u_rhs, D_rhs, b_rhs = split(x_rhs) + + uD_lhs = Function(Vreduced) + u_test, D_test = TestFunctions(Vreduced) + u_trial, D_trial = TrialFunctions(Vreduced) + u_lhsr, D_lhsr = uD_lhs.subfunctions + + # Reference profiles + uDb_bar = Function(Vmixed) + D_bar = split(uDb_bar)[1] + b_bar = split(uDb_bar)[2] + D_bar_subfunc = uDb_bar.subfunctions[1] + b_bar_subfunc = uDb_bar.subfunctions[2] + + # Set up perp function + sphere_degree = mesh.coordinates.function_space().ufl_element().degree() + VecDG = VectorFunctionSpace(mesh, 'DG', sphere_degree) + N = Function(VecDG).interpolate(CellNormal(mesh)) + perp = lambda u: cross(N, u) + + # Coriolis + f = Function(VCG1) + + # ---------------------------------------------------------------------------- # + # Set up problem and solvers + # ---------------------------------------------------------------------------- # + + # We have a 3-function mixed space for wind, depth and buoyancy + # Strategy is to approximately eliminate buoyancy, leaving a wind-depth problem + # The wind-depth problem is solved using the hydridization preconditioner + + # Approximate elimination of b + b = -0.5 * dt * dot(u_trial, grad(b_bar)) + b_rhs + + solver_parameters = { + 'ksp_type': 'preonly', + 'mat_type': 'matfree', + 'pc_type': 'python', + 'pc_python_type': 'firedrake.HybridizationPC', + 'hybridization': { + 'ksp_type': 'cg', + 'pc_type': 'gamg', + 'ksp_rtol': 1e-1, + 'mg_levels': { + 'ksp_type': 'chebyshev', + 'ksp_max_it': 1, + 'pc_type': 'bjacobi', + 'sub_pc_type': 'ilu' + } + } + } + + n = FacetNormal(mesh) + + # Linear thermal shallow water problem + uD_eqn = ( + inner(u_test, (u_trial - u_rhs)) * dx + - beta_u * (D_trial - D_bar) * div(u_test * b_bar) * dx + + beta_u * jump(u_test * b_bar, n) * avg(D_trial - D_bar) * dS + - beta_u * 0.5 * D_bar * b_bar * div(u_test) * dx + - beta_u * 0.5 * D_bar * b * div(u_test) * dx + - beta_u * 0.5 * b_bar * div(u_test*(D_trial - D_bar)) * dx + + beta_u * 0.5 * jump((D_trial - D_bar)*u_test, n) * avg(b_bar) * dS + + inner(D_test, (D_trial - D_rhs)) * dx + + beta_d * D_test * div(D_bar*u_trial) * dx + + beta_u * f * inner(u_test, perp(u_trial)) * dx + ) + + # Boundary conditions + bcs = [] + + # Solver for u, D + uD_problem = LinearVariationalProblem( + lhs(uD_eqn), rhs(uD_eqn), uD_lhs, bcs=bcs, constant_jacobian=True + ) + + # Provide callback for the nullspace of the trace system + def trace_nullsp(T): + return VectorSpaceBasis(constant=True) + + appctx = {"trace_nullspace": trace_nullsp} + uD_solver = LinearVariationalSolver( + uD_problem, solver_parameters=solver_parameters, appctx=appctx + ) + + # ---------------------------------------------------------------------------- # + # Set some initial conditions + # ---------------------------------------------------------------------------- # + + f.interpolate(2*Omega*x[2]/radius) + D_bar_subfunc.assign(H) + b_bar_subfunc.assign(g) + + # Simplest test is to give a right-hand side that is zero + x_rhs.assign(1.0) + + # ---------------------------------------------------------------------------- # + # Run + # ---------------------------------------------------------------------------- # + + uD_solver.solve() From 2a6ab2317af1d7fcf83a8ab7dfc4060896ad6497 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Thu, 6 Feb 2025 14:12:04 +0000 Subject: [PATCH 3/3] slate: add test for checking count robustness --- .../slate/test_slate_hybridization.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/firedrake/slate/test_slate_hybridization.py b/tests/firedrake/slate/test_slate_hybridization.py index 52b67f2bf6..239b7f4db5 100644 --- a/tests/firedrake/slate/test_slate_hybridization.py +++ b/tests/firedrake/slate/test_slate_hybridization.py @@ -597,3 +597,30 @@ def trace_nullsp(T): # ---------------------------------------------------------------------------- # uD_solver.solve() + + +@pytest.mark.parametrize('counts', [(10001, 10002), (10002, 10001)]) +def test_slate_hybridization_count_safe(counts): + g_count, c_count = counts + mesh = UnitTriangleMesh() + BDM = FunctionSpace(mesh, "BDM", 2) + DG = FunctionSpace(mesh, "DG", 1) + V = BDM * DG + VectorDG = VectorFunctionSpace(mesh, 'DG', 0) + u = TrialFunction(V) + v = TestFunction(V) + g = Function(VectorDG, count=g_count) + c = Function(DG, count=c_count) + a = ( + inner(u, v) * dx + + inner(g[0] * u[0], v[0]) * dx + + inner(c * grad(u[0]), grad(v[0])) * dx + ) + sol = Function(V) + solver_parameters = { + 'mat_type': 'matfree', + 'ksp_type': 'preonly', + 'pc_type': 'python', + 'pc_python_type': 'firedrake.HybridizationPC', + } + solve(lhs(a) == rhs(a), sol, solver_parameters=solver_parameters)