Skip to content

Commit

Permalink
BaseForm: ensure that subclasses implement ufl_domains()
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jul 29, 2024
1 parent c1a8afb commit df5c67f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 35 deletions.
10 changes: 9 additions & 1 deletion ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,15 @@ def _analyze_domains(self):
from ufl.domain import join_domains

# Collect domains
self._domains = join_domains(chain.from_iterable(e.ufl_domains() for e in self.ufl_operands))
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)

def ufl_domains(self):
"""Return all domains found in the base form."""
if self._domains is None:
self._analyze_domains()
return self._domains

def equals(self, other):
"""Check if two Actions are equal."""
Expand Down
12 changes: 11 additions & 1 deletion ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#
# Modified by Nacime Bouziani, 2021-2022.

from itertools import chain

from ufl.argument import Coargument
from ufl.core.ufl_type import ufl_type
from ufl.form import BaseForm, FormSum, ZeroBaseForm
Expand Down Expand Up @@ -90,7 +92,15 @@ def _analyze_domains(self):
from ufl.domain import join_domains

# Collect unique domains
self._domains = join_domains([e.ufl_domain() for e in self.ufl_operands])
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)

def ufl_domains(self):
"""Return all domains found in the base form."""
if self._domains is None:
self._analyze_domains()
return self._domains

def equals(self, other):
"""Check if two Adjoints are equal."""
Expand Down
72 changes: 39 additions & 33 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,12 @@ def ufl_domain(self):
Fails if multiple domains are found.
"""
if self._domains is None:
self._analyze_domains()

if len(self._domains) > 1:
try:
(domain,) = set(self.ufl_domains())
except ValueError:
raise ValueError("%s must have exactly one domain." % type(self).__name__)
# Return the single geometric domain
return self._domains[0]
# Return the one and only domain
return domain

# --- Operator implementations ---

Expand All @@ -134,7 +133,7 @@ def __radd__(self, other):

def __add__(self, other):
"""Add."""
if isinstance(other, (int, float)) and other == 0:
if isinstance(other, numbers.Number) and other == 0:
# Allow adding 0 or 0.0 as a no-op, needed for sum([a,b])
return self
elif isinstance(other, Zero) and not (other.ufl_shape or other.ufl_free_indices):
Expand Down Expand Up @@ -318,25 +317,6 @@ def ufl_cell(self):
"""
return self.ufl_domain().ufl_cell()

def ufl_domain(self):
"""Return the single geometric integration domain occuring in the form.
Fails if multiple domains are found.
NB! This does not include domains of coefficients defined on
other meshes, look at form data for that additional
information.
"""
# Collect all domains
domains = self.ufl_domains()
# Check that all are equal TODO: don't return more than one if
# all are equal?
if not all(domain == domains[0] for domain in domains):
raise ValueError("Calling Form.ufl_domain() is only valid if all integrals share domain.")

# Return the one and only domain
return domains[0]

def geometric_dimension(self):
"""Return the geometric dimension shared by all domains and functions in this form."""
gdims = tuple(
Expand Down Expand Up @@ -795,7 +775,15 @@ def _analyze_domains(self):
from ufl.domain import join_domains

# Collect unique domains
self._domains = join_domains([component.ufl_domain() for component in self._components])
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)

def ufl_domains(self):
"""Return all domains found in the base form."""
if self._domains is None:
self._analyze_domains()
return self._domains

def __hash__(self):
"""Hash."""
Expand Down Expand Up @@ -838,12 +826,15 @@ class ZeroBaseForm(BaseForm):
used for sake of simplifying base-form expressions.
"""

__slots__ = ("_arguments",
"_coefficients",
"ufl_operands",
"_hash",
# Pyadjoint compatibility
"form")
__slots__ = (
"_arguments",
"_coefficients",
"ufl_operands",
"_domains",
"_hash",
# Pyadjoint compatibility
"form",
)

def __init__(self, arguments):
"""Initialise."""
Expand All @@ -858,6 +849,21 @@ def _analyze_form_arguments(self):
# `self._arguments` is already set in `BaseForm.__init__`
self._coefficients = ()

def _analyze_domains(self):
"""Analyze which domains can be found in ZeroBaseForm."""
from ufl.domain import join_domains

# Collect unique domains
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)

def ufl_domains(self):
"""Return all domains found in the base form."""
if self._domains is None:
self._analyze_domains()
return self._domains

def __ne__(self, other):
"""Overwrite BaseForm.__neq__ which relies on `equals`."""
return not self == other
Expand Down

0 comments on commit df5c67f

Please sign in to comment.