diff --git a/loopy/kernel/function_interface.py b/loopy/kernel/function_interface.py index 40d9969bf..33dfd73f2 100644 --- a/loopy/kernel/function_interface.py +++ b/loopy/kernel/function_interface.py @@ -23,8 +23,8 @@ THE SOFTWARE. """ from abc import ABC, abstractmethod -from collections.abc import Collection, Mapping, Sequence -from dataclasses import dataclass, fields, replace +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, replace from typing import TYPE_CHECKING, Any, Callable, FrozenSet, TypeVar from warnings import warn @@ -304,9 +304,7 @@ def get_kw_pos_association(kernel): # {{{ template class -# not frozen for Firedrake compatibility -# not eq to avoid having __hash__ set to None in subclasses -@dataclass(init=False, eq=False) +@dataclass(frozen=True, init=False) class InKernelCallable(ABC): """ An abstract interface to define a callable encountered in a kernel. @@ -370,51 +368,9 @@ def __init__(self, def name(self) -> str: raise NotImplementedError() - # {{{ hackery to avoid breaking Firedrake - - def _all_attrs(self) -> Collection[str]: - dc_attrs = { - fld.name for fld in fields(self) - } - legacy_fields: Collection[str] = getattr(self, "fields", []) - return dc_attrs | set(legacy_fields) - def copy(self, **kwargs: Any) -> Self: - present_kwargs = { - name: getattr(self, name) - for name in self._all_attrs() - } - kwargs = { - **present_kwargs, - **kwargs, - } - return replace(self, **kwargs) - def update_persistent_hash(self, key_hash, key_builder) -> None: - for field_name in self._all_attrs(): - key_builder.rec(key_hash, getattr(self, field_name)) - - def __eq__(self, other: object): - if type(self) is not type(other): - return False - - for f in self._all_attrs(): - if getattr(self, f) != getattr(other, f): - return False - - return True - - def __hash__(self): - import hashlib - - from loopy.tools import LoopyKeyBuilder - key_hash = hashlib.sha256() - self.update_persistent_hash(key_hash, LoopyKeyBuilder()) - return hash(key_hash.digest()) - - # }}} - def with_types(self, arg_id_to_dtype, clbl_inf_ctx): """ :arg arg_id_to_type: a mapping from argument identifiers (integers for @@ -565,8 +521,7 @@ def is_type_specialized(self): # {{{ scalar callable -# not frozen, not eq for Firedrake compatibility -@dataclass(init=False, eq=False) +@dataclass(frozen=True, init=False) class ScalarCallable(InKernelCallable): """ An abstract interface to a scalar callable encountered in a kernel. @@ -744,8 +699,7 @@ def is_type_specialized(self): # {{{ callable kernel -# not frozen, not eq for Firedrake compatibility -@dataclass(init=False, eq=False) +@dataclass(frozen=True, init=False) class CallableKernel(InKernelCallable): """ Records information about a callee kernel. Also provides interface through diff --git a/loopy/library/random123.py b/loopy/library/random123.py index cde0b093a..329770e05 100644 --- a/loopy/library/random123.py +++ b/loopy/library/random123.py @@ -176,8 +176,7 @@ def full_name(self) -> str: # }}} -# not frozen, not eq for Firedrake compatibility -@dataclass(init=False, eq=False) +@dataclass(frozen=True, init=False) class Random123Callable(ScalarCallable): """ Records information about for the random123 functions.