Skip to content

Commit

Permalink
slate: fix kernel argument ordering (#4003)
Browse files Browse the repository at this point in the history
* slate: fix kernel argument ordering

* slate: add gusto test

Co-authored-by: Josh Hope-Collins <[email protected]>

* slate: add test for checking count robustness

---------

Co-authored-by: Thomas Bendall <[email protected]>
Co-authored-by: Josh Hope-Collins <[email protected]>
  • Loading branch information
3 people authored Feb 6, 2025
1 parent 5391c50 commit 9650136
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 17 deletions.
32 changes: 15 additions & 17 deletions firedrake/slate/slac/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,24 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrap

# Pick the coefficients associated with a Tensor()/TSFC kernel
tsfc_coefficients = {tsfc_coefficients[i]: indices for i, indices in kinfo.coefficient_numbers}
for c, cinfo in wrapper_coefficients.items():
if c in tsfc_coefficients:
if isinstance(cinfo, tuple):
if tsfc_coefficients[c]:
ind, = tsfc_coefficients[c]
if ind != 0:
raise ValueError(f"Active indices of non-mixed function must be (0, ), not {tsfc_coefficients[c]}")
kernel_data.append((c, cinfo[0]))
else:
for ind, (c_, info) in enumerate(cinfo.items()):
if ind in tsfc_coefficients[c]:
kernel_data.append((c_, info[0]))
for c in tsfc_coefficients:
cinfo = wrapper_coefficients[c]
if isinstance(cinfo, tuple):
if tsfc_coefficients[c]:
ind, = tsfc_coefficients[c]
if ind != 0:
raise ValueError(f"Active indices of non-mixed function must be (0, ), not {tsfc_coefficients[c]}")
kernel_data.append((c, cinfo[0]))
else:
for ind, (c_, info) in enumerate(cinfo.items()):
if ind in tsfc_coefficients[c]:
kernel_data.append((c_, info[0]))

# Pick the constants associated with a Tensor()/TSFC kernel
tsfc_constants = tuple(tsfc_constants[i] for i in kinfo.constant_numbers)
kernel_data.extend([
(constant, constant_name)
for constant, constant_name in wrapper_constants
if constant in tsfc_constants
])
wrapper_constants = dict(wrapper_constants)
for c in tsfc_constants:
kernel_data.append((c, wrapper_constants[c]))
return kernel_data

def loopify_tsfc_kernel_data(self, kernel_data):
Expand Down
162 changes: 162 additions & 0 deletions tests/firedrake/slate/test_slate_hybridization.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,165 @@ def test_mixed_poisson_approximated_schur_jacobi_prec(setup_poisson):

assert sigma_err < 1e-8
assert u_err < 1e-8


def test_slate_hybridization_gusto():

# ---------------------------------------------------------------------------- #
# Set up core objects
# ---------------------------------------------------------------------------- #

degree = 1
radius = Constant(6371220.)
dt = Constant(3000.0)
H = Constant(3e4/9.80616)
g = Constant(9.80616)
Omega = Constant(7.292e-5)
alpha = Constant(0.5)
beta_u = alpha*dt
beta_d = alpha*dt

mesh = IcosahedralSphereMesh(float(radius), refinement_level=0, degree=2)
x = SpatialCoordinate(mesh)

# Function spaces
VDG = FunctionSpace(mesh, "DG", degree)
VHDiv = FunctionSpace(mesh, "BDM", degree+1)
mesh.init_cell_orientations(x)

VCG1 = FunctionSpace(mesh, 'CG', 1)
Vmixed = MixedFunctionSpace((VHDiv, VDG, VDG))
Vreduced = MixedFunctionSpace((VHDiv, VDG))

# Mixed RHS and LHS functions
x_rhs = Function(Vmixed)

# Components of various mixed functions
u_rhs, D_rhs, b_rhs = split(x_rhs)

uD_lhs = Function(Vreduced)
u_test, D_test = TestFunctions(Vreduced)
u_trial, D_trial = TrialFunctions(Vreduced)
u_lhsr, D_lhsr = uD_lhs.subfunctions

# Reference profiles
uDb_bar = Function(Vmixed)
D_bar = split(uDb_bar)[1]
b_bar = split(uDb_bar)[2]
D_bar_subfunc = uDb_bar.subfunctions[1]
b_bar_subfunc = uDb_bar.subfunctions[2]

# Set up perp function
sphere_degree = mesh.coordinates.function_space().ufl_element().degree()
VecDG = VectorFunctionSpace(mesh, 'DG', sphere_degree)
N = Function(VecDG).interpolate(CellNormal(mesh))
perp = lambda u: cross(N, u)

# Coriolis
f = Function(VCG1)

# ---------------------------------------------------------------------------- #
# Set up problem and solvers
# ---------------------------------------------------------------------------- #

# We have a 3-function mixed space for wind, depth and buoyancy
# Strategy is to approximately eliminate buoyancy, leaving a wind-depth problem
# The wind-depth problem is solved using the hydridization preconditioner

# Approximate elimination of b
b = -0.5 * dt * dot(u_trial, grad(b_bar)) + b_rhs

solver_parameters = {
'ksp_type': 'preonly',
'mat_type': 'matfree',
'pc_type': 'python',
'pc_python_type': 'firedrake.HybridizationPC',
'hybridization': {
'ksp_type': 'cg',
'pc_type': 'gamg',
'ksp_rtol': 1e-1,
'mg_levels': {
'ksp_type': 'chebyshev',
'ksp_max_it': 1,
'pc_type': 'bjacobi',
'sub_pc_type': 'ilu'
}
}
}

n = FacetNormal(mesh)

# Linear thermal shallow water problem
uD_eqn = (
inner(u_test, (u_trial - u_rhs)) * dx
- beta_u * (D_trial - D_bar) * div(u_test * b_bar) * dx
+ beta_u * jump(u_test * b_bar, n) * avg(D_trial - D_bar) * dS
- beta_u * 0.5 * D_bar * b_bar * div(u_test) * dx
- beta_u * 0.5 * D_bar * b * div(u_test) * dx
- beta_u * 0.5 * b_bar * div(u_test*(D_trial - D_bar)) * dx
+ beta_u * 0.5 * jump((D_trial - D_bar)*u_test, n) * avg(b_bar) * dS
+ inner(D_test, (D_trial - D_rhs)) * dx
+ beta_d * D_test * div(D_bar*u_trial) * dx
+ beta_u * f * inner(u_test, perp(u_trial)) * dx
)

# Boundary conditions
bcs = []

# Solver for u, D
uD_problem = LinearVariationalProblem(
lhs(uD_eqn), rhs(uD_eqn), uD_lhs, bcs=bcs, constant_jacobian=True
)

# Provide callback for the nullspace of the trace system
def trace_nullsp(T):
return VectorSpaceBasis(constant=True)

appctx = {"trace_nullspace": trace_nullsp}
uD_solver = LinearVariationalSolver(
uD_problem, solver_parameters=solver_parameters, appctx=appctx
)

# ---------------------------------------------------------------------------- #
# Set some initial conditions
# ---------------------------------------------------------------------------- #

f.interpolate(2*Omega*x[2]/radius)
D_bar_subfunc.assign(H)
b_bar_subfunc.assign(g)

# Simplest test is to give a right-hand side that is zero
x_rhs.assign(1.0)

# ---------------------------------------------------------------------------- #
# Run
# ---------------------------------------------------------------------------- #

uD_solver.solve()


@pytest.mark.parametrize('counts', [(10001, 10002), (10002, 10001)])
def test_slate_hybridization_count_safe(counts):
g_count, c_count = counts
mesh = UnitTriangleMesh()
BDM = FunctionSpace(mesh, "BDM", 2)
DG = FunctionSpace(mesh, "DG", 1)
V = BDM * DG
VectorDG = VectorFunctionSpace(mesh, 'DG', 0)
u = TrialFunction(V)
v = TestFunction(V)
g = Function(VectorDG, count=g_count)
c = Function(DG, count=c_count)
a = (
inner(u, v) * dx
+ inner(g[0] * u[0], v[0]) * dx
+ inner(c * grad(u[0]), grad(v[0])) * dx
)
sol = Function(V)
solver_parameters = {
'mat_type': 'matfree',
'ksp_type': 'preonly',
'pc_type': 'python',
'pc_python_type': 'firedrake.HybridizationPC',
}
solve(lhs(a) == rhs(a), sol, solver_parameters=solver_parameters)

0 comments on commit 9650136

Please sign in to comment.