Skip to content

Commit

Permalink
Enable, fix ruff TC checks
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 10, 2024
1 parent 27cfbec commit fe3a8fe
Show file tree
Hide file tree
Showing 33 changed files with 226 additions and 112 deletions.
5 changes: 5 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@

sys._BUILDING_SPHINX_DOCS = True


nitpick_ignore_regex = [
["py:class", r"numpy.(u?)int[\d]+"],
["py:class", r"numpy.bool_"],
["py:class", r"typing_extensions(.+)"],
["py:class", r"P\.args"],
["py:class", r"P\.kwargs"],
["py:class", r"lp\.LoopKernel"],
["py:class", r"_dtype_any"],
]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ extend-select = [
"NPY", # numpy
"RUF",
"UP",
"TC",
]
extend-ignore = [
"E226",
Expand Down
5 changes: 3 additions & 2 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
THE SOFTWARE.
"""

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

from loopy.tools import LoopyKeyBuilder
Expand All @@ -47,12 +46,14 @@
Stack,
)
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper


if TYPE_CHECKING:
from collections.abc import Mapping

from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.loopy import LoopyCall

__doc__ = """
.. currentmodule:: pytato.analysis
Expand Down
12 changes: 6 additions & 6 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _dataclass_setstate(self, state):

# {{{ assign mapper_method

mm_cls = cast(type[_HasMapperMethod], cls)
mm_cls = cast("type[_HasMapperMethod]", cls)

snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", mm_cls.__name__).lower()
default_mapper_method_name = f"map_{snake_clsname}"
Expand Down Expand Up @@ -804,7 +804,7 @@ def conj(self) -> ArrayOrScalar:

def __abs__(self) -> Array:
import pytato as pt
return cast(Array, pt.abs(self))
return cast("Array", pt.abs(self))

def __pos__(self) -> Array:
return self
Expand Down Expand Up @@ -1755,7 +1755,7 @@ def shape(self) -> ShapeType:
for i_basic_idx in i_basic_indices)

adv_idx_shape = get_shape_after_broadcasting([
cast(Array | Integer, not_none(self.indices[i_idx]))
cast("Array | Integer", not_none(self.indices[i_idx]))
for i_idx in i_adv_indices])

# type-ignored because mypy cannot figure out basic-indices only refer
Expand Down Expand Up @@ -1803,7 +1803,7 @@ def shape(self) -> ShapeType:
for i_basic_idx in i_basic_indices)

adv_idx_shape = get_shape_after_broadcasting([
cast(Array | Integer, not_none(self.indices[i_idx]))
cast("Array | Integer", not_none(self.indices[i_idx]))
for i_idx in i_adv_indices])

