/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.deeplearning;

import hex.ModelCategory;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.deeplearning.NeuralNetwork;
import hex.genmodel.utils.DistributionFamily;
import java.io.Serializable;

public class DeeplearningMojoModel
extends MojoModel {
    public int _mini_batch_size;
    public int _nums;
    public int _cats;
    public int[] _catoffsets;
    public double[] _normmul;
    public double[] _normsub;
    public double[] _normrespmul;
    public double[] _normrespsub;
    public boolean _use_all_factor_levels;
    public String _activation;
    public String[] _allActivations;
    public boolean _imputeMeans;
    public int[] _units;
    public double[] _all_drop_out_ratios;
    public StoreWeightsBias[] _weightsAndBias;
    public int[] _catNAFill;
    public int _numLayers;
    public DistributionFamily _family;
    protected String _genmodel_encoding;
    protected String[] _orig_names;
    protected String[][] _orig_domain_values;
    protected double[] _orig_projection_array;

    DeeplearningMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    public void init() {
        this._numLayers = this._units.length - 1;
        this._allActivations = new String[this._numLayers];
        int n2 = this._numLayers - 1;
        for (int i2 = 0; i2 < n2; ++i2) {
            this._allActivations[i2] = this._activation;
        }
        this._allActivations[n2] = this.isAutoEncoder() ? this._activation : (this.isClassifier() ? "Softmax" : "Linear");
    }

    @Override
    public final double[] score0(double[] dataRow, double offset, double[] preds) {
        assert (dataRow != null) : "doubles are null";
        double[] dArray = new double[this._units[0]];
        double[] dArray2 = new double[this._nums];
        int[] nArray = new int[this._cats];
        DeeplearningMojoModel.setInput(dataRow, dArray, dArray2, nArray, this._nums, this._cats, this._catoffsets, this._normmul, this._normsub, this._use_all_factor_levels, true);
        for (int i2 = 0; i2 < this._numLayers; ++i2) {
            NeuralNetwork neuralNetwork = new NeuralNetwork(this._allActivations[i2], this._all_drop_out_ratios[i2], this._weightsAndBias[i2], dArray, this._units[i2 + 1]);
            double[] dArray3 = neuralNetwork.fprop1Layer();
            dArray = dArray3;
        }
        if (!this.isAutoEncoder()) assert (this._nclasses == dArray.length) : "nclasses " + this._nclasses + " neuronsOutput.length " + dArray.length;
        return this.modifyOutputs(dArray, preds, dataRow);
    }

    public double[] modifyOutputs(double[] out, double[] preds, double[] dataRow) {
        if (this.isAutoEncoder()) {
            if (this._normmul != null && this._normmul.length > 0) {
                int n2;
                int n3 = out.length - this._nums;
                for (n2 = 0; n2 < n3; ++n2) {
                    preds[n2] = out[n2];
                }
                for (n2 = 0; n2 < this._nums; ++n2) {
                    int n4 = n3 + n2;
                    preds[n4] = out[n4] / this._normmul[n2] + this._normsub[n2];
                }
            } else {
                for (int i2 = 0; i2 < out.length; ++i2) {
                    preds[i2] = out[i2];
                }
            }
        } else if (this._family == DistributionFamily.modified_huber) {
            preds[0] = -1.0;
            DeeplearningMojoModel deeplearningMojoModel = this;
            preds[2] = deeplearningMojoModel.linkInv(deeplearningMojoModel._family, preds[0]);
            preds[1] = 1.0 - preds[2];
        } else if (this.isClassifier()) {
            assert (preds.length == out.length + 1);
            for (int i3 = 0; i3 < preds.length - 1; ++i3) {
                preds[i3 + 1] = out[i3];
                if (!Double.isNaN(preds[i3 + 1])) continue;
                throw new RuntimeException("Predicted class probability NaN!");
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities(preds, this._priorClassDistrib, this._modelClassDistrib);
            }
            preds[0] = GenModel.getPrediction(preds, this._priorClassDistrib, dataRow, this._defaultThreshold);
        } else {
            preds[0] = this._normrespmul != null ? out[0] / this._normrespmul[0] + this._normrespsub[0] : out[0];
            DeeplearningMojoModel deeplearningMojoModel = this;
            preds[0] = deeplearningMojoModel.linkInv(deeplearningMojoModel._family, preds[0]);
            if (Double.isNaN(preds[0])) {
                throw new RuntimeException("Predicted regression target NaN!");
            }
        }
        return preds;
    }

    private double linkInv(DistributionFamily distribution, double f2) {
        switch (distribution) {
            case bernoulli: 
            case quasibinomial: 
            case modified_huber: 
            case ordinal: {
                return 1.0 / (1.0 + Math.min(1.0E19, Math.exp(-f2)));
            }
            case multinomial: 
            case poisson: 
            case gamma: 
            case tweedie: {
                return Math.min(1.0E19, Math.exp(f2));
            }
        }
        return f2;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    @Override
    public int getPredsSize(ModelCategory mc) {
        if (mc == ModelCategory.AutoEncoder) {
            return this._units[0];
        }
        if (this.isClassifier()) {
            return this.nclasses() + 1;
        }
        return 2;
    }

    public double calculateReconstructionErrorPerRowData(double[] original, double[] reconstructed) {
        assert (original != null && original.length > 0 && reconstructed != null && reconstructed.length > 0);
        assert (original.length == reconstructed.length);
        int n2 = original.length - this._nums;
        double d2 = 0.0;
        for (int i2 = 0; i2 < original.length; ++i2) {
            double d3 = this._normmul != null && this._normmul.length > 0 && this._nums > 0 && i2 >= n2 ? this._normmul[i2 - n2] : 1.0;
            d2 += Math.pow((reconstructed[i2] - original[i2]) * d3, 2.0);
        }
        return d2 / (double)original.length;
    }

    @Override
    public CategoricalEncoding getCategoricalEncoding() {
        switch (this._genmodel_encoding) {
            case "AUTO": 
            case "SortByResponse": 
            case "OneHotInternal": {
                return CategoricalEncoding.AUTO;
            }
            case "Binary": {
                return CategoricalEncoding.Binary;
            }
            case "Eigen": {
                return CategoricalEncoding.Eigen;
            }
            case "LabelEncoder": {
                return CategoricalEncoding.LabelEncoder;
            }
        }
        return null;
    }

    @Override
    public String[] getOrigNames() {
        return this._orig_names;
    }

    @Override
    public double[] getOrigProjectionArray() {
        return this._orig_projection_array;
    }

    @Override
    public String[][] getOrigDomainValues() {
        return this._orig_domain_values;
    }

    public static class StoreWeightsBias
    implements Serializable {
        float[] _wValues;
        double[] _bValues;

        StoreWeightsBias(float[] wvalues, double[] bvalues) {
            this._wValues = wvalues;
            this._bValues = bvalues;
        }
    }
}

