-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathdemo_ext.m
89 lines (66 loc) · 2.75 KB
/
demo_ext.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
function demo_ext()
% demonstration file for SGDLibrary.
%
% This file illustrates how to use this library in case of linear
% regression problem. This demonstrates SGD and SVRG algorithms.
%
% This file is part of SGDLibrary.
%
% Created by H.Kasai on Oct. 24, 2016
% Modified by H.Kasai on Nov. 03, 2016
clc;
clear;
close all;
%% generate synthetic data
% set number of dimensions
d = 2;
% set number of samples
n = 300;
% generate data
data = logistic_regression_data_generator(n, d);
%% define problem definitions
problem = logistic_regression(data.x_train, data.y_train, data.x_test, data.y_test);
%% calculate optimal solution for optimality gap
w_opt = problem.calc_solution(1000);
options.f_opt = problem.cost(w_opt);
%% set options for convergence animation
options.max_epoch = 100;
options.store_w = true;
%% perform algorithms SGD and SVRG
options.w_init = data.w_init;
options.step_init = 0.01;
[w_sgd, info_sgd] = sgd(problem, options);
[w_svrg, info_svrg] = svrg(problem, options);
%% display cost/optimality gap vs number of gradient evaluations
display_graph('grad_calc_count','cost', {'SGD', 'SVRG'}, {w_sgd, w_svrg}, {info_sgd, info_svrg});
display_graph('grad_calc_count','optimality_gap', {'SGD', 'SVRG'}, {w_sgd, w_svrg}, {info_sgd, info_svrg});
%% calculate classification accuracy
% for SGD
% predict
y_pred_sgd = problem.prediction(w_sgd);
% calculate accuracy
accuracy_sgd = problem.accuracy(y_pred_sgd);
fprintf('Classificaiton accuracy: %s: %.4f\n', 'SGD', accuracy_sgd);
% convert from {1,-1} to {1,2}
y_pred_sgd(y_pred_sgd==-1) = 2;
y_pred_sgd(y_pred_sgd==1) = 1;
% for SVRG
% predict
y_pred_svrg = problem.prediction(w_svrg);
% calculate accuracy
accuracy_svrg = problem.accuracy(y_pred_svrg);
fprintf('Classificaiton accuracy: %s: %.4f\n', 'SVRG', accuracy_svrg);
% convert from {1,-1} to {1,2}
y_pred_svrg(y_pred_svrg==-1) = 2;
y_pred_svrg(y_pred_svrg==1) = 1;
%% display classification results
% convert from {1,-1} to {1,2}
data.y_train(data.y_train==-1) = 2;
data.y_train(data.y_train==1) = 1;
data.y_test(data.y_test==-1) = 2;
data.y_test(data.y_test==1) = 1;
% display results
display_classification_result(problem, {'SGD', 'SVRG'}, {w_sgd, w_svrg}, {y_pred_sgd, y_pred_svrg}, {accuracy_sgd, accuracy_svrg}, data.x_train, data.y_train, data.x_test, data.y_test);
%% display convergence animation
draw_convergence_animation(problem, {'SGD', 'SVRG'}, {info_sgd.w, info_svrg.w}, options.max_epoch, 0.1);
end