Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the linearized function for maximizing ellipsoid volume. #15

Merged
merged 1 commit into from
Nov 5, 2023

Conversation

hongkai-dai
Copy link
Owner

@hongkai-dai hongkai-dai commented Oct 30, 2023

This change is Reviewable

Copy link
Owner Author

@hongkai-dai hongkai-dai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+@Chuanruijiang for review please, thanks!

Reviewable status: 0 of 4 files reviewed, all discussions resolved (waiting on @Chuanruijiang)

Copy link
Collaborator

@Chuanruijiang Chuanruijiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 2 of 4 files at r1, all commit messages.
Reviewable status: 2 of 4 files reviewed, 4 unresolved discussions (waiting on @hongkai-dai)


doc/maximize_inner_ellipsoid.md line 51 at r1 (raw file):

\log(b^TS^{-1}b/4-c) = \log\left(-\text{det}\begin{bmatrix}c & b^T/2 \\b/2 & S\end{bmatrix}\right) - \log \text{det}(S)
$$

Here I have a quick question, hwo can we make sure that the det[c, b/2; b/2, S] is always negative?


tests/test_ellipsoid_utils.py line 18 at r1 (raw file):

    """
    n = S.shape[0]
    return jnp.log(b.dot(jnp.linalg.solve(S, b)) / 4 - c) - 1.0 / n * jnp.log(

Alright, linalg.solve(S,b) returns the solution to the equation "Sx=b", and "x=S^(-1)b". So, is this means that using "linalg.solve(S,b)" instead of " linalg.inv(S) dot (b)" is faster?


tests/test_ellipsoid_utils.py line 44 at r1 (raw file):

    S_grad, b_grad, c_grad = jax.grad(eval_max_volume_cost, argnums=(0, 1, 2))(
        S_bar_jnp, b_bar_jnp, c_bar
    )

OK, I see. So we linearize the cost fucntion at the "_bar" point and evaluate them at the "_val" point.


tests/test_ellipsoid_utils.py line 73 at r1 (raw file):

    c_bar = -10.0

    cost = mut.add_max_volume_linear_cost(prog, S, b, c, S_bar, b_bar, c_bar)

OK, so in this test, we wanted to see whether the cost (8) in the .md is close to the linearization of (4), which is the original cost, in the .md near the point "(S_bar, b_bar, c_bar)". Am I correct?

@hongkai-dai hongkai-dai force-pushed the maximize_ellipsoid_volume branch from 0586195 to 1365a3b Compare November 5, 2023 02:10
Copy link
Owner Author

@hongkai-dai hongkai-dai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 2 of 4 files reviewed, 4 unresolved discussions (waiting on @Chuanruijiang)


doc/maximize_inner_ellipsoid.md line 51 at r1 (raw file):

Previously, Chuanruijiang wrote…

Here I have a quick question, hwo can we make sure that the det[c, b/2; b/2, S] is always negative?

The constraint that det [c b/2; b/2 S] < 0 is not a convex constraint. This determinant being negative, together with S is psd would mean that the set {x | x'*S*x + b'*x + c <= 0} is an ellipsoid. We impose a sufficient condition as a given point x_bar satisfies this inequality x_bar' * S * x_bar + b' * x_bar + c <= 0 (you can also prove that this condition guarantees that the determinant is negative). We mentioned this constraint below in the markdown file.


tests/test_ellipsoid_utils.py line 18 at r1 (raw file):

Previously, Chuanruijiang wrote…

Alright, linalg.solve(S,b) returns the solution to the equation "Sx=b", and "x=S^(-1)b". So, is this means that using "linalg.solve(S,b)" instead of " linalg.inv(S) dot (b)" is faster?

Yes, using linalg.solve(S, b) is generally much faster than doing the inverse linalg.inv(S). There are many good solvers to solve the linear system of equations without needing to inverting the matrix.


tests/test_ellipsoid_utils.py line 44 at r1 (raw file):

Previously, Chuanruijiang wrote…

OK, I see. So we linearize the cost fucntion at the "_bar" point and evaluate them at the "_val" point.

That is right.


tests/test_ellipsoid_utils.py line 73 at r1 (raw file):

Previously, Chuanruijiang wrote…

OK, so in this test, we wanted to see whether the cost (8) in the .md is close to the linearization of (4), which is the original cost, in the .md near the point "(S_bar, b_bar, c_bar)". Am I correct?

It is slightly different. As explained in the documentation of check_add_max_volume_linear_cost, the idea is to make sure that the linearization I did in the markdown file is correct. There are two ways to compute the linearization, either analytically (as I did in the markdown file and also implemented in add_max_volume_linear_cost()), or numerically, as done in check_add_max_volume_linear_cost, using the jax package.

I expanded the documentation in check_add_max_volume_linear_cost to make it more clear.

Copy link
Collaborator

@Chuanruijiang Chuanruijiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:lgtm:

Reviewed 1 of 4 files at r1, 1 of 1 files at r2, all commit messages.
Reviewable status: :shipit: complete! all files reviewed, all discussions resolved (waiting on @hongkai-dai)

@hongkai-dai hongkai-dai merged commit bc99fdc into main Nov 5, 2023
@hongkai-dai hongkai-dai deleted the maximize_ellipsoid_volume branch January 3, 2024 06:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants