|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 2 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2009 Christian Widmer 00008 * Copyright (C) 2007-2009 Max-Planck-Society 00009 */ 00010 00011 #include "lib/config.h" 00012 00013 #ifdef USE_SVMLIGHT 00014 00015 #include "classifier/svm/DomainAdaptationSVM.h" 00016 #include "lib/io.h" 00017 #include <iostream> 00018 #include <vector> 00019 00020 #ifdef HAVE_BOOST_SERIALIZATION 00021 #include <boost/serialization/export.hpp> 00022 BOOST_CLASS_EXPORT(shogun::CDomainAdaptationSVM); 00023 #endif //HAVE_BOOST_SERIALIZATION 00024 00025 using namespace shogun; 00026 00027 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight() 00028 { 00029 } 00030 00031 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab) 00032 { 00033 init(pre_svm, B_param); 00034 } 00035 00036 CDomainAdaptationSVM::~CDomainAdaptationSVM() 00037 { 00038 SG_UNREF(presvm); 00039 SG_DEBUG("deleting DomainAdaptationSVM\n"); 00040 } 00041 00042 00043 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param) 00044 { 00045 // increase reference counts 00046 SG_REF(pre_svm); 00047 00048 this->presvm=pre_svm; 00049 this->B=B_param; 00050 this->train_factor=1.0; 00051 00052 // set bias of parent svm to zero 00053 this->presvm->set_bias(0.0); 00054 00055 // invoke sanity check 00056 is_presvm_sane(); 00057 } 00058 00059 bool CDomainAdaptationSVM::is_presvm_sane() 00060 { 00061 00062 if (!presvm) { 00063 SG_ERROR("presvm is null"); 00064 } 00065 00066 if (presvm->get_num_support_vectors() == 0) { 00067 SG_ERROR("presvm has no support vectors, please train first"); 00068 } 00069 00070 if (presvm->get_bias() != 0) { 00071 SG_ERROR("presvm bias not set to zero"); 00072 } 00073 00074 if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) { 00075 SG_ERROR("kernel types do not agree"); 00076 } 00077 00078 if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) { 00079 SG_ERROR("feature types do not agree"); 00080 } 00081 00082 return true; 00083 00084 } 00085 00086 00087 bool CDomainAdaptationSVM::train(CFeatures* data) 00088 { 00089 00090 if (data) 00091 { 00092 if (labels->get_num_labels() != data->get_num_vectors()) 00093 SG_ERROR("Number of training vectors does not match number of labels\n"); 00094 kernel->init(data, data); 00095 } 00096 00097 int32_t num_training_points = get_labels()->get_num_labels(); 00098 00099 00100 std::vector<float64_t> lin_term = std::vector<float64_t>(num_training_points); 00101 00102 // grab current training features 00103 CFeatures* train_data = get_kernel()->get_lhs(); 00104 00105 // bias of parent SVM was set to zero in constructor, already contains B 00106 CLabels* parent_svm_out = presvm->classify(train_data); 00107 00108 // pre-compute linear term 00109 for (int32_t i=0; i!=num_training_points; i++) 00110 { 00111 lin_term[i] = (- B*(get_label(i) * parent_svm_out->get_label(i)))*train_factor - 1.0; 00112 } 00113 00114 //set linear term for QP 00115 this->set_linear_term(lin_term); 00116 00117 //train SVM 00118 bool success = CSVMLight::train(); 00119 00120 ASSERT(presvm) 00121 00122 return success; 00123 00124 } 00125 00126 00127 CSVM* CDomainAdaptationSVM::get_presvm() 00128 { 00129 return presvm; 00130 } 00131 00132 00133 float64_t CDomainAdaptationSVM::get_B() 00134 { 00135 return B; 00136 } 00137 00138 00139 float64_t CDomainAdaptationSVM::get_train_factor() 00140 { 00141 return train_factor; 00142 } 00143 00144 00145 void CDomainAdaptationSVM::set_train_factor(float64_t factor) 00146 { 00147 train_factor = factor; 00148 } 00149 00150 00151 CLabels* CDomainAdaptationSVM::classify(CFeatures* data) 00152 { 00153 00154 ASSERT(presvm->get_bias()==0.0); 00155 00156 int32_t num_examples = data->get_num_vectors(); 00157 00158 CLabels* out_current = CSVMLight::classify(data); 00159 00160 // recursive call if used on DomainAdaptationSVM object 00161 CLabels* out_presvm = presvm->classify(data); 00162 00163 00164 // combine outputs 00165 for (int32_t i=0; i!=num_examples; i++) 00166 { 00167 float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i); 00168 out_current->set_label(i, out_combined); 00169 } 00170 00171 return out_current; 00172 00173 } 00174 00175 #endif //USE_SVMLIGHT