/*
 * Decompiled with CFR 0.152.
 */
package hex.util;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import hex.DataInfo;
import hex.FrameTask;
import hex.gram.Gram;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public class LinearAlgebraUtils {
    public static final double[] forwardSolve(double[][] L, double[] b) {
        assert (L != null && L.length == L[0].length && L.length == b.length);
        double[] res = new double[b.length];
        for (int i = 0; i < b.length; ++i) {
            res[i] = b[i];
            for (int j = 0; j < i; ++j) {
                int n = i;
                res[n] = res[n] - L[i][j] * res[j];
            }
            int n = i;
            res[n] = res[n] / L[i][i];
        }
        return res;
    }

    public static final double[] backwardSolve(double[][] U, double[] b) {
        assert (U != null && U.length == U[0].length && U.length == b.length);
        double[] res = new double[b.length];
        for (int i = b.length - 1; i >= 0; --i) {
            res[i] = b[i];
            for (int j = i + 1; j < b.length; ++j) {
                int n = i;
                res[n] = res[n] - U[i][j] * res[j];
            }
            int n = i;
            res[n] = res[n] / U[i][i];
        }
        return res;
    }

    private static double modifyNumeric(double x, int col, DataInfo dinfo) {
        double y = x;
        if (Double.isNaN(x) && dinfo._imputeMissing) {
            y = dinfo._numMeans[col];
        }
        if (dinfo._normSub != null && dinfo._normMul != null) {
            y = (y - dinfo._normSub[col]) * dinfo._normMul[col];
        }
        return y;
    }

    public static double[] expandRow(double[] row, DataInfo dinfo, double[] tmp) {
        return LinearAlgebraUtils.expandRow(row, dinfo, tmp, true);
    }

    public static double[] expandRow(double[] row, DataInfo dinfo, double[] tmp, boolean modify_numeric) {
        for (int col = 0; col < dinfo._cats; ++col) {
            int cidx;
            if (Double.isNaN(row[col])) {
                if (dinfo._imputeMissing) {
                    cidx = dinfo._catModes[col];
                } else {
                    if (dinfo._catMissing[col] == 0) continue;
                    cidx = dinfo._catOffsets[col + 1] - 1;
                }
            } else {
                cidx = dinfo.getCategoricalId(col, (int)row[col]);
            }
            if (cidx < 0) continue;
            tmp[cidx] = 1.0;
        }
        int chk_cnt = dinfo._cats;
        int exp_cnt = dinfo.numStart();
        for (int col = 0; col < dinfo._nums; ++col) {
            tmp[exp_cnt] = modify_numeric ? LinearAlgebraUtils.modifyNumeric(row[chk_cnt], col, dinfo) : row[chk_cnt];
            ++exp_cnt;
            ++chk_cnt;
        }
        return tmp;
    }

    public static double[] expandRow(Chunk[] chks, int row_in_chunk, DataInfo dinfo, double[] tmp, boolean modify_numeric) {
        for (int col = 0; col < dinfo._cats; ++col) {
            int cidx;
            double x = chks[col].atd(row_in_chunk);
            if (Double.isNaN(x)) {
                if (dinfo._imputeMissing) {
                    cidx = dinfo._catModes[col];
                } else {
                    if (dinfo._catMissing[col] == 0) continue;
                    cidx = dinfo._catOffsets[col + 1] - 1;
                }
            } else {
                cidx = dinfo.getCategoricalId(col, (int)x);
            }
            if (cidx < 0) continue;
            tmp[cidx] = 1.0;
        }
        int exp_cnt = dinfo.numStart();
        for (int col = 0; col < dinfo._nums; ++col) {
            double x = chks[col].atd(row_in_chunk);
            tmp[exp_cnt] = modify_numeric ? LinearAlgebraUtils.modifyNumeric(x, col, dinfo) : x;
            ++exp_cnt;
        }
        return tmp;
    }

    public static double[][] computeR(Key jobKey, DataInfo yinfo, boolean transpose) {
        Gram.GramTask gtsk = new Gram.GramTask(jobKey, yinfo);
        gtsk.doAll(yinfo._adaptedFrame);
        Matrix ygram = new Matrix(gtsk._gram.getXX());
        CholeskyDecomposition chol = new CholeskyDecomposition(ygram);
        double[][] L = chol.getL().getArray();
        ArrayUtils.mult((double[][])L, (double)Math.sqrt(gtsk._nobs));
        return transpose ? L : ArrayUtils.transpose((double[][])L);
    }

    public static double computeQ(Key jobKey, DataInfo yinfo, Frame ywfrm) {
        double[][] cholL = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolve qrtsk = new ForwardSolve(yinfo, cholL);
        qrtsk.doAll(ywfrm);
        return qrtsk._sse;
    }

    public static void computeQInPlace(Key jobKey, DataInfo yinfo) {
        double[][] cholL = LinearAlgebraUtils.computeR(jobKey, yinfo, true);
        ForwardSolveInPlace qrtsk = new ForwardSolveInPlace(yinfo, cholL);
        qrtsk.doAll(yinfo._adaptedFrame);
    }

    public static class ForwardSolveInPlace
    extends MRTask<ForwardSolveInPlace> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;

        public ForwardSolveInPlace(DataInfo ainfo, double[][] L) {
            assert (L != null && L.length == L[0].length && L.length == ainfo._adaptedFrame.numCols());
            this._ainfo = ainfo;
            this._ncols = ainfo._adaptedFrame.numCols();
            this._L = L;
        }

        public void map(Chunk[] cs) {
            assert (this._ncols == cs.length);
            Chunk[] achks = new Chunk[this._ncols];
            for (int i = 0; i < this._ncols; ++i) {
                achks[i] = cs[i];
            }
            for (int row = 0; row < cs[0]._len; ++row) {
                DataInfo.Row arow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(achks, row, arow);
                if (arow.bad) continue;
                double[] aexp = arow.expandCats();
                double[] qrow = LinearAlgebraUtils.forwardSolve(this._L, aexp);
                assert (qrow.length == this._ncols);
                for (int d = 0; d < this._ncols; ++d) {
                    cs[d].set(row, qrow[d]);
                }
            }
        }
    }

    public static class ForwardSolve
    extends MRTask<ForwardSolve> {
        final DataInfo _ainfo;
        final int _ncols;
        final double[][] _L;
        public double _sse;

        public ForwardSolve(DataInfo ainfo, double[][] L) {
            assert (L != null && L.length == L[0].length && L.length == ainfo._adaptedFrame.numCols());
            this._ainfo = ainfo;
            this._ncols = ainfo._adaptedFrame.numCols();
            this._L = L;
            this._sse = 0.0;
        }

        public void map(Chunk[] cs) {
            assert (2 * this._ncols == cs.length);
            Chunk[] achks = new Chunk[this._ncols];
            for (int i = 0; i < this._ncols; ++i) {
                achks[i] = cs[i];
            }
            for (int row = 0; row < cs[0]._len; ++row) {
                DataInfo.Row arow = this._ainfo.newDenseRow();
                this._ainfo.extractDenseRow(achks, row, arow);
                if (arow.bad) continue;
                double[] aexp = arow.expandCats();
                double[] qrow = LinearAlgebraUtils.forwardSolve(this._L, aexp);
                int i = 0;
                for (int d = this._ncols; d < 2 * this._ncols; ++d) {
                    double qold = cs[d].atd(row);
                    double diff = qrow[i] - qold;
                    this._sse += diff * diff;
                    cs[d].set(row, qrow[i++]);
                }
                assert (i == qrow.length);
            }
        }
    }

    public static class SMulTask
    extends MRTask<SMulTask> {
        final DataInfo _ainfo;
        final int _ncolA;
        final int _ncolExp;
        final int _ncolQ;
        public double[][] _atq;

        public SMulTask(DataInfo ainfo, int ncolQ) {
            this._ainfo = ainfo;
            this._ncolA = ainfo._adaptedFrame.numCols();
            this._ncolExp = ainfo._adaptedFrame.numColsExp();
            this._ncolQ = ncolQ;
        }

        public void map(Chunk[] cs) {
            assert (this._ncolA + this._ncolQ == cs.length);
            this._atq = new double[this._ncolExp][this._ncolQ];
            for (int k = this._ncolA; k < this._ncolA + this._ncolQ; ++k) {
                for (int p = 0; p < this._ainfo._cats; ++p) {
                    for (int row = 0; row < cs[0]._len; ++row) {
                        int cidx;
                        if (cs[p].isNA(row) && this._ainfo._skipMissing) continue;
                        double q = cs[k].atd(row);
                        double a = cs[p].atd(row);
                        if (Double.isNaN(a)) {
                            if (this._ainfo._imputeMissing) {
                                cidx = this._ainfo._catModes[p];
                            } else {
                                if (this._ainfo._catMissing[p] == 0) continue;
                                cidx = this._ainfo._catOffsets[p + 1] - 1;
                            }
                        } else {
                            cidx = this._ainfo.getCategoricalId(p, (int)a);
                        }
                        if (cidx < 0) continue;
                        double[] dArray = this._atq[cidx];
                        int n = k - this._ncolA;
                        dArray[n] = dArray[n] + q;
                    }
                }
                int pnum = 0;
                int pexp = this._ainfo.numStart();
                for (int p = this._ainfo._cats; p < this._ncolA; ++p) {
                    for (int row = 0; row < cs[0]._len; ++row) {
                        if (cs[p].isNA(row) && this._ainfo._skipMissing) continue;
                        double q = cs[k].atd(row);
                        double a = cs[p].atd(row);
                        a = LinearAlgebraUtils.modifyNumeric(a, pnum, this._ainfo);
                        double[] dArray = this._atq[pexp];
                        int n = k - this._ncolA;
                        dArray[n] = dArray[n] + q * a;
                    }
                    ++pexp;
                    ++pnum;
                }
                assert (pexp == this._atq.length);
            }
        }

        public void reduce(SMulTask other) {
            ArrayUtils.add((double[][])this._atq, (double[][])other._atq);
        }
    }

    public static class BMulInPlaceTask
    extends MRTask<BMulInPlaceTask> {
        final DataInfo _xinfo;
        final double[][] _yt;
        final int _ncolX;

        public BMulInPlaceTask(DataInfo xinfo, double[][] yt) {
            assert (yt != null && yt[0].length == xinfo._adaptedFrame.numColsExp());
            this._xinfo = xinfo;
            this._ncolX = xinfo._adaptedFrame.numCols();
            this._yt = yt;
        }

        public void map(Chunk[] cs) {
            assert (cs.length == this._ncolX + this._yt.length);
            Chunk[] xchk = new Chunk[this._ncolX];
            for (int i = 0; i < this._ncolX; ++i) {
                xchk[i] = cs[i];
            }
            for (int row = 0; row < cs[0]._len; ++row) {
                DataInfo.Row xrow = this._xinfo.newDenseRow();
                this._xinfo.extractDenseRow(xchk, row, xrow);
                if (xrow.bad) continue;
                int bidx = this._ncolX;
                for (int p = 0; p < this._yt.length; ++p) {
                    double sum = xrow.innerProduct(this._yt[p]);
                    cs[bidx].set(row, sum);
                    ++bidx;
                }
                assert (bidx == cs.length);
            }
        }
    }

    public static class BMulTask
    extends FrameTask<BMulTask> {
        final double[][] _yt;

        public BMulTask(Key jobKey, DataInfo dinfo, double[][] yt) {
            super(jobKey, dinfo);
            this._yt = yt;
        }

        @Override
        protected void processRow(long gid, DataInfo.Row row, NewChunk[] outputs) {
            for (int p = 0; p < this._yt.length; ++p) {
                double x = row.innerProduct(this._yt[p]);
                outputs[p].addNum(x);
            }
        }
    }
}

