Skip to content

Commit

Permalink
Merge pull request #92 from memimo/predict
Browse files Browse the repository at this point in the history
Predict function for logistic regression.
  • Loading branch information
nouiz committed Jul 2, 2015
2 parents 15c5442 + efa74a1 commit 55db5cd
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 0 deletions.
3 changes: 3 additions & 0 deletions code/convolutional_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2, 2)):
# store parameters of this layer
self.params = [self.W, self.b]

# keep track of model input
self.input = input


def evaluate_lenet5(learning_rate=0.1, n_epochs=200,
dataset='mnist.pkl.gz',
Expand Down
3 changes: 3 additions & 0 deletions code/logistic_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __init__(self, input, n_in, n_out):
# symbolic form
self.y_pred = T.argmax(self.p_y_given_x, axis=1)

# keep track of model input
self.input = input

def negative_log_likelihood(self, y):
"""Return the negative log-likelihood of the prediction of this model
under a given target distribution.
Expand Down
33 changes: 33 additions & 0 deletions code/logistic_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def __init__(self, input, n_in, n_out):
# parameters of the model
self.params = [self.W, self.b]

# keep track of model input
self.input = input

def negative_log_likelihood(self, y):
"""Return the mean of the negative log-likelihood of the prediction
of this model under a given target distribution.
Expand Down Expand Up @@ -415,6 +418,10 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
)
)

# save the best model
with open('best_model.pkl', 'w') as f:
cPickle.dump(classifier, f)

if patience <= iter:
done_looping = True
break
Expand All @@ -433,5 +440,31 @@ def sgd_optimization_mnist(learning_rate=0.13, n_epochs=1000,
os.path.split(__file__)[1] +
' ran for %.1fs' % ((end_time - start_time)))


def predict():
"""
An example of how to load a trained model and use it
to predict labels.
"""

# load the saved model
classifier = cPickle.load(open('best_model.pkl'))

# compile a predictor function
predict_model = theano.function(
inputs=[classifier.input],
outputs=classifier.y_pred)

# We can test it on some examples from test test
dataset='mnist.pkl.gz'
datasets = load_data(dataset)
test_set_x, test_set_y = datasets[2]
test_set_x = test_set_x.get_value()

predicted_values = predict_model(test_set_x[:10])
print ("Predicted values for the first 10 examples in test set:")
print predicted_values


if __name__ == '__main__':
sgd_optimization_mnist()
3 changes: 3 additions & 0 deletions code/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def __init__(self, rng, input, n_in, n_hidden, n_out):
self.params = self.hiddenLayer.params + self.logRegressionLayer.params
# end-snippet-3

# keep track of model input
self.input = input


def test_mlp(learning_rate=0.01, L1_reg=0.00, L2_reg=0.0001, n_epochs=1000,
dataset='mnist.pkl.gz', batch_size=20, n_hidden=500):
Expand Down
13 changes: 13 additions & 0 deletions doc/logreg.txt
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ approximately 1.936 epochs/sec and it took 75 epochs to reach a test
error of 7.489%. On the GPU the code does almost 10.0 epochs/sec. For this
instance we used a batch size of 600.


Prediction Using a Trained Model
+++++++++++++++++++++++++++++++

``sgd_optimization_mnist`` serialize and pickle the model each time new
lowest validation error is reached. We can reload this model and predict
labels of new data. ``predict`` function shows an example of how
this could be done.

.. literalinclude:: ../code/logistic_sgd.py
:pyobject: predict


.. rubric:: Footnotes

.. [#f1] For smaller datasets and simpler models, more sophisticated descent
Expand Down

0 comments on commit 55db5cd

Please sign in to comment.