/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import java.util.Arrays;
import water.H2O;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstBuiltin;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;
import water.util.Log;

public class AstDistance
extends AstBuiltin<AstDistance> {
    @Override
    public String[] args() {
        return new String[]{"ary", "x", "y", "measure"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "distance";
    }

    @Override
    public String description() {
        return "Compute a pairwise distance measure between all rows of two numeric H2OFrames.\nFor a given (usually larger) reference frame (N rows x p cols),\nand a (usually smaller) query frame (M rows x p cols), we return a numeric Frame of size (N rows x M cols),\nwhere the ij-th element is the distance measure between the i-th reference row and the j-th query row.\nNote1: The output frame is symmetric.\nNote2: Since N x M can be very large, it may be more efficient (memory-wise) to make multiple calls with smaller query Frames.";
    }

    @Override
    public Val apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame frx = stk.track(asts[1].exec(env)).getFrame();
        Frame fry = stk.track(asts[2].exec(env)).getFrame();
        String measure = stk.track(asts[3].exec(env)).getStr();
        return this.computeCosineDistances(frx, fry, measure);
    }

    public Val computeCosineDistances(Frame references, Frame queries, String distanceMetric) {
        Log.info("Number of references: " + references.numRows());
        Log.info("Number of queries   : " + queries.numRows());
        Object[] options = new String[]{"cosine", "cosine_sq", "l1", "l2"};
        if (!ArrayUtils.contains(options, distanceMetric.toLowerCase())) {
            throw new IllegalArgumentException("Invalid distance measure provided: " + distanceMetric + ". Mustbe one of " + Arrays.toString(options));
        }
        if (references.numRows() * queries.numRows() * 8L > H2O.CLOUD.free_mem()) {
            throw new IllegalArgumentException("Not enough free memory to allocate the distance matrix (" + references.numRows() + " rows and " + queries.numRows() + " cols. Try specifying a smaller query frame.");
        }
        if (references.numCols() != queries.numCols()) {
            throw new IllegalArgumentException("Frames must have the same number of cols, found " + references.numCols() + " and " + queries.numCols());
        }
        if (queries.numRows() > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Queries can't be larger than 2 billion rows.");
        }
        if (queries.numCols() != references.numCols()) {
            throw new IllegalArgumentException("Queries and References must have the same dimensionality");
        }
        for (int i2 = 0; i2 < queries.numCols(); ++i2) {
            if (!references.vec(i2).isNumeric()) {
                throw new IllegalArgumentException("References column " + references.name(i2) + " is not numeric.");
            }
            if (!queries.vec(i2).isNumeric()) {
                throw new IllegalArgumentException("Queries column " + references.name(i2) + " is not numeric.");
            }
            if (references.vec(i2).naCnt() > 0L) {
                throw new IllegalArgumentException("References column " + references.name(i2) + " contains missing values.");
            }
            if (queries.vec(i2).naCnt() <= 0L) continue;
            throw new IllegalArgumentException("Queries column " + references.name(i2) + " contains missing values.");
        }
        return new ValFrame(((DistanceComputer)new DistanceComputer(queries, distanceMetric).doAll((int)queries.numRows(), (byte)3, references)).outputFrame());
    }

    public static class DistanceComputer
    extends MRTask<DistanceComputer> {
        Frame _queries;
        String _measure;

        DistanceComputer(Frame queries, String measure) {
            this._queries = queries;
            this._measure = measure;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            int r2;
            int p2 = cs.length;
            int Q2 = (int)this._queries.numRows();
            int R = cs[0]._len;
            Vec.Reader[] Qs = new Vec.Reader[p2];
            for (int i2 = 0; i2 < p2; ++i2) {
                Qs[i2] = this._queries.vec(i2).new Vec.Reader();
            }
            double[] denomR = null;
            double[] denomQ = null;
            boolean cosine = this._measure.toLowerCase().equals("cosine");
            boolean cosine_sq = this._measure.toLowerCase().equals("cosine_sq");
            boolean l1 = this._measure.toLowerCase().equals("l1");
            boolean l2 = this._measure.toLowerCase().equals("l2");
            if (cosine || cosine_sq) {
                int c2;
                denomR = new double[R];
                denomQ = new double[Q2];
                for (r2 = 0; r2 < R; ++r2) {
                    for (c2 = 0; c2 < p2; ++c2) {
                        int n2 = r2;
                        denomR[n2] = denomR[n2] + Math.pow(cs[c2].atd(r2), 2.0);
                    }
                }
                for (int q2 = 0; q2 < Q2; ++q2) {
                    for (c2 = 0; c2 < p2; ++c2) {
                        int n3 = q2;
                        denomQ[n3] = denomQ[n3] + Math.pow(Qs[c2].at(q2), 2.0);
                    }
                }
            }
            for (r2 = 0; r2 < cs[0]._len; ++r2) {
                for (int q3 = 0; q3 < Q2; ++q3) {
                    int c3;
                    double distRQ = 0.0;
                    if (l1) {
                        for (c3 = 0; c3 < p2; ++c3) {
                            distRQ += Math.abs(cs[c3].atd(r2) - Qs[c3].at(q3));
                        }
                    } else if (l2) {
                        for (c3 = 0; c3 < p2; ++c3) {
                            distRQ += Math.pow(cs[c3].atd(r2) - Qs[c3].at(q3), 2.0);
                        }
                        distRQ = Math.sqrt(distRQ);
                    } else if (cosine || cosine_sq) {
                        for (c3 = 0; c3 < p2; ++c3) {
                            distRQ += cs[c3].atd(r2) * Qs[c3].at(q3);
                        }
                        if (cosine_sq) {
                            distRQ *= distRQ;
                            distRQ /= denomR[r2] * denomQ[q3];
                        } else {
                            distRQ /= Math.sqrt(denomR[r2] * denomQ[q3]);
                        }
                    }
                    ncs[q3].addNum(distRQ);
                }
            }
        }
    }
}

