|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2008 Vojtech Franc 00008 * Written (W) 2007-2009 Soeren Sonnenburg 00009 * Copyright (C) 2007-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _WDSVMOCAS_H___ 00013 #define _WDSVMOCAS_H___ 00014 00015 #include "lib/common.h" 00016 #include "classifier/Classifier.h" 00017 #include "classifier/svm/SVMOcas.h" 00018 #include "features/StringFeatures.h" 00019 #include "features/Labels.h" 00020 00021 namespace shogun 00022 { 00023 template <class ST> class CStringFeatures; 00024 00026 class CWDSVMOcas : public CClassifier 00027 { 00028 public: 00033 CWDSVMOcas(E_SVM_TYPE type); 00034 00043 CWDSVMOcas( 00044 float64_t C, int32_t d, int32_t from_d, 00045 CStringFeatures<uint8_t>* traindat, CLabels* trainlab); 00046 virtual ~CWDSVMOcas(); 00047 00052 virtual inline EClassifierType get_classifier_type() { return CT_WDSVMOCAS; } 00053 00062 virtual bool train(CFeatures* data=NULL); 00063 00070 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00071 00076 inline float64_t get_C1() { return C1; } 00077 00082 inline float64_t get_C2() { return C2; } 00083 00088 inline void set_epsilon(float64_t eps) { epsilon=eps; } 00089 00094 inline float64_t get_epsilon() { return epsilon; } 00095 00100 inline void set_features(CStringFeatures<uint8_t>* feat) 00101 { 00102 SG_UNREF(features); 00103 SG_REF(feat); 00104 features=feat; 00105 } 00106 00111 inline CStringFeatures<uint8_t>* get_features() 00112 { 00113 SG_REF(features); 00114 return features; 00115 } 00116 00121 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00122 00127 inline bool get_bias_enabled() { return use_bias; } 00128 00133 inline void set_bufsize(int32_t sz) { bufsize=sz; } 00134 00139 inline int32_t get_bufsize() { return bufsize; } 00140 00146 inline void set_degree(int32_t d, int32_t from_d) 00147 { 00148 degree=d; 00149 from_degree=from_d; 00150 } 00151 00156 inline int32_t get_degree() { return degree; } 00157 00162 CLabels* classify(); 00163 00169 virtual CLabels* classify(CFeatures* data); 00170 00176 inline virtual float64_t classify_example(int32_t num) 00177 { 00178 ASSERT(features); 00179 if (!wd_weights) 00180 set_wd_weights(); 00181 00182 int32_t len=0; 00183 float64_t sum=0; 00184 bool free_vec; 00185 uint8_t* vec=features->get_feature_vector(num, len, free_vec); 00186 //SG_INFO("len %d, string_length %d\n", len, string_length); 00187 ASSERT(len==string_length); 00188 00189 for (int32_t j=0; j<string_length; j++) 00190 { 00191 int32_t offs=w_dim_single_char*j; 00192 int32_t val=0; 00193 for (int32_t k=0; (j+k<string_length) && (k<degree); k++) 00194 { 00195 val=val*alphabet_size + vec[j+k]; 00196 sum+=wd_weights[k] * w[offs+val]; 00197 offs+=w_offsets[k]; 00198 } 00199 } 00200 features->free_feature_vector(vec, len, free_vec); 00201 return sum/normalization_const; 00202 } 00203 00205 inline void set_normalization_const() 00206 { 00207 ASSERT(features); 00208 normalization_const=0; 00209 for (int32_t i=0; i<degree; i++) 00210 normalization_const+=(string_length-i)*wd_weights[i]*wd_weights[i]; 00211 00212 normalization_const=CMath::sqrt(normalization_const); 00213 SG_DEBUG("normalization_const:%f\n", normalization_const); 00214 } 00215 00220 inline float64_t get_normalization_const() { return normalization_const; } 00221 00222 00223 protected: 00228 int32_t set_wd_weights(); 00229 00238 static void compute_W( 00239 float64_t *sq_norm_W, float64_t *dp_WoldW, float64_t *alpha, 00240 uint32_t nSel, void* ptr ); 00241 00248 static float64_t update_W(float64_t t, void* ptr ); 00249 00255 static void* add_new_cut_helper(void* ptr); 00256 00265 static void add_new_cut( 00266 float64_t *new_col_H, uint32_t *new_cut, uint32_t cut_length, 00267 uint32_t nSel, void* ptr ); 00268 00274 static void* compute_output_helper(void* ptr); 00275 00281 static void compute_output( float64_t *output, void* ptr ); 00282 00289 static void sort( float64_t* vals, uint32_t* idx, uint32_t size); 00290 00292 inline virtual const char* get_name() const { return "WDSVMOcas"; } 00293 00294 protected: 00296 CStringFeatures<uint8_t>* features; 00298 bool use_bias; 00300 int32_t bufsize; 00302 float64_t C1; 00304 float64_t C2; 00306 float64_t epsilon; 00308 E_SVM_TYPE method; 00309 00311 int32_t degree; 00313 int32_t from_degree; 00315 float32_t* wd_weights; 00317 int32_t num_vec; 00319 int32_t string_length; 00321 int32_t alphabet_size; 00322 00324 float64_t normalization_const; 00325 00327 float64_t bias; 00329 float64_t old_bias; 00331 int32_t* w_offsets; 00333 int32_t w_dim; 00335 int32_t w_dim_single_char; 00337 float32_t* w; 00339 float32_t* old_w; 00341 float64_t* lab; 00342 00344 float32_t** cuts; 00346 float64_t* cp_bias; 00347 }; 00348 } 00349 #endif