/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.dependency.srl;

import com.carrotsearch.hppc.IntOpenHashSet;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.dependency.DEPArc;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.dependency.srl.AbstractSRLabeler;
import com.googlecode.clearnlp.dependency.srl.SRLLib;
import com.googlecode.clearnlp.feature.xml.FtrToken;
import com.googlecode.clearnlp.feature.xml.SRLFtrXml;
import com.googlecode.clearnlp.util.UTOutput;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import com.googlecode.clearnlp.util.pair.IntIntPair;
import com.googlecode.clearnlp.util.pair.StringIntPair;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;

public class SRLabeler
extends AbstractSRLabeler {
    public static final int MODEL_SIZE = 2;
    public static final int MODEL_LEFT = 0;
    public static final int MODEL_RIGHT = 1;
    protected static final int PATH_ALL = 0;
    protected static final int PATH_UP = 1;
    protected static final int PATH_DOWN = 2;
    protected static final int SUBCAT_ALL = 0;
    protected static final int SUBCAT_LEFT = 1;
    protected static final int SUBCAT_RIGHT = 2;
    protected static final String LB_NO_ARG = "N";
    protected SRLFtrXml f_xml;
    protected StringTrainSpace[] s_spaces;
    protected StringModel[] s_models;
    protected DEPTree d_tree;
    protected int i_pred;
    protected int i_arg;
    protected int n_preds;
    protected IntIntPair n_trans;
    protected PrintStream f_trans;
    protected StringIntPair[][] g_heads;
    protected DEPNode[] lm_deps;
    protected DEPNode[] rm_deps;
    protected DEPNode[] ln_sibs;
    protected DEPNode[] rn_sibs;
    protected DEPNode d_lca;
    protected IntOpenHashSet s_skip;
    protected List<String> l_argns;
    protected Prob1DMap m_down;
    protected Prob1DMap m_up;
    protected Set<String> s_down;
    protected Set<String> s_up;

    public SRLabeler() {
        super((byte)0);
        this.m_down = new Prob1DMap();
        this.m_up = new Prob1DMap();
    }

    public SRLabeler(SRLFtrXml xml, StringTrainSpace[] spaces, Set<String> sDown, Set<String> sUp) {
        super((byte)1);
        this.f_xml = xml;
        this.s_spaces = spaces;
        this.s_down = sDown;
        this.s_up = sUp;
    }

    public SRLabeler(SRLFtrXml xml, StringModel[] models, Set<String> sDown, Set<String> sUp) {
        super((byte)2);
        this.f_xml = xml;
        this.s_models = models;
        this.s_down = sDown;
        this.s_up = sUp;
    }

    public SRLabeler(SRLFtrXml xml, StringModel[] models, StringTrainSpace[] spaces, Set<String> sDown, Set<String> sUp) {
        super((byte)3);
        this.f_xml = xml;
        this.s_models = models;
        this.s_spaces = spaces;
        this.s_down = sDown;
        this.s_up = sUp;
    }

    public SRLabeler(PrintStream fout) {
        super((byte)4);
        this.f_trans = fout;
    }

    @Override
    public void saveModel(PrintStream fout) {
    }

    public void saveModel(PrintStream fout, int idx) {
        this.s_models[idx].save(fout);
    }

    public void saveDownSet(PrintStream fout) {
        UTOutput.printSet(fout, this.s_down);
    }

    public void saveUpSet(PrintStream fout) {
        UTOutput.printSet(fout, this.s_up);
    }

    public StringModel[] getModels() {
        return this.s_models;
    }

    public void init(DEPTree tree) {
        this.d_tree = tree;
        this.i_pred = this.getNextPredId(0);
        this.s_skip = new IntOpenHashSet();
        this.l_argns = new ArrayList<String>();
        this.n_trans = new IntIntPair(0, 0);
        this.n_preds = 0;
        if (this.i_flag != 2) {
            this.g_heads = tree.getSHeads();
        }
        this.initArcs();
        tree.clearSHeads();
    }

    private int getNextPredId(int prevId) {
        DEPNode pred = this.d_tree.getNextPredicate(prevId);
        return pred != null ? pred.id : this.d_tree.size();
    }

    protected void initArcs() {
        int size = this.d_tree.size();
        this.lm_deps = new DEPNode[size];
        this.rm_deps = new DEPNode[size];
        this.ln_sibs = new DEPNode[size];
        this.rn_sibs = new DEPNode[size];
        this.d_tree.setDependents();
        for (int i = 1; i < size; ++i) {
            DEPNode curr;
            int j;
            List<DEPArc> deps = this.d_tree.get(i).getDependents();
            if (deps.isEmpty()) continue;
            int len = deps.size();
            DEPArc lmd = deps.get(0);
            DEPArc rmd = deps.get(len - 1);
            if (lmd.getNode().id < i) {
                this.lm_deps[i] = lmd.getNode();
            }
            if (rmd.getNode().id > i) {
                this.rm_deps[i] = rmd.getNode();
            }
            for (j = 1; j < len; ++j) {
                curr = deps.get(j).getNode();
                DEPNode prev = deps.get(j - 1).getNode();
                if (this.ln_sibs[curr.id] != null && this.ln_sibs[curr.id].id >= prev.id) continue;
                this.ln_sibs[curr.id] = prev;
            }
            for (j = 0; j < len - 1; ++j) {
                curr = deps.get(j).getNode();
                DEPNode next = deps.get(j + 1).getNode();
                if (this.rn_sibs[curr.id] != null && this.rn_sibs[curr.id].id <= next.id) continue;
                this.rn_sibs[curr.id] = next;
            }
        }
    }

    private void collect(DEPTree tree) {
        DEPNode pred = tree.getNextPredicate(0);
        tree.setDependents();
        while (pred != null) {
            for (DEPArc arc : pred.getGrandDependents()) {
                this.collectDown(pred, arc.getNode());
            }
            DEPNode head = pred.getHead();
            if (head != null) {
                this.collectUp(pred, head.getHead());
            }
            pred = tree.getNextPredicate(pred.id);
        }
    }

    private void collectDown(DEPNode pred, DEPNode arg) {
        if (arg.isArgumentOf(pred)) {
            for (String path : this.getDUPathList(pred, arg.getHead())) {
                this.m_down.add(path);
            }
        }
        for (DEPArc arc : arg.getDependents()) {
            this.collectDown(pred, arc.getNode());
        }
    }

    private void collectUp(DEPNode pred, DEPNode head) {
        if (head == null) {
            return;
        }
        for (DEPArc arc : head.getDependents()) {
            if (!arc.getNode().isArgumentOf(pred)) continue;
            for (String path : this.getDUPathList(head, pred)) {
                this.m_up.add(path);
            }
        }
        this.collectUp(pred, head.getHead());
    }

    private String getDUPath(DEPNode top, DEPNode bottom) {
        return this.getPathAux(top, bottom, "d", "|", true);
    }

    private List<String> getDUPathList(DEPNode top, DEPNode bottom) {
        ArrayList<String> paths = new ArrayList<String>();
        while (bottom != top) {
            paths.add(this.getDUPath(top, bottom));
            bottom = bottom.getHead();
        }
        return paths;
    }

    public Set<String> getDownSet(int cutoff) {
        return this.m_down.toSet(cutoff);
    }

    public Set<String> getUpSet(int cutoff) {
        return this.m_up.toSet(cutoff);
    }

    public IntIntPair getNumTransitions() {
        return this.n_trans;
    }

    public int getNumPredicates() {
        return this.n_preds;
    }

    @Override
    public void label(DEPTree tree) {
        if (this.i_flag == 0) {
            this.collect(tree);
            return;
        }
        this.init(tree);
        this.labelAux();
        if (this.i_flag == 4) {
            this.f_trans.println();
        }
    }

    private void labelAux() {
        int size = this.d_tree.size();
        while (this.i_pred < size) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            ++this.n_trans.i1;
            this.s_skip.clear();
            this.s_skip.add(this.i_pred);
            this.s_skip.add(0);
            this.l_argns.clear();
            this.d_lca = pred;
            do {
                this.labelAux(pred, this.d_lca);
                this.d_lca = this.d_lca.getHead();
            } while (this.d_lca != null);
            ++this.n_preds;
            this.i_pred = this.getNextPredId(this.i_pred);
        }
    }

    private void labelAux(DEPNode pred, DEPNode head) {
        if (!this.s_skip.contains(head.id)) {
            this.i_arg = head.id;
            this.addArgument(this.getLabel(this.getDirIndex()));
        }
        this.labelDown(pred, head.getDependents());
    }

    private void labelDown(DEPNode pred, List<DEPArc> arcs) {
        for (DEPArc arc : arcs) {
            DEPNode arg = arc.getNode();
            if (this.s_skip.contains(arg.id)) continue;
            this.i_arg = arg.id;
            this.addArgument(this.getLabel(this.getDirIndex()));
            if (this.i_pred != this.d_lca.id || !this.s_down.contains(this.getDUPath(pred, arg))) continue;
            this.labelDown(pred, arg.getDependents());
        }
    }

    private int getDirIndex() {
        return this.i_arg < this.i_pred ? 0 : 1;
    }

    private String getLabel(int idx) {
        StringFeatureVector vector = this.i_flag != 4 ? this.getFeatureVector(this.f_xml) : null;
        String label = null;
        if (this.i_flag == 1) {
            label = this.getGoldArgLabel();
            this.s_spaces[idx].addInstance(label, vector);
        } else if (this.i_flag == 2) {
            label = this.getAutoLabel(idx, vector);
        } else if (this.i_flag == 3) {
            this.s_spaces[idx].addInstance(this.getGoldArgLabel(), vector);
            label = this.getAutoLabel(idx, vector);
        } else {
            label = this.getGoldArgLabel();
        }
        return label;
    }

    private String getGoldArgLabel() {
        for (StringIntPair head : this.g_heads[this.i_arg]) {
            if (head.i != this.i_pred) continue;
            return head.s;
        }
        return LB_NO_ARG;
    }

    private String getAutoLabel(int idx, StringFeatureVector vector) {
        return this.s_models[idx].predictBest((StringFeatureVector)vector).label;
    }

    private void addArgument(String label) {
        this.s_skip.add(this.i_arg);
        ++this.n_trans.i2;
        if (this.i_flag == 4) {
            this.printState(label);
        }
        if (!label.equals(LB_NO_ARG)) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            DEPNode arg = this.d_tree.get(this.i_arg);
            arg.addSHead(pred, label);
            if (SRLLib.isNumberedArgument(label)) {
                this.l_argns.add(label);
            }
        }
    }

    private void printState(String label) {
        StringBuilder build = new StringBuilder();
        build.append(this.i_pred);
        build.append(" -");
        build.append(label);
        build.append("-> ");
        build.append(this.i_arg);
        this.f_trans.println(build.toString());
    }

    @Override
    protected String getField(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("f")) {
            return node.form;
        }
        if (token.isField("m")) {
            return node.lemma;
        }
        if (token.isField("p")) {
            return node.pos;
        }
        if (token.isField("d")) {
            return node.getLabel();
        }
        if (token.isField("n")) {
            return this.getDistance(node);
        }
        Matcher m = SRLFtrXml.P_ARGN.matcher(token.field);
        if (m.find()) {
            int idx = this.l_argns.size() - Integer.parseInt(m.group(1)) - 1;
            return idx >= 0 ? this.l_argns.get(idx) : null;
        }
        m = SRLFtrXml.P_PATH.matcher(token.field);
        if (m.find()) {
            String type = m.group(1);
            int dir = Integer.parseInt(m.group(2));
            return this.getPath(type, dir);
        }
        m = SRLFtrXml.P_SUBCAT.matcher(token.field);
        if (m.find()) {
            String type = m.group(1);
            int dir = Integer.parseInt(m.group(2));
            return this.getSubcat(node, type, dir);
        }
        m = SRLFtrXml.P_FEAT.matcher(token.field);
        if (m.find()) {
            return node.getFeat(m.group(1));
        }
        m = SRLFtrXml.P_BOOLEAN.matcher(token.field);
        if (m.find()) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            int field = Integer.parseInt(m.group(1));
            switch (field) {
                case 0: {
                    return node.isDependentOf(pred) ? token.field : null;
                }
                case 1: {
                    return pred.isDependentOf(node) ? token.field : null;
                }
                case 2: {
                    return pred.isDependentOf(this.d_lca) ? token.field : null;
                }
                case 3: {
                    return pred == this.d_lca ? token.field : null;
                }
                case 4: {
                    return node == this.d_lca ? token.field : null;
                }
            }
        }
        return null;
    }

    @Override
    protected String[] getFields(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("ds")) {
            return this.getDeprelSet(node.getDependents());
        }
        if (token.isField("gds")) {
            return this.getDeprelSet(node.getGrandDependents());
        }
        return null;
    }

    private String[] getDeprelSet(List<DEPArc> deps) {
        if (deps.isEmpty()) {
            return null;
        }
        HashSet<String> set = new HashSet<String>();
        for (DEPArc arc : deps) {
            set.add(arc.getLabel());
        }
        String[] fields = new String[set.size()];
        set.toArray(fields);
        return fields;
    }

    private String getDistance(DEPNode node) {
        int dist = Math.abs(this.i_pred - node.id);
        if (dist <= 5) {
            return "0";
        }
        if (dist <= 10) {
            return "1";
        }
        if (dist <= 15) {
            return "2";
        }
        return "3";
    }

    private String getPath(String type, int dir) {
        DEPNode pred = this.d_tree.get(this.i_pred);
        DEPNode arg = this.d_tree.get(this.i_arg);
        if (dir == 1) {
            if (this.d_lca != pred) {
                return this.getPathAux(this.d_lca, pred, type, "^", true);
            }
        } else if (dir == 2) {
            if (this.d_lca != arg) {
                return this.getPathAux(this.d_lca, arg, type, "|", true);
            }
        } else {
            if (pred == this.d_lca) {
                return this.getPathAux(pred, arg, type, "|", true);
            }
            if (pred.isDescendentOf(arg)) {
                return this.getPathAux(arg, pred, type, "^", true);
            }
            String path = this.getPathAux(this.d_lca, pred, type, "^", true);
            path = path + this.getPathAux(this.d_lca, arg, type, "|", false);
            return path;
        }
        return null;
    }

    private String getPathAux(DEPNode top, DEPNode bottom, String type, String delim, boolean includeTop) {
        StringBuilder build = new StringBuilder();
        DEPNode head = bottom;
        int dist = 0;
        do {
            if (type.equals("p")) {
                build.append(delim);
                build.append(head.pos);
                continue;
            }
            if (type.equals("d")) {
                build.append(delim);
                build.append(head.getLabel());
                continue;
            }
            if (!type.equals("n")) continue;
            ++dist;
        } while ((head = head.getHead()) != top);
        if (type.equals("p")) {
            if (includeTop) {
                build.append(delim);
                build.append(top.pos);
            }
        } else if (type.equals("n")) {
            build.append(delim);
            build.append(dist);
        }
        return build.length() == 0 ? null : build.toString();
    }

    private String getSubcat(DEPNode node, String type, int dir) {
        List<DEPArc> deps = node.getDependents();
        StringBuilder build = new StringBuilder();
        int size = deps.size();
        if (dir == 1) {
            for (int i = 0; i < size; ++i) {
                DEPNode dep = deps.get(i).getNode();
                if (dep.id <= node.id) {
                    this.getSubcatAux(build, dep, type);
                    continue;
                }
                break;
            }
        } else if (dir == 2) {
            for (int i = size - 1; i >= 0; --i) {
                DEPNode dep = deps.get(i).getNode();
                if (dep.id >= node.id) {
                    this.getSubcatAux(build, dep, type);
                    continue;
                }
                break;
            }
        } else {
            for (int i = 0; i < size; ++i) {
                DEPNode dep = deps.get(i).getNode();
                this.getSubcatAux(build, dep, type);
            }
        }
        return build.length() == 0 ? null : build.substring("_".length());
    }

    private void getSubcatAux(StringBuilder build, DEPNode node, String type) {
        build.append("_");
        if (type.equals("p")) {
            build.append(node.pos);
        } else if (type.equals("d")) {
            build.append(node.getLabel());
        }
    }

    private DEPNode getNode(FtrToken token) {
        DEPNode node = null;
        switch (token.source) {
            case 'p': {
                node = this.d_tree.get(this.i_pred);
                break;
            }
            case 'a': {
                node = this.d_tree.get(this.i_arg);
            }
        }
        if (token.relation != null) {
            if (token.isRelation("h")) {
                node = node.getHead();
            } else if (token.isRelation("lmd")) {
                node = this.lm_deps[node.id];
            } else if (token.isRelation("rmd")) {
                node = this.rm_deps[node.id];
            } else if (token.isRelation("lns")) {
                node = this.ln_sibs[node.id];
            } else if (token.isRelation("rns")) {
                node = this.rn_sibs[node.id];
            }
        }
        return node;
    }
}

