Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quadrature rule hash #132

Merged
merged 5 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion FIAT/reference_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from math import factorial

import numpy
from gem.utils import safe_repr
from recursivenodes.nodes import _decode_family, _recursive

from FIAT.orientation_utils import (
Expand Down Expand Up @@ -126,7 +127,7 @@ def linalg_subspace_intersection(A, B):
return U[:, :rank_c]


class Cell(object):
class Cell:
"""Abstract class for a reference cell. Provides accessors for
geometry (vertex coordinates) as well as topology (orderings of
vertices that make up edges, faces, etc."""
Expand Down Expand Up @@ -184,6 +185,9 @@ def __init__(self, shape, vertices, topology):
# Dictionary with derived cells
self._split_cache = {}

def __repr__(self):
return f"{type(self).__name__}({self.shape!r}, {safe_repr(self.vertices)}, {self.topology!r})"

def _key(self):
"""Hashable object key data (excluding type)."""
# Default: only type matters
Expand Down Expand Up @@ -1130,6 +1134,9 @@ def __init__(self, *cells):
super().__init__(TENSORPRODUCT, vertices, topology)
self.cells = tuple(cells)

def __repr__(self):
return f"{type(self).__name__}({self.cells!r})"

def _key(self):
return self.cells

Expand Down
35 changes: 29 additions & 6 deletions finat/point_set.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
from abc import ABCMeta, abstractproperty
import abc
import hashlib
from functools import cached_property
from itertools import chain, product

import numpy

import gem
from gem.utils import cached_property
from gem.utils import safe_repr


class AbstractPointSet(metaclass=ABCMeta):
class AbstractPointSet(abc.ABC):
"""A way of specifying a known set of points, perhaps with some
(tensor) structure.

Points, when stored, have shape point_set_shape + (point_dimension,)
where point_set_shape is () for scalar, (N,) for N element vector,
(N, M) for N x M matrix etc.
"""
def __hash__(self):
return int.from_bytes(hashlib.md5(repr(self).encode()).digest(), byteorder="big")

@abstractproperty
@abc.abstractmethod
def __repr__(self):
pass

@property
@abc.abstractmethod
def points(self):
"""A flattened numpy array of points or ``UnknownPointsArray``
object with shape (# of points, point dimension)."""
Expand All @@ -27,12 +36,14 @@ def dimension(self):
_, dim = self.points.shape
return dim

@abstractproperty
@property
@abc.abstractmethod
def indices(self):
"""GEM indices with matching shape and extent to the structure of the
point set."""

@abstractproperty
@property
@abc.abstractmethod
def expression(self):
"""GEM expression describing the points, with free indices
``self.indices`` and shape (point dimension,)."""
Expand All @@ -53,6 +64,9 @@ def __init__(self, point):
assert len(point.shape) == 1
self.point = point

def __repr__(self):
return f"{type(self).__name__}({safe_repr(self.point)})"

@cached_property
def points(self):
# Make sure we conform to the expected (# of points, point dimension)
Expand Down Expand Up @@ -106,6 +120,9 @@ def __init__(self, points_expr):
assert len(points_expr.shape) == 2
self._points_expr = points_expr

def __repr__(self):
return f"{type(self).__name__}({self._points_expr!r})"

@cached_property
def points(self):
return UnknownPointsArray(self._points_expr.shape)
Expand Down Expand Up @@ -133,6 +150,9 @@ def __init__(self, points):
assert len(points.shape) == 2
self.points = points

def __repr__(self):
return f"{type(self).__name__}({self.points!r})"

@cached_property
def points(self):
pass # set at initialisation
Expand Down Expand Up @@ -177,6 +197,9 @@ class TensorPointSet(AbstractPointSet):
def __init__(self, factors):
self.factors = tuple(factors)

def __repr__(self):
return f"{type(self).__name__}({self.factors!r})"

@cached_property
def points(self):
return numpy.array([list(chain(*pt_tuple))
Expand Down
36 changes: 31 additions & 5 deletions finat/quadrature.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABCMeta, abstractproperty
from functools import reduce
import hashlib
from abc import ABCMeta, abstractmethod
from functools import cached_property, reduce

import gem
import numpy
from FIAT.quadrature import GaussLegendreQuadratureLineRule
from FIAT.quadrature_schemes import create_quadrature as fiat_scheme
from FIAT.reference_element import LINE, QUADRILATERAL, TENSORPRODUCT
from gem.utils import cached_property
from gem.utils import safe_repr

from finat.point_set import GaussLegendrePointSet, PointSet, TensorPointSet

Expand Down Expand Up @@ -60,11 +61,23 @@ class AbstractQuadratureRule(metaclass=ABCMeta):
"""Abstract class representing a quadrature rule as point set and a
corresponding set of weights."""

@abstractproperty
def __hash__(self):
return int.from_bytes(hashlib.md5(repr(self).encode()).digest(), byteorder="big")

def __eq__(self, other):
return type(other) is type(self) and repr(other) == repr(self)

@abstractmethod
def __repr__(self):
pass

@property
@abstractmethod
def point_set(self):
"""Point set object representing the quadrature points."""

@abstractproperty
@property
@abstractmethod
def weight_expression(self):
"""GEM expression describing the weights, with the same free indices
as the point set."""
Expand Down Expand Up @@ -110,6 +123,16 @@ def __init__(self, point_set, weights, ref_el=None, io_ornt_map_tuple=(None, )):
self.weights = numpy.asarray(weights)
self._intrinsic_orientation_permutation_map_tuple = io_ornt_map_tuple

def __repr__(self):
return (
f"{type(self).__name__}("
f"{self.point_set!r}, "
f"{safe_repr(self.weights)}, "
f"{self.ref_el!r}, "
f"{self._intrinsic_orientation_permutation_map_tuple!r}"
")"
)

@cached_property
def point_set(self):
pass # set at initialisation
Expand All @@ -131,6 +154,9 @@ def __init__(self, factors, ref_el=None):
for m in factor._intrinsic_orientation_permutation_map_tuple
)

def __repr__(self):
return f"{type(self).__name__}({self.factors!r}, {self.ref_el!r})"

@cached_property
def point_set(self):
return TensorPointSet(q.point_set for q in self.factors)
Expand Down
54 changes: 54 additions & 0 deletions gem/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import collections
import functools
import numbers
from functools import cached_property # noqa: F401
from typing import Any

import numpy as np
from ufl.constantvalue import format_float


def groupby(iterable, key=None):
Expand Down Expand Up @@ -88,3 +94,51 @@ def __exit__(self, exc_type, exc_value, traceback):
assert self.state is variable._head
value, variable._head = variable._head
self.state = None


@functools.singledispatch
def safe_repr(obj: Any) -> str:
"""Return a 'safe' repr for an object, accounting for floating point error.

Parameters
----------
obj :
The object to produce a repr for.

Returns
-------
str :
A repr for the object.

"""
raise TypeError(f"Cannot provide a safe repr for {type(obj).__name__}")


@safe_repr.register(str)
def _(text: str) -> str:
return text


@safe_repr.register(numbers.Integral)
def _(num: numbers.Integral) -> str:
return repr(num)


@safe_repr.register(numbers.Real)
def _(num: numbers.Real) -> str:
return format_float(num)
ksagiyam marked this conversation as resolved.
Show resolved Hide resolved


@safe_repr.register(np.ndarray)
def _(array: np.ndarray) -> str:
return f"{type(array).__name__}([{', '.join(map(safe_repr, array))}])"


@safe_repr.register(list)
def _(list_: list) -> str:
return f"[{', '.join(map(safe_repr, list_))}]"


@safe_repr.register(tuple)
def _(tuple_: tuple) -> str:
return f"({', '.join(map(safe_repr, tuple_))})"
5 changes: 5 additions & 0 deletions test/finat/test_create_fiat_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def tensor_name(request):
ids=lambda x: x.cellname(),
scope="module")
def ufl_A(request, tensor_name):
if request.param == ufl.quadrilateral:
if tensor_name == "DG":
tensor_name = "DQ"
elif tensor_name == "DG L2":
tensor_name = "DQ L2"
connorjward marked this conversation as resolved.
Show resolved Hide resolved
return finat.ufl.FiniteElement(tensor_name, request.param, 1)


Expand Down
19 changes: 19 additions & 0 deletions test/finat/test_quadrature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from FIAT import ufc_cell
from finat.quadrature import make_quadrature


@pytest.mark.parametrize(
"cell_name",
["interval", "triangle", "interval * interval", "triangle * interval"]
)
def test_quadrature_rules_are_hashable(cell_name):
ref_cell = ufc_cell(cell_name)
quadrature1 = make_quadrature(ref_cell, 3)
quadrature2 = make_quadrature(ref_cell, 3)

assert quadrature1 is not quadrature2
assert hash(quadrature1) == hash(quadrature2)
assert repr(quadrature1) == repr(quadrature2)
assert quadrature1 == quadrature2