From d3e72fb9f05263824f134ed9346df5ba90665f81 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Tue, 20 Aug 2024 23:11:49 +0100 Subject: [PATCH] handle restrictions on MixedMesh --- ...st_mixed_function_space_with_mixed_mesh.py | 55 ++++- ufl/algorithms/apply_coefficient_split.py | 210 ++++++++++++++++++ ufl/algorithms/apply_restrictions.py | 196 ++++++++++++++-- ufl/algorithms/balancing.py | 9 +- ufl/algorithms/check_arities.py | 2 + ufl/algorithms/compute_form_data.py | 75 ++++++- ufl/algorithms/domain_analysis.py | 4 +- ufl/algorithms/estimate_degrees.py | 8 + ufl/algorithms/formtransformations.py | 2 + ufl/exproperators.py | 10 +- ufl/formatting/ufl2unicode.py | 10 + ufl/restriction.py | 16 ++ 12 files changed, 559 insertions(+), 38 deletions(-) create mode 100644 ufl/algorithms/apply_coefficient_split.py diff --git a/test/test_mixed_function_space_with_mixed_mesh.py b/test/test_mixed_function_space_with_mixed_mesh.py index 679d1ff14..b986e7e26 100644 --- a/test/test_mixed_function_space_with_mixed_mesh.py +++ b/test/test_mixed_function_space_with_mixed_mesh.py @@ -1,5 +1,5 @@ from ufl import (triangle, Mesh, MixedMesh, FunctionSpace, TestFunction, TrialFunction, Coefficient, - Measure, SpatialCoordinate, FacetNormal, CellVolume, FacetArea, inner, grad, split, ) + Measure, SpatialCoordinate, FacetNormal, CellVolume, FacetArea, inner, grad, div, split, ) from ufl.algorithms import compute_form_data from ufl.finiteelement import FiniteElement, MixedElement from ufl.pullback import identity_pullback, contravariant_piola @@ -27,8 +27,9 @@ def test_mixed_function_space_with_mixed_mesh_basic(): f0, f1, f2 = split(f) g0, g1, g2 = split(g) dx1 = Measure("dx", mesh1) + ds2 = Measure("ds", mesh2) x = SpatialCoordinate(mesh1) - form = x[1] * f0 * inner(grad(u0), v1) * dx1(999) + form = x[1] * f0 * inner(grad(u0), v1) * dx1(999) + div(f1) * g2 * inner(u1, grad(v2)) * ds2(888) fd = compute_form_data(form, do_apply_function_pullbacks=True, do_apply_integral_scaling=True, @@ -37,16 +38,60 @@ def test_mixed_function_space_with_mixed_mesh_basic(): do_apply_restrictions=True, do_estimate_degrees=True, complex_mode=False) - id0, = fd.integral_data + id0, id1 = fd.integral_data assert fd.preprocessed_form.arguments() == (v, u) - assert fd.reduced_coefficients == [f] + assert fd.reduced_coefficients == [f, g] assert form.coefficients()[fd.original_coefficient_positions[0]] is f + assert form.coefficients()[fd.original_coefficient_positions[1]] is g assert id0.domain is mesh1 assert id0.integral_type == 'cell' assert id0.subdomain_id == (999, ) assert fd.original_form.domain_numbering()[id0.domain] == 0 assert id0.integral_coefficients == set([f]) - assert id0.enabled_coefficients == [True] + assert id0.enabled_coefficients == [True, False] + assert id1.domain is mesh2 + assert id1.integral_type == 'exterior_facet' + assert id1.subdomain_id == (888, ) + assert fd.original_form.domain_numbering()[id1.domain] == 1 + assert id1.integral_coefficients == set([f, g]) + assert id1.enabled_coefficients == [True, True] + + +def test_mixed_function_space_with_mixed_mesh_restriction(): + cell = triangle + elem0 = FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1) + elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 1, (2, ), contravariant_piola, HDiv) + elem2 = FiniteElement("Discontinuous Lagrange", cell, 0, (), identity_pullback, L2) + elem = MixedElement([elem0, elem1, elem2]) + mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=100) + mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=101) + mesh2 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=102) + domain = MixedMesh(mesh0, mesh1, mesh2) + V = FunctionSpace(domain, elem) + V1 = FunctionSpace(mesh1, elem1) + V2 = FunctionSpace(mesh2, elem2) + u1 = TrialFunction(V1) + v2 = TestFunction(V2) + f = Coefficient(V, count=1000) + g = Coefficient(V, count=2000) + f0, f1, f2 = split(f) + g0, g1, g2 = split(g) + dS1 = Measure("dS", mesh1) + x2 = SpatialCoordinate(mesh2) + form = inner(x2, g1) * g2 * inner(u1('-'), grad(v2('|'))) * dS1(999) + fd = compute_form_data(form, + do_apply_function_pullbacks=True, + do_apply_integral_scaling=True, + do_apply_geometry_lowering=True, + preserve_geometry_types=(CellVolume, FacetArea), + do_apply_restrictions=True, + do_estimate_degrees=True, + do_split_coefficients=(f, g), + do_assume_single_integral_type=False, + complex_mode=False) + integral_data, = fd.integral_data + assert integral_data.domain_integral_type_map[mesh1] == "interior_facet" + assert integral_data.domain_integral_type_map[mesh2] == "exterior_facet" def test_mixed_function_space_with_mixed_mesh_signature(): diff --git a/ufl/algorithms/apply_coefficient_split.py b/ufl/algorithms/apply_coefficient_split.py new file mode 100644 index 000000000..4e7d1200d --- /dev/null +++ b/ufl/algorithms/apply_coefficient_split.py @@ -0,0 +1,210 @@ +"""Apply coefficient split. + +This module contains classes and functions to split coefficients defined on mixed function spaces. +""" + +import numpy +from ufl.classes import Restricted +from ufl.corealg.map_dag import map_expr_dag +from ufl.corealg.multifunction import MultiFunction, memoized_handler +from ufl.domain import extract_unique_domain +from ufl.classes import (Coefficient, Form, ReferenceGrad, ReferenceValue, + Indexed, MultiIndex, Index, FixedIndex, + ComponentTensor, ListTensor, Zero, + NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted) +from ufl import indices +from ufl.checks import is_cellwise_constant +from ufl.tensors import as_tensor + + +class CoefficientSplitter(MultiFunction): + + def __init__(self, coefficient_split): + MultiFunction.__init__(self) + self._coefficient_split = coefficient_split + + expr = MultiFunction.reuse_if_untouched + + def modified_terminal(self, o): + restriction = None + local_derivatives = 0 + reference_value = False + t = o + while not t._ufl_is_terminal_: + assert t._ufl_is_terminal_modifier_, f"Got {repr(t)}" + if isinstance(t, ReferenceValue): + assert not reference_value, "Got twice pulled back terminal!" + reference_value = True + t, = t.ufl_operands + elif isinstance(t, ReferenceGrad): + local_derivatives += 1 + t, = t.ufl_operands + elif isinstance(t, Restricted): + assert restriction is None, "Got twice restricted terminal!" + restriction = t._side + t, = t.ufl_operands + elif t._ufl_terminal_modifiers_: + raise ValueError("Missing handler for terminal modifier type %s, object is %s." % (type(t), repr(t))) + else: + raise ValueError("Unexpected type %s object %s." % (type(t), repr(t))) + if not isinstance(t, Coefficient): + # Only split coefficients + return o + if t not in self._coefficient_split: + # Only split mixed coefficients + return o + # Reference value expected + assert reference_value + # Derivative indices + beta = indices(local_derivatives) + components = [] + for subcoeff in self._coefficient_split[t]: + c = subcoeff + # Apply terminal modifiers onto the subcoefficient + if reference_value: + c = ReferenceValue(c) + for n in range(local_derivatives): + # Return zero if expression is trivially constant. This has to + # happen here because ReferenceGrad has no access to the + # topological dimension of a literal zero. + if is_cellwise_constant(c): + dim = extract_unique_domain(subcoeff).topological_dimension() + c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions) + else: + c = ReferenceGrad(c) + if restriction == '+': + c = PositiveRestricted(c) + elif restriction == '-': + c = NegativeRestricted(c) + elif restriction == '|': + c = SingleValueRestricted(c) + elif restriction == '?': + c = ToBeRestricted(c) + elif restriction is not None: + raise RuntimeError(f"Got unknown restriction: {restriction}") + # Collect components of the subcoefficient + for alpha in numpy.ndindex(subcoeff.ufl_element().reference_value_shape): + # New modified terminal: component[alpha + beta] + components.append(c[alpha + beta]) + # Repack derivative indices to shape + c, = indices(1) + return ComponentTensor(as_tensor(components)[c], MultiIndex((c,) + beta)) + + positive_restricted = modified_terminal + negative_restricted = modified_terminal + single_value_restricted = modified_terminal + to_be_restricted = modified_terminal + reference_grad = modified_terminal + reference_value = modified_terminal + terminal = modified_terminal + + +def apply_coefficient_split(expr, coefficient_split): + """Split mixed coefficients, so mixed elements need not be + implemented. + + :arg split: A :py:class:`dict` mapping each mixed coefficient to a + sequence of subcoefficients. If None, calling this + function is a no-op. + """ + if coefficient_split is None: + return expr + splitter = CoefficientSplitter(coefficient_split) + return map_expr_dag(splitter, expr) + + +class FixedIndexRemover(MultiFunction): + + def __init__(self, fimap): + MultiFunction.__init__(self) + self.fimap = fimap + self._object_cache = {} + + expr = MultiFunction.reuse_if_untouched + + @memoized_handler + def zero(self, o): + free_indices = [] + index_dimensions = [] + for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions): + if Index(i) in self.fimap: + ind_j = self.fimap[Index(i)] + if not isinstance(ind_j, FixedIndex): + free_indices.append(ind_j.count()) + index_dimensions.append(d) + else: + free_indices.append(i) + index_dimensions.append(d) + return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions)) + + @memoized_handler + def list_tensor(self, o): + cc = [] + for o1 in o.ufl_operands: + comp = map_expr_dag(self, o1) + cc.append(comp) + return ListTensor(*cc) + + @memoized_handler + def multi_index(self, o): + return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices())) + + +class IndexRemover(MultiFunction): + + def __init__(self): + MultiFunction.__init__(self) + self._object_cache = {} + + expr = MultiFunction.reuse_if_untouched + + @memoized_handler + def _zero_simplify(self, o): + operand, = o.ufl_operands + operand = map_expr_dag(self, operand) + if isinstance(operand, Zero): + return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions) + else: + return o._ufl_expr_reconstruct_(operand) + + @memoized_handler + def indexed(self, o): + o1, i1 = o.ufl_operands + if isinstance(o1, ComponentTensor): + o2, i2 = o1.ufl_operands + assert len(i2.indices()) == len(i1.indices()) + fimap = dict(zip(i2.indices(), i1.indices())) + rule = FixedIndexRemover(fimap) + v = map_expr_dag(rule, o2) + return map_expr_dag(self, v) + elif isinstance(o1, ListTensor): + if isinstance(i1[0], FixedIndex): + o1 = o1.ufl_operands[i1[0]._value] + if len(i1) > 1: + i1 = MultiIndex(i1[1:]) + return map_expr_dag(self, Indexed(o1, i1)) + else: + return map_expr_dag(self, o1) + o1 = map_expr_dag(self, o1) + return Indexed(o1, i1) + + # Do something nicer + positive_restricted = _zero_simplify + negative_restricted = _zero_simplify + single_value_restricted = _zero_simplify + to_be_restricted = _zero_simplify + reference_grad = _zero_simplify + reference_value = _zero_simplify + + +def remove_component_and_list_tensors(o): + if isinstance(o, Form): + integrals = [] + for integral in o.integrals(): + integrand = remove_component_and_list_tensors(integral.integrand()) + if not isinstance(integrand, Zero): + integrals.append(integral.reconstruct(integrand=integrand)) + return o._ufl_expr_reconstruct_(integrals) + else: + rule = IndexRemover() + return map_expr_dag(rule, o) diff --git a/ufl/algorithms/apply_restrictions.py b/ufl/algorithms/apply_restrictions.py index 8f788a009..18353306d 100644 --- a/ufl/algorithms/apply_restrictions.py +++ b/ufl/algorithms/apply_restrictions.py @@ -10,30 +10,34 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later - from ufl.algorithms.map_integrands import map_integrand_dags from ufl.classes import Restricted from ufl.corealg.map_dag import map_expr_dag from ufl.corealg.multifunction import MultiFunction -from ufl.domain import extract_unique_domain +from ufl.domain import extract_unique_domain, MixedMesh from ufl.measure import integral_type_to_measure_name from ufl.sobolevspace import H1 +from ufl.classes import ReferenceGrad, ReferenceValue +from ufl.restriction import PositiveRestricted, SingleValueRestricted class RestrictionPropagator(MultiFunction): """Restriction propagator.""" - def __init__(self, side=None): + def __init__(self, side=None, assume_single_integral_type=True): """Initialise.""" MultiFunction.__init__(self) self.current_restriction = side - self.default_restriction = "+" + self.default_restriction = "+" if assume_single_integral_type else "?" # Caches for propagating the restriction with map_expr_dag - self.vcaches = {"+": {}, "-": {}} - self.rcaches = {"+": {}, "-": {}} + self.vcaches = {"+": {}, "-": {}, "|": {}, "?": {}} + self.rcaches = {"+": {}, "-": {}, "|": {}, "?": {}} if self.current_restriction is None: - self._rp = {"+": RestrictionPropagator("+"), - "-": RestrictionPropagator("-")} + self._rp = {"+": RestrictionPropagator("+", assume_single_integral_type), + "-": RestrictionPropagator("-", assume_single_integral_type), + "|": RestrictionPropagator("|", assume_single_integral_type), + "?": RestrictionPropagator("?", assume_single_integral_type)} + self.assume_single_integral_type = assume_single_integral_type def restricted(self, o): """When hitting a restricted quantity, visit child with a separate restriction algorithm.""" @@ -55,9 +59,12 @@ def _ignore_restriction(self, o): def _require_restriction(self, o): """Restrict a discontinuous quantity to current side, require a side to be set.""" - if self.current_restriction is None: + if self.current_restriction is not None: + return o(self.current_restriction) + elif not self.assume_single_integral_type: + return o + else: raise ValueError(f"Discontinuous type {o._ufl_class_.__name__} must be restricted.") - return o(self.current_restriction) def _default_restricted(self, o): """Restrict a continuous quantity to default side if no current restriction is set.""" @@ -172,11 +179,17 @@ def facet_normal(self, o): return self._require_restriction(o) -def apply_restrictions(expression): +def apply_restrictions(expression, assume_single_integral_type=True): """Propagate restriction nodes to wrap differential terminals directly.""" - integral_types = [k for k in integral_type_to_measure_name.keys() - if k.startswith("interior_facet")] - rules = RestrictionPropagator() + if assume_single_integral_type: + integral_types = [k for k in integral_type_to_measure_name.keys() + if k.startswith("interior_facet")] + else: + # Integration type of the integral is not necessarily the same as + # the integral type of a given function; e.g., the former can be + # ``exterior_facet`` and the latter ``interior_facet``. + integral_types = None + rules = RestrictionPropagator(assume_single_integral_type=assume_single_integral_type) return map_integrand_dags(rules, expression, only_integral_type=integral_types) @@ -184,14 +197,15 @@ def apply_restrictions(expression): class DefaultRestrictionApplier(MultiFunction): """Default restriction applier.""" - def __init__(self, side=None): + def __init__(self, side=None, assume_single_integral_type=True): """Initialise.""" MultiFunction.__init__(self) self.current_restriction = side - self.default_restriction = "+" - if self.current_restriction is None: - self._rp = {"+": DefaultRestrictionApplier("+"), - "-": DefaultRestrictionApplier("-")} + # If multiple domains exist, the restriction on a function defined on + # a certain domain can not be determined by merely inspecting the + # local part of the DAG. "?" restrictions will be replaced with the + # appropriate ones later using ``replace_to_be_restricted`` function. + self.default_restriction = "+" if assume_single_integral_type else "?" def terminal(self, o): """Apply to terminal.""" @@ -236,13 +250,149 @@ def _default_restricted(self, o): facet_origin = _default_restricted # FIXME: Is this valid for quads? -def apply_default_restrictions(expression): +def apply_default_restrictions(expression, assume_single_integral_type=True): """Some terminals can be restricted from either side. This applies a default restriction to such terminals if unrestricted. """ - integral_types = [k for k in integral_type_to_measure_name.keys() - if k.startswith("interior_facet")] - rules = DefaultRestrictionApplier() + if assume_single_integral_type: + integral_types = [k for k in integral_type_to_measure_name.keys() + if k.startswith("interior_facet")] + else: + integral_types = None + rules = DefaultRestrictionApplier(assume_single_integral_type=assume_single_integral_type) return map_integrand_dags(rules, expression, only_integral_type=integral_types) + + +class DomainRestrictionMapMaker(MultiFunction): + """Make a map from domains to restriction(s). + + Inspect the DAG and collect domain-restrictions map. + This must be done per integral_data. + """ + + def __init__(self, domain_restriction_map): + MultiFunction.__init__(self) + self._domain_restriction_map = domain_restriction_map + + expr = MultiFunction.reuse_if_untouched + + def _modifier(self, o): + restriction = None + local_derivatives = 0 + reference_value = False + t = o + while not t._ufl_is_terminal_: + assert t._ufl_is_terminal_modifier_, f"Expecting a terminal modifier: got {repr(t)}" + if isinstance(t, ReferenceValue): + assert not reference_value, "Got twice pulled back terminal" + reference_value = True + t, = t.ufl_operands + elif isinstance(t, ReferenceGrad): + local_derivatives += 1 + t, = t.ufl_operands + elif isinstance(t, Restricted): + assert restriction is None, "Got twice restricted terminal" + restriction = t._side + t, = t.ufl_operands + elif t._ufl_terminal_modifiers_: + raise ValueError("Missing handler for terminal modifier type %s, object is %s." % (type(t), repr(t))) + else: + raise ValueError("Unexpected type %s object %s." % (type(t), repr(t))) + domain = extract_unique_domain(t, expand_mixed_mesh=False) + if isinstance(domain, MixedMesh): + raise RuntimeError(f"Not expecting a terminal object on a mixed mesh at this stage: found {repr(t)}") + if domain is not None: + if domain not in self._domain_restriction_map: + self._domain_restriction_map[domain] = set() + if restriction in ['+', '-', '|']: + self._domain_restriction_map[domain].add(restriction) + elif restriction not in ['?', None]: + raise RuntimeError + return o + + reference_value = _modifier + reference_grad = _modifier + positive_restricted = _modifier + negative_restricted = _modifier + single_value_restricted = _modifier + to_be_restricted = _modifier + terminal = _modifier + + +def make_domain_restriction_map(integral_data): + """Make domain-restriction map for the given integral_data.""" + domain_restriction_map = {} + rule = DomainRestrictionMapMaker(domain_restriction_map) + for integral in integral_data.integrals: + _ = map_expr_dag(rule, integral.integrand()) + return domain_restriction_map + + +def make_domain_integral_type_map(integral_data): + domain_restriction_map = make_domain_restriction_map(integral_data) + integration_domain = integral_data.domain + integration_type = integral_data.integral_type + domain_integral_type_dict = {} + for d, rs in domain_restriction_map.items(): + if rs in [{'+'}, {'-'}, {'+', '-'}]: + domain_integral_type_dict[d] = "interior_facet" + elif rs == {'|'}: + domain_integral_type_dict[d] = "exterior_facet" + elif rs == set(): + if d.topological_dimension() == integration_domain.topological_dimension(): + if integration_type == "cell": + domain_integral_type_dict[d] = "cell" + elif integration_type in ["exterior_facet", "interior_facet"]: + domain_integral_type_dict[d] = "exterior_facet" + else: + raise NotImplementedError + else: + raise NotImplementedError + else: + raise RuntimeError(f"Found inconsistent restrictions {rs} for domain {d}") + if integration_domain in domain_integral_type_dict: + if domain_integral_type_dict[integration_domain] != integration_type: + raise RuntimeError(f"""Found inconsistent integral types for the integration domain ({integration_domain}) : + {domain_integral_type_dict[integration_domain]} != {integration_type}""") + else: + domain_integral_type_dict[integration_domain] = integration_type + return domain_integral_type_dict + + +class ToBeRestrectedReplacer(MultiFunction): + """Replace ``?`` restrictions.""" + + def __init__(self, domain_integral_type_map): + MultiFunction.__init__(self) + self.domain_integral_type_map = domain_integral_type_map + + expr = MultiFunction.reuse_if_untouched + + def to_be_restricted(self, o): + mt, = o.ufl_operands + domain = extract_unique_domain(mt) + if isinstance(domain, MixedMesh): + raise RuntimeError(f"""Not expecting a (modified) terminal object on a mixed mesh at this stage : + got {repr(o)}""") + if domain not in self.domain_integral_type_map: + raise RuntimeError(f"Integral type on {domain} not known") + integral_type = self.domain_integral_type_map[domain] + if integral_type == "cell": + return mt + elif integral_type == "exterior_facet": + return SingleValueRestricted(mt) + elif integral_type == "interial_facet": + return PositiveRestricted(mt) + else: + raise RuntimeError(f"Unknown integral type: {integral_type}") + + +def replace_to_be_restricted(integral_data): + new_integrals = [] + rule = ToBeRestrectedReplacer(integral_data.domain_integral_type_map) + for integral in integral_data.integrals: + integrand = map_expr_dag(rule, integral.integrand()) + new_integrals.append(integral.reconstruct(integrand=integrand)) + return new_integrals diff --git a/ufl/algorithms/balancing.py b/ufl/algorithms/balancing.py index 477ec3f6f..19d1d7ba5 100644 --- a/ufl/algorithms/balancing.py +++ b/ufl/algorithms/balancing.py @@ -6,14 +6,15 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later -from ufl.classes import (CellAvg, FacetAvg, Grad, Indexed, NegativeRestricted, PositiveRestricted, ReferenceGrad, - ReferenceValue) +from ufl.classes import (CellAvg, FacetAvg, Grad, Indexed, + NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted, + ReferenceGrad, ReferenceValue) from ufl.corealg.map_dag import map_expr_dag from ufl.corealg.multifunction import MultiFunction modifier_precedence = [ ReferenceValue, ReferenceGrad, Grad, CellAvg, FacetAvg, PositiveRestricted, - NegativeRestricted, Indexed + NegativeRestricted, SingleValueRestricted, ToBeRestricted, Indexed ] modifier_precedence = { @@ -76,6 +77,8 @@ def _modifier(self, expr, *ops): facet_avg = _modifier positive_restricted = _modifier negative_restricted = _modifier + single_value_restricted = _modifier + to_be_restricted = _modifier def balance_modifiers(expr): diff --git a/ufl/algorithms/check_arities.py b/ufl/algorithms/check_arities.py index c93727ad8..1c661add6 100644 --- a/ufl/algorithms/check_arities.py +++ b/ufl/algorithms/check_arities.py @@ -101,6 +101,8 @@ def linear_operator(self, o, a): # Positive and negative restrictions behave as linear operators positive_restricted = linear_operator negative_restricted = linear_operator + single_value_restricted = linear_operator + to_be_restricted = linear_operator # Cell and facet average are linear operators cell_avg = linear_operator diff --git a/ufl/algorithms/compute_form_data.py b/ufl/algorithms/compute_form_data.py index a99361e21..f67ce8fc4 100644 --- a/ufl/algorithms/compute_form_data.py +++ b/ufl/algorithms/compute_form_data.py @@ -18,7 +18,9 @@ from ufl.algorithms.apply_function_pullbacks import apply_function_pullbacks from ufl.algorithms.apply_geometry_lowering import apply_geometry_lowering from ufl.algorithms.apply_integral_scaling import apply_integral_scaling -from ufl.algorithms.apply_restrictions import apply_default_restrictions, apply_restrictions +from ufl.algorithms.apply_restrictions import (apply_default_restrictions, apply_restrictions, + replace_to_be_restricted, make_domain_integral_type_map) +from ufl.algorithms.apply_coefficient_split import apply_coefficient_split, remove_component_and_list_tensors from ufl.algorithms.check_arities import check_form_arity from ufl.algorithms.comparison_checker import do_comparison_check # See TODOs at the call sites of these below: @@ -28,10 +30,12 @@ from ufl.algorithms.formdata import FormData from ufl.algorithms.formtransformations import compute_form_arities from ufl.algorithms.remove_complex_nodes import remove_complex_nodes +from ufl.algorithms.replace import replace from ufl.classes import Coefficient, Form, FunctionSpace, GeometricFacetQuantity from ufl.corealg.traversal import traverse_unique_terminals from ufl.domain import extract_unique_domain, extract_domains, MixedMesh from ufl.utils.sequences import max_degree +from ufl.constantvalue import Zero def _auto_select_degree(elements): @@ -248,6 +252,8 @@ def compute_form_data( do_apply_geometry_lowering=False, preserve_geometry_types=(), do_apply_default_restrictions=True, do_apply_restrictions=True, do_estimate_degrees=True, do_append_everywhere_integrals=True, + do_assume_single_integral_type=True, + do_split_coefficients=None, complex_mode=False, ): """Compute form data. @@ -295,9 +301,17 @@ def compute_form_data( if do_apply_integral_scaling: form = apply_integral_scaling(form) + # Can allow for some simplifications if there indeed is only a single domain + if not do_assume_single_integral_type: + have_single_domain = len(extract_domains(form)) == 1 + # Apply default restriction to fully continuous terminals if do_apply_default_restrictions: - form = apply_default_restrictions(form) + if do_assume_single_integral_type: + form = apply_default_restrictions(form) + else: + # Apply '?' restrictions in general multi-domain problems + form = apply_default_restrictions(form, assume_single_integral_type=have_single_domain) # Lower abstractions for geometric quantities into a smaller set # of quantities, allowing the form compiler to deal with a smaller @@ -323,7 +337,10 @@ def compute_form_data( # Propagate restrictions to terminals if do_apply_restrictions: - form = apply_restrictions(form) + if do_assume_single_integral_type: + form = apply_restrictions(form) + else: + form = apply_restrictions(form, assume_single_integral_type=have_single_domain) # If in real mode, remove any complex nodes introduced during form processing. if not complex_mode: @@ -401,6 +418,58 @@ def compute_form_data( # compatible data structure. self.max_subdomain_ids = _compute_max_subdomain_ids(self.integral_data) + # Split coefficients that are contained in ``do_split_coefficients`` tuple + # into components and store a dict in ``self`` that maps + # each coefficient to its components + if do_split_coefficients is not None: + coefficient_split = {} + for o in self.reduced_coefficients: + c = self.function_replace_map[o] + elem = c.ufl_element() + mesh = extract_unique_domain(c, expand_mixed_mesh=False) + # Use MixedMesh as an indicator for MixedElement as + # the followings would be ambiguous: + # -- elem.num_sub_elements > 1 + # -- isinstance(elem.pullback, MixedPullback) + if isinstance(mesh, MixedMesh) and o in do_split_coefficients: + assert len(mesh) == len(elem.sub_elements) + coefficient_split[c] = [Coefficient(FunctionSpace(m, e)) + for m, e in zip(mesh, elem.sub_elements)] + self.coefficient_split = coefficient_split + for itg_data in self.integral_data: + new_integrals = [] + for integral in itg_data.integrals: + integrand = replace(integral.integrand(), self.function_replace_map) + integrand = remove_component_and_list_tensors(integrand) + integrand = apply_coefficient_split(integrand, self.coefficient_split) + integrand = remove_component_and_list_tensors(integrand) + if not isinstance(integrand, Zero): + new_integrals.append(integral.reconstruct(integrand=integrand)) + itg_data.integrals = new_integrals + else: + self.coefficient_split = {} + + # Make ``itg_data.domain_integral_type_map``; this is only significant + # when we handle general multi-domain problems + if do_assume_single_integral_type: + for itg_data in self.integral_data: + itg_data.domain_integral_type_map = {itg_data.domain: itg_data.integral_type} + else: + if have_single_domain: + # Make a short-cut; there is no '?' restrictions by construction + for itg_data in self.integral_data: + itg_data.domain_integral_type_map = {itg_data.domain: itg_data.integral_type} + else: + # Inspect the form and replacce all '?' restrictions with appropriate ones + # in general multi-domain problems; we must have split coefficients into components + # to simplify the DAG and facilitate this inspection + if do_split_coefficients is None: + raise ValueError("""Need to pass 'do_split_coefficients=tuple_of_coefficients_to_splilt' + for general multi-domain problems""") + for itg_data in self.integral_data: + itg_data.domain_integral_type_map = make_domain_integral_type_map(itg_data) + itg_data.integrals = replace_to_be_restricted(itg_data) + # --- Checks _check_elements(self) _check_facet_geometry(self.integral_data) diff --git a/ufl/algorithms/domain_analysis.py b/ufl/algorithms/domain_analysis.py index 3a11b123a..ff756e260 100644 --- a/ufl/algorithms/domain_analysis.py +++ b/ufl/algorithms/domain_analysis.py @@ -29,7 +29,8 @@ class IntegralData(object): __slots__ = ('domain', 'integral_type', 'subdomain_id', 'integrals', 'metadata', 'integral_coefficients', - 'enabled_coefficients') + 'enabled_coefficients', + 'domain_integral_type_map') def __init__(self, domain, integral_type, subdomain_id, integrals, metadata): @@ -51,6 +52,7 @@ def __init__(self, domain, integral_type, subdomain_id, integrals, # this stage: self.integral_coefficients = None self.enabled_coefficients = None + self.domain_integral_type_map = None # TODO: I think we can get rid of this with some refactoring # in ffc: diff --git a/ufl/algorithms/estimate_degrees.py b/ufl/algorithms/estimate_degrees.py index e26be163c..f288a9364 100644 --- a/ufl/algorithms/estimate_degrees.py +++ b/ufl/algorithms/estimate_degrees.py @@ -177,6 +177,14 @@ def negative_restricted(self, v, a): """Apply to negative_restricted.""" return a + def single_value_restricted(self, v, a): + """Apply to single_value_restricted.""" + return a + + def to_be_restricted(self, v, a): + """Apply to to_be_restricted.""" + return a + def conj(self, v, a): """Apply to conj.""" return a diff --git a/ufl/algorithms/formtransformations.py b/ufl/algorithms/formtransformations.py index 58d168c14..d5ed76dd3 100644 --- a/ufl/algorithms/formtransformations.py +++ b/ufl/algorithms/formtransformations.py @@ -240,6 +240,8 @@ def linear_operator(self, x, arg): # Positive and negative restrictions behave as linear operators positive_restricted = linear_operator negative_restricted = linear_operator + single_value_restricted = linear_operator + to_be_restricted = linear_operator # Cell and facet average are linear operators cell_avg = linear_operator diff --git a/ufl/exproperators.py b/ufl/exproperators.py index 60eb87566..590319cf6 100644 --- a/ufl/exproperators.py +++ b/ufl/exproperators.py @@ -24,7 +24,7 @@ from ufl.index_combination_utils import create_slice_indices, merge_overlapping_indices from ufl.indexed import Indexed from ufl.indexsum import IndexSum -from ufl.restriction import NegativeRestricted, PositiveRestricted +from ufl.restriction import NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted from ufl.tensoralgebra import Inner, Transposed from ufl.tensors import ComponentTensor, as_tensor from ufl.utils.stacks import StackDict @@ -305,7 +305,7 @@ def _abs(self): Expr.__abs__ = _abs -# --- Extend Expr with restiction operators a("+"), a("-") --- +# --- Extend Expr with restiction operators a("+"), a("-"), a("|"), a("?") --- def _restrict(self, side): """Restrict.""" @@ -313,6 +313,10 @@ def _restrict(self, side): return PositiveRestricted(self) if side == "-": return NegativeRestricted(self) + if side == "|": + return SingleValueRestricted(self) + if side == "?": + return ToBeRestricted(self) raise ValueError(f"Invalid side '{side}' in restriction operator.") @@ -335,7 +339,7 @@ def _eval(self, coord, mapping=None, component=()): def _call(self, arg, mapping=None, component=()): """Take the restriction or evaluate depending on argument.""" - if arg in ("+", "-"): + if arg in ("+", "-", "|", "?"): if mapping is not None: raise ValueError("Not expecting a mapping when taking restriction.") return _restrict(self, arg) diff --git a/ufl/formatting/ufl2unicode.py b/ufl/formatting/ufl2unicode.py index 5c193da40..27e255f13 100644 --- a/ufl/formatting/ufl2unicode.py +++ b/ufl/formatting/ufl2unicode.py @@ -128,6 +128,8 @@ class UC: superscript_plus = u"⁺" superscript_minus = u"⁻" + superscript_vertical_bar = u"|" + superscript_question_mark = u"?" superscript_equals = u"⁼" superscript_left_paren = u"⁽" superscript_right_paren = u"⁾" @@ -745,6 +747,14 @@ def negative_restricted(self, o, f): """Format a negative_restriced.""" return f"{par(f)}{UC.superscript_minus}" + def single_value_restricted(self, o, f): + """Format a sigle_value_restriced.""" + return f"{par(f)}{UC.superscript_vertical_bar}" + + def to_be_restricted(self, o, f): + """Format a to_be_restriced.""" + return f"{par(f)}{UC.superscript_question_mark}" + def cell_avg(self, o, f): """Format a cell_avg.""" f = overline_string(f) diff --git a/ufl/restriction.py b/ufl/restriction.py index 2871cd53f..05eae3f04 100644 --- a/ufl/restriction.py +++ b/ufl/restriction.py @@ -56,3 +56,19 @@ class NegativeRestricted(Restricted): __slots__ = () _side = "-" + + +@ufl_type(is_terminal_modifier=True) +class SingleValueRestricted(Restricted): + """Single value restriction.""" + + __slots__ = () + _side = "|" + + +@ufl_type(is_terminal_modifier=True) +class ToBeRestricted(Restricted): + """Single value restriction.""" + + __slots__ = () + _side = "?"