|
SHOGUN v0.9.3
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Written (W) 2009 Soeren Sonnnenburg 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include "lib/common.h" 00013 #include "lib/Mathematics.h" 00014 #include "kernel/AUCKernel.h" 00015 #include "features/SimpleFeatures.h" 00016 #include "lib/io.h" 00017 00018 using namespace shogun; 00019 00020 CAUCKernel::CAUCKernel(int32_t size, CKernel* s) 00021 : CSimpleKernel<uint16_t>(size), subkernel(s) 00022 { 00023 SG_REF(subkernel); 00024 } 00025 00026 CAUCKernel::~CAUCKernel() 00027 { 00028 SG_UNREF(subkernel); 00029 cleanup(); 00030 } 00031 00032 CLabels* CAUCKernel::setup_auc_maximization(CLabels* labels) 00033 { 00034 SG_INFO( "setting up AUC maximization\n") ; 00035 ASSERT(labels); 00036 ASSERT(labels->is_two_class_labeling()); 00037 00038 // get the original labels 00039 int32_t num=0; 00040 ASSERT(labels); 00041 int32_t* int_labels=labels->get_int_labels(num); 00042 ASSERT(subkernel->get_num_vec_rhs()==num); 00043 00044 // count positive and negative 00045 int32_t num_pos=0; 00046 int32_t num_neg=0; 00047 00048 for (int32_t i=0; i<num; i++) 00049 { 00050 if (int_labels[i]==1) 00051 num_pos++; 00052 else 00053 num_neg++; 00054 } 00055 00056 // create AUC features and labels (alternate labels) 00057 int32_t num_auc = num_pos*num_neg; 00058 SG_INFO("num_pos: %i num_neg: %i num_auc: %i\n", num_pos, num_neg, num_auc); 00059 00060 uint16_t* features_auc = new uint16_t[num_auc*2]; 00061 int32_t* labels_auc = new int32_t[num_auc]; 00062 int32_t n=0 ; 00063 00064 for (int32_t i=0; i<num; i++) 00065 { 00066 if (int_labels[i]!=1) 00067 continue; 00068 00069 for (int32_t j=0; j<num; j++) 00070 { 00071 if (int_labels[j]!=-1) 00072 continue; 00073 00074 // create about as many positively as negatively labeled examples 00075 if (n%2==0) 00076 { 00077 features_auc[n*2]=i; 00078 features_auc[n*2+1]=j; 00079 labels_auc[n]=1; 00080 } 00081 else 00082 { 00083 features_auc[n*2]=j; 00084 features_auc[n*2+1]=i; 00085 labels_auc[n]=-1; 00086 } 00087 00088 n++; 00089 ASSERT(n<=num_auc); 00090 } 00091 } 00092 00093 // create label object and attach it to svm 00094 CLabels* lab_auc = new CLabels(num_auc); 00095 lab_auc->set_int_labels(labels_auc, num_auc); 00096 SG_REF(lab_auc); 00097 00098 // create feature object 00099 CSimpleFeatures<uint16_t>* f = new CSimpleFeatures<uint16_t>(0); 00100 f->set_feature_matrix(features_auc, 2, num_auc); 00101 00102 // create AUC kernel and attach the features 00103 init(f,f); 00104 00105 delete[] int_labels; 00106 delete[] labels_auc; 00107 00108 return lab_auc; 00109 } 00110 00111 00112 bool CAUCKernel::init(CFeatures* l, CFeatures* r) 00113 { 00114 CSimpleKernel<uint16_t>::init(l, r); 00115 init_normalizer(); 00116 return true; 00117 } 00118 00119 float64_t CAUCKernel::compute(int32_t idx_a, int32_t idx_b) 00120 { 00121 int32_t alen, blen; 00122 bool afree, bfree; 00123 00124 uint16_t* avec=((CSimpleFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, afree); 00125 uint16_t* bvec=((CSimpleFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, bfree); 00126 00127 ASSERT(alen==2); 00128 ASSERT(blen==2); 00129 00130 ASSERT(subkernel && subkernel->has_features()); 00131 00132 float64_t k11,k12,k21,k22; 00133 int32_t idx_a1=avec[0], idx_a2=avec[1], idx_b1=bvec[0], idx_b2=bvec[1]; 00134 00135 k11 = subkernel->kernel(idx_a1,idx_b1); 00136 k12 = subkernel->kernel(idx_a1,idx_b2); 00137 k21 = subkernel->kernel(idx_a2,idx_b1); 00138 k22 = subkernel->kernel(idx_a2,idx_b2); 00139 00140 float64_t result = k11+k22-k21-k12; 00141 00142 ((CSimpleFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, afree); 00143 ((CSimpleFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, bfree); 00144 00145 return result; 00146 }