Skip to content

Commit

Permalink
Merge pull request #878 from gchq/feat/jaxtyping_for_data
Browse files Browse the repository at this point in the history
feat: `jaxtyping` support for `coreax.data.Data`
  • Loading branch information
rg936672 authored Dec 19, 2024
2 parents 39e9ef1 + 6efb0c8 commit c7ef0f7
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 169 deletions.
3 changes: 3 additions & 0 deletions .cspell/library_terms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ autodoc
automodule
autosectionlabel
autoupdate
beartype
bibfiles
bibtex
bmatrix
Expand Down Expand Up @@ -57,6 +58,7 @@ jacrev
jax
jaxlib
jaxopt
jaxtyped
jaxtyping
jumanjihouse
keepends
Expand Down Expand Up @@ -139,6 +141,7 @@ tqdm
triu
ttest
tuplegetter
typeguard
typehints
umap
undoc
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
(https://github.com/gchq/coreax/pull/887)
- **[BREAKING CHANGE]** Equinox dependency version is changed from `<0.11.8` to `>=0.
11.5`. (https://github.com/gchq/coreax/pull/898)
- **[BREAKING CHANGE]** The `jaxtyping` version is now lower bounded at `v0.2.31` to enable `coreax.data.Data` jaxtyping compatibility.

### Removed

Expand All @@ -62,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`tests.unit.test_solvers`. (https://github.com/gchq/coreax/pull/822)
- Added a unit test for RPCholesky to check whether the coreset has duplicates.
(https://github.com/gchq/coreax/pull/836)
- Enabled `jaxtyping` compatible type hinting for `coreax.data.Data`, to indicate the
expected type and shape of a `Data` objects `Data.data` array attribute. For example
`Bool[Data, "n d"]` indicates `Data.data` should be an `n d` array of bools.

### Fixed

Expand Down
17 changes: 17 additions & 0 deletions coreax/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ class Data(eqx.Module):
`n`-vector inputs for `data` are interpreted as `n` points in 1-dimension and
converted to a `(n, 1)` array.
Compatible with :func:`jaxtyping.jaxtyped` -- :class:`Data` is interpreted as an
array type, whose shape is the expected shape of :attr:`Data.data`.
.. note::
A `Data` object whose :attr:`Data.data` is expected to be a floating point array
with shape `a b`, can be type hinted as `x: Float[Data, " a b"] = ...`.
:param data: An :math:`n \times d` array defining the features of the unsupervised
dataset
:param weights: An :math:`n`-vector of weights where each element of the weights
Expand Down Expand Up @@ -164,6 +171,16 @@ def __len__(self) -> int:
"""Return data length."""
return len(self.data)

@property
def dtype(self):
"""Return dtype of data; used for jaxtyping annotations."""
return self.data.dtype

@property
def shape(self):
"""Return shape of data; used for jaxtyping annotations."""
return self.data.shape

def normalize(self, *, preserve_zeros: bool = False) -> Self:
"""
Return a copy of ``self`` with ``weights`` that sum to one.
Expand Down
1 change: 1 addition & 0 deletions documentation/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
"python": ("https://docs.python.org/3", None),
"jax": ("https://jax.readthedocs.io/en/latest", None),
"jaxopt": ("https://jaxopt.github.io/stable", None),
"jaxtyping": ("https://docs.kidger.site/jaxtyping", None),
"flax": ("https://flax-linen.readthedocs.io/en/latest", None),
"optax": ("https://optax.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"flax",
"jax",
"jaxopt",
"jaxtyping",
"jaxtyping>0.2.31",
"optax",
"scikit-learn",
"tqdm",
Expand All @@ -48,6 +48,7 @@ benchmark = [
# Run unit tests with coverage assessment
test = [
"coreax[benchmark]",
"beartype",
"imageio",
"matplotlib",
"numpy",
Expand Down
10 changes: 5 additions & 5 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ alabaster==0.7.16 ; python_full_version < '3.10'
alabaster==1.0.0 ; python_full_version >= '3.10'
apeye==1.4.1
apeye-core==1.1.5
attrs==24.2.0
attrs==24.3.0
autodocsumm==0.2.14
babel==2.16.0
beautifulsoup4==4.12.3
cachecontrol==0.14.1
certifi==2024.12.14
charset-normalizer==3.4.0
chex==0.1.88
colorama==0.4.6 ; platform_system == 'Windows' or sys_platform == 'win32'
colorama==0.4.6 ; (platform_system != 'Darwin' and platform_system != 'Linux' and sys_platform == 'win32') or (platform_system == 'Windows' and sys_platform != 'win32')
cssutils==2.11.1
dict2css==0.3.0.post1
docutils==0.21.2
Expand All @@ -33,9 +33,9 @@ imagesize==1.4.1
importlib-metadata==8.5.0 ; python_full_version < '3.10'
importlib-resources==6.4.5
jax==0.4.30 ; python_full_version < '3.10'
jax==0.4.37 ; python_full_version >= '3.10'
jax==0.4.38 ; python_full_version >= '3.10'
jaxlib==0.4.30 ; python_full_version < '3.10'
jaxlib==0.4.36 ; python_full_version >= '3.10'
jaxlib==0.4.38 ; python_full_version >= '3.10'
jaxopt==0.8.3
jaxtyping==0.2.36
jinja2==3.1.4
Expand All @@ -58,7 +58,7 @@ orbax-checkpoint==0.6.4 ; python_full_version < '3.10'
orbax-checkpoint==0.10.2 ; python_full_version >= '3.10'
packaging==24.2
platformdirs==4.3.6
protobuf==5.29.1
protobuf==5.29.2
pybtex==0.24.0
pybtex-docutils==1.0.3
pygments==2.18.0
Expand Down
55 changes: 51 additions & 4 deletions tests/unit/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
produce the expected results on simple examples.
"""

from functools import partial

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
from beartype.door import is_bearable
from jax import Array
from jaxtyping import Float, Int, Shaped

import coreax.data

Expand Down Expand Up @@ -99,8 +99,11 @@ def test_atleast_2d_consistent(arrays: tuple[Array]) -> None:
@pytest.mark.parametrize(
"data_type",
[
partial(coreax.data.Data, DATA_ARRAY),
partial(coreax.data.SupervisedData, DATA_ARRAY, SUPERVISION),
pytest.param(jtu.Partial(coreax.data.Data, DATA_ARRAY), id="Data"),
pytest.param(
jtu.Partial(coreax.data.SupervisedData, DATA_ARRAY, SUPERVISION),
id="SupervisedData",
),
],
)
class TestData:
Expand Down Expand Up @@ -147,6 +150,16 @@ def test_len(self, data_type):
_data = data_type()
assert len(_data) == len(_data.data)

def test_dtype(self, data_type):
"""Test dtype property; required for jaxtyping annotations."""
_data = data_type()
assert _data.data.dtype == _data.dtype

def test_shape(self, data_type):
"""Test shape property; required for jaxtyping annotations."""
_data = data_type()
assert _data.data.shape == _data.shape

@pytest.mark.parametrize("weights", (None, 0, 3, DATA_ARRAY.reshape(-1)))
def test_normalize(self, data_type, weights):
"""Test weight normalization."""
Expand All @@ -160,6 +173,40 @@ def test_normalize(self, data_type, weights):
normalized_with_zeros.weights, jnp.nan_to_num(expected_weights)
)

@pytest.mark.parametrize(
"dtype, valid_jax_type, invalid_jax_type",
[(jnp.int32, Int, Float), (jnp.float32, Float, Int)],
)
def test_jaxtyping_compatibility(
self, data_type, dtype, valid_jax_type, invalid_jax_type
):
"""
Test `Data` compatibility with jaxtyping annotations.
Checks the following cases:
- Correct narrowed shape,
- Correct narrowed shape and narrowed data type,
- Correct narrowed shape and incorrect narrowed data type,
- Incorrect narrowed shape
- Incorrectly narrowed instance type
"""
data_factory = eqx.tree_at(
lambda x: x.args,
data_type,
replace=jtu.tree_map(lambda y: jnp.astype(y, dtype), data_type.args),
)
data = data_factory()
valid_shape = " ".join(str(dim) for dim in data.shape)
invalid_shape = " ".join(str(dim + 1) for dim in data.shape)

assert is_bearable(data, Shaped[coreax.data.Data, valid_shape])
assert is_bearable(data, valid_jax_type[coreax.data.Data, valid_shape])
assert not is_bearable(data, invalid_jax_type[coreax.data.Data, invalid_shape])
assert not is_bearable(data, Shaped[coreax.data.Data, invalid_shape])
if not isinstance(data, coreax.data.SupervisedData):
incorrect_instance_type = Shaped[coreax.data.SupervisedData, "..."]
assert not is_bearable(data, incorrect_instance_type)


class TestSupervisedData:
"""Test operation of SupervisedData class."""
Expand Down
Loading

0 comments on commit c7ef0f7

Please sign in to comment.