Skip to content

Commit

Permalink
Merge pull request #906 from gchq/fix/pyright2
Browse files Browse the repository at this point in the history
Type annotation fixes to make coreax package pass Pyright
  • Loading branch information
qh681248 authored Dec 24, 2024
2 parents f784720 + fa06f3b commit 4252551
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 21 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- `MMD.compute` no longer returns `nan`. (https://github.com/gchq/coreax/issues/855)
- Corrected an implementation error in `coreax.solvers.CaratheodoryRecombination`
- Corrected an implementation error in `coreax.solvers.CaratheodoryRecombination`,
which caused numerical instability when using either `CaratheodoryRecombination`
or `TreeRecombination` on GPU machines. (https://github.com/gchq/coreax/pull/874, see
also https://github.com/gchq/coreax/issues/852 and
https://github.com/gchq/coreax/issues/853)
- `KernelHerding.refine` correctly computes a refinement of an existing coreset. (https://github.com/gchq/coreax/issues/870)
- Pylint pre-commit hook is now configured as the Pylint docs recommend (https://github.com/gchq/coreax/pull/899)
- Pylint pre-commit hook is now configured as the Pylint docs recommend. (https://github.com/gchq/coreax/pull/899)
- Type annotations so that core coreax package passes Pyright. (https://github.com/gchq/coreax/pull/906)

### Changed

Expand Down
18 changes: 13 additions & 5 deletions coreax/coreset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Coreset(eqx.Module, Generic[_Data]):
nodes: _Data
pre_coreset_data: _Data

def __init__(self, nodes: _Data, pre_coreset_data: _Data):
def __init__(self, nodes: _Data, pre_coreset_data: _Data) -> None:
"""Handle type conversion of ``nodes`` and ``pre_coreset_data``."""
if isinstance(nodes, Array):
self.nodes = as_data(nodes)
Expand All @@ -67,15 +67,15 @@ def __init__(self, nodes: _Data, pre_coreset_data: _Data):
else:
self.pre_coreset_data = pre_coreset_data

def __check_init__(self):
def __check_init__(self) -> None:
"""Check that coreset has fewer 'nodes' than the 'pre_coreset_data'."""
if len(self.nodes) > len(self.pre_coreset_data):
raise ValueError(
"'len(nodes)' cannot be greater than 'len(pre_coreset_data)' "
"by definition of a Coreset"
)

def __len__(self):
def __len__(self) -> int:
"""Return Coreset size/length."""
return len(self.nodes)

Expand All @@ -96,7 +96,7 @@ def compute_metric(
return metric.compute(self.pre_coreset_data, self.coreset, **metric_kwargs)


class Coresubset(Coreset[Data], Generic[_Data]):
class Coresubset(Coreset[_Data], Generic[_Data]):
r"""
Data structure for representing a coresubset.
Expand All @@ -121,9 +121,17 @@ class Coresubset(Coreset[Data], Generic[_Data]):
:param pre_coreset_data: The dataset :math:`X` used to construct the coreset.
"""

# Unlike on Coreset, contains indices of coreset rather than coreset itself
nodes: Data

def __init__(self, nodes: Data, pre_coreset_data: _Data):
"""Handle typing of ``nodes`` being a `Data` instance."""
super().__init__(nodes, pre_coreset_data)
# nodes type can't technically be cast to _Data but do so anyway to avoid a
# significant amount of boilerplate just for type checking
super().__init__(
nodes, # pyright: ignore [reportArgumentType]
pre_coreset_data,
)

@property
def coreset(self) -> _Data:
Expand Down
5 changes: 4 additions & 1 deletion coreax/kernels/scalar_valued.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,10 @@ class SteinKernel(UniCompositeKernel):
:math:`\mathbb{R}^d \to \mathbb{R}^d`
"""

score_function: Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]
score_function: Callable[
[Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int]],
Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]],
]

@override
def compute_elementwise(self, x, y):
Expand Down
50 changes: 41 additions & 9 deletions coreax/score_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from abc import abstractmethod
from collections.abc import Callable, Sequence
from functools import partial
from typing import Union
from typing import Union, overload

import equinox as eqx
import numpy as np
Expand All @@ -52,6 +52,7 @@
from jaxtyping import Shaped
from optax import adamw
from tqdm import tqdm as LoudTQDM # noqa:N812
from typing_extensions import override

from coreax.kernels import ScalarValuedKernel, SquaredExponentialKernel, SteinKernel
from coreax.networks import ScoreNetwork, _LearningRateOptimiser, create_train_state
Expand All @@ -74,9 +75,30 @@ class ScoreMatching(eqx.Module):
"""

@abstractmethod
@overload
def match(
self, x: Union[Shaped[Array, " 1 1"], Shaped[Array, ""], float, int]
) -> Callable[
[Union[Shaped[Array, " 1 1"], Shaped[Array, ""], float, int]],
Shaped[Array, " 1 1"],
]: ...

@abstractmethod
@overload
def match( # pyright: ignore[reportOverlappingOverload]
self, x: Shaped[Array, " n d"]
) -> Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]:
) -> Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]: ...

@abstractmethod
def match(
self, x: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int]
) -> Union[
Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]],
Callable[
[Union[Shaped[Array, " 1 1"], Shaped[Array, ""], float, int]],
Shaped[Array, " 1 1"],
],
]:
r"""
Match some model score function to dataset :math:`X\in\mathbb{R}^{n \times d}`.
Expand Down Expand Up @@ -404,9 +426,8 @@ def loss(params):
state = state.apply_gradients(grads=grads)
return state, val

def match( # noqa: C901, PLR0912
self, x: Shaped[Array, " n d"]
) -> Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]:
@override
def match(self, x): # noqa: C901, PLR0912
r"""
Learn a sliced score matching function via :cite:`song2020ssm`.
Expand Down Expand Up @@ -523,9 +544,8 @@ def __init__(self, length_scale: float):
)
super().__init__()

