/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.easy;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.algos.deepwater.DeepwaterMojoModel;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.exception.PredictNumberFormatException;
import hex.genmodel.easy.exception.PredictUnknownCategoricalLevelException;
import hex.genmodel.easy.exception.PredictUnknownTypeException;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import hex.genmodel.easy.prediction.SortedClassProbability;
import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.net.URL;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import javax.imageio.ImageIO;

public class EasyPredictModelWrapper
implements Serializable {
    private final GenModel m;
    private final HashMap<String, Integer> modelColumnNameToIndexMap;
    private final HashMap<Integer, HashMap<String, Integer>> domainMap;
    private final boolean convertUnknownCategoricalLevelsToNa;
    private final ConcurrentHashMap<String, AtomicLong> unknownCategoricalLevelsSeenPerColumn;

    public EasyPredictModelWrapper(Config config) {
        int i;
        this.m = config.getModel();
        this.modelColumnNameToIndexMap = new HashMap();
        String[] modelColumnNames = this.m.getNames();
        for (i = 0; i < modelColumnNames.length; ++i) {
            this.modelColumnNameToIndexMap.put(modelColumnNames[i], i);
        }
        this.unknownCategoricalLevelsSeenPerColumn = new ConcurrentHashMap();
        this.convertUnknownCategoricalLevelsToNa = config.getConvertUnknownCategoricalLevelsToNa();
        this.setupConvertUnknownCategoricalLevelsToNa();
        this.domainMap = new HashMap();
        for (i = 0; i < this.m.getNumCols(); ++i) {
            String[] domainValues = this.m.getDomainValues(i);
            if (domainValues == null) continue;
            HashMap<String, Integer> m = new HashMap<String, Integer>();
            for (int j = 0; j < domainValues.length; ++j) {
                m.put(domainValues[j], j);
            }
            this.domainMap.put(i, m);
        }
    }

    public EasyPredictModelWrapper(GenModel model) {
        this(new Config().setModel(model));
    }

    public long getTotalUnknownCategoricalLevelsSeen() {
        ConcurrentHashMap<String, AtomicLong> map = this.getUnknownCategoricalLevelsSeenPerColumn();
        long total = 0L;
        for (AtomicLong l : map.values()) {
            total += l.get();
        }
        return total;
    }

    public ConcurrentHashMap<String, AtomicLong> getUnknownCategoricalLevelsSeenPerColumn() {
        return this.unknownCategoricalLevelsSeenPerColumn;
    }

    public AbstractPrediction predict(RowData data, ModelCategory mc) throws PredictException {
        switch (mc) {
            case AutoEncoder: {
                return this.predictAutoEncoder(data);
            }
            case Binomial: {
                return this.predictBinomial(data);
            }
            case Multinomial: {
                return this.predictMultinomial(data);
            }
            case Clustering: {
                return this.predictClustering(data);
            }
            case Regression: {
                return this.predictRegression(data);
            }
            case DimReduction: {
                return this.predictDimReduction(data);
            }
            case Unknown: {
                throw new PredictException("Unknown model category");
            }
        }
        throw new PredictException("Unhandled model category (" + (Object)((Object)this.m.getModelCategory()) + ") in switch statement");
    }

    public AbstractPrediction predict(RowData data) throws PredictException {
        return this.predict(data, this.m.getModelCategory());
    }

    public AutoEncoderModelPrediction predictAutoEncoder(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.AutoEncoder, data);
        throw new RuntimeException("Unimplemented " + preds.length);
    }

    public DimReductionModelPrediction predictDimReduction(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.DimReduction, data);
        DimReductionModelPrediction p = new DimReductionModelPrediction();
        p.dimensions = preds;
        return p;
    }

    public BinomialModelPrediction predictBinomial(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Binomial, data);
        BinomialModelPrediction p = new BinomialModelPrediction();
        p.classProbabilities = new double[this.m.getNumResponseClasses()];
        double d = preds[0];
        p.labelIndex = (int)d;
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        p.label = domainValues[p.labelIndex];
        System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length);
        return p;
    }

    public MultinomialModelPrediction predictMultinomial(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Multinomial, data);
        MultinomialModelPrediction p = new MultinomialModelPrediction();
        p.classProbabilities = new double[this.m.getNumResponseClasses()];
        p.labelIndex = (int)preds[0];
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        p.label = domainValues[p.labelIndex];
        System.arraycopy(preds, 1, p.classProbabilities, 0, p.classProbabilities.length);
        return p;
    }

    private SortedClassProbability[] sortByDescendingClassProbability(String[] domainValues, double[] classProbabilities) {
        assert (classProbabilities.length == domainValues.length);
        SortedClassProbability[] arr = new SortedClassProbability[domainValues.length];
        for (int i = 0; i < domainValues.length; ++i) {
            arr[i] = new SortedClassProbability();
            arr[i].name = domainValues[i];
            arr[i].probability = classProbabilities[i];
        }
        Arrays.sort(arr, Collections.reverseOrder());
        return arr;
    }

    public SortedClassProbability[] sortByDescendingClassProbability(BinomialModelPrediction p) {
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        double[] classProbabilities = p.classProbabilities;
        return this.sortByDescendingClassProbability(domainValues, classProbabilities);
    }

    public SortedClassProbability[] sortByDescendingClassProbability(MultinomialModelPrediction p) {
        String[] domainValues = this.m.getDomainValues(this.m.getResponseIdx());
        double[] classProbabilities = p.classProbabilities;
        return this.sortByDescendingClassProbability(domainValues, classProbabilities);
    }

    public ClusteringModelPrediction predictClustering(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Clustering, data);
        ClusteringModelPrediction p = new ClusteringModelPrediction();
        p.cluster = (int)preds[0];
        return p;
    }

    public RegressionModelPrediction predictRegression(RowData data) throws PredictException {
        double[] preds = this.preamble(ModelCategory.Regression, data);
        RegressionModelPrediction p = new RegressionModelPrediction();
        p.value = preds[0];
        return p;
    }

    public ModelCategory getModelCategory() {
        return this.m.getModelCategory();
    }

    public String[] getResponseDomainValues() {
        return this.m.getDomainValues(this.m.getResponseIdx());
    }

    public String getHeader() {
        return this.m.getHeader();
    }

    private void setupConvertUnknownCategoricalLevelsToNa() {
        if (this.convertUnknownCategoricalLevelsToNa) {
            for (int i = 0; i < this.m.getNumCols(); ++i) {
                String[] domainValues = this.m.getDomainValues(i);
                if (domainValues == null) continue;
                String columnName = this.m.getNames()[i];
                this.unknownCategoricalLevelsSeenPerColumn.put(columnName, new AtomicLong());
            }
        } else {
            this.unknownCategoricalLevelsSeenPerColumn.clear();
        }
    }

    private void validateModelCategory(ModelCategory c) throws PredictException {
        if (!this.m.getModelCategories().contains((Object)c)) {
            throw new PredictException((Object)((Object)c) + " prediction type is not supported for this model.");
        }
    }

    private double[] preamble(ModelCategory c, RowData data) throws PredictException {
        this.validateModelCategory(c);
        return this.predict(data, new double[this.m.getPredsSize(c)]);
    }

    private void setToNaN(double[] arr) {
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = Double.NaN;
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private double[] fillRawData(RowData data, double[] rawData) throws PredictException {
        boolean isImage = this.m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel)this.m)._problem_type.equals("image");
        boolean isText = this.m instanceof DeepwaterMojoModel && ((DeepwaterMojoModel)this.m)._problem_type.equals("text");
        for (String dataColumnName : data.keySet()) {
            double value;
            Integer index = this.modelColumnNameToIndexMap.get(dataColumnName);
            if (index == null || index >= rawData.length) continue;
            BufferedImage img = null;
            String[] domainValues = this.m.getDomainValues(index);
            if (domainValues == null) {
                double value2 = Double.NaN;
                Object o = data.get(dataColumnName);
                if (o instanceof String) {
                    String s = ((String)o).trim();
                    boolean isURL = s.matches("^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]");
                    if (isImage) {
                        try {
                            if (isURL) {
                                img = ImageIO.read(new URL(s));
                            }
                            img = ImageIO.read(new File(s));
                        }
                        catch (IOException e) {
                            throw new PredictException("Couldn't read image from " + s);
                        }
                    } else {
                        if (isText) {
                            throw new IllegalArgumentException("MOJO scoring for text classification is not yet implemented.");
                        }
                        try {
                            value2 = Double.parseDouble(s);
                        }
                        catch (NumberFormatException nfe) {
                            throw new PredictNumberFormatException("Unable to parse value: " + s + ", from column: " + dataColumnName + ", as Double; " + nfe.getMessage());
                        }
                    }
                } else if (o instanceof Double) {
                    value2 = (Double)o;
                } else {
                    if (!(o instanceof byte[]) || !isImage) throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for numeric column " + dataColumnName);
                    ByteArrayInputStream is = new ByteArrayInputStream((byte[])o);
                    try {
                        img = ImageIO.read(is);
                    }
                    catch (IOException e) {
                        throw new PredictException("Couldn't interpret raw bytes as an image.");
                    }
                }
                if (isImage && img != null) {
                    DeepwaterMojoModel dwm = (DeepwaterMojoModel)this.m;
                    int W = dwm._width;
                    int H = dwm._height;
                    int C = dwm._channels;
                    float[] _destData = new float[W * H * C];
                    try {
                        GenModel.img2pixels(img, W, H, C, _destData, 0, dwm._meanImageData);
                    }
                    catch (IOException e) {
                        e.printStackTrace();
                        throw new PredictException("Couldn't vectorize image.");
                    }
                    rawData = new double[_destData.length];
                    for (int i = 0; i < rawData.length; ++i) {
                        rawData[i] = _destData[i];
                    }
                    return rawData;
                }
                rawData[index.intValue()] = value2;
                continue;
            }
            Object o = data.get(dataColumnName);
            if (o instanceof String) {
                String levelName = (String)o;
                HashMap<String, Integer> columnDomainMap = this.domainMap.get(index);
                Integer levelIndex = columnDomainMap.get(levelName);
                if (levelIndex == null) {
                    levelIndex = columnDomainMap.get(dataColumnName + "." + levelName);
                }
                if (levelIndex == null) {
                    if (!this.convertUnknownCategoricalLevelsToNa) throw new PredictUnknownCategoricalLevelException("Unknown categorical level (" + dataColumnName + "," + levelName + ")", dataColumnName, levelName);
                    value = Double.NaN;
                    this.unknownCategoricalLevelsSeenPerColumn.get(dataColumnName).incrementAndGet();
                } else {
                    value = levelIndex.intValue();
                }
            } else {
                if (!(o instanceof Double) || !Double.isNaN((Double)o)) throw new PredictUnknownTypeException("Unexpected object type " + o.getClass().getName() + " for categorical column " + dataColumnName);
                value = (Double)o;
            }
            rawData[index.intValue()] = value;
        }
        return rawData;
    }

    private double[] predict(RowData data, double[] preds) throws PredictException {
        double[] rawData = new double[this.m.nfeatures()];
        this.setToNaN(rawData);
        rawData = this.fillRawData(data, rawData);
        preds = this.m.score0(rawData, preds);
        return preds;
    }

    public static class Config {
        private GenModel model;
        private boolean convertUnknownCategoricalLevelsToNa = false;

        public Config setModel(GenModel value) {
            this.model = value;
            return this;
        }

        public GenModel getModel() {
            return this.model;
        }

        public Config setConvertUnknownCategoricalLevelsToNa(boolean value) {
            this.convertUnknownCategoricalLevelsToNa = value;
            return this;
        }

        public boolean getConvertUnknownCategoricalLevelsToNa() {
            return this.convertUnknownCategoricalLevelsToNa;
        }
    }
}

