/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import java.util.Arrays;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValNum;
import water.util.ArrayUtils;

public class AstCorrelation
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary", "x", "y", "use"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "cor";
    }

    @Override
    public Val apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Mode mode;
        String use;
        Frame frx = stk.track(asts[1].exec(env)).getFrame();
        Frame fry = stk.track(asts[2].exec(env)).getFrame();
        if (frx.numRows() != fry.numRows()) {
            throw new IllegalArgumentException("Frames must have the same number of rows, found " + frx.numRows() + " and " + fry.numRows());
        }
        String string = use = stk.track(asts[3].exec(env)).getStr();
        int n = -1;
        switch (string.hashCode()) {
            case 401590963: {
                if (!string.equals("everything")) break;
                n = 0;
                break;
            }
            case -913287373: {
                if (!string.equals("all.obs")) break;
                n = 1;
                break;
            }
            case -411139381: {
                if (!string.equals("complete.obs")) break;
                n = 2;
            }
        }
        switch (n) {
            case 0: {
                mode = Mode.Everything;
                break;
            }
            case 1: {
                mode = Mode.AllObs;
                break;
            }
            case 2: {
                mode = Mode.CompleteObs;
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown use mode: " + use);
            }
        }
        return fry.numRows() == 1L ? this.scalar(frx, fry, mode) : this.array(frx, fry, mode);
    }

    private ValNum scalar(Frame frx, Frame fry, Mode mode) {
        double yval;
        double xval;
        if (frx.numCols() != fry.numCols()) {
            throw new IllegalArgumentException("Single rows must have the same number of columns, found " + frx.numCols() + " and " + fry.numCols());
        }
        Vec[] vecxs = frx.vecs();
        Vec[] vecys = fry.vecs();
        double xmean = 0.0;
        double ymean = 0.0;
        double xvar = 0.0;
        double yvar = 0.0;
        double xsd = 0.0;
        double ysd = 0.0;
        double ncols = frx.numCols();
        double NACount = 0.0;
        double ss = 0.0;
        int r = 0;
        while ((double)r < ncols) {
            xval = vecxs[r].at(0L);
            yval = vecys[r].at(0L);
            if (Double.isNaN(xval) || Double.isNaN(yval)) {
                NACount += 1.0;
            } else {
                xmean += xval;
                ymean += yval;
            }
            ++r;
        }
        xmean /= ncols - NACount;
        ymean /= ncols - NACount;
        r = 0;
        while ((double)r < ncols) {
            xval = vecxs[r].at(0L);
            yval = vecys[r].at(0L);
            if (!Double.isNaN(xval) && !Double.isNaN(yval)) {
                xvar += Math.pow(vecxs[r].at(0L) - xmean, 2.0);
                yvar += Math.pow(vecys[r].at(0L) - ymean, 2.0);
                ss += (vecxs[r].at(0L) - xmean) * (vecys[r].at(0L) - ymean);
            }
            ++r;
        }
        xsd = Math.sqrt(xvar / (double)frx.numRows());
        ysd = Math.sqrt(yvar / (double)fry.numRows());
        double cor_denom = xsd * ysd;
        if (NACount != 0.0) {
            if (mode.equals((Object)Mode.AllObs)) {
                throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
            }
            if (mode.equals((Object)Mode.Everything)) {
                return new ValNum(Double.NaN);
            }
        }
        int r2 = 0;
        while ((double)r2 < ncols) {
            xval = vecxs[r2].at(0L);
            yval = vecys[r2].at(0L);
            if (!Double.isNaN(xval) && !Double.isNaN(yval)) {
                ss += (vecxs[r2].at(0L) - xmean) * (vecys[r2].at(0L) - ymean);
            }
            ++r2;
        }
        return new ValNum(ss / cor_denom);
    }

    private Val array(Frame frx, Frame fry, Mode mode) {
        Vec[] vecxs = frx.vecs();
        int ncolx = vecxs.length;
        Vec[] vecys = fry.vecs();
        int ncoly = vecys.length;
        if (mode.equals((Object)Mode.AllObs)) {
            for (Vec v : vecxs) {
                if (v.naCnt() == 0L) continue;
                throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
            }
        }
        CorTaskMean taskMean = (CorTaskMean)new CorTaskMean(ncoly, ncolx, mode.equals((Object)Mode.CompleteObs)).doAll(new Frame(fry).add(frx));
        long NACount = taskMean._NACount;
        double[] ymeans = ArrayUtils.div(taskMean._ysum, (double)(fry.numRows() - NACount));
        double[] xmeans = ArrayUtils.div(taskMean._xsum, (double)(fry.numRows() - NACount));
        CorTask[] cvs = new CorTask[ncoly];
        for (int y = 0; y < ymeans.length; ++y) {
            cvs[y] = (CorTask)new CorTask(ymeans[y], xmeans, true).dfork(new Frame(vecys[y]).add(frx));
        }
        if (ncolx == 1 && ncoly == 1) {
            return new ValNum(((CorTask)cvs[0].getResult())._cors[0] / ((CorTask)cvs[0].getResult())._denom[0]);
        }
        Vec[] res = new Vec[ncoly];
        Key<Vec>[] keys = Vec.VectorGroup.VG_LEN1.addVecs(ncoly);
        for (int y = 0; y < ncoly; ++y) {
            res[y] = Vec.makeVec(ArrayUtils.div(((CorTask)cvs[y].getResult())._cors, ((CorTask)cvs[y].getResult())._denom), keys[y]);
        }
        return new ValFrame(new Frame(fry._names, res));
    }

    private static class CorTaskMean
    extends MRTask<CorTaskMean> {
        double[] _xsum;
        double[] _ysum;
        long _NACount;
        int _ncolx;
        int _ncoly;
        boolean _completeObs;

        CorTaskMean(int ncoly, int ncolx, boolean completeObs) {
            this._ncolx = ncolx;
            this._ncoly = ncoly;
            this._completeObs = completeObs;
        }

        @Override
        public void map(Chunk[] cs) {
            this._xsum = new double[this._ncolx];
            this._ysum = new double[this._ncoly];
            double[] xvals = new double[this._ncolx];
            double[] yvals = new double[this._ncoly];
            int len = cs[0]._len;
            for (int row = 0; row < len; ++row) {
                boolean add = true;
                Arrays.fill(xvals, 0.0);
                Arrays.fill(yvals, 0.0);
                for (int y = 0; y < this._ncoly; ++y) {
                    Chunk cy = cs[y];
                    double yval = cy.atd(row);
                    if (Double.isNaN(yval) && this._completeObs) {
                        ++this._NACount;
                        add = false;
                        break;
                    }
                    yvals[y] = yval;
                }
                if (add) {
                    for (int x = 0; x < this._ncolx; ++x) {
                        Chunk cx = cs[x + this._ncoly];
                        double xval = cx.atd(row);
                        if (Double.isNaN(xval) && this._completeObs) {
                            ++this._NACount;
                            add = false;
                            break;
                        }
                        xvals[x] = xval;
                    }
                }
                if (!add) continue;
                ArrayUtils.add(this._xsum, xvals);
                ArrayUtils.add(this._ysum, yvals);
            }
        }

        @Override
        public void reduce(CorTaskMean cvt) {
            ArrayUtils.add(this._xsum, cvt._xsum);
            ArrayUtils.add(this._ysum, cvt._ysum);
            this._NACount += cvt._NACount;
        }
    }

    private static class CorTask
    extends MRTask<CorTask> {
        double[] _cors;
        double[] _denom;
        final double[] _xmeans;
        final double _ymean;
        boolean _completeObs;

        CorTask(double ymean, double[] xmeans, boolean completeObs) {
            this._ymean = ymean;
            this._xmeans = xmeans;
            this._completeObs = completeObs;
        }

        @Override
        public void map(Chunk[] cs) {
            int ncolsx = cs.length - 1;
            Chunk cy = cs[0];
            int len = cy._len;
            this._cors = new double[ncolsx];
            this._denom = new double[ncolsx];
            for (int x = 0; x < ncolsx; ++x) {
                double sum = 0.0;
                double varx = 0.0;
                double vary = 0.0;
                Chunk cx = cs[x + 1];
                double xmean = this._xmeans[x];
                for (int row = 0; row < len; ++row) {
                    if (this._completeObs) {
                        if (cx.isNA(row) || cy.isNA(row)) continue;
                        varx += (cx.atd(row) - xmean) * (cx.atd(row) - xmean);
                        vary += (cy.atd(row) - this._ymean) * (cy.atd(row) - this._ymean);
                        sum += (cx.atd(row) - xmean) * (cy.atd(row) - this._ymean);
                        continue;
                    }
                    varx += (cx.atd(row) - xmean) * (cx.atd(row) - xmean);
                    vary += (cy.atd(row) - this._ymean) * (cy.atd(row) - this._ymean);
                    sum += (cx.atd(row) - xmean) * (cy.atd(row) - this._ymean);
                }
                this._cors[x] = sum;
                this._denom[x] = Math.sqrt(varx) * Math.sqrt(vary);
            }
        }

        @Override
        public void reduce(CorTask cvt) {
            ArrayUtils.add(this._cors, cvt._cors);
            ArrayUtils.add(this._denom, cvt._denom);
        }
    }

    private static enum Mode {
        Everything,
        AllObs,
        CompleteObs;

    }
}

