Skip to content

Commit

Permalink
Cache split cells (#128)
Browse files Browse the repository at this point in the history
* Cache split cells
  • Loading branch information
pbrubeck authored Jan 15, 2025
1 parent 7aa40e6 commit 70683e0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
14 changes: 14 additions & 0 deletions FIAT/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ class AlfeldSplit(PowellSabinSplit):
"""Splits a simplicial complex by connecting cell vertices to their
barycenter.
"""
def __new__(cls, ref_el):
try:
return ref_el._split_cache[cls]
except KeyError:
self = super().__new__(cls)
return ref_el._split_cache.setdefault(cls, self)

def __init__(self, ref_el):
sd = ref_el.get_spatial_dimension()
super().__init__(ref_el, dimension=sd)
Expand All @@ -313,6 +320,13 @@ class WorseyFarinSplit(PowellSabinSplit):
"""Splits a simplicial complex by connecting cell and facet vertices to their
barycenter. This reduces to Powell-Sabin on the triangle, and Alfeld on the interval.
"""
def __new__(cls, ref_el):
try:
return ref_el._split_cache[cls]
except KeyError:
self = super().__new__(cls)
return ref_el._split_cache.setdefault(cls, self)

def __init__(self, ref_el):
sd = ref_el.get_spatial_dimension()
super().__init__(ref_el, dimension=sd-1)
Expand Down
3 changes: 3 additions & 0 deletions FIAT/reference_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def __init__(self, shape, vertices, topology):
d01_entities = tuple(e for d, e in neighbors if d == dim1)
self.connectivity[(dim0, dim1)].append(d01_entities)

# Dictionary with derived cells
self._split_cache = {}

def _key(self):
"""Hashable object key data (excluding type)."""
# Default: only type matters
Expand Down
9 changes: 9 additions & 0 deletions test/FIAT/unit/test_macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def cell(request):
return ufc_simplex(dim)


def test_split_cache(cell):
A = AlfeldSplit(cell)
B = AlfeldSplit(cell)
assert B is A
fe = Lagrange(cell, 1, variant="alfeld")
C = fe.ref_complex
assert C is A


@pytest.mark.parametrize("split", (AlfeldSplit, IsoSplit))
def test_split_entity_transform(split, cell):
split_cell = split(cell)
Expand Down

0 comments on commit 70683e0

Please sign in to comment.