diff --git a/gusto/preconditioners.py b/gusto/preconditioners.py index b400e2898..a471b18bc 100644 --- a/gusto/preconditioners.py +++ b/gusto/preconditioners.py @@ -9,6 +9,7 @@ from gusto.recovery.recovery_kernels import AverageKernel, AverageWeightings from pyop2.profiling import timed_region, timed_function from pyop2.utils import as_tuple +from functools import partial __all__ = ["VerticalHybridizationPC"] @@ -49,8 +50,7 @@ def initialize(self, pc): FiniteElement, TensorProductElement, TrialFunction, TrialFunctions, TestFunction, DirichletBC, interval, MixedElement, BrokenElement) - from firedrake.assemble import (allocate_matrix, OneFormAssembler, - TwoFormAssembler) + from firedrake.assemble import get_assembler from firedrake.formmanipulation import split_form from ufl.algorithms.replace import replace from ufl.cell import TensorProductCell @@ -176,22 +176,21 @@ def initialize(self, pc): # Assemble the Schur complement operator and right-hand side self.schur_rhs = Cofunction(Vv_tr.dual()) - self._assemble_Srhs = OneFormAssembler( + self._assemble_Srhs = partial(get_assembler( K * Atilde.inv * AssembledVector(self.broken_residual), - tensor=self.schur_rhs, - form_compiler_parameters=self.ctx.fc_params).assemble + form_compiler_parameters=self.ctx.fc_params).assemble, tensor=self.schur_rhs) mat_type = PETSc.Options().getString(prefix + "mat_type", "aij") schur_comp = K * Atilde.inv * K.T - self.S = allocate_matrix(schur_comp, bcs=trace_bcs, - form_compiler_parameters=self.ctx.fc_params, - mat_type=mat_type, - options_prefix=prefix) - self._assemble_S = TwoFormAssembler(schur_comp, - tensor=self.S, - bcs=trace_bcs, - form_compiler_parameters=self.ctx.fc_params).assemble + self.S = get_assembler(schur_comp, bcs=trace_bcs, + form_compiler_parameters=self.ctx.fc_params, + mat_type=mat_type, + options_prefix=prefix).allocate() + self._assemble_S = partial(get_assembler( + schur_comp, + bcs=trace_bcs, + form_compiler_parameters=self.ctx.fc_params).assemble, tensor=self.S) self._assemble_S() Smat = self.S.petscmat @@ -228,7 +227,7 @@ def _reconstruction_calls(self, split_mixed_op, split_trace_op): split_trace_op (dict): a ``dict`` of split forms that make up the trace contribution in the hybridized mixed system. """ - from firedrake.assemble import OneFormAssembler + from firedrake.assemble import get_assembler # We always eliminate the velocity block first id0, id1 = (self.vidx, self.pidx) @@ -256,15 +255,15 @@ def _reconstruction_calls(self, split_mixed_op, split_trace_op): R = K_1.T - C * A.inv * K_0.T u_rec = M.solve(f - C * A.inv * g - R * lambdar, decomposition="PartialPivLU") - self._sub_unknown = OneFormAssembler(u_rec, - tensor=u, - form_compiler_parameters=self.ctx.fc_params).assemble + self._sub_unknown = partial(get_assembler( + u_rec, + form_compiler_parameters=self.ctx.fc_params).assemble, tensor=u) sigma_rec = A.solve(g - B * AssembledVector(u) - K_0.T * lambdar, decomposition="PartialPivLU") - self._elim_unknown = OneFormAssembler(sigma_rec, - tensor=sigma, - form_compiler_parameters=self.ctx.fc_params).assemble + self._elim_unknown = partial(get_assembler( + sigma_rec, + form_compiler_parameters=self.ctx.fc_params).assemble, tensor=sigma) @timed_function("VertHybridRecon") def _reconstruct(self):