Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Black and JAX #146

Merged
merged 3 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: isort
name: isort (python)
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
name: black (python)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ tensorflow-probability = {extras = ["jax"], version = "^0.20.1"}
optional = true

[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
black = "24.2.0"
flake8 = "^6.0.0"
isort = "^5.12.0"
pre-commit = "^3.2.2"
Expand Down
1 change: 1 addition & 0 deletions scripts/Beyond_Normal/experimental/ksg_spiral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests whether KSG estimator is invariant to the "spiral" diffeomorphism."""

import argparse

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions scripts/Beyond_Normal/experimental/ksg_time_comparison.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simple comparison of the speed of our KSG estimators."""

import argparse
import time

Expand Down
1 change: 1 addition & 0 deletions scripts/Beyond_Normal/figures/concept/1v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main figure for the paper, like a visual abstract."""

from typing import Optional

import jax
Expand Down
1 change: 1 addition & 0 deletions scripts/Beyond_Normal/figures/student/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
This script plots the correction term as a function
of the degrees of freedom for several dimensions.
"""

from typing import cast

import matplotlib.pyplot as plt
Expand Down
1 change: 1 addition & 0 deletions scripts/Beyond_Normal/figures/visualise_distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Visualisation of several samplers used."""

import matplotlib.pyplot as plt

import bmi.benchmark.task_list as tl
Expand Down
1 change: 1 addition & 0 deletions scripts/Mixtures/plot_appearing_and_vanishing_mi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Appearing and vanishing mutual information illustration."""

import matplotlib.pyplot as plt
import numpy as np

Expand Down
1 change: 1 addition & 0 deletions scripts/Mixtures/validate_chi2_combination.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This script is used to validate the chi2 combination for
the multivariate normal PMI profile."""

import jax
import matplotlib.pyplot as plt
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/bmi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Benchmarking mutual information package."""

import bmi.benchmark as benchmark
import bmi.estimators as estimators
import bmi.samplers as samplers
Expand Down
1 change: 1 addition & 0 deletions src/bmi/benchmark/utils/timer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Creates a Timer class, a convenient thing to measure the elapsed time."""

import time


Expand Down
1 change: 1 addition & 0 deletions src/bmi/estimators/_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Note: if a bin p(x, y) has zero counts, we assign zero contribution from it to the MI:
MI \\approx \\sum p(x, y) \\log( p(x, y) / p(x)p(y) )
"""

from itertools import product
from typing import Optional

Expand Down
1 change: 1 addition & 0 deletions src/bmi/estimators/_kde.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Estimation of mutual information via kernel density estimation
of the differential entropy."""

from typing import Literal, Optional, Union

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/bmi/estimators/ksg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Kraskov estimators."""

from typing import Literal, Optional, Sequence, cast

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/bmi/estimators/neural/_backend_linear_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

In particular, we can use larger batches.
"""

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
Expand Down
1 change: 1 addition & 0 deletions src/bmi/estimators/neural/_backend_quadratic_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
They require quadratic, that is O(batch size ** 2), memory
so they cannot be used with large batches.
"""

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
Expand Down
7 changes: 3 additions & 4 deletions src/bmi/estimators/neural/_basic_training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Basic training loop used for most neural estimators."""

from typing import Callable, Optional

import equinox as eqx
Expand All @@ -9,9 +10,7 @@
from bmi.estimators.neural._types import BatchedPoints, Critic, Point


def get_batch(
xs: BatchedPoints, ys: BatchedPoints, key: jax.random.PRNGKeyArray, batch_size: Optional[int]
):
def get_batch(xs: BatchedPoints, ys: BatchedPoints, key: jax.Array, batch_size: Optional[int]):
if batch_size is not None:
batch_indices = jax.random.choice(
key,
Expand All @@ -25,7 +24,7 @@ def get_batch(


def basic_training(
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
critic: eqx.Module,
mi_formula: Callable[[Critic, Point, Point], float],
xs: BatchedPoints,
Expand Down
3 changes: 2 additions & 1 deletion src/bmi/estimators/neural/_critics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module with neural networks used as critics."""

from typing import Sequence

import equinox as eqx
Expand All @@ -16,7 +17,7 @@ class MLP(eqx.Module):

def __init__(
self,
key: jax.random.PRNGKeyArray,
key: jax.Array,
dim_x: int,
dim_y: int,
hidden_layers: Sequence[int] = (5,),
Expand Down
5 changes: 3 additions & 2 deletions src/bmi/estimators/neural/_estimators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""API of the neural estimators implemented in JAX."""

from typing import Any, Callable, Literal, Optional, Sequence

import equinox as eqx
Expand Down Expand Up @@ -43,14 +44,14 @@ def train_test_split(
xs: BatchedPoints,
ys: BatchedPoints,
train_size: Optional[float],
key: jax.random.PRNGKeyArray,
key: jax.Array,
) -> tuple[BatchedPoints, BatchedPoints, BatchedPoints, BatchedPoints]:
if train_size is None:
return xs, xs, ys, ys

else:
# get random int from jax key
random_state = int(jax.random.randint(key, (1,), 0, 1000))
random_state = int(jax.random.randint(key, shape=(), minval=0, maxval=1000))

xs_train, xs_test, ys_train, ys_test = msel.train_test_split(
xs,
Expand Down
7 changes: 4 additions & 3 deletions src/bmi/estimators/neural/_mine_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
The expression for the gradient
is given by Equation (12) in Section 3.2.
"""

from typing import Optional, Sequence

import equinox as eqx
Expand Down Expand Up @@ -100,7 +101,7 @@ def _mine_value_neg_grad_log_denom(


def _sample_paired_unpaired(
key: jax.random.PRNGKeyArray,
key: jax.Array,
xs: BatchedPoints,
ys: BatchedPoints,
batch_size: Optional[int],
Expand Down Expand Up @@ -132,7 +133,7 @@ def _sample_paired_unpaired(


def mine_training(
rng: jax.random.PRNGKeyArray,
rng: jax.Array,
critic: eqx.Module,
xs: BatchedPoints,
ys: BatchedPoints,
Expand Down Expand Up @@ -312,7 +313,7 @@ def trained_critic(self) -> Optional[eqx.Module]:
def parameters(self) -> MINEParams:
return self._params

def _create_critic(self, dim_x: int, dim_y: int, key: jax.random.PRNGKeyArray) -> MLP:
def _create_critic(self, dim_x: int, dim_y: int, key: jax.Array) -> MLP:
return MLP(dim_x=dim_x, dim_y=dim_y, key=key, hidden_layers=self._params.hidden_layers)

def estimate_with_info(self, x: ArrayLike, y: ArrayLike) -> EstimateResult:
Expand Down
1 change: 1 addition & 0 deletions src/bmi/estimators/neural/_training_log.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility class for keeping information about training and displaying tqdm."""

from typing import Union

import jax
Expand Down
3 changes: 2 additions & 1 deletion src/bmi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
This restriction is to ensure that any subpackage can import from
the `interface` module and that we do not run into the circular imports issue.
"""

import pathlib
from abc import abstractmethod
from typing import Any, Optional, Protocol, Union
Expand All @@ -25,7 +26,7 @@ class BaseModel(pydantic.BaseModel): # pytype: disable=invalid-annotation
pass


# This should be updated to the PRNGKeyArray (or possibly union with Any)
# This should be updated to the Array (or possibly union with Any)
# when it becomes a part of public JAX API
KeyArray = Any
Pathlike = Union[str, pathlib.Path]
Expand Down
1 change: 1 addition & 0 deletions src/bmi/plot_utils/subplots_from_axsize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Allows creating """

from collections.abc import Iterable
from typing import Union

Expand Down
1 change: 1 addition & 0 deletions src/bmi/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Subpackage with different samplers, which can be used to define benchmark tasks
or used directly to sample from distributions with known mutual information."""

from bmi.samplers._additive_uniform import AdditiveUniformSampler

# isort: off
Expand Down
1 change: 1 addition & 0 deletions src/bmi/samplers/_matrix_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for creating dispersion matrices."""

import dataclasses
from typing import Optional

Expand Down
8 changes: 3 additions & 5 deletions src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ class JointDistribution:
dim_y: int
analytic_mi: Optional[float] = None

def sample(
self, n_points: int, key: jax.random.PRNGKeyArray
) -> tuple[jnp.ndarray, jnp.ndarray]:
def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Sample from the joint distribution $P_{XY}$.

Args:
Expand Down Expand Up @@ -152,7 +150,7 @@ def transform(
)


def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) -> jnp.ndarray:
def pmi_profile(key: jax.Array, dist: JointDistribution, n: int) -> jnp.ndarray:
"""Monte Carlo draws a sample of size `n` from the PMI distribution.

Args:
Expand All @@ -168,7 +166,7 @@ def pmi_profile(key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int) -


def monte_carlo_mi_estimate(
key: jax.random.PRNGKeyArray, dist: JointDistribution, n: int
key: jax.Array, dist: JointDistribution, n: int
) -> tuple[float, float]:
"""Estimates the mutual information $I(X; Y)$ using Monte Carlo sampling.

Expand Down
1 change: 1 addition & 0 deletions src/bmi/samplers/_tfp/_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A wrapper from TFP distributions to BMI samplers."""

from typing import Optional, Union

import jax
Expand Down
1 change: 1 addition & 0 deletions src/bmi/samplers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Partial implementation of the ISampler interface, convenient to inherit from."""

from typing import Union

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions src/bmi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
When we refactor the code and observe patterns,
the methods implemented here will move.
"""

from typing import Generator, Union

import numpy as np
Expand Down
7 changes: 4 additions & 3 deletions tests/benchmark/tasks/test_bimodal_gaussians.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import pytest

import bmi.benchmark.tasks.bimodal_gaussians as bimodal_gaussians


@pytest.mark.parametrize("n_samples", [10_000])
@pytest.mark.parametrize("n_samples", [1_000])
@pytest.mark.parametrize("seed", [0])
def test_numerical_inversion(n_samples, seed):
task = bimodal_gaussians.task_bimodal_gaussians()
Expand All @@ -14,5 +15,5 @@ def test_numerical_inversion(n_samples, seed):
print(samples_x.max())

# there should be almost no repeats
assert len(set(map(float, samples_x))) > 0.999 * n_samples
assert len(set(map(float, samples_y))) > 0.999 * n_samples
assert np.unique(samples_x).shape[0] > 0.99 * n_samples
assert np.unique(samples_y).shape[0] > 0.99 * n_samples
1 change: 1 addition & 0 deletions tests/estimators/neural/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests of different backends."""

from typing import Tuple

import jax.numpy as jnp
Expand Down
1 change: 1 addition & 0 deletions tests/estimators/neural/test_neural.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests of the neural estimators on simple distributions."""

import equinox as eqx
import numpy as np
import pytest
Expand Down
1 change: 1 addition & 0 deletions tests/estimators/test_kde.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests of the kernel density estimator."""

import numpy as np
import pytest

Expand Down
1 change: 1 addition & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The tests for public API."""

import pytest


Expand Down
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests of the utilities."""

import numpy as np
import numpy.testing as nptest
import pytest
Expand Down
1 change: 1 addition & 0 deletions workflows/Beyond_Normal/_high_mi_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utilities for generating tasks
for high mutual information plot."""

from typing import Callable, Optional

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions workflows/Beyond_Normal/scripts/plot_spiral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This script visualises the 2D spirals, sampling some points
and applying spiral diffeomorphism with different speed parameter."""

import jax
import matplotlib.pyplot as plt
import numpy as np
Expand Down
1 change: 1 addition & 0 deletions workflows/Mixtures/example_distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The example distributions used in the paper."""

import dataclasses

import jax.numpy as jnp
Expand Down
Loading