Skip to content

Commit

Permalink
no thing much, profiling, see #3
Browse files Browse the repository at this point in the history
  • Loading branch information
thangbui committed Mar 28, 2017
1 parent 17338c1 commit 0a515d7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 26 deletions.
27 changes: 16 additions & 11 deletions examples/gplvm_aep_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@
import numpy as np
import pdb
import matplotlib.pylab as plt
from .context import aep
from scipy import special

# from .context import aep

import sys
import os
sys.path.insert(0, os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
import geepee.aep_models as aep

np.random.seed(42)

def run_cluster():
import GPy
# create dataset
print "creating dataset..."
N = 50
N = 100
k1 = GPy.kern.RBF(5, variance=1, lengthscale=1. /
np.random.dirichlet(np.r_[10, 10, 10, 0.1, 0.1]), ARD=True)
k2 = GPy.kern.RBF(5, variance=1, lengthscale=1. /
Expand All @@ -28,7 +36,7 @@ def run_cluster():

# inference
print "inference ..."
M = 20
M = 30
D = 5
lvm = aep.SGPLVM(Y, D, M, lik='Gaussian')

Expand Down Expand Up @@ -121,11 +129,7 @@ def run_mnist():


def run_oil():

# data_path = '/Users/thangbui/Desktop/gplvm/tmp/data/'
# data_path = '../tmp/data/'
data_path = './tmp/data/'

data_path = '/scratch/tdb40/datasets/lvm/three_phase_oil_flow/'

def oil(data_set='oil'):
"""The three phase oil data from Bishop and James (1993)."""
Expand Down Expand Up @@ -389,6 +393,7 @@ def run_frey():

if __name__ == '__main__':
run_cluster()
run_semicircle()
run_pinwheel()
run_xor()
# run_semicircle()
# run_pinwheel()
# run_xor()
# run_oil()
3 changes: 2 additions & 1 deletion geepee/aep_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _forward_prop_random_thru_post_mm(self, mx, vx):
vout = psi0 + Bpsi2 - mout**2
return mout, vout

# @profile
@profile
def backprop_grads_lvm(self, m, v, dm, dv, psi1, psi2, mx, vx, alpha=1.0):
N = self.N
M = self.M
Expand Down Expand Up @@ -713,6 +713,7 @@ def __init__(self, y_train, hidden_size, no_pseudo,
self.x_post_1 = np.zeros((N, Din))
self.x_post_2 = np.zeros((N, Din))

@profile
def objective_function(self, params, idxs, alpha=1.0):
N = self.N
yb = self.y_train[idxs, :]
Expand Down
28 changes: 14 additions & 14 deletions tests/test_grads_aep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,20 +1284,20 @@ def kink(T, process_noise, obs_noise, xprev=None):


if __name__ == '__main__':
# test_gplvm_aep_gaussian()
# test_gplvm_aep_probit()
# test_gplvm_aep_gaussian_scipy()
# test_gplvm_aep_probit_scipy()

test_gpr_aep_gaussian()
test_gpr_aep_probit()
test_gpr_aep_gaussian_scipy()
test_gpr_aep_probit_scipy()

test_dgpr_aep_gaussian()
test_dgpr_aep_probit()
test_dgpr_aep_gaussian_scipy()
test_dgpr_aep_probit_scipy()
test_gplvm_aep_gaussian()
test_gplvm_aep_probit()
test_gplvm_aep_gaussian_scipy()
test_gplvm_aep_probit_scipy()

# test_gpr_aep_gaussian()
# test_gpr_aep_probit()
# test_gpr_aep_gaussian_scipy()
# test_gpr_aep_probit_scipy()

# test_dgpr_aep_gaussian()
# test_dgpr_aep_probit()
# test_dgpr_aep_gaussian_scipy()
# test_dgpr_aep_probit_scipy()

# test_gpssm_aep_gaussian()
# np.random.seed(42)
Expand Down

0 comments on commit 0a515d7

Please sign in to comment.