|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _LINEARHMM_H__ 00013 #define _LINEARHMM_H__ 00014 00015 #include "features/StringFeatures.h" 00016 #include "features/Labels.h" 00017 #include "distributions/Distribution.h" 00018 00019 namespace shogun 00020 { 00039 class CLinearHMM : public CDistribution 00040 { 00041 public: 00046 CLinearHMM(CStringFeatures<uint16_t>* f); 00047 00053 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols); 00054 virtual ~CLinearHMM(); 00055 00064 virtual bool train(CFeatures* data=NULL); 00065 00073 bool train( 00074 const int32_t* indizes, int32_t num_indizes, 00075 float64_t pseudo_count); 00076 00083 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len); 00084 00091 float64_t get_likelihood_example(uint16_t* vector, int32_t len); 00092 00098 virtual float64_t get_log_likelihood_example(int32_t num_example); 00099 00106 virtual float64_t get_log_derivative( 00107 int32_t num_param, int32_t num_example); 00108 00115 virtual inline float64_t get_log_derivative_obsolete( 00116 uint16_t obs, int32_t pos) 00117 { 00118 return 1.0/transition_probs[pos*num_symbols+obs]; 00119 } 00120 00127 virtual inline float64_t get_derivative_obsolete( 00128 uint16_t* vector, int32_t len, int32_t pos) 00129 { 00130 ASSERT(pos<len); 00131 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]]; 00132 } 00133 00138 virtual inline int32_t get_sequence_length() { return sequence_length; } 00139 00144 virtual inline int32_t get_num_symbols() { return num_symbols; } 00145 00150 virtual inline int32_t get_num_model_parameters() { return num_params; } 00151 00158 virtual inline float64_t get_positional_log_parameter( 00159 uint16_t obs, int32_t position) 00160 { 00161 return log_transition_probs[position*num_symbols+obs]; 00162 } 00163 00169 virtual inline float64_t get_log_model_parameter(int32_t num_param) 00170 { 00171 ASSERT(log_transition_probs); 00172 ASSERT(num_param<num_params); 00173 00174 return log_transition_probs[num_param]; 00175 } 00176 00184 virtual void get_log_transition_probs(float64_t** dst, int32_t* num); 00185 00192 virtual bool set_log_transition_probs( 00193 const float64_t* src, int32_t num); 00194 00200 virtual void get_transition_probs(float64_t** dst, int32_t* num); 00201 00208 virtual bool set_transition_probs(const float64_t* src, int32_t num); 00209 00211 inline virtual const char* get_name() const { return "LinearHMM"; } 00212 00213 protected: 00215 int32_t sequence_length; 00217 int32_t num_symbols; 00219 int32_t num_params; 00221 float64_t* transition_probs; 00223 float64_t* log_transition_probs; 00224 }; 00225 } 00226 #endif