Skip to content

Commit

Permalink
Update FiredrakeNamedConstant with new upstream implementation, which…
Browse files Browse the repository at this point in the history
… provides a nontrivial __new__ to deprecate a former usage
  • Loading branch information
francesco-ballarin committed Sep 11, 2023
1 parent 7b58fa4 commit 4d8c8c1
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 24 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ max-line-length = 120
show-source = True
docstring-convention = numpy
inline-quotes = double
imr241_exclude = ufl4rom, ufl4rom.*
imr241_exclude = __future__, ufl4rom, ufl4rom.*
imr245_include = *
imr245_exclude = ufl4rom, ufl4rom.*
imr245_exclude = __future__, ufl4rom, ufl4rom.*
ignore = ANN101, W503
exclude = .eggs, build, dist
per-file-ignores =
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/utils/test_name_scalar_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def test_name_scalar_13_firedrake() -> None:
+ c1 * f1 * u * v * dx
)
if np.issubdtype(petsc4py.PETSc.ScalarType, np.complexfloating): # names differ due to different c2 dtype
expected_name = "39e3bcaa7bbefe05ce8e1f0a19db7e0b6b6085ff"
expected_name = "3cd044e2a94c786d61061c6b893dd744a7ba95df"
else:
expected_name = "e869ce69d844731a95d97a5d560cd833c61335d1"
expected_name = "bf19de5b633b22e5c9c738ac7f96882048e792e1"
assert ufl4rom.utils.name(a13) == expected_name


Expand Down
2 changes: 1 addition & 1 deletion ufl4rom/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
from ufl4rom.utils.expand_sum import expand_sum
from ufl4rom.utils.name import name
from ufl4rom.utils.named_coefficient import NamedCoefficient
from ufl4rom.utils.named_constant import DolfinxNamedConstant, FiredrakeNamedConstant, NamedConstant
from ufl4rom.utils.named_constant import DolfinxNamedConstant, FiredrakeNamedConstant, NamedConstant, NamedConstantValue
from ufl4rom.utils.rewrite_quotients import rewrite_quotients
16 changes: 13 additions & 3 deletions ufl4rom/utils/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Import specialization of UFL classes from dolfinx and firedrake backends."""

from __future__ import annotations

import typing

import ufl
Expand Down Expand Up @@ -43,12 +45,20 @@ def name(self) -> str: # pragma: no cover
except ImportError:
FiredrakeScalarType = float

class FiredrakeConstant(ufl.Constant): # type: ignore[misc, no-any-unimported]
class FiredrakeConstant(ufl.constantvalue.ConstantValue): # type: ignore[misc, no-any-unimported]
"""Mock firedrake.Constant class."""

def __new__( # type: ignore[no-any-unimported]
cls: typing.Type[FiredrakeConstant],
value: typing.Union[FiredrakeScalarType, typing.Iterable[FiredrakeScalarType]],
domain: typing.Optional[ufl.AbstractDomain] = None, name: typing.Optional[str] = None
) -> FiredrakeConstant: # pragma: no cover
"""Create a new constant."""
raise RuntimeError("Cannot use a firedrake constant when firedrake is not installed")

def __init__( # type: ignore[no-any-unimported]
self, value: typing.Union[FiredrakeScalarType, typing.Iterable[FiredrakeScalarType]],
domain: typing.Optional[ufl.AbstractDomain] = None
domain: typing.Optional[ufl.AbstractDomain] = None, name: typing.Optional[str] = None
) -> None: # pragma: no cover
raise RuntimeError("Cannot use a firedrake constant when firedrake is not installed")

Expand All @@ -57,7 +67,7 @@ class FiredrakeFunction(ufl.Coefficient): # type: ignore[misc, no-any-unimporte

def name(self) -> str: # pragma: no cover
"""Get function name."""
raise RuntimeError("Cannot use a firedrake function when dolfin is not installed")
raise RuntimeError("Cannot use a firedrake function when firedrake is not installed")
else:
import petsc4py.PETSc

Expand Down
40 changes: 27 additions & 13 deletions ufl4rom/utils/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from ufl4rom.utils.backends import DolfinxConstant, DolfinxFunction, FiredrakeConstant, FiredrakeFunction
from ufl4rom.utils.named_coefficient import NamedCoefficient
from ufl4rom.utils.named_constant import DolfinxNamedConstant, FiredrakeNamedConstant, NamedConstant
from ufl4rom.utils.named_constant import DolfinxNamedConstant, FiredrakeNamedConstant, NamedConstant, NamedConstantValue


def name( # type: ignore[no-any-unimported]
Expand Down Expand Up @@ -84,7 +84,7 @@ def constant(self, o: ufl.Constant) -> NamedConstant: # type: ignore[no-any-uni
Note that the following backends:
* firedrake
actually implement their Constant objects inheriting from ufl Coefficient, so these cases are considered
actually implement their Constant objects inheriting from ufl ConstantValue, so these cases are considered
in the method below.
Raises an error when an unnamed ufl Constant is provided.
Expand All @@ -102,6 +102,31 @@ def constant(self, o: ufl.Constant) -> NamedConstant: # type: ignore[no-any-uni
raise RuntimeError(
"The case of plain UFL constants is not handled, because its value cannot be extracted")

def constant_value( # type: ignore[no-any-unimported]
self, o: ufl.constantvalue.ConstantValue
) -> ufl.constantvalue.ConstantValue:
"""
Replace a ufl ConstantValue with a ufl4rom NamedConstantValue, when possible.
Processes the following backends:
* firedrake: preserves name if provided, otherwise sets the name to the current value of the constant
Return the constant value unchanged when an unnamed ufl ConstantValue object is provided,
since it is safe to do so because that object is not counted.
"""
if isinstance(o, NamedConstantValue):
return o
elif isinstance(o, FiredrakeConstant):
# firedrake subclass ufl.ConstantValue (rather than ufl.Constant) in their definition
# of a Constant value.
if isinstance(o, FiredrakeNamedConstant):
return NamedConstantValue(str(o), o._ufl_shape)
else:
# Both of them provide a values attribute: use it to define a new NamedCoefficient
return NamedConstantValue(str(o.values()), o._ufl_shape)
else:
return o

def coefficient(self, o: ufl.Coefficient) -> NamedCoefficient: # type: ignore[no-any-unimported]
"""
Replace a ufl Coefficient with a ufl4rom NamedCoefficient, when possible.
Expand All @@ -110,21 +135,10 @@ def coefficient(self, o: ufl.Coefficient) -> NamedCoefficient: # type: ignore[n
* dolfinx: preserves name if provided, otherwise raises an error
* firedrake: preserves name if provided, otherwise raises an error
Processes the following backends which use Coefficient to implement actually Constant objects:
* firedrake: preserves name if provided, otherwise sets the name to the current value of the constant
Raises an error also when an unnamed ufl Coefficient is provided.
"""
if isinstance(o, NamedCoefficient):
return o
elif isinstance(o, FiredrakeConstant):
# firedrake subclass ufl.Coefficient (rather than ufl.Constant) in their definition
# of a Constant value.
if isinstance(o, FiredrakeNamedConstant):
return NamedCoefficient(str(o), o._ufl_function_space)
else:
# Both of them provide a values attribute: use it to define a new NamedCoefficient
return NamedCoefficient(str(o.values()), o._ufl_function_space)
elif isinstance(o, DolfinxFunction):
# dolfinx default name for functions is f
assert not o.name == "f", "Please provide a name to the Function"
Expand Down
47 changes: 44 additions & 3 deletions ufl4rom/utils/named_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Add a name to ufl.Constant and its specialization offered by backends."""

