/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import water.Key;
import water.MRTask;
import water.fvec.C16Chunk;
import water.fvec.CStrChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValNum;
import water.util.ArrayUtils;

public class AstCorrelation
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary", "x", "y", "use"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "cor";
    }

    @Override
    public Val apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Mode mode;
        String use;
        Frame frx = stk.track(asts[1].exec(env)).getFrame();
        Frame fry = stk.track(asts[2].exec(env)).getFrame();
        if (frx.numRows() != fry.numRows()) {
            throw new IllegalArgumentException("Frames must have the same number of rows, found " + frx.numRows() + " and " + fry.numRows());
        }
        switch (use = stk.track(asts[3].exec(env)).getStr()) {
            case "everything": {
                mode = Mode.Everything;
                break;
            }
            case "all.obs": {
                mode = Mode.AllObs;
                break;
            }
            case "complete.obs": {
                mode = Mode.CompleteObs;
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown use mode: " + use);
            }
        }
        return fry.numRows() == 1L ? this.scalar(frx, fry, mode) : this.array(frx, fry, mode);
    }

    private ValNum scalar(Frame frx, Frame fry, Mode mode) {
        double yval;
        double xval;
        if (frx.numCols() != fry.numCols()) {
            throw new IllegalArgumentException("Single rows must have the same number of columns, found " + frx.numCols() + " and " + fry.numCols());
        }
        Vec[] vecxs = frx.vecs();
        Vec[] vecys = fry.vecs();
        double xmean = 0.0;
        double ymean = 0.0;
        double xvar = 0.0;
        double yvar = 0.0;
        double ncols = fry.numCols();
        double NACount = 0.0;
        double ss = 0.0;
        int r = 0;
        while ((double)r < ncols) {
            xval = vecxs[r].at(0L);
            yval = vecys[r].at(0L);
            if (Double.isNaN(xval) || Double.isNaN(yval)) {
                NACount += 1.0;
            } else {
                xmean += xval;
                ymean += yval;
            }
            ++r;
        }
        xmean /= ncols - NACount;
        ymean /= ncols - NACount;
        r = 0;
        while ((double)r < ncols) {
            xval = vecxs[r].at(0L);
            yval = vecys[r].at(0L);
            if (!Double.isNaN(xval) && !Double.isNaN(yval)) {
                xvar += Math.pow(vecxs[r].at(0L) - xmean, 2.0);
                yvar += Math.pow(vecys[r].at(0L) - ymean, 2.0);
                ss += (vecxs[r].at(0L) - xmean) * (vecys[r].at(0L) - ymean);
            }
            ++r;
        }
        double xsd = Math.sqrt(xvar / (ncols - 1.0 - NACount));
        double ysd = Math.sqrt(yvar / (ncols - 1.0 - NACount));
        double denom = xsd * ysd;
        if (NACount != 0.0) {
            if (mode.equals((Object)Mode.AllObs)) {
                throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
            }
            if (mode.equals((Object)Mode.Everything)) {
                return new ValNum(Double.NaN);
            }
        }
        return new ValNum(ss / (ncols - NACount - 1.0) / denom);
    }

    private Val array(Frame frx, Frame fry, Mode mode) {
        int y;
        Vec[] vecxs = frx.vecs();
        int ncolx = vecxs.length;
        Vec[] vecys = fry.vecs();
        int ncoly = vecys.length;
        if (mode.equals((Object)Mode.Everything) || mode.equals((Object)Mode.AllObs)) {
            int y2;
            if (mode.equals((Object)Mode.AllObs)) {
                for (Vec v : vecxs) {
                    if (v.naCnt() == 0L) continue;
                    throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
                }
            }
            CoVarTask[] cvs = new CoVarTask[ncoly];
            double[] xmeans = new double[ncolx];
            for (int x = 0; x < ncolx; ++x) {
                xmeans[x] = vecxs[x].mean();
            }
            double[] sigmay = new double[ncoly];
            double[] sigmax = new double[ncolx];
            double[][] denom = new double[ncoly][ncolx];
            for (y2 = 0; y2 < ncoly; ++y2) {
                cvs[y2] = (CoVarTask)new CoVarTask(vecys[y2].mean(), xmeans).dfork(new Frame(vecys[y2]).add(frx));
                sigmay[y2] = vecys[y2].sigma();
            }
            for (int x = 0; x < ncolx; ++x) {
                sigmax[x] = vecxs[x].sigma();
            }
            for (y2 = 0; y2 < ncoly; ++y2) {
                for (int x = 0; x < ncolx; ++x) {
                    denom[y2][x] = sigmay[y2] * sigmax[x];
                }
            }
            if (ncolx == 1 && ncoly == 1) {
                return new ValNum(((CoVarTask)cvs[0].getResult())._covs[0] / (double)(fry.numRows() - 1L) / denom[0][0]);
            }
            Vec[] res = new Vec[ncoly];
            Key<Vec>[] keys = Vec.VectorGroup.VG_LEN1.addVecs(ncoly);
            for (int y3 = 0; y3 < ncoly; ++y3) {
                res[y3] = Vec.makeVec(ArrayUtils.div(ArrayUtils.div(((CoVarTask)cvs[y3].getResult())._covs, (double)(fry.numRows() - 1L)), denom[y3]), keys[y3]);
            }
            return new ValFrame(new Frame(fry._names, res));
        }
        Frame frxy_naomit = ((MRTask)new MRTask(){

            private void copyRow(int row, Chunk[] cs, NewChunk[] ncs) {
                for (int i = 0; i < cs.length; ++i) {
                    if (cs[i] instanceof CStrChunk) {
                        ncs[i].addStr(cs[i], row);
                        continue;
                    }
                    if (cs[i] instanceof C16Chunk) {
                        ncs[i].addUUID(cs[i], row);
                        continue;
                    }
                    if (cs[i].hasFloat()) {
                        ncs[i].addNum(cs[i].atd(row));
                        continue;
                    }
                    ncs[i].addNum(cs[i].at8(row), 0);
                }
            }

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                for (int row = 0; row < cs[0]._len; ++row) {
                    int col;
                    for (col = 0; col < cs.length && !cs[col].isNA(row); ++col) {
                    }
                    if (col != cs.length) continue;
                    this.copyRow(row, cs, ncs);
                }
            }
        }.doAll(new Frame(frx).add(fry).types(), new Frame(frx).add(fry))).outputFrame(new Frame(frx).add(fry).names(), new Frame(frx).add(fry).domains());
        Vec[] vecxs_naomit = frxy_naomit.subframe(0, ncolx).vecs();
        int ncolx_naomit = vecxs_naomit.length;
        Vec[] vecys_naomit = frxy_naomit.subframe(ncolx, frxy_naomit.vecs().length).vecs();
        int ncoly_naomit = vecys_naomit.length;
        CoVarTask[] cvs = new CoVarTask[ncoly_naomit];
        double[] xmeans = new double[ncolx_naomit];
        for (int x = 0; x < ncolx_naomit; ++x) {
            xmeans[x] = vecxs_naomit[x].mean();
        }
        double[] sigmay = new double[ncoly_naomit];
        double[] sigmax = new double[ncolx_naomit];
        double[][] denom = new double[ncoly_naomit][ncolx_naomit];
        for (y = 0; y < ncoly_naomit; ++y) {
            cvs[y] = (CoVarTask)new CoVarTask(vecys_naomit[y].mean(), xmeans).dfork(new Frame(vecys_naomit[y]).add(frxy_naomit.subframe(0, ncolx)));
            sigmay[y] = vecys_naomit[y].sigma();
        }
        for (int x = 0; x < ncolx_naomit; ++x) {
            sigmax[x] = vecxs_naomit[x].sigma();
        }
        for (y = 0; y < ncoly_naomit; ++y) {
            for (int x = 0; x < ncolx_naomit; ++x) {
                denom[y][x] = sigmay[y] * sigmax[x];
            }
        }
        if (ncolx_naomit == 1 && ncoly_naomit == 1) {
            return new ValNum(((CoVarTask)cvs[0].getResult())._covs[0] / (double)(frxy_naomit.numRows() - 1L) / denom[0][0]);
        }
        Vec[] res = new Vec[ncoly_naomit];
        Key<Vec>[] keys = Vec.VectorGroup.VG_LEN1.addVecs(ncoly_naomit);
        for (int y4 = 0; y4 < ncoly_naomit; ++y4) {
            res[y4] = Vec.makeVec(ArrayUtils.div(ArrayUtils.div(((CoVarTask)cvs[y4].getResult())._covs, (double)(frxy_naomit.numRows() - 1L)), denom[y4]), keys[y4]);
        }
        return new ValFrame(new Frame(frxy_naomit.subframe((int)ncolx, (int)frxy_naomit.vecs().length)._names, res));
    }

    private static class CoVarTask
    extends MRTask<CoVarTask> {
        double[] _covs;
        final double[] _xmeans;
        final double _ymean;

        CoVarTask(double ymean, double[] xmeans) {
            this._ymean = ymean;
            this._xmeans = xmeans;
        }

        @Override
        public void map(Chunk[] cs) {
            int ncolsx = cs.length - 1;
            Chunk cy = cs[0];
            int len = cy._len;
            this._covs = new double[ncolsx];
            for (int x = 0; x < ncolsx; ++x) {
                double sum = 0.0;
                Chunk cx = cs[x + 1];
                double xmean = this._xmeans[x];
                for (int row = 0; row < len; ++row) {
                    sum += (cx.atd(row) - xmean) * (cy.atd(row) - this._ymean);
                }
                this._covs[x] = sum;
            }
        }

        @Override
        public void reduce(CoVarTask cvt) {
            ArrayUtils.add(this._covs, cvt._covs);
        }
    }

    private static enum Mode {
        Everything,
        AllObs,
        CompleteObs;

    }
}

