Skip to content

Commit

Permalink
added simple save/load pickle #2
Browse files Browse the repository at this point in the history
  • Loading branch information
thangbui committed Mar 31, 2017
1 parent e56a31b commit c5ba7f3
Show file tree
Hide file tree
Showing 4 changed files with 386 additions and 215 deletions.
45 changes: 42 additions & 3 deletions examples/gpssm_hodgkin_huxley.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,22 @@ def model_V_n():
hypers = model_aep.init_hypers(y)
for key in params.keys():
hypers[key] = params[key]
model_aep.update_hypers(hypers, alpha)
model_aep.update_hypers(hypers)
# optimise
# model_aep.set_fixed_params(['R', 'sn', 'sf'])
model_aep.set_fixed_params(['sf'])
model_aep.optimise(method='L-BFGS-B', alpha=alpha, maxiter=10000, reinit_hypers=False)
opt_hypers = model_aep.get_hypers()
plot_model(model_aep, 'AEP %.3f'%alpha)
plt.show()
model_aep.save_model('/tmp/gpssm_hh_VN.pickle')


def model_all():
# load dataset
data = np.loadtxt('./sandbox/hh_data.txt')
# use the voltage and potasisum current
y = data
y = data[:, :4]
y = y / np.std(y, axis=0)
# init hypers
Dlatent = 2
Expand All @@ -152,7 +153,7 @@ def model_all():
hypers = model_aep.init_hypers(y)
for key in params.keys():
hypers[key] = params[key]
model_aep.update_hypers(hypers, alpha)
model_aep.update_hypers(hypers)
# optimise
# model_aep.set_fixed_params(['R', 'sn', 'sf'])
model_aep.set_fixed_params(['sf'])
Expand All @@ -161,8 +162,46 @@ def model_all():
opt_hypers = model_aep.get_hypers()
plot_model(model_aep, 'AEP %.3f'%alpha)
plt.show()
model_aep.save_model('/tmp/gpssm_hh_all.pickle')


def model_all_with_control():
# TODO: predict with control
# load dataset
data = np.loadtxt('./sandbox/hh_data.txt')
# use the voltage and potasisum current
y = data[:, :4]
xc = data[:, [-1]]
y = y / np.std(y, axis=0)
# init hypers
Dlatent = 2
Dobs = y.shape[1]
M = 30
T = y.shape[0]
R = np.ones(Dobs)*np.log(0.01)/2
lsn = np.log(0.01)/2
params = {'sn': lsn, 'R': R}

alpha = 0.4
print 'alpha = %.3f' % alpha
# create AEP model
model_aep = aep.SGPSSM(y, Dlatent, M,
lik='Gaussian', prior_mean=0, prior_var=1000)
hypers = model_aep.init_hypers(y)
for key in params.keys():
hypers[key] = params[key]
model_aep.update_hypers(hypers)
# optimise
# model_aep.set_fixed_params(['R', 'sn', 'sf'])
model_aep.set_fixed_params(['sf'])
model_aep.optimise(method='L-BFGS-B', alpha=alpha, maxiter=30000, reinit_hypers=False)
# model_aep.optimise(method='adam', alpha=alpha, maxiter=10000, reinit_hypers=False, adam_lr=0.05)
opt_hypers = model_aep.get_hypers()
plot_model(model_aep, 'AEP %.3f'%alpha)
plt.show()


if __name__ == '__main__':
model_all()
# model_all_with_control()
# model_V_n()
2 changes: 1 addition & 1 deletion exps/gpssm/kink_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def kink(T, process_noise, obs_noise, xprev=None):
hypers = model.init_hypers(y_train)
for key in params.keys():
hypers[key] = params[key]
model.update_hypers(hypers, alpha)
model.update_hypers(hypers)
model.set_fixed_params(['C'])
model.optimise(method='L-BFGS-B', alpha=alpha, maxiter=np.inf, reinit_hypers=False)

Expand Down
2 changes: 1 addition & 1 deletion exps/gpssm/lin_cos_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def func(T, process_noise, obs_noise, xprev=None):
hypers = model.init_hypers(y_train)
for key in params.keys():
hypers[key] = params[key]
model.update_hypers(hypers, alpha)
model.update_hypers(hypers)
model.set_fixed_params(['C', 'R'])
model.optimise(method='L-BFGS-B', alpha=alpha, maxiter=10000, reinit_hypers=False)

Expand Down
Loading

0 comments on commit c5ba7f3

Please sign in to comment.