/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import ai.h2o.targetencoding.BlendingParams;
import ai.h2o.targetencoding.TargetEncoderBroadcastJoin;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.fvec.task.FillNAWithLongValueTask;
import water.fvec.task.FilterByValueTask;
import water.fvec.task.IsNotNaTask;
import water.fvec.task.UniqTask;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.ast.prims.advmath.AstKFold;
import water.rapids.ast.prims.mungers.AstGroup;
import water.rapids.ast.prims.mungers.AstMelt;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValNum;
import water.rapids.vals.ValStr;
import water.rapids.vals.ValStrs;
import water.util.ArrayUtils;
import water.util.FrameUtils;
import water.util.TwoDimTable;

public class TargetEncoderHelper
extends Iced<TargetEncoderHelper> {
    static String NUMERATOR_COL = "numerator";
    static String DENOMINATOR_COL = "denominator";
    static String TARGETCLASS_COL = "targetclass";
    static String NA_POSTFIX = "_NA";
    private static final Logger logger = LoggerFactory.getLogger(TargetEncoderHelper.class);

    private TargetEncoderHelper() {
    }

    public static int addKFoldColumn(Frame frame, String name, int nfolds, long seed) {
        Vec foldVec = frame.anyVec().makeZero();
        frame.add(name, AstKFold.kfoldColumn((Vec)foldVec, (int)nfolds, (long)(seed == -1L ? new Random().nextLong() : seed)));
        return frame.numCols() - 1;
    }

    static double computePriorMean(Frame encodings) {
        assert (encodings.find(TARGETCLASS_COL) < 0);
        return TargetEncoderHelper.computePriorMean(encodings, -1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static double computePriorMean(Frame encodings, int targetClass) {
        int tcIdx = encodings.find(TARGETCLASS_COL);
        assert (targetClass == -1 == tcIdx < 0);
        Frame fr = null;
        try {
            fr = tcIdx < 0 ? encodings : TargetEncoderHelper.filterByValue(encodings, tcIdx, targetClass);
            Vec numeratorVec = fr.vec(NUMERATOR_COL);
            Vec denominatorVec = fr.vec(DENOMINATOR_COL);
            assert (numeratorVec != null);
            assert (denominatorVec != null);
            double d = numeratorVec.mean() / denominatorVec.mean();
            return d;
        }
        finally {
            if (fr != null && fr != encodings) {
                fr.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static Frame buildEncodingsFrame(Frame fr, int columnToEncodeIdx, int targetIdx, int foldColumnIdx, int nclasses) {
        try {
            Frame result;
            int[] nArray;
            Scope.enter();
            if (foldColumnIdx < 0) {
                int[] nArray2 = new int[1];
                nArray = nArray2;
                nArray2[0] = columnToEncodeIdx;
            } else {
                int[] nArray3 = new int[2];
                nArray3[0] = columnToEncodeIdx;
                nArray = nArray3;
                nArray3[1] = foldColumnIdx;
            }
            int[] groupBy = nArray;
            if (nclasses > 2) {
                String[] stringArray;
                String targetName = fr.name(targetIdx);
                Vec targetVec = fr.vec(targetIdx);
                Frame targetFr = new Frame(new String[]{targetName}, new Vec[]{targetVec});
                Frame oheTarget = (Frame)new FrameUtils.CategoricalOneHotEncoder(targetFr, new String[0]).exec().get();
                Scope.track((Frame[])new Frame[]{oheTarget});
                Frame expandedFr = new Frame(fr).add(oheTarget);
                AstGroup.AGG[] aggs = new AstGroup.AGG[oheTarget.numCols() + 1];
                for (int i = 0; i < oheTarget.numCols(); ++i) {
                    int partialTargetIdx = fr.numCols() + i;
                    aggs[i] = new AstGroup.AGG(AstGroup.FCN.sum, partialTargetIdx, AstGroup.NAHandling.ALL, -1);
                }
                aggs[aggs.length - 1] = new AstGroup.AGG(AstGroup.FCN.nrow, targetIdx, AstGroup.NAHandling.ALL, -1);
                result = new AstGroup().performGroupingWithAggregations(expandedFr, groupBy, aggs).getFrame();
                Scope.track((Frame[])new Frame[]{result});
                String[] targetVals = new String[oheTarget.numCols()];
                for (int i = 0; i < oheTarget.names().length; ++i) {
                    String oheCol = oheTarget.name(i);
                    String targetVal = oheCol.replaceFirst(targetName + ".", "");
                    TargetEncoderHelper.renameColumn(result, "sum_" + oheCol, targetVal);
                    targetVals[i] = targetVal;
                }
                TargetEncoderHelper.renameColumn(result, "nrow", DENOMINATOR_COL);
                if (foldColumnIdx < 0) {
                    String[] stringArray2 = new String[2];
                    stringArray2[0] = fr.name(columnToEncodeIdx);
                    stringArray = stringArray2;
                    stringArray2[1] = DENOMINATOR_COL;
                } else {
                    String[] stringArray3 = new String[3];
                    stringArray3[0] = fr.name(columnToEncodeIdx);
                    stringArray3[1] = fr.name(foldColumnIdx);
                    stringArray = stringArray3;
                    stringArray3[2] = DENOMINATOR_COL;
                }
                String[] idVars = stringArray;
                result = TargetEncoderHelper.melt(result, idVars, targetVals, TARGETCLASS_COL, NUMERATOR_COL, true);
                CategoricalWrappedVec.updateDomain((Vec)result.vec(TARGETCLASS_COL), (String[])targetVec.domain());
            } else {
                AstGroup.AGG[] aggs = new AstGroup.AGG[]{new AstGroup.AGG(AstGroup.FCN.sum, targetIdx, AstGroup.NAHandling.ALL, -1), new AstGroup.AGG(AstGroup.FCN.nrow, targetIdx, AstGroup.NAHandling.ALL, -1)};
                result = new AstGroup().performGroupingWithAggregations(fr, groupBy, aggs).getFrame();
                TargetEncoderHelper.renameColumn(result, "sum_" + fr.name(targetIdx), NUMERATOR_COL);
                TargetEncoderHelper.renameColumn(result, "nrow", DENOMINATOR_COL);
            }
            Scope.untrack((Frame[])new Frame[]{result});
            Frame frame = result;
            return frame;
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    static Frame groupEncodingsByCategory(Frame encodingsFrame, int teColumnIdx) {
        int[] nArray;
        int numeratorIdx = encodingsFrame.find(NUMERATOR_COL);
        assert (numeratorIdx >= 0);
        int denominatorIdx = encodingsFrame.find(DENOMINATOR_COL);
        assert (denominatorIdx >= 0);
        int classesIdx = encodingsFrame.find(TARGETCLASS_COL);
        if (classesIdx < 0) {
            int[] nArray2 = new int[1];
            nArray = nArray2;
            nArray2[0] = teColumnIdx;
        } else {
            int[] nArray3 = new int[2];
            nArray3[0] = teColumnIdx;
            nArray = nArray3;
            nArray3[1] = classesIdx;
        }
        int[] groupBy = nArray;
        AstGroup.AGG[] aggs = new AstGroup.AGG[]{new AstGroup.AGG(AstGroup.FCN.sum, numeratorIdx, AstGroup.NAHandling.ALL, -1), new AstGroup.AGG(AstGroup.FCN.sum, denominatorIdx, AstGroup.NAHandling.ALL, -1)};
        Frame result = new AstGroup().performGroupingWithAggregations(encodingsFrame, groupBy, aggs).getFrame();
        TargetEncoderHelper.renameColumn(result, "sum_" + NUMERATOR_COL, NUMERATOR_COL);
        TargetEncoderHelper.renameColumn(result, "sum_" + DENOMINATOR_COL, DENOMINATOR_COL);
        return result;
    }

    static Frame groupEncodingsByCategory(Frame encodingsFrame, int teColumnIdx, boolean hasFolds) {
        if (hasFolds) {
            return TargetEncoderHelper.groupEncodingsByCategory(encodingsFrame, teColumnIdx);
        }
        return encodingsFrame.deepCopy(Key.make().toString());
    }

    static void imputeCategoricalColumn(Frame data, int columnIdx, String naCategory) {
        Vec currentVec = data.vec(columnIdx);
        int indexForNACategory = currentVec.cardinality();
        FillNAWithLongValueTask task = new FillNAWithLongValueTask(columnIdx, (long)indexForNACategory);
        task.doAll(data);
        if (task._imputationHappened) {
            String[] oldDomain = currentVec.domain();
            String[] newDomain = new String[indexForNACategory + 1];
            System.arraycopy(oldDomain, 0, newDomain, 0, oldDomain.length);
            newDomain[indexForNACategory] = naCategory;
            TargetEncoderHelper.updateColumnDomain(data, columnIdx, newDomain);
        }
    }

    private static void updateColumnDomain(Frame fr, int columnIdx, String[] domain) {
        fr.write_lock();
        Vec updatedVec = fr.vec(columnIdx);
        updatedVec.setDomain(domain);
        DKV.put((Keyed)updatedVec);
        fr.update();
        fr.unlock();
    }

    static long[] getUniqueColumnValues(Frame data, int columnIndex) {
        Vec uniqueValues = TargetEncoderHelper.uniqueValuesBy(data, columnIndex).vec(0);
        long numberOfUniqueValues = uniqueValues.length();
        assert (numberOfUniqueValues <= Integer.MAX_VALUE) : "Number of unique values exceeded Integer.MAX_VALUE";
        int length = (int)numberOfUniqueValues;
        long[] uniqueValuesArr = MemoryManager.malloc8((int)length);
        int i = 0;
        while ((long)i < uniqueValues.length()) {
            uniqueValuesArr[i] = uniqueValues.at8((long)i);
            ++i;
        }
        uniqueValues.remove();
        return uniqueValuesArr;
    }

    static double getBlendedValue(double posteriorMean, double priorMean, long numberOfRowsForCategory, BlendingParams blendingParams) {
        double lambda = 1.0 / (1.0 + Math.exp((blendingParams.getInflectionPoint() - (double)numberOfRowsForCategory) / blendingParams.getSmoothing()));
        return lambda * posteriorMean + (1.0 - lambda) * priorMean;
    }

    static Frame mergeEncodings(Frame leftFrame, Frame encodingsFrame, int leftTEColumnIdx, int encodingsTEColumnIdx) {
        return TargetEncoderHelper.mergeEncodings(leftFrame, encodingsFrame, leftTEColumnIdx, -1, encodingsTEColumnIdx, -1, 0);
    }

    static Frame mergeEncodings(Frame leftFrame, Frame encodingsFrame, int leftTEColumnIdx, int leftFoldColumnIdx, int encodingsTEColumnIdx, int encodingsFoldColumnIdx, int maxFoldValue) {
        return TargetEncoderBroadcastJoin.join(leftFrame, new int[]{leftTEColumnIdx}, leftFoldColumnIdx, encodingsFrame, new int[]{encodingsTEColumnIdx}, encodingsFoldColumnIdx, maxFoldValue);
    }

    static int applyEncodings(Frame fr, String newEncodedColumnName, double priorMean, BlendingParams blendingParams) {
        int numeratorIdx = fr.find(NUMERATOR_COL);
        assert (numeratorIdx >= 0);
        int denominatorIdx = numeratorIdx + 1;
        Vec zeroVec = fr.anyVec().makeCon(0.0);
        fr.add(newEncodedColumnName, zeroVec);
        int encodedColumnIdx = fr.numCols() - 1;
        new ApplyEncodings(encodedColumnIdx, numeratorIdx, denominatorIdx, priorMean, blendingParams).doAll(fr);
        return encodedColumnIdx;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static void addNoise(Frame fr, int columnIdx, double noiseLevel, long seed) {
        if (seed == -1L) {
            seed = new Random().nextLong();
        }
        Vec zeroVec = fr.anyVec().makeCon(0.0);
        Vec randomVec = zeroVec.makeRand(seed);
        try {
            fr.add("runIf", randomVec);
            int runifIdx = fr.numCols() - 1;
            new AddNoiseTask(columnIdx, runifIdx, noiseLevel).doAll(fr);
            fr.remove(runifIdx);
        }
        finally {
            randomVec.remove();
            zeroVec.remove();
        }
    }

    static void subtractTargetValueForLOO(Frame fr, String targetColumn, int targetClass) {
        int numeratorIndex = fr.find(NUMERATOR_COL);
        int denominatorIndex = fr.find(DENOMINATOR_COL);
        int targetIndex = fr.find(targetColumn);
        assert (numeratorIndex >= 0);
        assert (denominatorIndex >= 0);
        assert (targetIndex >= 0);
        new SubtractCurrentRowForLeaveOneOutTask(numeratorIndex, denominatorIndex, targetIndex, targetClass).doAll(fr);
    }

    static Frame melt(Frame fr, String[] idVars, String[] valueVars, String varCol, String valueCol, boolean skipNA) {
        Frame melted = new AstMelt().exec(new Val[]{null, new ValFrame(fr), new ValStrs(idVars), new ValStrs(valueVars), new ValStr(varCol), new ValStr(valueCol), new ValNum(skipNA ? 1.0 : 0.0)}).getFrame();
        return TargetEncoderHelper.register(melted);
    }

    static Frame rBind(Frame a, Frame b) {
        if (a == null) {
            assert (b != null);
            return b;
        }
        String tree = String.format("(rbind %s %s)", a._key, b._key);
        return TargetEncoderHelper.execRapidsAndGetFrame(tree);
    }

    private static Frame execRapidsAndGetFrame(String astTree) {
        Val val = Rapids.exec((String)astTree);
        return TargetEncoderHelper.register(val.getFrame());
    }

    static int addCon(Frame fr, String newColumnName, long constant) {
        Vec constVec = fr.anyVec().makeCon((double)constant);
        fr.add(newColumnName, constVec);
        return fr.numCols() - 1;
    }

    static Frame filterOutNAsInColumn(Frame fr, int columnIndex) {
        Frame oneColumnFrame = new Frame(new Vec[]{fr.vec(columnIndex)});
        Frame noNaPredicateFrame = ((IsNotNaTask)new IsNotNaTask().doAll(1, (byte)3, oneColumnFrame)).outputFrame();
        Frame filtered = TargetEncoderHelper.selectByPredicate(fr, noNaPredicateFrame);
        noNaPredicateFrame.delete();
        return filtered;
    }

    static Frame filterNotByValue(Frame fr, int columnIndex, double value) {
        return TargetEncoderHelper.filterByValueBase(fr, columnIndex, value, true);
    }

    static Frame filterByValue(Frame fr, int columnIndex, double value) {
        return TargetEncoderHelper.filterByValueBase(fr, columnIndex, value, false);
    }

    private static Frame filterByValueBase(Frame fr, int columnIndex, double value, boolean isInverted) {
        Frame predicateFrame = ((FilterByValueTask)new FilterByValueTask(value, isInverted).doAll(1, (byte)3, new Frame(new Vec[]{fr.vec(columnIndex)}))).outputFrame();
        Frame filtered = TargetEncoderHelper.selectByPredicate(fr, predicateFrame);
        predicateFrame.delete();
        return filtered;
    }

    private static Frame selectByPredicate(Frame fr, Frame predicateFrame) {
        Vec predicate = predicateFrame.anyVec();
        Vec[] vecs = (Vec[])ArrayUtils.append((Object[])fr.vecs(), (Object[])new Vec[]{predicate});
        return ((Frame.DeepSelect)new Frame.DeepSelect().doAll(fr.types(), vecs)).outputFrame(Key.make(), fr._names, fr.domains());
    }

    static Frame uniqueValuesBy(Frame fr, int columnIndex) {
        Vec v;
        Vec vec0 = fr.vec(columnIndex);
        if (vec0.isCategorical()) {
            v = Vec.makeSeq((long)0L, (long)vec0.domain().length, (boolean)true);
            v.setDomain(vec0.domain());
            DKV.put((Keyed)v);
        } else {
            UniqTask t = (UniqTask)new UniqTask().doAll(new Vec[]{vec0});
            int nUniq = t._uniq.size();
            final AstGroup.G[] uniq = t._uniq.keySet().toArray(new AstGroup.G[nUniq]);
            v = Vec.makeZero((long)nUniq, (byte)vec0.get_type());
            new MRTask(){

                public void map(Chunk c) {
                    int start = (int)c.start();
                    for (int i = 0; i < c._len; ++i) {
                        c.set(i, uniq[i + start]._gs[0]);
                    }
                }
            }.doAll(new Vec[]{v});
        }
        return new Frame(new Vec[]{v});
    }

    static void renameColumn(Frame fr, int colIndex, String newName) {
        String[] newNames = fr.names();
        newNames[colIndex] = newName;
        fr.setNames(newNames);
    }

    static void renameColumn(Frame fr, String oldName, String newName) {
        TargetEncoderHelper.renameColumn(fr, fr.find(oldName), newName);
    }

    static Map<String, Integer> nameToIndex(Frame fr) {
        return TargetEncoderHelper.nameToIndex(fr.names());
    }

    static Map<String, Integer> nameToIndex(String[] columns) {
        HashMap<String, Integer> nameToIdx = new HashMap<String, Integer>(columns.length);
        for (int i = 0; i < columns.length; ++i) {
            nameToIdx.put(columns[i], i);
        }
        return nameToIdx;
    }

    static Frame register(Frame frame) {
        frame._key = Key.make();
        DKV.put((Keyed)frame);
        return frame;
    }

    static void printFrame(Frame fr) {
        TwoDimTable twoDimTable = fr.toTwoDimTable(0L, (int)fr.numRows(), false);
        System.out.println(twoDimTable.toString(2, true));
    }

    private static class SubtractCurrentRowForLeaveOneOutTask
    extends MRTask<SubtractCurrentRowForLeaveOneOutTask> {
        private int _numeratorIdx;
        private int _denominatorIdx;
        private int _targetIdx;
        private int _targetClass;

        public SubtractCurrentRowForLeaveOneOutTask(int numeratorIdx, int denominatorIdx, int targetIdx, int targetClass) {
            this._numeratorIdx = numeratorIdx;
            this._denominatorIdx = denominatorIdx;
            this._targetIdx = targetIdx;
            this._targetClass = targetClass;
        }

        public void map(Chunk[] cs) {
            Chunk num = cs[this._numeratorIdx];
            Chunk den = cs[this._denominatorIdx];
            Chunk target = cs[this._targetIdx];
            for (int i = 0; i < num._len; ++i) {
                if (target.isNA(i)) continue;
                double ti = target.atd(i);
                if (this._targetClass == -1) {
                    num.set(i, num.atd(i) - target.atd(i));
                } else if ((double)this._targetClass == ti) {
                    num.set(i, num.atd(i) - 1.0);
                }
                den.set(i, den.atd(i) - 1.0);
            }
        }
    }

    private static class AddNoiseTask
    extends MRTask<AddNoiseTask> {
        private int _columnIdx;
        private int _runifIdx;
        private double _noiseLevel;

        public AddNoiseTask(int columnIdx, int runifIdx, double noiseLevel) {
            this._columnIdx = columnIdx;
            this._runifIdx = runifIdx;
            this._noiseLevel = noiseLevel;
        }

        public void map(Chunk[] cs) {
            Chunk column = cs[this._columnIdx];
            Chunk runifCol = cs[this._runifIdx];
            for (int i = 0; i < column._len; ++i) {
                if (column.isNA(i)) continue;
                column.set(i, column.atd(i) + (runifCol.atd(i) * 2.0 - 1.0) * this._noiseLevel);
            }
        }
    }

    private static class ApplyEncodings
    extends MRTask<ApplyEncodings> {
        private int _encodedColIdx;
        private int _numeratorIdx;
        private int _denominatorIdx;
        private double _priorMean;
        private BlendingParams _blendingParams;

        ApplyEncodings(int encodedColIdx, int numeratorIdx, int denominatorIdx, double priorMean, BlendingParams blendingParams) {
            this._encodedColIdx = encodedColIdx;
            this._numeratorIdx = numeratorIdx;
            this._denominatorIdx = denominatorIdx;
            this._priorMean = priorMean;
            this._blendingParams = blendingParams;
        }

        public void map(Chunk[] cs) {
            Chunk num = cs[this._numeratorIdx];
            Chunk den = cs[this._denominatorIdx];
            Chunk encoded = cs[this._encodedColIdx];
            boolean useBlending = this._blendingParams != null;
            for (int i = 0; i < num._len; ++i) {
                double encodedValue;
                if (num.isNA(i) || den.isNA(i)) {
                    encoded.setNA(i);
                    continue;
                }
                if (den.at8(i) == 0L) {
                    if (logger.isDebugEnabled()) {
                        logger.debug("Denominator is zero for column index = " + this._encodedColIdx + ". Imputing with _priorMean = " + this._priorMean);
                    }
                    encoded.set(i, this._priorMean);
                    continue;
                }
                double posteriorMean = num.atd(i) / den.atd(i);
                if (useBlending) {
                    long numberOfRowsInCurrentCategory = den.at8(i);
                    encodedValue = TargetEncoderHelper.getBlendedValue(posteriorMean, this._priorMean, numberOfRowsInCurrentCategory, this._blendingParams);
                } else {
                    encodedValue = posteriorMean;
                }
                encoded.set(i, encodedValue);
            }
        }
    }
}

