/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ctakes.temporal.ae.feature.selection;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;
import org.cleartk.ml.feature.transform.OneToOneTrainableExtractor_ImplBase;

public class ZscoreNormalizationExtractor<OUTCOME_T, FOCUS_T extends Annotation>
extends OneToOneTrainableExtractor_ImplBase<OUTCOME_T> {
    private boolean isTrained = false;
    private Map<String, MeanStdPair> meanStdMap;

    public ZscoreNormalizationExtractor(String name) {
        super(name);
    }

    public Feature transform(Feature feature) {
        String featureName = feature.getName();
        Object featureValue = feature.getValue();
        if (featureValue instanceof Number) {
            MeanStdPair stats = this.meanStdMap.get(featureName);
            double mmn = 0.5;
            double value = ((Number)feature.getValue()).doubleValue();
            if (stats != null) {
                mmn = (value - stats.mean) / stats.std;
            }
            return new Feature("Zscore_NORMED_" + featureName, (Object)mmn);
        }
        return feature;
    }

    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        HashMap<String, ZscoreRunningStat> featureStatsMap = new HashMap<String, ZscoreRunningStat>();
        for (Instance<OUTCOME_T> instance : instances) {
            for (Feature feature : instance.getFeatures()) {
                String featureName = feature.getName();
                Object featureValue = feature.getValue();
                if (featureValue instanceof Number) {
                    ZscoreRunningStat stats;
                    if (featureStatsMap.containsKey(featureName)) {
                        stats = (ZscoreRunningStat)featureStatsMap.get(featureName);
                    } else {
                        stats = new ZscoreRunningStat();
                        featureStatsMap.put(featureName, stats);
                    }
                    stats.add(((Number)featureValue).doubleValue());
                    continue;
                }
                System.err.println("Ignore non-numeric feature from normalization: " + featureName + " with Value: " + featureValue);
            }
        }
        this.meanStdMap = new HashMap<String, MeanStdPair>();
        for (Map.Entry entry : featureStatsMap.entrySet()) {
            ZscoreRunningStat stats = (ZscoreRunningStat)entry.getValue();
            this.meanStdMap.put((String)entry.getKey(), new MeanStdPair(stats.getMean(), stats.getStdDev()));
        }
        this.isTrained = true;
    }

    public void save(URI zmusDataUri) throws IOException {
        File out = new File(zmusDataUri);
        BufferedWriter writer = null;
        writer = new BufferedWriter(new FileWriter(out));
        for (Map.Entry<String, MeanStdPair> entry : this.meanStdMap.entrySet()) {
            MeanStdPair pair = entry.getValue();
            writer.append(String.format(Locale.ROOT, "%s\t%f\t%f\n", entry.getKey(), pair.mean, pair.std));
        }
        writer.close();
    }

    public void load(URI zmusDataUri) throws IOException {
        File in = new File(zmusDataUri);
        BufferedReader reader = null;
        this.meanStdMap = new HashMap<String, MeanStdPair>();
        reader = new BufferedReader(new FileReader(in));
        String line = null;
        while ((line = reader.readLine()) != null) {
            String[] featureMeanStddev = line.split("\\t");
            this.meanStdMap.put(featureMeanStddev[0], new MeanStdPair(Double.parseDouble(featureMeanStddev[1]), Double.parseDouble(featureMeanStddev[2])));
        }
        reader.close();
        this.isTrained = true;
    }

    public List<Feature> extract(JCas view, FOCUS_T focusAnnotation) throws CleartkExtractorException {
        return null;
    }

    public static class ZscoreRunningStat
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private List<Double> data;
        private double sum;
        private double mean;
        private int n;

        public ZscoreRunningStat() {
            this.clear();
        }

        public void add(double x) {
            ++this.n;
            this.sum += x;
            this.mean = this.sum / (double)this.n;
        }

        public void clear() {
            this.data = new ArrayList<Double>();
            this.sum = 0.0;
            this.n = 0;
            this.mean = 0.0;
        }

        public int getNumSamples() {
            return this.n;
        }

        private double getVariance() {
            double temp = 0.0;
            for (double a : this.data) {
                temp += (this.mean - a) * (this.mean - a);
            }
            return temp / (double)this.n;
        }

        public double getStdDev() {
            return Math.sqrt(this.getVariance());
        }

        public double getMean() {
            return this.mean;
        }
    }

    private static class MeanStdPair {
        public double mean;
        public double std;

        public MeanStdPair(double mean, double std) {
            this.mean = mean;
            this.std = std;
        }
    }
}

