-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
234 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
Score Matching for ICA / TICA | ||
|
||
* Trains using minFunc | ||
* Dataset available at http://cs.stanford.edu/~jngiam/data/patches.mat | ||
|
||
Quick Start | ||
=========== | ||
|
||
Run "runExample" in Matlab | ||
|
||
|
||
Notes | ||
===== | ||
|
||
* Uses a smaller dataset (e.g., data = data(:, 1:20000)) for faster training | ||
|
Submodule common
updated
49 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
%% Initialize Checks | ||
nHidden = 16; | ||
nInput = 4; | ||
W = randn(nHidden, 4); | ||
X = randn(4, 2); | ||
modelScoreMatchingFunc = @(W, X) icaScoreMatching (W, X, nHidden, nInput); | ||
|
||
%% Check Xg | ||
[loss, grad] = smCheckXgLoss(X(:), W(:), X, modelScoreMatchingFunc); | ||
numgrad = computeNumericalGradient( @(p) smCheckXgLoss(p, W, X, modelScoreMatchingFunc), X(:)); | ||
|
||
numgrad=numgrad(:); grad = grad(:); | ||
|
||
% Use this to visually compare the gradients side by side | ||
disp([numgrad grad]); | ||
|
||
% Compare numerically computed gradients wit the ones obtained from backpropagation | ||
diff = norm(numgrad-grad)/norm(numgrad+grad); | ||
disp(diff); % Should be small. In our implementation, these values are | ||
% usually less than 1e-9. | ||
|
||
%% Check Xg2 | ||
% [smloss, paramgrad, energy, Xg, Xg2, Xg3] = rbmScoreMatching(modelParams, X); | ||
numgrad = zeros(numel(X),1); | ||
grad = zeros(numel(X),1); | ||
for Xidx = 1:numel(X) | ||
params = X(Xidx); | ||
[loss, grad(Xidx)] = smCheckXg2Loss(params, W, X, modelScoreMatchingFunc, Xidx); | ||
numgrad(Xidx) = computeNumericalGradient( @(p) smCheckXg2Loss(p, W, X, modelScoreMatchingFunc, Xidx), params); | ||
end | ||
|
||
numgrad=numgrad(:); grad = grad(:); | ||
|
||
disp([numgrad grad]); | ||
|
||
diff = norm(numgrad-grad)/norm(numgrad+grad); | ||
disp(diff); | ||
|
||
%% Check Param Gradient | ||
|
||
[loss, grad] = icaScoreMatching(W, X, nHidden, nInput); | ||
numgrad = computeNumericalGradient( @(p) icaScoreMatching(p, X, nHidden, nInput), W(:)); | ||
|
||
disp([numgrad grad]); | ||
|
||
% Compare numerically computed gradients with the ones obtained from backpropagation | ||
diff = norm(numgrad-grad)/norm(numgrad+grad); | ||
disp(diff); % Should be small. In our implementation, these values are | ||
% usually less than 1e-9. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
function [loss, paramgrad] = smCheckXg2Loss(params, W, X, modelScoreFunc, Xidx) | ||
|
||
X(Xidx) = params; | ||
|
||
[sml, pg, energy, Xg, paramgrad] = modelScoreFunc(W, X); | ||
|
||
loss = Xg(Xidx); | ||
paramgrad = paramgrad(Xidx); | ||
|
||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
function [loss, paramgrad] = smCheckXgLoss(params, W, X, modelScoreFunc) | ||
|
||
[sml, pg, energy, paramgrad, Xg2] = modelScoreFunc(W, reshape(params, size(X))); | ||
loss = sum(energy); | ||
paramgrad = paramgrad(:); | ||
|
||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
function [smloss, grad, energy, Xg, Xg2] = ... | ||
icaScoreMatching(W, X, nHidden, nInput) | ||
|
||
% Save Original Weights | ||
oW = reshape(W, nHidden, nInput); | ||
|
||
% Force onto unit ball | ||
W = l2row(oW); | ||
|
||
lambda = 0; % weight decay | ||
epsilon = 1; | ||
smReg = 0; % Regularized score matching? | ||
|
||
F = W*X; | ||
absF = sqrt(epsilon + (W*X).^2); | ||
energy = 0; | ||
%% Compute Energy (only if interested, for checking gradients) | ||
if nargout > 2 | ||
energy = energy + sum(absF); | ||
end | ||
|
||
%% Compute del E/del X | ||
oX = F ./ absF; | ||
Xg = W' * oX; | ||
|
||
%% Score Matching Loss | ||
smloss = 0.5*sum(sum(Xg.^2)); | ||
|
||
%% Compute del^2 E/del X_i ^2 | ||
oX2 = (1 - oX.^2) ./ absF; | ||
Xg2 = (W.^2)' * oX2; | ||
|
||
smloss = smloss + sum(sum(-Xg2, 1)) + smReg * sum(sum(Xg2.^2)) ; | ||
|
||
%% Weight Regularization , note that we sum over examples so * size(X,2) | ||
smloss = smloss + 0.5 * lambda * size(X,2) * sum(W(:).^2); | ||
|
||
%% Scale smloss to be nicer | ||
smloss = smloss / size(X, 2); | ||
|
||
%% Done if we are not interested in gradients | ||
if nargout <= 1 | ||
return | ||
end | ||
|
||
%% Compute Gradient for 0.5*sumsmCheck(sum(Xg.^2)) | ||
if nargout > 1 | ||
Wgrad = oX * Xg'; | ||
|
||
Xg2a = W * Xg; | ||
Xg2a = Xg2a .* oX2; | ||
|
||
Wgrad = Wgrad + Xg2a * X'; | ||
end | ||
|
||
%% Compute Gradient for sum(sum(-Xg2, 1)) + smReg * sum(sum(Xg2.^2)) | ||
if nargout > 1 | ||
|
||
vhg2 = oX2 * (- 1 + 2 * smReg * Xg2)'; | ||
Wgrad = Wgrad + 2 * W .* vhg2; | ||
|
||
Xg3a = (W.^2) * ( - 1 + 2 * smReg * Xg2); | ||
Xg3a = Xg3a .* (-3 * oX .* (1 - oX.^2) ./ (absF.^2)); % back prop | ||
|
||
Wgrad = Wgrad + Xg3a * X'; | ||
end | ||
|
||
|
||
%% Weight Regularization | ||
Wgrad = Wgrad + lambda * size(X,2) * W; | ||
|
||
%% Compress for returning | ||
Wgrad = l2rowg(oW, W, Wgrad); | ||
|
||
grad = [Wgrad(:)] / size(X, 2); | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
%% Startup (addpaths) | ||
startup | ||
|
||
%% Clear | ||
clear ; close all ; clc ; | ||
|
||
%% Load Data | ||
% You can obtain patches.mat from | ||
% http://cs.stanford.edu/~jngiam/data/patches.mat | ||
|
||
fprintf('Loading Data\n'); | ||
|
||
% Loads a variable data (size 256x50000) | ||
load patches.mat | ||
|
||
% Reduce dataset size for faster training | ||
data = data(:, 1:20000); | ||
|
||
%% PCA Whitening | ||
fprintf('\nPCA Whitening\n'); | ||
|
||
% Remove DC | ||
data = bsxfun(@minus, data, mean(data, 1)); | ||
|
||
% Remove the "mean" patch | ||
data = bsxfun(@minus, data, mean(data, 2)); | ||
|
||
% Compute Covariance Matrix and Eigenstuff | ||
cov = data * data' / size(data, 2); | ||
[E,D] = eig(cov); | ||
d = diag(D); | ||
|
||
% Sort eigenvalues in descending order | ||
[dsort, idx] = sort(d, 'descend'); | ||
|
||
% PCA Whitening (and pick top 99% of eigenvalues) | ||
dsum = cumsum(dsort); | ||
dcutoff = find(dsum > 0.99 * dsum(end), 1); | ||
E = E(:, idx(1:dcutoff)); | ||
d = d(idx(1:dcutoff)); | ||
V = diag(1./sqrt(d+1e-6)) * E'; | ||
|
||
%% Whiten the data | ||
whiteData = V * data; | ||
|
||
%% Run the optimization with minFunc (ICA) | ||
fprintf('\nTraining ICA (w/ Score Matching)\n\n'); | ||
nHidden = 400; nInput = size(whiteData, 1); | ||
W = randn(nHidden, nInput); | ||
options.Method = 'lbfgs'; | ||
options.maxIter = 100; % Maximum number of iterations of L-BFGS to run | ||
options.display = 'on'; | ||
|
||
tic | ||
[optW, cost] = minFunc( @icaScoreMatching, ... | ||
W(:), options, whiteData, ... | ||
nHidden, nInput); | ||
toc | ||
|
||
%% Display Results | ||
optW = reshape(optW, nHidden, nInput); | ||
displayData(optW * V); | ||
|
||
fprintf('ICA Training Completed.\n'); | ||
fprintf('Press Enter to Continue.\n\n'); | ||
pause |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
%% Add Paths | ||
addpath ([pwd filesep 'common']); | ||
addpath ([pwd filesep 'common/minFunc']); | ||
addpath ([pwd filesep 'gradientChecks']); | ||
addpath ([pwd filesep 'objFunc']); |