|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "classifier/svm/LibSVMMultiClass.h" 00012 #include "lib/io.h" 00013 00014 using namespace shogun; 00015 00016 CLibSVMMultiClass::CLibSVMMultiClass(LIBSVM_SOLVER_TYPE st) 00017 : CMultiClassSVM(ONE_VS_ONE), model(NULL), solver_type(st) 00018 { 00019 } 00020 00021 CLibSVMMultiClass::CLibSVMMultiClass(float64_t C, CKernel* k, CLabels* lab) 00022 : CMultiClassSVM(ONE_VS_ONE, C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC) 00023 { 00024 } 00025 00026 CLibSVMMultiClass::~CLibSVMMultiClass() 00027 { 00028 //SG_PRINT("deleting LibSVM\n"); 00029 } 00030 00031 bool CLibSVMMultiClass::train(CFeatures* data) 00032 { 00033 struct svm_node* x_space; 00034 00035 problem = svm_problem(); 00036 00037 ASSERT(labels && labels->get_num_labels()); 00038 int32_t num_classes = labels->get_num_classes(); 00039 problem.l=labels->get_num_labels(); 00040 SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes); 00041 00042 if (data) 00043 { 00044 if (labels->get_num_labels() != data->get_num_vectors()) 00045 SG_ERROR("Number of training vectors does not match number of labels\n"); 00046 kernel->init(data, data); 00047 } 00048 00049 problem.y=new float64_t[problem.l]; 00050 problem.x=new struct svm_node*[problem.l]; 00051 problem.pv=new float64_t[problem.l]; 00052 problem.C=new float64_t[problem.l]; 00053 00054 x_space=new struct svm_node[2*problem.l]; 00055 00056 for (int32_t i=0; i<problem.l; i++) 00057 { 00058 problem.pv[i]=-1.0; 00059 problem.y[i]=labels->get_label(i); 00060 problem.x[i]=&x_space[2*i]; 00061 x_space[2*i].index=i; 00062 x_space[2*i+1].index=-1; 00063 } 00064 00065 ASSERT(kernel); 00066 00067 param.svm_type=solver_type; // C SVM or NU_SVM 00068 param.kernel_type = LINEAR; 00069 param.degree = 3; 00070 param.gamma = 0; // 1/k 00071 param.coef0 = 0; 00072 param.nu = get_nu(); // Nu 00073 param.kernel=kernel; 00074 param.cache_size = kernel->get_cache_size(); 00075 param.max_train_time = max_train_time; 00076 param.C = get_C1(); 00077 param.eps = epsilon; 00078 param.p = 0.1; 00079 param.shrinking = 1; 00080 param.nr_weight = 0; 00081 param.weight_label = NULL; 00082 param.weight = NULL; 00083 param.use_bias = get_bias_enabled(); 00084 00085 const char* error_msg = svm_check_parameter(&problem,¶m); 00086 00087 if(error_msg) 00088 SG_ERROR("Error: %s\n",error_msg); 00089 00090 model = svm_train(&problem, ¶m); 00091 00092 if (model) 00093 { 00094 ASSERT(model->nr_class==num_classes); 00095 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef)); 00096 create_multiclass_svm(num_classes); 00097 00098 int32_t* offsets=new int32_t[num_classes]; 00099 offsets[0]=0; 00100 00101 for (int32_t i=1; i<num_classes; i++) 00102 offsets[i] = offsets[i-1]+model->nSV[i-1]; 00103 00104 int32_t s=0; 00105 for (int32_t i=0; i<num_classes; i++) 00106 { 00107 for (int32_t j=i+1; j<num_classes; j++) 00108 { 00109 int32_t k, l; 00110 00111 float64_t sgn=1; 00112 if (model->label[i]>model->label[j]) 00113 sgn=-1; 00114 00115 int32_t num_sv=model->nSV[i]+model->nSV[j]; 00116 float64_t bias=-model->rho[s]; 00117 00118 ASSERT(num_sv>0); 00119 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]); 00120 00121 CSVM* svm=new CSVM(num_sv); 00122 00123 svm->set_bias(sgn*bias); 00124 00125 int32_t sv_idx=0; 00126 for (k=0; k<model->nSV[i]; k++) 00127 { 00128 svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index); 00129 svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]); 00130 sv_idx++; 00131 } 00132 00133 for (k=0; k<model->nSV[j]; k++) 00134 { 00135 svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index); 00136 svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]); 00137 sv_idx++; 00138 } 00139 00140 int32_t idx=0; 00141 00142 if (sgn>0) 00143 { 00144 for (k=0; k<model->label[i]; k++) 00145 idx+=num_classes-k-1; 00146 00147 for (l=model->label[i]+1; l<model->label[j]; l++) 00148 idx++; 00149 } 00150 else 00151 { 00152 for (k=0; k<model->label[j]; k++) 00153 idx+=num_classes-k-1; 00154 00155 for (l=model->label[j]+1; l<model->label[i]; l++) 00156 idx++; 00157 } 00158 00159 00160 // if (sgn>0) 00161 // idx=((num_classes-1)*model->label[i]+model->label[j])/2; 00162 // else 00163 // idx=((num_classes-1)*model->label[j]+model->label[i])/2; 00164 // 00165 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f label:(%d,%d) -> svm[%d]\n", s, num_sv, model->l, bias, model->label[i], model->label[j], idx); 00166 00167 set_svm(idx, svm); 00168 s++; 00169 } 00170 } 00171 00172 CSVM::set_objective(model->objective); 00173 00174 delete[] offsets; 00175 delete[] problem.x; 00176 delete[] problem.y; 00177 delete[] x_space; 00178 00179 svm_destroy_model(model); 00180 model=NULL; 00181 00182 return true; 00183 } 00184 else 00185 return false; 00186 } 00187