/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import water.MRTask;
import water.MemoryManager;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Frame;

class TargetEncoderBroadcastJoin {
    TargetEncoderBroadcastJoin() {
    }

    static Frame join(Frame leftFrame, int[] leftCatColumnsIdxs, int leftFoldColumnIdx, Frame rightFrame, int[] rightCatColumnsIdxs, int rightFoldColumnIdx, int maxFoldValue) {
        int rightNumeratorIdx = rightFrame.find("numerator");
        int rightDenominatorIdx = rightFrame.find("denominator");
        assert (leftCatColumnsIdxs.length == 1);
        assert (rightCatColumnsIdxs.length == 1);
        int leftCatColumnIdx = leftCatColumnsIdxs[0];
        int rightCatColumnIdx = rightCatColumnsIdxs[0];
        int rightCatCardinality = rightFrame.vec(rightCatColumnIdx).cardinality();
        if (rightFoldColumnIdx != -1 && rightFrame.vec(rightFoldColumnIdx).max() > 2.147483647E9) {
            throw new IllegalArgumentException("Fold value should be a non-negative integer (i.e. should belong to [0, Integer.MAX_VALUE] range)");
        }
        int[][] levelMappings = new int[][]{CategoricalWrappedVec.computeMap(leftFrame.vec(leftCatColumnIdx).domain(), rightFrame.vec(rightCatColumnIdx).domain())};
        double[][] encodingData = TargetEncoderBroadcastJoin.encodingsToArray(rightFrame, rightCatColumnIdx, rightFoldColumnIdx, rightNumeratorIdx, rightDenominatorIdx, rightCatCardinality, maxFoldValue);
        Frame resultFrame = new Frame(leftFrame);
        resultFrame.add("numerator", resultFrame.anyVec().makeCon(0.0));
        resultFrame.add("denominator", resultFrame.anyVec().makeCon(0.0));
        new BroadcastJoiner(leftCatColumnsIdxs, leftFoldColumnIdx, encodingData, levelMappings, rightCatCardinality - 1).doAll(resultFrame);
        return resultFrame;
    }

    static double[][] encodingsToArray(Frame encodingsFrame, int categoricalColIdx, int foldColIdx, int numColIdx, int denColIdx, int categoricalColCardinality, int maxFoldValue) {
        return ((FrameWithEncodingDataToArray)new FrameWithEncodingDataToArray(categoricalColIdx, foldColIdx, numColIdx, denColIdx, categoricalColCardinality, maxFoldValue).doAll(encodingsFrame)).getEncodingDataArray();
    }

    private static class BroadcastJoiner
    extends MRTask<BroadcastJoiner> {
        int _categoricalColumnIdx;
        int _foldColumnIdx;
        int _maxKnownCatLevel;
        double[][] _encodingDataArray;
        int[][] _levelMappings;

        BroadcastJoiner(int[] categoricalColumnsIdxs, int foldColumnIdx, double[][] encodingDataArray, int[][] levelMappings, int maxKnownCatLevel) {
            assert (categoricalColumnsIdxs.length == 1) : "Only single column target encoding (i.e. one categorical column is used to produce its encodings) is supported for now";
            assert (levelMappings.length == 1);
            this._categoricalColumnIdx = categoricalColumnsIdxs[0];
            this._foldColumnIdx = foldColumnIdx;
            this._encodingDataArray = encodingDataArray;
            this._levelMappings = levelMappings;
            this._maxKnownCatLevel = maxKnownCatLevel;
        }

        @Override
        public void map(Chunk[] cs) {
            int[] levelMapping = this._levelMappings[0];
            Chunk categoricalChunk = cs[this._categoricalColumnIdx];
            Chunk num = cs[cs.length - 2];
            Chunk den = cs[cs.length - 1];
            for (int i2 = 0; i2 < num.len(); ++i2) {
                int level = (int)categoricalChunk.at8(i2);
                if (level >= levelMapping.length) {
                    this.setEncodingComponentsToNAs(num, den, i2);
                    continue;
                }
                int mappedLevel = levelMapping[level];
                int foldValue = this._foldColumnIdx >= 0 ? (int)cs[this._foldColumnIdx].at8(i2) : 0;
                double[] numDenArray = this._encodingDataArray[foldValue];
                if (mappedLevel > this._maxKnownCatLevel) {
                    this.setEncodingComponentsToNAs(num, den, i2);
                    continue;
                }
                double denominator = numDenArray[2 * mappedLevel + 1];
                if (denominator == 0.0) {
                    this.setEncodingComponentsToNAs(num, den, i2);
                    continue;
                }
                double numerator = numDenArray[2 * mappedLevel];
                num.set(i2, numerator);
                den.set(i2, denominator);
            }
        }

        private void setEncodingComponentsToNAs(Chunk num, Chunk den, int i2) {
            num.setNA(i2);
            den.setNA(i2);
        }
    }

