Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

slate: fix kernel argument ordering #4003

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions firedrake/slate/slac/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,24 @@ def collect_tsfc_kernel_data(self, mesh, tsfc_coefficients, tsfc_constants, wrap

# Pick the coefficients associated with a Tensor()/TSFC kernel
tsfc_coefficients = {tsfc_coefficients[i]: indices for i, indices in kinfo.coefficient_numbers}
for c, cinfo in wrapper_coefficients.items():
if c in tsfc_coefficients:
if isinstance(cinfo, tuple):
if tsfc_coefficients[c]:
ind, = tsfc_coefficients[c]
if ind != 0:
raise ValueError(f"Active indices of non-mixed function must be (0, ), not {tsfc_coefficients[c]}")
kernel_data.append((c, cinfo[0]))
else:
for ind, (c_, info) in enumerate(cinfo.items()):
if ind in tsfc_coefficients[c]:
kernel_data.append((c_, info[0]))
for c in tsfc_coefficients:
cinfo = wrapper_coefficients[c]
if isinstance(cinfo, tuple):
if tsfc_coefficients[c]:
ind, = tsfc_coefficients[c]
if ind != 0:
raise ValueError(f"Active indices of non-mixed function must be (0, ), not {tsfc_coefficients[c]}")
kernel_data.append((c, cinfo[0]))
else:
for ind, (c_, info) in enumerate(cinfo.items()):
if ind in tsfc_coefficients[c]:
kernel_data.append((c_, info[0]))

# Pick the constants associated with a Tensor()/TSFC kernel
tsfc_constants = tuple(tsfc_constants[i] for i in kinfo.constant_numbers)
kernel_data.extend([
(constant, constant_name)
for constant, constant_name in wrapper_constants
if constant in tsfc_constants
])
wrapper_constants = dict(wrapper_constants)
for c in tsfc_constants:
kernel_data.append((c, wrapper_constants[c]))
return kernel_data

def loopify_tsfc_kernel_data(self, kernel_data):
Expand Down
162 changes: 162 additions & 0 deletions tests/firedrake/slate/test_slate_hybridization.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,165 @@ def test_mixed_poisson_approximated_schur_jacobi_prec(setup_poisson):

assert sigma_err < 1e-8
assert u_err < 1e-8


def test_slate_hybridization_gusto():

# ---------------------------------------------------------------------------- #
# Set up core objects
# ---------------------------------------------------------------------------- #

degree = 1
radius = Constant(6371220.)
dt = Constant(3000.0)
H = Constant(3e4/9.80616)
g = Constant(9.80616)
Omega = Constant(7.292e-5)
alpha = Constant(0.5)
beta_u = alpha*dt
beta_d = alpha*dt

mesh = IcosahedralSphereMesh(float(radius), refinement_level=0, degree=2)
x = SpatialCoordinate(mesh)

# Function spaces
VDG = FunctionSpace(mesh, "DG", degree)
VHDiv = FunctionSpace(mesh, "BDM", degree+1)
mesh.init_cell_orientations(x)

VCG1 = FunctionSpace(mesh, 'CG', 1)
Vmixed = MixedFunctionSpace((VHDiv, VDG, VDG))
Vreduced = MixedFunctionSpace((VHDiv, VDG))

# Mixed RHS and LHS functions
x_rhs = Function(Vmixed)

# Components of various mixed functions
u_rhs, D_rhs, b_rhs = split(x_rhs)

uD_lhs = Function(Vreduced)
u_test, D_test = TestFunctions(Vreduced)
u_trial, D_trial = TrialFunctions(Vreduced)
u_lhsr, D_lhsr = uD_lhs.subfunctions

# Reference profiles
uDb_bar = Function(Vmixed)
D_bar = split(uDb_bar)[1]
b_bar = split(uDb_bar)[2]
D_bar_subfunc = uDb_bar.subfunctions[1]
b_bar_subfunc = uDb_bar.subfunctions[2]

# Set up perp function
sphere_degree = mesh.coordinates.function_space().ufl_element().degree()
VecDG = VectorFunctionSpace(mesh, 'DG', sphere_degree)
N = Function(VecDG).interpolate(CellNormal(mesh))
perp = lambda u: cross(N, u)

# Coriolis
f = Function(VCG1)

# ---------------------------------------------------------------------------- #
# Set up problem and solvers
# ---------------------------------------------------------------------------- #

# We have a 3-function mixed space for wind, depth and buoyancy
# Strategy is to approximately eliminate buoyancy, leaving a wind-depth problem
# The wind-depth problem is solved using the hydridization preconditioner

# Approximate elimination of b
b = -0.5 * dt * dot(u_trial, grad(b_bar)) + b_rhs

solver_parameters = {
'ksp_type': 'preonly',
'mat_type': 'matfree',
'pc_type': 'python',
'pc_python_type': 'firedrake.HybridizationPC',
'hybridization': {
'ksp_type': 'cg',
'pc_type': 'gamg',
'ksp_rtol': 1e-1,
'mg_levels': {
'ksp_type': 'chebyshev',
'ksp_max_it': 1,
'pc_type': 'bjacobi',
'sub_pc_type': 'ilu'
}
}
}

n = FacetNormal(mesh)

# Linear thermal shallow water problem
uD_eqn = (
inner(u_test, (u_trial - u_rhs)) * dx
- beta_u * (D_trial - D_bar) * div(u_test * b_bar) * dx
+ beta_u * jump(u_test * b_bar, n) * avg(D_trial - D_bar) * dS
- beta_u * 0.5 * D_bar * b_bar * div(u_test) * dx
- beta_u * 0.5 * D_bar * b * div(u_test) * dx
- beta_u * 0.5 * b_bar * div(u_test*(D_trial - D_bar)) * dx
+ beta_u * 0.5 * jump((D_trial - D_bar)*u_test, n) * avg(b_bar) * dS
+ inner(D_test, (D_trial - D_rhs)) * dx
+ beta_d * D_test * div(D_bar*u_trial) * dx
+ beta_u * f * inner(u_test, perp(u_trial)) * dx
)

# Boundary conditions
bcs = []

# Solver for u, D
uD_problem = LinearVariationalProblem(
lhs(uD_eqn), rhs(uD_eqn), uD_lhs, bcs=bcs, constant_jacobian=True
)

# Provide callback for the nullspace of the trace system
def trace_nullsp(T):
return VectorSpaceBasis(constant=True)

appctx = {"trace_nullspace": trace_nullsp}
uD_solver = LinearVariationalSolver(
uD_problem, solver_parameters=solver_parameters, appctx=appctx
)

# ---------------------------------------------------------------------------- #
# Set some initial conditions
# ---------------------------------------------------------------------------- #

f.interpolate(2*Omega*x[2]/radius)
D_bar_subfunc.assign(H)
b_bar_subfunc.assign(g)

# Simplest test is to give a right-hand side that is zero
x_rhs.assign(1.0)

# ---------------------------------------------------------------------------- #
# Run
# ---------------------------------------------------------------------------- #

uD_solver.solve()


@pytest.mark.parametrize('counts', [(10001, 10002), (10002, 10001)])
def test_slate_hybridization_count_safe(counts):
g_count, c_count = counts
mesh = UnitTriangleMesh()
BDM = FunctionSpace(mesh, "BDM", 2)
DG = FunctionSpace(mesh, "DG", 1)
V = BDM * DG
VectorDG = VectorFunctionSpace(mesh, 'DG', 0)
u = TrialFunction(V)
v = TestFunction(V)
g = Function(VectorDG, count=g_count)
c = Function(DG, count=c_count)
a = (
inner(u, v) * dx
+ inner(g[0] * u[0], v[0]) * dx
+ inner(c * grad(u[0]), grad(v[0])) * dx
)
sol = Function(V)
solver_parameters = {
'mat_type': 'matfree',
'ksp_type': 'preonly',
'pc_type': 'python',
'pc_python_type': 'firedrake.HybridizationPC',
}
solve(lhs(a) == rhs(a), sol, solver_parameters=solver_parameters)
Loading