From 35b4d5352d1ccb5107d0a2f52de3bb80b952cb62 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 25 Jul 2024 15:23:05 +0100 Subject: [PATCH] BaseForm: ensure that subclasses implement ufl_domains() --- ufl/action.py | 10 ++++++++- ufl/adjoint.py | 12 +++++++++- ufl/form.py | 59 ++++++++++++++++++++++++++------------------------ 3 files changed, 51 insertions(+), 30 deletions(-) diff --git a/ufl/action.py b/ufl/action.py index 78489ff6b..fb6416865 100644 --- a/ufl/action.py +++ b/ufl/action.py @@ -122,7 +122,15 @@ def _analyze_domains(self): from ufl.domain import join_domains # Collect domains - self._domains = join_domains(chain.from_iterable(e.ufl_domain() for e in self.ufl_operands)) + self._domains = join_domains( + chain.from_iterable(e.ufl_domains() for e in self.ufl_operands) + ) + + def ufl_domains(self): + """Return all domains found in the base form.""" + if self._domains is None: + self._analyze_domains() + return self._domains def equals(self, other): """Check if two Actions are equal.""" diff --git a/ufl/adjoint.py b/ufl/adjoint.py index 7c1d5c63f..987372a73 100644 --- a/ufl/adjoint.py +++ b/ufl/adjoint.py @@ -8,6 +8,8 @@ # # Modified by Nacime Bouziani, 2021-2022. +from itertools import chain + from ufl.argument import Coargument from ufl.core.ufl_type import ufl_type from ufl.form import BaseForm, FormSum, ZeroBaseForm @@ -97,7 +99,15 @@ def _analyze_domains(self): from ufl.domain import join_domains # Collect unique domains - self._domains = join_domains([e.ufl_domain() for e in self.ufl_operands]) + self._domains = join_domains( + chain.from_iterable(e.ufl_domains() for e in self.ufl_operands) + ) + + def ufl_domains(self): + """Return all domains found in the base form.""" + if self._domains is None: + self._analyze_domains() + return self._domains def equals(self, other): """Check if two Adjoints are equal.""" diff --git a/ufl/form.py b/ufl/form.py index 620fb4e1e..8aadd57d5 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -113,13 +113,12 @@ def ufl_domain(self): Fails if multiple domains are found. """ - if self._domains is None: - self._analyze_domains() - - if len(self._domains) > 1: + try: + (domain,) = set(self.ufl_domains()) + except ValueError: raise ValueError("%s must have exactly one domain." % type(self).__name__) - # Return the single geometric domain - return self._domains[0] + # Return the one and only domain + return domain # --- Operator implementations --- @@ -139,7 +138,7 @@ def __radd__(self, other): def __add__(self, other): """Add.""" - if isinstance(other, (int, float)) and other == 0: + if isinstance(other, numbers.Number) and other == 0: # Allow adding 0 or 0.0 as a no-op, needed for sum([a,b]) return self elif isinstance(other, Zero): @@ -329,26 +328,6 @@ def ufl_cell(self): """ return self.ufl_domain().ufl_cell() - def ufl_domain(self): - """Return the single geometric integration domain occuring in the form. - - Fails if multiple domains are found. - - NB! This does not include domains of coefficients defined on - other meshes, look at form data for that additional information. - """ - # Collect all domains - domains = self.ufl_domains() - # Check that all are equal TODO: don't return more than one if - # all are equal? - if not all(domain == domains[0] for domain in domains): - raise ValueError( - "Calling Form.ufl_domain() is only valid if all integrals share domain." - ) - - # Return the one and only domain - return domains[0] - def geometric_dimension(self): """Return the geometric dimension shared by all domains and functions in this form.""" gdims = tuple(set(domain.geometric_dimension() for domain in self.ufl_domains())) @@ -807,7 +786,15 @@ def _analyze_domains(self): from ufl.domain import join_domains # Collect unique domains - self._domains = join_domains([component.ufl_domain() for component in self._components]) + self._domains = join_domains( + chain.from_iterable(e.ufl_domains() for e in self.ufl_operands) + ) + + def ufl_domains(self): + """Return all domains found in the base form.""" + if self._domains is None: + self._analyze_domains() + return self._domains def __hash__(self): """Hash.""" @@ -857,6 +844,7 @@ class ZeroBaseForm(BaseForm): "_arguments", "_coefficients", "ufl_operands", + "_domains", "_hash", # Pyadjoint compatibility "form", @@ -875,6 +863,21 @@ def _analyze_form_arguments(self): # `self._arguments` is already set in `BaseForm.__init__` self._coefficients = () + def _analyze_domains(self): + """Analyze which domains can be found in ZeroBaseForm.""" + from ufl.domain import join_domains + + # Collect unique domains + self._domains = join_domains( + chain.from_iterable(e.ufl_domains() for e in self.ufl_operands) + ) + + def ufl_domains(self): + """Return all domains found in the base form.""" + if self._domains is None: + self._analyze_domains() + return self._domains + def __ne__(self, other): """Overwrite BaseForm.__neq__ which relies on `equals`.""" return not self == other