diff --git a/pylearn2/datasets/mnist_augmented.py b/pylearn2/datasets/mnist_augmented.py new file mode 100644 index 0000000000..cba3c07168 --- /dev/null +++ b/pylearn2/datasets/mnist_augmented.py @@ -0,0 +1,80 @@ +""" +Augmented MNIST wrapper class +""" + +import os +import numpy as np + +from pylearn2.datasets.dense_design_matrix import DenseDesignMatrix +from pylearn2.scripts.dbm.augment_input import augment_input +from pylearn2.utils import serial + + +class MNIST_AUGMENTED(DenseDesignMatrix): + + """ + Loads MNIST dataset and builds augmented dataset + for DBM discriminative finetuning. + + Parameters + ---------- + dataset : `pylearn2.datasets.dataset.Dataset` + which_set : str + Select between training and test set. + model : `pylearn2.models.model.Model` + The DBM to be finetuned. + mf_steps : int + Number of mean field updates for data augmentation. + one_hot : bool, optional + Enable or disable one-hot configuration for + label matrix. + start : int, optional + First index of dataset to be finetuned. + stop : int, optional + Last index of dataset to be finetuned. + save_aug : bool, optional + Select whether to save the augmented dataset + in a pkl file or not. + """ + + def __init__(self, dataset, which_set, model, mf_steps, one_hot=True, + start=None, stop=None, save_aug=False): + + self.path = os.path.join('${PYLEARN2_DATA_PATH}', 'mnist') + self.path = serial.preprocess(self.path) + + try: + if which_set == 'train': + path = os.path.join(self.path, 'aug_train_dump.pkl.gz') + datasets = serial.load(filepath=path) + augmented_X, y = datasets[0], datasets[1] + else: + path = os.path.join(self.path, 'aug_test_dump.pkl.gz') + datasets = serial.load(filepath=path) + augmented_X, y = datasets[0], datasets[1] + augmented_X, y = augmented_X[start:stop], y[start:stop] + except: + X = dataset.X + if one_hot is True: + one_hot = np.zeros((dataset.y.shape[0], 10), dtype='float32') + for i in range(dataset.y.shape[0]): + label = dataset.y[i] + one_hot[i, label] = 1. + y = one_hot + else: + y = dataset.y + + # BUILD AUGMENTED INPUT FOR FINETUNING + X, y = X[start:stop], y[start:stop] + augmented_X = augment_input(X, model, mf_steps) + + if save_aug is True: + datasets = augmented_X, y + if which_set == 'train': + path = os.path.join(self.path, 'aug_train_dump.pkl.gz') + serial.save(filepath=path, obj=datasets) + else: + path = os.path.join(self.path, 'aug_test_dump.pkl.gz') + serial.save(filepath=path, obj=datasets) + + super(MNIST_AUGMENTED, self).__init__(X=augmented_X, y=y) diff --git a/pylearn2/scripts/dbm/augment_input.py b/pylearn2/scripts/dbm/augment_input.py new file mode 100644 index 0000000000..a2c52dd2b9 --- /dev/null +++ b/pylearn2/scripts/dbm/augment_input.py @@ -0,0 +1,68 @@ +""" +This module augments the dataset in order to make it suitable for +DBM discriminative finetuning. +For each example in the dataset, using the provided trained DBM, +it performs n mean-field updates initializing the state of the second +hidden layer of the DBM and augments the example with this state. +It returns a dataset where each example is composed of its previous +value concatenated with the respective initialization of the second +hidden layer of the DBM. +""" + +from pylearn2.utils import sharedX +from theano import function +import numpy + + +def augment_input(X, model, mf_steps): + + """ + Input augmentation script. + + Parameters + ---------- + X : ndarray, 2-dimensional + A matrix containing the initial dataset. + model : DBM + The DBM model to be finetuned. It is used for + mean field updates. + mf_steps : int + The number of mean field updates. + + Returns + ------- + final_data : ndarray, 2-dimensional + The final augmented dataset. + + References + ---------- + Salakhutdinov Ruslan and Hinton Geoffrey. "An efficient + procedure for deep boltzmann machines". 2012. + """ + + print("\nAugmenting data...\n") + + i = 0 + init_data = model.visible_layer.space.get_origin_batch(batch_size=1, + dtype='float32') + + for x in X[:]: + init_data[0] = x + data = sharedX(init_data, name='v') + # mean field inference of second hidden layer + # (niter: number of mean field updates) + marginal_posterior = model.mf(V=data, niter=mf_steps)[1] + mp = function([], marginal_posterior) + mp = mp()[0][0] + if i == 0: + final_data = numpy.asarray([numpy.concatenate((mp, x))]) + else: + final_data = numpy.append(final_data, + [numpy.concatenate((mp, x))], + axis=0) + + i += 1 + + print("Data augmentation complete!") + + return final_data diff --git a/pylearn2/scripts/papers/dbm/README b/pylearn2/scripts/papers/dbm/README new file mode 100644 index 0000000000..677c92f903 --- /dev/null +++ b/pylearn2/scripts/papers/dbm/README @@ -0,0 +1,25 @@ +The files in this directory recreate the experiment reported in the +paper + +An efficient learning procedure for Deep Boltzmann Machines. G. Hinton, and R. Salakhutdinov. + +The procedure is divided in three phases: pretraining of RBMs, training and finetuning. The test_dbm_mnist script allows to enable each phase of training and to select whether the DBM +is composed of a softmax layer or not, and whether the MLP has to do finetuning with dropout or not. +This implementation works only for DBMs with 2 hidden layers: the stacking of RBMs to compose the DBM needs some changes to Contrastive Divergence algorithm that have not been implemented here. +However, it has been shown that using more than 2 layers in a DBM, does not guarantee to improve performances. + +As explained in the paper, the finetuning procedure uses an augmented input to feed the MLP and this implementation creates it using augment_input.py and +mnistaugmented.py in pylearn2/datasets/. The latter takes the mnist dataset and augment it. Eventually, it saves .pkl files of the augmented dataset +because data augmentation is a time-consuming operation. + +There are two tests in /tests. The script to run the whole procedure, with all the right parameters, reaches the result published by Hinton & Salakhutdinov. The fast version of it, is +suitable to be run on travis. It does not perform well because it uses a very small training set and a very small number of epochs. + +NO DROPOUT RESULTS: +The test returns a 0.94% test error WITHOUT softmax on the top of the DBM and dropout. +DROPOUT RESULTS: +The test returns a 0.84% test error WITH softmax on the top of the DBM and dropout. + +Experiments have been performed on Ubuntu 14.04 LTS using a NVIDIA Tesla C1060 GPU and a 8-core Intel(R) Core(TM) i7 CPU 920 @ 2.67GHz. I used openblas-base and numpy version 1.9.0, +scipy version 0.13.3, theano version 0.6.0 and pylearn2 with 6264 commits. + diff --git a/pylearn2/scripts/papers/dbm/dbm_mnist.yaml b/pylearn2/scripts/papers/dbm/dbm_mnist.yaml new file mode 100644 index 0000000000..9e402b610b --- /dev/null +++ b/pylearn2/scripts/papers/dbm/dbm_mnist.yaml @@ -0,0 +1,71 @@ +!obj:pylearn2.train.Train { + dataset: &data !obj:pylearn2.datasets.binarizer.Binarizer { + raw: &raw_train !obj:pylearn2.datasets.mnist.MNIST { + which_set: "train", + start: 0, + stop: %(train_stop)i + } + }, + model: !obj:pylearn2.models.dbm.DBM { + batch_size: %(batch_size)i, + niter: 10, + inference_procedure: !obj:pylearn2.models.dbm.WeightDoubling {}, + visible_layer: !obj:pylearn2.models.dbm.BinaryVector { + nvis: 784, + }, + hidden_layers: [ + !obj:pylearn2.models.dbm.BinaryVectorMaxPool { + layer_name: 'h1', + detector_layer_dim: %(n_h1)i, + pool_size: 1, + irange: 0.001, + }, + !obj:pylearn2.models.dbm.BinaryVectorMaxPool { + layer_name: 'h2', + detector_layer_dim: %(n_h2)i, + pool_size: 1, + irange: 0.001, + }, + ] + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + learning_rate: 0.005, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: 0.5, + }, + monitoring_batches: %(monitoring_batches)i, + monitoring_dataset: *data, + cost : !obj:pylearn2.costs.cost.SumOfCosts { + costs: [ + !obj:pylearn2.costs.dbm.VariationalPCD { + num_chains: 100, + num_gibbs_steps: 5, + }, + !obj:pylearn2.costs.dbm.WeightDecay { + coeffs: [ .0002, .0002 ], + }, + !obj:pylearn2.costs.dbm.TorontoSparsity { + targets: [ .2, .1 ], + coeffs: [ .001, .001 ], + } + ] + }, + termination_criterion: !obj:pylearn2.termination_criteria.EpochCounter { + max_epochs: %(max_epochs)i + }, + update_callbacks: [ + !obj:pylearn2.training_algorithms.sgd.CustomizedLROverEpoch { + } + ] + }, + extensions: [ + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + final_momentum: 0.9, + start: 1, + saturate: 6, + }, + ], + save_path: "%(save_path)s/dbm_mnist.pkl", + save_freq : %(max_epochs)i +} + diff --git a/pylearn2/scripts/papers/dbm/dbm_mnist_l1.yaml b/pylearn2/scripts/papers/dbm/dbm_mnist_l1.yaml new file mode 100644 index 0000000000..95bbd04c85 --- /dev/null +++ b/pylearn2/scripts/papers/dbm/dbm_mnist_l1.yaml @@ -0,0 +1,39 @@ +!obj:pylearn2.train.Train { + dataset: &data !obj:pylearn2.datasets.binarizer.Binarizer { + raw: &raw_train !obj:pylearn2.datasets.mnist.MNIST { + which_set: "train", + start: 0, + stop: %(train_stop)i + } + }, + model: !obj:pylearn2.models.rbm.RBM { + nvis : 784, + nhid : %(nhid)i, + irange : 0.001, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + learning_rate : 0.05, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: 0.5, + }, + batch_size : %(batch_size)i, + monitoring_batches : %(monitoring_batches)i, + monitoring_dataset : *data, + cost: !obj:pylearn2.costs.ebm_estimation.CDk { + nsteps : 1, + }, + termination_criterion : !obj:pylearn2.termination_criteria.EpochCounter { + max_epochs: %(max_epochs)i, + }, + }, + extensions: [ + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + start: 1, + saturate: 6, + final_momentum: 0.9, + }, + ], + save_path: "%(save_path)s/dbm_mnist_l1.pkl", + save_freq: %(max_epochs)i +} + diff --git a/pylearn2/scripts/papers/dbm/dbm_mnist_l2.yaml b/pylearn2/scripts/papers/dbm/dbm_mnist_l2.yaml new file mode 100644 index 0000000000..76d765e78c --- /dev/null +++ b/pylearn2/scripts/papers/dbm/dbm_mnist_l2.yaml @@ -0,0 +1,42 @@ +!obj:pylearn2.train.Train { + dataset: &train !obj:pylearn2.datasets.binarizer.Binarizer { + raw: !obj:pylearn2.datasets.transformer_dataset.TransformerDataset { + raw: !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'train', + start: 0, + stop: %(train_stop)i + }, + transformer: !pkl: "%(save_path)s/dbm_mnist_l1.pkl" + }, + }, + model: !obj:pylearn2.models.rbm.RBM { + nvis : %(nvis)i, + nhid : %(nhid)i, + irange : 0.01, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + learning_rate : 0.05, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: 0.5, + }, + batch_size : %(batch_size)i, + monitoring_batches : %(monitoring_batches)i, + monitoring_dataset : *train, + cost : !obj:pylearn2.costs.ebm_estimation.CDk { + nsteps : 5, + }, + termination_criterion : !obj:pylearn2.termination_criteria.EpochCounter { + max_epochs: %(max_epochs)i, + }, + }, + extensions: [ + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + start: 1, + saturate: 6, + final_momentum: 0.9, + }, + ], + save_path: "%(save_path)s/dbm_mnist_l2.pkl", + save_freq: %(max_epochs)i +} + diff --git a/pylearn2/scripts/papers/dbm/dbm_mnist_mlp.yaml b/pylearn2/scripts/papers/dbm/dbm_mnist_mlp.yaml new file mode 100644 index 0000000000..e467a9127b --- /dev/null +++ b/pylearn2/scripts/papers/dbm/dbm_mnist_mlp.yaml @@ -0,0 +1,54 @@ +!obj:pylearn2.train.Train { + dataset: &train !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'train', + start: 0, + stop: %(train_stop)i + }, + model: !obj:pylearn2.models.mlp.MLP { + batch_size: %(batch_size)i, + layers: [ + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h0', + dim: %(n_h0)i, + sparse_init: 15, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h1', + dim: %(n_h1)i, + sparse_init: 15, + }, + !obj:pylearn2.models.mlp.Softmax { + layer_name: 'y', + n_classes: 10, + irange: 0.05 + } + ], + nvis: %(nvis)i, + }, + algorithm: !obj:pylearn2.training_algorithms.bgd.BGD { + conjugate: 1, + line_search_mode: 'exhaustive', + updates_per_batch: 6, + monitoring_dataset: { + 'test': !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'test', + }, + }, + cost: !obj:pylearn2.costs.mlp.Default {}, + termination_criterion: !obj:pylearn2.termination_criteria.And { + criteria: [ + !obj:pylearn2.termination_criteria.EpochCounter { + max_epochs: %(max_epochs)i + } + ] + }, + }, + extensions: [ + !obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { + channel_name: 'test_y_misclass', + save_path: "%(save_path)s/dbm_mnist_mlp.pkl", + store_best_model: True + }, + ] +} + diff --git a/pylearn2/scripts/papers/dbm/dbm_mnist_mlp_dropout.yaml b/pylearn2/scripts/papers/dbm/dbm_mnist_mlp_dropout.yaml new file mode 100644 index 0000000000..1ac524dd21 --- /dev/null +++ b/pylearn2/scripts/papers/dbm/dbm_mnist_mlp_dropout.yaml @@ -0,0 +1,62 @@ +!obj:pylearn2.train.Train { + dataset: !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'train', + start: 0, + stop: %(train_stop)i, + }, + model: !obj:pylearn2.models.mlp.MLP { + batch_size: %(batch_size)i, + layers: [ + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h0', + dim: %(n_h0)i, + sparse_init: 15, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h1', + dim: %(n_h1)i, + sparse_init: 15, + }, + !obj:pylearn2.models.mlp.Softmax { + layer_name: 'y', + n_classes: 10, + istdev: 0.01, + } + ], + nvis: %(nvis)i, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + learning_rate: 1, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: 0.5, + }, + monitoring_dataset: { + 'test': !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'test', + }, + }, + cost: !obj:pylearn2.costs.mlp.dropout.Dropout { + input_include_probs: {'h0': .8, 'h1': 0.5, 'y': 0.5}, + }, + termination_criterion: !obj:pylearn2.termination_criteria.And { + criteria: [ + !obj:pylearn2.termination_criteria.EpochCounter { + max_epochs: %(max_epochs)i + } + ] + } + }, + extensions: [ + !obj:pylearn2.training_algorithms.sgd.MomentumAdjustor { + start: 1, + saturate: 500, + final_momentum: 0.99, + }, + !obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { + channel_name: 'test_y_misclass', + save_path: "%(save_path)s/dbm_mnist_mlp_dropout.pkl", + store_best_model: True + }, + ] +} + diff --git a/pylearn2/scripts/papers/dbm/dbm_mnist_softmax.yaml b/pylearn2/scripts/papers/dbm/dbm_mnist_softmax.yaml new file mode 100644 index 0000000000..eb46913176 --- /dev/null +++ b/pylearn2/scripts/papers/dbm/dbm_mnist_softmax.yaml @@ -0,0 +1,84 @@ +!obj:pylearn2.train.Train { + dataset: &data !obj:pylearn2.datasets.binarizer.Binarizer { + raw: &raw_train !obj:pylearn2.datasets.mnist.MNIST { + which_set: "train", + start: 0, + stop: %(train_stop)i + } + }, + model: !obj:pylearn2.models.dbm.DBM { + batch_size: %(batch_size)i, + niter: 10, + inference_procedure: !obj:pylearn2.models.dbm.WeightDoubling {}, + visible_layer: !obj:pylearn2.models.dbm.BinaryVector { + nvis: 784, + }, + hidden_layers: [ + !obj:pylearn2.models.dbm.BinaryVectorMaxPool { + layer_name: 'h1', + detector_layer_dim: %(n_h1)i, + pool_size: 1, + irange: 0.001, + }, + !obj:pylearn2.models.dbm.BinaryVectorMaxPool { + layer_name: 'h2', + detector_layer_dim: %(n_h2)i, + pool_size: 1, + irange: 0.001, + }, + !obj:pylearn2.models.dbm.Softmax { + layer_name: 'y', + irange: 0.001, + n_classes: 10 + } + ] + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + learning_rate: 0.005, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: 0.5, + }, + monitoring_batches: %(monitoring_batches)i, + monitoring_dataset: { + 'test': + !obj:pylearn2.datasets.binarizer.Binarizer { + raw: !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'test', + }, + } + }, + cost : !obj:pylearn2.costs.cost.SumOfCosts { + costs: [ + !obj:pylearn2.costs.dbm.VariationalPCD { + num_chains: 100, + num_gibbs_steps: 5, + supervised: 1, + toronto_neg: 1 + }, + !obj:pylearn2.costs.dbm.WeightDecay { + coeffs: [ .0002, .0002, 0.0002 ], + }, + ] + }, + termination_criterion: !obj:pylearn2.termination_criteria.EpochCounter { + max_epochs: %(max_epochs)i + }, + update_callbacks: [ + !obj:pylearn2.training_algorithms.sgd.CustomizedLROverEpoch { + } + ] + }, + extensions: [ + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + final_momentum: 0.9, + start: 1, + saturate: 6, + }, + !obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { + channel_name: 'test_misclass', + save_path: "%(save_path)s/dbm_mnist_softmax.pkl", + store_best_model: True + }, + ], +} + diff --git a/pylearn2/scripts/papers/dbm/tests/test_dbm_mnist.py b/pylearn2/scripts/papers/dbm/tests/test_dbm_mnist.py new file mode 100644 index 0000000000..204014406d --- /dev/null +++ b/pylearn2/scripts/papers/dbm/tests/test_dbm_mnist.py @@ -0,0 +1,247 @@ +""" +This is the test version that achieves the +Hinton & Salakhutdinov's results. +""" + +__authors__ = "Carlo D'Eramo, Francesco Visin, Matteo Matteucci" +__copyright__ = "Copyright 2014-2015, Politecnico di Milano" +__credits__ = ["Carlo D'Eramo, Francesco Visin, Matteo Matteucci"] +__license__ = "3-clause BSD" +__maintainer__ = "AIR-lab" +__email__ = "carlo.deramo@mail.polimi.it, francesco.visin@polimi.it, \ +matteo.matteucci@polimi.it" + +import os +import numpy + +from pylearn2.utils import serial +from pylearn2.config import yaml_parse +from pylearn2.testing import no_debug_mode +from pylearn2.datasets import mnist_augmented +from theano import function + +# PARAMETERS +N_HIDDEN_0 = 500 +N_HIDDEN_1 = 1000 + +PRETRAINING = 1 +TRAINING = 1 +FINETUNING = 1 + +# PRETRAINING +MAX_EPOCHS_L1 = 100 +MAX_EPOCHS_L2 = 200 + +# TRAINING +MAX_EPOCHS_DBM = 500 +SOFTMAX = 0 + +# FINETUNING +MAX_EPOCHS_MLP = 500 +DROPOUT = 0 +MF_STEPS = 1 # mf_steps for data augmentation + + +@no_debug_mode +def test_train_example(): + + """ + Test script. + + Parameters + ---------- + WRITEME + """ + + # path definition + cwd = os.getcwd() + train_path = cwd # training path is the current working directory + try: + os.chdir(train_path) + + # START PRETRAINING + # load and train first layer + train_yaml_path = os.path.join(train_path, 'dbm_mnist_l1.yaml') + layer1_yaml = open(train_yaml_path, 'r').read() + hyper_params_l1 = {'train_stop': 60000, + 'batch_size': 100, + 'monitoring_batches': 5, + 'nhid': N_HIDDEN_0, + 'max_epochs': MAX_EPOCHS_L1, + 'save_path': train_path + } + + if PRETRAINING: + + layer1_yaml = layer1_yaml % (hyper_params_l1) + train = yaml_parse.load(layer1_yaml) + + print("\n-----------------------------------" + " Unsupervised pre-training " + "-----------------------------------\n") + + print("\nPre-Training first layer...\n") + train.main_loop() + + # load and train second layer + train_yaml_path = os.path.join(train_path, 'dbm_mnist_l2.yaml') + layer2_yaml = open(train_yaml_path, 'r').read() + hyper_params_l2 = {'train_stop': 60000, + 'batch_size': 100, + 'monitoring_batches': 5, + 'nvis': hyper_params_l1['nhid'], + 'nhid': N_HIDDEN_1, + 'max_epochs': MAX_EPOCHS_L2, + 'save_path': train_path + } + + if PRETRAINING: + + layer2_yaml = layer2_yaml % (hyper_params_l2) + train = yaml_parse.load(layer2_yaml) + + print("\n...Pre-training second layer...\n") + train.main_loop() + + if TRAINING: + + # START TRAINING + if SOFTMAX: + train_yaml_path = os.path.join(train_path, + 'dbm_mnist_softmax.yaml') + else: + train_yaml_path = os.path.join(train_path, 'dbm_mnist.yaml') + yaml = open(train_yaml_path, 'r').read() + hyper_params_dbm = {'train_stop': 60000, + 'valid_stop': 60000, + 'batch_size': 100, + 'n_h1': hyper_params_l1['nhid'], + 'n_h2': hyper_params_l2['nhid'], + 'monitoring_batches': 5, + 'max_epochs': MAX_EPOCHS_DBM, + 'save_path': train_path + } + + yaml = yaml % (hyper_params_dbm) + train = yaml_parse.load(yaml) + + rbm1 = serial.load(os.path.join(train_path, 'dbm_mnist_l1.pkl')) + rbm2 = serial.load(os.path.join(train_path, 'dbm_mnist_l2.pkl')) + pretrained_rbms = [rbm1, rbm2] + + # clamp pretrained weights into respective dbm layers + for h, l in zip(train.model.hidden_layers, pretrained_rbms): + h.set_weights(l.get_weights()) + + # clamp pretrained biases into respective dbm layers + bias_param = pretrained_rbms[0].get_params()[1] + fun = function([], bias_param) + cuda_bias = fun() + bias = numpy.asarray(cuda_bias) + train.model.visible_layer.set_biases(bias) + bias_param = pretrained_rbms[-1].get_params()[1] + fun = function([], bias_param) + cuda_bias = fun() + bias = numpy.asarray(cuda_bias) + train.model.hidden_layers[0].set_biases(bias) + bias_param = pretrained_rbms[-1].get_params()[2] + fun = function([], bias_param) + cuda_bias = fun() + bias = numpy.asarray(cuda_bias) + train.model.hidden_layers[1].set_biases(bias) + + print("\nAll layers weights and biases have been clamped " + "to the respective layers of the DBM") + + print("\n-----------------------------------" + " Unsupervised training " + "-----------------------------------\n") + + print("\nTraining phase...") + train.main_loop() + + if FINETUNING: + + # START SUPERVISED TRAINING WITH BACKPROPAGATION + print("\n-----------------------------------" + " Supervised training " + "-----------------------------------\n") + + # load dbm as a mlp + if DROPOUT: + train_yaml_path = os.path.join(train_path, + 'dbm_mnist_mlp_dropout.yaml') + else: + train_yaml_path = os.path.join(train_path, + 'dbm_mnist_mlp.yaml') + mlp_yaml = open(train_yaml_path, 'r').read() + hyper_params_mlp = {'train_stop': 60000, + 'valid_stop': 60000, + 'batch_size': 5000, + 'nvis': 784 + hyper_params_l2['nhid'], + 'n_h0': hyper_params_l1['nhid'], + 'n_h1': hyper_params_l2['nhid'], + 'max_epochs': MAX_EPOCHS_MLP, + 'save_path': train_path + } + + mlp_yaml = mlp_yaml % (hyper_params_mlp) + train = yaml_parse.load(mlp_yaml) + + if SOFTMAX: + dbm = serial.load(os.path.join(train_path, + 'dbm_mnist_softmax.pkl')) + else: + dbm = serial.load(os.path.join(train_path, + 'dbm_mnist.pkl')) + + train.dataset = mnist_augmented.MNIST_AUGMENTED( + dataset=train.dataset, + which_set='train', + one_hot=1, + model=dbm, start=0, + stop=hyper_params_mlp['train_stop'], + mf_steps=MF_STEPS) + train.algorithm.monitoring_dataset = { + ''' + 'valid': mnist_augmented.MNIST_AUGMENTED( + dataset=train.algorithm.monitoring_dataset['valid'], + which_set='train', one_hot=1, model=dbm, + start=hyper_params_mlp['train_stop'], + stop=hyper_params_mlp['valid_stop'], + mf_steps=mf_steps), + ''' + 'test': mnist_augmented.MNIST_AUGMENTED( + dataset=train.algorithm.monitoring_dataset['test'], + which_set='test', one_hot=1, model=dbm, + mf_steps=MF_STEPS)} + + # DBM TRAINED WEIGHTS CLAMPED FOR FINETUNING AS + # EXPLAINED BY HINTON + + # Concatenate weights between first and second hidden + # layer & weights between visible and first hidden layer + train.model.layers[0].set_weights( + numpy.concatenate(( + dbm.hidden_layers[1].get_weights().transpose(), + dbm.hidden_layers[0].get_weights()))) + + # then clamp all the others normally + for l, h in zip(train.model.layers[1:], dbm.hidden_layers[1:]): + l.set_weights(h.get_weights()) + + # clamp biases + for l, h in zip(train.model.layers, dbm.hidden_layers): + l.set_biases(h.get_biases()) + + print("\nDBM trained weights and biases have been " + "clamped in the MLP.") + + print("\n...Finetuning...\n") + train.main_loop() + + finally: + os.chdir(cwd) + +if __name__ == '__main__': + test_train_example() diff --git a/pylearn2/scripts/papers/dbm/tests/test_dbm_mnist_fast.py b/pylearn2/scripts/papers/dbm/tests/test_dbm_mnist_fast.py new file mode 100644 index 0000000000..9a441ba75c --- /dev/null +++ b/pylearn2/scripts/papers/dbm/tests/test_dbm_mnist_fast.py @@ -0,0 +1,205 @@ +""" +This is a fast version of test script +for DBM training. +""" + +__authors__ = "Carlo D'Eramo, Francesco Visin, Matteo Matteucci" +__copyright__ = "Copyright 2014-2015, Politecnico di Milano" +__credits__ = ["Carlo D'Eramo, Francesco Visin, Matteo Matteucci"] +__license__ = "3-clause BSD" +__maintainer__ = "AIR-lab" +__email__ = "carlo.deramo@mail.polimi.it, francesco.visin@polimi.it, \ +matteo.matteucci@polimi.it" + +import os +import numpy + +from pylearn2.utils import serial +from pylearn2.config import yaml_parse +from pylearn2.testing import no_debug_mode +from pylearn2.datasets import mnist_augmented +from theano import function + + +@no_debug_mode +def test_train_example(): + + """ + Fast test script. + + Parameters + ---------- + WRITEME + """ + + # path definition + cwd = os.getcwd() + train_path = cwd # train path is the current working directory + try: + os.chdir(train_path) + + # START PRETRAINING + # load and train first layer + train_yaml_path = os.path.join(train_path, '..', 'dbm_mnist_l1.yaml') + layer1_yaml = open(train_yaml_path, 'r').read() + hyper_params_l1 = {'train_stop': 20, + 'batch_size': 5, + 'monitoring_batches': 5, + 'nhid': 100, + 'max_epochs': 10, + 'save_path': train_path + } + + layer1_yaml = layer1_yaml % (hyper_params_l1) + train = yaml_parse.load(layer1_yaml) + + print("\n-----------------------------------" + " Unsupervised pre-training' " + "-----------------------------------\n") + + print("\nPre-Training first layer...\n") + train.main_loop() + + # load and train second layer + train_yaml_path = os.path.join(train_path, '..', 'dbm_mnist_l2.yaml') + layer2_yaml = open(train_yaml_path, 'r').read() + hyper_params_l2 = {'train_stop': 20, + 'batch_size': 5, + 'monitoring_batches': 5, + 'nvis': hyper_params_l1['nhid'], + 'nhid': 200, + 'max_epochs': 20, + 'save_path': train_path + } + + layer2_yaml = layer2_yaml % (hyper_params_l2) + train = yaml_parse.load(layer2_yaml) + + print("\n...Pre-training second layer...\n") + train.main_loop() + + # START TRAINING + train_yaml_path = os.path.join(train_path, '..', 'dbm_mnist.yaml') + yaml = open(train_yaml_path, 'r').read() + hyper_params_dbm = {'train_stop': 20, + 'valid_stop': 20, + 'batch_size': 5, + 'n_h1': hyper_params_l1['nhid'], + 'n_h2': hyper_params_l2['nhid'], + 'monitoring_batches': 5, + 'max_epochs': 5, + 'save_path': train_path + } + + yaml = yaml % (hyper_params_dbm) + train = yaml_parse.load(yaml) + + rbm1 = serial.load(os.path.join(train_path, 'dbm_mnist_l1.pkl')) + rbm2 = serial.load(os.path.join(train_path, 'dbm_mnist_l2.pkl')) + pretrained_rbms = [rbm1, rbm2] + + # clamp pretrained weights into respective dbm layers + for h, l in zip(train.model.hidden_layers, pretrained_rbms): + h.set_weights(l.get_weights()) + + # clamp pretrained biases into respective dbm layers + bias_param = pretrained_rbms[0].get_params()[1] + fun = function([], bias_param) + cuda_bias = fun() + bias = numpy.asarray(cuda_bias) + train.model.visible_layer.set_biases(bias) + bias_param = pretrained_rbms[-1].get_params()[1] + fun = function([], bias_param) + cuda_bias = fun() + bias = numpy.asarray(cuda_bias) + train.model.hidden_layers[0].set_biases(bias) + bias_param = pretrained_rbms[-1].get_params()[2] + fun = function([], bias_param) + cuda_bias = fun() + bias = numpy.asarray(cuda_bias) + train.model.hidden_layers[1].set_biases(bias) + + print("\nAll layers weights and biases have been clamped " + "to the respective layers of the DBM") + + print("\n-----------------------------------" + " Unsupervised training' " + "-----------------------------------\n") + + print("\nTraining phase...") + train.main_loop() + + # START SUPERVISED TRAINING WITH BACKPROPAGATION + print("\n-----------------------------------" + " Supervised training' " + "-----------------------------------\n") + + # load dbm as a mlp + train_yaml_path = os.path.join(train_path, '..', 'dbm_mnist_mlp.yaml') + + mlp_yaml = open(train_yaml_path, 'r').read() + hyper_params_mlp = {'train_stop': 20, + 'valid_stop': 20, + 'batch_size': 5, + 'nvis': 784 + hyper_params_l2['nhid'], + 'n_h0': hyper_params_l1['nhid'], + 'n_h1': hyper_params_l2['nhid'], + 'max_epochs': 20, + 'save_path': train_path + } + + mlp_yaml = mlp_yaml % (hyper_params_mlp) + train = yaml_parse.load(mlp_yaml) + + dbm = serial.load(os.path.join(train_path, 'dbm_mnist.pkl')) + + train.dataset = mnist_augmented.MNIST_AUGMENTED( + dataset=train.dataset, + which_set='train', + one_hot=1, + model=dbm, + start=0, + stop=hyper_params_mlp['train_stop'], + mf_steps=1) + train.algorithm.monitoring_dataset = { + ''' + 'valid': mnist_augmented.MNIST_AUGMENTED( + dataset=train.algorithm.monitoring_dataset['valid'], + which_set='train', one_hot=1, model=dbm, + start=hyper_params_mlp['train_stop'], + stop=hyper_params_mlp['valid_stop'], mf_steps=1), + ''' + 'test': mnist_augmented.MNIST_AUGMENTED( + dataset=train.algorithm.monitoring_dataset['test'], + which_set='test', one_hot=1, model=dbm, start=0, + stop=20, mf_steps=1) + } + + # DBM TRAINED WEIGHTS CLAMPED FOR FINETUNING AS + # EXPLAINED BY HINTON + + # Concatenate weights between first and second hidden layer & + # weights between visible and first hidden layer + train.model.layers[0].set_weights( + numpy.concatenate(( + dbm.hidden_layers[1].get_weights().transpose(), + dbm.hidden_layers[0].get_weights()))) + + # then clamp all the others normally + for l, h in zip(train.model.layers[1:], dbm.hidden_layers[1:]): + l.set_weights(h.get_weights()) + + # clamp biases + for l, h in zip(train.model.layers, dbm.hidden_layers): + l.set_biases(h.get_biases()) + + print("\nDBM trained weights and biases have been clamped in the MLP.") + + print("\n...Finetuning...\n") + train.main_loop() + + finally: + os.chdir(cwd) + +if __name__ == '__main__': + test_train_example() diff --git a/pylearn2/training_algorithms/sgd.py b/pylearn2/training_algorithms/sgd.py index ce24008e80..a42d3738b1 100644 --- a/pylearn2/training_algorithms/sgd.py +++ b/pylearn2/training_algorithms/sgd.py @@ -1143,3 +1143,26 @@ def on_monitor(self, model, dataset, algorithm): for param in model.get_params(): param.set_value(saved_params[param]) self._count += 1 + +class CustomizedLROverEpoch(object): + """ + Assigns a different value to learning rate at each + training epoch. This is used when the value to be assigned + is not reproducible with linear or exponential decay. + """ + def __init__(self): + self.__dict__.update(locals()) + del self.self + self._count = 0 + self.new_lr = 0 + + def __call__(self, algorithm): + if self._count == 0: + self._new_lr = algorithm.learning_rate.get_value() + else: + self.new_lr = 10 / (2000 + self._count) # new learning rate will have this value + + self._count += 1 + new_lr = np.cast[config.floatX](self.new_lr) + algorithm.learning_rate.set_value(new_lr) +