/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException;
import hex.genmodel.easy.exception.PredictUnknownTypeException;
import hex.genmodel.easy.exception.PredictWrongModelCategoryException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.Serializable;
import java.util.HashMap;

public class EasyPredictModelWrapper
implements Serializable {
    private final GenModel m;
    private final HashMap<String, Integer> modelColumnNameToIndexMap;
    private final HashMap<Integer, HashMap<String, Integer>> domainMap;

    public EasyPredictModelWrapper(GenModel model) {
        int i;
        this.m = model;
        this.modelColumnNameToIndexMap = new HashMap();
        String[] modelColumnNames = this.m.getNames();
        for (i = 0; i < modelColumnNames.length; ++i) {
            this.modelColumnNameToIndexMap.put(modelColumnNames[i], i);
        }
        this.domainMap = new HashMap();
        for (i = 0; i < this.m.getNumCols(); ++i) {
            String[] domainValues = this.m.getDomainValues(i);
            if (domainValues == null) continue;
            HashMap<String, Integer> m = new HashMap<String, Integer>();
            for (int j = 0; j < domainValues.length; ++j) {
                m.put(domainValues[j], j);
            }
            this.domainMap.put(i, m);
        }
    }

    public AbstractPrediction predict(RowData data) throws PredictException {
        switch (this.m.getModelCategory()) {
            case AutoEncoder: {
                return this.predictAutoEncoder(data);
            }
            case Binomial: {
                return this.predictBinomial(data);
            }
            case Multinomial: {
                return this.predictMultinomial(data);
            }
            case Clustering: {
                return this.predictClustering(data);
            }
            case Regression: {
                return this.predictRegression(data);
            }
            case Unknown: {
                throw new PredictException("Unknown model category");
            }
        }
        throw new PredictException("Unhandled model category (" + (Object)((Object)this.m.getModelCategory()) + ") in switch statement");
    }

    public AutoEncoderModelPrediction predictAutoEncoder(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.AutoEncoder, data);
        throw new RuntimeException("Unimplemented " + preds.length);
    }

    public BinomialModelPrediction predictBinomial(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Binomial, data);
        BinomialModelPrediction p = new BinomialModelPrediction();
        p.classProbabilities = new double[this.m.getNumResponseClasses()];
        double d = preds[0];
        p.labelIndex = (int)d;
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        p.label = domainValues[p.labelIndex];
        System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length);
        return p;
    }

    public MultinomialModelPrediction predictMultinomial(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Multinomial, data);
        MultinomialModelPrediction p = new MultinomialModelPrediction();
        p.classProbabilities = new double[this.m.getNumResponseClasses()];
        p.labelIndex = (int)preds[0];
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        p.label = domainValues[p.labelIndex];
        System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length);
        return p;
    }

    public ClusteringModelPrediction predictClustering(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Clustering, data);
        ClusteringModelPrediction p = new ClusteringModelPrediction();
        p.cluster = (int)preds[0];
        return p;
    }

    public RegressionModelPrediction predictRegression(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Regression, data);
        RegressionModelPrediction p = new RegressionModelPrediction();
        p.value = preds[0];
        return p;
    }

    public ModelCategory getModelCategory() {
        return this.m.getModelCategory();
    }

    public String[] getResponseDomainValues() {
        return this.m.getDomainValues(this.m.getResponseIdx());
    }

    public String getHeader() {
        return this.m.getHeader();
    }

    private void validateModelCategory(ModelCategory c) throws PredictException {
        if (this.m.getModelCategory() != c) {
            throw new PredictWrongModelCategoryException("Prediction type unsupported by model of category " + (Object)((Object)this.m.getModelCategory()));
        }
    }

    private double[] preamble(ModelCategory c, RowData data) throws PredictException {
        this.validateModelCategory(c);
        double[] preds = new double[this.m.getPredsSize()];
        preds = this.predict(data, preds);
        return preds;
    }

    private void setToNaN(double[] arr) {
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = Double.NaN;
        }
    }

    private void fillRawData(RowData data, double[] rawData) throws PredictException {
        for (String dataColumnName : data.keySet()) {
            Integer index = this.modelColumnNameToIndexMap.get(dataColumnName);
            if (index == null) continue;
            String[] domainValues = this.m.getDomainValues(index);
            if (domainValues == null) {
                double value;
                Object o = data.get(dataColumnName);
                if (o instanceof String) {
                    String s = (String)o;
                    value = Double.parseDouble(s);
                } else if (o instanceof Double) {
                    value = (Double)o;
                } else {
                    throw new PredictUnknownTypeException("Unknown object type " + o.getClass().getName());
                }
                rawData[index.intValue()] = value;
                continue;
            }
            Object o = data.get(dataColumnName);
            if (o instanceof String) {
                double value;
                String levelName = (String)o;
                HashMap<String, Integer> columnDomainMap = this.domainMap.get(index);
                Integer levelIndex = columnDomainMap.get(levelName);
                if (levelIndex == null) {
                    throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + dataColumnName + "," + levelName + ")", dataColumnName, levelName);
                }
                rawData[index.intValue()] = value = (double)levelIndex.intValue();
                continue;
            }
            throw new PredictUnknownTypeException("Unknown object type " + o.getClass().getName());
        }
    }

    private double[] predict(RowData data, double[] preds) throws PredictException {
        double[] rawData = new double[this.m.nfeatures()];
        this.setToNaN(rawData);
        this.fillRawData(data, rawData);
        preds = this.m.score0(rawData, preds);
        return preds;
    }
}

