|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include "lib/common.h" 00013 #include "kernel/HistogramWordStringKernel.h" 00014 #include "features/Features.h" 00015 #include "features/StringFeatures.h" 00016 #include "classifier/PluginEstimate.h" 00017 #include "lib/io.h" 00018 00019 using namespace shogun; 00020 00021 CHistogramWordStringKernel::CHistogramWordStringKernel(int32_t size, CPluginEstimate* pie) 00022 : CStringKernel<uint16_t>(size), estimate(pie), mean(NULL), variance(NULL), 00023 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL), 00024 ld_mean_lhs(NULL), ld_mean_rhs(NULL), 00025 plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0), 00026 num_symbols(0), sum_m2_s2(0), initialized(false) 00027 { 00028 } 00029 00030 CHistogramWordStringKernel::CHistogramWordStringKernel( 00031 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, CPluginEstimate* pie) 00032 : CStringKernel<uint16_t>(10), estimate(pie), mean(NULL), variance(NULL), 00033 sqrtdiag_lhs(NULL), sqrtdiag_rhs(NULL), 00034 ld_mean_lhs(NULL), ld_mean_rhs(NULL), 00035 plo_lhs(NULL), plo_rhs(NULL), num_params(0), num_params2(0), 00036 num_symbols(0), sum_m2_s2(0), initialized(false) 00037 { 00038 init(l, r); 00039 } 00040 00041 CHistogramWordStringKernel::~CHistogramWordStringKernel() 00042 { 00043 delete[] variance; 00044 delete[] mean; 00045 if (sqrtdiag_lhs != sqrtdiag_rhs) 00046 delete[] sqrtdiag_rhs; 00047 delete[] sqrtdiag_lhs; 00048 if (ld_mean_lhs!=ld_mean_rhs) 00049 delete[] ld_mean_rhs ; 00050 delete[] ld_mean_lhs ; 00051 if (plo_lhs!=plo_rhs) 00052 delete[] plo_rhs ; 00053 delete[] plo_lhs ; 00054 } 00055 00056 bool CHistogramWordStringKernel::init(CFeatures* p_l, CFeatures* p_r) 00057 { 00058 CStringKernel<uint16_t>::init(p_l,p_r); 00059 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l; 00060 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r; 00061 ASSERT(l); 00062 ASSERT(r); 00063 00064 SG_DEBUG( "init: lhs: %ld rhs: %ld\n", l, r) ; 00065 int32_t i; 00066 initialized=false; 00067 00068 if (sqrtdiag_lhs != sqrtdiag_rhs) 00069 delete[] sqrtdiag_rhs; 00070 sqrtdiag_rhs=NULL ; 00071 delete[] sqrtdiag_lhs; 00072 sqrtdiag_lhs=NULL ; 00073 if (ld_mean_lhs!=ld_mean_rhs) 00074 delete[] ld_mean_rhs ; 00075 ld_mean_rhs=NULL ; 00076 delete[] ld_mean_lhs ; 00077 ld_mean_lhs=NULL ; 00078 if (plo_lhs!=plo_rhs) 00079 delete[] plo_rhs ; 00080 plo_rhs=NULL ; 00081 delete[] plo_lhs ; 00082 plo_lhs=NULL ; 00083 00084 sqrtdiag_lhs= new float64_t[l->get_num_vectors()]; 00085 ld_mean_lhs = new float64_t[l->get_num_vectors()]; 00086 plo_lhs = new float64_t[l->get_num_vectors()]; 00087 00088 for (i=0; i<l->get_num_vectors(); i++) 00089 sqrtdiag_lhs[i]=1; 00090 00091 if (l==r) 00092 { 00093 sqrtdiag_rhs=sqrtdiag_lhs; 00094 ld_mean_rhs=ld_mean_lhs; 00095 plo_rhs=plo_lhs; 00096 } 00097 else 00098 { 00099 sqrtdiag_rhs=new float64_t[r->get_num_vectors()]; 00100 for (i=0; i<r->get_num_vectors(); i++) 00101 sqrtdiag_rhs[i]=1; 00102 00103 ld_mean_rhs=new float64_t[r->get_num_vectors()]; 00104 plo_rhs=new float64_t[r->get_num_vectors()]; 00105 } 00106 00107 float64_t* l_plo_lhs=plo_lhs; 00108 float64_t* l_plo_rhs=plo_rhs; 00109 float64_t* l_ld_mean_lhs=ld_mean_lhs; 00110 float64_t* l_ld_mean_rhs=ld_mean_rhs; 00111 00112 //from our knowledge first normalize variance to 1 and then norm=1 does the job 00113 if (!initialized) 00114 { 00115 int32_t num_vectors=l->get_num_vectors(); 00116 num_symbols=(int32_t) l->get_num_symbols(); 00117 int32_t llen=l->get_vector_length(0); 00118 int32_t rlen=r->get_vector_length(0); 00119 num_params=llen*((int32_t) l->get_num_symbols()); 00120 num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols()); 00121 00122 if ((!estimate) || (!estimate->check_models())) 00123 { 00124 SG_ERROR( "no estimate available\n"); 00125 return false ; 00126 } ; 00127 if (num_params2!=estimate->get_num_params()) 00128 { 00129 SG_ERROR( "number of parameters of estimate and feature representation do not match\n"); 00130 return false ; 00131 } ; 00132 00133 //add 1 as we have the 'bias' also in this vector 00134 num_params2++; 00135 00136 delete[] mean; 00137 mean=new float64_t[num_params2]; 00138 delete[] variance; 00139 variance=new float64_t[num_params2]; 00140 00141 for (i=0; i<num_params2; i++) 00142 { 00143 mean[i]=0; 00144 variance[i]=0; 00145 } 00146 00147 // compute mean 00148 for (i=0; i<num_vectors; i++) 00149 { 00150 int32_t len; 00151 bool free_vec; 00152 uint16_t* vec=l->get_feature_vector(i, len, free_vec); 00153 00154 mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors; 00155 00156 for (int32_t j=0; j<len; j++) 00157 { 00158 int32_t idx=compute_index(j, vec[j]); 00159 mean[idx] += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors; 00160 mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors; 00161 } 00162 00163 l->free_feature_vector(vec, i, free_vec); 00164 } 00165 00166 // compute variance 00167 for (i=0; i<num_vectors; i++) 00168 { 00169 int32_t len; 00170 bool free_vec; 00171 uint16_t* vec=l->get_feature_vector(i, len, free_vec); 00172 00173 variance[0] += CMath::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors; 00174 00175 for (int32_t j=0; j<len; j++) 00176 { 00177 for (int32_t k=0; k<4; k++) 00178 { 00179 int32_t idx=compute_index(j, k); 00180 if (k!=vec[j]) 00181 { 00182 variance[idx]+=mean[idx]*mean[idx]/num_vectors; 00183 variance[idx+num_params]+=mean[idx+num_params]*mean[idx+num_params]/num_vectors; 00184 } 00185 else 00186 { 00187 variance[idx] += CMath::sq(estimate->log_derivative_pos_obsolete(vec[j], j) 00188 -mean[idx])/num_vectors; 00189 variance[idx+num_params] += CMath::sq(estimate->log_derivative_neg_obsolete(vec[j], j) 00190 -mean[idx+num_params])/num_vectors; 00191 } 00192 } 00193 } 00194 00195 l->free_feature_vector(vec, i, free_vec); 00196 } 00197 00198 00199 // compute sum_i m_i^2/s_i^2 00200 sum_m2_s2=0 ; 00201 for (i=1; i<num_params2; i++) 00202 { 00203 if (variance[i]<1e-14) // then it is likely to be numerical inaccuracy 00204 variance[i]=1 ; 00205 00206 //fprintf(stderr, "%i: mean=%1.2e std=%1.2e\n", i, mean[i], std[i]) ; 00207 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ; 00208 } ; 00209 } 00210 00211 // compute sum of 00212 //result -= estimate->log_derivative_pos(avec[i], i)*mean[a_idx]/variance[a_idx] ; 00213 //result -= estimate->log_derivative_neg(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00214 for (i=0; i<l->get_num_vectors(); i++) 00215 { 00216 int32_t alen; 00217 bool free_avec; 00218 uint16_t* avec = l->get_feature_vector(i, alen, free_avec); 00219 00220 float64_t result=0 ; 00221 for (int32_t j=0; j<alen; j++) 00222 { 00223 int32_t a_idx = compute_index(j, avec[j]) ; 00224 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; 00225 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00226 } 00227 ld_mean_lhs[i]=result ; 00228 00229 // precompute posterior-log-odds 00230 plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; 00231 l->free_feature_vector(avec, alen, free_avec); 00232 } ; 00233 00234 if (ld_mean_lhs!=ld_mean_rhs) 00235 { 00236 // compute sum of 00237 //result -= estimate->log_derivative_pos(bvec[i], i)*mean[b_idx]/variance[b_idx] ; 00238 //result -= estimate->log_derivative_neg(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ; 00239 for (i=0; i < r->get_num_vectors(); i++) 00240 { 00241 int32_t alen; 00242 bool free_avec; 00243 uint16_t* avec=r->get_feature_vector(i, alen, free_avec); 00244 00245 float64_t result=0 ; 00246 for (int32_t j=0; j<alen; j++) 00247 { 00248 int32_t a_idx = compute_index(j, avec[j]) ; 00249 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; 00250 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00251 } 00252 ld_mean_rhs[i]=result ; 00253 00254 // precompute posterior-log-odds 00255 plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; 00256 r->free_feature_vector(avec, alen, free_avec); 00257 } ; 00258 } ; 00259 00260 //warning hacky 00261 // 00262 this->lhs=l; 00263 this->rhs=l; 00264 plo_lhs = l_plo_lhs ; 00265 plo_rhs = l_plo_lhs ; 00266 ld_mean_lhs = l_ld_mean_lhs ; 00267 ld_mean_rhs = l_ld_mean_lhs ; 00268 00269 //compute normalize to 1 values 00270 for (i=0; i<l->get_num_vectors(); i++) 00271 { 00272 sqrtdiag_lhs[i]=sqrt(compute(i,i)); 00273 00274 //trap divide by zero exception 00275 if (sqrtdiag_lhs[i]==0) 00276 sqrtdiag_lhs[i]=1e-16; 00277 } 00278 00279 // if lhs is different from rhs (train/test data) 00280 // compute also the normalization for rhs 00281 if (sqrtdiag_lhs!=sqrtdiag_rhs) 00282 { 00283 this->lhs=r; 00284 this->rhs=r; 00285 plo_lhs = l_plo_rhs ; 00286 plo_rhs = l_plo_rhs ; 00287 ld_mean_lhs = l_ld_mean_rhs ; 00288 ld_mean_rhs = l_ld_mean_rhs ; 00289 00290 //compute normalize to 1 values 00291 for (i=0; i<r->get_num_vectors(); i++) 00292 { 00293 sqrtdiag_rhs[i]=sqrt(compute(i,i)); 00294 00295 //trap divide by zero exception 00296 if (sqrtdiag_rhs[i]==0) 00297 sqrtdiag_rhs[i]=1e-16; 00298 } 00299 } 00300 00301 this->lhs=l; 00302 this->rhs=r; 00303 plo_lhs = l_plo_lhs ; 00304 plo_rhs = l_plo_rhs ; 00305 ld_mean_lhs = l_ld_mean_lhs ; 00306 ld_mean_rhs = l_ld_mean_rhs ; 00307 00308 initialized = true ; 00309 return init_normalizer(); 00310 } 00311 00312 void CHistogramWordStringKernel::cleanup() 00313 { 00314 delete[] variance; 00315 variance=NULL; 00316 00317 delete[] mean; 00318 mean=NULL; 00319 00320 if (sqrtdiag_lhs != sqrtdiag_rhs) 00321 delete[] sqrtdiag_rhs; 00322 sqrtdiag_rhs=NULL; 00323 00324 delete[] sqrtdiag_lhs; 00325 sqrtdiag_lhs=NULL; 00326 00327 if (ld_mean_lhs!=ld_mean_rhs) 00328 delete[] ld_mean_rhs ; 00329 ld_mean_rhs=NULL; 00330 00331 delete[] ld_mean_lhs ; 00332 ld_mean_lhs=NULL; 00333 00334 if (plo_lhs!=plo_rhs) 00335 delete[] plo_rhs ; 00336 plo_rhs=NULL; 00337 00338 delete[] plo_lhs ; 00339 plo_lhs=NULL; 00340 00341 num_params2=0; 00342 num_params=0; 00343 num_symbols=0; 00344 sum_m2_s2=0; 00345 initialized = false; 00346 00347 CKernel::cleanup(); 00348 } 00349 00350 float64_t CHistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b) 00351 { 00352 int32_t alen, blen; 00353 bool free_avec, free_bvec; 00354 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec); 00355 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec); 00356 // can only deal with strings of same length 00357 ASSERT(alen==blen); 00358 00359 float64_t result = plo_lhs[idx_a]*plo_rhs[idx_b]/variance[0]; 00360 result+= sum_m2_s2 ; // does not contain 0-th element 00361 00362 for (int32_t i=0; i<alen; i++) 00363 { 00364 if (avec[i]==bvec[i]) 00365 { 00366 int32_t a_idx = compute_index(i, avec[i]) ; 00367 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ; 00368 result += dd*dd/variance[a_idx] ; 00369 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ; 00370 result += dd*dd/variance[a_idx+num_params] ; 00371 } ; 00372 } 00373 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ; 00374 00375 if (initialized) 00376 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ; 00377 00378 #ifdef BLABLA 00379 float64_t result2 = compute_slow(idx_a, idx_b) ; 00380 if (fabs(result - result2)>1e-10) 00381 SG_ERROR("new=%e old = %e diff = %e\n", result, result2, result - result2); 00382 #endif 00383 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec); 00384 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec); 00385 return result; 00386 } 00387 00388 #ifdef BLABLA 00389 00390 float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b) 00391 { 00392 int32_t alen, blen; 00393 bool free_avec, free_bvec; 00394 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec); 00395 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec); 00396 // can only deal with strings of same length 00397 ASSERT(alen==blen); 00398 00399 float64_t result=(estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])* 00400 (estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]); 00401 result+= sum_m2_s2 ; // does not contain 0-th element 00402 00403 for (int32_t i=0; i<alen; i++) 00404 { 00405 int32_t a_idx = compute_index(i, avec[i]) ; 00406 int32_t b_idx = compute_index(i, bvec[i]) ; 00407 00408 if (avec[i]==bvec[i]) 00409 { 00410 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ; 00411 result += dd*dd/variance[a_idx] ; 00412 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ; 00413 result += dd*dd/variance[a_idx+num_params] ; 00414 } ; 00415 00416 result -= estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ; 00417 result -= estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ; 00418 result -= estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00419 result -= estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ; 00420 } 00421 00422 if (initialized) 00423 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ; 00424 00425 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec); 00426 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec); 00427 return result; 00428 } 00429 00430 #endif