|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _SQRTDIAGKERNELNORMALIZER_H___ 00012 #define _SQRTDIAGKERNELNORMALIZER_H___ 00013 00014 #include "kernel/KernelNormalizer.h" 00015 #include "kernel/CommWordStringKernel.h" 00016 00017 namespace shogun 00018 { 00029 class CSqrtDiagKernelNormalizer : public CKernelNormalizer 00030 { 00031 public: 00036 CSqrtDiagKernelNormalizer(bool use_opt_diag=false): sqrtdiag_lhs(NULL), 00037 sqrtdiag_rhs(NULL), use_optimized_diagonal_computation(use_opt_diag) 00038 { 00039 } 00040 00042 virtual ~CSqrtDiagKernelNormalizer() 00043 { 00044 delete[] sqrtdiag_lhs; 00045 delete[] sqrtdiag_rhs; 00046 } 00047 00050 virtual bool init(CKernel* k) 00051 { 00052 ASSERT(k); 00053 int32_t num_lhs=k->get_num_vec_lhs(); 00054 int32_t num_rhs=k->get_num_vec_rhs(); 00055 ASSERT(num_lhs>0); 00056 ASSERT(num_rhs>0); 00057 00058 CFeatures* old_lhs=k->lhs; 00059 CFeatures* old_rhs=k->rhs; 00060 00061 k->lhs=old_lhs; 00062 k->rhs=old_lhs; 00063 bool r1=alloc_and_compute_diag(k, sqrtdiag_lhs, num_lhs); 00064 00065 k->lhs=old_rhs; 00066 k->rhs=old_rhs; 00067 bool r2=alloc_and_compute_diag(k, sqrtdiag_rhs, num_rhs); 00068 00069 k->lhs=old_lhs; 00070 k->rhs=old_rhs; 00071 00072 return r1 && r2; 00073 } 00074 00080 inline virtual float64_t normalize( 00081 float64_t value, int32_t idx_lhs, int32_t idx_rhs) 00082 { 00083 float64_t sqrt_both=sqrtdiag_lhs[idx_lhs]*sqrtdiag_rhs[idx_rhs]; 00084 return value/sqrt_both; 00085 } 00086 00091 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs) 00092 { 00093 return value/sqrtdiag_lhs[idx_lhs]; 00094 } 00095 00100 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs) 00101 { 00102 return value/sqrtdiag_rhs[idx_rhs]; 00103 } 00104 00105 public: 00110 bool alloc_and_compute_diag(CKernel* k, float64_t* &v, int32_t num) 00111 { 00112 delete[] v; 00113 v=new float64_t[num]; 00114 00115 for (int32_t i=0; i<num; i++) 00116 { 00117 if (k->get_kernel_type() == K_COMMWORDSTRING) 00118 { 00119 if (use_optimized_diagonal_computation) 00120 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_diag(i)); 00121 else 00122 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_helper(i,i, true)); 00123 } 00124 else 00125 v[i]=sqrt(k->compute(i,i)); 00126 00127 if (v[i]==0.0) 00128 v[i]=1e-16; /* avoid divide by zero exception */ 00129 } 00130 00131 return (v!=NULL); 00132 } 00133 00135 inline virtual const char* get_name() const { return "SqrtDiagKernelNormalizer"; } 00136 00137 protected: 00139 float64_t* sqrtdiag_lhs; 00141 float64_t* sqrtdiag_rhs; 00143 bool use_optimized_diagonal_computation; 00144 }; 00145 } 00146 #endif