From 70683e0e5e06ec8373c18d05a9f54f09c7c551ce Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Wed, 15 Jan 2025 20:34:19 +0000 Subject: [PATCH] Cache split cells (#128) * Cache split cells --- FIAT/macro.py | 14 ++++++++++++++ FIAT/reference_element.py | 3 +++ test/FIAT/unit/test_macro.py | 9 +++++++++ 3 files changed, 26 insertions(+) diff --git a/FIAT/macro.py b/FIAT/macro.py index 77ab89382..6662bdfcb 100644 --- a/FIAT/macro.py +++ b/FIAT/macro.py @@ -304,6 +304,13 @@ class AlfeldSplit(PowellSabinSplit): """Splits a simplicial complex by connecting cell vertices to their barycenter. """ + def __new__(cls, ref_el): + try: + return ref_el._split_cache[cls] + except KeyError: + self = super().__new__(cls) + return ref_el._split_cache.setdefault(cls, self) + def __init__(self, ref_el): sd = ref_el.get_spatial_dimension() super().__init__(ref_el, dimension=sd) @@ -313,6 +320,13 @@ class WorseyFarinSplit(PowellSabinSplit): """Splits a simplicial complex by connecting cell and facet vertices to their barycenter. This reduces to Powell-Sabin on the triangle, and Alfeld on the interval. """ + def __new__(cls, ref_el): + try: + return ref_el._split_cache[cls] + except KeyError: + self = super().__new__(cls) + return ref_el._split_cache.setdefault(cls, self) + def __init__(self, ref_el): sd = ref_el.get_spatial_dimension() super().__init__(ref_el, dimension=sd-1) diff --git a/FIAT/reference_element.py b/FIAT/reference_element.py index b4cd2c63e..30fa94bb8 100644 --- a/FIAT/reference_element.py +++ b/FIAT/reference_element.py @@ -181,6 +181,9 @@ def __init__(self, shape, vertices, topology): d01_entities = tuple(e for d, e in neighbors if d == dim1) self.connectivity[(dim0, dim1)].append(d01_entities) + # Dictionary with derived cells + self._split_cache = {} + def _key(self): """Hashable object key data (excluding type).""" # Default: only type matters diff --git a/test/FIAT/unit/test_macro.py b/test/FIAT/unit/test_macro.py index f1aa683c3..1677d7779 100644 --- a/test/FIAT/unit/test_macro.py +++ b/test/FIAT/unit/test_macro.py @@ -16,6 +16,15 @@ def cell(request): return ufc_simplex(dim) +def test_split_cache(cell): + A = AlfeldSplit(cell) + B = AlfeldSplit(cell) + assert B is A + fe = Lagrange(cell, 1, variant="alfeld") + C = fe.ref_complex + assert C is A + + @pytest.mark.parametrize("split", (AlfeldSplit, IsoSplit)) def test_split_entity_transform(split, cell): split_cell = split(cell)