    private static class FrameWithEncodingDataToArray
    extends MRTask<FrameWithEncodingDataToArray> {
        final double[][] _encodingDataPerNode;
        final int _categoricalColumnIdx;
        final int _foldColumnIdx;
        final int _numeratorIdx;
        final int _denominatorIdx;
        final int _cardinalityOfCatCol;

        FrameWithEncodingDataToArray(int categoricalColumnIdx, int foldColumnIdx, int numeratorIdx, int denominatorIdx, int cardinalityOfCatCol, int maxFoldValue) {
            this._categoricalColumnIdx = categoricalColumnIdx;
            this._foldColumnIdx = foldColumnIdx;
            this._numeratorIdx = numeratorIdx;
            this._denominatorIdx = denominatorIdx;
            this._cardinalityOfCatCol = cardinalityOfCatCol;
            if (foldColumnIdx == -1) {
                this._encodingDataPerNode = MemoryManager.malloc8d(1, this._cardinalityOfCatCol * 2);
            } else {
                assert (maxFoldValue >= 1) : "There should be at least two folds in the fold column";
                assert (this._cardinalityOfCatCol > 0 && this._cardinalityOfCatCol < 0x3FFFFFFF) : "Cardinality of categ. column should be within range (0, Integer.MAX_VALUE / 2 )";
                this._encodingDataPerNode = MemoryManager.malloc8d(maxFoldValue + 1, this._cardinalityOfCatCol * 2);
            }
        }

        @Override
        public void map(Chunk[] cs) {
            Chunk categoricalChunk = cs[this._categoricalColumnIdx];
            Chunk numeratorChunk = cs[this._numeratorIdx];
            Chunk denominatorChunk = cs[this._denominatorIdx];
            for (int i2 = 0; i2 < categoricalChunk.len(); ++i2) {
                int level = (int)categoricalChunk.at8(i2);
                int foldValue = this._foldColumnIdx != -1 ? (int)cs[this._foldColumnIdx].at8(i2) : 0;
                double[] numDenArray = this._encodingDataPerNode[foldValue];
                numDenArray[2 * level] = numeratorChunk.atd(i2);
                numDenArray[2 * level + 1] = denominatorChunk.at8(i2);
            }
        }

        @Override
        public void reduce(FrameWithEncodingDataToArray mrt) {
            double[][] rightArr;
            double[][] leftArr = this.getEncodingDataArray();
            if (leftArr != (rightArr = mrt.getEncodingDataArray())) {
                for (int rowIdx = 0; rowIdx < leftArr.length; ++rowIdx) {
                    for (int colIdx = 0; colIdx < leftArr[rowIdx].length; ++colIdx) {
                        double valueFromLeftArr = leftArr[rowIdx][colIdx];
                        double valueFromRIghtArr = rightArr[rowIdx][colIdx];
                        leftArr[rowIdx][colIdx] = Math.max(valueFromLeftArr, valueFromRIghtArr);
                    }
                }
            }
        }

        double[][] getEncodingDataArray() {
            return this._encodingDataPerNode;
        }
    }
}

