Skip to content

Commit

Permalink
BUG #394 fix prox and regularization in SR3 (#544)
Browse files Browse the repository at this point in the history
Now, prox and regularization have a more plain API: weights are either scalars or must
match the shape of the optimization variable.  This removes weird use cases that only
worked because of broadcasting, which constrained our ability to simplify

It also fixes the calculation of regularizers in SR3 _calculate_penalty, with the aim that
this will also be replaced with get_prox and get_regularization when they are able to
handle CVXPY (or JAX) expressions (arrays)
  • Loading branch information
himkwtn authored Sep 10, 2024
1 parent 69e51e8 commit 3d791e2
Show file tree
Hide file tree
Showing 9 changed files with 849 additions and 848 deletions.
1,182 changes: 594 additions & 588 deletions examples/1_feature_overview/example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/1_feature_overview/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def f(x):
model.print()

# With thresholds matrix
thresholds = 2 * np.ones((10, 3))
thresholds = 2 * np.ones((3, 10))
thresholds[4:, :] = 0.1
sr3_optimizer = ps.SR3(thresholder="weighted_l0", thresholds=thresholds)
model = ps.SINDy(optimizer=sr3_optimizer).fit(x_train, t=dt)
Expand Down
99 changes: 33 additions & 66 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from scipy.linalg import cho_factor
from sklearn.exceptions import ConvergenceWarning

from ..utils import get_regularization
from ..utils import reorder_constraints
from .sr3 import SR3

Expand Down Expand Up @@ -65,9 +64,9 @@ class ConstrainedSR3(SR3):
thresholder : string, optional (default 'l0')
Regularization function to use. Currently implemented options
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm), 'cad' (clipped
absolute deviation), 'weighted_l0' (weighted l0 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
are 'l0' (l0 norm), 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l0' (weighted l0 norm), 'weighted_l1' (weighted l1 norm),
and 'weighted_l2' (weighted l2 norm).
max_iter : int, optional (default 30)
Maximum iterations of the optimization algorithm.
Expand Down Expand Up @@ -192,7 +191,6 @@ def __init__(
)

self.verbose_cvxpy = verbose_cvxpy
self.reg = get_regularization(thresholder)
self.constraint_lhs = constraint_lhs
self.constraint_rhs = constraint_rhs
self.constraint_order = constraint_order
Expand Down Expand Up @@ -271,20 +269,41 @@ def _update_full_coef_constraints(self, H, x_transpose_y, coef_sparse):
rhs = rhs.reshape(g.shape)
return inv1.dot(rhs)

@staticmethod
def _calculate_penalty(
regularization: str, regularization_weight, xi: cp.Variable
) -> cp.Expression:
"""
Args:
-----
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' |
'l2' | 'weighted_l2'
regularization_weight: float | np.array, can be a scalar
or an array of the same shape as xi
xi: cp.Variable
Returns:
--------
cp.Expression
"""
regularization = regularization.lower()
if regularization == "l1":
return regularization_weight * cp.sum(cp.abs(xi))
elif regularization == "weighted_l1":
return cp.sum(cp.multiply(regularization_weight, cp.abs(xi)))
elif regularization == "l2":
return regularization_weight * cp.sum(xi**2)
elif regularization == "weighted_l2":
return cp.sum(cp.multiply(regularization_weight, xi**2))

def _create_var_and_part_cost(
self, var_len: int, x_expanded: np.ndarray, y: np.ndarray
) -> Tuple[cp.Variable, cp.Expression]:
xi = cp.Variable(var_len)
cost = cp.sum_squares(x_expanded @ xi - y.flatten())
if self.thresholder.lower() == "l1":
cost = cost + self.threshold * cp.norm1(xi)
elif self.thresholder.lower() == "weighted_l1":
cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi)
elif self.thresholder.lower() == "l2":
cost = cost + self.threshold * cp.norm2(xi) ** 2
elif self.thresholder.lower() == "weighted_l2":
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2
return xi, cost
threshold = self.thresholds if self.thresholds is not None else self.threshold
penalty = self._calculate_penalty(self.thresholder, np.ravel(threshold), xi)
return xi, cost + penalty

