Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG #394 fix prox and regularization #544

Merged
merged 31 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
298807e
bug(regularization): fix calculation and confusion regarding threshold
himkwtn Aug 14, 2024
a2df39a
bug: fix calling regularization with wrong argument
himkwtn Aug 14, 2024
b436eab
CLN: linting
himkwtn Aug 14, 2024
070d0c0
ENH: unit tests for get_regularization
himkwtn Aug 14, 2024
58a0ce2
ENH: unit tests for get_prox
himkwtn Aug 16, 2024
14b564a
ENH: add typings for regularization and prox
himkwtn Aug 19, 2024
c2bb3ee
bug: fix duplicate test name
himkwtn Aug 19, 2024
3cad8f6
CLN: remove debug variable
himkwtn Aug 19, 2024
4509ea2
BUG: fix l0 reg
himkwtn Aug 20, 2024
b35f24c
ENH: improve get_regularization test cases
himkwtn Aug 20, 2024
86fef20
BUG: fix cvxpy regularization calculation
himkwtn Aug 20, 2024
9ff815a
ENH: improve get prox/reg
himkwtn Aug 23, 2024
89ead0f
ENH: remove cad
himkwtn Aug 23, 2024
658c300
ENH: create unit test for get_regularization shape validation
himkwtn Aug 23, 2024
1387b81
CLN: clean up util code
himkwtn Aug 26, 2024
ded5b29
revert test cases and add thresholds transpose
himkwtn Aug 26, 2024
d971225
DOC: constrained sr3 method update doc string
himkwtn Aug 26, 2024
de58118
CLN: fix linting
himkwtn Aug 26, 2024
ba3798e
CLN: fix linting
himkwtn Aug 26, 2024
6928455
CLN: fix linting
himkwtn Aug 26, 2024
05ee9e6
change shape validation
himkwtn Aug 28, 2024
2981f77
CLN: merge weighted and non-weighted prox/reg fn
himkwtn Aug 28, 2024
c2237bf
clean up docstring
himkwtn Aug 28, 2024
5ee3bcc
CLN: fix constrained SR3 docstring
himkwtn Aug 30, 2024
870525d
BUG: fix example for using thresholds in SR3
himkwtn Aug 30, 2024
a0475aa
fix according to comments
himkwtn Sep 3, 2024
8a2afeb
fix lint
himkwtn Sep 5, 2024
a513d76
test weighted_prox
himkwtn Sep 5, 2024
98df181
publish notebook
himkwtn Sep 5, 2024
90967e0
manually fix notebook
himkwtn Sep 9, 2024
01ce00d
refactor ConstrainedSR3._calculate_penalty
himkwtn Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/1_feature_overview/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def f(x):
model.print()

