diff --git a/sgptools/utils/gpflow.py b/sgptools/utils/gpflow.py index 42995fb..98f695e 100644 --- a/sgptools/utils/gpflow.py +++ b/sgptools/utils/gpflow.py @@ -50,6 +50,7 @@ def get_model_params(X_train, y_train, variance=1.0, noise_variance=0.1, kernel=None, + return_gp=False, **kwargs): """Train a GP on the given training set @@ -64,29 +65,36 @@ def get_model_params(X_train, y_train, variance (float): Kernel variance noise_variance (float): Data noise variance kernel (gpflow.kernels.Kernel): gpflow kernel function + return_gp (bool): If True, returns the trained GP model Returns: loss (list): Loss values obtained during training variance (float): Optimized data noise variance kernel (gpflow.kernels.Kernel): Optimized gpflow kernel function + gp (gpflow.models.GPR): Optimized gpflow GP model. + Returned only if ```return_gp=True```. + """ if kernel is None: kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales, variance=variance) - gpr_gt = gpflow.models.GPR(data=(X_train, y_train), - kernel=kernel, - noise_variance=noise_variance) + gpr = gpflow.models.GPR(data=(X_train, y_train), + kernel=kernel, + noise_variance=noise_variance) if max_steps > 0: - loss = optimize_model(gpr_gt, max_steps=max_steps, lr=lr, **kwargs) + loss = optimize_model(gpr, max_steps=max_steps, lr=lr, **kwargs) else: loss = 0 if print_params: - print_summary(gpr_gt) + print_summary(gpr) - return loss, gpr_gt.likelihood.variance, kernel + if return_gp: + return loss, gpr.likelihood.variance, kernel, gpr + else: + return loss, gpr.likelihood.variance, kernel class TraceInducingPts(gpflow.monitor.MonitorTask):