diff --git a/FIAT/macro.py b/FIAT/macro.py index 77ab8938..6662bdfc 100644 --- a/FIAT/macro.py +++ b/FIAT/macro.py @@ -304,6 +304,13 @@ class AlfeldSplit(PowellSabinSplit): """Splits a simplicial complex by connecting cell vertices to their barycenter. """ + def __new__(cls, ref_el): + try: + return ref_el._split_cache[cls] + except KeyError: + self = super().__new__(cls) + return ref_el._split_cache.setdefault(cls, self) + def __init__(self, ref_el): sd = ref_el.get_spatial_dimension() super().__init__(ref_el, dimension=sd) @@ -313,6 +320,13 @@ class WorseyFarinSplit(PowellSabinSplit): """Splits a simplicial complex by connecting cell and facet vertices to their barycenter. This reduces to Powell-Sabin on the triangle, and Alfeld on the interval. """ + def __new__(cls, ref_el): + try: + return ref_el._split_cache[cls] + except KeyError: + self = super().__new__(cls) + return ref_el._split_cache.setdefault(cls, self) + def __init__(self, ref_el): sd = ref_el.get_spatial_dimension() super().__init__(ref_el, dimension=sd-1) diff --git a/FIAT/reference_element.py b/FIAT/reference_element.py index b4cd2c63..30fa94bb 100644 --- a/FIAT/reference_element.py +++ b/FIAT/reference_element.py @@ -181,6 +181,9 @@ def __init__(self, shape, vertices, topology): d01_entities = tuple(e for d, e in neighbors if d == dim1) self.connectivity[(dim0, dim1)].append(d01_entities) + # Dictionary with derived cells + self._split_cache = {} + def _key(self): """Hashable object key data (excluding type).""" # Default: only type matters diff --git a/test/FIAT/unit/test_macro.py b/test/FIAT/unit/test_macro.py index f1aa683c..1677d777 100644 --- a/test/FIAT/unit/test_macro.py +++ b/test/FIAT/unit/test_macro.py @@ -16,6 +16,15 @@ def cell(request): return ufc_simplex(dim) +def test_split_cache(cell): + A = AlfeldSplit(cell) + B = AlfeldSplit(cell) + assert B is A + fe = Lagrange(cell, 1, variant="alfeld") + C = fe.ref_complex + assert C is A + + @pytest.mark.parametrize("split", (AlfeldSplit, IsoSplit)) def test_split_entity_transform(split, cell): split_cell = split(cell)