/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.rewrite.HopRewriteRule;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.utils.Explain;

public class RewriteMatrixMultChainOptimizationTranspose
extends HopRewriteRule {
    private static final Boolean PUSH_DOWN_TRANSPOSE = true;

    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        if (roots == null) {
            return null;
        }
        for (Hop h : roots) {
            this.rule_OptimizeMMChains(h, state);
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return null;
        }
        this.rule_OptimizeMMChains(root, state);
        return root;
    }

    private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state) {
        if (!hop.isVisited()) {
            if (HopRewriteUtils.isMatrixMultiply(hop) && !((AggBinaryOp)hop).hasLeftPMInput()) {
                this.prepAndOptimizeMMChain(hop, state);
            }
            for (Hop hi : hop.getInput()) {
                this.rule_OptimizeMMChains(hi, state);
            }
            hop.setVisited();
        }
    }

    private void prepAndOptimizeMMChain(Hop hop, ProgramRewriteStatus state) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("MM Chain Optimization for HOP: (" + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ", " + hop.getName() + ")"));
        }
        ArrayList<Hop> mmOperators = new ArrayList<Hop>();
        mmOperators.add(hop);
        ArrayList<Hop> mmChain = new ArrayList<Hop>(hop.getInput());
        if (PUSH_DOWN_TRANSPOSE.booleanValue()) {
            this.checkChainForTransposeAndRewrite(mmChain, hop);
        }
        int mmChainIndex = 0;
        while (mmChainIndex < mmChain.size()) {
            boolean expandable = false;
            Hop h = mmChain.get(mmChainIndex);
            if (HopRewriteUtils.isMatrixMultiply(h) && !h.isVisited()) {
                boolean bl = expandable = h.getParent().size() <= 1 && RewriteMatrixMultChainOptimizationTranspose.inputCount(h.getParent().get(0), h) <= 1;
                if (!expandable) break;
            }
            h.setVisited();
            if (!expandable) {
                ++mmChainIndex;
                continue;
            }
            List<Hop> tempList = mmChain.get(mmChainIndex).getInput();
            if (tempList.size() != 2) {
                throw new HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs.");
            }
            mmOperators.add(mmChain.get(mmChainIndex));
            mmChain.set(mmChainIndex, tempList.get(0));
            mmChain.add(mmChainIndex + 1, tempList.get(1));
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)"Identified MM Chain: ");
            for (Hop h : mmChain) {
                RewriteMatrixMultChainOptimizationTranspose.logTraceHop(h, 1);
            }
        }
        if (mmChain.size() != 2) {
            this.optimizeMMChain(hop, mmChain, mmOperators, state);
        }
    }

    protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
        double[] dimsArray = new double[mmChain.size() + 1];
        boolean dimsKnown = RewriteMatrixMultChainOptimizationTranspose.getDimsArray(hop, mmChain, dimsArray);
        if (dimsKnown) {
            RewriteMatrixMultChainOptimizationTranspose.clearLinksWithinChain(hop, mmOperators);
            int size = mmChain.size();
            int[][] split = RewriteMatrixMultChainOptimizationTranspose.mmChainDP(dimsArray, mmChain.size());
            LOG.trace((Object)"Optimal MM Chain: ");
            this.mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, new MutableInt(1), split, 1);
        }
    }

    private static int[][] mmChainDP(double[] dimArray, int size) {
        double[][] dpMatrix = new double[size][size];
        int[][] split = new int[size][size];
        for (int i = 0; i < size; ++i) {
            Arrays.fill(dpMatrix[i], 0.0);
            Arrays.fill(split[i], -1);
        }
        for (int l = 2; l <= size; ++l) {
            for (int i = 0; i < size - l + 1; ++i) {
                int j = i + l - 1;
                dpMatrix[i][j] = Double.MAX_VALUE;
                for (int k = i; k <= j - 1; ++k) {
                    double cost = dpMatrix[i][k] + dpMatrix[k + 1][j] + dimArray[i] * dimArray[k + 1] * dimArray[j + 1];
                    if (!(cost < dpMatrix[i][j])) continue;
                    dpMatrix[i][j] = cost;
                    split[i][j] = k;
                }
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("mmchainopt [i=" + (i + 1) + ",j=" + (j + 1) + "]: costs = " + dpMatrix[i][j] + ", split = " + (split[i][j] + 1)));
            }
        }
        return split;
    }

    protected final void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, MutableInt opIndex, int[][] split, int level) {
        String offset;
        if (i == j) {
            RewriteMatrixMultChainOptimizationTranspose.logTraceHop(h, level);
            return;
        }
        if (LOG.isTraceEnabled()) {
            offset = Explain.getIdentation(level);
            LOG.trace((Object)(offset + "("));
        }
        if (i == split[i][j]) {
            h.getInput().add(mmChain.get(i));
            mmChain.get(i).getParent().add(h);
        } else {
            int ix = opIndex.getValue();
            opIndex.increment();
            h.getInput().add(mmOperators.get(ix));
            mmOperators.get(ix).getParent().add(h);
        }
        if (split[i][j] + 1 == j) {
            h.getInput().add(mmChain.get(j));
            mmChain.get(j).getParent().add(h);
        } else {
            int ix = opIndex.getValue();
            opIndex.increment();
            h.getInput().add(mmOperators.get(ix));
            mmOperators.get(ix).getParent().add(h);
        }
        this.mmChainRelinkHops(h.getInput(0), i, split[i][j], mmChain, mmOperators, opIndex, split, level + 1);
        this.mmChainRelinkHops(h.getInput(1), split[i][j] + 1, j, mmChain, mmOperators, opIndex, split, level + 1);
        h.refreshSizeInformation();
        if (LOG.isTraceEnabled()) {
            offset = Explain.getIdentation(level);
            LOG.trace((Object)(offset + ")"));
        }
    }

    protected static void clearLinksWithinChain(Hop hop, ArrayList<Hop> operators) {
        for (int i = 0; i < operators.size(); ++i) {
            Hop op = operators.get(i);
            if (op.getInput().size() != 2 || i > 0 && op.getParent().size() > 1) {
                throw new HopsException(hop.printErrorLocation() + "Unexpected error while applying optimization on matrix-mult chain. \n");
            }
            Hop input1 = op.getInput(0);
            Hop input2 = op.getInput(1);
            op.getInput().clear();
            input1.getParent().remove(op);
            input2.getParent().remove(op);
        }
    }

    protected static boolean getDimsArray(Hop hop, ArrayList<Hop> chain, double[] dimsArray) {
        boolean dimsKnown = true;
        for (Hop value : chain) {
            if (value.getDim1() > 0L && value.getDim2() > 0L) continue;
            dimsKnown = false;
        }
        if (dimsKnown) {
            for (int i = 0; i < chain.size(); ++i) {
                if (i == 0) {
                    dimsArray[i] = chain.get(i).getDim1();
                    if (dimsArray[i] <= 0.0) {
                        throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i]);
                    }
                } else if (chain.get(i - 1).getDim2() != chain.get(i).getDim1()) {
                    throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Matrix Dimension Mismatch: " + chain.get(i - 1).getDim2() + " != " + chain.get(i).getDim1());
                }
                dimsArray[i + 1] = chain.get(i).getDim2();
                if (!(dimsArray[i + 1] <= 0.0)) continue;
                throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]);
            }
        }
        return dimsKnown;
    }

    private static int inputCount(Hop p, Hop h) {
        return CollectionUtils.cardinality(h, p.getInput());
    }

    private static void logTraceHop(Hop hop, int level) {
        if (LOG.isTraceEnabled()) {
            String offset = Explain.getIdentation(level);
            LOG.trace((Object)(offset + "Hop " + hop.getName() + "(" + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ") " + hop.getDim1() + "x" + hop.getDim2()));
        }
    }

    private Hop rewriteChainOnTransposeOperator(Hop transposeHop) {
        Hop matrixMultHop = transposeHop.getInput(0);
        Hop firstMatrix = matrixMultHop.getInput(0);
        Hop secondMatrix = matrixMultHop.getInput(1);
        Hop secondTransposeHop = null;
        try {
            secondTransposeHop = (Hop)transposeHop.clone();
        }
        catch (CloneNotSupportedException ex) {
            System.err.println("Error on cloning transpose operator: " + ex.getMessage());
        }
        assert (secondTransposeHop != null);
        this.updateParentOfHop(firstMatrix, transposeHop);
        this.updateParentOfHop(secondMatrix, secondTransposeHop);
        this.updateParentOfHop(transposeHop, matrixMultHop);
        this.updateParentOfHop(secondTransposeHop, matrixMultHop);
        ArrayList<Hop> inputList = new ArrayList<Hop>();
        inputList.add(firstMatrix);
        this.updateAttributesOfHop(transposeHop, inputList, firstMatrix.getName());
        inputList.set(0, secondMatrix);
        this.updateAttributesOfHop(secondTransposeHop, inputList, secondMatrix.getName());
        inputList.set(0, secondTransposeHop);
        inputList.add(transposeHop);
        this.updateAttributesOfHop(matrixMultHop, inputList, firstMatrix.getName());
        return matrixMultHop;
    }

    private void checkChainForTransposeAndRewrite(ArrayList<Hop> mmChain, Hop parentOfChain) {
        for (int mmChainIndex = 0; mmChainIndex < mmChain.size(); ++mmChainIndex) {
            Hop transposeOperatorChild;
            Hop currentChainHop = mmChain.get(mmChainIndex);
            boolean isTransposeOperator = HopRewriteUtils.isReorg(currentChainHop, Types.ReOrgOp.TRANS);
            if (!isTransposeOperator || currentChainHop.isVisited() || currentChainHop.getInput().size() != 1 || !HopRewriteUtils.isMatrixMultiply(transposeOperatorChild = currentChainHop.getInput(0)) || !this.hasOnlyTwoReadsAsInput(transposeOperatorChild) || transposeOperatorChild.getParent().size() != 1) continue;
            int indexInParentInput = parentOfChain.getInput().indexOf(currentChainHop);
            Hop matrixMultHop = this.rewriteChainOnTransposeOperator(currentChainHop);
            this.updateParentOfHop(matrixMultHop, parentOfChain);
            parentOfChain.getInput().set(indexInParentInput, matrixMultHop);
            mmChain.set(mmChainIndex, matrixMultHop);
        }
    }

    private void updateParentOfHop(Hop hopToUpdate, Hop parentToSet) {
        hopToUpdate.getParent().clear();
        hopToUpdate.getParent().add(parentToSet);
    }

    private void updateAttributesOfHop(Hop hopToUpdate, ArrayList<Hop> inputList, String text) {
        hopToUpdate.getInput().clear();
        for (Hop input : inputList) {
            hopToUpdate.getInput().add(input);
        }
        if (HopRewriteUtils.isMatrixMultiply(hopToUpdate)) {
            hopToUpdate.setDim1(inputList.get(0).getDim1());
            hopToUpdate.setDim2(inputList.get(1).getDim2());
        } else {
            hopToUpdate.setDim1(inputList.get(0).getDim2());
            hopToUpdate.setDim2(inputList.get(0).getDim1());
        }
    }

    private boolean hasOnlyTwoReadsAsInput(Hop transposeOperatorChild) {
        if (transposeOperatorChild.getInput().size() == 2) {
            for (Hop hop : transposeOperatorChild.getInput()) {
                if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD, Types.OpOpData.PERSISTENTREAD)) continue;
                return false;
            }
            return true;
        }
        return false;
    }
}

