-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscikit_train_supervised.m
68 lines (58 loc) · 2.53 KB
/
scikit_train_supervised.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
% This file is part of scikit-from-matlab.
%
% scikit-from-matlab is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% scikit-from-matlab is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with scikit-from-matlab. If not, see <https://www.gnu.org/licenses/>.
% Author: Abhishek Jaiantilal ([email protected])
% scikit-from-matlab 0.0.1
function model = scikit_train_supervised(Xtrn, Ytrn, algo_name, algo_params, CV_strategy, CV_params_for_algo, CV_params)
%will send the data to python. run CV or default parameters to find best CV parameters
if nargin < 3
error('Atleast supply an algo name')
end
if ~exist('algo_params','var')
algo_params = py.dict;
end
%if both CV_params_for_algo and CV_params are not available
%just don't do CV
if ~exist('CV_strategy','var')
DoCV = false;
else
DoCV = true;
end
if ~exist('CV_params_for_algo','var')
CV_params_for_algo = py.dict;
end
if ~exist('CV_params','var')
CV_params = py.dict;
end
%note that array data is represented differently in C (xgboost) / matlab
%colum major, row major. so what we do is flip it before sending
Xtrn = Xtrn';
mod = py.importlib.import_module('scikit_train_predict_supervised');
if str2num(pyversion) >= 3
py.importlib.reload(mod); %python >= version 3
else
py.reload(mod); %python == version 2.xx
end
%note how size(,2), size(,1) is shown because we have flipped the data
%so that the order is preserved, but when reshaped via numpy on python
%side what we do is reshape to the original N,D rather than the
%transposed D,N size
pyXtrn = py.list(cell({Xtrn(:).', int32(size(Xtrn,2)), int32(size(Xtrn,1))}));
pyYtrn = py.list(cell({Ytrn(:).', int32(size(Ytrn,1)), int32(size(Ytrn,2))}));
if DoCV
model = py.scikit_train_predict_supervised.trainCV(pyXtrn, pyYtrn, algo_name, algo_params, CV_strategy, CV_params_for_algo, CV_params);
else
model = py.scikit_train_predict_supervised.train(pyXtrn, pyYtrn, algo_name, algo_params);
end
end