def match(
self, x: Shaped[Array, " n d"]
) -> Callable[[Shaped[Array, " n d"]], Shaped[Array, " n d"]]:
@override
def match(self, x):
r"""
Learn a score function using kernel density estimation to model a distribution.
Expand All @@ -542,7 +562,19 @@ def match(
"""
kde_data = x

def score_function(x_: Shaped[Array, " n d"]) -> Shaped[Array, " n d"]:
@overload
def score_function(
x_: Union[Shaped[Array, " 1 1"], Shaped[Array, ""], float, int],
) -> Shaped[Array, " 1 1"]: ...

@overload
def score_function( # pyright: ignore[reportOverlappingOverload]
x_: Shaped[Array, " n d"],
) -> Shaped[Array, " n d"]: ...

def score_function(
x_: Union[Shaped[Array, " n d"], Shaped[Array, ""], float, int],
) -> Union[Shaped[Array, " n d"], Shaped[Array, " 1 1"]]:
r"""
Compute the score function using a kernel density estimation.
Expand Down
18 changes: 14 additions & 4 deletions coreax/solvers/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

import math
import warnings
from typing import Generic, Optional, TypeVar, Union
from typing import Generic, Optional, TypeVar, Union, overload

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jax import Array
from jax.typing import ArrayLike
from jaxtyping import Integer
from sklearn.neighbors import BallTree, KDTree
from typing_extensions import TypeAlias, override

Expand All @@ -37,7 +37,7 @@
_Data = TypeVar("_Data", bound=Data)
_Coreset = TypeVar("_Coreset", Coreset, Coresubset)
_State = TypeVar("_State")
_Indices = TypeVar("_Indices", ArrayLike, None)
_Indices = Integer[Array, "..."]


class CompositeSolver(
Expand Down Expand Up @@ -128,9 +128,19 @@ def reduce(
# There is no obvious way to use state information here.
del solver_state

@overload
def _reduce_coreset(
data: _Data, _indices: _Indices
) -> tuple[_Coreset, _State, _Indices]: ...

@overload
def _reduce_coreset(
data: _Data, _indices: Optional[_Indices] = None
) -> tuple[_Coreset, _State, Optional[_Indices]]: ...

def _reduce_coreset(
data: _Data, _indices: Optional[_Indices] = None
) -> tuple[_Coreset, _State, _Indices]:
) -> tuple[_Coreset, _State, Optional[_Indices]]:
if len(data) <= self.leaf_size:
coreset, state = self.base_solver.reduce(data)
if _indices is not None:
Expand Down

0 comments on commit 4252551

Please sign in to comment.