Skip to content

Commit

Permalink
Merge pull request #562 from dynamicslab/simultaneous
Browse files Browse the repository at this point in the history
Extract common functionality from `SINDy` to `_BaseSINDy`
  • Loading branch information
Jacob-Stevens-Haas authored Oct 18, 2024
2 parents c70acd3 + 22329dd commit 2ca37cb
Show file tree
Hide file tree
Showing 11 changed files with 181 additions and 156 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ classifiers = [
]
readme = "README.rst"
dependencies = [
"jax>=0.4,<0.5",
"scikit-learn>=1.1, !=1.5.0",
"derivative>=0.6.2",
"typing_extensions",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -63,7 +65,6 @@ cvxpy = [
]
sbr = [
"numpyro",
"jax",
"arviz==0.17.1",
"scipy<1.13.0"
]
Expand Down
9 changes: 8 additions & 1 deletion pysindy/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@
# In python 3.12, use type statement
# https://docs.python.org/3/reference/simple_stmts.html#the-type-statement
NpFlt = np.floating[npt.NBitBase]
Float2D = np.ndarray[tuple[int, int], np.dtype[NpFlt]]
FloatDType = np.dtype[np.floating[npt.NBitBase]]
Int1D = np.ndarray[tuple[int], np.dtype[np.int_]]
Float1D = np.ndarray[tuple[int], FloatDType]
Float2D = np.ndarray[tuple[int, int], FloatDType]
Float3D = np.ndarray[tuple[int, int, int], FloatDType]
Float4D = np.ndarray[tuple[int, int, int, int], FloatDType]
Float5D = np.ndarray[tuple[int, int, int, int, int], FloatDType]
FloatND = npt.NDArray[NpFlt]
18 changes: 14 additions & 4 deletions pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional
from typing import Sequence

import jax
import numpy as np
from scipy import sparse
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -144,19 +145,28 @@ def x_sequence_or_item(wrapped_func):
@wraps(wrapped_func)
def func(self, x, *args, **kwargs):
if isinstance(x, Sequence):
xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x]
if isinstance(x[0], jax.Array):
xs = x
else:
xs = [AxesArray(xi, comprehend_axes(xi)) for xi in x]
result = wrapped_func(self, xs, *args, **kwargs)
if isinstance(result, Sequence): # e.g. transform() returns x
# if transform() is a normal "return x"
if isinstance(result, Sequence) and isinstance(result[0], np.ndarray):
return [AxesArray(xp, comprehend_axes(xp)) for xp in result]
return result # e.g. fit() returns self
else:
if not sparse.issparse(x):
if isinstance(x, jax.Array):

def reconstructor(x):
return x

elif not sparse.issparse(x) and isinstance(x, np.ndarray):
x = AxesArray(x, comprehend_axes(x))

def reconstructor(x):
return x

else: # sparse arrays
else: # sparse
reconstructor = type(x)
axes = comprehend_axes(x)
wrap_axes(axes, x)
Expand Down
7 changes: 4 additions & 3 deletions pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..utils import AxesArray
from ..utils import comprehend_axes
from ..utils import wrap_axes
from ..utils._axis_conventions import AX_COORD
from .base import BaseFeatureLibrary
from .base import x_sequence_or_item

Expand Down Expand Up @@ -160,7 +161,7 @@ def get_feature_names(self, input_features=None):
return feature_names

