|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "classifier/svm/LibSVM.h" 00012 #include "lib/io.h" 00013 00014 using namespace shogun; 00015 00016 #ifdef HAVE_BOOST_SERIALIZATION 00017 #include <boost/serialization/export.hpp> 00018 BOOST_CLASS_EXPORT(CLibSVM); 00019 #endif //HAVE_BOOST_SERIALIZATION 00020 00021 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st) 00022 : CSVM(), model(NULL), solver_type(st) 00023 { 00024 } 00025 00026 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab) 00027 : CSVM(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC) 00028 { 00029 problem = svm_problem(); 00030 } 00031 00032 CLibSVM::~CLibSVM() 00033 { 00034 } 00035 00036 00037 bool CLibSVM::train(CFeatures* data) 00038 { 00039 struct svm_node* x_space; 00040 00041 ASSERT(labels && labels->get_num_labels()); 00042 ASSERT(labels->is_two_class_labeling()); 00043 00044 if (data) 00045 { 00046 if (labels->get_num_labels() != data->get_num_vectors()) 00047 SG_ERROR("Number of training vectors does not match number of labels\n"); 00048 kernel->init(data, data); 00049 } 00050 00051 problem.l=labels->get_num_labels(); 00052 SG_INFO( "%d trainlabels\n", problem.l); 00053 00054 00055 // check length of linear term 00056 if (!linear_term.empty() && 00057 labels->get_num_labels() != (int32_t)linear_term.size()) 00058 { 00059 SG_ERROR("Number of training vectors does not match length of linear term\n"); 00060 } 00061 00062 // set linear term 00063 if (!linear_term.empty()) 00064 { 00065 // set with linear term from base class 00066 problem.pv = get_linear_term_array(); 00067 00068 } 00069 else 00070 { 00071 // fill with minus ones 00072 problem.pv = new float64_t[problem.l]; 00073 00074 for (int i=0; i!=problem.l; i++) 00075 problem.pv[i] = -1.0; 00076 } 00077 00078 problem.y=new float64_t[problem.l]; 00079 problem.x=new struct svm_node*[problem.l]; 00080 problem.C=new float64_t[problem.l]; 00081 00082 x_space=new struct svm_node[2*problem.l]; 00083 00084 for (int32_t i=0; i<problem.l; i++) 00085 { 00086 problem.y[i]=labels->get_label(i); 00087 problem.x[i]=&x_space[2*i]; 00088 x_space[2*i].index=i; 00089 x_space[2*i+1].index=-1; 00090 } 00091 00092 int32_t weights_label[2]={-1,+1}; 00093 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00094 00095 ASSERT(kernel && kernel->has_features()); 00096 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00097 00098 param.svm_type=solver_type; // C SVM or NU_SVM 00099 param.kernel_type = LINEAR; 00100 param.degree = 3; 00101 param.gamma = 0; // 1/k 00102 param.coef0 = 0; 00103 param.nu = get_nu(); 00104 param.kernel=kernel; 00105 param.cache_size = kernel->get_cache_size(); 00106 param.max_train_time = max_train_time; 00107 param.C = get_C1(); 00108 param.eps = epsilon; 00109 param.p = 0.1; 00110 param.shrinking = 1; 00111 param.nr_weight = 2; 00112 param.weight_label = weights_label; 00113 param.weight = weights; 00114 param.use_bias = get_bias_enabled(); 00115 00116 const char* error_msg = svm_check_parameter(&problem, ¶m); 00117 00118 if(error_msg) 00119 SG_ERROR("Error: %s\n",error_msg); 00120 00121 model = svm_train(&problem, ¶m); 00122 00123 if (model) 00124 { 00125 ASSERT(model->nr_class==2); 00126 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0])); 00127 00128 int32_t num_sv=model->l; 00129 00130 create_new_model(num_sv); 00131 CSVM::set_objective(model->objective); 00132 00133 float64_t sgn=model->label[0]; 00134 00135 set_bias(-sgn*model->rho[0]); 00136 00137 for (int32_t i=0; i<num_sv; i++) 00138 { 00139 set_support_vector(i, (model->SV[i])->index); 00140 set_alpha(i, sgn*model->sv_coef[0][i]); 00141 } 00142 00143 delete[] problem.x; 00144 delete[] problem.y; 00145 delete[] problem.pv; 00146 delete[] problem.C; 00147 00148 00149 delete[] x_space; 00150 00151 svm_destroy_model(model); 00152 model=NULL; 00153 return true; 00154 } 00155 else 00156 return false; 00157 }