Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix extract_subelement_component #122

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions finat/ufl/finiteelementbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@
from ufl.utils.sequences import product


# Dict of supported pullback names and their ufl representation
supported_pullbacks = {
"identity": pullback.identity_pullback,
"L2 Piola": pullback.l2_piola,
"covariant Piola": pullback.covariant_piola,
"contravariant Piola": pullback.contravariant_piola,
"double covariant Piola": pullback.double_covariant_piola,
"double contravariant Piola": pullback.double_contravariant_piola,
"covariant contravariant Piola": pullback.covariant_contravariant_piola,
"custom": pullback.custom_pullback,
"physical": pullback.physical_pullback,
}


class FiniteElementBase(AbstractFiniteElement):
"""Base class for all finite elements."""
__slots__ = ("_family", "_cell", "_degree", "_quad_scheme",
Expand Down Expand Up @@ -121,6 +135,14 @@ def is_cellwise_constant(self, component=None):
"""Return whether the basis functions of this element is spatially constant over each cell."""
return self._is_globally_constant() or self.degree() == 0

def value_shape(self, domain=None):
"""Return the shape of the value space on a physical domain."""
return self.pullback.physical_value_shape(self, domain)

def value_size(self, domain=None):
"""Return the integer product of the value shape on a physical domain."""
return product(self.value_shape(domain))

@property
def reference_value_shape(self):
"""Return the shape of the value space on the reference cell."""
Expand All @@ -131,7 +153,7 @@ def reference_value_size(self):
"""Return the integer product of the reference value shape."""
return product(self.reference_value_shape)

def symmetry(self): # FIXME: different approach
def symmetry(self, domain=None):
r"""Return the symmetry dict.

This is a mapping :math:`c_0 \\to c_1`
Expand All @@ -141,37 +163,37 @@ def symmetry(self): # FIXME: different approach
"""
return {}

def _check_component(self, domain, i):
def _check_component(self, i, domain=None):
"""Check that component index i is valid."""
sh = self.value_shape(domain.geometric_dimension())
sh = self.value_shape(domain)
r = len(sh)
if not (len(i) == r and all(j < k for (j, k) in zip(i, sh))):
if not (len(i) == r and all(int(j) < k for (j, k) in zip(i, sh))):
raise ValueError(
f"Illegal component index {i} (value rank {len(i)}) "
f"for element (value rank {r}).")

def extract_subelement_component(self, domain, i):
def extract_subelement_component(self, i, domain=None):
"""Extract direct subelement index and subelement relative component index for a given component index."""
if isinstance(i, int):
i = (i,)
self._check_component(domain, i)
self._check_component(i, domain)
return (None, i)

def extract_component(self, domain, i):
def extract_component(self, i, domain=None):
"""Recursively extract component index relative to a (simple) element.

and that element for given value component index.
"""
if isinstance(i, int):
i = (i,)
self._check_component(domain, i)
self._check_component(i, domain)
return (i, self)

def _check_reference_component(self, i):
"""Check that reference component index i is valid."""
sh = self.reference_value_shape
r = len(sh)
if not (len(i) == r and all(j < k for (j, k) in zip(i, sh))):
if not (len(i) == r and all(int(j) < k for (j, k) in zip(i, sh))):
raise ValueError(
f"Illegal component index {i} (value rank {len(i)}) "
f"for element (value rank {r}).")
Expand Down Expand Up @@ -246,23 +268,7 @@ def embedded_subdegree(self):
@property
def pullback(self):
"""Get the pull back."""
if self.mapping() == "identity":
return pullback.identity_pullback
elif self.mapping() == "L2 Piola":
return pullback.l2_piola
elif self.mapping() == "covariant Piola":
return pullback.covariant_piola
elif self.mapping() == "contravariant Piola":
return pullback.contravariant_piola
elif self.mapping() == "double covariant Piola":
return pullback.double_covariant_piola
elif self.mapping() == "double contravariant Piola":
return pullback.double_contravariant_piola
elif self.mapping() == "covariant contravariant Piola":
return pullback.covariant_contravariant_piola
elif self.mapping() == "custom":
return pullback.custom_pullback
elif self.mapping() == "physical":
return pullback.physical_pullback

raise ValueError(f"Unsupported mapping: {self.mapping()}")
try:
return supported_pullbacks[self.mapping()]
except KeyError:
raise ValueError(f"Unsupported mapping: {self.mapping()}")
71 changes: 31 additions & 40 deletions finat/ufl/mixedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, *elements, **kwargs):
if not all(e.quadrature_scheme() == quad_scheme for e in elements):
raise ValueError("Quadrature scheme mismatch for sub elements of mixed element.")

