/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.transform.encode;

import java.io.IOException;
import java.util.BitSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.Mean;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.matrix.data.FrameBlock;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.transform.TfUtils;
import org.apache.sysml.runtime.transform.encode.Encoder;
import org.apache.sysml.runtime.transform.meta.TfMetaUtils;
import org.apache.sysml.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

public class EncoderMVImpute
extends Encoder {
    private static final long serialVersionUID = 9057868620144662194L;
    private MVMethod[] _mvMethodList = null;
    private MVMethod[] _mvscMethodList = null;
    private BitSet _isMVScaled = null;
    private CM _varFn = CM.getCMFnObject(CMOperator.AggregateOperationTypes.VARIANCE);
    private Mean _meanFn = Mean.getMeanFnObject();
    private KahanObject[] _meanList = null;
    private long[] _countList = null;
    private CM_COV_Object[] _varList = null;
    private int[] _scnomvList = null;
    private MVMethod[] _scnomvMethodList = null;
    private KahanObject[] _scnomvMeanList = null;
    private long[] _scnomvCountList = null;
    private CM_COV_Object[] _scnomvVarList = null;
    private String[] _replacementList = null;
    private String[] _NAstrings = null;
    private List<Integer> _rcList = null;
    private HashMap<Integer, HashMap<String, Long>> _hist = null;

    public String[] getReplacements() {
        return this._replacementList;
    }

    public KahanObject[] getMeans() {
        return this._meanList;
    }

    public CM_COV_Object[] getVars() {
        return this._varList;
    }

    public KahanObject[] getMeans_scnomv() {
        return this._scnomvMeanList;
    }

    public CM_COV_Object[] getVars_scnomv() {
        return this._scnomvVarList;
    }

    public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int clen) throws JSONException {
        super(null, clen);
        int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, "impute");
        this.initColList(collist);
        this.parseMethodsAndReplacments(parsedSpec);
        this._hist = new HashMap();
    }

    public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, String[] NAstrings, int clen) throws JSONException {
        super(null, clen);
        boolean isMV = parsedSpec.containsKey("impute");
        boolean isSC = parsedSpec.containsKey("scale");
        this._NAstrings = NAstrings;
        if (!isMV) {
            this._colList = null;
            this._mvMethodList = null;
            this._meanList = null;
            this._countList = null;
            this._replacementList = null;
        } else {
            JSONObject mvobj = (JSONObject)parsedSpec.get("impute");
            JSONArray mvattrs = (JSONArray)mvobj.get("attributes");
            JSONArray mvmthds = (JSONArray)mvobj.get("methods");
            int mvLength = mvattrs.size();
            this._colList = new int[mvLength];
            this._mvMethodList = new MVMethod[mvLength];
            this._meanList = new KahanObject[mvLength];
            this._countList = new long[mvLength];
            this._varList = new CM_COV_Object[mvLength];
            this._isMVScaled = new BitSet(this._colList.length);
            this._isMVScaled.clear();
            for (int i = 0; i < this._colList.length; ++i) {
                this._colList[i] = UtilFunctions.toInt(mvattrs.get(i));
                this._mvMethodList[i] = MVMethod.values()[UtilFunctions.toInt(mvmthds.get(i))];
                this._meanList[i] = new KahanObject(0.0, 0.0);
            }
            this._replacementList = new String[mvLength];
            JSONArray constants = (JSONArray)mvobj.get("constants");
            for (int i = 0; i < constants.size(); ++i) {
                this._replacementList[i] = constants.get(i) == null ? "NaN" : constants.get(i).toString();
            }
        }
        if (!isSC) {
            this._scnomvCountList = null;
            this._scnomvMeanList = null;
            this._scnomvVarList = null;
        } else {
            byte mthd;
            int colID;
            int i;
            if (this._colList != null) {
                this._mvscMethodList = new MVMethod[this._colList.length];
            }
            JSONObject scobj = (JSONObject)parsedSpec.get("scale");
            JSONArray scattrs = (JSONArray)scobj.get("attributes");
            JSONArray scmthds = (JSONArray)scobj.get("methods");
            int scLength = scattrs.size();
            int[] _allscaled = new int[scLength];
            int scnomv = 0;
            for (i = 0; i < scLength; ++i) {
                colID = UtilFunctions.toInt(scattrs.get(i));
                mthd = (byte)UtilFunctions.toInt(scmthds.get(i));
                _allscaled[i] = colID;
                int mvidx = this.isApplicable(colID);
                if (mvidx != -1) {
                    this._isMVScaled.set(mvidx);
                    this._mvscMethodList[mvidx] = MVMethod.values()[mthd];
                    this._varList[mvidx] = new CM_COV_Object();
                    continue;
                }
                ++scnomv;
            }
            if (scnomv > 0) {
                this._scnomvList = new int[scnomv];
                this._scnomvMethodList = new MVMethod[scnomv];
                this._scnomvMeanList = new KahanObject[scnomv];
                this._scnomvCountList = new long[scnomv];
                this._scnomvVarList = new CM_COV_Object[scnomv];
                int idx = 0;
                for (i = 0; i < scLength; ++i) {
                    colID = UtilFunctions.toInt(scattrs.get(i));
                    mthd = (byte)UtilFunctions.toInt(scmthds.get(i));
                    if (this.isApplicable(colID) != -1) continue;
                    this._scnomvList[idx] = colID;
                    this._scnomvMethodList[idx] = MVMethod.values()[mthd];
                    this._scnomvMeanList[idx] = new KahanObject(0.0, 0.0);
                    this._scnomvVarList[idx] = new CM_COV_Object();
                    ++idx;
                }
            }
        }
    }

    private void parseMethodsAndReplacments(JSONObject parsedSpec) throws JSONException {
        JSONArray mvspec = (JSONArray)parsedSpec.get("impute");
        this._mvMethodList = new MVMethod[mvspec.size()];
        this._replacementList = new String[mvspec.size()];
        this._meanList = new KahanObject[mvspec.size()];
        this._countList = new long[mvspec.size()];
        for (int i = 0; i < mvspec.size(); ++i) {
            JSONObject mvobj = (JSONObject)mvspec.get(i);
            this._mvMethodList[i] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase());
            if (this._mvMethodList[i] == MVMethod.CONSTANT) {
                this._replacementList[i] = mvobj.getString("value").toString();
            }
            this._meanList[i] = new KahanObject(0.0, 0.0);
        }
    }

    public void prepare(String[] words) throws IOException {
        try {
            int colID;
            int i;
            String w = null;
            if (this._colList != null) {
                for (i = 0; i < this._colList.length; ++i) {
                    colID = this._colList[i];
                    w = UtilFunctions.unquote(words[colID - 1].trim());
                    try {
                        boolean computeMean;
                        if (TfUtils.isNA(this._NAstrings, w)) continue;
                        int n = i;
                        this._countList[n] = this._countList[n] + 1L;
                        boolean bl = computeMean = this._mvMethodList[i] == MVMethod.GLOBAL_MEAN || this._isMVScaled.get(i);
                        if (!computeMean) continue;
                        double d = UtilFunctions.parseToDouble(w);
                        this._meanFn.execute2(this._meanList[i], d, this._countList[i]);
                        if (!this._isMVScaled.get(i) || this._mvscMethodList[i] != MVMethod.GLOBAL_MODE) continue;
                        this._varFn.execute((Data)this._varList[i], d);
                        continue;
                    }
                    catch (NumberFormatException e) {
                        throw new RuntimeException("Encountered \"" + w + "\" in column ID \"" + colID + "\", when expecting a numeric value. Consider adding \"" + w + "\" to na.strings, along with an appropriate imputation method.");
                    }
                }
            }
            if (this._scnomvList != null) {
                for (i = 0; i < this._scnomvList.length; ++i) {
                    colID = this._scnomvList[i];
                    w = UtilFunctions.unquote(words[colID - 1].trim());
                    double d = UtilFunctions.parseToDouble(w);
                    int n = i;
                    this._scnomvCountList[n] = this._scnomvCountList[n] + 1L;
                    this._meanFn.execute2(this._scnomvMeanList[i], d, this._scnomvCountList[i]);
                    if (this._scnomvMethodList[i] != MVMethod.GLOBAL_MODE) continue;
                    this._varFn.execute((Data)this._scnomvVarList[i], d);
                }
            }
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    public MVMethod getMethod(int colID) {
        int idx = this.isApplicable(colID);
        if (idx == -1) {
            return MVMethod.INVALID;
        }
        return this._mvMethodList[idx];
    }

    public long getNonMVCount(int colID) {
        int idx = this.isApplicable(colID);
        return idx == -1 ? 0L : this._countList[idx];
    }

    public String getReplacement(int colID) {
        int idx = this.isApplicable(colID);
        return idx == -1 ? null : this._replacementList[idx];
    }

    @Override
    public MatrixBlock encode(FrameBlock in, MatrixBlock out) {
        this.build(in);
        return this.apply(in, out);
    }

    @Override
    public void build(FrameBlock in) {
        try {
            for (int j = 0; j < this._colList.length; ++j) {
                int colID = this._colList[j];
                if (this._mvMethodList[j] == MVMethod.GLOBAL_MEAN) {
                    long off = this._countList[j];
                    for (int i = 0; i < in.getNumRows(); ++i) {
                        this._meanFn.execute2(this._meanList[j], UtilFunctions.objectToDouble(in.getSchema()[colID - 1], in.get(i, colID - 1)), off + (long)i + 1L);
                    }
                    this._replacementList[j] = String.valueOf(this._meanList[j]._sum);
                    int n = j;
                    this._countList[n] = this._countList[n] + (long)in.getNumRows();
                    continue;
                }
                if (this._mvMethodList[j] != MVMethod.GLOBAL_MODE) continue;
                HashMap<String, Long> hist = this._hist.containsKey(colID) ? this._hist.get(colID) : new HashMap<String, Long>();
                for (int i = 0; i < in.getNumRows(); ++i) {
                    String key = String.valueOf(in.get(i, colID - 1));
                    if (key == null || key.isEmpty()) continue;
                    Long val = (Long)hist.get(key);
                    hist.put(key, val != null ? val + 1L : 1L);
                }
                this._hist.put(colID, hist);
                long max = Long.MIN_VALUE;
                for (Map.Entry e : hist.entrySet()) {
                    if ((Long)e.getValue() <= max) continue;
                    this._replacementList[j] = (String)e.getKey();
                    max = (Long)e.getValue();
                }
            }
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    @Override
    public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
        for (int i = 0; i < in.getNumRows(); ++i) {
            for (int j = 0; j < this._colList.length; ++j) {
                int colID = this._colList[j];
                if (!Double.isNaN(out.quickGetValue(i, colID - 1))) continue;
                out.quickSetValue(i, colID - 1, Double.parseDouble(this._replacementList[j]));
            }
        }
        return out;
    }

    @Override
    public FrameBlock getMetaData(FrameBlock out) {
        for (int j = 0; j < this._colList.length; ++j) {
            out.getColumnMetadata(this._colList[j] - 1).setMvValue(this._replacementList[j]);
        }
        return out;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        for (int j = 0; j < this._colList.length; ++j) {
            int colID = this._colList[j];
            String mvVal = UtilFunctions.unquote(meta.getColumnMetadata(colID - 1).getMvValue());
            if (this._rcList.contains(colID)) {
                Long mvVal2 = meta.getRecodeMap(colID - 1).get(mvVal);
                if (mvVal2 == null) {
                    throw new RuntimeException("Missing recode value for impute value '" + mvVal + "' (colID=" + colID + ").");
                }
                this._replacementList[j] = mvVal2.toString();
                continue;
            }
            this._replacementList[j] = mvVal;
        }
    }

    public void initRecodeIDList(List<Integer> rcList) {
        this._rcList = rcList;
    }

    public HashMap<String, Long> getHistogram(int colID) {
        return this._hist.get(colID);
    }

    public static enum MVMethod {
        INVALID,
        GLOBAL_MEAN,
        GLOBAL_MODE,
        CONSTANT;

    }
}

