|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _LABELS__H__ 00013 #define _LABELS__H__ 00014 00015 #include "lib/common.h" 00016 #include "lib/io.h" 00017 #include "lib/File.h" 00018 #include "base/SGObject.h" 00019 00020 namespace shogun 00021 { 00022 00023 class CFile; 00024 00030 class CLabels : public CSGObject 00031 { 00032 public: 00034 CLabels(); 00035 00040 CLabels(int32_t num_labels); 00041 00047 CLabels(float64_t* src, int32_t len); 00048 00054 CLabels(float64_t* in_confidences, int32_t in_num_labels, int32_t in_num_classes); 00055 00060 CLabels(CFile* loader); 00061 virtual ~CLabels(); 00062 00067 virtual void load(CFile* loader); 00068 00073 virtual void save(CFile* writer); 00074 00081 inline bool set_label(int32_t idx, float64_t label) 00082 { 00083 if (labels && idx<num_labels) 00084 { 00085 labels[idx]=label; 00086 return true; 00087 } 00088 else 00089 return false; 00090 } 00091 00098 inline bool set_int_label(int32_t idx, int32_t label) 00099 { 00100 if (labels && idx<num_labels) 00101 { 00102 labels[idx]= (float64_t) label; 00103 return true; 00104 } 00105 else 00106 return false; 00107 } 00108 00114 inline float64_t get_label(int32_t idx) 00115 { 00116 if (labels && idx<num_labels) 00117 return labels[idx]; 00118 else 00119 return -1; 00120 } 00121 00127 inline int32_t get_int_label(int32_t idx) 00128 { 00129 if (labels && idx<num_labels) 00130 { 00131 ASSERT(labels[idx]== ((float64_t) ((int32_t) labels[idx]))); 00132 return ((int32_t) labels[idx]); 00133 } 00134 else 00135 return -1; 00136 } 00137 00142 bool is_two_class_labeling(); 00143 00150 int32_t get_num_classes(); 00151 00158 float64_t* get_labels(int32_t &len); 00159 00165 void get_labels(float64_t** dst, int32_t* len); 00166 00172 void set_labels(float64_t* src, int32_t len); 00173 00179 void set_confidences(float64_t* in_confidences, int32_t in_num_labels, int32_t in_num_classes); 00180 00186 float64_t* get_confidences(int32_t& out_num_labels, int32_t& out_num_classes); 00187 00193 void get_confidences(float64_t** dst, int32_t* out_num_labels, int32_t* out_num_classes); 00194 00200 float64_t* get_sample_confidences(const int32_t& in_sample_index, int32_t& out_num_classes); 00201 00208 int32_t* get_int_labels(int32_t &len); 00209 00216 void set_int_labels(int32_t *labels, int32_t len) ; 00217 00222 inline int32_t get_num_labels() { return num_labels; } 00223 00225 inline virtual const char* get_name() const { return "Labels"; } 00226 00227 protected: 00229 void find_labels(); 00230 00231 #ifdef HAVE_BOOST_SERIALIZATION 00232 private: 00233 00234 // serialization needs to split up in save/load because 00235 // the serialization of pointers to natives (int* & friends) 00236 // requires a workaround 00237 friend class ::boost::serialization::access; 00238 template<class Archive> 00239 void save(Archive & ar, const unsigned int archive_version) const 00240 { 00241 00242 SG_DEBUG("archiving Labels\n"); 00243 00244 ar & ::boost::serialization::base_object<CSGObject>(*this); 00245 00246 ar & num_labels; 00247 for (int32_t i=0; i < num_labels; ++i) 00248 { 00249 ar & labels[i]; 00250 } 00251 00252 SG_DEBUG("done with Labels\n"); 00253 00254 } 00255 00256 template<class Archive> 00257 void load(Archive & ar, const unsigned int archive_version) 00258 { 00259 00260 SG_DEBUG("archiving Labels\n"); 00261 00262 ar & ::boost::serialization::base_object<CSGObject>(*this); 00263 00264 ar & num_labels; 00265 00266 SG_DEBUG("num_labels: %i\n", num_labels); 00267 00268 if (num_labels > 0) 00269 { 00270 00271 labels = new float64_t[num_labels]; 00272 for (int32_t i=0; i< num_labels; ++i) 00273 { 00274 ar & labels[i]; 00275 } 00276 00277 } 00278 00279 SG_DEBUG("done with Labels\n"); 00280 00281 } 00282 00283 GLOBAL_BOOST_SERIALIZATION_SPLIT_MEMBER(); 00284 00285 00286 public: 00287 00288 virtual std::string toString() const 00289 { 00290 std::ostringstream s; 00291 00292 ::boost::archive::text_oarchive oa(s); 00293 00294 oa << *this; 00295 00296 return s.str(); 00297 } 00298 00299 00300 virtual void fromString(std::string str) 00301 { 00302 00303 std::istringstream is(str); 00304 00305 ::boost::archive::text_iarchive ia(is); 00306 00307 ia >> *this; 00308 00309 } 00310 #endif //HAVE_BOOST_SERIALIZATION 00311 00312 protected: 00314 int32_t num_labels; 00316 float64_t* labels; 00317 00319 int16_t m_num_classes; 00320 00322 float64_t* m_confidences; 00323 00324 }; 00325 } 00326 #endif