Skip to content

Commit

Permalink
handle restrictions on MixedMesh
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Aug 20, 2024
1 parent 52bd425 commit d3e72fb
Show file tree
Hide file tree
Showing 12 changed files with 559 additions and 38 deletions.
55 changes: 50 additions & 5 deletions test/test_mixed_function_space_with_mixed_mesh.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ufl import (triangle, Mesh, MixedMesh, FunctionSpace, TestFunction, TrialFunction, Coefficient,
Measure, SpatialCoordinate, FacetNormal, CellVolume, FacetArea, inner, grad, split, )
Measure, SpatialCoordinate, FacetNormal, CellVolume, FacetArea, inner, grad, div, split, )
from ufl.algorithms import compute_form_data
from ufl.finiteelement import FiniteElement, MixedElement
from ufl.pullback import identity_pullback, contravariant_piola
Expand Down Expand Up @@ -27,8 +27,9 @@ def test_mixed_function_space_with_mixed_mesh_basic():
f0, f1, f2 = split(f)
g0, g1, g2 = split(g)
dx1 = Measure("dx", mesh1)
ds2 = Measure("ds", mesh2)
x = SpatialCoordinate(mesh1)
form = x[1] * f0 * inner(grad(u0), v1) * dx1(999)
form = x[1] * f0 * inner(grad(u0), v1) * dx1(999) + div(f1) * g2 * inner(u1, grad(v2)) * ds2(888)
fd = compute_form_data(form,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
Expand All @@ -37,16 +38,60 @@ def test_mixed_function_space_with_mixed_mesh_basic():
do_apply_restrictions=True,
do_estimate_degrees=True,
complex_mode=False)
id0, = fd.integral_data
id0, id1 = fd.integral_data
assert fd.preprocessed_form.arguments() == (v, u)
assert fd.reduced_coefficients == [f]
assert fd.reduced_coefficients == [f, g]
assert form.coefficients()[fd.original_coefficient_positions[0]] is f
assert form.coefficients()[fd.original_coefficient_positions[1]] is g
assert id0.domain is mesh1
assert id0.integral_type == 'cell'
assert id0.subdomain_id == (999, )
assert fd.original_form.domain_numbering()[id0.domain] == 0
assert id0.integral_coefficients == set([f])
assert id0.enabled_coefficients == [True]
assert id0.enabled_coefficients == [True, False]
assert id1.domain is mesh2
assert id1.integral_type == 'exterior_facet'
assert id1.subdomain_id == (888, )
assert fd.original_form.domain_numbering()[id1.domain] == 1
assert id1.integral_coefficients == set([f, g])
assert id1.enabled_coefficients == [True, True]


def test_mixed_function_space_with_mixed_mesh_restriction():
cell = triangle
elem0 = FiniteElement("Lagrange", cell, 1, (), identity_pullback, H1)
elem1 = FiniteElement("Brezzi-Douglas-Marini", cell, 1, (2, ), contravariant_piola, HDiv)
elem2 = FiniteElement("Discontinuous Lagrange", cell, 0, (), identity_pullback, L2)
elem = MixedElement([elem0, elem1, elem2])
mesh0 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=100)
mesh1 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=101)
mesh2 = Mesh(FiniteElement("Lagrange", cell, 1, (2, ), identity_pullback, H1), ufl_id=102)
domain = MixedMesh(mesh0, mesh1, mesh2)
V = FunctionSpace(domain, elem)
V1 = FunctionSpace(mesh1, elem1)
V2 = FunctionSpace(mesh2, elem2)
u1 = TrialFunction(V1)
v2 = TestFunction(V2)
f = Coefficient(V, count=1000)
g = Coefficient(V, count=2000)
f0, f1, f2 = split(f)
g0, g1, g2 = split(g)
dS1 = Measure("dS", mesh1)
x2 = SpatialCoordinate(mesh2)
form = inner(x2, g1) * g2 * inner(u1('-'), grad(v2('|'))) * dS1(999)
fd = compute_form_data(form,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(CellVolume, FacetArea),
do_apply_restrictions=True,
do_estimate_degrees=True,
do_split_coefficients=(f, g),
do_assume_single_integral_type=False,
complex_mode=False)
integral_data, = fd.integral_data
assert integral_data.domain_integral_type_map[mesh1] == "interior_facet"
assert integral_data.domain_integral_type_map[mesh2] == "exterior_facet"


