Skip to content

Commit

Permalink
Add flag to return trained gp from get_model_param
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Dec 17, 2024
1 parent 7ff3542 commit 63f6591
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions sgptools/utils/gpflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_model_params(X_train, y_train,
variance=1.0,
noise_variance=0.1,
kernel=None,
return_gp=False,
**kwargs):
"""Train a GP on the given training set
Expand All @@ -64,29 +65,36 @@ def get_model_params(X_train, y_train,
variance (float): Kernel variance
noise_variance (float): Data noise variance
kernel (gpflow.kernels.Kernel): gpflow kernel function
return_gp (bool): If True, returns the trained GP model
Returns:
loss (list): Loss values obtained during training
variance (float): Optimized data noise variance
kernel (gpflow.kernels.Kernel): Optimized gpflow kernel function
gp (gpflow.models.GPR): Optimized gpflow GP model.
Returned only if ```return_gp=True```.
"""
if kernel is None:
kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
variance=variance)

gpr_gt = gpflow.models.GPR(data=(X_train, y_train),
kernel=kernel,
noise_variance=noise_variance)
gpr = gpflow.models.GPR(data=(X_train, y_train),
kernel=kernel,
noise_variance=noise_variance)

if max_steps > 0:
loss = optimize_model(gpr_gt, max_steps=max_steps, lr=lr, **kwargs)
loss = optimize_model(gpr, max_steps=max_steps, lr=lr, **kwargs)
else:
loss = 0

if print_params:
print_summary(gpr_gt)
print_summary(gpr)

return loss, gpr_gt.likelihood.variance, kernel
if return_gp:
return loss, gpr.likelihood.variance, kernel, gpr
else:
return loss, gpr.likelihood.variance, kernel


class TraceInducingPts(gpflow.monitor.MonitorTask):
Expand Down

0 comments on commit 63f6591

Please sign in to comment.