@x_sequence_or_item
def fit(self, x_full, y=None):
def fit(self, x_full: list[AxesArray], y=None):
"""
Compute number of output features.
Expand All @@ -180,7 +181,7 @@ def fit(self, x_full, y=None):
"Can't have include_interaction be False and interaction_only"
" be True"
)
n_features = x_full[0].shape[x_full[0].ax_coord]
n_features = x_full[0].shape[AX_COORD]
combinations = self._combinations(
n_features,
self.degree,
Expand Down Expand Up @@ -217,7 +218,7 @@ def transform(self, x_full):
axes = comprehend_axes(x)
x = x.asformat("csc")
wrap_axes(axes, x)
n_features = x.shape[x.ax_coord]
n_features = x.shape[AX_COORD]
if n_features != self.n_features_in_:
raise ValueError("x shape does not match training shape")

Expand Down
20 changes: 18 additions & 2 deletions pysindy/optimizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,27 @@
import abc
import warnings
from typing import Callable
from typing import NewType
from typing import Optional
from typing import Tuple

import numpy as np
from scipy import sparse
from sklearn.base import BaseEstimator
from sklearn.linear_model import LinearRegression
from sklearn.linear_model._base import _preprocess_data
from sklearn.utils.extmath import safe_sparse_dot
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.validation import check_X_y

from .._typing import Float2D
from .._typing import FloatDType
from ..utils import AxesArray
from ..utils import drop_nan_samples

NFeat = NewType("NFeat", int)
NTarget = NewType("NTarget", int)


def _rescale_data(X, y, sample_weight):
"""Rescale data so as to support sample_weight"""
Expand All @@ -32,14 +39,17 @@ def _rescale_data(X, y, sample_weight):
return X, y


class ComplexityMixin:
class _BaseOptimizer(BaseEstimator, abc.ABC):
coef_: np.ndarray[tuple[NTarget, NFeat], FloatDType]
intercept_: np.ndarray[tuple[NTarget], FloatDType]

@property
def complexity(self):
check_is_fitted(self)
return np.count_nonzero(self.coef_) + np.count_nonzero(self.intercept_)


class BaseOptimizer(LinearRegression, ComplexityMixin):
class BaseOptimizer(LinearRegression, _BaseOptimizer):
"""
Base class for SINDy optimizers. Subclasses must implement
a _reduce method for carrying out the bulk of the work of
Expand Down Expand Up @@ -89,6 +99,12 @@ class BaseOptimizer(LinearRegression, ComplexityMixin):
"""

max_iter: int
normalize_columns: bool
initial_guess: Optional[np.ndarray[tuple[NTarget, NFeat], FloatDType]]
copy_X: bool
unbias: bool

def __init__(
self,
max_iter=20,
Expand Down
26 changes: 11 additions & 15 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,29 @@
from itertools import repeat
from math import comb
from typing import cast
from typing import NewType
from typing import Optional
from typing import TypeVar
from typing import Union

import cvxpy as cp
import numpy as np
from numpy.typing import NBitBase
from numpy.typing import NDArray
from sklearn.exceptions import ConvergenceWarning

from .._typing import Float1D
from .._typing import Float2D
from .._typing import Float3D
from .._typing import Float4D
from .._typing import Float5D
from .._typing import Int1D
from ..feature_library.polynomial_library import n_poly_features
from ..feature_library.polynomial_library import PolynomialLibrary
from ..utils import reorder_constraints
from .base import FloatDType
from .base import NFeat
from .base import NTarget
from .constrained_sr3 import ConstrainedSR3

AnyFloat = np.dtype[np.floating[NBitBase]]
Int1D = np.ndarray[tuple[int], np.dtype[np.int_]]
Float1D = np.ndarray[tuple[int], AnyFloat]
Float2D = np.ndarray[tuple[int, int], AnyFloat]
Float3D = np.ndarray[tuple[int, int, int], AnyFloat]
Float4D = np.ndarray[tuple[int, int, int, int], AnyFloat]
Float5D = np.ndarray[tuple[int, int, int, int, int], AnyFloat]
FloatND = NDArray[np.floating[NBitBase]]
NFeat = NewType("NFeat", int)
NTarget = NewType("NTarget", int)


class EnstrophyMat:
"""Pre-compute some useful factors of an enstrophy matrix
Expand Down Expand Up @@ -601,7 +597,7 @@ def _solve_m_relax_and_split(
self,
trap_ctr: Float1D,
prev_A: Float2D,
coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat],
coef_sparse: np.ndarray[tuple[NFeat, NTarget], FloatDType],
) -> tuple[Float1D, Float2D]:
r"""Updates the trap center
Expand Down Expand Up @@ -693,7 +689,7 @@ def _reduce(self, x, y):
self.constraint_lhs = reorder_constraints(
self.constraint_lhs, n_features, output_order="feature"
)
coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat] = self.coef_.T
coef_sparse: np.ndarray[tuple[NFeat, NTarget], FloatDType] = self.coef_.T

# Print initial values for each term in the optimization
if self.verbose:
Expand Down
Loading

0 comments on commit 2ca37cb

Please sign in to comment.