/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.targetencoder.EncodingMap;
import hex.genmodel.algos.targetencoder.EncodingMaps;
import java.io.Serializable;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

public class TargetEncoderMojoModel
extends MojoModel {
    public final Map<String, Integer> _columnNameToIdx;
    public Map<String, Boolean> _teColumn2HasNAs;
    public boolean _withBlending;
    public double _inflectionPoint;
    public double _smoothing;
    List<String> _nonPredictors;
    Map<String, EncodingMap> _encodingsByCol;
    boolean _keepOriginalCategoricalColumns;
    private final boolean _imputeUnknownLevels = true;

    public static double computeLambda(long nrows, double inflectionPoint, double smoothing) {
        return 1.0 / (1.0 + Math.exp((inflectionPoint - (double)nrows) / smoothing));
    }

    public static double computeBlendedEncoding(double lambda, double posteriorMean, double priorMean) {
        return lambda * posteriorMean + (1.0 - lambda) * priorMean;
    }

    public TargetEncoderMojoModel(String[] columns, String[][] domains, String responseName) {
        super(columns, domains, responseName);
        this._columnNameToIdx = TargetEncoderMojoModel.name2Idx(columns);
    }

    static Map<String, Integer> name2Idx(String[] columns) {
        HashMap<String, Integer> hashMap = new HashMap<String, Integer>(columns.length);
        for (int i2 = 0; i2 < columns.length; ++i2) {
            hashMap.put(columns[i2], i2);
        }
        return hashMap;
    }

    protected void setEncodings(EncodingMaps encodingMaps) {
        this._encodingsByCol = this.sortByColumnIndex(encodingMaps);
    }

    @Override
    public int getPredsSize() {
        if (this._encodingsByCol == null) {
            return 0;
        }
        return this._encodingsByCol.size() * this.getNumEncColsPerPredictor();
    }

    int getNumEncColsPerPredictor() {
        if (this.nclasses() > 1) {
            return this.nclasses() - 1;
        }
        return 1;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        if (this._encodingsByCol == null) {
            throw new IllegalStateException("Encoding map is missing.");
        }
        int n2 = 0;
        for (Map.Entry<String, EncodingMap> entry : this._encodingsByCol.entrySet()) {
            String string = entry.getKey();
            EncodingMap encodingMap = entry.getValue();
            int n3 = this._columnNameToIdx.get(string);
            double d2 = row[n3];
            int n4 = Double.isNaN(d2) ? this.encodeNA(preds, n2, encodingMap, string) : this.encodeCategory(preds, n2, encodingMap, (int)d2);
            n2 += n4;
        }
        return preds;
    }

    public EncodingMap getEncodings(String column) {
        return this._encodingsByCol.get(column);
    }

    private double computeEncodedValue(double[] numDen, double priorMean) {
        double d2 = numDen[0] / numDen[1];
        if (this._withBlending) {
            long l2 = (long)numDen[1];
            double d3 = TargetEncoderMojoModel.computeLambda(l2, this._inflectionPoint, this._smoothing);
            return TargetEncoderMojoModel.computeBlendedEncoding(d3, d2, priorMean);
        }
        return d2;
    }

    int encodeCategory(double[] result, int startIdx, EncodingMap encodings, int category) {
        if (this.nclasses() > 2) {
            for (int i2 = 0; i2 < this.nclasses() - 1; ++i2) {
                int n2 = i2 + 1;
                double[] dArray = encodings.getNumDen(category, n2);
                double d2 = encodings.getPriorMean(n2);
                result[startIdx + i2] = this.computeEncodedValue(dArray, d2);
            }
            return this.nclasses() - 1;
        }
        double[] dArray = encodings.getNumDen(category);
        double d3 = encodings.getPriorMean();
        result[startIdx] = this.computeEncodedValue(dArray, d3);
        return 1;
    }

    int encodeNA(double[] result, int startIdx, EncodingMap encodings, String column) {
        int filled;
        if (this._teColumn2HasNAs.get(column).booleanValue()) {
            EncodingMap encodingMap = encodings;
            filled = this.encodeCategory(result, startIdx, encodingMap, encodingMap.getNACategory());
        } else {
            filled = this.encodeWithPriorMean(result, startIdx, encodings);
        }
        return filled;
    }

    private int encodeWithPriorMean(double[] preds, int startIdx, EncodingMap encodings) {
        if (this._nclasses > 2) {
            for (int i2 = 0; i2 < this._nclasses - 1; ++i2) {
                preds[startIdx + i2] = encodings.getPriorMean(i2 + 1);
            }
            return this._nclasses - 1;
        }
        preds[startIdx] = encodings.getPriorMean();
        return 1;
    }

    Map<String, EncodingMap> sortByColumnIndex(EncodingMaps encodingMaps) {
        TreeMap<String, EncodingMap> treeMap = new TreeMap<String, EncodingMap>(new ColumnComparator(this._columnNameToIdx));
        treeMap.putAll(encodingMaps.encodingMap());
        return treeMap;
    }

    private static class ColumnComparator
    implements Serializable,
    Comparator<String> {
        private Map<String, Integer> _columnToIdx;

        public ColumnComparator(Map<String, Integer> _columnToIdx) {
            this._columnToIdx = _columnToIdx;
        }

        @Override
        public int compare(String lhs, String rhs) {
            return Integer.compare(this._columnToIdx.get(lhs), this._columnToIdx.get(rhs));
        }
    }
}

