|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Implementation of SVM-Ocas solver. 00008 * 00009 * Linear unbiased binary SVM solver. 00010 * 00011 * Written (W) 1999-2009 Vojtech Franc 00012 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00013 * 00014 * Modifications: 00015 * 23-oct-2007, VF 00016 * 10-oct-2007, VF, created. 00017 * 14-nov-2007, VF, updates 00018 * ----------------------------------------------------------------------*/ 00019 00020 #include <stdlib.h> 00021 #include <string.h> 00022 #include <math.h> 00023 #include <sys/time.h> 00024 #include <time.h> 00025 #include <stdio.h> 00026 #include <stdint.h> 00027 00028 #include "classifier/svm/libocas.h" 00029 #include "classifier/svm/libocas_common.h" 00030 #include "classifier/svm/qpssvmlib.h" 00031 00032 namespace shogun 00033 { 00034 00035 static const uint32_t QPSolverMaxIter = 10000000; 00036 00037 static float64_t *H; 00038 static uint32_t BufSize; 00039 00040 /*---------------------------------------------------------------------- 00041 Returns pointer at i-th column of Hessian matrix. 00042 ----------------------------------------------------------------------*/ 00043 static const void *get_col( uint32_t i) 00044 { 00045 return( &H[ BufSize*i ] ); 00046 } 00047 00048 /*---------------------------------------------------------------------- 00049 Returns time of the day in seconds. 00050 ----------------------------------------------------------------------*/ 00051 static float64_t get_time() 00052 { 00053 struct timeval tv; 00054 if (gettimeofday(&tv, NULL)==0) 00055 return (float64_t) (tv.tv_sec+((double)(tv.tv_usec))/1e6); 00056 else 00057 return 0.0; 00058 } 00059 00060 /*---------------------------------------------------------------------- 00061 SVM-Ocas solver. 00062 ----------------------------------------------------------------------*/ 00063 ocas_return_value_T svm_ocas_solver( 00064 float64_t C, 00065 uint32_t nData, 00066 float64_t TolRel, 00067 float64_t TolAbs, 00068 float64_t QPBound, 00069 uint32_t _BufSize, 00070 uint8_t Method, 00071 void (*compute_W)(float64_t*, float64_t*, float64_t*, uint32_t, void*), 00072 float64_t (*update_W)(float64_t, void*), 00073 void (*add_new_cut)(float64_t*, uint32_t*, uint32_t, uint32_t, void*), 00074 void (*compute_output)(float64_t*, void* ), 00075 void (*sort)(float64_t*, uint32_t*, uint32_t), 00076 void* user_data) 00077 { 00078 ocas_return_value_T ocas; 00079 float64_t *b, *alpha, *diag_H; 00080 float64_t *output, *old_output; 00081 float64_t xi, sq_norm_W, QPSolverTolRel, dot_prod_WoldW, dummy, sq_norm_oldW; 00082 float64_t A0, B0, GradVal, t, t1=0, t2=0, *Ci, *Bi, *hpf; 00083 float64_t start_time; 00084 uint32_t *hpi; 00085 uint32_t cut_length; 00086 uint32_t i, *new_cut; 00087 uint16_t *I; 00088 int8_t qp_exitflag; 00089 float64_t gap; 00090 00091 ocas.ocas_time = get_time(); 00092 ocas.solver_time = 0; 00093 ocas.output_time = 0; 00094 ocas.sort_time = 0; 00095 ocas.add_time = 0; 00096 ocas.w_time = 0; 00097 00098 BufSize = _BufSize; 00099 00100 QPSolverTolRel = TolRel*0.5; 00101 00102 H=NULL; 00103 b=NULL; 00104 alpha=NULL; 00105 new_cut=NULL; 00106 I=NULL; 00107 diag_H=NULL; 00108 output=NULL; 00109 old_output=NULL; 00110 hpf=NULL; 00111 hpi=NULL; 00112 Ci=NULL; 00113 Bi=NULL; 00114 00115 /* Hessian matrix contains dot product of normal vectors of selected cutting planes */ 00116 H = (float64_t*)OCAS_CALLOC(BufSize*BufSize,sizeof(float64_t)); 00117 if(H == NULL) 00118 { 00119 ocas.exitflag=-2; 00120 goto cleanup; 00121 } 00122 00123 /* bias of cutting planes */ 00124 b = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t)); 00125 if(b == NULL) 00126 { 00127 ocas.exitflag=-2; 00128 goto cleanup; 00129 } 00130 00131 alpha = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t)); 00132 if(alpha == NULL) 00133 { 00134 ocas.exitflag=-2; 00135 goto cleanup; 00136 } 00137 00138 /* indices of examples which define a new cut */ 00139 new_cut = (uint32_t*)OCAS_CALLOC(nData,sizeof(uint32_t)); 00140 if(new_cut == NULL) 00141 { 00142 ocas.exitflag=-2; 00143 goto cleanup; 00144 } 00145 00146 I = (uint16_t*)OCAS_CALLOC(BufSize,sizeof(uint16_t)); 00147 if(I == NULL) 00148 { 00149 ocas.exitflag=-2; 00150 goto cleanup; 00151 } 00152 00153 for(i=0; i< BufSize; i++) I[i] = 1; 00154 00155 diag_H = (float64_t*)OCAS_CALLOC(BufSize,sizeof(float64_t)); 00156 if(diag_H == NULL) 00157 { 00158 ocas.exitflag=-2; 00159 goto cleanup; 00160 } 00161 00162 output = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t)); 00163 if(output == NULL) 00164 { 00165 ocas.exitflag=-2; 00166 goto cleanup; 00167 } 00168 00169 old_output = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t)); 00170 if(old_output == NULL) 00171 { 00172 ocas.exitflag=-2; 00173 goto cleanup; 00174 } 00175 00176 /* array of hinge points used in line-serach */ 00177 hpf = (float64_t*) OCAS_CALLOC(nData, sizeof(hpf[0])); 00178 if(hpf == NULL) 00179 { 00180 ocas.exitflag=-2; 00181 goto cleanup; 00182 } 00183 00184 hpi = (uint32_t*) OCAS_CALLOC(nData, sizeof(hpi[0])); 00185 if(hpi == NULL) 00186 { 00187 ocas.exitflag=-2; 00188 goto cleanup; 00189 } 00190 00191 /* vectors Ci, Bi are used in the line search procedure */ 00192 Ci = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t)); 00193 if(Ci == NULL) 00194 { 00195 ocas.exitflag=-2; 00196 goto cleanup; 00197 } 00198 00199 Bi = (float64_t*)OCAS_CALLOC(nData,sizeof(float64_t)); 00200 if(Bi == NULL) 00201 { 00202 ocas.exitflag=-2; 00203 goto cleanup; 00204 } 00205 00206 ocas.nCutPlanes = 0; 00207 ocas.exitflag = 0; 00208 ocas.nIter = 0; 00209 00210 /* Compute initial value of Q_P assuming that W is zero vector.*/ 00211 sq_norm_W = 0; 00212 xi = nData; 00213 ocas.Q_P = 0.5*sq_norm_W + C*xi; 00214 ocas.Q_D = 0; 00215 00216 /* Compute the initial cutting plane */ 00217 cut_length = nData; 00218 for(i=0; i < nData; i++) 00219 new_cut[i] = i; 00220 00221 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P); 00222 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6); 00223 00224 /* main loop */ 00225 while( ocas.exitflag == 0 ) 00226 { 00227 ocas.nIter++; 00228 00229 /* append a new cut to the buffer and update H */ 00230 b[ocas.nCutPlanes] = -(float64_t)cut_length; 00231 00232 start_time = get_time(); 00233 00234 add_new_cut( &H[INDEX2(0,ocas.nCutPlanes,BufSize)], new_cut, cut_length, ocas.nCutPlanes, user_data ); 00235 00236 ocas.add_time += get_time() - start_time; 00237 00238 /* copy new added row: H(ocas.nCutPlanes,ocas.nCutPlanes,1:ocas.nCutPlanes-1) = H(1:ocas.nCutPlanes-1:ocas.nCutPlanes)' */ 00239 diag_H[ocas.nCutPlanes] = H[INDEX2(ocas.nCutPlanes,ocas.nCutPlanes,BufSize)]; 00240 for(i=0; i < ocas.nCutPlanes; i++) { 00241 H[INDEX2(ocas.nCutPlanes,i,BufSize)] = H[INDEX2(i,ocas.nCutPlanes,BufSize)]; 00242 } 00243 00244 ocas.nCutPlanes++; 00245 00246 /* call inner QP solver */ 00247 start_time = get_time(); 00248 00249 qp_exitflag = qpssvm_solver( &get_col, diag_H, b, C, I, alpha, 00250 ocas.nCutPlanes, QPSolverMaxIter, 0.0, QPSolverTolRel, &ocas.Q_D, &dummy, 0 ); 00251 00252 ocas.solver_time += get_time() - start_time; 00253 00254 ocas.Q_D = -ocas.Q_D; 00255 00256 ocas.nNZAlpha = 0; 00257 for(i=0; i < ocas.nCutPlanes; i++) { 00258 if( alpha[i] != 0) ocas.nNZAlpha++; 00259 } 00260 00261 sq_norm_oldW = sq_norm_W; 00262 start_time = get_time(); 00263 compute_W( &sq_norm_W, &dot_prod_WoldW, alpha, ocas.nCutPlanes, user_data ); 00264 ocas.w_time += get_time() - start_time; 00265 00266 /* select a new cut */ 00267 switch( Method ) 00268 { 00269 /* cutting plane algorithm implemented in SVMperf and BMRM */ 00270 case 0: 00271 00272 start_time = get_time(); 00273 compute_output( output, user_data ); 00274 ocas.output_time += get_time()-start_time; 00275 00276 xi = 0; 00277 cut_length = 0; 00278 ocas.trn_err = 0; 00279 for(i=0; i < nData; i++) 00280 { 00281 if(output[i] <= 0) ocas.trn_err++; 00282 00283 if(output[i] <= 1) { 00284 xi += 1 - output[i]; 00285 new_cut[cut_length] = i; 00286 cut_length++; 00287 } 00288 } 00289 ocas.Q_P = 0.5*sq_norm_W + C*xi; 00290 00291 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P); 00292 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6); 00293 00294 break; 00295 00296 00297 /* Ocas strategy */ 00298 case 1: 00299 00300 /* Linesearch */ 00301 A0 = sq_norm_W -2*dot_prod_WoldW + sq_norm_oldW; 00302 B0 = dot_prod_WoldW - sq_norm_oldW; 00303 00304 memcpy( old_output, output, sizeof(float64_t)*nData ); 00305 00306 start_time = get_time(); 00307 compute_output( output, user_data ); 00308 ocas.output_time += get_time()-start_time; 00309 00310 uint32_t num_hp = 0; 00311 GradVal = B0; 00312 for(i=0; i< nData; i++) { 00313 00314 Ci[i] = C*(1-old_output[i]); 00315 Bi[i] = C*(old_output[i] - output[i]); 00316 00317 float64_t val; 00318 if(Bi[i] != 0) 00319 val = -Ci[i]/Bi[i]; 00320 else 00321 val = -OCAS_PLUS_INF; 00322 00323 if (val>0) 00324 { 00325 hpi[num_hp] = i; 00326 hpf[num_hp] = val; 00327 num_hp++; 00328 } 00329 00330 if( (Bi[i] < 0 && val > 0) || (Bi[i] > 0 && val <= 0)) 00331 GradVal += Bi[i]; 00332 00333 } 00334 00335 t = 0; 00336 if( GradVal < 0 ) 00337 { 00338 start_time = get_time(); 00339 sort(hpf, hpi, num_hp); 00340 ocas.sort_time += get_time() - start_time; 00341 00342 float64_t t_new, GradVal_new; 00343 i = 0; 00344 while( GradVal < 0 && i < num_hp ) 00345 { 00346 t_new = hpf[i]; 00347 GradVal_new = GradVal + CMath::abs(Bi[hpi[i]]) + A0*(t_new-t); 00348 00349 if( GradVal_new >= 0 ) 00350 { 00351 t = t + GradVal*(t-t_new)/(GradVal_new - GradVal); 00352 } 00353 else 00354 { 00355 t = t_new; 00356 i++; 00357 } 00358 00359 GradVal = GradVal_new; 00360 } 00361 } 00362 00363 /* 00364 t = hpf[0] - 1; 00365 i = 0; 00366 GradVal = t*A0 + Bsum; 00367 while( GradVal < 0 && i < num_hp && hpf[i] < OCAS_PLUS_INF ) { 00368 t = hpf[i]; 00369 Bsum = Bsum + CMath::abs(Bi[hpi[i]]); 00370 GradVal = t*A0 + Bsum; 00371 i++; 00372 } 00373 */ 00374 t = CMath::max(t,0.0); /* just sanity check; t < 0 should not ocure */ 00375 00376 t1 = t; /* new (best so far) W */ 00377 t2 = t+(1.0-t)/10.0; /* new cutting plane */ 00378 00379 /* update W to be the best so far solution */ 00380 sq_norm_W = update_W( t1, user_data ); 00381 00382 /* select a new cut */ 00383 xi = 0; 00384 cut_length = 0; 00385 ocas.trn_err = 0; 00386 for(i=0; i < nData; i++ ) { 00387 00388 if( (old_output[i]*(1-t2) + t2*output[i]) <= 1 ) 00389 { 00390 new_cut[cut_length] = i; 00391 cut_length++; 00392 } 00393 00394 output[i] = old_output[i]*(1-t1) + t1*output[i]; 00395 00396 if( output[i] <= 1) xi += 1-output[i]; 00397 if( output[i] <= 0) ocas.trn_err++; 00398 00399 } 00400 00401 ocas.Q_P = 0.5*sq_norm_W + C*xi; 00402 00403 gap=(ocas.Q_P-ocas.Q_D)/CMath::abs(ocas.Q_P); 00404 SG_SABS_PROGRESS(gap, -CMath::log10(gap), -CMath::log10(1), -CMath::log10(TolRel), 6); 00405 00406 break; 00407 } 00408 00409 /* Stopping conditions */ 00410 if( ocas.Q_P - ocas.Q_D <= TolRel*CMath::abs(ocas.Q_P)) ocas.exitflag = 1; 00411 if( ocas.Q_P - ocas.Q_D <= TolAbs) ocas.exitflag = 2; 00412 if( ocas.Q_P <= QPBound) ocas.exitflag = 3; 00413 if(ocas.nCutPlanes >= BufSize) ocas.exitflag = -1; 00414 00415 } /* end of the main loop */ 00416 00417 cleanup: 00418 00419 OCAS_FREE(H); 00420 OCAS_FREE(b); 00421 OCAS_FREE(alpha); 00422 OCAS_FREE(new_cut); 00423 OCAS_FREE(I); 00424 OCAS_FREE(diag_H); 00425 OCAS_FREE(output); 00426 OCAS_FREE(old_output); 00427 OCAS_FREE(hpf); 00428 OCAS_FREE(hpi); 00429 OCAS_FREE(Ci); 00430 OCAS_FREE(Bi); 00431 00432 ocas.ocas_time = get_time() - ocas.ocas_time; 00433 00434 return(ocas); 00435 } 00436 }