Skip to content

Commit

Permalink
Merge pull request #901 from gchq/fix/pyright_pass1
Browse files Browse the repository at this point in the history
Typing fixes
  • Loading branch information
rg936672 authored Dec 20, 2024
2 parents db7cae0 + f171eef commit c747ca0
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 15 deletions.
45 changes: 35 additions & 10 deletions coreax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Data-structures for representing weighted and/or supervised data."""

from typing import List, Optional, Sequence, Union, overload
from typing import Optional, Sequence, Union, overload

import equinox as eqx
import jax.numpy as jnp
Expand All @@ -26,8 +26,14 @@

@overload
def _atleast_2d_consistent(
arrays: Sequence[Shaped[Array, ""]],
) -> List[Shaped[Array, " 1 1"]]: ...
arrays: Union[Shaped[Array, ""], float, int],
) -> Shaped[Array, " 1 1"]: ...


@overload
def _atleast_2d_consistent(
arrays: Sequence[Union[Shaped[Array, ""], float, int]],
) -> list[Shaped[Array, " 1 1"]]: ...


@overload
Expand All @@ -45,9 +51,15 @@ def _atleast_2d_consistent( # pyright:ignore[reportOverlappingOverload]
@overload
def _atleast_2d_consistent( # pyright:ignore[reportOverlappingOverload]
arrays: Sequence[
Union[Shaped[Array, " _n _d _*p"], Shaped[Array, " _n"], Shaped[Array, ""]]
Union[
Shaped[Array, " _n _d _*p"],
Shaped[Array, " _n"],
Shaped[Array, ""],
float,
int,
]
],
) -> List[
) -> list[
Union[Shaped[Array, " _n _d _*p"], Shaped[Array, " _n 1"], Shaped[Array, " 1 1"]]
]: ...

Expand All @@ -58,25 +70,38 @@ def _atleast_2d_consistent( # pyright:ignore[reportOverlappingOverload]
Shaped[Array, " n d *p"],
Shaped[Array, " n"],
Shaped[Array, ""],
float,
int,
Sequence[
Union[Shaped[Array, " _n _d _*p"], Shaped[Array, " _n"], Shaped[Array, ""]]
Union[
Shaped[Array, " _n _d _*p"],
Shaped[Array, " _n"],
Shaped[Array, ""],
float,
int,
]
],
],
) -> Union[
Shaped[Array, " n d *p"],
Shaped[Array, " n 1"],
Shaped[Array, " 1 1"],
List[Union[Shaped[Array, " _n _d _*p"], Shaped[Array, " _n"], Shaped[Array, ""]]],
list[Union[Shaped[Array, " _n _d _*p"], Shaped[Array, " _n"], Shaped[Array, ""]]],
]:
r"""
Given an array or sequence of arrays ensure they are at least 2-dimensional.
Float and integer types are cast as zero-dimensional input arrays, giving 1x1
output.
.. note::
This function differs from `jax.numpy.atleast_2d` in that it converts
1-dimensional `n`-vectors into arrays of shape `(n, 1)` rather than `(1, n)`.
This function differs from :func:`jax.numpy.atleast_2d` in that it converts
1-dimensional ``n``-vectors into arrays of shape ``(n, 1)`` rather than
``(1, n)``.
:param arrays: Singular array or sequence of arrays
:return: at least 2-dimensional array or list of at least 2-dimensional arrays
:return: At least 2-dimensional array or list of at least 2-dimensional arrays
"""
# If we have been given just one array, return as an array, not list
if len(arrays) == 1:
Expand Down
3 changes: 2 additions & 1 deletion coreax/solvers/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
import jax.tree_util as jtu
import numpy as np
from jax import Array
from jax.typing import ArrayLike
from sklearn.neighbors import BallTree, KDTree
from typing_extensions import TypeAlias, override

from coreax.coreset import Coreset, Coresubset
from coreax.data import Data
from coreax.solvers.base import ExplicitSizeSolver, PaddingInvariantSolver, Solver
from coreax.util import ArrayLike, tree_zero_pad_leading_axis
from coreax.util import tree_zero_pad_leading_axis

BinaryTree: TypeAlias = Union[KDTree, BallTree]
_Data = TypeVar("_Data", bound=Data)
Expand Down
2 changes: 1 addition & 1 deletion coreax/solvers/coresubset.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class SteinThinning(
score_matching: Optional[ScoreMatching] = None
unique: bool = True
regularise: bool = True
regulariser_lambda: float = None
regulariser_lambda: Optional[float] = None
block_size: Optional[Union[int, tuple[Optional[int], Optional[int]]]] = None
unroll: Union[int, bool, tuple[Union[int, bool], Union[int, bool]]] = 1

Expand Down
4 changes: 2 additions & 2 deletions coreax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class factories and checks for numerical precision.
import jax.random as jr
import jax.tree_util as jtu
from jax import Array, block_until_ready, jit, vmap
from jax.typing import ArrayLike
from jaxtyping import Shaped
from typing_extensions import TypeAlias, deprecated

Expand All @@ -56,7 +55,8 @@ class factories and checks for numerical precision.

#: JAX random key type annotations.
KeyArray: TypeAlias = Array
KeyArrayLike: TypeAlias = ArrayLike
# jax.random functions crash if passed a scalar, so can't use ArrayLike
KeyArrayLike: TypeAlias = Array


class NotCalculatedError(Exception):
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_as_supervised_data():
(jnp.array([[1], [1]]), jnp.array([[1], [1]])),
(jnp.array([[[1]], [[1]]]),),
(jnp.array([[[1]], [[1]]]), jnp.array([[[1]], [[1]]])),
(1.0,),
(1.0, 1.0),
(1,),
(1, 1),
],
ids=[
"single_zero_dimensional_array",
Expand All @@ -68,6 +72,10 @@ def test_as_supervised_data():
"multiple_two_dimensional_arrays",
"single_three_dimensional_array",
"multiple_three_dimensional_arrays",
"single_float",
"multiple_floats",
"single_int",
"multiple_ints",
],
)
def test_atleast_2d_consistent(arrays: tuple[Array]) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import numpy as np
import pytest
from jax import Array, jacfwd, vmap
from jax.typing import ArrayLike

from coreax.data import Data
from coreax.kernels import (
Expand All @@ -40,7 +41,7 @@
)
from coreax.metrics import KSD, MMD
from coreax.score_matching import convert_stein_kernel
from coreax.util import ArrayLike, pairwise
from coreax.util import pairwise


class _MetricProblem(NamedTuple):
Expand Down

0 comments on commit c747ca0

Please sign in to comment.