|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Written (W) 2009 Marius Kloft 00009 * Copyright (C) 2009 TU Berlin and Max-Planck-Society 00010 */ 00011 00012 #include "classifier/svm/ScatterSVM.h" 00013 #include "lib/io.h" 00014 00015 using namespace shogun; 00016 00017 CScatterSVM::CScatterSVM() 00018 : CMultiClassSVM(ONE_VS_REST), model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0) 00019 { 00020 } 00021 00022 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab) 00023 : CMultiClassSVM(ONE_VS_REST, C, k, lab), model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0) 00024 { 00025 } 00026 00027 CScatterSVM::~CScatterSVM() 00028 { 00029 delete[] norm_wc; 00030 delete[] norm_wcw; 00031 //SG_PRINT("deleting ScatterSVM\n"); 00032 } 00033 00034 bool CScatterSVM::train(CFeatures* data) 00035 { 00036 struct svm_node* x_space; 00037 00038 ASSERT(labels && labels->get_num_labels()); 00039 int32_t num_classes = labels->get_num_classes(); 00040 00041 if (data) 00042 { 00043 if (labels->get_num_labels() != data->get_num_vectors()) 00044 SG_ERROR("Number of training vectors does not match number of labels\n"); 00045 kernel->init(data, data); 00046 } 00047 00048 problem.l=labels->get_num_labels(); 00049 SG_INFO( "%d trainlabels\n", problem.l); 00050 00051 problem.y=new float64_t[problem.l]; 00052 problem.x=new struct svm_node*[problem.l]; 00053 x_space=new struct svm_node[2*problem.l]; 00054 00055 for (int32_t i=0; i<problem.l; i++) 00056 { 00057 problem.y[i]=labels->get_label(i); 00058 problem.x[i]=&x_space[2*i]; 00059 x_space[2*i].index=i; 00060 x_space[2*i+1].index=-1; 00061 } 00062 00063 int32_t weights_label[2]={-1,+1}; 00064 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00065 00066 ASSERT(kernel && kernel->has_features()); 00067 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00068 00069 param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM 00070 param.kernel_type = LINEAR; 00071 param.degree = 3; 00072 param.gamma = 0; // 1/k 00073 param.coef0 = 0; 00074 param.nu = get_nu(); // Nu 00075 param.kernel=kernel; 00076 param.cache_size = kernel->get_cache_size(); 00077 param.C = 0; 00078 param.eps = epsilon; 00079 param.p = 0.1; 00080 param.shrinking = 0; 00081 param.nr_weight = 2; 00082 param.weight_label = weights_label; 00083 param.weight = weights; 00084 param.nr_class=num_classes; 00085 param.use_bias = get_bias_enabled(); 00086 00087 int32_t* numc=new int32_t[num_classes]; 00088 CMath::fill_vector(numc, num_classes, 0); 00089 00090 for (int32_t i=0; i<problem.l; i++) 00091 numc[(int32_t) problem.y[i]]++; 00092 00093 int32_t Nc=0; 00094 int32_t Nmin=problem.l; 00095 for (int32_t i=0; i<num_classes; i++) 00096 { 00097 if (numc[i]>0) 00098 { 00099 Nc++; 00100 Nmin=CMath::min(Nmin, numc[i]); 00101 } 00102 00103 } 00104 00105 float64_t nu_min=((float64_t) Nc)/problem.l; 00106 float64_t nu_max=((float64_t) Nc)*Nmin/problem.l; 00107 00108 SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max); 00109 00110 if (param.nu<nu_min || param.nu>nu_max) 00111 SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max); 00112 00113 const char* error_msg = svm_check_parameter(&problem,¶m); 00114 00115 if(error_msg) 00116 SG_ERROR("Error: %s\n",error_msg); 00117 00118 model = svm_train(&problem, ¶m); 00119 00120 if (model) 00121 { 00122 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)); 00123 00124 ASSERT(model->nr_class==num_classes); 00125 create_multiclass_svm(num_classes); 00126 00127 rho=model->rho[0]; 00128 00129 delete[] norm_wcw; 00130 norm_wcw = new float64_t[m_num_svms]; 00131 00132 for (int32_t i=0; i<num_classes; i++) 00133 { 00134 int32_t num_sv=model->nSV[i]; 00135 00136 CSVM* svm=new CSVM(num_sv); 00137 svm->set_bias(model->rho[i+1]); 00138 norm_wcw[i]=model->normwcw[i]; 00139 00140 00141 for (int32_t j=0; j<num_sv; j++) 00142 { 00143 svm->set_alpha(j, model->sv_coef[i][j]); 00144 svm->set_support_vector(j, model->SV[i][j].index); 00145 } 00146 00147 set_svm(i, svm); 00148 } 00149 00150 delete[] problem.x; 00151 delete[] problem.y; 00152 delete[] x_space; 00153 for (int32_t i=0; i<num_classes; i++) 00154 { 00155 free(model->SV[i]); 00156 model->SV[i]=NULL; 00157 } 00158 svm_destroy_model(model); 00159 compute_norm_wc(); 00160 00161 model=NULL; 00162 return true; 00163 } 00164 else 00165 return false; 00166 } 00167 00168 void CScatterSVM::compute_norm_wc() 00169 { 00170 delete[] norm_wc; 00171 norm_wc = new float64_t[m_num_svms]; 00172 for (int32_t i=0; i<m_num_svms; i++) 00173 norm_wc[i]=0; 00174 00175 00176 for (int c=0; c<m_num_svms; c++) 00177 { 00178 CSVM* svm=m_svms[c]; 00179 int32_t num_sv = svm->get_num_support_vectors(); 00180 00181 for (int32_t i=0; i<num_sv; i++) 00182 { 00183 int32_t ii=svm->get_support_vector(i); 00184 for (int32_t j=0; j<num_sv; j++) 00185 { 00186 int32_t jj=svm->get_support_vector(j); 00187 norm_wc[c]+=svm->get_alpha(i)*kernel->kernel(ii,jj)*svm->get_alpha(j); 00188 } 00189 } 00190 } 00191 00192 for (int32_t i=0; i<m_num_svms; i++) 00193 norm_wc[i]=CMath::sqrt(norm_wc[i]); 00194 00195 CMath::display_vector(norm_wc, m_num_svms, "norm_wc"); 00196 } 00197 00198 CLabels* CScatterSVM::classify_one_vs_rest() 00199 { 00200 ASSERT(m_num_svms>0); 00201 CLabels* output=NULL; 00202 if (!kernel) 00203 { 00204 SG_ERROR( "SVM can not proceed without kernel!\n"); 00205 return false ; 00206 } 00207 00208 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00209 { 00210 int32_t num_vectors=kernel->get_num_vec_rhs(); 00211 00212 output=new CLabels(num_vectors); 00213 SG_REF(output); 00214 00215 for (int32_t i=0; i<num_vectors; i++) 00216 { 00217 output->set_label(i, classify_example(i)); 00218 } 00219 /* 00220 ASSERT(num_vectors==output->get_num_labels()); 00221 CLabels** outputs=new CLabels*[m_num_svms]; 00222 00223 for (int32_t i=0; i<m_num_svms; i++) 00224 { 00225 ASSERT(m_svms[i]); 00226 m_svms[i]->set_kernel(kernel); 00227 m_svms[i]->set_labels(labels); 00228 outputs[i]=m_svms[i]->classify(); 00229 } 00230 00231 for (int32_t i=0; i<num_vectors; i++) 00232 { 00233 int32_t winner=0; 00234 float64_t max_out=outputs[0]->get_label(i)/norm_wc[0]; 00235 00236 for (int32_t j=1; j<m_num_svms; j++) 00237 { 00238 float64_t out=outputs[j]->get_label(i)/norm_wc[j]; 00239 00240 if (out>max_out) 00241 { 00242 winner=j; 00243 max_out=out; 00244 } 00245 } 00246 00247 output->set_label(i, winner); 00248 } 00249 00250 for (int32_t i=0; i<m_num_svms; i++) 00251 SG_UNREF(outputs[i]); 00252 00253 delete[] outputs; 00254 */ 00255 } 00256 00257 return output; 00258 } 00259 00260 float64_t CScatterSVM::classify_example(int32_t num) 00261 { 00262 /* 00263 ASSERT(m_num_svms>0); 00264 float64_t* outputs=new float64_t[m_num_svms]; 00265 int32_t winner=0; 00266 float64_t max_out=m_svms[0]->classify_example(num)/norm_wc[0]; 00267 00268 for (int32_t i=1; i<m_num_svms; i++) 00269 { 00270 outputs[i]=m_svms[i]->classify_example(num)/norm_wc[i]; 00271 if (outputs[i]>max_out) 00272 { 00273 winner=i; 00274 max_out=outputs[i]; 00275 } 00276 } 00277 delete[] outputs; 00278 00279 return winner; 00280 */ 00281 00282 ASSERT(m_num_svms>0); 00283 float64_t* outputs=new float64_t[m_num_svms]; 00284 int32_t winner=0; 00285 00286 for (int32_t c=0; c<m_num_svms; c++) 00287 outputs[c]=m_svms[c]->get_bias()-rho; 00288 00289 for (int32_t c=0; c<m_num_svms; c++) 00290 { 00291 float64_t v=0; 00292 00293 for (int32_t i=0; i<m_svms[c]->get_num_support_vectors(); i++) 00294 { 00295 float64_t alpha=m_svms[c]->get_alpha(i); 00296 int32_t svidx=m_svms[c]->get_support_vector(i); 00297 v += alpha*kernel->kernel(svidx, num); 00298 } 00299 00300 outputs[c] += v; 00301 for (int32_t j=0; j<m_num_svms; j++) 00302 outputs[j] -= v/m_num_svms; 00303 } 00304 00305 for (int32_t j=0; j<m_num_svms; j++) 00306 outputs[j]/=norm_wcw[j]; 00307 00308 float64_t max_out=outputs[0]; 00309 for (int32_t j=0; j<m_num_svms; j++) 00310 { 00311 if (outputs[j]>max_out) 00312 { 00313 max_out=outputs[j]; 00314 winner=j; 00315 } 00316 } 00317 00318 delete[] outputs; 00319 00320 //SG_PRINT("winner = %d\n", winner); 00321 00322 return winner; 00323 }