|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _SVM_H___ 00012 #define _SVM_H___ 00013 00014 #include "lib/common.h" 00015 #include "features/Features.h" 00016 #include "kernel/Kernel.h" 00017 #include "classifier/KernelMachine.h" 00018 #include "lib/Parameter.h" 00019 00020 namespace shogun 00021 { 00022 00023 class CMKL; 00024 00047 class CSVM : public CKernelMachine 00048 { 00049 public: 00053 CSVM(int32_t num_sv=0); 00054 00062 CSVM(float64_t C, CKernel* k, CLabels* lab); 00063 virtual ~CSVM(); 00064 00067 void set_defaults(int32_t num_sv=0); 00068 00069 00075 virtual std::vector<float64_t> get_linear_term(); 00076 00077 00083 virtual void set_linear_term(std::vector<float64_t> lin); 00084 00085 00089 bool load(FILE* svm_file); 00090 00094 bool save(FILE* svm_file); 00095 00100 inline void set_nu(float64_t nue) { nu=nue; } 00101 00102 00111 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00112 00113 00118 inline void set_epsilon(float64_t eps) { epsilon=eps; } 00119 00124 inline void set_tube_epsilon(float64_t eps) { tube_epsilon=eps; } 00125 00130 inline float64_t get_tube_epsilon() { return tube_epsilon; } 00131 00136 inline void set_qpsize(int32_t qps) { qpsize=qps; } 00137 00142 inline float64_t get_epsilon() { return epsilon; } 00143 00148 inline float64_t get_nu() { return nu; } 00149 00154 inline float64_t get_C1() { return C1; } 00155 00160 inline float64_t get_C2() { return C2; } 00161 00166 inline int32_t get_qpsize() { return qpsize; } 00167 00172 inline void set_shrinking_enabled(bool enable) 00173 { 00174 use_shrinking=enable; 00175 } 00176 00181 inline bool get_shrinking_enabled() 00182 { 00183 return use_shrinking; 00184 } 00185 00190 float64_t compute_svm_dual_objective(); 00191 00196 float64_t compute_svm_primal_objective(); 00197 00202 inline void set_objective(float64_t v) 00203 { 00204 objective=v; 00205 } 00206 00211 inline float64_t get_objective() 00212 { 00213 return objective; 00214 } 00215 00223 void set_callback_function(CMKL* m, bool (*cb) 00224 (CMKL* mkl, const float64_t* sumw, const float64_t suma)); 00225 00227 inline virtual const char* get_name() const { return "SVM"; } 00228 00229 #ifdef HAVE_BOOST_SERIALIZATION 00230 friend class ::boost::serialization::access; 00231 // When the class Archive corresponds to an output archive, the 00232 // & operator is defined similar to <<. Likewise, when the class Archive 00233 // is a type of input archive the & operator is defined similar to >>. 00234 template<class Archive> 00235 void serialize(Archive & ar, const unsigned int archive_version) 00236 { 00237 00238 SG_DEBUG("archiving CSVM\n"); 00239 00240 ar & ::boost::serialization::base_object<CKernelMachine>(*this); 00241 00242 ar & linear_term; 00243 00244 ar & svm_loaded; 00245 00246 ar & epsilon; 00247 ar & tube_epsilon; 00248 00249 ar & nu; 00250 ar & C1; 00251 ar & C2; 00252 00253 ar & objective; 00254 00255 ar & qpsize; 00256 ar & use_shrinking; 00257 00258 //TODO serialize mkl object 00259 //CMKL* mkl; 00260 00261 SG_DEBUG("done with CSVM\n"); 00262 } 00263 00264 public: 00265 virtual void toFile(std::string filename) const 00266 { 00267 00268 std::ofstream os(filename.c_str(), std::ios::binary); 00269 ::boost::archive::binary_oarchive oa(os); 00270 00271 oa << *this; 00272 00273 } 00274 00275 virtual void fromFile(std::string filename) 00276 { 00277 00278 std::ifstream is(filename.c_str(), std::ios::binary); 00279 ::boost::archive::binary_iarchive ia(is); 00280 00281 ia >> *this; 00282 00283 } 00284 00285 #endif //HAVE_BOOST_SERIALIZATION 00286 00287 protected: 00288 00294 virtual float64_t* get_linear_term_array(); 00295 00297 std::vector<float64_t> linear_term; 00298 00300 bool svm_loaded; 00302 float64_t epsilon; 00304 float64_t tube_epsilon; 00306 float64_t nu; 00308 float64_t C1; 00310 float64_t C2; 00312 float64_t objective; 00314 int32_t qpsize; 00316 bool use_shrinking; 00317 00320 bool (*callback) (CMKL* mkl, const float64_t* sumw, const float64_t suma); 00323 CMKL* mkl; 00324 00325 CParameter parameters; 00326 }; 00327 } 00328 #endif