# Compute value sizes in global and reference configurations
# Compute value sizes in reference configuration
reference_value_size_sum = sum(product(s.reference_value_shape) for s in self._sub_elements)

# Default reference value shape: Treated simply as all
Expand All @@ -75,7 +75,7 @@ def __init__(self, *elements, **kwargs):

def __repr__(self):
"""Doc."""
return "MixedElement(" + ", ".join(repr(e) for e in self._sub_elements) + ")"
return "MixedElement(" + ", ".join(map(repr, self._sub_elements)) + ")"

def _is_linear(self):
"""Doc."""
Expand All @@ -87,7 +87,7 @@ def reconstruct_from_elements(self, *elements):
return self
return MixedElement(*elements)

def symmetry(self, domain):
def symmetry(self, domain=None):
r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.

meaning that component :math:`c_0` is represented by component
Expand All @@ -103,15 +103,15 @@ def symmetry(self, domain):
st = shape_to_strides(sh)
# Map symmetries of subelement into index space of this
# element
for c0, c1 in e.symmetry().items():
for c0, c1 in e.symmetry(domain).items():
j0 = flatten_multiindex(c0, st) + j
j1 = flatten_multiindex(c1, st) + j
sm[(j0,)] = (j1,)
# Update base index for next element
j += product(sh)
if j != product(self.value_shape(domain)):
raise ValueError("Size mismatch in symmetry algorithm.")
return sm or {}
return sm

@property
def sobolev_space(self):
Expand All @@ -135,20 +135,21 @@ def sub_elements(self):
"""Return list of sub elements."""
return self._sub_elements

def extract_subelement_component(self, domain, i):
def extract_subelement_component(self, i, domain=None):
"""Extract direct subelement index and subelement relative.

component index for a given component index.
"""
if isinstance(i, int):
i = (i,)
self._check_component(i)
self._check_component(i, domain)

# Select between indexing modes
if len(self.value_shape(domain)) == 1:
# Indexing into a long vector of flattened subelement
# shapes
j, = i
j = int(j)

# Find subelement for this index
for sub_element_index, e in enumerate(self._sub_elements):
Expand All @@ -172,13 +173,13 @@ def extract_subelement_component(self, domain, i):
component = i[1:]
return (sub_element_index, component)

def extract_component(self, i):
def extract_component(self, i, domain=None):
"""Recursively extract component index relative to a (simple) element.

and that element for given value component index.
"""
sub_element_index, component = self.extract_subelement_component(i)
return self._sub_elements[sub_element_index].extract_component(component)
sub_element_index, component = self.extract_subelement_component(i, domain)
return self._sub_elements[sub_element_index].extract_component(component, domain)

