Skip to content

Commit

Permalink
attempt to fix type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jun 11, 2024
1 parent 6258011 commit 4d96e18
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,9 @@ class Array(Taggable):
tags: FrozenSet[Tag] = attrs.field(kw_only=True)

# These are automatically excluded from equality in EqualityComparer
non_equality_tags: FrozenSet[Optional[Tag]] = attrs.field(kw_only=True,
hash=False,
default=frozenset())
non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True,
hash=False,
default=frozenset())

_mapper_method: ClassVar[str]

Expand Down
4 changes: 2 additions & 2 deletions pytato/distributed/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class DistributedSendRefHolder(Array):

def __init__(self, send: DistributedSend, passthrough_data: Array,
tags: FrozenSet[Tag] = frozenset(),
non_equality_tags: FrozenSet[Optional[Tag]] = frozenset()) -> None:
non_equality_tags: FrozenSet[Tag] = frozenset()) -> None:
super().__init__(axes=passthrough_data.axes, tags=tags,
non_equality_tags=non_equality_tags)
object.__setattr__(self, "send", send)
Expand Down Expand Up @@ -232,7 +232,7 @@ def make_distributed_send_ref_holder(
send: DistributedSend,
passthrough_data: Array,
tags: FrozenSet[Tag] = frozenset(),
non_equality_tags: FrozenSet[Optional[Tag]] = frozenset(),
non_equality_tags: FrozenSet[Tag] = frozenset(),
) -> DistributedSendRefHolder:
"""Make a :class:`DistributedSendRefHolder` object."""
if not non_equality_tags:
Expand Down
4 changes: 2 additions & 2 deletions pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar,
op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501
get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501
tags: FrozenSet[Tag],
non_equality_tags: FrozenSet[Optional[Tag]],
non_equality_tags: FrozenSet[Tag],
) -> ArrayOrScalar:
from pytato.array import _get_default_axes

Expand Down Expand Up @@ -481,7 +481,7 @@ def _index_into(
ary: Array,
indices: Tuple[ConvertibleToIndexExpr, ...],
tags: FrozenSet[Tag],
non_equality_tags: FrozenSet[Optional[Tag]]) -> Array:
non_equality_tags: FrozenSet[Tag]) -> Array:
from pytato.diagnostic import CannotBroadcastError
from pytato.array import _get_default_axes

Expand Down

0 comments on commit 4d96e18

Please sign in to comment.