/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.string;

import no.priv.garshol.duke.Comparator;
import no.priv.garshol.duke.comparators.JaccardIndexComparator;
import no.priv.garshol.duke.comparators.JaroWinkler;
import no.priv.garshol.duke.comparators.Levenshtein;
import no.priv.garshol.duke.comparators.LongestCommonSubstring;
import no.priv.garshol.duke.comparators.QGramComparator;
import no.priv.garshol.duke.comparators.SoundexComparator;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;

public class AstStrDistance
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary_x", "ary_y", "measure"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "strDistance";
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame frX = stk.track(asts[1].exec(env)).getFrame();
        Frame frY = stk.track(asts[2].exec(env)).getFrame();
        String measure = asts[3].exec(env).getStr();
        if (frX.numCols() != frY.numCols() || frX.numRows() != frY.numRows()) {
            throw new IllegalArgumentException("strDistance() requires the frames to have the same number of columns and rows.");
        }
        for (int i = 0; i < frX.numCols(); ++i) {
            if (AstStrDistance.isCharacterType(frX.vec(i)) && AstStrDistance.isCharacterType(frY.vec(i))) continue;
            throw new IllegalArgumentException("Types of columns of both frames need to be String/Factor");
        }
        AstStrDistance.makeComparator(measure);
        byte[] outputTypes = new byte[frX.numCols()];
        Vec[] vecs = new Vec[frX.numCols() * 2];
        for (int i = 0; i < outputTypes.length; ++i) {
            outputTypes[i] = 3;
            vecs[i] = frX.vec(i);
            vecs[i + outputTypes.length] = frY.vec(i);
        }
        Frame distFr = ((StringDistanceComparator)new StringDistanceComparator(measure).doAll(outputTypes, vecs)).outputFrame();
        return new ValFrame(distFr);
    }

    private static boolean isCharacterType(Vec v) {
        return v.get_type() == 2 || v.get_type() == 4;
    }

    private static Comparator makeComparator(String measure) {
        switch (measure) {
            case "jaccard": 
            case "JaccardIndex": {
                return new JaccardIndexComparator();
            }
            case "jw": 
            case "JaroWinkler": {
                return new JaroWinkler();
            }
            case "lv": 
            case "Levenshtein": {
                return new Levenshtein();
            }
            case "lcs": 
            case "LongestCommonSubstring": {
                return new LongestCommonSubstring();
            }
            case "qgram": 
            case "QGram": {
                return new QGramComparator();
            }
            case "soundex": 
            case "Soundex": {
                return new SoundexComparator();
            }
        }
        throw new IllegalArgumentException("Unknown comparator: " + measure);
    }

    private static class StringDistanceComparator
    extends MRTask<StringDistanceComparator> {
        private final String _measure;

        private StringDistanceComparator(String measure) {
            this._measure = measure;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] nc) {
            BufferedString tmpStr = new BufferedString();
            Comparator cmp = AstStrDistance.makeComparator(this._measure);
            int N = nc.length;
            assert (N * 2 == cs.length);
            for (int i = 0; i < N; ++i) {
                Chunk cX = cs[i];
                String[] domainX = this._fr.vec(i).domain();
                Chunk cY = cs[i + N];
                String[] domainY = this._fr.vec(i + N).domain();
                for (int row = 0; row < cX._len; ++row) {
                    if (cX.isNA(row) || cY.isNA(row)) {
                        nc[i].addNA();
                        continue;
                    }
                    String strX = StringDistanceComparator.getString(tmpStr, cX, row, domainX);
                    String strY = StringDistanceComparator.getString(tmpStr, cY, row, domainY);
                    double dist = cmp.compare(strX, strY);
                    nc[i].addNum(dist);
                }
            }
        }

        private static String getString(BufferedString tmpStr, Chunk chk, int row, String[] domain) {
            if (domain != null) {
                return domain[(int)chk.at8(row)];
            }
            return chk.atStr(tmpStr, row).toString();
        }
    }
}

