|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "lib/common.h" 00012 #include "lib/io.h" 00013 #include "classifier/svm/MultiClassSVM.h" 00014 00015 using namespace shogun; 00016 00017 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type) 00018 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL) 00019 { 00020 } 00021 00022 CMultiClassSVM::CMultiClassSVM( 00023 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab) 00024 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL) 00025 { 00026 } 00027 00028 CMultiClassSVM::~CMultiClassSVM() 00029 { 00030 cleanup(); 00031 } 00032 00033 void CMultiClassSVM::cleanup() 00034 { 00035 for (int32_t i=0; i<m_num_svms; i++) 00036 SG_UNREF(m_svms[i]); 00037 00038 delete[] m_svms; 00039 m_num_svms=0; 00040 m_svms=NULL; 00041 } 00042 00043 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes) 00044 { 00045 if (num_classes>0) 00046 { 00047 cleanup(); 00048 00049 m_num_classes=num_classes; 00050 00051 if (multiclass_type==ONE_VS_REST) 00052 m_num_svms=num_classes; 00053 else if (multiclass_type==ONE_VS_ONE) 00054 m_num_svms=num_classes*(num_classes-1)/2; 00055 else 00056 SG_ERROR("unknown multiclass type\n"); 00057 00058 m_svms=new CSVM*[m_num_svms]; 00059 if (m_svms) 00060 { 00061 memset(m_svms,0, m_num_svms*sizeof(CSVM*)); 00062 return true; 00063 } 00064 } 00065 return false; 00066 } 00067 00068 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm) 00069 { 00070 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm) 00071 { 00072 SG_REF(svm); 00073 m_svms[num]=svm; 00074 return true; 00075 } 00076 return false; 00077 } 00078 00079 CLabels* CMultiClassSVM::classify() 00080 { 00081 if (multiclass_type==ONE_VS_REST) 00082 return classify_one_vs_rest(); 00083 else if (multiclass_type==ONE_VS_ONE) 00084 return classify_one_vs_one(); 00085 else 00086 SG_ERROR("unknown multiclass type\n"); 00087 00088 return NULL; 00089 } 00090 00091 CLabels* CMultiClassSVM::classify_one_vs_one() 00092 { 00093 ASSERT(m_num_svms>0); 00094 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2); 00095 CLabels* result=NULL; 00096 00097 if (!kernel) 00098 { 00099 SG_ERROR( "SVM can not proceed without kernel!\n"); 00100 return false ; 00101 } 00102 00103 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00104 { 00105 int32_t num_vectors=kernel->get_num_vec_rhs(); 00106 00107 result=new CLabels(num_vectors); 00108 SG_REF(result); 00109 00110 ASSERT(num_vectors==result->get_num_labels()); 00111 CLabels** outputs=new CLabels*[m_num_svms]; 00112 00113 for (int32_t i=0; i<m_num_svms; i++) 00114 { 00115 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]); 00116 ASSERT(m_svms[i]); 00117 m_svms[i]->set_kernel(kernel); 00118 outputs[i]=m_svms[i]->classify(); 00119 } 00120 00121 int32_t* votes=new int32_t[m_num_classes]; 00122 for (int32_t v=0; v<num_vectors; v++) 00123 { 00124 int32_t s=0; 00125 memset(votes, 0, sizeof(int32_t)*m_num_classes); 00126 00127 for (int32_t i=0; i<m_num_classes; i++) 00128 { 00129 for (int32_t j=i+1; j<m_num_classes; j++) 00130 { 00131 if (outputs[s++]->get_label(v)>0) 00132 votes[i]++; 00133 else 00134 votes[j]++; 00135 } 00136 } 00137 00138 int32_t winner=0; 00139 int32_t max_votes=votes[0]; 00140 00141 for (int32_t i=1; i<m_num_classes; i++) 00142 { 00143 if (votes[i]>max_votes) 00144 { 00145 max_votes=votes[i]; 00146 winner=i; 00147 } 00148 } 00149 00150 result->set_label(v, winner); 00151 } 00152 00153 delete[] votes; 00154 00155 for (int32_t i=0; i<m_num_svms; i++) 00156 SG_UNREF(outputs[i]); 00157 delete[] outputs; 00158 } 00159 00160 return result; 00161 } 00162 00163 CLabels* CMultiClassSVM::classify_one_vs_rest() 00164 { 00165 ASSERT(m_num_svms>0); 00166 CLabels* result=NULL; 00167 00168 if (!kernel) 00169 { 00170 SG_ERROR( "SVM can not proceed without kernel!\n"); 00171 return false ; 00172 } 00173 00174 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00175 { 00176 int32_t num_vectors=kernel->get_num_vec_rhs(); 00177 00178 result=new CLabels(num_vectors); 00179 SG_REF(result); 00180 00181 ASSERT(num_vectors==result->get_num_labels()); 00182 CLabels** outputs=new CLabels*[m_num_svms]; 00183 00184 for (int32_t i=0; i<m_num_svms; i++) 00185 { 00186 ASSERT(m_svms[i]); 00187 m_svms[i]->set_kernel(kernel); 00188 outputs[i]=m_svms[i]->classify(); 00189 } 00190 00191 for (int32_t i=0; i<num_vectors; i++) 00192 { 00193 int32_t winner=0; 00194 float64_t max_out=outputs[0]->get_label(i); 00195 00196 for (int32_t j=1; j<m_num_svms; j++) 00197 { 00198 float64_t out=outputs[j]->get_label(i); 00199 00200 if (out>max_out) 00201 { 00202 winner=j; 00203 max_out=out; 00204 } 00205 } 00206 00207 result->set_label(i, winner); 00208 } 00209 00210 for (int32_t i=0; i<m_num_svms; i++) 00211 SG_UNREF(outputs[i]); 00212 00213 delete[] outputs; 00214 } 00215 00216 return result; 00217 } 00218 00219 float64_t CMultiClassSVM::classify_example(int32_t num) 00220 { 00221 if (multiclass_type==ONE_VS_REST) 00222 return classify_example_one_vs_rest(num); 00223 else if (multiclass_type==ONE_VS_ONE) 00224 return classify_example_one_vs_one(num); 00225 else 00226 SG_ERROR("unknown multiclass type\n"); 00227 00228 return 0; 00229 } 00230 00231 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num) 00232 { 00233 ASSERT(m_num_svms>0); 00234 float64_t* outputs=new float64_t[m_num_svms]; 00235 int32_t winner=0; 00236 float64_t max_out=m_svms[0]->classify_example(num); 00237 00238 for (int32_t i=1; i<m_num_svms; i++) 00239 { 00240 outputs[i]=m_svms[i]->classify_example(num); 00241 if (outputs[i]>max_out) 00242 { 00243 winner=i; 00244 max_out=outputs[i]; 00245 } 00246 } 00247 delete[] outputs; 00248 00249 return winner; 00250 } 00251 00252 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num) 00253 { 00254 ASSERT(m_num_svms>0); 00255 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2); 00256 00257 int32_t* votes=new int32_t[m_num_classes]; 00258 int32_t s=0; 00259 00260 for (int32_t i=0; i<m_num_classes; i++) 00261 { 00262 for (int32_t j=i+1; j<m_num_classes; j++) 00263 { 00264 if (m_svms[s++]->classify_example(num)>0) 00265 votes[i]++; 00266 else 00267 votes[j]++; 00268 } 00269 } 00270 00271 int32_t winner=0; 00272 int32_t max_votes=votes[0]; 00273 00274 for (int32_t i=1; i<m_num_classes; i++) 00275 { 00276 if (votes[i]>max_votes) 00277 { 00278 max_votes=votes[i]; 00279 winner=i; 00280 } 00281 } 00282 00283 delete[] votes; 00284 00285 return winner; 00286 } 00287 00288 bool CMultiClassSVM::load(FILE* modelfl) 00289 { 00290 bool result=true; 00291 char char_buffer[1024]; 00292 int32_t int_buffer; 00293 float64_t double_buffer; 00294 int32_t line_number=1; 00295 int32_t svm_idx=-1; 00296 00297 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF) 00298 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00299 else 00300 { 00301 char_buffer[15]='\0'; 00302 if (strcmp("%MultiClassSVM", char_buffer)!=0) 00303 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number); 00304 00305 line_number++; 00306 } 00307 00308 int_buffer=0; 00309 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1) 00310 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00311 00312 if (!feof(modelfl)) 00313 line_number++; 00314 00315 if (int_buffer != multiclass_type) 00316 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type); 00317 00318 int_buffer=0; 00319 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1) 00320 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00321 00322 if (!feof(modelfl)) 00323 line_number++; 00324 00325 if (int_buffer < 2) 00326 SG_ERROR("less than 2 classes - how is this multiclass?\n"); 00327 00328 create_multiclass_svm(int_buffer); 00329 00330 int_buffer=0; 00331 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1) 00332 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00333 00334 if (!feof(modelfl)) 00335 line_number++; 00336 00337 if (m_num_svms != int_buffer) 00338 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer); 00339 00340 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1) 00341 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00342 00343 if (!feof(modelfl)) 00344 line_number++; 00345 00346 for (int32_t n=0; n<m_num_svms; n++) 00347 { 00348 svm_idx=-1; 00349 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF) 00350 { 00351 result=false; 00352 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00353 } 00354 else 00355 { 00356 char_buffer[4]='\0'; 00357 if (strncmp("%SVM", char_buffer, 4)!=0) 00358 { 00359 result=false; 00360 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00361 } 00362 00363 if (svm_idx != n) 00364 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00365 00366 line_number++; 00367 } 00368 00369 int_buffer=0; 00370 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2) 00371 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00372 00373 if (svm_idx != n) 00374 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00375 00376 if (!feof(modelfl)) 00377 line_number++; 00378 00379 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx); 00380 CSVM* svm=new CSVM(int_buffer); 00381 00382 double_buffer=0; 00383 00384 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2) 00385 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00386 00387 if (svm_idx != n) 00388 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00389 00390 if (!feof(modelfl)) 00391 line_number++; 00392 00393 svm->set_bias(double_buffer); 00394 00395 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1) 00396 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00397 00398 if (svm_idx != n) 00399 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00400 00401 if (!feof(modelfl)) 00402 line_number++; 00403 00404 for (int32_t i=0; i<svm->get_num_support_vectors(); i++) 00405 { 00406 double_buffer=0; 00407 int_buffer=0; 00408 00409 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2) 00410 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00411 00412 if (!feof(modelfl)) 00413 line_number++; 00414 00415 svm->set_support_vector(i, int_buffer); 00416 svm->set_alpha(i, double_buffer); 00417 } 00418 00419 if (fscanf(modelfl,"%2s", char_buffer) == EOF) 00420 { 00421 result=false; 00422 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00423 } 00424 else 00425 { 00426 char_buffer[3]='\0'; 00427 if (strcmp("];", char_buffer)!=0) 00428 { 00429 result=false; 00430 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00431 } 00432 line_number++; 00433 } 00434 00435 set_svm(n, svm); 00436 } 00437 00438 svm_loaded=result; 00439 return result; 00440 } 00441 00442 bool CMultiClassSVM::save(FILE* modelfl) 00443 { 00444 if (!kernel) 00445 SG_ERROR("Kernel not defined!\n"); 00446 00447 if (!m_svms || m_num_svms<1 || m_num_classes <=2) 00448 SG_ERROR("Multiclass SVM not trained!\n"); 00449 00450 SG_INFO( "Writing model file..."); 00451 fprintf(modelfl,"%%MultiClassSVM\n"); 00452 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type); 00453 fprintf(modelfl,"num_classes=%d;\n", m_num_classes); 00454 fprintf(modelfl,"num_svms=%d;\n", m_num_svms); 00455 fprintf(modelfl,"kernel='%s';\n", kernel->get_name()); 00456 00457 for (int32_t i=0; i<m_num_svms; i++) 00458 { 00459 CSVM* svm=m_svms[i]; 00460 ASSERT(svm); 00461 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1); 00462 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors()); 00463 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias()); 00464 00465 fprintf(modelfl, "alphas%d=[\n", i); 00466 00467 for(int32_t j=0; j<svm->get_num_support_vectors(); j++) 00468 { 00469 fprintf(modelfl,"\t[%+10.16e,%d];\n", 00470 svm->get_alpha(j), svm->get_support_vector(j)); 00471 } 00472 00473 fprintf(modelfl, "];\n"); 00474 } 00475 00476 SG_DONE(); 00477 return true ; 00478 }