diff --git a/pytato/array.py b/pytato/array.py index 31fe0b126..21eec352d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -449,9 +449,9 @@ class Array(Taggable): tags: FrozenSet[Tag] = attrs.field(kw_only=True) # These are automatically excluded from equality in EqualityComparer - non_equality_tags: FrozenSet[Optional[Tag]] = attrs.field(kw_only=True, - hash=False, - default=frozenset()) + non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True, + hash=False, + default=frozenset()) _mapper_method: ClassVar[str] diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index ac7cc5ac2..ec2d17417 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -148,7 +148,7 @@ class DistributedSendRefHolder(Array): def __init__(self, send: DistributedSend, passthrough_data: Array, tags: FrozenSet[Tag] = frozenset(), - non_equality_tags: FrozenSet[Optional[Tag]] = frozenset()) -> None: + non_equality_tags: FrozenSet[Tag] = frozenset()) -> None: super().__init__(axes=passthrough_data.axes, tags=tags, non_equality_tags=non_equality_tags) object.__setattr__(self, "send", send) @@ -232,7 +232,7 @@ def make_distributed_send_ref_holder( send: DistributedSend, passthrough_data: Array, tags: FrozenSet[Tag] = frozenset(), - non_equality_tags: FrozenSet[Optional[Tag]] = frozenset(), + non_equality_tags: FrozenSet[Tag] = frozenset(), ) -> DistributedSendRefHolder: """Make a :class:`DistributedSendRefHolder` object.""" if not non_equality_tags: diff --git a/pytato/utils.py b/pytato/utils.py index e5463d8f5..2097514d3 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -180,7 +180,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 tags: FrozenSet[Tag], - non_equality_tags: FrozenSet[Optional[Tag]], + non_equality_tags: FrozenSet[Tag], ) -> ArrayOrScalar: from pytato.array import _get_default_axes @@ -481,7 +481,7 @@ def _index_into( ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...], tags: FrozenSet[Tag], - non_equality_tags: FrozenSet[Optional[Tag]]) -> Array: + non_equality_tags: FrozenSet[Tag]) -> Array: from pytato.diagnostic import CannotBroadcastError from pytato.array import _get_default_axes