/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.targetencoder.ColumnsMapping;
import hex.genmodel.algos.targetencoder.ColumnsToSingleMapping;
import hex.genmodel.algos.targetencoder.EncodingMap;
import hex.genmodel.algos.targetencoder.EncodingMaps;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class TargetEncoderMojoModel
extends MojoModel {
    public final Map<String, Integer> _columnNameToIdx;
    public Map<String, Boolean> _teColumn2HasNAs;
    public boolean _withBlending;
    public double _inflectionPoint;
    public double _smoothing;
    public List<ColumnsToSingleMapping> _inencMapping;
    public List<ColumnsMapping> _inoutMapping;
    List<String> _nonPredictors;
    Map<String, EncodingMap> _encodingsByCol;
    boolean _keepOriginalCategoricalColumns;
    private final boolean _imputeUnknownLevels = true;

    public static double computeLambda(long nrows, double inflectionPoint, double smoothing) {
        return 1.0 / (1.0 + Math.exp((inflectionPoint - (double)nrows) / smoothing));
    }

    public static double computeBlendedEncoding(double lambda, double posteriorMean, double priorMean) {
        return lambda * posteriorMean + (1.0 - lambda) * priorMean;
    }

    static Map<String, Integer> name2Idx(String[] columns) {
        HashMap<String, Integer> hashMap = new HashMap<String, Integer>(columns.length);
        for (int i2 = 0; i2 < columns.length; ++i2) {
            hashMap.put(columns[i2], i2);
        }
        return hashMap;
    }

    public TargetEncoderMojoModel(String[] columns, String[][] domains, String responseName) {
        super(columns, domains, responseName);
        this._columnNameToIdx = TargetEncoderMojoModel.name2Idx(columns);
    }

    protected void init() {
        if (this._encodingsByCol == null) {
            return;
        }
        if (this._inencMapping == null) {
            this._inencMapping = new ArrayList<ColumnsToSingleMapping>();
        }
        if (this._inoutMapping == null) {
            this._inoutMapping = new ArrayList<ColumnsMapping>();
        }
        if (this._inencMapping.isEmpty() && this._inoutMapping.isEmpty()) {
            for (String string : this._encodingsByCol.keySet()) {
                String[] stringArray = new String[]{string};
                this._inencMapping.add(new ColumnsToSingleMapping(stringArray, string, null));
                String[] stringArray2 = new String[this.getNumEncColsPerPredictor()];
                if (stringArray2.length > 1) {
                    for (int i2 = 0; i2 < stringArray2.length; ++i2) {
                        stringArray2[i2] = string + "_" + (i2 + 1) + "_te";
                    }
                } else {
                    stringArray2[0] = string + "_te";
                }
                this._inoutMapping.add(new ColumnsMapping(stringArray, stringArray2));
            }
        }
    }

    protected void setEncodings(EncodingMaps encodingMaps) {
        this._encodingsByCol = encodingMaps.encodingMap();
    }

    @Override
    public int getPredsSize() {
        if (this._encodingsByCol == null) {
            return 0;
        }
        return this._encodingsByCol.size() * this.getNumEncColsPerPredictor();
    }

    int getNumEncColsPerPredictor() {
        if (this.nclasses() > 1) {
            return this.nclasses() - 1;
        }
        return 1;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        if (this._encodingsByCol == null) {
            throw new IllegalStateException("Encoding map is missing.");
        }
        int n2 = 0;
        for (ColumnsToSingleMapping columnsToSingleMapping : this._inencMapping) {
            double d2;
            Object[] objectArray = columnsToSingleMapping.from();
            String string = columnsToSingleMapping.toSingle();
            EncodingMap encodingMap = this._encodingsByCol.get(string);
            int[] nArray = this.columnsIndices((String[])objectArray);
            if (nArray.length == 1) {
                d2 = row[nArray[0]];
            } else {
                assert (columnsToSingleMapping.toDomainAsNum() != null) : "Missing domain for interaction between columns " + Arrays.toString(objectArray);
                d2 = this.interactionValue(row, nArray, columnsToSingleMapping.toDomainAsNum());
            }
            int n3 = Double.isNaN(d2) ? this.encodeNA(preds, n2, encodingMap, string) : this.encodeCategory(preds, n2, encodingMap, (int)d2);
            n2 += n3;
        }
        return preds;
    }

    public EncodingMap getEncodings(String column) {
        return this._encodingsByCol.get(column);
    }

    private int[] columnsIndices(String[] names) {
        int[] nArray = new int[names.length];
        for (int i2 = 0; i2 < nArray.length; ++i2) {
            nArray[i2] = this._columnNameToIdx.get(names[i2]);
        }
        return nArray;
    }

    private double interactionValue(double[] row, int[] colsIdx, long[] interactionDomain) {
        long l2 = 0L;
        long l3 = 1L;
        int[] nArray = colsIdx;
        int n2 = colsIdx.length;
        for (int i2 = 0; i2 < n2; ++i2) {
            int n3 = nArray[i2];
            double d2 = row[n3];
            int n4 = this.getDomainValues(n3).length;
            if (Double.isNaN(d2) || d2 >= (double)n4) {
                d2 = n4;
            }
            l2 = (long)((double)l2 + (double)l3 * d2);
            l3 *= (long)(n4 + 1);
        }
        int n5 = Arrays.binarySearch(interactionDomain, l2);
        if (n5 < 0) {
            return Double.NaN;
        }
        return n5;
    }

    private double computeEncodedValue(double[] numDen, double priorMean) {
        double d2 = numDen[0] / numDen[1];
        if (this._withBlending) {
            long l2 = (long)numDen[1];
            double d3 = TargetEncoderMojoModel.computeLambda(l2, this._inflectionPoint, this._smoothing);
            return TargetEncoderMojoModel.computeBlendedEncoding(d3, d2, priorMean);
        }
        return d2;
    }

    int encodeCategory(double[] result, int startIdx, EncodingMap encodings, int category) {
        if (this.nclasses() > 2) {
            for (int i2 = 0; i2 < this.nclasses() - 1; ++i2) {
                int n2 = i2 + 1;
                double[] dArray = encodings.getNumDen(category, n2);
                double d2 = encodings.getPriorMean(n2);
                result[startIdx + i2] = this.computeEncodedValue(dArray, d2);
            }
            return this.nclasses() - 1;
        }
        double[] dArray = encodings.getNumDen(category);
        double d3 = encodings.getPriorMean();
        result[startIdx] = this.computeEncodedValue(dArray, d3);
        return 1;
    }

    int encodeNA(double[] result, int startIdx, EncodingMap encodings, String column) {
        int filled;
        if (this._teColumn2HasNAs.get(column).booleanValue()) {
            EncodingMap encodingMap = encodings;
            filled = this.encodeCategory(result, startIdx, encodingMap, encodingMap.getNACategory());
        } else {
            filled = this.encodeWithPriorMean(result, startIdx, encodings);
        }
        return filled;
    }

    private int encodeWithPriorMean(double[] preds, int startIdx, EncodingMap encodings) {
        if (this._nclasses > 2) {
            for (int i2 = 0; i2 < this._nclasses - 1; ++i2) {
                preds[startIdx + i2] = encodings.getPriorMean(i2 + 1);
            }
            return this._nclasses - 1;
        }
        preds[startIdx] = encodings.getPriorMean();
        return 1;
    }
}

