/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.targetencoder.EncodingMap;
import hex.genmodel.algos.targetencoder.EncodingMaps;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

public class TargetEncoderMojoModel
extends MojoModel {
    public EncodingMaps _targetEncodingMap;
    public Map<String, Integer> _teColumnNameToIdx = new HashMap<String, Integer>();
    public Map<String, Integer> _teColumnNameToMissingValuesPresence;
    public boolean _withBlending;
    public double _inflectionPoint;
    public double _smoothing;
    public double _priorMean;
    private final boolean _imputationOfUnknownLevelsIsEnabled = true;

    public TargetEncoderMojoModel(String[] columns, String[][] domains, String responseName) {
        super(columns, domains, responseName);
        this._teColumnNameToIdx = new HashMap<String, Integer>(columns.length);
        for (int i = 0; i < columns.length - 1; ++i) {
            this._teColumnNameToIdx.put(columns[i], i);
        }
    }

    public static double computeLambda(int nrows, double inflectionPoint, double smoothing) {
        return 1.0 / (1.0 + Math.exp((inflectionPoint - (double)nrows) / smoothing));
    }

    public static double computeBlendedEncoding(double lambda, double posteriorMean, double priorMean) {
        return lambda * posteriorMean + (1.0 - lambda) * priorMean;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        if (this._targetEncodingMap != null) {
            int predictionIndex = 0;
            LinkedHashMap<String, EncodingMap> sortedByColumnIndex = this.sortByColumnIndex(this._targetEncodingMap.encodingMap());
            for (Map.Entry<String, EncodingMap> columnToEncodingsMap : sortedByColumnIndex.entrySet()) {
                EncodingMap encodings = columnToEncodingsMap.getValue();
                String teColumn = columnToEncodingsMap.getKey();
                int indexOfColumnInRow = this._teColumnNameToIdx.get(teColumn);
                double categoricalLevelIndex = row[indexOfColumnInRow];
                if (Double.isNaN(categoricalLevelIndex)) {
                    if (this._teColumnNameToMissingValuesPresence.get(teColumn) == 1) {
                        int indexOfNALevel = encodings._encodingMap.size() - 1;
                        this.computeEncodings(preds, predictionIndex, encodings, indexOfNALevel);
                    } else {
                        preds[predictionIndex] = this._priorMean;
                    }
                } else {
                    int categoricalLevelIndexAsInt = (int)categoricalLevelIndex;
                    this.computeEncodings(preds, predictionIndex, encodings, categoricalLevelIndexAsInt);
                }
                ++predictionIndex;
            }
        } else {
            throw new IllegalStateException("Encoding map is missing.");
        }
        return preds;
    }

    private void computeEncodings(double[] preds, int predictionIndex, EncodingMap encodings, int originalValueAsInt) {
        int[] correspondingNumAndDen = encodings._encodingMap.get(originalValueAsInt);
        double posteriorMean = (double)correspondingNumAndDen[0] / (double)correspondingNumAndDen[1];
        if (this._withBlending) {
            double blendedValue;
            int numberOfRowsInCurrentCategory = correspondingNumAndDen[1];
            double lambda = TargetEncoderMojoModel.computeLambda(numberOfRowsInCurrentCategory, this._inflectionPoint, this._smoothing);
            preds[predictionIndex] = blendedValue = TargetEncoderMojoModel.computeBlendedEncoding(lambda, posteriorMean, this._priorMean);
        } else {
            preds[predictionIndex] = posteriorMean;
        }
    }

    <K, V> LinkedHashMap<K, V> sortByColumnIndex(Map<K, V> map) {
        ArrayList<Map.Entry<K, V>> list = new ArrayList<Map.Entry<K, V>>(map.entrySet());
        Collections.sort(list, new SortByKeyAssociatedIndex(this._teColumnNameToIdx));
        LinkedHashMap<K, V> result = new LinkedHashMap<K, V>();
        for (Map.Entry<K, V> entry : list) {
            result.put(entry.getKey(), entry.getValue());
        }
        return result;
    }

    public static class SortByKeyAssociatedIndex<K extends String, V>
    implements Comparator<Map.Entry<K, V>> {
        public Map<String, Integer> _teColumnNameToIdx;

        public SortByKeyAssociatedIndex(Map<String, Integer> teColumnNameToIdx) {
            this._teColumnNameToIdx = teColumnNameToIdx;
        }

        @Override
        public int compare(Map.Entry<K, V> o1, Map.Entry<K, V> o2) {
            String keyLeft = (String)o1.getKey();
            String keyRight = (String)o2.getKey();
            Integer keyLeftIdx = this._teColumnNameToIdx.get(keyLeft);
            Integer keyRightIdx = this._teColumnNameToIdx.get(keyRight);
            return keyLeftIdx.compareTo(keyRightIdx);
        }
    }
}

