|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Vojtech Franc, xfrancv@cmp.felk.cvut.cz 00008 * Copyright (C) 1999-2008 Center for Machine Perception, CTU FEL Prague 00009 */ 00010 00011 #include "lib/io.h" 00012 #include "classifier/svm/GMNPSVM.h" 00013 #include "classifier/svm/gmnplib.h" 00014 00015 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW)) 00016 #define MINUS_INF INT_MIN 00017 #define PLUS_INF INT_MAX 00018 #define KDELTA(A,B) (A==B) 00019 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4)) 00020 00021 using namespace shogun; 00022 00023 CGMNPSVM::CGMNPSVM() 00024 : CMultiClassSVM(ONE_VS_REST) 00025 { 00026 } 00027 00028 CGMNPSVM::CGMNPSVM(float64_t C, CKernel* k, CLabels* lab) 00029 : CMultiClassSVM(ONE_VS_REST, C, k, lab) 00030 { 00031 } 00032 00033 CGMNPSVM::~CGMNPSVM() 00034 { 00035 } 00036 00037 bool CGMNPSVM::train(CFeatures* data) 00038 { 00039 ASSERT(kernel); 00040 ASSERT(labels && labels->get_num_labels()); 00041 00042 if (data) 00043 { 00044 if (data->get_num_vectors() != labels->get_num_labels()) 00045 { 00046 SG_ERROR("Numbert of vectors (%d) does not match number of labels (%d)\n", 00047 data->get_num_vectors(), labels->get_num_labels()); 00048 } 00049 kernel->init(data, data); 00050 } 00051 00052 int32_t num_data = labels->get_num_labels(); 00053 int32_t num_classes = labels->get_num_classes(); 00054 int32_t num_virtual_data= num_data*(num_classes-1); 00055 00056 SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes); 00057 00058 float64_t* vector_y = new float64_t[num_data]; 00059 for (int32_t i=0; i<num_data; i++) 00060 { 00061 vector_y[i]= labels->get_label(i)+1; 00062 00063 } 00064 00065 float64_t C = get_C1(); 00066 int32_t tmax = 1000000000; 00067 float64_t tolabs = 0; 00068 float64_t tolrel = epsilon; 00069 00070 float64_t reg_const=0; 00071 if( C!=0 ) 00072 reg_const = 1/(2*C); 00073 00074 00075 float64_t* alpha = new float64_t[num_virtual_data]; 00076 float64_t* vector_c = new float64_t[num_virtual_data]; 00077 memset(vector_c, 0, num_virtual_data*sizeof(float64_t)); 00078 00079 float64_t thlb = 10000000000.0; 00080 int32_t t = 0; 00081 float64_t* History = NULL; 00082 int32_t verb = 0; 00083 00084 CGMNPLib mnp(vector_y,kernel,num_data, num_virtual_data, num_classes, reg_const); 00085 00086 mnp.gmnp_imdm(vector_c, num_virtual_data, tmax, 00087 tolabs, tolrel, thlb, alpha, &t, &History, verb ); 00088 00089 /* matrix alpha [num_classes x num_data] */ 00090 float64_t* all_alphas= new float64_t[num_classes*num_data]; 00091 memset(all_alphas,0,num_classes*num_data*sizeof(float64_t)); 00092 00093 /* bias vector b [num_classes x 1] */ 00094 float64_t* all_bs=new float64_t[num_classes]; 00095 memset(all_bs,0,num_classes*sizeof(float64_t)); 00096 00097 /* compute alpha/b from virt_data */ 00098 for(int32_t i=0; i < num_classes; i++ ) 00099 { 00100 for(int32_t j=0; j < num_virtual_data; j++ ) 00101 { 00102 int32_t inx1=0; 00103 int32_t inx2=0; 00104 00105 mnp.get_indices2( &inx1, &inx2, j ); 00106 00107 all_alphas[(inx1*num_classes)+i] += 00108 alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2)); 00109 all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2)); 00110 } 00111 } 00112 00113 create_multiclass_svm(num_classes); 00114 00115 for (int32_t i=0; i<num_classes; i++) 00116 { 00117 int32_t num_sv=0; 00118 for (int32_t j=0; j<num_data; j++) 00119 { 00120 if (all_alphas[j*num_classes+i] != 0) 00121 num_sv++; 00122 } 00123 ASSERT(num_sv>0); 00124 SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]); 00125 00126 CSVM* svm=new CSVM(num_sv); 00127 00128 int32_t k=0; 00129 for (int32_t j=0; j<num_data; j++) 00130 { 00131 if (all_alphas[j*num_classes+i] != 0) 00132 { 00133 svm->set_alpha(k, all_alphas[j*num_classes+i]); 00134 svm->set_support_vector(k, j); 00135 k++; 00136 } 00137 } 00138 00139 svm->set_bias(all_bs[i]); 00140 set_svm(i, svm); 00141 } 00142 00143 m_basealphas.resize(num_classes, ::std::vector<float64_t>(num_data,0)); 00144 for(int j=0; j < num_virtual_data; j++ ) 00145 { 00146 int inx1=0; 00147 int inx2=0; 00148 00149 mnp.get_indices2( &inx1, &inx2, j ); 00150 m_basealphas[inx2-1][inx1]=alpha[j]; 00151 } 00152 00153 delete[] vector_c; 00154 delete[] alpha; 00155 delete[] all_alphas; 00156 delete[] all_bs; 00157 delete[] vector_y; 00158 delete[] History; 00159 00160 return true; 00161 } 00162 00163 void CGMNPSVM::getbasealphas(::std::vector< ::std::vector<float64_t> > & basealphas) 00164 { 00165 basealphas=m_basealphas; 00166 }