/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.deeplearning;

import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.deeplearning.NeuralNetwork;
import hex.genmodel.utils.DistributionFamily;
import java.io.Serializable;

public class DeeplearningMojoModel
extends MojoModel {
    public int _mini_batch_size;
    public int _nums;
    public int _cats;
    public int[] _catoffsets;
    public double[] _normmul;
    public double[] _normsub;
    public double[] _normrespmul;
    public double[] _normrespsub;
    public boolean _use_all_factor_levels;
    public String _activation;
    public String[] _allActivations;
    public boolean _imputeMeans;
    public int[] _units;
    public double[] _all_drop_out_ratios;
    public StoreWeightsBias[] _weightsAndBias;
    public int[] _catNAFill;
    public int _numLayers;
    public DistributionFamily _family;

    DeeplearningMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    public void init() {
        this._numLayers = this._units.length - 1;
        this._allActivations = new String[this._numLayers];
        int inputLayers = this._numLayers - 1;
        for (int index = 0; index < inputLayers; ++index) {
            this._allActivations[index] = this._activation;
        }
        this._allActivations[inputLayers] = this.isAutoEncoder() ? this._activation : (this.isClassifier() ? "Softmax" : "Linear");
    }

    @Override
    public final double[] score0(double[] dataRow, double offset, double[] preds) {
        assert (dataRow != null) : "doubles are null";
        double[] neuronsInput = new double[this._units[0]];
        double[] _numsA = new double[this._nums];
        int[] _catsA = new int[this._cats];
        DeeplearningMojoModel.setInput(dataRow, neuronsInput, _numsA, _catsA, this._nums, this._cats, this._catoffsets, this._normmul, this._normsub, this._use_all_factor_levels, true);
        for (int layer = 0; layer < this._numLayers; ++layer) {
            NeuralNetwork oneLayer = new NeuralNetwork(this._allActivations[layer], this._all_drop_out_ratios[layer], this._weightsAndBias[layer], neuronsInput, this._units[layer + 1]);
            double[] neuronsOutput = oneLayer.fprop1Layer();
            neuronsInput = neuronsOutput;
        }
        if (!this.isAutoEncoder()) assert (this._nclasses == neuronsInput.length) : "nclasses " + this._nclasses + " neuronsOutput.length " + neuronsInput.length;
        return this.modifyOutputs(neuronsInput, preds, dataRow);
    }

    public double[] modifyOutputs(double[] out, double[] preds, double[] dataRow) {
        if (this.isAutoEncoder()) {
            if (this._normmul != null && this._normmul.length > 0) {
                int k;
                int nodeSize = out.length - this._nums;
                for (k = 0; k < nodeSize; ++k) {
                    preds[k] = out[k];
                }
                for (k = 0; k < this._nums; ++k) {
                    int offset = nodeSize + k;
                    preds[offset] = out[offset] / this._normmul[k] + this._normsub[k];
                }
            } else {
                for (int k = 0; k < out.length; ++k) {
                    preds[k] = out[k];
                }
            }
        } else if (this._family == DistributionFamily.modified_huber) {
            preds[0] = -1.0;
            preds[2] = this.linkInv(this._family, preds[0]);
            preds[1] = 1.0 - preds[2];
        } else if (this.isClassifier()) {
            assert (preds.length == out.length + 1);
            for (int i = 0; i < preds.length - 1; ++i) {
                preds[i + 1] = out[i];
                if (!Double.isNaN(preds[i + 1])) continue;
                throw new RuntimeException("Predicted class probability NaN!");
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities(preds, this._priorClassDistrib, this._modelClassDistrib);
            }
            preds[0] = GenModel.getPrediction(preds, this._priorClassDistrib, dataRow, this._defaultThreshold);
        } else {
            preds[0] = this._normrespmul != null ? out[0] / this._normrespmul[0] + this._normrespsub[0] : out[0];
            preds[0] = this.linkInv(this._family, preds[0]);
            if (Double.isNaN(preds[0])) {
                throw new RuntimeException("Predicted regression target NaN!");
            }
        }
        return preds;
    }

    private double linkInv(DistributionFamily distribution, double f) {
        switch (distribution) {
            case bernoulli: 
            case quasibinomial: 
            case modified_huber: 
            case ordinal: {
                return 1.0 / (1.0 + Math.min(1.0E19, Math.exp(-f)));
            }
            case multinomial: 
            case poisson: 
            case gamma: 
            case tweedie: {
                return Math.min(1.0E19, Math.exp(f));
            }
        }
        return f;
    }

    @Override
    public double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    @Override
    public int getPredsSize(ModelCategory mc) {
        return mc == ModelCategory.AutoEncoder ? this._units[0] : (this.isClassifier() ? this.nclasses() + 1 : 2);
    }

    public double calculateReconstructionErrorPerRowData(double[] original, double[] reconstructed) {
        assert (original != null && original.length > 0 && reconstructed != null && reconstructed.length > 0);
        assert (original.length == reconstructed.length);
        int numStartIndex = original.length - this._nums;
        double l2 = 0.0;
        for (int i = 0; i < original.length; ++i) {
            double norm = this._normmul != null && this._normmul.length > 0 && this._nums > 0 && i >= numStartIndex ? this._normmul[i - numStartIndex] : 1.0;
            l2 += Math.pow((reconstructed[i] - original[i]) * norm, 2.0);
        }
        return l2 / (double)original.length;
    }

    public static class StoreWeightsBias
    implements Serializable {
        float[] _wValues;
        double[] _bValues;

        StoreWeightsBias(float[] wvalues, double[] bvalues) {
            this._wValues = wvalues;
            this._bValues = bvalues;
        }
    }
}

