/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendParams;
import deepwater.backends.BackendTrain;
import deepwater.backends.RuntimeOptions;
import deepwater.datasets.ImageDataSet;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;

public class DeepwaterMojoModel
extends MojoModel {
    public String _problem_type;
    public int _mini_batch_size;
    public int _height;
    public int _width;
    public int _channels;
    public int _nums;
    public int _cats;
    public int[] _catOffsets;
    public double[] _normMul;
    public double[] _normSub;
    public double[] _normRespMul;
    public double[] _normRespSub;
    public boolean _useAllFactorLevels;
    transient byte[] _network;
    transient byte[] _parameters;
    public transient float[] _meanImageData;
    BackendTrain _backend;
    BackendModel _model;
    ImageDataSet _imageDataSet;
    RuntimeOptions _opts;
    BackendParams _backendParams;

    DeepwaterMojoModel(String[] columns, String[][] domains) {
        super(columns, domains);
    }

    @Override
    public final double[] score0(double[] doubles, double offset, double[] preds) {
        float[] floats;
        int cats;
        assert (doubles != null) : "doubles are null";
        int n = cats = this._catOffsets == null ? 0 : this._catOffsets[this._cats];
        if (this._nums > 0) {
            floats = new float[this._nums + cats];
            GenModel.setInput(doubles, floats, this._nums, this._cats, this._catOffsets, this._normMul, this._normSub, this._useAllFactorLevels);
        } else {
            floats = new float[doubles.length];
            for (int i = 0; i < floats.length; ++i) {
                floats[i] = (float)doubles[i] - (this._meanImageData == null ? 0.0f : this._meanImageData[i]);
            }
        }
        float[] predFloats = this._backend.predict(this._model, floats);
        assert (this._nclasses == predFloats.length) : "nclasses " + this._nclasses + " predFloats.length " + predFloats.length;
        if (this._nclasses > 1) {
            for (int i = 0; i < predFloats.length; ++i) {
                preds[1 + i] = predFloats[i];
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities(preds, this._priorClassDistrib, this._modelClassDistrib);
            }
            preds[0] = GenModel.getPrediction(preds, this._priorClassDistrib, doubles, this._defaultThreshold);
        } else {
            preds[0] = this._normRespMul != null && this._normRespSub != null ? (double)predFloats[0] * this._normRespMul[0] + this._normRespSub[0] : (double)predFloats[0];
        }
        return preds;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    public static BackendTrain createDeepWaterBackend(String backend) {
        try {
            if (backend.equals("mxnet")) {
                backend = "deepwater.backends.mxnet.MXNetBackend";
            }
            if (backend.equals("tensorflow")) {
                backend = "deepwater.backends.tensorflow.TensorflowBackend";
            }
            if (backend.equals("caffe")) {
                backend = "deepwater.backends.caffe.CaffeBackend";
            }
            return (BackendTrain)Class.forName(backend).newInstance();
        }
        catch (Exception exception) {
            return null;
        }
    }
}