# type-ignored because mypy cannot figure out basic-indices only refer slices
Expand Down Expand Up @@ -2037,7 +2037,7 @@ def matmul(x1: Array, x2: Array) -> Array:
if x1.ndim == x2.ndim == 1:
return pt.sum(x1 * x2)
elif x1.ndim == 1:
return cast(Array, pt.dot(x1, x2))
return cast("Array", pt.dot(x1, x2))
elif x2.ndim == 1:
x1_indices = index_names[:x1.ndim]
return pt.einsum(f"{x1_indices}, {x1_indices[-1]} -> {x1_indices[:-1]}",
Expand Down Expand Up @@ -2370,7 +2370,7 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN,
else:
fill_value = conv_dtype.type(fill_value)

return IndexLambda(expr=cast(ArithmeticExpression, fill_value),
return IndexLambda(expr=cast("ArithmeticExpression", fill_value),
shape=shape, dtype=conv_dtype,
bindings=immutabledict(),
tags=_get_default_tags(),
Expand Down
11 changes: 7 additions & 4 deletions pytato/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,13 @@
# }}}


from typing import cast
from typing import TYPE_CHECKING, cast

import numpy as np
from immutabledict import immutabledict

import pymbolic.primitives as prim
from pymbolic import Scalar, var
from pymbolic.typing import Expression

from pytato.array import (
Array,
Expand All @@ -78,6 +77,10 @@
from pytato.scalar_expr import SCALAR_CLASSES


if TYPE_CHECKING:
from pymbolic.typing import Expression


def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],
func_name: str,
ret_dtype: _dtype_any | None = None,
Expand All @@ -88,7 +91,7 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...],
np_func_name = func_name

np_func = getattr(np, np_func_name)
return cast(ArrayOrScalar, np_func(*inputs))
return cast("ArrayOrScalar", np_func(*inputs))

if not inputs:
raise ValueError("at least one argument must be present")
Expand Down Expand Up @@ -233,7 +236,7 @@ def imag(x: ArrayOrScalar) -> ArrayOrScalar:
result_dtype = np.empty(0, dtype=x_dtype).real.dtype
else:
if np.isscalar(x):
return cast(Scalar, x_dtype.type(0))
return cast("Scalar", x_dtype.type(0))
else:
assert isinstance(x, Array)
import pytato as pt
Expand Down
12 changes: 8 additions & 4 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
"""

import dataclasses
from collections.abc import Mapping
from typing import Any
from typing import TYPE_CHECKING, Any

from immutabledict import immutabledict

Expand All @@ -57,10 +56,8 @@
SizeParam,
make_dict_of_named_arrays,
)
from pytato.function import NamedCallResult
from pytato.loopy import LoopyCall
from pytato.scalar_expr import IntegralScalarExpression, is_integral_scalar_expression
from pytato.target import Target
from pytato.transform import (
ArrayOrNames,
CachedWalkMapper,
Expand All @@ -70,6 +67,13 @@
from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin


if TYPE_CHECKING:
from collections.abc import Mapping

from pytato.function import NamedCallResult
from pytato.target import Target


SymbolicIndex: TypeAlias = tuple[IntegralScalarExpression, ...]


Expand Down
20 changes: 11 additions & 9 deletions pytato/distributed/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,30 @@
"""

import logging
from collections.abc import Hashable, Mapping
from typing import TYPE_CHECKING, Any

import numpy as np

from pytato.array import make_dict_of_named_arrays
from pytato.distributed.nodes import DistributedRecv, DistributedSend
from pytato.distributed.partition import (
DistributedGraphPart,
DistributedGraphPartition,
PartId,
)
from pytato.scalar_expr import INT_CLASSES
from pytato.target import BoundProgram


logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from collections.abc import Hashable, Mapping

import mpi4py.MPI

from pytato.distributed.nodes import DistributedRecv, DistributedSend
from pytato.distributed.partition import (
DistributedGraphPart,
DistributedGraphPartition,
PartId,
)
from pytato.target import BoundProgram


# {{{ generate_code_for_partition

Expand Down Expand Up @@ -134,7 +136,7 @@ def execute_distributed_partition(
context: dict[str, Any] = input_args.copy()

pids_to_execute = set(partition.parts)
pids_executed = set()
pids_executed: set[PartId] = set()
recv_names_completed = set()
send_requests = []

Expand Down
9 changes: 5 additions & 4 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,15 @@
DistributedSend,
DistributedSendRefHolder,
)
from pytato.function import FunctionDefinition, NamedCallResult
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, CopyMapper


if TYPE_CHECKING:
import mpi4py.MPI

from pytato.function import FunctionDefinition, NamedCallResult


@dataclasses.dataclass(frozen=True)
class CommunicationOpIdentifier:
Expand Down Expand Up @@ -350,7 +351,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
if name is not None:
return self._get_placeholder_for(name, expr)

return cast(ArrayOrNames, super().rec(expr))
return cast("ArrayOrNames", super().rec(expr))

def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder:
placeholder = self.partition_input_name_to_placeholder.get(name)
Expand Down Expand Up @@ -820,7 +821,7 @@ def find_distributed_partition(
raise comm_batches_or_exc

comm_batches = cast(
Sequence[Set[CommunicationOpIdentifier]],
"Sequence[Set[CommunicationOpIdentifier]]",
comm_batches_or_exc)

# }}}
Expand Down Expand Up @@ -921,7 +922,7 @@ def find_distributed_partition(
ary: max(
(comm_id_to_part_id[
_recv_to_comm_id(local_rank,
cast(DistributedRecv, recvd_ary))]
cast("DistributedRecv", recvd_ary))]
for recvd_ary in recvd_array_dep_mapper(ary)),
default=-1)
for ary in mso_arrays
Expand Down
9 changes: 5 additions & 4 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,8 @@

import dataclasses
import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import numpy as np

from pymbolic.mapper.optimize import optimize_mapper

from pytato.array import (
Expand All @@ -51,7 +48,6 @@
ShapeType,
make_dict_of_named_arrays,
)
from pytato.distributed.nodes import CommTagType, DistributedRecv
from pytato.distributed.partition import (
CommunicationOpIdentifier,
DistributedGraphPartition,
Expand All @@ -64,7 +60,12 @@


if TYPE_CHECKING:
from collections.abc import Sequence

import mpi4py.MPI
import numpy as np

from pytato.distributed.nodes import CommTagType, DistributedRecv


# {{{ data structures
Expand Down
5 changes: 3 additions & 2 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
THE SOFTWARE.
"""

from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from pytools import memoize_method
Expand All @@ -50,11 +49,13 @@
SizeParam,
Stack,
)
from pytato.function import Call, FunctionDefinition, NamedCallResult


if TYPE_CHECKING:
from collections.abc import Callable

from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall, LoopyCallResult

__doc__ = """
Expand Down
6 changes: 5 additions & 1 deletion pytato/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@


import dataclasses
from collections.abc import Iterable, Iterator, Mapping, Sequence
from numbers import Number
from typing import (
TYPE_CHECKING,
Any,
)

Expand Down Expand Up @@ -58,6 +58,10 @@
)


if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence


__doc__ = r"""
.. currentmodule:: pytato.loopy
Expand Down
17 changes: 14 additions & 3 deletions pytato/pad.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
"""
.. autofunction:: pad
Cross-references
----------------
.. class:: Integer
See :mod:`pymbolic.typing`.
"""
from __future__ import annotations


__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni"

import collections.abc as abc
from collections.abc import Sequence
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np

import pymbolic.primitives as prim
from pymbolic import Scalar
from pymbolic.typing import Integer
from pytools import UniqueNameGenerator

from pytato.array import Array, IndexLambda
from pytato.scalar_expr import INT_CLASSES


if TYPE_CHECKING:
from collections.abc import Sequence

from pymbolic.typing import Integer


def _get_constant_padded_idx_lambda(
array: Array,
pad_widths: Sequence[tuple[Integer, Integer]],
Expand Down
Loading

0 comments on commit fe3a8fe

Please sign in to comment.