/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.targetencoder.ColumnsMapping;
import hex.genmodel.algos.targetencoder.ColumnsToSingleMapping;
import hex.genmodel.algos.targetencoder.EncodingMap;
import hex.genmodel.algos.targetencoder.EncodingMaps;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class TargetEncoderMojoModel
extends MojoModel {
    public final Map<String, Integer> _columnNameToIdx;
    public Map<String, Boolean> _teColumn2HasNAs;
    public boolean _withBlending;
    public double _inflectionPoint;
    public double _smoothing;
    public List<ColumnsToSingleMapping> _inencMapping;
    public List<ColumnsMapping> _inoutMapping;
    List<String> _nonPredictors;
    Map<String, EncodingMap> _encodingsByCol;
    boolean _keepOriginalCategoricalColumns;
    private final boolean _imputeUnknownLevels = true;

    public static double computeLambda(long nrows, double inflectionPoint, double smoothing) {
        return 1.0 / (1.0 + Math.exp((inflectionPoint - (double)nrows) / smoothing));
    }

    public static double computeBlendedEncoding(double lambda, double posteriorMean, double priorMean) {
        return lambda * posteriorMean + (1.0 - lambda) * priorMean;
    }

    static Map<String, Integer> name2Idx(String[] columns) {
        HashMap<String, Integer> nameToIdx = new HashMap<String, Integer>(columns.length);
        for (int i = 0; i < columns.length; ++i) {
            nameToIdx.put(columns[i], i);
        }
        return nameToIdx;
    }

    public TargetEncoderMojoModel(String[] columns, String[][] domains, String responseName) {
        super(columns, domains, responseName);
        this._columnNameToIdx = TargetEncoderMojoModel.name2Idx(columns);
    }

    protected void init() {
        if (this._encodingsByCol == null) {
            return;
        }
        if (this._inencMapping == null) {
            this._inencMapping = new ArrayList<ColumnsToSingleMapping>();
        }
        if (this._inoutMapping == null) {
            this._inoutMapping = new ArrayList<ColumnsMapping>();
        }
        if (this._inencMapping.isEmpty() && this._inoutMapping.isEmpty()) {
            for (String col : this._encodingsByCol.keySet()) {
                String[] in = new String[]{col};
                this._inencMapping.add(new ColumnsToSingleMapping(in, col, null));
                String[] out = new String[this.getNumEncColsPerPredictor()];
                if (out.length > 1) {
                    for (int i = 0; i < out.length; ++i) {
                        out[i] = col + "_" + (i + 1) + "_te";
                    }
                } else {
                    out[0] = col + "_te";
                }
                this._inoutMapping.add(new ColumnsMapping(in, out));
            }
        }
    }

    protected void setEncodings(EncodingMaps encodingMaps) {
        this._encodingsByCol = encodingMaps.encodingMap();
    }

    @Override
    public int getPredsSize() {
        return this._encodingsByCol == null ? 0 : this._encodingsByCol.size() * this.getNumEncColsPerPredictor();
    }

    int getNumEncColsPerPredictor() {
        return this.nclasses() > 1 ? this.nclasses() - 1 : 1;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        if (this._encodingsByCol == null) {
            throw new IllegalStateException("Encoding map is missing.");
        }
        int predsIdx = 0;
        for (ColumnsToSingleMapping colMap : this._inencMapping) {
            double category;
            Object[] colGroup = colMap.from();
            String teColumn = colMap.toSingle();
            EncodingMap encodings = this._encodingsByCol.get(teColumn);
            int[] colsIdx = this.columnsIndices((String[])colGroup);
            if (colsIdx.length == 1) {
                category = row[colsIdx[0]];
            } else {
                assert (colMap.toDomainAsNum() != null) : "Missing domain for interaction between columns " + Arrays.toString(colGroup);
                category = this.interactionValue(row, colsIdx, colMap.toDomainAsNum());
            }
            int filled = Double.isNaN(category) ? this.encodeNA(preds, predsIdx, encodings, teColumn) : this.encodeCategory(preds, predsIdx, encodings, (int)category);
            predsIdx += filled;
        }
        return preds;
    }

    public EncodingMap getEncodings(String column) {
        return this._encodingsByCol.get(column);
    }

    private int[] columnsIndices(String[] names) {
        int[] indices = new int[names.length];
        for (int i = 0; i < indices.length; ++i) {
            indices[i] = this._columnNameToIdx.get(names[i]);
        }
        return indices;
    }

    private double interactionValue(double[] row, int[] colsIdx, long[] interactionDomain) {
        long interaction = 0L;
        long multiplier = 1L;
        for (int colIdx : colsIdx) {
            double val = row[colIdx];
            int domainCard = this.getDomainValues(colIdx).length;
            if (Double.isNaN(val) || val >= (double)domainCard) {
                val = domainCard;
            }
            interaction = (long)((double)interaction + (double)multiplier * val);
            multiplier *= (long)(domainCard + 1);
        }
        int catVal = Arrays.binarySearch(interactionDomain, interaction);
        return catVal < 0 ? Double.NaN : (double)catVal;
    }

    private double computeEncodedValue(double[] numDen, double priorMean) {
        double posteriorMean = numDen[0] / numDen[1];
        if (this._withBlending) {
            long nrows = (long)numDen[1];
            double lambda = TargetEncoderMojoModel.computeLambda(nrows, this._inflectionPoint, this._smoothing);
            return TargetEncoderMojoModel.computeBlendedEncoding(lambda, posteriorMean, priorMean);
        }
        return posteriorMean;
    }

    int encodeCategory(double[] result, int startIdx, EncodingMap encodings, int category) {
        if (this.nclasses() > 2) {
            for (int i = 0; i < this.nclasses() - 1; ++i) {
                int targetClass = i + 1;
                double[] numDen = encodings.getNumDen(category, targetClass);
                double priorMean = encodings.getPriorMean(targetClass);
                result[startIdx + i] = this.computeEncodedValue(numDen, priorMean);
            }
            return this.nclasses() - 1;
        }
        double[] numDen = encodings.getNumDen(category);
        double priorMean = encodings.getPriorMean();
        result[startIdx] = this.computeEncodedValue(numDen, priorMean);
        return 1;
    }

    int encodeNA(double[] result, int startIdx, EncodingMap encodings, String column) {
        int filled = 0;
        filled = this._teColumn2HasNAs.get(column) != false ? this.encodeCategory(result, startIdx, encodings, encodings.getNACategory()) : this.encodeWithPriorMean(result, startIdx, encodings);
        return filled;
    }

    private int encodeWithPriorMean(double[] preds, int startIdx, EncodingMap encodings) {
        if (this._nclasses > 2) {
            for (int i = 0; i < this._nclasses - 1; ++i) {
                preds[startIdx + i] = encodings.getPriorMean(i + 1);
            }
            return this._nclasses - 1;
        }
        preds[startIdx] = encodings.getPriorMean();
        return 1;
    }
}

