/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.temporal.ae.feature.selection;

import com.google.common.base.Function;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import com.google.common.collect.Table;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.util.LinkedHashSet;
import java.util.Set;
import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.transform.TransformableFeature;

public class Chi2FeatureSelection<OUTCOME_T>
extends FeatureSelection<OUTCOME_T> {
    private double chi2Threshold;
    private int numFeatures = 0;
    private Chi2Scorer<OUTCOME_T> chi2Function;
    private boolean yates = false;
    private LinkedHashSet<String> discardedFeatureNames;

    public Chi2FeatureSelection(String name) {
        this(name, 0.0);
    }

    public Chi2FeatureSelection(String name, double threshold) {
        super(name);
        this.chi2Threshold = threshold;
    }

    public Chi2FeatureSelection(String name, double threshold, boolean yates) {
        super(name);
        this.chi2Threshold = threshold;
        this.yates = yates;
    }

    @Override
    public boolean apply(Feature feature) {
        return this.selectedFeatureNames.contains(this.getFeatureName(feature));
    }

    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        if (this.chi2Threshold < 0.0 || this.chi2Threshold > 1.0) {
            System.err.println("Feature Selection threshold should be from 0 to 1");
            System.exit(0);
        }
        this.chi2Function = new Chi2Scorer(this.yates);
        for (Instance<OUTCOME_T> instance : instances) {
            Object outcome = instance.getOutcome();
            for (Feature feature : instance.getFeatures()) {
                if (!this.isTransformable(feature)) continue;
                for (Feature untransformedFeature : ((TransformableFeature)feature).getFeatures()) {
                    this.chi2Function.update(this.getFeatureName(untransformedFeature), outcome, 1);
                }
            }
        }
        Set featureNames = this.chi2Function.featValueClassCount.rowKeySet();
        Ordering ordering = Ordering.natural().onResultOf(this.chi2Function).reverse();
        int totalFeatures = featureNames.size();
        this.numFeatures = (int)Math.round((double)totalFeatures * this.chi2Threshold);
        this.selectedFeatureNames = Sets.newLinkedHashSet((Iterable)ordering.immutableSortedCopy((Iterable)featureNames).subList(0, this.numFeatures));
        this.discardedFeatureNames = Sets.newLinkedHashSet((Iterable)ordering.immutableSortedCopy((Iterable)featureNames).subList(this.numFeatures, totalFeatures));
        this.isTrained = true;
    }

    public void save(URI uri) throws IOException {
        if (!this.isTrained) {
            throw new IllegalStateException("Cannot save before training");
        }
        File out = new File(uri);
        String uriPath = uri.getPath();
        int lastIndex = uriPath.lastIndexOf(46);
        String discardPath = (lastIndex >= 0 ? uriPath.substring(0, lastIndex) : uriPath) + "_discarded.dat";
        File discardOut = new File(discardPath);
        BufferedWriter writer = new BufferedWriter(new FileWriter(out));
        BufferedWriter diswriter = new BufferedWriter(new FileWriter(discardOut));
        for (String feature : this.selectedFeatureNames) {
            writer.append(String.format("%s\t%f\n", feature, this.chi2Function.score(feature)));
        }
        for (String feature : this.discardedFeatureNames) {
            diswriter.append(String.format("%s\t%f\n", feature, this.chi2Function.score(feature)));
        }
        writer.close();
        diswriter.close();
    }

    public void load(URI uri) throws IOException {
        this.selectedFeatureNames = Sets.newLinkedHashSet();
        File in = new File(uri);
        BufferedReader reader = new BufferedReader(new FileReader(in));
        String line = null;
        for (int n = 0; (line = reader.readLine()) != null && n < this.numFeatures; ++n) {
            String[] featureValuePair = line.split("\t");
            this.selectedFeatureNames.add(featureValuePair[0]);
        }
        reader.close();
        this.isTrained = true;
    }

    private static class Chi2Scorer<OUTCOME_T>
    implements Function<String, Double> {
        protected Multiset<OUTCOME_T> classCounts = HashMultiset.create();
        protected Table<String, OUTCOME_T, Integer> featValueClassCount = HashBasedTable.create();
        private boolean yates = false;

        public Chi2Scorer(boolean yate) {
            this.yates = yate;
        }

        public void update(String featureName, OUTCOME_T outcome, int occurrences) {
            Integer count = (Integer)this.featValueClassCount.get((Object)featureName, outcome);
            if (count == null) {
                count = 0;
            }
            this.featValueClassCount.put((Object)featureName, outcome, (Object)(count + occurrences));
            this.classCounts.add(outcome, occurrences);
        }

        public Double apply(String featureName) {
            return this.score(featureName);
        }

        public double score(String featureName) {
            int numOfClass = this.classCounts.elementSet().size();
            int[] posiOutcomeCounts = new int[numOfClass];
            int[] outcomeCounts = new int[numOfClass];
            int classId = 0;
            int posiFeatCount = 0;
            for (Object clas : this.classCounts.elementSet()) {
                posiOutcomeCounts[classId] = this.featValueClassCount.contains((Object)featureName, clas) ? (Integer)this.featValueClassCount.get((Object)featureName, clas) : 0;
                posiFeatCount += posiOutcomeCounts[classId];
                outcomeCounts[classId] = this.classCounts.count(clas);
                ++classId;
            }
            int n = this.classCounts.size();
            int negaFeatCount = n - posiFeatCount;
            double chi2val = 0.0;
            if (posiFeatCount == 0 || posiFeatCount == n) {
                return chi2val;
            }
            for (int lbl = 0; lbl < numOfClass; ++lbl) {
                double expected = (double)outcomeCounts[lbl] / (double)n * (double)posiFeatCount;
                if (expected > 0.0) {
                    double diff = Math.abs((double)posiOutcomeCounts[lbl] - expected);
                    if (this.yates) {
                        diff -= 0.5;
                    }
                    if (diff > 0.0) {
                        chi2val += Math.pow(diff, 2.0) / expected;
                    }
                }
                expected = (double)outcomeCounts[lbl] / (double)n * (double)negaFeatCount;
                double observ = outcomeCounts[lbl] - posiOutcomeCounts[lbl];
                if (!(expected > 0.0)) continue;
                double diff = Math.abs(observ - expected);
                if (this.yates) {
                    diff -= 0.5;
                }
                if (!(diff > 0.0)) continue;
                chi2val += Math.pow(diff, 2.0) / expected;
            }
            return chi2val;
        }
    }
}