def extract_subelement_reference_component(self, i):
"""Extract direct subelement index and subelement relative.
Expand All @@ -193,6 +194,7 @@ def extract_subelement_reference_component(self, i):
assert len(self.reference_value_shape) == 1
# Indexing into a long vector of flattened subelement shapes
j, = i
j = int(j)

# Find subelement for this index
for sub_element_index, e in enumerate(self._sub_elements):
Expand All @@ -217,20 +219,20 @@ def extract_reference_component(self, i):
sub_element_index, reference_component = self.extract_subelement_reference_component(i)
return self._sub_elements[sub_element_index].extract_reference_component(reference_component)

def is_cellwise_constant(self, component=None):
def is_cellwise_constant(self, component=None, domain=None):
"""Return whether the basis functions of this element is spatially constant over each cell."""
if component is None:
return all(e.is_cellwise_constant() for e in self.sub_elements)
else:
i, e = self.extract_component(component)
i, e = self.extract_component(component, domain)
return e.is_cellwise_constant()

def degree(self, component=None):
def degree(self, component=None, domain=None):
"""Return polynomial degree of finite element."""
if component is None:
return self._degree # from FiniteElementBase, computed as max of subelements in __init__
else:
i, e = self.extract_component(component)
i, e = self.extract_component(component, domain)
return e.degree()

@property
Expand All @@ -244,11 +246,12 @@ def embedded_superdegree(self):
return max(e.embedded_superdegree for e in self.sub_elements)

def reconstruct(self, **kwargs):
"""Doc."""
return MixedElement(*[e.reconstruct(**kwargs) for e in self.sub_elements])
"""Construct a new FiniteElement object with some properties replaced with new values."""
elements = (e.reconstruct(**kwargs) for e in self.sub_elements)
return self.reconstruct_from_elements(*elements)

def variant(self):
"""Doc."""
"""Return the common variant to all subelements."""
try:
variant, = {e.variant() for e in self.sub_elements}
return variant
Expand All @@ -257,7 +260,7 @@ def variant(self):

def __str__(self):
"""Format as string for pretty printing."""
tmp = ", ".join(str(element) for element in self._sub_elements)
tmp = ", ".join(map(str, self._sub_elements))
return "<Mixed element: (" + tmp + ")>"

def shortstr(self):
Expand All @@ -282,7 +285,6 @@ def __init__(self, family, cell=None, degree=None, dim=None,
if isinstance(family, FiniteElementBase):
sub_element = family
cell = sub_element.cell
variant = sub_element.variant()
else:
if cell is not None:
cell = as_cell(cell)
Expand Down Expand Up @@ -315,13 +317,8 @@ def __init__(self, family, cell=None, degree=None, dim=None,

self._sub_element = sub_element

if variant is None:
var_str = ""
else:
var_str = ", variant='" + variant + "'"

# Cache repr string
self._repr = f"VectorElement({repr(sub_element)}, dim={dim}{var_str})"
self._repr = f"VectorElement({repr(sub_element)}, dim={dim})"

def __repr__(self):
"""Doc."""
Expand Down Expand Up @@ -369,7 +366,6 @@ def __init__(self, family, cell=None, degree=None, shape=None,
if isinstance(family, FiniteElementBase):
sub_element = family
cell = sub_element.cell
variant = sub_element.variant()
else:
if cell is not None:
cell = as_cell(cell)
Expand Down Expand Up @@ -416,7 +412,7 @@ def __init__(self, family, cell=None, degree=None, shape=None,
if index in symmetry:
continue
sub_element_mapping[index] = len(sub_elements)
sub_elements += [sub_element]
sub_elements.append(sub_element)

# Update mapping for symmetry
for index in indices:
Expand Down Expand Up @@ -445,14 +441,9 @@ def __init__(self, family, cell=None, degree=None, shape=None,
self._sub_element_mapping = sub_element_mapping
self._flattened_sub_element_mapping = flattened_sub_element_mapping

if variant is None:
var_str = ""
else:
var_str = ", variant='" + variant + "'"

# Cache repr string
self._repr = (f"TensorElement({repr(sub_element)}, shape={shape}, "
f"symmetry={symmetry}{var_str})")
f"symmetry={symmetry})")

@property
def pullback(self):
Expand Down Expand Up @@ -490,25 +481,25 @@ def flattened_sub_element_mapping(self):
"""Doc."""
return self._flattened_sub_element_mapping

def extract_subelement_component(self, i):
def extract_subelement_component(self, i, domain=None):
"""Extract direct subelement index and subelement relative.

component index for a given component index.
"""
if isinstance(i, int):
i = (i,)
self._check_component(i)
self._check_component(i, domain)

i = self.symmetry().get(i, i)
l = len(self._shape) # noqa: E741
ii = i[:l]
jj = i[l:]
i = self.symmetry(domain).get(i, i)
rank = len(self._shape)
ii = i[:rank]
jj = i[rank:]
if ii not in self._sub_element_mapping:
raise ValueError(f"Illegal component index {i}.")
k = self._sub_element_mapping[ii]
return (k, jj)

def symmetry(self):
def symmetry(self, domain=None):
r"""Return the symmetry dict, which is a mapping :math:`c_0 \\to c_1`.

meaning that component :math:`c_0` is represented by component
Expand Down
21 changes: 21 additions & 0 deletions test/finat/test_ufl_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import ufl
import finat.ufl


def test_extract_subelement_component():
cell = ufl.triangle
domain = ufl.Mesh(finat.ufl.VectorElement(finat.ufl.FiniteElement("Lagrange", cell, 1)))

V = finat.ufl.VectorElement(finat.ufl.FiniteElement("Lagrange", cell, 2))
Q = finat.ufl.FiniteElement("Lagrange", cell, 1)
Z = V * Q

space = ufl.FunctionSpace(domain, Z)
test = ufl.TestFunction(space)

for i in range(3):
expr = test[i]
_, multiindex = expr.ufl_operands
subindex, _ = Z.extract_subelement_component(multiindex, domain)
sub_elem = Z.sub_elements[subindex]
assert sub_elem is (Q if i == 2 else V)
Loading