/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import ai.h2o.targetencoding.TargetEncoder;
import water.MRTask;
import water.MemoryManager;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Frame;

class BroadcastJoinForTargetEncoder {
    BroadcastJoinForTargetEncoder() {
    }

    static Frame join(Frame leftFrame, int[] leftCatColumnsIdxs, int leftFoldColumnIdx, Frame broadcastedFrame, int[] rightCatColumnsIdxs, int rightFoldColumnIdx, int maxFoldValue) {
        int numeratorIdx = broadcastedFrame.find(TargetEncoder.NUMERATOR_COL_NAME);
        int denominatorIdx = broadcastedFrame.find(TargetEncoder.DENOMINATOR_COL_NAME);
        int broadcastedFrameCatCardinality = broadcastedFrame.vec(rightCatColumnsIdxs[0]).cardinality();
        if (rightFoldColumnIdx != -1 && broadcastedFrame.vec(rightFoldColumnIdx).max() > 2.147483647E9) {
            throw new IllegalArgumentException("Fold value should be a non-negative integer (i.e. should belong to [0, Integer.MAX_VALUE] range)");
        }
        int[][] levelMappings = new int[][]{CategoricalWrappedVec.computeMap((String[])leftFrame.vec(leftCatColumnsIdxs[0]).domain(), (String[])broadcastedFrame.vec(0).domain())};
        int[][] encodingDataArray = ((FrameWithEncodingDataToArray)new FrameWithEncodingDataToArray(rightCatColumnsIdxs[0], rightFoldColumnIdx, numeratorIdx, denominatorIdx, broadcastedFrameCatCardinality, maxFoldValue).doAll(broadcastedFrame)).getEncodingDataArray();
        new BroadcastJoiner(leftCatColumnsIdxs, leftFoldColumnIdx, encodingDataArray, levelMappings, broadcastedFrameCatCardinality).doAll(leftFrame);
        return leftFrame;
    }

    static class BroadcastJoiner
    extends MRTask<BroadcastJoiner> {
        int _categoricalColumnIdx;
        int _foldColumnIdx;
        int _cardinalityOfCatCol;
        int[][] _encodingDataArray;
        int[][] _levelMappings;

        BroadcastJoiner(int[] categoricalColumnsIdxs, int foldColumnIdx, int[][] encodingDataArray, int[][] levelMappings, int cardinalityOfCatCol) {
            assert (categoricalColumnsIdxs.length == 1) : "Only single column target encoding(i.e. one categorical column is used to produce its encodings) is supported for now";
            this._categoricalColumnIdx = categoricalColumnsIdxs[0];
            this._foldColumnIdx = foldColumnIdx;
            this._encodingDataArray = encodingDataArray;
            this._levelMappings = levelMappings;
            this._cardinalityOfCatCol = cardinalityOfCatCol;
        }

        public void map(Chunk[] cs) {
            Chunk categoricalChunk = cs[this._categoricalColumnIdx];
            int numOfVecs = cs.length;
            Chunk num = cs[numOfVecs - 2];
            Chunk den = cs[numOfVecs - 1];
            for (int i = 0; i < num.len(); ++i) {
                int levelValue = (int)categoricalChunk.at8(i);
                int mappedLevelValue = -1;
                int numberOfLevelValues = this._levelMappings[0].length;
                if (levelValue >= numberOfLevelValues) {
                    this.setEncodingComponentsToNAs(num, den, i);
                    continue;
                }
                mappedLevelValue = this._levelMappings[0][levelValue];
                int[] arrForNumeratorsAndDenominators = null;
                int foldValue = -1;
                if (this._foldColumnIdx != -1) {
                    long foldValueFromVec = cs[this._foldColumnIdx].at8(i);
                    foldValue = (int)foldValueFromVec;
                } else {
                    foldValue = 0;
                }
                arrForNumeratorsAndDenominators = this._encodingDataArray[foldValue];
                if (mappedLevelValue >= this._cardinalityOfCatCol) {
                    this.setEncodingComponentsToNAs(num, den, i);
                    continue;
                }
                int denominator = arrForNumeratorsAndDenominators[this._cardinalityOfCatCol + mappedLevelValue];
                if (denominator == 0) {
                    this.setEncodingComponentsToNAs(num, den, i);
                    continue;
                }
                int numerator = arrForNumeratorsAndDenominators[mappedLevelValue];
                num.set(i, (long)numerator);
                den.set(i, (long)denominator);
            }
        }

        private void setEncodingComponentsToNAs(Chunk num, Chunk den, int i) {
            num.setNA(i);
            den.setNA(i);
        }
    }

