Skip to content

Commit

Permalink
Merge pull request #573 from dynamicslab/fast_sbr_test
Browse files Browse the repository at this point in the history
Speed up SBR tests, and add NUTS kwargs
  • Loading branch information
Jacob-Stevens-Haas authored Oct 16, 2024
2 parents 4d27145 + ddf8d6d commit c70acd3
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 38 deletions.
7 changes: 7 additions & 0 deletions pysindy/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import numpy as np
import numpy.typing as npt

# In python 3.12, use type statement
# https://docs.python.org/3/reference/simple_stmts.html#the-type-statement
NpFlt = np.floating[npt.NBitBase]
Float2D = np.ndarray[tuple[int, int], np.dtype[NpFlt]]
16 changes: 11 additions & 5 deletions pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_X_y

from .._typing import Float2D
from ..utils import AxesArray
from ..utils import drop_nan_samples

Expand Down Expand Up @@ -178,8 +179,7 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):

x_normed = np.copy(x)
if self.normalize_columns:
reg = 1 / np.linalg.norm(x, 2, axis=0)
x_normed = x * reg
feat_norms, x_normed = _normalize_features(x_normed)

if self.initial_guess is None:
self.coef_ = np.linalg.lstsq(x_normed, y, rcond=None)[0].T
Expand All @@ -203,11 +203,11 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):

# Rescale coefficients to original units
if self.normalize_columns:
self.coef_ = np.multiply(reg, self.coef_)
self.coef_ = self.coef_ / feat_norms
if hasattr(self, "coef_full_"):
self.coef_full_ = np.multiply(reg, self.coef_full_)
self.coef_full_ = self.coef_full_ / feat_norms
for i in range(np.shape(self.history_)[0]):
self.history_[i] = np.multiply(reg, self.history_[i])
self.history_[i] = self.history_[i] / feat_norms

self._set_intercept(X_offset, y_offset, X_scale)
return self
Expand Down Expand Up @@ -395,3 +395,9 @@ def _drop_random_samples(
x_dot_new = np.take(x_dot, rand_inds, axis=x.ax_sample)

return x_new, x_dot_new


def _normalize_features(x: Float2D) -> Float2D:
"Calculate the length of vectors and normalize them"
lengths = np.linalg.norm(x, 2, axis=0)
return lengths, x / lengths
18 changes: 13 additions & 5 deletions pysindy/optimizers/sbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
num_warmup: int = 1000,
num_samples: int = 5000,
mcmc_kwargs: Optional[dict] = None,
nuts_kwargs: Optional[dict] = None,
unbias: bool = False,
**kwargs,
):
Expand Down Expand Up @@ -135,10 +136,14 @@ def __init__(
self.mcmc_kwargs = mcmc_kwargs
else:
self.mcmc_kwargs = {}
if nuts_kwargs is not None:
self.nuts_kwargs = nuts_kwargs
else:
self.nuts_kwargs = {}

def _reduce(self, x, y):
# set up a sparse regression and sample.
self.mcmc_ = self._run_mcmc(x, y, **self.mcmc_kwargs)
self.mcmc_ = self._run_mcmc(x, y, self.nuts_kwargs, self.mcmc_kwargs)

# set the mean values as the coefficients.
self.coef_ = np.array(self.mcmc_.get_samples()["beta"].mean(axis=0))
Expand All @@ -165,15 +170,18 @@ def _numpyro_model(self, x, y):
sigma = numpyro.sample("sigma", Exponential(self.noise_hyper_lambda))
numpyro.sample("obs", Normal(mu, sigma), obs=y)

def _run_mcmc(self, x, y, **kwargs):
def _run_mcmc(self, x, y, nuts_kwargs, mcmc_kwargs):
# set up a jax random key.
seed = kwargs.pop("seed", 0)
seed = mcmc_kwargs.pop("seed", 0)
rng_key = random.PRNGKey(seed)

# run the MCMC
kernel = NUTS(self._numpyro_model)
kernel = NUTS(self._numpyro_model, **nuts_kwargs)
mcmc = MCMC(
kernel, num_warmup=self.num_warmup, num_samples=self.num_samples, **kwargs
kernel,
num_warmup=self.num_warmup,
num_samples=self.num_samples,
**mcmc_kwargs,
)
mcmc.run(rng_key, x=x, y=y)

Expand Down
52 changes: 24 additions & 28 deletions test/test_optimizers/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pysindy.optimizers import STLSQ
from pysindy.optimizers import TrappingSR3
from pysindy.optimizers import WrappedOptimizer
from pysindy.optimizers.base import _normalize_features
from pysindy.optimizers.ssr import _ind_inflection
from pysindy.optimizers.stlsq import _remove_and_decrement
from pysindy.utils import supports_multiple_targets
Expand Down Expand Up @@ -125,7 +126,7 @@ def data(request):
ElasticNet(fit_intercept=False),
DummyLinearModel(),
MIOSR(),
SBR(),
SBR(num_warmup=10, num_samples=10),
],
ids=lambda param: type(param),
)
Expand Down Expand Up @@ -154,7 +155,19 @@ def test_not_fitted(optimizer):
optimizer.predict(np.ones((1, 3)))


@pytest.mark.parametrize("optimizer", [STLSQ(), SR3(), SBR()])
@pytest.mark.parametrize(
"optimizer",
[
STLSQ(),
SR3(),
SBR(
num_warmup=1,
num_samples=1,
nuts_kwargs={"max_tree_depth": 1, "target_accept_prob": 0.1},
),
],
ids=type,
)
def test_complexity_not_fitted(optimizer, data_derivative_2d):
with pytest.raises(NotFittedError):
optimizer.complexity
Expand Down Expand Up @@ -1022,33 +1035,16 @@ def test_inequality_constraints_reqs():
)


@pytest.mark.parametrize(
"optimizer",
[
STLSQ,
SSR,
FROLS,
SR3,
ConstrainedSR3,
StableLinearSR3,
TrappingSR3,
MIOSR,
SBR,
],
)
def test_normalize_columns(data_derivative_1d, optimizer):
def test_normalize_columns(data_derivative_1d):
x, x_dot = data_derivative_1d
if len(x.shape) == 1:
x = x.reshape(-1, 1)
opt = optimizer(normalize_columns=True)
opt, x = _align_optimizer_and_1dfeatures(opt, x)
opt.fit(x, x_dot)
check_is_fitted(opt)
assert opt.complexity >= 0
if len(x_dot.shape) > 1:
assert opt.coef_.shape == (x.shape[1], x_dot.shape[1])
else:
assert opt.coef_.shape == (1, x.shape[1])
x = np.reshape(x, (-1, 1))
x_dot = np.reshape(x_dot, (-1, 1))
cols = np.hstack((x, x_dot))
norm, ncols = _normalize_features(cols)
result = np.linalg.norm(ncols, axis=0)
expected = [1.0, 1.0]
np.testing.assert_allclose(result, expected)
np.testing.assert_allclose(ncols * norm, cols)


@pytest.mark.parametrize(
Expand Down

0 comments on commit c70acd3

Please sign in to comment.