def test_mixed_function_space_with_mixed_mesh_signature():
Expand Down
210 changes: 210 additions & 0 deletions ufl/algorithms/apply_coefficient_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Apply coefficient split.
This module contains classes and functions to split coefficients defined on mixed function spaces.
"""

import numpy
from ufl.classes import Restricted
from ufl.corealg.map_dag import map_expr_dag
from ufl.corealg.multifunction import MultiFunction, memoized_handler
from ufl.domain import extract_unique_domain
from ufl.classes import (Coefficient, Form, ReferenceGrad, ReferenceValue,
Indexed, MultiIndex, Index, FixedIndex,
ComponentTensor, ListTensor, Zero,
NegativeRestricted, PositiveRestricted, SingleValueRestricted, ToBeRestricted)
from ufl import indices
from ufl.checks import is_cellwise_constant
from ufl.tensors import as_tensor


class CoefficientSplitter(MultiFunction):

def __init__(self, coefficient_split):
MultiFunction.__init__(self)
self._coefficient_split = coefficient_split

expr = MultiFunction.reuse_if_untouched

def modified_terminal(self, o):
restriction = None
local_derivatives = 0
reference_value = False
t = o
while not t._ufl_is_terminal_:
assert t._ufl_is_terminal_modifier_, f"Got {repr(t)}"
if isinstance(t, ReferenceValue):
assert not reference_value, "Got twice pulled back terminal!"
reference_value = True
t, = t.ufl_operands
elif isinstance(t, ReferenceGrad):
local_derivatives += 1
t, = t.ufl_operands
elif isinstance(t, Restricted):
assert restriction is None, "Got twice restricted terminal!"
restriction = t._side
t, = t.ufl_operands
elif t._ufl_terminal_modifiers_:
raise ValueError("Missing handler for terminal modifier type %s, object is %s." % (type(t), repr(t)))
else:
raise ValueError("Unexpected type %s object %s." % (type(t), repr(t)))
if not isinstance(t, Coefficient):
# Only split coefficients
return o
if t not in self._coefficient_split:
# Only split mixed coefficients
return o
# Reference value expected
assert reference_value
# Derivative indices
beta = indices(local_derivatives)
components = []
for subcoeff in self._coefficient_split[t]:
c = subcoeff
# Apply terminal modifiers onto the subcoefficient
if reference_value:
c = ReferenceValue(c)
for n in range(local_derivatives):
# Return zero if expression is trivially constant. This has to
# happen here because ReferenceGrad has no access to the
# topological dimension of a literal zero.
if is_cellwise_constant(c):
dim = extract_unique_domain(subcoeff).topological_dimension()
c = Zero(c.ufl_shape + (dim,), c.ufl_free_indices, c.ufl_index_dimensions)
else:
c = ReferenceGrad(c)
if restriction == '+':
c = PositiveRestricted(c)
elif restriction == '-':
c = NegativeRestricted(c)
elif restriction == '|':
c = SingleValueRestricted(c)
elif restriction == '?':
c = ToBeRestricted(c)
elif restriction is not None:
raise RuntimeError(f"Got unknown restriction: {restriction}")
# Collect components of the subcoefficient
for alpha in numpy.ndindex(subcoeff.ufl_element().reference_value_shape):
# New modified terminal: component[alpha + beta]
components.append(c[alpha + beta])
# Repack derivative indices to shape
c, = indices(1)
return ComponentTensor(as_tensor(components)[c], MultiIndex((c,) + beta))

positive_restricted = modified_terminal
negative_restricted = modified_terminal
single_value_restricted = modified_terminal
to_be_restricted = modified_terminal
reference_grad = modified_terminal
reference_value = modified_terminal
terminal = modified_terminal


def apply_coefficient_split(expr, coefficient_split):
"""Split mixed coefficients, so mixed elements need not be
implemented.
:arg split: A :py:class:`dict` mapping each mixed coefficient to a
sequence of subcoefficients. If None, calling this
function is a no-op.
"""
if coefficient_split is None:
return expr
splitter = CoefficientSplitter(coefficient_split)
return map_expr_dag(splitter, expr)


class FixedIndexRemover(MultiFunction):

def __init__(self, fimap):
MultiFunction.__init__(self)
self.fimap = fimap
self._object_cache = {}

expr = MultiFunction.reuse_if_untouched

@memoized_handler
def zero(self, o):
free_indices = []
index_dimensions = []
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
if Index(i) in self.fimap:
ind_j = self.fimap[Index(i)]
if not isinstance(ind_j, FixedIndex):
free_indices.append(ind_j.count())
index_dimensions.append(d)
else:
free_indices.append(i)
index_dimensions.append(d)
return Zero(shape=o.ufl_shape, free_indices=tuple(free_indices), index_dimensions=tuple(index_dimensions))

@memoized_handler
def list_tensor(self, o):
cc = []
for o1 in o.ufl_operands:
comp = map_expr_dag(self, o1)
cc.append(comp)
return ListTensor(*cc)

@memoized_handler
def multi_index(self, o):
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))


class IndexRemover(MultiFunction):

def __init__(self):
MultiFunction.__init__(self)
self._object_cache = {}

expr = MultiFunction.reuse_if_untouched

@memoized_handler
def _zero_simplify(self, o):
operand, = o.ufl_operands
operand = map_expr_dag(self, operand)
if isinstance(operand, Zero):
return Zero(shape=o.ufl_shape, free_indices=o.ufl_free_indices, index_dimensions=o.ufl_index_dimensions)
else:
return o._ufl_expr_reconstruct_(operand)

@memoized_handler
def indexed(self, o):
o1, i1 = o.ufl_operands
if isinstance(o1, ComponentTensor):
o2, i2 = o1.ufl_operands
assert len(i2.indices()) == len(i1.indices())
fimap = dict(zip(i2.indices(), i1.indices()))
rule = FixedIndexRemover(fimap)
v = map_expr_dag(rule, o2)
return map_expr_dag(self, v)
elif isinstance(o1, ListTensor):
if isinstance(i1[0], FixedIndex):
o1 = o1.ufl_operands[i1[0]._value]
if len(i1) > 1:
i1 = MultiIndex(i1[1:])
return map_expr_dag(self, Indexed(o1, i1))
else:
return map_expr_dag(self, o1)
o1 = map_expr_dag(self, o1)
return Indexed(o1, i1)

# Do something nicer
positive_restricted = _zero_simplify
negative_restricted = _zero_simplify
single_value_restricted = _zero_simplify
to_be_restricted = _zero_simplify
reference_grad = _zero_simplify
reference_value = _zero_simplify


def remove_component_and_list_tensors(o):
if isinstance(o, Form):
integrals = []
for integral in o.integrals():
integrand = remove_component_and_list_tensors(integral.integrand())
if not isinstance(integrand, Zero):
integrals.append(integral.reconstruct(integrand=integrand))
return o._ufl_expr_reconstruct_(integrals)
else:
rule = IndexRemover()
return map_expr_dag(rule, o)
Loading

0 comments on commit d3e72fb

Please sign in to comment.