/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.targetencoder.EncodingMap;
import hex.genmodel.algos.targetencoder.EncodingMaps;
import java.io.Serializable;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

public class TargetEncoderMojoModel
extends MojoModel {
    public final Map<String, Integer> _columnNameToIdx;
    public Map<String, Boolean> _teColumn2HasNAs;
    public boolean _withBlending;
    public double _inflectionPoint;
    public double _smoothing;
    List<String> _nonPredictors;
    Map<String, EncodingMap> _encodingsByCol;
    boolean _keepOriginalCategoricalColumns;
    private final boolean _imputeUnknownLevels = true;

    public static double computeLambda(long nrows, double inflectionPoint, double smoothing) {
        return 1.0 / (1.0 + Math.exp((inflectionPoint - (double)nrows) / smoothing));
    }

    public static double computeBlendedEncoding(double lambda, double posteriorMean, double priorMean) {
        return lambda * posteriorMean + (1.0 - lambda) * priorMean;
    }

    public TargetEncoderMojoModel(String[] columns, String[][] domains, String responseName) {
        super(columns, domains, responseName);
        this._columnNameToIdx = TargetEncoderMojoModel.name2Idx(columns);
    }

    static Map<String, Integer> name2Idx(String[] columns) {
        HashMap<String, Integer> nameToIdx = new HashMap<String, Integer>(columns.length);
        for (int i = 0; i < columns.length; ++i) {
            nameToIdx.put(columns[i], i);
        }
        return nameToIdx;
    }

    protected void setEncodings(EncodingMaps encodingMaps) {
        this._encodingsByCol = this.sortByColumnIndex(encodingMaps);
    }

    @Override
    public int getPredsSize() {
        return this._encodingsByCol == null ? 0 : this._encodingsByCol.size() * this.getNumEncColsPerPredictor();
    }

    int getNumEncColsPerPredictor() {
        return this.nclasses() > 1 ? this.nclasses() - 1 : 1;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        if (this._encodingsByCol == null) {
            throw new IllegalStateException("Encoding map is missing.");
        }
        int predsIdx = 0;
        for (Map.Entry<String, EncodingMap> columnToEncodings : this._encodingsByCol.entrySet()) {
            String teColumn = columnToEncodings.getKey();
            EncodingMap encodings = columnToEncodings.getValue();
            int colIdx = this._columnNameToIdx.get(teColumn);
            double category = row[colIdx];
            int filled = Double.isNaN(category) ? this.encodeNA(preds, predsIdx, encodings, teColumn) : this.encodeCategory(preds, predsIdx, encodings, (int)category);
            predsIdx += filled;
        }
        return preds;
    }

    public EncodingMap getEncodings(String column) {
        return this._encodingsByCol.get(column);
    }

    private double computeEncodedValue(double[] numDen, double priorMean) {
        double posteriorMean = numDen[0] / numDen[1];
        if (this._withBlending) {
            long nrows = (long)numDen[1];
            double lambda = TargetEncoderMojoModel.computeLambda(nrows, this._inflectionPoint, this._smoothing);
            return TargetEncoderMojoModel.computeBlendedEncoding(lambda, posteriorMean, priorMean);
        }
        return posteriorMean;
    }

    int encodeCategory(double[] result, int startIdx, EncodingMap encodings, int category) {
        if (this.nclasses() > 2) {
            for (int i = 0; i < this.nclasses() - 1; ++i) {
                int targetClass = i + 1;
                double[] numDen = encodings.getNumDen(category, targetClass);
                double priorMean = encodings.getPriorMean(targetClass);
                result[startIdx + i] = this.computeEncodedValue(numDen, priorMean);
            }
            return this.nclasses() - 1;
        }
        double[] numDen = encodings.getNumDen(category);
        double priorMean = encodings.getPriorMean();
        result[startIdx] = this.computeEncodedValue(numDen, priorMean);
        return 1;
    }

    int encodeNA(double[] result, int startIdx, EncodingMap encodings, String column) {
        int filled = 0;
        filled = this._teColumn2HasNAs.get(column) != false ? this.encodeCategory(result, startIdx, encodings, encodings.getNACategory()) : this.encodeWithPriorMean(result, startIdx, encodings);
        return filled;
    }

    private int encodeWithPriorMean(double[] preds, int startIdx, EncodingMap encodings) {
        if (this._nclasses > 2) {
            for (int i = 0; i < this._nclasses - 1; ++i) {
                preds[startIdx + i] = encodings.getPriorMean(i + 1);
            }
            return this._nclasses - 1;
        }
        preds[startIdx] = encodings.getPriorMean();
        return 1;
    }

    Map<String, EncodingMap> sortByColumnIndex(EncodingMaps encodingMaps) {
        TreeMap<String, EncodingMap> sorted = new TreeMap<String, EncodingMap>(new ColumnComparator(this._columnNameToIdx));
        sorted.putAll(encodingMaps.encodingMap());
        return sorted;
    }

    private static class ColumnComparator
    implements Comparator<String>,
    Serializable {
        private Map<String, Integer> _columnToIdx;

        public ColumnComparator(Map<String, Integer> _columnToIdx) {
            this._columnToIdx = _columnToIdx;
        }

        @Override
        public int compare(String lhs, String rhs) {
            return Integer.compare(this._columnToIdx.get(lhs), this._columnToIdx.get(rhs));
        }
    }
}

