Skip to content

Commit

Permalink
Improve API for ProjectedProcess GP
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexAndorra authored and ricardoV94 committed Mar 28, 2024
1 parent 47c4b48 commit ee07caa
Showing 1 changed file with 60 additions and 16 deletions.
76 changes: 60 additions & 16 deletions pymc_experimental/gp/latent_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional

import numpy as np
import pymc as pm
Expand All @@ -31,40 +32,83 @@ class LatentApprox(pm.gp.Latent):
class ProjectedProcess(pm.gp.Latent):
## AKA: DTC
def __init__(
self, n_inducing, *, mean_func=pm.gp.mean.Zero(), cov_func=pm.gp.cov.Constant(0.0)
self,
n_inducing: Optional[int] = None,
*,
mean_func=pm.gp.mean.Zero(),
cov_func=pm.gp.cov.Constant(0.0),
):
self.n_inducing = n_inducing
super().__init__(mean_func=mean_func, cov_func=cov_func)

def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs):
def _build_prior(self, name, X, X_inducing, jitter=JITTER_DEFAULT, **kwargs):
mu = self.mean_func(X)
Kuu = self.cov_func(Xu)
Kuu = self.cov_func(X_inducing)
L = cholesky(stabilize(Kuu, jitter))

n_inducing_points = np.shape(Xu)[0]
n_inducing_points = np.shape(X_inducing)[0]
v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs)
u = pm.Deterministic(name + "_u", L @ v)

Kfu = self.cov_func(X, Xu)
Kfu = self.cov_func(X, X_inducing)
Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u))

return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L

def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs):
if Xu is None and self.n_inducing is None:
raise ValueError
elif Xu is None:
if isinstance(X, np.ndarray):
Xu = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs)
def prior(
self,
name: str,
X: np.ndarray,
X_inducing: Optional[np.ndarray] = None,
jitter: float = JITTER_DEFAULT,
**kwargs,
) -> np.ndarray:
"""
Builds the GP prior with optional inducing points locations.
Parameters:
- name: Name for the GP variable.
- X: Input data.
- X_inducing: Optional. Inducing points for the GP.
- jitter: Jitter to ensure numerical stability.
Returns:
- GP function
"""
# Check if X is a numpy array
if not isinstance(X, np.ndarray):
raise ValueError("'X' must be a numpy array.")

# Proceed with provided X_inducing or determine X_inducing based on n_inducing
if X_inducing is not None:
pass # X_inducing is directly used

elif self.n_inducing is not None:
# Validate n_inducing
if not isinstance(self.n_inducing, int) or self.n_inducing <= 0:
raise ValueError(
"The number of inducing points, 'n_inducing', must be a positive integer."
)
if self.n_inducing > len(X):
raise ValueError(
"The number of inducing points, 'n_inducing', cannot be greater than the number of data points in 'X'."
)
# Use k-means to select X_inducing from X based on n_inducing
X_inducing = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs)
else:
# Neither X_inducing nor n_inducing provided
raise ValueError(
"Either 'X_inducing' (inducing points) or 'n_inducing' (number of inducing points) must be specified."
)

f, Kuuiu, L = self._build_prior(name, X, Xu, jitter, **kwargs)
self.X, self.Xu = X, Xu
f, Kuuiu, L = self._build_prior(name, X, X_inducing, jitter, **kwargs)
self.X, self.X_inducing = X, X_inducing
self.L, self.Kuuiu = L, Kuuiu
self.f = f
return f

def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs):
Ksu = self.cov_func(Xnew, Xu)
def _build_conditional(self, name, Xnew, X_inducing, L, Kuuiu, jitter, **kwargs):
Ksu = self.cov_func(Xnew, X_inducing)
mu = self.mean_func(Xnew) + Ksu @ Kuuiu
tmp = solve_lower(L, pt.transpose(Ksu))
Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T)
Expand All @@ -74,7 +118,7 @@ def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs):

def conditional(self, name, Xnew, jitter=1e-6, **kwargs):
mu, chol = self._build_conditional(
name, Xnew, self.Xu, self.L, self.Kuuiu, jitter, **kwargs
name, Xnew, self.X_inducing, self.L, self.Kuuiu, jitter, **kwargs
)
return pm.MvNormal(name, mu=mu, chol=chol)

Expand Down

0 comments on commit ee07caa

Please sign in to comment.