SHOGUN v0.9.3
DomainAdaptationSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 2 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2009 Christian Widmer
00008  * Copyright (C) 2007-2009 Max-Planck-Society
00009  */
00010 
00011 #include "lib/config.h"
00012 
00013 #ifdef USE_SVMLIGHT
00014 
00015 #include "classifier/svm/DomainAdaptationSVM.h"
00016 #include "lib/io.h"
00017 #include <iostream>
00018 #include <vector>
00019 
00020 #ifdef HAVE_BOOST_SERIALIZATION
00021 #include <boost/serialization/export.hpp>
00022 BOOST_CLASS_EXPORT(shogun::CDomainAdaptationSVM);
00023 #endif //HAVE_BOOST_SERIALIZATION
00024 
00025 using namespace shogun;
00026 
00027 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight()
00028 {
00029 }
00030 
00031 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab)
00032 {
00033   init(pre_svm, B_param);
00034 }
00035 
00036 CDomainAdaptationSVM::~CDomainAdaptationSVM()
00037 {
00038     SG_UNREF(presvm);
00039     SG_DEBUG("deleting DomainAdaptationSVM\n");
00040 }
00041 
00042 
00043 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
00044 {
00045     // increase reference counts
00046     SG_REF(pre_svm);
00047 
00048     this->presvm=pre_svm;
00049     this->B=B_param;
00050     this->train_factor=1.0;
00051 
00052     // set bias of parent svm to zero
00053     this->presvm->set_bias(0.0);
00054 
00055     // invoke sanity check
00056     is_presvm_sane();
00057 }
00058 
00059 bool CDomainAdaptationSVM::is_presvm_sane()
00060 {
00061 
00062     if (!presvm) {
00063         SG_ERROR("presvm is null");
00064     }
00065 
00066     if (presvm->get_num_support_vectors() == 0) {
00067         SG_ERROR("presvm has no support vectors, please train first");
00068     }
00069 
00070     if (presvm->get_bias() != 0) {
00071         SG_ERROR("presvm bias not set to zero");
00072     }
00073 
00074     if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) {
00075         SG_ERROR("kernel types do not agree");
00076     }
00077 
00078     if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) {
00079         SG_ERROR("feature types do not agree");
00080     }
00081 
00082     return true;
00083 
00084 }
00085 
00086 
00087 bool CDomainAdaptationSVM::train(CFeatures* data)
00088 {
00089 
00090     if (data)
00091     {
00092         if (labels->get_num_labels() != data->get_num_vectors())
00093             SG_ERROR("Number of training vectors does not match number of labels\n");
00094         kernel->init(data, data);
00095     }
00096 
00097     int32_t num_training_points = get_labels()->get_num_labels();
00098 
00099 
00100     std::vector<float64_t> lin_term = std::vector<float64_t>(num_training_points);
00101 
00102     // grab current training features
00103     CFeatures* train_data = get_kernel()->get_lhs();
00104 
00105     // bias of parent SVM was set to zero in constructor, already contains B
00106     CLabels* parent_svm_out = presvm->classify(train_data);
00107 
00108     // pre-compute linear term
00109     for (int32_t i=0; i!=num_training_points; i++)
00110     {
00111         lin_term[i] = (- B*(get_label(i) * parent_svm_out->get_label(i)))*train_factor - 1.0;
00112     }
00113 
00114     //set linear term for QP
00115     this->set_linear_term(lin_term);
00116 
00117     //train SVM
00118     bool success = CSVMLight::train();
00119 
00120     ASSERT(presvm)
00121 
00122     return success;
00123 
00124 }
00125 
00126 
00127 CSVM* CDomainAdaptationSVM::get_presvm()
00128 {
00129     return presvm;
00130 }
00131 
00132 
00133 float64_t CDomainAdaptationSVM::get_B()
00134 {
00135     return B;
00136 }
00137 
00138 
00139 float64_t CDomainAdaptationSVM::get_train_factor()
00140 {
00141     return train_factor;
00142 }
00143 
00144 
00145 void CDomainAdaptationSVM::set_train_factor(float64_t factor)
00146 {
00147     train_factor = factor;
00148 }
00149 
00150 
00151 CLabels* CDomainAdaptationSVM::classify(CFeatures* data)
00152 {
00153 
00154     ASSERT(presvm->get_bias()==0.0);
00155 
00156     int32_t num_examples = data->get_num_vectors();
00157 
00158     CLabels* out_current = CSVMLight::classify(data);
00159 
00160     // recursive call if used on DomainAdaptationSVM object
00161     CLabels* out_presvm = presvm->classify(data);
00162 
00163 
00164     // combine outputs
00165     for (int32_t i=0; i!=num_examples; i++)
00166     {
00167         float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i);
00168         out_current->set_label(i, out_combined);
00169     }
00170 
00171     return out_current;
00172 
00173 }
00174 
00175 #endif //USE_SVMLIGHT
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation