/*
 * Decompiled with CFR 0.152.
 */
package hex.util;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;
import Jama.QRDecomposition;
import hex.DataInfo;
import hex.FrameTask;
import hex.Interaction;
import hex.ToEigenVec;
import hex.gam.MatrixFrameUtils.TriDiagonalMatrix;
import hex.gram.Gram;
import hex.util.EigenPair;
import java.util.ArrayList;
import java.util.Arrays;
import jsr166y.ForkJoinTask;
import jsr166y.RecursiveAction;
import org.apache.commons.lang.ArrayUtils;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.Log;

public class LinearAlgebraUtils {
    public static ToEigenVec toEigen = new ToEigenVec(){

        public Vec toEigenVec(Vec src) {
            return LinearAlgebraUtils.toEigen(src);
        }
    };

    public static double[] forwardSolve(double[][] L, double[] b) {
        assert (L != null && L.length == b.length);
        double[] res = new double[b.length];
        for (int i = 0; i < b.length; ++i) {
            res[i] = b[i];
            for (int j = 0; j < i; ++j) {
                int n = i;
                res[n] = res[n] - L[i][j] * res[j];
            }
            int n = i;
            res[n] = res[n] / L[i][i];
        }
        return res;
    }

    public static double[] sqrtDiag(double[][] aMat) {
        int matrixSize = aMat.length;
        double[] answer = new double[matrixSize];
        for (int index = 0; index < matrixSize; ++index) {
            answer[index] = Math.sqrt(aMat[index][index]);
        }
        return answer;
    }

    public static double[][] chol2Inv(double[][] cholR, boolean upperTriag) {
        final int matrixSize = cholR.length;
        final double[][] cholL = upperTriag ? water.util.ArrayUtils.transpose((double[][])cholR) : cholR;
        final double[][] inverted = new double[matrixSize][];
        RecursiveAction[] ras = new RecursiveAction[matrixSize];
        int index = 0;
        while (index < matrixSize) {
            final double[] oneColumn = new double[matrixSize];
            oneColumn[index] = 1.0;
            final int i = index++;
            ras[i] = new RecursiveAction(){

                protected void compute() {
                    double[] upperColumn = LinearAlgebraUtils.forwardSolve(cholL, oneColumn);
                    inverted[i] = Arrays.copyOf(upperColumn, matrixSize);
                }
            };
        }
        ForkJoinTask.invokeAll((ForkJoinTask[])ras);
        final double[][] cholRNew = upperTriag ? cholR : water.util.ArrayUtils.transpose((double[][])cholR);
        int index2 = 0;
        while (index2 < matrixSize) {
            double[] oneColumn = new double[matrixSize];
            oneColumn[index2] = 1.0;
            final int i = index2++;
            ras[i] = new RecursiveAction(){

                protected void compute() {
                    double[] lowerColumn = new double[matrixSize];
                    LinearAlgebraUtils.backwardSolve(cholRNew, inverted[i], lowerColumn);
                    inverted[i] = Arrays.copyOf(lowerColumn, matrixSize);
                }
            };
        }
        ForkJoinTask.invokeAll((ForkJoinTask[])ras);
        return inverted;
    }

    public static double[][] chol2Inv(double[][] cholR) {
        return LinearAlgebraUtils.chol2Inv(cholR, true);
    }

    public static double[][] generateTriDiagMatrix(double[] hj) {
        int matrixSize = hj.length - 1;
        final double[][] lowDiag = new double[matrixSize][];
        RecursiveAction[] ras = new RecursiveAction[matrixSize];
        for (int index = 0; index < matrixSize; ++index) {
            final int rowSize = index + 1;
            final int i = index;
            final double hjIndex = hj[index];
            final double hjIndexP1 = hj[index + 1];
            double oneO3 = 0.3333333333333333;
            double oneO6 = 0.16666666666666666;
            final double[] tempDiag = MemoryManager.malloc8d((int)rowSize);
            ras[i] = new RecursiveAction(){

                protected void compute() {
                    tempDiag[i] = (hjIndex + hjIndexP1) * 0.3333333333333333;
                    if (i > 0) {
                        tempDiag[i - 1] = hjIndex * 0.16666666666666666;
                    }
                    lowDiag[i] = Arrays.copyOf(tempDiag, rowSize);
                }
            };
        }
        ForkJoinTask.invokeAll((ForkJoinTask[])ras);
        return lowDiag;
    }

    public static double[][] generateOrthogonalComplement(double[][] orthMat, double[][] starT, int numBasis, long seed) {
        int index;
        int numOrthVec = orthMat[0].length;
        int vecSize = orthMat.length;
        double[][] orthMatT = water.util.ArrayUtils.transpose((double[][])orthMat);
        double[][] orthMatCompT = new double[numBasis][vecSize];
        double[][] orthMatCompT2 = new double[numBasis][vecSize];
        double[] innerProd = new double[numOrthVec];
        double[] scaleProd = new double[vecSize];
        double[][] orthMatCompT3 = water.util.ArrayUtils.subtract((double[][])LinearAlgebraUtils.generateIdentityMat(vecSize), (double[][])water.util.ArrayUtils.multArrArr((double[][])orthMat, (double[][])orthMatT));
        for (index = 0; index < numBasis; ++index) {
            System.arraycopy(orthMatCompT3[index], 0, orthMatCompT2[index], 0, vecSize);
        }
        LinearAlgebraUtils.applyGramSchmit(orthMatCompT2);
        for (index = 0; index < numBasis; ++index) {
            orthMatCompT[index] = water.util.ArrayUtils.gaussianVector((long)(seed + (long)index), (double[])orthMatCompT[index]);
            LinearAlgebraUtils.genInnerProduct(orthMatT, orthMatCompT[index], innerProd);
            for (int basisInd = 0; basisInd < numOrthVec; ++basisInd) {
                System.arraycopy(orthMatT[basisInd], 0, scaleProd, 0, vecSize);
                water.util.ArrayUtils.mult((double[])scaleProd, (double)innerProd[basisInd]);
                water.util.ArrayUtils.subtract((double[])orthMatCompT[index], (double[])scaleProd, (double[])orthMatCompT[index]);
            }
        }
        LinearAlgebraUtils.applyGramSchmit(orthMatCompT);
        return orthMatCompT;
    }

    public static double[][] generateIdentityMat(int size) {
        double[][] identity = new double[size][size];
        for (int index = 0; index < size; ++index) {
            identity[index][index] = 1.0;
        }
        return identity;
    }

    public static double[][] generateQR(double[][] starT) {
        Matrix starTMat = new Matrix(starT);
        QRDecomposition starTMat_qr = new QRDecomposition(starTMat);
        return starTMat_qr.getQ().getArray();
    }

    public static void genInnerProduct(double[][] mat, double[] vector, double[] innerProd) {
        int numVec = mat.length;
        for (int index = 0; index < numVec; ++index) {
            innerProd[index] = water.util.ArrayUtils.innerProduct((double[])mat[index], (double[])vector);
        }
    }

    public static void applyGramSchmit(double[][] matT) {
        int numVec = matT.length;
        int vecSize = matT[0].length;
        double[] innerProd = new double[numVec];
        double[] scaleVec = new double[vecSize];
        for (int index = 0; index < numVec; ++index) {
            LinearAlgebraUtils.genInnerProduct(matT, matT[index], innerProd);
            for (int indexJ = 0; indexJ < index; ++indexJ) {
                System.arraycopy(matT[indexJ], 0, scaleVec, 0, vecSize);
                water.util.ArrayUtils.mult((double[])scaleVec, (double)innerProd[indexJ]);
                water.util.ArrayUtils.subtract((double[])matT[index], (double[])scaleVec, (double[])matT[index]);
            }
            double mag = 1.0 / water.util.ArrayUtils.l2norm((double[])matT[index]);
            water.util.ArrayUtils.mult((double[])matT[index], (double)mag);
        }
    }

    public static double[][] expandLowTrian2Ful(final double[][] cholL) {
        final int numRows = cholL.length;
        final double[][] result = new double[numRows][];
        RecursiveAction[] ras = new RecursiveAction[numRows];
        int index = 0;
        while (index < numRows) {
            final int i = index++;
            final double[] tempResult = MemoryManager.malloc8d((int)numRows);
            ras[i] = new RecursiveAction(){

                protected void compute() {
                    for (int colIndex = 0; colIndex <= i; ++colIndex) {
                        tempResult[colIndex] = cholL[i][colIndex];
                    }
                    result[i] = Arrays.copyOf(tempResult, numRows);
                }
            };
        }
        ForkJoinTask.invokeAll((ForkJoinTask[])ras);
        return result;
    }

    public static double[][] matrixMultiply(final double[][] A, final double[][] B) {
        final int arow = A[0].length;
        int acol = A.length;
        int bcol = B.length;
        final double[][] result = new double[bcol][];
        RecursiveAction[] ras = new RecursiveAction[acol];
        int index = 0;
        while (index < acol) {
            final int i = index++;
            final double[] tempResult = new double[arow];
            ras[i] = new RecursiveAction(){

                protected void compute() {
                    water.util.ArrayUtils.multArrVec((double[][])A, (double[])B[i], (double[])tempResult);
                    result[i] = Arrays.copyOf(tempResult, arow);
                }
            };
        }
        ForkJoinTask.invokeAll((ForkJoinTask[])ras);
        return result;
    }

    public static double[][] matrixMultiplyTriagonal(final double[][] A, final TriDiagonalMatrix B, boolean transposeResult) {
        final int arow = A.length;
        int bcol = B._size + 2;
        final int lastCol = bcol - 1;
        final int secondLastCol = bcol - 2;
        final int kMinus1 = bcol - 3;
        final int kMinus2 = bcol - 4;
        final double[][] result = new double[bcol][];
        RecursiveAction[] ras = new RecursiveAction[bcol];
        int index = 0;
        while (index < bcol) {
            final int i = index++;
            final double[] tempResult = new double[arow];
            final double[] bCol = new double[B._size];
            ras[i] = new RecursiveAction(){

                protected void compute() {
                    if (i == 0) {
                        bCol[0] = B._first_diag[0];
                    } else if (i == 1) {
                        bCol[0] = B._second_diag[0];
                        bCol[1] = B._first_diag[1];
                    } else if (i == lastCol) {
                        bCol[kMinus1] = B._third_diag[kMinus1];
                    } else if (i == secondLastCol) {
                        bCol[kMinus2] = B._third_diag[kMinus2];
                        bCol[kMinus1] = B._second_diag[kMinus1];
                    } else {
                        bCol[i - 2] = B._third_diag[i - 2];
                        bCol[i - 1] = B._second_diag[i - 1];
                        bCol[i] = B._first_diag[i];
                    }
                    water.util.ArrayUtils.multArrVec((double[][])A, (double[])bCol, (double[])tempResult);
                    result[i] = Arrays.copyOf(tempResult, arow);
                }
            };
        }
        ForkJoinTask.invokeAll((ForkJoinTask[])ras);
        return transposeResult ? (Object)water.util.ArrayUtils.transpose((double[][])result) : result;
    }

    public static double[] backwardSolve(double[][] L, double[] b, double[] res) {
        int lastIndex;
        assert (L != null && L.length == L[0].length && L.length == b.length);
        if (res == null) {
            res = new double[b.length];
        }
        for (int rowIndex = lastIndex = b.length - 1; rowIndex >= 0; --rowIndex) {
            res[rowIndex] = b[rowIndex];
            for (int colIndex = lastIndex; colIndex > rowIndex; --colIndex) {
                int n = rowIndex;
                res[n] = res[n] - L[rowIndex][colIndex] * res[colIndex];
            }
            int n = rowIndex;
            res[n] = res[n] / L[rowIndex][rowIndex];
        }
        return res;
    }

    private static double modifyNumeric(double x, int col, DataInfo dinfo) {
        double y;
        double d = y = Double.isNaN(x) && dinfo._imputeMissing ? dinfo._numNAFill[col] : x;
        if (dinfo._normSub != null && dinfo._normMul != null) {
            y = (y - dinfo._normSub[col]) * dinfo._normMul[col];
        }
        return y;
    }

    public static double[] expandRow(double[] row, DataInfo dinfo, double[] tmp, boolean modify_numeric) {
        for (int col = 0; col < dinfo._cats; ++col) {
            int cidx;
            if (Double.isNaN(row[col])) {
                if (dinfo._imputeMissing) {
                    cidx = dinfo.catNAFill()[col];
                } else {
                    if (!dinfo._catMissing[col]) continue;
                    cidx = dinfo._catOffsets[col + 1] - 1;
                }
            } else {
                cidx = dinfo._catOffsets[col + 1] - dinfo._catOffsets[col] == 1 ? dinfo.getCategoricalId(col, 0) : dinfo.getCategoricalId(col, (int)row[col]);
            }
            if (dinfo._catOffsets[col + 1] - dinfo._catOffsets[col] == 1 && cidx >= 0) {
                tmp[cidx] = row[col];
                continue;
            }
            if (cidx < 0) continue;
            tmp[cidx] = 1.0;
        }
        int chk_cnt = dinfo._cats;
        int exp_cnt = dinfo.numStart();
        for (int col = 0; col < dinfo._nums; ++col) {
            tmp[exp_cnt] = modify_numeric ? LinearAlgebraUtils.modifyNumeric(row[chk_cnt], col, dinfo) : row[chk_cnt];
            ++exp_cnt;
            ++chk_cnt;
        }
        return tmp;
    }

    public static double[][] reshape1DArray(double[] arr, int m, int n) {
        double[][] arr2D = new double[m][n];
        for (int i = 0; i < m; ++i) {
            System.arraycopy(arr, i * n, arr2D[i], 0, n);
        }
        return arr2D;
    }

    public static EigenPair[] createSortedEigenpairs(double[] eigenvalues, double[][] eigenvectors) {
        int count = eigenvalues.length;
        Object[] eigenPairs = new EigenPair[count];
        for (int i = 0; i < count; ++i) {
            eigenPairs[i] = new EigenPair(eigenvalues[i], eigenvectors[i]);
        }
        Arrays.sort(eigenPairs);
        return eigenPairs;
    }

    public static EigenPair[] createReverseSortedEigenpairs(double[] eigenvalues, double[][] eigenvectors) {
        Object[] eigenPairs = LinearAlgebraUtils.createSortedEigenpairs(eigenvalues, eigenvectors);
        ArrayUtils.reverse((Object[])eigenPairs);
        return eigenPairs;
    }

    public static double[] extractEigenvaluesFromEigenpairs(EigenPair[] eigenPairs) {
        int count = eigenPairs.length;
        double[] eigenvalues = new double[count];
        for (int i = 0; i < count; ++i) {
            eigenvalues[i] = eigenPairs[i].eigenvalue;
        }
        return eigenvalues;
    }

    public static double[][] extractEigenvectorsFromEigenpairs(EigenPair[] eigenPairs) {
        int count = eigenPairs.length;
        double[][] eigenvectors = new double[count][];
        for (int i = 0; i < count; ++i) {
            eigenvectors[i] = eigenPairs[i].eigenvector;
        }
        return eigenvectors;
    }

    public static void choleskySymDiagMat(double[][] xx) {
        xx[0][0] = Math.sqrt(xx[0][0]);
        int rowNumber = xx.length;
        for (int row = 1; row < rowNumber; ++row) {
            int lowerDiag = row - 1;
            if (lowerDiag > 0) {
                int kMinus2 = lowerDiag - 1;
                xx[row][lowerDiag] = (xx[row][lowerDiag] - xx[row][kMinus2]) / xx[lowerDiag][lowerDiag];
            } else {
                xx[row][lowerDiag] = xx[row][lowerDiag] / xx[lowerDiag][lowerDiag];
            }
            xx[row][row] = Math.sqrt(xx[row][row] - xx[row][lowerDiag] * xx[row][lowerDiag]);
        }
    }

    public static double[][] computeR(Key<Job> jobKey, DataInfo yinfo, boolean transpose) {
        Gram.GramTask gtsk = new Gram.GramTask(jobKey, yinfo);
        gtsk.doAll(yinfo._adaptedFrame);
        Gram.Cholesky chol = gtsk._gram.cholesky(null);
        double[][] L = chol.getL();
        water.util.ArrayUtils.mult((double[][])L, (double)Math.sqrt(gtsk._nobs));
        return transpose ? L : water.util.ArrayUtils.transpose((double[][])L);
    }

    public static double computeQ(Key<Job> jobKey, DataInfo yinfo, Frame ywfrm, double[][] xx) {
        xx = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolve qrtsk = new ForwardSolve(yinfo, xx);
        qrtsk.doAll(ywfrm);
        return qrtsk._sse;
    }

    public static double[][] computeQ(Key<Job> jobKey, DataInfo yinfo, Frame ywfrm) {
        double[][] xx = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolve qrtsk = new ForwardSolve(yinfo, xx);
        qrtsk.doAll(ywfrm);
        return xx;
    }

    public static double[][] computeQInPlace(Key<Job> jobKey, DataInfo yinfo) {
        double[][] cholL = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolveInPlace qrtsk = new ForwardSolveInPlace(yinfo, cholL);
        qrtsk.doAll(yinfo._adaptedFrame);
        return cholL;
    }

    public static int numColsExp(Frame fr, boolean useAllFactorLevels) {
        int uAFL = useAllFactorLevels ? 0 : 1;
        int cols = 0;
        for (Vec vec : fr.vecs()) {
            cols += vec.isCategorical() && vec.domain() != null ? vec.domain().length - uAFL : 1;
        }
        return cols;
    }

    static double[] multiple(double[] diagYY, int nTot, int nVars) {
        int ny = diagYY.length;
        int i = 0;
        while (i < ny) {
            int n = i++;
            diagYY[n] = diagYY[n] * (double)nTot;
        }
        double[][] uu = new double[ny][ny];
        for (int i2 = 0; i2 < ny; ++i2) {
            for (int j = 0; j < ny; ++j) {
                double yyij = i2 == j ? diagYY[i2] : 0.0;
                uu[i2][j] = (yyij - diagYY[i2] * diagYY[j] / (double)nTot) / ((double)nVars * Math.sqrt(diagYY[i2] * diagYY[j]));
                if (!Double.isNaN(uu[i2][j])) continue;
                uu[i2][j] = 0.0;
            }
        }
        EigenvalueDecomposition eigen = new EigenvalueDecomposition(new Matrix(uu));
        double[] eigenvalues = eigen.getRealEigenvalues();
        double[][] eigenvectors = eigen.getV().getArray();
        int maxIndex = water.util.ArrayUtils.maxIndex((double[])eigenvalues);
        return eigenvectors[maxIndex];
    }

    public static double[] toEigenArray(Vec src) {
        Key source = Key.make();
        Key dest = Key.make();
        Frame train = new Frame(source, new String[]{"enum"}, new Vec[]{src});
        int maxLevels = 1024;
        boolean created = false;
        if (src.cardinality() > maxLevels) {
            DKV.put((Keyed)train);
            created = true;
            Log.info((Object[])new Object[]{"Reducing the cardinality of a categorical column with " + src.cardinality() + " levels to " + maxLevels});
            train = (Frame)Interaction.getInteraction((Key)train._key, (String[])train.names(), (int)maxLevels).execImpl(dest).get();
        }
        DataInfo dinfo = new DataInfo(train, null, 0, true, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, false, true, false, false, false, false, false);
        DKV.put((Keyed)dinfo);
        Gram.GramTask gtsk = (Gram.GramTask)new Gram.GramTask(null, dinfo).doAll(dinfo._adaptedFrame);
        double[] rounded = new double[gtsk._gram._diag.length];
        for (int i = 0; i < rounded.length; ++i) {
            rounded[i] = (float)gtsk._gram._diag[i];
        }
        dinfo.remove();
        double[] array = LinearAlgebraUtils.multiple(rounded, (int)gtsk._nobs, 1);
        if (created) {
            train.remove();
            DKV.remove((Key)source);
        }
        return array;
    }

    public static Vec toEigen(Vec src) {
        Key source = Key.make();
        Key dest = Key.make();
        Frame train = new Frame(source, new String[]{"enum"}, new Vec[]{src});
        int maxLevels = 1024;
        boolean created = false;
        if (src.cardinality() > maxLevels) {
            DKV.put((Keyed)train);
            created = true;
            Log.info((Object[])new Object[]{"Reducing the cardinality of a categorical column with " + src.cardinality() + " levels to " + maxLevels});
            train = (Frame)Interaction.getInteraction((Key)train._key, (String[])train.names(), (int)maxLevels).execImpl(dest).get();
        }
        Vec v = ((ProjectOntoEigenVector)new ProjectOntoEigenVector(LinearAlgebraUtils.toEigenArray(src)).doAll(1, (byte)3, train)).outputFrame().anyVec();
        if (created) {
            train.remove();
            DKV.remove((Key)source);
        }
        return v;
    }

    public static double[] toEigenProjectionArray(Frame _origTrain, Frame _train, boolean expensive) {
        if (expensive && _origTrain != null && _origTrain != _train) {
            ArrayList<Double> projections = new ArrayList<Double>();
            for (int i = 0; i < _origTrain.numCols(); ++i) {
                double[] actProjection;
                Vec currentCol = _origTrain.vec(i);
                if (!currentCol.isCategorical()) continue;
                for (double v : actProjection = LinearAlgebraUtils.toEigenArray(currentCol)) {
                    projections.add(v);
                }
            }
            double[] primitive_projections = new double[projections.size()];
            for (int i = 0; i < projections.size(); ++i) {
                primitive_projections[i] = (Double)projections.get(i);
            }
            return primitive_projections;
        }
        return null;
    }

    public static String getMatrixInString(double[][] matrix) {
        int dimX = matrix.length;
        if (dimX <= 0) {
            return "";
        }
        int dimY = matrix[0].length;
        for (int x = 1; x < dimX; ++x) {
            if (matrix[x].length == dimY) continue;
            return "Stacked matrix!";
        }
        StringBuilder stringOfMatrix = new StringBuilder();
        for (int x = 0; x < dimX; ++x) {
            for (int y = 0; y < dimY; ++y) {
                if (matrix[x][y] > 0.0) {
                    stringOfMatrix.append(' ');
                }
                stringOfMatrix.append(String.format("%.4f\t", matrix[x][y]));
            }
            stringOfMatrix.append('\n');
        }
        return stringOfMatrix.toString();
    }

    static class ProjectOntoEigenVector
    extends MRTask<ProjectOntoEigenVector> {
        final double[] _yCoord;

        ProjectOntoEigenVector(double[] yCoord) {
            this._yCoord = yCoord;
        }

        public void map(Chunk[] cs, NewChunk[] nc) {
            for (int i = 0; i < cs[0]._len; ++i) {
                if (cs[0].isNA(i)) {
                    nc[0].addNA();
                    continue;
                }
                int which = (int)cs[0].at8(i);
                nc[0].addNum((double)((float)this._yCoord[which]));
            }
        }
    }

    public static class ForwardSolveInPlace
    extends MRTask<ForwardSolveInPlace> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;

        public ForwardSolveInPlace(DataInfo ainfo, double[][] L) {
            assert (L != null && L.length == L[0].length && L.length == ainfo._adaptedFrame.numCols());
            this._ainfo = ainfo;
            this._ncols = ainfo._adaptedFrame.numCols();
            this._L = L;
        }