from __future__ import annotations

import re
import typing

Expand Down Expand Up @@ -36,6 +38,37 @@ def __str__(self) -> str: # pragma: no cover
return self._name


class NamedConstantValue(ufl.constantvalue.ConstantValue): # type: ignore[misc, no-any-unimported]
"""An ufl.constantvalue.ConstantValue with an additional name attribute."""

def __init__(
self, name: str, shape: typing.Tuple[int, ...] = ()
) -> None:
super().__init__()
self._name = name
assert not hasattr(self, "_ufl_shape")
self._ufl_shape = shape

# Represent the constant value by its name and its shape
self._repr = "NamedConstantValue({}, {})".format(repr(self._name), repr(self._ufl_shape))
self._repr = re.sub(" +", " ", self._repr)
self._repr = re.sub(r"\[ ", "[", self._repr)
self._repr = re.sub(r" \]", "]", self._repr)

def __str__(self) -> str: # pragma: no cover
"""Return the name of the constant value as a string representation."""
return self._name

def __repr__(self) -> str: # pragma: no cover
"""Return string representation this object can be reconstructed from."""
return self._repr

@property
def ufl_shape(self) -> typing.Tuple[int, ...]: # pragma: no cover
"""Shape of the constant value."""
return self._ufl_shape


class DolfinxNamedConstant(DolfinxConstant):
"""A dolfinx.Constant with an additional name attribute."""

Expand All @@ -62,17 +95,25 @@ def __str__(self) -> str:
class FiredrakeNamedConstant(FiredrakeConstant):
"""A firedrake.Constant with an additional name attribute."""

def __new__( # type: ignore[no-any-unimported]
cls: typing.Type[FiredrakeNamedConstant], name: str,
value: typing.Union[FiredrakeScalarType, typing.Iterable[FiredrakeScalarType]],
domain: typing.Optional[ufl.AbstractDomain] = None
) -> FiredrakeNamedConstant:
"""Create a new constant."""
return typing.cast(FiredrakeNamedConstant, FiredrakeConstant.__new__(cls, value, domain))

def __init__( # type: ignore[no-any-unimported]
self, name: str, value: typing.Union[FiredrakeScalarType, typing.Iterable[FiredrakeScalarType]],
domain: typing.Optional[ufl.AbstractDomain] = None
) -> None:
super().__init__(value, domain)
super().__init__(value, domain, name)
assert domain is None, "Giving Constants a domain has been deprecated in firedrake"
self._name = name

# Neglect the count argument when preparing the representation string, as we aim to
# get a representation which is independent on the internal counter
self._repr = "FiredrakeNamedConstant({}, {}, {})".format(
repr(self._name), repr(self.values()), repr(self._ufl_function_space._ufl_domain))
self._repr = "FiredrakeNamedConstant({}, {})".format(repr(self._name), repr(self.values()))
self._repr = re.sub(" +", " ", self._repr)
self._repr = re.sub(r"\[ ", "[", self._repr)
self._repr = re.sub(r" \]", "]", self._repr)
Expand Down

0 comments on commit 4d8c8c1

Please sign in to comment.