def _update_coef_cvxpy(self, xi, cost, var_len, coef_prev, tol):
if self.use_constraints:
Expand Down Expand Up @@ -342,58 +361,6 @@ def _update_coef_cvxpy(self, xi, cost, var_len, coef_prev, tol):
coef_new = (xi.value).reshape(coef_prev.shape)
return coef_new

def _update_sparse_coef(self, coef_full):
"""Update the regularized weight vector"""
if self.thresholds is None:
return super(ConstrainedSR3, self)._update_sparse_coef(coef_full)
else:
coef_sparse = self.prox(coef_full, self.thresholds.T)
self.history_.append(coef_sparse.T)
return coef_sparse

def _objective(self, x, y, q, coef_full, coef_sparse, trimming_array=None):
"""Objective function"""
if q != 0:
print_ind = q % (self.max_iter // 10.0)
else:
print_ind = q
R2 = (y - np.dot(x, coef_full)) ** 2
D2 = (coef_full - coef_sparse) ** 2
if self.use_trimming:
assert trimming_array is not None
R2 *= trimming_array.reshape(x.shape[0], 1)

if self.thresholds is None:
regularization = self.reg(coef_full, self.threshold**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu
else:
regularization = self.reg(coef_full, self.thresholds**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu

def _reduce(self, x, y):
"""
Perform at most ``self.max_iter`` iterations of the SR3 algorithm
Expand Down
60 changes: 22 additions & 38 deletions pysindy/optimizers/sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class SR3(BaseOptimizer):
thresholder : string, optional (default 'L0')
Regularization function to use. Currently implemented options
are 'L0' (L0 norm), 'L1' (L1 norm), 'L2' (L2 norm) and 'CAD' (clipped
absolute deviation). Note by 'L2 norm' we really mean
are 'L0' (L0 norm), 'L1' (L1 norm) and 'L2' (L2 norm).
Note by 'L2 norm' we really mean
the squared L2 norm, i.e. ridge regression
trimming_fraction : float, optional (default 0.0)
Expand Down Expand Up @@ -179,10 +179,9 @@ def __init__(
"weighted_l0",
"weighted_l1",
"weighted_l2",
"cad",
):
raise NotImplementedError(
"Please use a valid thresholder, l0, l1, l2, cad, "
"Please use a valid thresholder, l0, l1, l2, "
"weighted_l0, weighted_l1, weighted_l2."
)
if thresholder[:8].lower() == "weighted" and thresholds is None:
Expand Down Expand Up @@ -212,6 +211,10 @@ def __init__(
self.threshold = threshold
self.thresholds = thresholds
self.nu = nu
if thresholds is not None:
self.lam = thresholds.T**2 / (2 * nu)
else:
self.lam = threshold**2 / (2 * nu)
self.tol = tol
self.thresholder = thresholder
self.reg = get_regularization(thresholder)
Expand Down Expand Up @@ -253,36 +256,20 @@ def _objective(self, x, y, q, coef_full, coef_sparse, trimming_array=None):
if self.use_trimming:
assert trimming_array is not None
R2 *= trimming_array.reshape(x.shape[0], 1)
if self.thresholds is None:
regularization = self.reg(coef_full, self.threshold**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu
else:
regularization = self.reg(coef_full, self.thresholds**2 / self.nu)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu
regularization = self.reg(coef_full, self.lam)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu

def _update_full_coef(self, cho, x_transpose_y, coef_sparse):
"""Update the unregularized weight vector"""
Expand All @@ -293,10 +280,7 @@ def _update_full_coef(self, cho, x_transpose_y, coef_sparse):

def _update_sparse_coef(self, coef_full):
"""Update the regularized weight vector"""
if self.thresholds is None:
coef_sparse = self.prox(coef_full, self.threshold)
else:
coef_sparse = self.prox(coef_full, self.thresholds.T)
coef_sparse = self.prox(coef_full, self.lam * self.nu)
return coef_sparse

def _update_trimming_array(self, coef_full, trimming_array, trimming_grad):
Expand Down
48 changes: 4 additions & 44 deletions pysindy/optimizers/stable_linear_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class StableLinearSR3(ConstrainedSR3):
thresholder : string, optional (default 'l1')
Regularization function to use. Currently implemented options
are 'l1' (l1 norm), 'l2' (l2 norm), 'cad' (clipped
absolute deviation),
are 'l1' (l1 norm), 'l2' (l2 norm),
'weighted_l1' (weighted l1 norm), and 'weighted_l2' (weighted l2 norm).
Note that the thresholder must be convex here.
Expand Down Expand Up @@ -211,15 +210,9 @@ def _create_var_and_part_cost(
xi = cp.Variable(coef_sparse.shape[0] * coef_sparse.shape[1])
cost = cp.sum_squares(x @ xi - y.flatten())
cost = cost + cp.sum_squares(xi - coef_neg_def.flatten()) / (2 * self.nu)
if self.thresholder.lower() == "l1":
cost = cost + self.threshold * cp.norm1(xi)
elif self.thresholder.lower() == "weighted_l1":
cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi)
elif self.thresholder.lower() == "l2":
cost = cost + self.threshold * cp.norm2(xi)
elif self.thresholder.lower() == "weighted_l2":
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi)
return xi, cost
threshold = self.thresholds if self.thresholds is not None else self.threshold
penalty = self._calculate_penalty(self.thresholder, np.ravel(threshold), xi)
return xi, cost + penalty

def _update_coef_cvxpy(self, x, y, coef_sparse, coef_negative_definite):
"""
Expand Down Expand Up @@ -310,39 +303,6 @@ def _update_A(self, A_old, coef_sparse):
A_temp[r:, :r] = A_old[r:, :r]
return A_temp.T

def _objective(
self, x, y, q, coef_negative_definite, coef_sparse, trimming_array=None
):
"""Objective function"""
if q != 0:
print_ind = q % (self.max_iter // 10.0)
else:
print_ind = q
R2 = (y - np.dot(x, coef_negative_definite)) ** 2
D2 = (coef_negative_definite - coef_sparse) ** 2
if self.use_trimming:
assert trimming_array is not None
R2 *= trimming_array.reshape(x.shape[0], 1)

regularization = self.reg(
coef_negative_definite,
(self.threshold**2 if self.thresholds is None else self.thresholds**2)
/ self.nu,
)
if print_ind == 0 and self.verbose:
row = [
q,
np.sum(R2),
np.sum(D2) / self.nu,
regularization,
np.sum(R2) + np.sum(D2) + regularization,
]
print(
"{0:10d} ... {1:10.4e} ... {2:10.4e} ... {3:10.4e}"
" ... {4:10.4e}".format(*row)
)
return 0.5 * np.sum(R2) + 0.5 * regularization + 0.5 * np.sum(D2) / self.nu

def _reduce(self, x, y):
"""
Perform at most ``self.max_iter`` iterations of the SR3 algorithm
Expand Down
14 changes: 0 additions & 14 deletions pysindy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from .base import get_prox
from .base import get_regularization
from .base import print_model
from .base import prox_cad
from .base import prox_l0
from .base import prox_l1
from .base import prox_l2
from .base import prox_weighted_l0
from .base import prox_weighted_l1
from .base import prox_weighted_l2
from .base import reorder_constraints
from .base import supports_multiple_targets
from .base import validate_control_variables
Expand Down Expand Up @@ -66,13 +59,6 @@
"get_prox",
"get_regularization",
"print_model",
"prox_cad",
"prox_l0",
"prox_weighted_l0",
"prox_l1",
"prox_weighted_l1",
"prox_l2",
"prox_weighted_l2",
"reorder_constraints",
"supports_multiple_targets",
"validate_control_variables",
Expand Down
Loading

0 comments on commit 3d791e2

Please sign in to comment.