# With thresholds matrix
thresholds = 2 * np.ones((10, 3))
thresholds = 2 * np.ones((3, 10))
thresholds[4:, :] = 0.1
sr3_optimizer = ps.SR3(thresholder="weighted_l0", thresholds=thresholds)
model = ps.SINDy(optimizer=sr3_optimizer).fit(x_train, t=dt)
Expand Down
4 changes: 2 additions & 2 deletions pysindy/optimizers/constrained_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def _calculate_penalty(
regularization: 'l0' | 'weighted_l0' | 'l1' | 'weighted_l1' |
'l2' | 'weighted_l2'
regularization_weight: float | np.array, can be a scalar
or an array of shape (n_targets, n_features)
xi: cp.Variable
or an array of shape (m, n)
xi: cp.Variable of length m*n
himkwtn marked this conversation as resolved.
Show resolved Hide resolved

Returns:
--------
Expand Down
103 changes: 42 additions & 61 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def validate_prox_and_reg_inputs(func, regularization):
def _validate_prox_and_reg_inputs(func, regularization):
def wrapper(x, regularization_weight):
if regularization[:8] == "weighted":
if not isinstance(regularization_weight, np.ndarray):
Expand All @@ -163,15 +163,11 @@ def wrapper(x, regularization_weight):
weight_shape = regularization_weight.shape
if weight_shape != x.shape:
raise ValueError(
f"Invalid shape for 'regularization_weight': \
{weight_shape}. Must be the same shape as x: {x.shape}."
f"Invalid shape for 'regularization_weight':"
f"{weight_shape}. Must be the same shape as x: {x.shape}."
)
else:
if not isinstance(regularization_weight, (int, float)) and (
isinstance(regularization_weight, np.ndarray)
and regularization_weight.shape not in [(1, 1), (1,)]
):
raise ValueError("'regularization_weight' must be a scalar")
elif not isinstance(regularization_weight, (int, float)):
raise ValueError("'regularization_weight' must be a scalar")
return func(x, regularization_weight)

return wrapper
Expand All @@ -180,7 +176,7 @@ def wrapper(x, regularization_weight):
def get_prox(
regularization: str,
) -> Callable[
[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], NDArray[np.float64]
[NDArray[np.float64], Union[float, NDArray[np.float64]]], NDArray[np.float64]
]:
"""
Args:
Expand All @@ -190,59 +186,47 @@ def get_prox(
Returns:
--------
proximal_operator: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (n_targets, n_features)
A function that takes an input x of shape (m, n)
and regularization weight factor which can be a scalar or
an array of shape (n_targets, n_features),
and returns an array of shape (n_targets, n_features)
an array of shape (m, n),
and returns an array of shape (m, n)
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
"""

def prox_l0(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L0 regularization."""
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_weighted_l0(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def prox_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
"""Proximal operator for weighted l0 regularization."""
threshold = np.sqrt(2 * regularization_weight)
return x * (np.abs(x) > threshold)

def prox_l1(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for L1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def prox_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
"""Proximal operator for weighted l1 regularization."""
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_l2(x: NDArray[np.float64], regularization_weight: np.float64):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)
return np.sign(x) * np.maximum(np.abs(x) - regularization_weight, 0)

def prox_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def prox_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
"""Proximal operator for ridge regularization."""
return x / (1 + 2 * regularization_weight)

prox = {
"l0": prox_l0,
"weighted_l0": prox_weighted_l0,
"weighted_l0": prox_l0,
"l1": prox_l1,
"weighted_l1": prox_weighted_l1,
"weighted_l1": prox_l1,
"l2": prox_l2,
"weighted_l2": prox_weighted_l2,
"weighted_l2": prox_l2,
}
regularization = regularization.lower()
return validate_prox_and_reg_inputs(prox[regularization], regularization)
return _validate_prox_and_reg_inputs(prox[regularization], regularization)


def get_regularization(
regularization: str,
) -> Callable[[NDArray[np.float64], Union[np.float64, NDArray[np.float64]]], float]:
) -> Callable[[NDArray[np.float64], Union[float, NDArray[np.float64]]], float]:
"""
Args:
-----
Expand All @@ -251,46 +235,43 @@ def get_regularization(
Returns:
--------
regularization_function: (x: np.array, reg_weight: float | np.array) -> np.array
A function that takes an input x of shape (n_targets, n_features)
A function that takes an input x of shape (m, n)
and regularization weight factor which can be a scalar or
an array of shape (n_targets, n_features),
an array of shape (m, n),
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
and returns a float
"""

def regularization_l0(x: NDArray[np.float64], regularization_weight: np.float64):
return regularization_weight * np.count_nonzero(x)

def regualization_weighted_l0(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def regularization_l0(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight[np.nonzero(x)])

def regularization_l1(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * np.abs(x))
return np.sum(regularization_weight * (x != 0))

def regualization_weighted_l1(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def regularization_l1(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):
return np.sum(regularization_weight * np.abs(x))

def regularization_l2(x: NDArray[np.float64], regularization_weight: np.float64):
return np.sum(regularization_weight * x**2)
return np.sum(regularization_weight * np.abs(x))

def regualization_weighted_l2(
x: NDArray[np.float64], regularization_weight: NDArray[np.float64]
def regularization_l2(
x: NDArray[np.float64],
regularization_weight: Union[float, NDArray[np.float64]],
):

return np.sum(regularization_weight * x**2)

regularization_fn = {
"l0": regularization_l0,
"weighted_l0": regualization_weighted_l0,
"weighted_l0": regularization_l0,
"l1": regularization_l1,
"weighted_l1": regualization_weighted_l1,
"weighted_l1": regularization_l1,
"l2": regularization_l2,
"weighted_l2": regualization_weighted_l2,
"weighted_l2": regularization_l2,
}
regularization = regularization.lower()
return validate_prox_and_reg_inputs(
return _validate_prox_and_reg_inputs(
regularization_fn[regularization], regularization
)

Expand Down
45 changes: 17 additions & 28 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,35 +83,16 @@ def test_get_regularization(regularization, lam, expected):
assert result == expected


@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize("lam", [1, np.array([1]), np.array([[1]])])
def test_get_prox_and_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
reg_result = reg(data, lam)
prox = get_prox(regularization)
prox_result = prox(data, lam)
assert reg_result is not None
assert prox_result is not None


@pytest.mark.parametrize(
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize("lam", [np.array([[1, 2]]).T])
def test_get_weighted_prox_and_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
reg_result = reg(data, lam)
prox = get_prox(regularization)
prox_result = prox(data, lam)
assert reg_result is not None
assert prox_result is not None


@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize(
"lam", [np.array([[1, 2]]), np.array([1, 2]), np.array([[1, 2]]).T]
"lam",
[
np.array([[1, 2]]),
np.array([1, 2]),
np.array([[1, 2]]).T,
np.array([1]),
np.array([[1]]),
],
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
)
def test_get_prox_and_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
himkwtn marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -127,7 +108,15 @@ def test_get_prox_and_regularization_bad_shape(regularization, lam):
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
@pytest.mark.parametrize(
"lam", [np.array([[1, 2]]), np.array([1, 2, 3]), np.array([[1, 2, 3]]).T, 1]
"lam",
[
np.array([[1, 2]]),
np.array([1, 2, 3]),
np.array([[1, 2, 3]]).T,
1,
np.array([1]),
np.array([[1]]),
],
Jacob-Stevens-Haas marked this conversation as resolved.
Show resolved Hide resolved
)
def test_get_weighted_prox_and_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
Expand Down
Loading