    static class FrameWithEncodingDataToArray
    extends MRTask<FrameWithEncodingDataToArray> {
        int[][] _encodingDataPerNode = null;
        int _categoricalColumnIdx;
        int _foldColumnIdx;
        int _numeratorIdx;
        int _denominatorIdx;
        int _cardinalityOfCatCol;
        int _maxFoldValue;

        FrameWithEncodingDataToArray(int categoricalColumnIdx, int foldColumnId, int numeratorIdx, int denominatorIdx, int cardinalityOfCatCol, int maxFoldValue) {
            this._categoricalColumnIdx = categoricalColumnIdx;
            this._foldColumnIdx = foldColumnId;
            this._numeratorIdx = numeratorIdx;
            this._denominatorIdx = denominatorIdx;
            this._cardinalityOfCatCol = cardinalityOfCatCol;
            if (foldColumnId == -1) {
                this._encodingDataPerNode = MemoryManager.malloc4((int)1, (int)(this._cardinalityOfCatCol * 2));
            } else {
                assert (maxFoldValue >= 1) : "There should be at leas two folds in the fold column";
                assert (this._cardinalityOfCatCol > 0 && this._cardinalityOfCatCol < 0x3FFFFFFF) : "Cardinality of categ. column should be within range (0, Integer.MAX_VALUE / 2 )";
                this._encodingDataPerNode = MemoryManager.malloc4((int)(maxFoldValue + 1), (int)(this._cardinalityOfCatCol * 2));
            }
        }

        public void map(Chunk[] cs) {
            Chunk categoricalChunk = cs[this._categoricalColumnIdx];
            Chunk numeratorChunk = cs[this._numeratorIdx];
            Chunk denominatorChunk = cs[this._denominatorIdx];
            for (int i = 0; i < categoricalChunk.len(); ++i) {
                int levelValue = (int)categoricalChunk.at8(i);
                int foldValue = this._foldColumnIdx != -1 ? (int)cs[this._foldColumnIdx].at8(i) : 0;
                int[] arrForNumeratorsAndDenominators = this._encodingDataPerNode[foldValue];
                arrForNumeratorsAndDenominators[levelValue] = (int)numeratorChunk.at8(i);
                arrForNumeratorsAndDenominators[this._cardinalityOfCatCol + levelValue] = (int)denominatorChunk.at8(i);
            }
        }

        public void reduce(FrameWithEncodingDataToArray mrt) {
            int[][] rightArr;
            int[][] leftArr = this.getEncodingDataArray();
            if (leftArr != (rightArr = mrt.getEncodingDataArray())) {
                for (int rowIdx = 0; rowIdx < leftArr.length; ++rowIdx) {
                    for (int colIdx = 0; colIdx < leftArr[rowIdx].length; ++colIdx) {
                        int valueFromLeftArr = leftArr[rowIdx][colIdx];
                        int valueFromRIghtArr = rightArr[rowIdx][colIdx];
                        leftArr[rowIdx][colIdx] = Math.max(valueFromLeftArr, valueFromRIghtArr);
                    }
                }
            }
        }

        int[][] getEncodingDataArray() {
            return this._encodingDataPerNode;
        }
    }
}

