diff --git a/doc/conf.py b/doc/conf.py index 1c69ebfa6..6966a6062 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,8 +46,13 @@ sys._BUILDING_SPHINX_DOCS = True + nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"numpy.bool_"], ["py:class", r"typing_extensions(.+)"], + ["py:class", r"P\.args"], + ["py:class", r"P\.kwargs"], + ["py:class", r"lp\.LoopKernel"], + ["py:class", r"_dtype_any"], ] diff --git a/pyproject.toml b/pyproject.toml index 6b5563242..665a83d81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ extend-select = [ "NPY", # numpy "RUF", "UP", + "TC", ] extend-ignore = [ "E226", diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a274a335a..9f4c4bf54 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,6 @@ THE SOFTWARE. """ -from collections.abc import Mapping from typing import TYPE_CHECKING, Any from loopy.tools import LoopyKeyBuilder @@ -47,12 +46,14 @@ Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.loopy import LoopyCall from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper if TYPE_CHECKING: + from collections.abc import Mapping + from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder + from pytato.loopy import LoopyCall __doc__ = """ .. currentmodule:: pytato.analysis diff --git a/pytato/array.py b/pytato/array.py index 1661e6e2d..064a98b4d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -405,7 +405,7 @@ def _dataclass_setstate(self, state): # {{{ assign mapper_method - mm_cls = cast(type[_HasMapperMethod], cls) + mm_cls = cast("type[_HasMapperMethod]", cls) snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", mm_cls.__name__).lower() default_mapper_method_name = f"map_{snake_clsname}" @@ -804,7 +804,7 @@ def conj(self) -> ArrayOrScalar: def __abs__(self) -> Array: import pytato as pt - return cast(Array, pt.abs(self)) + return cast("Array", pt.abs(self)) def __pos__(self) -> Array: return self @@ -1755,7 +1755,7 @@ def shape(self) -> ShapeType: for i_basic_idx in i_basic_indices) adv_idx_shape = get_shape_after_broadcasting([ - cast(Array | Integer, not_none(self.indices[i_idx])) + cast("Array | Integer", not_none(self.indices[i_idx])) for i_idx in i_adv_indices]) # type-ignored because mypy cannot figure out basic-indices only refer @@ -1803,7 +1803,7 @@ def shape(self) -> ShapeType: for i_basic_idx in i_basic_indices) adv_idx_shape = get_shape_after_broadcasting([ - cast(Array | Integer, not_none(self.indices[i_idx])) + cast("Array | Integer", not_none(self.indices[i_idx])) for i_idx in i_adv_indices]) # type-ignored because mypy cannot figure out basic-indices only refer slices @@ -2037,7 +2037,7 @@ def matmul(x1: Array, x2: Array) -> Array: if x1.ndim == x2.ndim == 1: return pt.sum(x1 * x2) elif x1.ndim == 1: - return cast(Array, pt.dot(x1, x2)) + return cast("Array", pt.dot(x1, x2)) elif x2.ndim == 1: x1_indices = index_names[:x1.ndim] return pt.einsum(f"{x1_indices}, {x1_indices[-1]} -> {x1_indices[:-1]}", @@ -2370,7 +2370,7 @@ def full(shape: ConvertibleToShape, fill_value: Scalar | prim.NaN, else: fill_value = conv_dtype.type(fill_value) - return IndexLambda(expr=cast(ArithmeticExpression, fill_value), + return IndexLambda(expr=cast("ArithmeticExpression", fill_value), shape=shape, dtype=conv_dtype, bindings=immutabledict(), tags=_get_default_tags(), diff --git a/pytato/cmath.py b/pytato/cmath.py index d7900ed66..ec981aa32 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -57,14 +57,13 @@ # }}} -from typing import cast +from typing import TYPE_CHECKING, cast import numpy as np from immutabledict import immutabledict import pymbolic.primitives as prim from pymbolic import Scalar, var -from pymbolic.typing import Expression from pytato.array import ( Array, @@ -78,6 +77,10 @@ from pytato.scalar_expr import SCALAR_CLASSES +if TYPE_CHECKING: + from pymbolic.typing import Expression + + def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...], func_name: str, ret_dtype: _dtype_any | None = None, @@ -88,7 +91,7 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...], np_func_name = func_name np_func = getattr(np, np_func_name) - return cast(ArrayOrScalar, np_func(*inputs)) + return cast("ArrayOrScalar", np_func(*inputs)) if not inputs: raise ValueError("at least one argument must be present") @@ -233,7 +236,7 @@ def imag(x: ArrayOrScalar) -> ArrayOrScalar: result_dtype = np.empty(0, dtype=x_dtype).real.dtype else: if np.isscalar(x): - return cast(Scalar, x_dtype.type(0)) + return cast("Scalar", x_dtype.type(0)) else: assert isinstance(x, Array) import pytato as pt diff --git a/pytato/codegen.py b/pytato/codegen.py index aa26e48a9..d08445517 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -38,8 +38,7 @@ """ import dataclasses -from collections.abc import Mapping -from typing import Any +from typing import TYPE_CHECKING, Any from immutabledict import immutabledict @@ -57,10 +56,8 @@ SizeParam, make_dict_of_named_arrays, ) -from pytato.function import NamedCallResult from pytato.loopy import LoopyCall from pytato.scalar_expr import IntegralScalarExpression, is_integral_scalar_expression -from pytato.target import Target from pytato.transform import ( ArrayOrNames, CachedWalkMapper, @@ -70,6 +67,13 @@ from pytato.transform.lower_to_index_lambda import ToIndexLambdaMixin +if TYPE_CHECKING: + from collections.abc import Mapping + + from pytato.function import NamedCallResult + from pytato.target import Target + + SymbolicIndex: TypeAlias = tuple[IntegralScalarExpression, ...] diff --git a/pytato/distributed/execute.py b/pytato/distributed/execute.py index 2630a76c1..3be65c76b 100644 --- a/pytato/distributed/execute.py +++ b/pytato/distributed/execute.py @@ -34,28 +34,30 @@ """ import logging -from collections.abc import Hashable, Mapping from typing import TYPE_CHECKING, Any import numpy as np from pytato.array import make_dict_of_named_arrays -from pytato.distributed.nodes import DistributedRecv, DistributedSend -from pytato.distributed.partition import ( - DistributedGraphPart, - DistributedGraphPartition, - PartId, -) from pytato.scalar_expr import INT_CLASSES -from pytato.target import BoundProgram logger = logging.getLogger(__name__) if TYPE_CHECKING: + from collections.abc import Hashable, Mapping + import mpi4py.MPI + from pytato.distributed.nodes import DistributedRecv, DistributedSend + from pytato.distributed.partition import ( + DistributedGraphPart, + DistributedGraphPartition, + PartId, + ) + from pytato.target import BoundProgram + # {{{ generate_code_for_partition @@ -134,7 +136,7 @@ def execute_distributed_partition( context: dict[str, Any] = input_args.copy() pids_to_execute = set(partition.parts) - pids_executed = set() + pids_executed: set[PartId] = set() recv_names_completed = set() send_requests = [] diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index ca817ad85..c929f750c 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -88,7 +88,6 @@ DistributedSend, DistributedSendRefHolder, ) -from pytato.function import FunctionDefinition, NamedCallResult from pytato.scalar_expr import SCALAR_CLASSES from pytato.transform import ArrayOrNames, CachedWalkMapper, CombineMapper, CopyMapper @@ -96,6 +95,8 @@ if TYPE_CHECKING: import mpi4py.MPI + from pytato.function import FunctionDefinition, NamedCallResult + @dataclasses.dataclass(frozen=True) class CommunicationOpIdentifier: @@ -350,7 +351,7 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames: if name is not None: return self._get_placeholder_for(name, expr) - return cast(ArrayOrNames, super().rec(expr)) + return cast("ArrayOrNames", super().rec(expr)) def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: placeholder = self.partition_input_name_to_placeholder.get(name) @@ -820,7 +821,7 @@ def find_distributed_partition( raise comm_batches_or_exc comm_batches = cast( - Sequence[Set[CommunicationOpIdentifier]], + "Sequence[Set[CommunicationOpIdentifier]]", comm_batches_or_exc) # }}} @@ -921,7 +922,7 @@ def find_distributed_partition( ary: max( (comm_id_to_part_id[ _recv_to_comm_id(local_rank, - cast(DistributedRecv, recvd_ary))] + cast("DistributedRecv", recvd_ary))] for recvd_ary in recvd_array_dep_mapper(ary)), default=-1) for ary in mso_arrays diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index 84d8a21ed..5e8aa526d 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -38,11 +38,8 @@ import dataclasses import logging -from collections.abc import Sequence from typing import TYPE_CHECKING, Any -import numpy as np - from pymbolic.mapper.optimize import optimize_mapper from pytato.array import ( @@ -51,7 +48,6 @@ ShapeType, make_dict_of_named_arrays, ) -from pytato.distributed.nodes import CommTagType, DistributedRecv from pytato.distributed.partition import ( CommunicationOpIdentifier, DistributedGraphPartition, @@ -64,7 +60,12 @@ if TYPE_CHECKING: + from collections.abc import Sequence + import mpi4py.MPI + import numpy as np + + from pytato.distributed.nodes import CommTagType, DistributedRecv # {{{ data structures diff --git a/pytato/equality.py b/pytato/equality.py index 831dc49a6..1ef2a88ed 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -25,7 +25,6 @@ THE SOFTWARE. """ -from collections.abc import Callable from typing import TYPE_CHECKING, Any from pytools import memoize_method @@ -50,11 +49,13 @@ SizeParam, Stack, ) -from pytato.function import Call, FunctionDefinition, NamedCallResult if TYPE_CHECKING: + from collections.abc import Callable + from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder + from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall, LoopyCallResult __doc__ = """ diff --git a/pytato/loopy.py b/pytato/loopy.py index 5bb86fcd1..de4c26506 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -27,9 +27,9 @@ import dataclasses -from collections.abc import Iterable, Iterator, Mapping, Sequence from numbers import Number from typing import ( + TYPE_CHECKING, Any, ) @@ -58,6 +58,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Mapping, Sequence + + __doc__ = r""" .. currentmodule:: pytato.loopy diff --git a/pytato/pad.py b/pytato/pad.py index 588070d0f..43ddc4b1d 100644 --- a/pytato/pad.py +++ b/pytato/pad.py @@ -1,5 +1,12 @@ """ .. autofunction:: pad + +Cross-references +---------------- + +.. class:: Integer + + See :mod:`pymbolic.typing`. """ from __future__ import annotations @@ -7,20 +14,24 @@ __copyright__ = "Copyright (C) 2023 Kaushik Kulkarni" import collections.abc as abc -from collections.abc import Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import pymbolic.primitives as prim from pymbolic import Scalar -from pymbolic.typing import Integer from pytools import UniqueNameGenerator from pytato.array import Array, IndexLambda from pytato.scalar_expr import INT_CLASSES +if TYPE_CHECKING: + from collections.abc import Sequence + + from pymbolic.typing import Integer + + def _get_constant_padded_idx_lambda( array: Array, pad_widths: Sequence[tuple[Integer, Integer]], diff --git a/pytato/raising.py b/pytato/raising.py index 7b1a6bc09..13efb737c 100644 --- a/pytato/raising.py +++ b/pytato/raising.py @@ -1,19 +1,16 @@ from __future__ import annotations -from collections.abc import Mapping, Sequence from dataclasses import dataclass from enum import Enum, auto, unique -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from immutabledict import immutabledict import pymbolic.primitives as p -from pymbolic.typing import Scalar from pytato.array import Array, ArrayOrScalar, IndexLambda, ShapeType from pytato.diagnostic import UnknownIndexLambdaExpr -from pytato.reductions import ReductionOperation from pytato.scalar_expr import ( SCALAR_CLASSES, IdentityMapper, @@ -28,6 +25,14 @@ ) +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from pymbolic.typing import Scalar + + from pytato.reductions import ReductionOperation + + __doc__ = """ .. autoclass:: HighLevelOp diff --git a/pytato/reductions.py b/pytato/reductions.py index e5dc8d817..0d2c5fc1e 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -30,8 +30,7 @@ from abc import ABC, abstractmethod -from collections.abc import Mapping, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from immutabledict import immutabledict @@ -43,6 +42,10 @@ from pytato.scalar_expr import INT_CLASSES, Reduce, ScalarExpression +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + # {{{ docs __doc__ = """ diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 2e66828cd..e74007c0f 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -134,7 +134,7 @@ def map_reduce(self, expr: Reduce, *args: P.args, **kwargs: P.kwargs) -> Expression: return Reduce( - cast(ArithmeticExpression, + cast("ArithmeticExpression", self.rec(expr.inner_expr, *args, **kwargs)), expr.op, immutabledict({ @@ -148,7 +148,7 @@ def map_reduce(self, def map_type_cast(self, expr: TypeCast, *args: P.args, **kwargs: P.kwargs) -> Expression: return TypeCast(expr.dtype, - cast(ArithmeticExpression, + cast("ArithmeticExpression", self.rec(expr.inner_expr, *args, **kwargs))) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 6b50d67f1..f2e7f066b 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -26,7 +26,7 @@ """ import dataclasses -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import numpy as np from immutabledict import immutabledict @@ -41,11 +41,14 @@ IndexLambda, ReductionDescriptor, ) -from pytato.function import Call, FunctionDefinition -from pytato.loopy import LoopyCall from pytato.transform import Mapper +if TYPE_CHECKING: + from pytato.function import Call, FunctionDefinition + from pytato.loopy import LoopyCall + + __doc__ = """ .. currentmodule:: pytato.stringifier @@ -90,7 +93,7 @@ def map_foreign(self, expr: Any, depth: int) -> str: + ", ".join(f"{key!r}: {self.rec(val, depth)}" for key, val in sorted(expr.items(), - key=lambda k_x_v: cast(str, k_x_v[0]))) + key=lambda k_x_v: cast("str", k_x_v[0]))) + "}") elif isinstance(expr, frozenset | set): return "{" + ", ".join(self.rec(el, depth) for el in expr) + "}" diff --git a/pytato/tags.py b/pytato/tags.py index 8a7c8d3e7..e0a98b7da 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -19,13 +19,17 @@ .. autoclass:: InlineCallTag """ -from collections.abc import Hashable from dataclasses import dataclass from traceback import FrameSummary, StackSummary +from typing import TYPE_CHECKING from pytools.tag import Tag, UniqueTag, tag_dataclass +if TYPE_CHECKING: + from collections.abc import Hashable + + # {{{ pre-defined tag: ImplementationStrategy @tag_dataclass diff --git a/pytato/target/__init__.py b/pytato/target/__init__.py index 6e3e1480e..57ab3940b 100644 --- a/pytato/target/__init__.py +++ b/pytato/target/__init__.py @@ -33,9 +33,12 @@ .. autoclass:: BoundProgram """ -from collections.abc import Mapping from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Mapping class Target: diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index cf8d57a51..c6568ba58 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -50,7 +50,6 @@ """ import sys -from collections.abc import Callable, Mapping from dataclasses import dataclass, field from functools import cached_property from typing import TYPE_CHECKING, Any @@ -64,6 +63,12 @@ from pytato.target import BoundProgram, Target +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + import pyopencl + + class ImplSubstitution(ImplementationStrategy): """ An :class:`~pytato.tags.ImplementationStrategy` that lowers the array @@ -75,7 +80,7 @@ class ImplSubstitution(ImplementationStrategy): if getattr(sys, "_BUILDING_SPHINX_DOCS", False) or TYPE_CHECKING: # Avoid import unless building docs to avoid creating a hard # dependency on pyopencl, when Loopy can run fine without. - import pyopencl + pass class LoopyTarget(Target): diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index ca9ae482d..528de5d0e 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -28,15 +28,13 @@ import sys from abc import ABC, abstractmethod from collections.abc import Mapping +from typing import TYPE_CHECKING import islpy as isl import loopy as lp import pymbolic.primitives as prim -import pytools from pymbolic import ArithmeticExpression, var -from pymbolic.typing import Expression -from pytools.tag import Tag import pytato.reductions as red import pytato.scalar_expr as scalar_expr @@ -59,8 +57,6 @@ normalize_outputs, preprocess, ) -from pytato.function import Call, NamedCallResult -from pytato.loopy import LoopyCall from pytato.scalar_expr import ( INT_CLASSES, ScalarExpression, @@ -73,16 +69,26 @@ Named, PrefixNamed, ) -from pytato.target import BoundProgram from pytato.target.loopy import ImplSubstitution, LoopyPyOpenCLTarget, LoopyTarget from pytato.transform import Mapper +if TYPE_CHECKING: + import pyopencl + import pytools + from pymbolic.typing import Expression + from pytools.tag import Tag + + from pytato.function import Call, NamedCallResult + from pytato.loopy import LoopyCall + from pytato.target import BoundProgram + + # set in doc/conf.py if getattr(sys, "_BUILDING_SPHINX_DOCS", False): # Avoid import unless building docs to avoid creating a hard # dependency on pyopencl, when Loopy can run fine without. - import pyopencl + from pytools.tag import Tag # noqa: TC001 __doc__ = """ .. autoclass:: PersistentExpressionContext diff --git a/pytato/target/python/__init__.py b/pytato/target/python/__init__.py index 8a85405b3..7fbee4e59 100644 --- a/pytato/target/python/__init__.py +++ b/pytato/target/python/__init__.py @@ -37,16 +37,19 @@ """ from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping from dataclasses import dataclass from functools import cached_property -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from pytato.target import BoundProgram, Target +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + # {{{ abstract types class PythonTarget(Target, ABC): diff --git a/pytato/target/python/jax.py b/pytato/target/python/jax.py index 2e9350625..efe691ac9 100644 --- a/pytato/target/python/jax.py +++ b/pytato/target/python/jax.py @@ -26,13 +26,18 @@ """ import ast -from collections.abc import Mapping +from typing import TYPE_CHECKING -from pytato.array import Array, DictOfNamedArrays from pytato.target.python import BoundJAXPythonProgram, JAXPythonTarget from pytato.target.python.numpy_like import generate_numpy_like +if TYPE_CHECKING: + from collections.abc import Mapping + + from pytato.array import Array, DictOfNamedArrays + + __doc__ = """ .. autofunction:: generate_jax """ diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 28f3f0fc4..0c9bd413c 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -28,8 +28,8 @@ import ast import os import sys -from collections.abc import Callable, Iterable, Mapping from typing import ( + TYPE_CHECKING, TypedDict, TypeVar, cast, @@ -72,11 +72,16 @@ ReductionOperation, SumReductionOperation, ) -from pytato.target.python import BoundPythonProgram, NumpyLikePythonTarget from pytato.transform import CachedMapper from pytato.utils import are_shape_components_equal, get_einsum_specification +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping + + from pytato.target.python import BoundPythonProgram, NumpyLikePythonTarget + + T = TypeVar("T") @@ -410,7 +415,7 @@ def _map_index_base(self, expr: IndexBase) -> str: default=-1, pred=lambda i: not (isinstance(expr.indices[i], NormalizedSlice) and _is_slice_trivial( - cast(NormalizedSlice, expr.indices[i]), + cast("NormalizedSlice", expr.indices[i]), expr.array.shape[i])) ) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 4c5d4a0f1..393583037 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -28,11 +28,12 @@ """ import dataclasses import logging -from collections.abc import Callable, Hashable, Iterable, Mapping from typing import ( + TYPE_CHECKING, Any, Generic, ParamSpec, + TypeAlias, TypeVar, cast, ) @@ -78,7 +79,11 @@ from pytato.tags import ImplStored -ArrayOrNames = Array | AbstractResultWithNamedArrays +if TYPE_CHECKING: + from collections.abc import Callable, Hashable, Iterable, Mapping + + +ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper @@ -131,6 +136,8 @@ Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. class:: ArrayOrNames + .. class:: MappedT A type variable representing the input type of a :class:`Mapper`. @@ -143,6 +150,9 @@ A type variable representing the result type of a :class:`Mapper`. +.. class:: Scalar + + See :data:`pymbolic.Scalar`. """ transform_logger = logging.getLogger(__file__) @@ -211,7 +221,7 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_foreign(expr, *args, **kwargs) assert method is not None - return cast(ResultT, method(expr, *args, **kwargs)) + return cast("ResultT", method(expr, *args, **kwargs)) def __call__(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: @@ -1572,7 +1582,7 @@ def map_and_copy(expr: MappedT, Uses :class:`CachedMapAndCopyMapper` under the hood and because of its caching nature each node is mapped exactly once. """ - return cast(MappedT, CachedMapAndCopyMapper(map_fn)(expr)) + return cast("MappedT", CachedMapAndCopyMapper(map_fn)(expr)) def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 74945a4d4..bc3d69909 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -29,7 +29,8 @@ THE SOFTWARE. """ -from collections.abc import Mapping + +from typing import TYPE_CHECKING from pytato.array import ( AbstractResultWithNamedArrays, @@ -42,6 +43,10 @@ from pytato.transform import ArrayOrNames, CopyMapper +if TYPE_CHECKING: + from collections.abc import Mapping + + # {{{ inlining class PlaceholderSubstitutor(CopyMapper): diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 0d9a6076f..7a23518c6 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -35,14 +35,11 @@ import dataclasses -from collections.abc import Callable, Mapping -from typing import cast +from typing import TYPE_CHECKING, cast import numpy as np from immutabledict import immutabledict -from pytools.tag import Tag - from pytato.array import ( Array, AxesT, @@ -59,7 +56,6 @@ Roll, Stack, ) -from pytato.raising import HighLevelOp from pytato.transform import ( MappedT, TransformMapperWithExtraArgs, @@ -67,6 +63,14 @@ from pytato.utils import are_shapes_equal +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + from pytools.tag import Tag + + from pytato.raising import HighLevelOp + + class EinsumDistributiveLawDescriptor: """ Abstract-type that informs :func:`apply_distributive_property_to_einsums` @@ -358,4 +362,4 @@ def apply_distributive_property_to_einsums( True """ mapper = EinsumDistributiveLawMapper(how_to_distribute) - return cast(MappedT, mapper(expr, None)) + return cast("MappedT", mapper(expr, None)) diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index dd8f6babd..e952e4e9e 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -31,7 +31,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, TypeVar, cast -import numpy as np from immutabledict import immutabledict import pymbolic.primitives as prim @@ -61,6 +60,10 @@ from pytato.transform import Mapper +if TYPE_CHECKING: + import numpy as np + + ToIndexLambdaT = TypeVar("ToIndexLambdaT", Array, AbstractResultWithNamedArrays) @@ -443,7 +446,7 @@ def map_contiguous_advanced_index(self, for i, idx_expr in enumerate(expr.indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast(Array | int | np.integer[Any], expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", expr.indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() @@ -511,7 +514,7 @@ def map_non_contiguous_advanced_index( for i, idx_expr in enumerate(expr.indices) if isinstance(idx_expr, (Array, *INT_CLASSES))) adv_idx_shape = get_shape_after_broadcasting([ - cast(Array | int | np.integer[Any], expr.indices[i_idx]) + cast("Array | int | np.integer[Any]", expr.indices[i_idx]) for i_idx in i_adv_indices]) vng = UniqueNameGenerator() @@ -622,7 +625,7 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: indices[to_index] = prim.Variable(f"_{from_index}") index_expr = prim.Variable("_in0")[ - cast(tuple[ArithmeticExpression], tuple(indices))] + cast("tuple[ArithmeticExpression]", tuple(indices))] return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 139b7bf5b..130db8b7a 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -40,7 +40,6 @@ import logging -from collections.abc import Collection, Mapping from typing import ( TYPE_CHECKING, Any, @@ -73,7 +72,6 @@ ) from pytato.diagnostic import UnknownIndexLambdaExpr from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder -from pytato.function import NamedCallResult from pytato.raising import ( BinaryOp, BroadcastOp, @@ -92,6 +90,9 @@ if TYPE_CHECKING: + from collections.abc import Collection, Mapping + + from pytato.function import NamedCallResult from pytato.loopy import LoopyCall @@ -416,7 +417,7 @@ def map_contiguous_advanced_index(self, for i_idx in i_basic_indices if i_idx > i_adv_indices[-1]]) - indirection_arrays: list[Array] = cast(list[Array], + indirection_arrays: list[Array] = cast("list[Array]", [expr.indices[i_idx] for i_idx in i_adv_indices if isinstance(expr.indices[i_idx], diff --git a/pytato/transform/remove_broadcasts_einsum.py b/pytato/transform/remove_broadcasts_einsum.py index a9997162a..8c7c224fb 100644 --- a/pytato/transform/remove_broadcasts_einsum.py +++ b/pytato/transform/remove_broadcasts_einsum.py @@ -97,6 +97,6 @@ def rewrite_einsums_with_no_broadcasts(expr: MappedT) -> MappedT: alter its value. """ mapper = EinsumWithNoBroadcastsRewriter() - return cast(MappedT, mapper(expr)) + return cast("MappedT", mapper(expr)) # vim:fdm=marker diff --git a/pytato/utils.py b/pytato/utils.py index 7fa60569e..aedd0dc05 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -22,8 +22,8 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from collections.abc import Callable, Iterable, Sequence from typing import ( + TYPE_CHECKING, Any, TypeVar, cast, @@ -36,7 +36,6 @@ import pymbolic.primitives as prim from pymbolic import ArithmeticExpression, Bool, Scalar from pytools import UniqueNameGenerator -from pytools.tag import Tag from pytato.array import ( AdvancedIndexInContiguousAxes, @@ -64,6 +63,12 @@ from pytato.transform import CachedMapper +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Sequence + + from pytools.tag import Tag + + __doc__ = """ Helper routines --------------- @@ -534,7 +539,7 @@ def _index_into( + (slice(None, None, None),) * (ary.ndim - len(indices) + 1) + indices[ellipsis_pos+1:]) - indices = cast(tuple[int, slice, "Array", None], indices) + indices = cast("tuple[int, slice, Array, None]", indices) # }}} diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index f505c900a..7e685aba9 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -29,7 +29,6 @@ import dataclasses import html -from collections.abc import Callable, Mapping from functools import partial from typing import ( TYPE_CHECKING, @@ -38,7 +37,6 @@ from pytools import UniqueNameGenerator from pytools.codegen import remove_common_indentation -from pytools.tag import Tag from pytato.array import ( AbstractResultWithNamedArrays, @@ -60,13 +58,17 @@ PartId, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.loopy import LoopyCall from pytato.tags import FunctionIdentifier from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + from pytools.tag import Tag + from pytato.distributed.nodes import DistributedSendRefHolder + from pytato.loopy import LoopyCall __doc__ = """ diff --git a/pytato/visualization/fancy_placeholder_data_flow.py b/pytato/visualization/fancy_placeholder_data_flow.py index f3ffe430b..3d06309fd 100644 --- a/pytato/visualization/fancy_placeholder_data_flow.py +++ b/pytato/visualization/fancy_placeholder_data_flow.py @@ -5,9 +5,8 @@ """ from __future__ import annotations -from collections.abc import Collection from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any from pytools import UniqueNameGenerator @@ -27,6 +26,10 @@ from pytato.transform import CachedMapper +if TYPE_CHECKING: + from collections.abc import Collection + + # {{{ Graph node colors PLACEHOLDER_COLOR = "lightgrey" diff --git a/test/testlib.py b/test/testlib.py index 02e76ad9a..36857197b 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -2,29 +2,34 @@ import operator import random -import types -from collections.abc import Callable, Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -import pyopencl as cl from pytools.tag import Tag import pytato as pt -from pytato.array import ( - Array, - AxisPermutation, - Concatenate, - DataWrapper, - Placeholder, - Reshape, - Roll, - Stack, -) from pytato.transform import Mapper +if TYPE_CHECKING: + import types + from collections.abc import Callable, Sequence + + import pyopencl as cl + + from pytato.array import ( + Array, + AxisPermutation, + Concatenate, + DataWrapper, + Placeholder, + Reshape, + Roll, + Stack, + ) + + # {{{ tools for comparison to numpy class NumpyBasedEvaluator(Mapper[Any, []]): @@ -296,7 +301,7 @@ def make_dws_placeholder(expr: pt.transform.ArrayOrNames else: return expr - dag = cast(pt.DictOfNamedArrays, + dag = cast("pt.DictOfNamedArrays", pt.transform.map_and_copy(dag, make_dws_placeholder)) return dag