Skip to content

Commit

Permalink
Merge pull request #393 from gchq/feature/beartype
Browse files Browse the repository at this point in the history
Add `beartype` runtime type checker
  • Loading branch information
db091756 authored Sep 4, 2024
2 parents 10f1b61 + 4688297 commit f2939ea
Show file tree
Hide file tree
Showing 78 changed files with 402 additions and 277 deletions.
19 changes: 18 additions & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
# Note that we only do type checking on Python 3.12.
python-version:
- "3.8"
- "3.9"
Expand All @@ -35,8 +36,24 @@ jobs:
run: python -m pip install --upgrade pip
- name: Install latest test dependencies
run: pip install -e .[test] --upgrade --upgrade-strategy eager
- name: Test with pytest
- name: Run tests incompatible with type checking
run: >
pytest
--ignore=tests/test_examples.py
-W "error::vanguard.utils.UnseededRandomWarning"
-m "no_beartype"
if: matrix.python-version == '3.12'
- name: Test with pytest (and type checking)
run: >
pytest
--ignore=tests/test_examples.py
-W "error::vanguard.utils.UnseededRandomWarning"
-m "not no_beartype"
--beartype-packages="vanguard"
if: matrix.python-version == '3.12'
- name: Run all tests
run: >
pytest
--ignore=tests/test_examples.py
-W "error::vanguard.utils.UnseededRandomWarning"
if: matrix.python-version != '3.12'
2 changes: 1 addition & 1 deletion .github/workflows/unittests_from_requirements.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run: python -m pip install --upgrade pip
- name: Install test dependencies
run: pip install -r requirements.txt --no-deps --upgrade
- name: Test with pytest
- name: Run tests
run: >
pytest
--ignore=tests/test_examples.py
Expand Down
39 changes: 26 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,38 @@ If the code is not running properly, recreate the environment with `pip install

## Tests

Vanguard's tests are contained in the `tests/` directory, and can be run with `unittest` or `pytest`.

Unit tests are in `tests/units`. There are two additional test files that dynamically run additional tests:
- `test_doctests.py` finds and runs all doctests.
- `test_examples.py` runs all notebooks under `examples/` as tests. These require `nbconvert` and `nbformat` to run,
Vanguard's tests are contained in the `tests/` directory, and can be run with `pytest`. The tests are arranged
as follows:
- `tests/units` contains unit tests. These should be fairly quick to run.
- `tests/integration` contains integration tests, which may take longer to run.
- `tests/test_doctests.py` finds and runs all doctests. This should be fairly quick to run.
- `tests/test_examples.py` runs all notebooks under `examples/` as tests. These require `nbconvert` and `nbformat` to run,
and can take a significant amount of time to complete, so consider excluding `test_examples.py` from your test
discovery.

```shell
# Unittest:
$ python -m unittest discover -s tests/units # run unit tests
$ python -m unittest tests/test_doctests.py # run doctests
$ python -m unittest tests/test_examples.py # run example tests (slow)

# Pytest:
```shell
$ pytest # run all tests (slow)
$ pytest tests/units # run unit tests
$ pytest tests/integration # run integration tests (slow)
$ pytest tests/test_doctests.py # run doctests
$ pytest tests/test_examples.py # run example tests (slow)
```

Note that some tests are non-deterministic and as such may occasionally fail due to randomness.
Please try running them again before raising an issue.
Our PR workflows run our tests with the `pytest-beartype` plugin. This is a runtime type checker that ensures all
our type hints are correct. In order to run with these checks locally, add
`--beartype-packages="vanguard" -m "not no_beartype"` to your pytest invocation. You should then separately run pytest
with `-m no_beartype` to ensure that all tests are run. The reason for this separation is that some of our tests check
that our handling of inputs of invalid type are correct, but `beartype` catches these errors before we get a chance to
look at them, causing the tests to fail; thus, these tests need to be run separately _without_ beartype.

Since different Python versions have different versions of standard library and third-party modules, we can't guarantee
that type hints are 100% correct on all Python versions. Type hints are only tested for correctness on the latest
version of Python (3.12).

For example, to run the unit tests with type checking:

```shell
$ pytest tests/units --beartype-packages="vanguard" -m "not no_beartype" # run unit tests with type checking
$ pytest tests/units -m no_beartype # run unit tests that are incompatible with beartype
```
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
nitpicky_ignore_mapping: Dict[str, List[str]] = {
"py:class": [
"torch.Size",
"gpytorch.likelihoods.gaussian_likelihood._GaussianLikelihoodBase",
],
"py:meth": [
"activate",
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [
# Run unit tests
test = [
"pytest-cov",
"pytest-beartype"
]
# Compile documentation
doc = [
Expand Down Expand Up @@ -135,6 +136,10 @@ exclude_also = [
]

[tool.pytest.ini_options]
markers = [
"no_beartype: for tests incompatible with beartype (e.g. checking for TypeErrors)",
]

# TODO: fix as many of these as possible, and for those we can't fix, suppress as many as possible at the point that
# they're emitted using the catch_warnings context manager. Suppress globally here only as a last resort.
# https://github.com/gchq/Vanguard/issues/281
Expand Down Expand Up @@ -164,4 +169,6 @@ filterwarnings = [
# TODO: replace with sparse_coo_tensor
# https://github.com/gchq/Vanguard/issues/278
"ignore:torch.sparse.SparseTensor\\(indices, values, shape, \\*, device=\\) is deprecated:UserWarning",
# TODO: replace with some alternative (e.g. beartype.typing.*)
"ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning"
]
2 changes: 2 additions & 0 deletions tests/integration/test_base_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def test_basic_gp(self, batch_size: Optional[int]) -> None:
_prediction_ci_median, prediction_ci_lower, prediction_ci_upper = posterior.confidence_interval(
alpha=self.confidence_interval_alpha
)
prediction_ci_lower = prediction_ci_lower.numpy()
prediction_ci_upper = prediction_ci_upper.numpy()

# Sense check the outputs
assert np.all(prediction_means <= prediction_ci_upper)
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_classification_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class BinaryClassifier(GaussianGPController):
train_x=x[train_indices],
train_y=y[train_indices],
kernel_class=ScaledRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=BernoulliLikelihood,
marginal_log_likelihood_class=VariationalELBO,
rng=self.rng,
Expand Down Expand Up @@ -133,7 +133,7 @@ class CategoricalClassifier(GaussianGPController):
train_x=x[train_indices],
train_y=y[train_indices, 0],
kernel_class=ScaledRBFKernel,
y_std=0,
y_std=0.0,
kernel_kwargs={"batch_shape": (3,)},
likelihood_class=DirichletClassificationLikelihood,
rng=self.rng,
Expand Down Expand Up @@ -183,7 +183,7 @@ class CategoricalClassifier(GaussianGPController):
train_x=x[train_indices],
train_y=y[train_indices, 0],
kernel_class=ScaledRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=DirichletKernelClassifierLikelihood,
marginal_log_likelihood_class=GenericExactMarginalLogLikelihood,
rng=self.rng,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_distribute_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class BinaryClassifier(GaussianGPController):
train_x=self.x[self.train_indices],
train_y=self.y[self.train_indices],
kernel_class=ScaledRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=BernoulliLikelihood,
marginal_log_likelihood_class=VariationalELBO,
rng=self.rng,
Expand Down Expand Up @@ -167,7 +167,7 @@ class BinaryClassifier(GaussianGPController):
train_x=self.x[self.train_indices],
train_y=self.y[self.train_indices],
kernel_class=ScaledRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=BernoulliLikelihood,
marginal_log_likelihood_class=VariationalELBO,
partitioner_kwargs=partitioner_kwargs,
Expand Down
9 changes: 5 additions & 4 deletions tests/integration/test_hierarchical_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest

import numpy as np
import torch
from gpytorch.kernels import RBFKernel

from tests.cases import get_default_rng
Expand Down Expand Up @@ -91,8 +92,8 @@ class BayesianRBFKernel(RBFKernel):

# Sense check the outputs. Note that we do not check confidence interval quality here,
# just that they can be created, due to highly varying quality of the resulting intervals,
self.assertTrue(np.all(prediction_medians <= prediction_ci_upper))
self.assertTrue(np.all(prediction_medians >= prediction_ci_lower))
self.assertTrue(torch.all(prediction_medians <= prediction_ci_upper))
self.assertTrue(torch.all(prediction_medians >= prediction_ci_lower))

def test_gp_variational_hierarchical(self):
"""
Expand Down Expand Up @@ -131,8 +132,8 @@ class BayesianRBFKernel(RBFKernel):

# Sense check the outputs. Note that we do not check confidence interval quality here,
# just that they can be created, due to highly varying quality of the resulting intervals,
self.assertTrue(np.all(prediction_medians <= prediction_ci_upper))
self.assertTrue(np.all(prediction_medians >= prediction_ci_lower))
self.assertTrue(torch.all(prediction_medians <= prediction_ci_upper))
self.assertTrue(torch.all(prediction_medians >= prediction_ci_lower))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_multitask_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class MultiTaskCategoricalClassifier(GaussianGPController):
train_x=x[train_indices],
train_y=y[train_indices],
kernel_class=ScaledRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=MultitaskBernoulliLikelihood,
marginal_log_likelihood_class=VariationalELBO,
)
Expand Down
13 changes: 10 additions & 3 deletions tests/units/base/posteriors/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

import numpy as np
import torch
from gpytorch.distributions import Distribution, MultivariateNormal
from gpytorch.distributions import MultivariateNormal
from scipy.stats import multivariate_normal
from torch import Tensor
from torch.distributions import Distribution

from tests.cases import get_default_rng
from vanguard.base.posteriors import MonteCarloPosteriorCollection, Posterior
Expand Down Expand Up @@ -182,12 +184,17 @@ def generate_mostly_invalid_posteriors(num_failures_between_successes: int) -> G
class MockDistribution:
"""Mock of Distribution class."""

__class__ = MultivariateNormal # lying to the type checker

def __init__(self, mean: torch.Tensor, covariance: torch.Tensor):
self.mean = mean
self.covariance_matrix = covariance

def __getattr__(self, _):
return Mock()
def add_jitter(self, _):
return self

def rsample(self, *_, **__):
return Mock(Tensor)

class MockPosterior:
"""Mock of Posterior class."""
Expand Down
2 changes: 1 addition & 1 deletion tests/units/base/posteriors/test_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_1_dim_mean_log_probability_order(self) -> None:
def test_sample(self) -> None:
"""Test that the `sample()` function simply returns a sample from the internal distribution."""
# Set up a mock distribution
mock_distribution = Mock()
mock_distribution = Mock(torch.distributions.Distribution)
del mock_distribution.covariance_matrix # no covariance matrix, so no jitter is added

# Define a function to generate and record random samples, and replace our mock distribution's `rsample`
Expand Down
2 changes: 1 addition & 1 deletion tests/units/base/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
self.x_test = scaled_test_x
self.y_test = function(self.x_test)

self.y_std = 1
self.y_std = 1.0

self.dataset = UniformSyntheticDataset(lambda x: np.sin(10 * x), 100, 100 // 4, self.y_std, rng=self.rng)

Expand Down
8 changes: 4 additions & 4 deletions tests/units/base/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUp(self) -> None:
"""Code to run before each test."""
self.tracker = MetricsTracker(loss)
for loss_value in range(100):
self.tracker.run_metrics(loss_value=loss_value, controller=None)
self.tracker.run_metrics(loss_value=float(loss_value), controller=None)

def test_get_item(self) -> None:
"""Items should be correct."""
Expand Down Expand Up @@ -77,11 +77,11 @@ def setUp(self) -> None:
"""Code to run before each test."""
self.tracker = MetricsTracker()
for loss_value in range(50):
self.tracker.run_metrics(loss_value=loss_value, controller=None)
self.tracker.run_metrics(loss_value=float(loss_value), controller=None)

self.tracker.add_metrics(loss)
for loss_value in range(50, 100):
self.tracker.run_metrics(loss_value=loss_value, controller=None)
self.tracker.run_metrics(loss_value=float(loss_value), controller=None)

def test_get_item_before_50(self) -> None:
"""Items should be correct."""
Expand All @@ -107,7 +107,7 @@ def setUp(self) -> None:
"""Code to run before each test."""
self.tracker = MetricsTracker(loss, loss_squared)
for loss_value in range(100):
self.tracker.run_metrics(loss_value=loss_value, controller=None)
self.tracker.run_metrics(loss_value=float(loss_value), controller=None)

def test_get_item_before_50(self) -> None:
"""Items should be correct."""
Expand Down
4 changes: 2 additions & 2 deletions tests/units/classification/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BatchScaledRBFKernel(ScaleKernel):
The recommended starting place for a kernel.
"""

def __init__(self, batch_shape: torch.Size) -> None:
def __init__(self, batch_shape: Union[int, torch.Size]) -> None:
batch_shape = batch_shape if isinstance(batch_shape, torch.Size) else torch.Size([batch_shape])
super().__init__(RBFKernel(batch_shape=batch_shape), batch_shape=batch_shape)

Expand All @@ -41,7 +41,7 @@ class BatchScaledMean(ZeroMean):
A basic mean with batch shape to match the above kernel.
"""

def __init__(self, batch_shape: torch.Size) -> None:
def __init__(self, batch_shape: Union[int, torch.Size]) -> None:
batch_shape = batch_shape if isinstance(batch_shape, torch.Size) else torch.Size([batch_shape])
super().__init__(batch_shape=batch_shape)

Expand Down
8 changes: 4 additions & 4 deletions tests/units/classification/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def setUp(self) -> None:
self.dataset.train_x,
self.dataset.train_y,
kernel_class=PeriodicRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=BernoulliLikelihood,
marginal_log_likelihood_class=VariationalELBO,
rng=self.rng,
Expand All @@ -72,7 +72,7 @@ class IllegalLikelihoodClass:
self.dataset.train_x,
self.dataset.train_y,
kernel_class=PeriodicRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=IllegalLikelihoodClass,
marginal_log_likelihood_class=VariationalELBO,
rng=self.rng,
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_fuzzy_predictions_monte_carlo(self) -> None:
dataset.train_x,
dataset.train_y,
kernel_class=PeriodicRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=BernoulliLikelihood,
marginal_log_likelihood_class=VariationalELBO,
rng=self.rng,
Expand Down Expand Up @@ -164,7 +164,7 @@ class UncertaintyBinaryClassifier(GaussianUncertaintyGPController):
train_x_std,
dataset.train_y,
kernel_class=PeriodicRBFKernel,
y_std=0,
y_std=0.0,
likelihood_class=BernoulliLikelihood,
marginal_log_likelihood_class=VariationalELBO,
rng=self.rng,
Expand Down
Loading

0 comments on commit f2939ea

Please sign in to comment.