-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsvrFreeParam.m
150 lines (137 loc) · 5.1 KB
/
svrFreeParam.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
% ---------------------------------------------------------------------------------------------------
%
% svrFreeParam.m
%
% This function performs the search of the free parameters of the $\varepsilon$-Huber SVR model.
% the function takes the training and test sets and performs the non-exhaustive
% iterative search strategy was used here. Basically, at each iteration (T iterations),
% a sequential search on every parameter domain is performed by splitting the range of
% the parameter in K linearly or logarithmically equally spaced points. Values of T=3
% and K=20 exhibited good performance in our simulations.
%
% INPUTS:
% Xtr ............. Training set (features in columns, samples in rows).
% Ytr ............. Actual value to be estimated/predicted in the training set.
% x_tst ............. Test set (features in columns, samples in rows).
% y_tst ............. Actual value to be estimated/predicted in the test set.
% yker ............. Label for the kernel type (at this moment only RBF kernel, yker = 'rbf').
% dibuja ............. Label for plotting. If =1 plots the results in each iteration
% see \cite{tr:Camps03a})
%
% OUTPUTS:
% nsv ............. Number of support vectors.
% Xsv ............. Support vector matrix.
% svs ............. (Alpha-Alpha^*).
% bias ............. Bias of the best model (b).
% sigmay ............. Best RBF kernel width.
% C ............. Best penalization factor.
% epsil ............. Best epsilon-insensitivity zone.
% gam ............. Best gamma parameter in the $\varepsilon$-Huber cost function.
%
% José L. Rojo-Álvarez & Gustavo Camps-Valls
% jlrojo@tsc.uc3m.es, gcamps@uv.es
% 2004(c)
%
% ---------------------------------------------------------------------------------------------------
function [nsv,Xsv,svs,bias,sigmay,C,epsil,gam] = svrFreeParam(Xtr,Ytr,x_tst,y_tst,yker,dibuja)
% Initializing the parameters
epsil = 1e-7
sigmay = 1.5
C = 1e5
gam = 1e-1
% Initializing the algorithm
Npoints = 4; % Nro de puntos para la busqueda
oldError = 1e10; % Solo se actualiza el parametro si el error disminuye
lowC = 3;
upC = 7;
lowgam = -10;
upgam = -5;
loweps = 0;
upeps = 1e-4;
lowsigmay = 1;
upsigmay = 3;
% Main Loop
error_parcial=[];
warning off
K=3; % iterations
for vuelta=1:K
disp(['Iteration ' num2str(vuelta) ' of ' num2str(K)])
% Search in gam
err=zeros(Npoints,1);
ggam=logspace(lowgam,upgam,Npoints);
for i=1:Npoints
[nsv,svs,bias] = mysvr(Xtr,Ytr,epsil,C,ggam(i),sigmay,yker);
ypredtest = svroutput(Xtr,x_tst,yker,svs,bias,sigmay);
err(i) = mean((ypredtest-y_tst).^2);
end
if dibuja,
subplot(321); loglog(ggam,err); axis tight;
xlabel('\gamma'),ylabel('MSE'),grid, drawnow
end
[kk,m] = min(err);
if kk<oldError,
gam = ggam(m); oldError=kk;
end
% Search in C
err=zeros(Npoints,1);
CC=logspace(lowC,upC,Npoints);
for i=1:Npoints
[nsv,svs,bias] = mysvr(Xtr,Ytr,epsil,CC(i),gam,sigmay,yker);
ypredtest = svroutput(Xtr,x_tst,yker,svs,bias,sigmay);
err(i) = mean((ypredtest-y_tst).^2);
end
if dibuja,
subplot(322); loglog(CC,err); axis tight;
xlabel('C'),ylabel('MSE'),grid, drawnow
end
[kk,m] = min(err);
if kk<oldError,
C = CC(m); oldError=kk;
end
% Search in epsilon
err=zeros(Npoints,1);
eepsil=linspace(loweps,upeps,Npoints);
for i=1:Npoints
[nsv,svs,bias] = mysvr(Xtr,Ytr,eepsil(i),C,gam,sigmay,yker);
ypredtest = svroutput(Xtr,x_tst,yker,svs,bias,sigmay);
err(i) = mean((ypredtest-y_tst).^2);
end
if dibuja,
subplot(323);
plot(eepsil,err); axis tight; xlabel('\epsilon'),ylabel('MSE'), grid, drawnow
end
[kk,m] = min(err);
if kk<oldError,
epsil = eepsil(m); oldError = kk;
end
% Search in sigmay
err=zeros(Npoints,1);
ssigmay=logspace(log10(lowsigmay),log10(upsigmay),Npoints);
for i=1:Npoints
[nsv,svs,bias] = mysvr(Xtr,Ytr,epsil,C,gam,ssigmay(i),yker);
ypredtest = svroutput(Xtr,x_tst,yker,svs,bias,ssigmay(i));
err(i) = mean((ypredtest-y_tst).^2);
end
if dibuja,
subplot(324); loglog(ssigmay,err); axis tight;
xlabel('\sigma_y'),ylabel('MSE'), grid, drawnow
end
[kk,m] = min(err);
if kk<oldError,
m2 =max(find(err==kk)); sigmay = ssigmay(m2);
oldError = kk;
end
% Partial error
error_parcial = [error_parcial oldError];
if dibuja
subplot(325); plot(error_parcial,'.-.'); axis tight;
xlabel('error'),ylabel('MSE'), grid, drawnow;
end
end
warning on
[nsv,svs,bias] = mysvr(Xtr,Ytr,epsil,C,gam,sigmay,yker);
tol=1e-7;
aux=find(abs(svs)>=tol);
svs =svs(aux);
Xsv =Xtr(aux,:);
nsv=length(svs);