/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.nlp;

import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.component.dep.CDEPBackParser;
import com.googlecode.clearnlp.component.dep.CDEPPassParser;
import com.googlecode.clearnlp.component.pos.CPOSBackTagger;
import com.googlecode.clearnlp.component.pos.CPOSTagger;
import com.googlecode.clearnlp.component.srl.CRolesetClassifier;
import com.googlecode.clearnlp.component.srl.CSRLabeler;
import com.googlecode.clearnlp.component.srl.CSenseClassifier;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.dependency.srl.SRLEval;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.nlp.NLPTrain;
import com.googlecode.clearnlp.reader.JointReader;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTOutput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.pair.ObjectDoublePair;
import java.io.FileInputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Random;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

public class NLPDevelop
extends NLPTrain {
    @Option(name="-d", usage="the directory containing development files (required)", required=true, metaVar="<directory>")
    protected String s_devDir;
    @Option(name="-r", usage="the random seed", required=false, metaVar="<directory>")
    protected int i_rand = 0;
    @Option(name="-g", usage="if set, generate files", required=false, metaVar="<boolean>")
    protected boolean b_generate = false;

    public NLPDevelop() {
    }

    public NLPDevelop(String[] args) {
        this.initArgs(args);
        try {
            this.develop(this.s_configFile, this.s_featureFiles.split(":"), this.s_trainDir, this.s_devDir, this.s_mode);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void develop(String configFile, String[] featureFiles, String trainDir, String devDir, String mode) throws Exception {
        Element eConfig = UTXml.getDocumentElement(new FileInputStream(configFile));
        JointFtrXml[] xmls = this.getFeatureTemplates(featureFiles);
        String[] trainFiles = UTFile.getSortedFileListBySize(trainDir, ".*", true);
        String[] devFiles = UTFile.getSortedFileListBySize(devDir, ".*", true);
        JointReader reader = this.getJointReader(UTXml.getFirstElementByTagName(eConfig, "reader"));
        if (mode.equals("pos")) {
            this.developComponent(eConfig, reader, xmls, trainFiles, devFiles, new CPOSTagger(xmls, this.getLowerSimplifiedForms(reader, xmls[0], trainFiles, -1)), mode, -1);
        } else if (mode.equals("dep")) {
            this.developComponentBoot(eConfig, reader, xmls, trainFiles, devFiles, new CDEPPassParser(xmls), mode, -1);
        } else if (mode.equals("pred")) {
            this.decode(reader, this.getTrainedComponent(eConfig, xmls, trainFiles, null, null, mode, 0, -1), devFiles, mode, mode);
        } else if (mode.equals("role")) {
            this.decode(reader, this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CRolesetClassifier(xmls), mode, -1), devFiles, mode, mode);
        } else if (mode.startsWith("sense")) {
            this.decode(reader, this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CSenseClassifier(xmls, mode.substring(mode.lastIndexOf("_") + 1)), mode, -1), devFiles, mode, mode);
        } else if (mode.equals("srl")) {
            this.developComponentBoot(eConfig, reader, xmls, trainFiles, devFiles, new CSRLabeler(xmls), mode, -1);
        } else if (mode.equals("pos_back")) {
            this.developComponentBoot(eConfig, reader, xmls, trainFiles, devFiles, new CPOSBackTagger(xmls, this.getLowerSimplifiedForms(reader, xmls[0], trainFiles, -1)), mode, -1);
        } else if (mode.equals("dep_back")) {
            this.developComponentBoot(eConfig, reader, xmls, trainFiles, devFiles, new CDEPBackParser(xmls), mode, -1);
        }
    }

    protected double developComponent(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, String[] devFiles, Object[] lexica, String mode, int devId) throws Exception {
        StringTrainSpace[] spaces = this.getStringTrainSpaces(eConfig, xmls, trainFiles, null, lexica, mode, 0, devId);
        Element eTrain = UTXml.getFirstElementByTagName(eConfig, mode);
        int mSize = spaces.length;
        int nUpdate = 1;
        StringModel[] models = new StringModel[mSize];
        double prevScore = -1.0;
        double currScore = 0.0;
        Random rand = new Random(this.i_rand);
        int iter = 0;
        do {
            prevScore = currScore;
            for (int i = 0; i < mSize; ++i) {
                this.updateModel(eTrain, spaces[i], rand, nUpdate++, i);
                models[i] = (StringModel)spaces[i].getModel();
            }
            AbstractStatisticalComponent processor = this.getComponent(xmls, models, lexica, mode);
            currScore = this.decode(reader, processor, devFiles, mode, Integer.toString(iter));
            ++iter;
        } while (prevScore < currScore);
        return prevScore;
    }

    protected double developComponent(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, String[] devFiles, AbstractStatisticalComponent component, String mode, int devId) throws Exception {
        Object[] lexica = component != null ? this.getLexica(component, reader, xmls, trainFiles, devId) : null;
        return this.developComponent(eConfig, reader, xmls, trainFiles, devFiles, lexica, mode, devId);
    }

    protected void developComponentBoot(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, String[] devFiles, AbstractStatisticalComponent component, String mode, int devId) throws Exception {
        double prevScore;
        Object[] lexica = this.getLexica(component, reader, xmls, trainFiles, devId);
        double currScore = 0.0;
        StringModel[] models = null;
        int boot = 0;
        do {
            prevScore = currScore;
            ObjectDoublePair<StringModel[]> p = this.developComponent(eConfig, reader, xmls, trainFiles, devFiles, lexica, models, mode, boot, devId);
            models = (StringModel[])p.o;
            currScore = p.d;
            ++boot;
        } while (-0.01 < currScore - prevScore);
    }

    protected void developComponentBoot2(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, String[] devFiles, AbstractStatisticalComponent component, String mode, int devId) throws Exception {
        double prevScore;
        Object[] lexica = this.getLexica(component, reader, xmls, trainFiles, devId);
        double currScore = 0.0;
        StringModel[] models = null;
        int boot = 0;
        do {
            prevScore = currScore;
            ObjectDoublePair<StringModel[]> p = this.developComponent2(eConfig, reader, xmls, trainFiles, devFiles, lexica, models, mode, boot, devId);
            models = (StringModel[])p.o;
            currScore = p.d;
            ++boot;
        } while (prevScore < currScore);
    }

    private ObjectDoublePair<StringModel[]> developComponent(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, String[] devFiles, Object[] lexica, StringModel[] models, String mode, int boot, int devId) throws Exception {
        int i;
        StringTrainSpace[] spaces = this.getStringTrainSpaces(eConfig, xmls, trainFiles, models, lexica, mode, boot, devId);
        Element eTrain = UTXml.getFirstElementByTagName(eConfig, mode);
        int mSize = spaces.length;
        int nUpdate = 1;
        double prevScore = -1.0;
        double currScore = 0.0;
        Random[] rands = new Random[mSize];
        models = new StringModel[mSize];
        double[][] prevWeights = new double[mSize][];
        for (i = 0; i < mSize; ++i) {
            rands[i] = new Random(this.i_rand);
        }
        do {
            prevScore = currScore;
            for (i = 0; i < mSize; ++i) {
                if (models[i] != null) {
                    double[] d = models[i].getWeights();
                    prevWeights[i] = Arrays.copyOf(d, d.length);
                }
                this.updateModel(eTrain, spaces[i], rands[i], nUpdate, i);
                models[i] = (StringModel)spaces[i].getModel();
            }
            AbstractStatisticalComponent component = this.getComponent(xmls, models, lexica, mode);
            currScore = this.decode(reader, component, devFiles, mode, boot + "." + nUpdate + "." + this.i_rand);
            ++nUpdate;
        } while (prevScore < currScore);
        for (i = 0; i < mSize; ++i) {
            models[i].setWeights(prevWeights[i]);
        }
        return new ObjectDoublePair<StringModel[]>(models, prevScore);
    }

    private ObjectDoublePair<StringModel[]> developComponent2(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, String[] devFiles, Object[] lexica, StringModel[] models, String mode, int boot, int devId) throws Exception {
        StringTrainSpace[] spaces = this.getStringTrainSpaces(eConfig, xmls, trainFiles, models, lexica, mode, boot, devId);
        Element eTrain = UTXml.getFirstElementByTagName(eConfig, mode);
        int mSize = spaces.length;
        double prevScore = -1.0;
        for (int i = 0; i < mSize; ++i) {
            StringModel[] tmp = models;
            models = new StringModel[i + 1];
            for (int j = 0; j < i; ++j) {
                models[j] = tmp[j];
            }
            Random rand = new Random(this.i_rand);
            double[] prevWeights = null;
            double currScore = 0.0;
            int nUpdate = 1;
            do {
                prevScore = currScore;
                if (models[i] != null) {
                    double[] d = models[i].getWeights();
                    prevWeights = Arrays.copyOf(d, d.length);
                }
                this.updateModel(eTrain, spaces[i], rand, nUpdate, i);
                models[i] = (StringModel)spaces[i].getModel();
                AbstractStatisticalComponent component = this.getComponent(xmls, models, lexica, mode);
                currScore = this.decode(reader, component, devFiles, mode, Integer.toString(100 * boot + nUpdate));
                ++nUpdate;
            } while (prevScore < currScore);
            models[i].setWeights(prevWeights);
        }
        return new ObjectDoublePair<StringModel[]>(models, prevScore);
    }

    protected double decode(JointReader reader, AbstractStatisticalComponent component, String[] devFiles, String mode, String ext) throws Exception {
        int[] counts = this.getCounts(mode);
        PrintStream fout = null;
        for (String devFile : devFiles) {
            DEPTree tree;
            if (this.b_generate) {
                fout = UTOutput.createPrintBufferedFileStream(devFile + "." + ext);
            }
            reader.open(UTInput.createBufferedFileReader(devFile));
            while ((tree = reader.next()) != null) {
                component.process(tree);
                component.countAccuracy(counts);
                if (!this.b_generate) continue;
                fout.println(this.toString(tree, mode) + "\n");
            }
            reader.close();
            if (!this.b_generate) continue;
            fout.close();
        }
        return this.getScore(mode, counts);
    }

    protected int[] getCounts(String mode) {
        if (mode.startsWith("pos") || mode.equals("role") || mode.startsWith("sense")) {
            return new int[2];
        }
        if (mode.equals("dep")) {
            return new int[4];
        }
        if (mode.equals("pred") || mode.equals("srl")) {
            return new int[3];
        }
        if (mode.equals("dep_back")) {
            return new int[5];
        }
        return null;
    }

    protected double getScore(String mode, int[] counts) {
        double score = 0.0;
        if (mode.startsWith("pos") || mode.equals("role") || mode.startsWith("sense")) {
            score = 100.0 * (double)counts[1] / (double)counts[0];
            System.out.printf("- ACC: %5.2f (%d/%d)\n", score, counts[1], counts[0]);
        } else if (mode.equals("dep")) {
            String[] labels = new String[]{"T", "LAS", "UAS", "LS"};
            this.printScores(labels, counts);
            score = 100.0 * (double)counts[1] / (double)counts[0];
        } else if (mode.equals("pred") || mode.equals("srl")) {
            double p = 100.0 * (double)counts[0] / (double)counts[1];
            double r = 100.0 * (double)counts[0] / (double)counts[2];
            score = SRLEval.getF1(p, r);
            System.out.printf("P: %5.2f ", p);
            System.out.printf("R: %5.2f ", r);
            System.out.printf("F1: %5.2f\n", score);
        } else if (mode.equals("dep_back")) {
            String[] labels = new String[]{"T", "POS", "LAS", "UAS", "LS"};
            this.printScores(labels, counts);
            score = 100.0 * (double)counts[2] / (double)counts[0];
        }
        return score;
    }

    private void printScores(String[] labels, int[] counts) {
        int t = counts[0];
        int size = counts.length;
        for (int i = 1; i < size; ++i) {
            System.out.printf("%3s: %5.2f (%d/%d)\n", labels[i], 100.0 * (double)counts[i] / (double)t, counts[i], t);
        }
    }

    public static void main(String[] args) {
        new NLPDevelop(args);
    }
}

