/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.genmodel.ConverterFactoryProvidingModel;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.deepwater.DWImageConverter;
import hex.genmodel.algos.deepwater.DWTextConverter;
import hex.genmodel.algos.deepwater.caffe.DeepwaterCaffeBackend;
import hex.genmodel.easy.CategoricalEncoder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowToRawDataConverter;
import java.io.File;
import java.util.Map;

public class DeepwaterMojoModel
extends MojoModel
implements ConverterFactoryProvidingModel {
    public String _problem_type;
    public int _mini_batch_size;
    public int _height;
    public int _width;
    public int _channels;
    public int _nums;
    public int _cats;
    public int[] _catOffsets;
    public double[] _normMul;
    public double[] _normSub;
    public double[] _normRespMul;
    public double[] _normRespSub;
    public boolean _useAllFactorLevels;
    transient byte[] _network;
    transient byte[] _parameters;
    public transient float[] _meanImageData;
    BackendTrain _backend;
    BackendModel _model;
    ImageDataSet _imageDataSet;
    RuntimeOptions _opts;
    BackendParams _backendParams;

    DeepwaterMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    public final double[] score0(double[] doubles, double offset, double[] preds) {
        float[] floats;
        int cats;
        assert (doubles != null) : "doubles are null";
        int n = cats = this._catOffsets == null ? 0 : this._catOffsets[this._cats];
        if (this._nums > 0) {
            floats = new float[this._nums + cats];
            GenModel.setInput((double[])doubles, (float[])floats, (int)this._nums, (int)this._cats, (int[])this._catOffsets, (double[])this._normMul, (double[])this._normSub, (boolean)this._useAllFactorLevels, (boolean)true);
        } else {
            floats = new float[doubles.length];
            for (int i = 0; i < floats.length; ++i) {
                floats[i] = (float)doubles[i] - (this._meanImageData == null ? 0.0f : this._meanImageData[i]);
            }
        }
        float[] predFloats = this._backend.predict(this._model, floats);
        assert (this._nclasses == predFloats.length) : "nclasses " + this._nclasses + " predFloats.length " + predFloats.length;
        if (this._nclasses > 1) {
            for (int i = 0; i < predFloats.length; ++i) {
                preds[1 + i] = predFloats[i];
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities((double[])preds, (double[])this._priorClassDistrib, (double[])this._modelClassDistrib);
            }
            preds[0] = GenModel.getPrediction((double[])preds, (double[])this._priorClassDistrib, (double[])doubles, (double)this._defaultThreshold);
        } else {
            preds[0] = this._normRespMul != null && this._normRespSub != null ? (double)predFloats[0] * this._normRespMul[0] + this._normRespSub[0] : (double)predFloats[0];
        }
        return preds;
    }

    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    public static BackendTrain createDeepWaterBackend(String backend) {
        try {
            File f = new File("/opt/caffe-h2o/");
            if (backend.equals("caffe") && f.exists() && f.isDirectory()) {
                return new DeepwaterCaffeBackend();
            }
            if (backend.equals("mxnet")) {
                backend = "deepwater.backends.mxnet.MXNetBackend";
            } else if (backend.equals("tensorflow")) {
                backend = "deepwater.backends.tensorflow.TensorflowBackend";
            }
            return (BackendTrain)Class.forName(backend).newInstance();
        }
        catch (Exception exception) {
            return null;
        }
    }

    public RowToRawDataConverter makeConverterFactory(Map<String, Integer> modelColumnNameToIndexMap, Map<Integer, CategoricalEncoder> domainMap, EasyPredictModelWrapper.ErrorConsumer errorConsumer, EasyPredictModelWrapper.Config config) {
        if (this._problem_type.equals("image")) {
            return new DWImageConverter(this, modelColumnNameToIndexMap, domainMap, errorConsumer, config);
        }
        if (this._problem_type.equals("text")) {
            return new DWTextConverter((GenModel)this, modelColumnNameToIndexMap, domainMap, errorConsumer, config);
        }
        return new RowToRawDataConverter((GenModel)this, modelColumnNameToIndexMap, domainMap, errorConsumer, config);
    }
}