        public void map(Chunk[] cs) {
            assert (this._ncols == cs.length);
            Chunk[] achks = new Chunk[this._ncols];
            System.arraycopy(cs, 0, achks, 0, this._ncols);
            for (int row = 0; row < cs[0]._len; ++row) {
                DataInfo.Row arow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(achks, row, arow);
                if (arow.isBad()) continue;
                double[] aexp = arow.expandCats();
                double[] qrow = LinearAlgebraUtils.forwardSolve(this._L, aexp);
                assert (qrow.length == this._ncols);
                for (int d = 0; d < this._ncols; ++d) {
                    cs[d].set(row, qrow[d]);
                }
            }
        }
    }

    public static class ForwardSolve
    extends MRTask<ForwardSolve> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;
        public double _sse;

        public ForwardSolve(DataInfo ainfo, double[][] L) {
            assert (L != null && L.length == L[0].length && L.length == ainfo._adaptedFrame.numCols());
            this._ainfo = ainfo;
            this._ncols = ainfo._adaptedFrame.numCols();
            this._L = L;
            this._sse = 0.0;
        }

        public void map(Chunk[] cs) {
            assert (2 * this._ncols == cs.length);
            Chunk[] achks = new Chunk[this._ncols];
            System.arraycopy(cs, 0, achks, 0, this._ncols);
            for (int row = 0; row < cs[0]._len; ++row) {
                DataInfo.Row arow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(achks, row, arow);
                if (arow.isBad()) continue;
                double[] aexp = arow.expandCats();
                double[] qrow = LinearAlgebraUtils.forwardSolve(this._L, aexp);
                int i = 0;
                for (int d = this._ncols; d < 2 * this._ncols; ++d) {
                    double qold = cs[d].atd(row);
                    double diff = qrow[i] - qold;
                    this._sse += diff * diff;
                    cs[d].set(row, qrow[i++]);
                }
                assert (i == qrow.length);
            }
        }
    }

    public static class SMulTask
    extends MRTask<SMulTask> {
        final DataInfo _ainfo;
        final int _ncolA;
        final int _ncolExp;
        final int _ncolQ;
        public double[][] _atq;

        public SMulTask(DataInfo ainfo, int ncolQ) {
            this._ainfo = ainfo;
            this._ncolA = ainfo._adaptedFrame.numCols();
            this._ncolExp = LinearAlgebraUtils.numColsExp(ainfo._adaptedFrame, true);
            this._ncolQ = ncolQ;
        }

        public SMulTask(DataInfo ainfo, int ncolQ, int ncolExp) {
            this._ainfo = ainfo;
            this._ncolA = ainfo._adaptedFrame.numCols();
            this._ncolExp = ncolExp;
            this._ncolQ = ncolQ;
        }

        public void map(Chunk[] cs) {
            assert (this._ncolA + this._ncolQ == cs.length);
            this._atq = new double[this._ncolExp][this._ncolQ];
            for (int k = this._ncolA; k < this._ncolA + this._ncolQ; ++k) {
                for (int p = 0; p < this._ainfo._cats; ++p) {
                    for (int row = 0; row < cs[0]._len; ++row) {
                        int cidx;
                        if (cs[p].isNA(row) && this._ainfo._skipMissing) continue;
                        double q = cs[k].atd(row);
                        double a = cs[p].atd(row);
                        if (Double.isNaN(a)) {
                            if (this._ainfo._imputeMissing) {
                                cidx = this._ainfo.catNAFill()[p];
                            } else {
                                if (!this._ainfo._catMissing[p]) continue;
                                cidx = this._ainfo._catOffsets[p + 1] - 1;
                            }
                        } else {
                            cidx = this._ainfo.getCategoricalId(p, (int)a);
                        }
                        if (cidx < 0) continue;
                        double[] dArray = this._atq[cidx];
                        int n = k - this._ncolA;
                        dArray[n] = dArray[n] + q;
                    }
                }
                int pnum = 0;
                int pexp = this._ainfo.numStart();
                for (int p = this._ainfo._cats; p < this._ncolA; ++p) {
                    for (int row = 0; row < cs[0]._len; ++row) {
                        if (cs[p].isNA(row) && this._ainfo._skipMissing) continue;
                        double q = cs[k].atd(row);
                        double a = cs[p].atd(row);
                        a = LinearAlgebraUtils.modifyNumeric(a, pnum, this._ainfo);
                        double[] dArray = this._atq[pexp];
                        int n = k - this._ncolA;
                        dArray[n] = dArray[n] + q * a;
                    }
                    ++pexp;
                    ++pnum;
                }
                assert (pexp == this._atq.length);
            }
        }

        public void reduce(SMulTask other) {
            water.util.ArrayUtils.add((double[][])this._atq, (double[][])other._atq);
        }
    }

    public static class BMulInPlaceTask
    extends MRTask<BMulInPlaceTask> {
        final DataInfo _xinfo;
        final double[][] _yt;
        final int _ncolX;
        public boolean _originalImplementation = true;

        public BMulInPlaceTask(DataInfo xinfo, double[][] yt, int nColsExp) {
            assert (yt != null && yt[0].length == nColsExp);
            this._xinfo = xinfo;
            this._ncolX = xinfo._adaptedFrame.numCols();
            this._yt = yt;
        }

        public BMulInPlaceTask(DataInfo xinfo, double[][] yt, int nColsExp, boolean originalWay) {
            assert (yt != null && yt[0].length == nColsExp);
            this._xinfo = xinfo;
            this._ncolX = xinfo._adaptedFrame.numCols();
            this._yt = yt;
            this._originalImplementation = originalWay;
        }

        public void map(Chunk[] cs) {
            assert (cs.length == this._ncolX + this._yt.length);
            int lastColInd = this._ncolX - 1;
            Chunk[] xchk = new Chunk[this._ncolX];
            DataInfo.Row xrow = this._xinfo.newDenseRow();
            System.arraycopy(cs, 0, xchk, 0, this._ncolX);
            for (int row = 0; row < cs[0]._len; ++row) {
                this._xinfo.extractDenseRow(xchk, row, xrow);
                if (xrow.isBad()) continue;
                int bidx = this._ncolX;
                for (double[] ps : this._yt) {
                    double sum = this._originalImplementation ? xrow.innerProduct(ps) : xrow.innerProduct(ps) - ps[lastColInd];
                    cs[bidx].set(row, sum);
                    ++bidx;
                }
                assert (bidx == cs.length);
            }
        }
    }

    public static class BMulTaskMatrices
    extends MRTask<BMulTaskMatrices> {
        final Frame _y;
        final int _nyChunks;
        final int _yColNum;

        public BMulTaskMatrices(Frame y) {
            this._y = y;
            this._nyChunks = this._y.anyVec().nChunks();
            this._yColNum = this._y.numCols();
        }

        private void mulResultPerYChunk(Chunk[] xChunk, Chunk[] yChunk) {
            int xChunkLen = xChunk[0].len();
            int yColLen = yChunk.length;
            int yChunkLen = yChunk[0].len();
            int resultColOffset = xChunk.length - yColLen;
            int xChunkColOffset = (int)yChunk[0].start();
            for (int colIndex = 0; colIndex < yColLen; ++colIndex) {
                int resultColIndex = colIndex + resultColOffset;
                for (int rowIndex = 0; rowIndex < xChunkLen; ++rowIndex) {
                    double origResult = xChunk[resultColIndex].atd(rowIndex);
                    for (int interIndex = 0; interIndex < yChunkLen; ++interIndex) {
                        origResult += xChunk[interIndex + xChunkColOffset].atd(rowIndex) * yChunk[colIndex].atd(interIndex);
                    }
                    xChunk[resultColIndex].set(rowIndex, origResult);
                }
            }
        }

        public void map(Chunk[] xChunk) {
            Chunk[] ychunk = new Chunk[this._y.numCols()];
            for (int ychunkInd = 0; ychunkInd < this._nyChunks; ++ychunkInd) {
                for (int chkIndex = 0; chkIndex < this._yColNum; ++chkIndex) {
                    ychunk[chkIndex] = this._y.vec(chkIndex).chunkForChunkIdx(ychunkInd);
                }
                this.mulResultPerYChunk(xChunk, ychunk);
            }
        }
    }

    public static class BMulTask
    extends FrameTask<BMulTask> {
        final double[][] _yt;

        public BMulTask(Key<Job> jobKey, DataInfo dinfo, double[][] yt) {
            super(jobKey, dinfo);
            this._yt = yt;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row row, NewChunk[] outputs) {
            for (int p = 0; p < this._yt.length; ++p) {
                double x = row.innerProduct(this._yt[p]);
                outputs[p].addNum(x);
            }
        }
    }

    public static class CopyQtoQMatrix
    extends MRTask<CopyQtoQMatrix> {
        public void map(Chunk[] cs) {
            int totColumn = cs.length;
            int halfColumn = totColumn / 2;
            int totRows = cs[0].len();
            for (int rowIndex = 0; rowIndex < totRows; ++rowIndex) {
                for (int colIndex = 0; colIndex < halfColumn; ++colIndex) {
                    cs[colIndex].set(rowIndex, cs[colIndex + halfColumn].atd(rowIndex));
                }
            }
        }
    }

    public static class FindMaxIndex
    extends MRTask<FindMaxIndex> {
        public long _maxIndex = -1L;
        int _colIndex;
        double _maxValue;

        public FindMaxIndex(int colOfInterest, double maxValue) {
            this._colIndex = colOfInterest;
            this._maxValue = maxValue;
        }

        public void map(Chunk[] cs) {
            int rowLen = cs[0].len();
            long startRowIndex = cs[0].start();
            for (int rowIndex = 0; rowIndex < rowLen; ++rowIndex) {
                double rowVal = cs[this._colIndex].atd(rowIndex);
                if (rowVal != this._maxValue) continue;
                this._maxIndex = startRowIndex + (long)rowIndex;
            }
        }

        public void reduce(FindMaxIndex other) {
            if (this._maxIndex < 0L) {
                this._maxIndex = other._maxIndex;
            } else if (this._maxIndex > other._maxIndex) {
                this._maxIndex = other._maxIndex;
            }
        }
    }
}

