From ee07caa259517c14697e285f5029c1cca076d1b9 Mon Sep 17 00:00:00 2001 From: AlexAndorra Date: Sun, 25 Feb 2024 16:26:04 -0300 Subject: [PATCH] Improve API for ProjectedProcess GP --- pymc_experimental/gp/latent_approx.py | 76 +++++++++++++++++++++------ 1 file changed, 60 insertions(+), 16 deletions(-) diff --git a/pymc_experimental/gp/latent_approx.py b/pymc_experimental/gp/latent_approx.py index f71c0f91..ddcbb845 100644 --- a/pymc_experimental/gp/latent_approx.py +++ b/pymc_experimental/gp/latent_approx.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Optional import numpy as np import pymc as pm @@ -31,40 +32,83 @@ class LatentApprox(pm.gp.Latent): class ProjectedProcess(pm.gp.Latent): ## AKA: DTC def __init__( - self, n_inducing, *, mean_func=pm.gp.mean.Zero(), cov_func=pm.gp.cov.Constant(0.0) + self, + n_inducing: Optional[int] = None, + *, + mean_func=pm.gp.mean.Zero(), + cov_func=pm.gp.cov.Constant(0.0), ): self.n_inducing = n_inducing super().__init__(mean_func=mean_func, cov_func=cov_func) - def _build_prior(self, name, X, Xu, jitter=JITTER_DEFAULT, **kwargs): + def _build_prior(self, name, X, X_inducing, jitter=JITTER_DEFAULT, **kwargs): mu = self.mean_func(X) - Kuu = self.cov_func(Xu) + Kuu = self.cov_func(X_inducing) L = cholesky(stabilize(Kuu, jitter)) - n_inducing_points = np.shape(Xu)[0] + n_inducing_points = np.shape(X_inducing)[0] v = pm.Normal(name + "_u_rotated_", mu=0.0, sigma=1.0, size=n_inducing_points, **kwargs) u = pm.Deterministic(name + "_u", L @ v) - Kfu = self.cov_func(X, Xu) + Kfu = self.cov_func(X, X_inducing) Kuuiu = solve_upper(pt.transpose(L), solve_lower(L, u)) return pm.Deterministic(name, mu + Kfu @ Kuuiu), Kuuiu, L - def prior(self, name, X, Xu=None, jitter=JITTER_DEFAULT, **kwargs): - if Xu is None and self.n_inducing is None: - raise ValueError - elif Xu is None: - if isinstance(X, np.ndarray): - Xu = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs) + def prior( + self, + name: str, + X: np.ndarray, + X_inducing: Optional[np.ndarray] = None, + jitter: float = JITTER_DEFAULT, + **kwargs, + ) -> np.ndarray: + """ + Builds the GP prior with optional inducing points locations. + + Parameters: + - name: Name for the GP variable. + - X: Input data. + - X_inducing: Optional. Inducing points for the GP. + - jitter: Jitter to ensure numerical stability. + + Returns: + - GP function + """ + # Check if X is a numpy array + if not isinstance(X, np.ndarray): + raise ValueError("'X' must be a numpy array.") + + # Proceed with provided X_inducing or determine X_inducing based on n_inducing + if X_inducing is not None: + pass # X_inducing is directly used + + elif self.n_inducing is not None: + # Validate n_inducing + if not isinstance(self.n_inducing, int) or self.n_inducing <= 0: + raise ValueError( + "The number of inducing points, 'n_inducing', must be a positive integer." + ) + if self.n_inducing > len(X): + raise ValueError( + "The number of inducing points, 'n_inducing', cannot be greater than the number of data points in 'X'." + ) + # Use k-means to select X_inducing from X based on n_inducing + X_inducing = pm.gp.util.kmeans_inducing_points(self.n_inducing, X, **kwargs) + else: + # Neither X_inducing nor n_inducing provided + raise ValueError( + "Either 'X_inducing' (inducing points) or 'n_inducing' (number of inducing points) must be specified." + ) - f, Kuuiu, L = self._build_prior(name, X, Xu, jitter, **kwargs) - self.X, self.Xu = X, Xu + f, Kuuiu, L = self._build_prior(name, X, X_inducing, jitter, **kwargs) + self.X, self.X_inducing = X, X_inducing self.L, self.Kuuiu = L, Kuuiu self.f = f return f - def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs): - Ksu = self.cov_func(Xnew, Xu) + def _build_conditional(self, name, Xnew, X_inducing, L, Kuuiu, jitter, **kwargs): + Ksu = self.cov_func(Xnew, X_inducing) mu = self.mean_func(Xnew) + Ksu @ Kuuiu tmp = solve_lower(L, pt.transpose(Ksu)) Qss = pt.transpose(tmp) @ tmp # Qss = tt.dot(tt.dot(Ksu, tt.nlinalg.pinv(Kuu)), Ksu.T) @@ -74,7 +118,7 @@ def _build_conditional(self, name, Xnew, Xu, L, Kuuiu, jitter, **kwargs): def conditional(self, name, Xnew, jitter=1e-6, **kwargs): mu, chol = self._build_conditional( - name, Xnew, self.Xu, self.L, self.Kuuiu, jitter, **kwargs + name, Xnew, self.X_inducing, self.L, self.Kuuiu, jitter, **kwargs ) return pm.MvNormal(name, mu=mu, chol=chol)