-
Notifications
You must be signed in to change notification settings - Fork 4
/
self_taught_learning.m
115 lines (82 loc) · 4.22 KB
/
self_taught_learning.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
function acc = self_taught_learning(params)
inputSize = 28 * 28;
numLabels = 5;
hiddenSize = params.hiddenSize;
sparsityParam = params.sparsityParam; % desired average activation of the hidden units.
% (This was denoted by the Greek alphabet rho, which looks like a lower-case "p",
% in the lecture notes).
lambda = params.lambda; % weight decay parameter
beta = params.beta; % weight of sparsity penalty term
maxIter = 100;
%% ======================================================================
% STEP 1: Load data from the MNIST database
%
% This loads our training and test data from the MNIST database files.
% We have sorted the data for you in this so that you will not have to
% change it.
% Load MNIST database files
mnistData = loadMNISTImages('train-images.idx3-ubyte');
mnistLabels = loadMNISTLabels('train-labels.idx1-ubyte');
% Set Unlabeled Set (All Images)
% Simulate a Labeled and Unlabeled set
labeledSet = find(mnistLabels >= 0 & mnistLabels <= 4);
unlabeledSet = find(mnistLabels >= 5);
numTrain = round(numel(labeledSet)/2);
trainSet = labeledSet(1:numTrain);
testSet = labeledSet(numTrain+1:end);
unlabeledData = mnistData(:, unlabeledSet);
trainData = mnistData(:, trainSet);
trainLabels = mnistLabels(trainSet)' + 1; % Shift Labels to the Range 1-5
testData = mnistData(:, testSet);
testLabels = mnistLabels(testSet)' + 1; % Shift Labels to the Range 1-5
% Output Some Statistics
fprintf('# examples in unlabeled set: %d\n', size(unlabeledData, 2));
fprintf('# examples in supervised training set: %d\n\n', size(trainData, 2));
fprintf('# examples in supervised testing set: %d\n\n', size(testData, 2));
%% ======================================================================
% STEP 2: Train the sparse autoencoder
% This trains the sparse autoencoder on the unlabeled training
% images.
% Randomly initialize the parameters
theta = initializeParameters(hiddenSize, inputSize);
opttheta = theta;
addpath minFunc/
options.Method = 'lbfgs'; % Here, we use L-BFGS to optimize our cost
% function. Generally, for minFunc to work, you
% need a function pointer with two outputs: the
% function value and the gradient. In our problem,
% sparseAutoencoderCost.m satisfies this.
options.maxIter = 100; % Maximum number of iterations of L-BFGS to run
options.display = 'on';
[opttheta, cost] = minFunc( @(p) sparseAutoencoderCost(p, ...
inputSize, hiddenSize, ...
lambda, sparsityParam, ...
beta, unlabeledData), ...
theta, options);
%% -----------------------------------------------------
% Visualize weights
W1 = reshape(opttheta(1:hiddenSize * inputSize), hiddenSize, inputSize);
% display_network(W1');
%%======================================================================
%% STEP 3: Extract Features from the Supervised Dataset
trainFeatures = feedForwardAutoencoder(opttheta, hiddenSize, inputSize, ...
trainData);
testFeatures = feedForwardAutoencoder(opttheta, hiddenSize, inputSize, ...
testData);
%%======================================================================
%% STEP 4: Train the softmax classifier
softmaxModel = struct;
% Use lambda = 1e-4 for the weight regularization for softmax
% compute softmaxModel using softmaxTrain on trainFeatures and
% trainLabels
numClasses = 5; % use a labeled dataset with the digits 1 to 5 with which to train the softmax classifier.
options.maxIter = 100;
softmaxModel = softmaxTrain(hiddenSize, numClasses, 1e-4, ...
trainFeatures, trainLabels, options);
%%======================================================================
%% STEP 5: Testing
[pred] = softmaxPredict(softmaxModel, testFeatures);
%% -----------------------------------------------------
% Classification Score
acc = 100*mean(pred(:) == testLabels(:));
end