Skip to content

Commit

Permalink
Sparse Gaussian processes
Browse files Browse the repository at this point in the history
  • Loading branch information
krasserm committed Dec 11, 2020
1 parent bffc6dc commit a3479de
Show file tree
Hide file tree
Showing 5 changed files with 51,327 additions and 23 deletions.
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,46 @@

This repository is a collection of notebooks about *Bayesian Machine Learning*. The following links display
some of the notebooks via [nbviewer](https://nbviewer.jupyter.org/) to ensure a proper rendering of formulas.
Dependencies are specified in `requirements.txt` files in subdirectories.

- [Bayesian regression with linear basis function models](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-linear-regression/bayesian_linear_regression.ipynb).
Introduction to Bayesian linear regression. Implementation from scratch with plain NumPy as well as usage of scikit-learn
for comparison. See also
[PyMC4 implementation](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-linear-regression/bayesian_linear_regression_pymc4.ipynb) and
Introduction to Bayesian linear regression. Implementation with plain NumPy and scikit-learn. See also
[PyMC3 implementation](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-linear-regression/bayesian_linear_regression_pymc3.ipynb).

- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/bayesian-machine-learning/blob/dev/gaussian-processes/gaussian_processes.ipynb)
[Gaussian processes](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/gaussian-processes/gaussian_processes.ipynb?flush_cache=true).
Introduction to Gaussian processes for regression. Example implementations with plain NumPy/SciPy as well as with libraries
scikit-learn and GPy ([requirements.txt](gaussian-processes/requirements.txt)).
Introduction to Gaussian processes for regression. Implementation with plain NumPy/SciPy as well as with scikit-learn and GPy.

- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/bayesian-machine-learning/blob/dev/gaussian-processes/gaussian_processes_classification.ipynb)
[Gaussian processes for classification](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/gaussian-processes/gaussian_processes_classification.ipynb).
Introduction to Gaussian processes for classification. Example implementations with plain NumPy/SciPy as well as with
scikit-learn ([requirements.txt](gaussian-processes/requirements.txt)).
Introduction to Gaussian processes for classification. Implementation with plain NumPy/SciPy as well as with scikit-learn.

- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/bayesian-machine-learning/blob/dev/gaussian-processes/gaussian_processes_sparse.ipynb)
[Sparse Gaussian processes](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/gaussian-processes/gaussian_processes_sparse.ipynb).
Introduction to sparse Gaussian processes using a variational approach. Example implementation with JAX.

- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-optimization/bayesian_optimization.ipynb)
[Bayesian optimization](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-optimization/bayesian_optimization.ipynb).
Introduction to Bayesian optimization. Example implementations with plain NumPy/SciPy as well as with libraries
scikit-optimize and GPyOpt. Hyper-parameter tuning as application example.
Introduction to Bayesian optimization. Implementation with plain NumPy/SciPy as well as with libraries scikit-optimize
and GPyOpt. Hyper-parameter tuning as application example.

- [Variational inference in Bayesian neural networks](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-neural-networks/bayesian_neural_networks.ipynb).
Demonstrates how to implement a Bayesian neural network and variational inference of network parameters. Example implementation
with Keras ([requirements.txt](bayesian-neural-networks/requirements.txt)). See also
[PyMC4 implementation](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-neural-networks/bayesian_neural_networks_pymc4.ipynb).
Demonstrates how to implement a Bayesian neural network and variational inference of weights. Example implementation
with Keras.

- [Reliable uncertainty estimates for neural network predictions](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/noise-contrastive-priors/ncp.ipynb).
Uses noise contrastive priors in Bayesian neural networks to get more reliable uncertainty estimates for OOD data.
Implemented with Tensorflow 2 and Tensorflow Probability ([requirements.txt](noise-contrastive-priors/requirements.txt)).
Uses noise contrastive priors for Bayesian neural networks to get more reliable uncertainty estimates for OOD data.
Implemented with Tensorflow 2 and Tensorflow Probability.

- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/bayesian-machine-learning/blob/dev/latent-variable-models/latent_variable_models_part_1.ipynb)
[Latent variable models, part 1: Gaussian mixture models and the EM algorithm](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/latent-variable-models/latent_variable_models_part_1.ipynb).
Introduction to the expectation maximization (EM) algorithm and its application to Gaussian mixture models. Example
implementation with plain NumPy/SciPy and scikit-learn for comparison. See also
Introduction to the expectation maximization (EM) algorithm and its application to Gaussian mixture models.
Implementation with plain NumPy/SciPy and scikit-learn. See also
[PyMC3 implementation](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/latent-variable-models/latent_variable_models_part_1_pymc3.ipynb).

- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/bayesian-machine-learning/blob/dev/latent-variable-models/latent_variable_models_part_2.ipynb)
[Latent variable models, part 2: Stochastic variational inference and variational autoencoders](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/latent-variable-models/latent_variable_models_part_2.ipynb).
Introduction to stochastic variational inference with variational autoencoder as application example. Implementation
Introduction to stochastic variational inference with a variational autoencoder as application example. Implementation
with Tensorflow 2.x.

- [Deep feature consistent variational autoencoder](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/autoencoder-applications/variational_autoencoder_dfc.ipynb).
Expand Down
12 changes: 7 additions & 5 deletions gaussian-processes/gaussian_processes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@
}
],
"source": [
"from numpy.linalg import cholesky, det, lstsq\n",
"from numpy.linalg import cholesky, det\n",
"from scipy.linalg import solve_triangular\n",
"from scipy.optimize import minimize\n",
"\n",
"def nll_fn(X_train, Y_train, noise, naive=True):\n",
Expand Down Expand Up @@ -437,14 +438,15 @@
" # in http://www.gaussianprocess.org/gpml/chapters/RW2.pdf, Section\n",
" # 2.2, Algorithm 2.1.\n",
" \n",
" def ls(a, b):\n",
" return lstsq(a, b, rcond=-1)[0]\n",
" \n",
" K = kernel(X_train, X_train, l=theta[0], sigma_f=theta[1]) + \\\n",
" noise**2 * np.eye(len(X_train))\n",
" L = cholesky(K)\n",
" \n",
" S1 = solve_triangular(L, Y_train, lower=True)\n",
" S2 = solve_triangular(L.T, S1, lower=False)\n",
" \n",
" return np.sum(np.log(np.diagonal(L))) + \\\n",
" 0.5 * Y_train.dot(ls(L.T, ls(L, Y_train))) + \\\n",
" 0.5 * Y_train.dot(S2) + \\\n",
" 0.5 * len(X_train) * np.log(2*np.pi)\n",
"\n",
" if naive:\n",
Expand Down
Loading

0 comments on commit a3479de

Please sign in to comment.