/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.DataInfo;
import hex.Distribution;
import hex.DistributionFactory;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.DeepLearningModelInfo;
import hex.deeplearning.Dropout;
import hex.deeplearning.MurmurHash;
import hex.deeplearning.Storage;
import java.nio.ByteBuffer;
import java.util.Arrays;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.MathUtils;

public abstract class Neurons {
    short _k;
    int[][] _maxIncoming;
    Distribution _dist;
    protected int units;
    protected transient DeepLearningModel.DeepLearningParameters params;
    protected transient int _index;
    public transient Storage.DenseVector[] _origa;
    public transient Storage.DenseVector[] _a;
    public transient Storage.DenseVector[] _e;
    public Neurons _previous;
    public Neurons _input;
    DeepLearningModelInfo _minfo;
    public Storage.DenseRowMatrix _w;
    public Storage.DenseRowMatrix _wEA;
    public Storage.DenseVector _b;
    public Storage.DenseVector _bEA;
    Storage.DenseRowMatrix _wm;
    Storage.DenseVector _bm;
    Storage.DenseRowMatrix _ada_dx_g;
    Storage.DenseVector _bias_ada_dx_g;
    protected Dropout _dropout;
    private boolean _shortcut = false;
    public Storage.DenseVector _avg_a;

    Neurons(int units) {
        this.units = units;
    }

    public String toString() {
        String s2 = this.getClass().getSimpleName();
        s2 = s2 + "\nNumber of Neurons: " + this.units;
        s2 = s2 + "\nParameters:\n" + this.params.toString();
        if (this._dropout != null) {
            s2 = s2 + "\nDropout:\n" + this._dropout.toString();
        }
        return s2;
    }

    void sanityCheck(boolean training) {
        if (this instanceof Input) {
            assert (this._previous == null);
        } else {
            assert (this._previous != null);
            if (this._minfo.has_momenta()) {
                assert (this._wm != null);
                assert (this._bm != null);
                assert (this._ada_dx_g == null);
            }
            if (this._minfo.adaDelta()) {
                if (this.params._rho == 0.0) {
                    throw new IllegalArgumentException("rho must be > 0 if epsilon is >0.");
                }
                if (this.params._epsilon == 0.0) {
                    throw new IllegalArgumentException("epsilon must be > 0 if rho is >0.");
                }
                assert (this._minfo.adaDelta());
                assert (this._bias_ada_dx_g != null);
                assert (this._wm == null);
                assert (this._bm == null);
            }
            if (this instanceof MaxoutDropout || this instanceof TanhDropout || this instanceof RectifierDropout) assert (!training || this._dropout != null);
        }
    }

    public final void init(Neurons[] neurons, int index, DeepLearningModel.DeepLearningParameters p2, DeepLearningModelInfo minfo, boolean training) {
        int mb;
        this._index = index - 1;
        this.params = (DeepLearningModel.DeepLearningParameters)p2.clone();
        this.params._hidden_dropout_ratios = minfo.get_params()._hidden_dropout_ratios;
        this.params._rate *= Math.pow(this.params._rate_decay, index - 1);
        this.params._distribution = minfo.get_params()._distribution;
        this._dist = DistributionFactory.getDistribution(this.params);
        this._a = new Storage.DenseVector[this.params._mini_batch_size];
        for (mb = 0; mb < this._a.length; ++mb) {
            this._a[mb] = new Storage.DenseVector(this.units);
        }
        if (!(this instanceof Input)) {
            this._e = new Storage.DenseVector[this.params._mini_batch_size];
            for (mb = 0; mb < this._e.length; ++mb) {
                this._e[mb] = new Storage.DenseVector(this.units);
            }
        } else if (this.params._autoencoder && this.params._input_dropout_ratio > 0.0) {
            this._origa = new Storage.DenseVector[this.params._mini_batch_size];
            for (mb = 0; mb < this._origa.length; ++mb) {
                this._origa[mb] = new Storage.DenseVector(this.units);
            }
        }
        if (training && (this instanceof MaxoutDropout || this instanceof TanhDropout || this instanceof RectifierDropout || this instanceof ExpRectifierDropout || this instanceof Input)) {
            Object object = this instanceof Input ? (this.params._input_dropout_ratio == 0.0 ? null : new Dropout(this.units, this.params._input_dropout_ratio)) : (this._dropout = new Dropout(this.units, this.params._hidden_dropout_ratios[this._index]));
        }
        if (!(this instanceof Input)) {
            this._previous = neurons[this._index];
            this._minfo = minfo;
            this._w = minfo.get_weights(this._index);
            this._b = minfo.get_biases(this._index);
            if (this.params._autoencoder && this.params._sparsity_beta > 0.0 && this._index < this.params._hidden.length) {
                this._avg_a = minfo.get_avg_activations(this._index);
            }
            if (minfo.has_momenta()) {
                this._wm = minfo.get_weights_momenta(this._index);
                this._bm = minfo.get_biases_momenta(this._index);
            }
            if (minfo.adaDelta()) {
                this._ada_dx_g = minfo.get_ada_dx_g(this._index);
                this._bias_ada_dx_g = minfo.get_biases_ada_dx_g(this._index);
            }
            this._shortcut = this.params._fast_mode || !this.params._adaptive_rate && !this._minfo.has_momenta() && this.params._l1 == 0.0 && this.params._l2 == 0.0;
        }
        this.sanityCheck(training);
    }

    protected abstract void fprop(long var1, boolean var3, int var4);

    protected abstract void bprop(int var1);

    protected final void bpropOutputLayer(int n2) {
        assert (this._index == this.params._hidden.length);
        assert (this._a.length == this.params._mini_batch_size);
        int rows = this._a[0].size();
        float m4 = this._minfo.adaDelta() ? 0.0f : this.momentum();
        float r2 = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m4);
        for (int row = 0; row < rows; ++row) {
            double[] g2 = new double[n2];
            for (int mb = 0; mb < n2; ++mb) {
                g2[mb] = this._e[mb].raw()[row];
            }
            this.bprop(row, g2, r2, m4, n2);
        }
    }

    protected void setOutputLayerGradient(double ignored, int mb, int n2) {
        assert (this._minfo.get_params()._autoencoder && this._index == this._minfo.get_params()._hidden.length);
        int rows = this._a[mb].size();
        for (int row = 0; row < rows; ++row) {
            this._e[mb].set(row, this.autoEncoderGradient(row, mb) / (double)n2);
        }
    }

    final void bprop(int row, double[] partial_grad, float rate, float momentum, int n2) {
        int mb;
        float rho = (float)this.params._rho;
        float eps = (float)this.params._epsilon;
        float l1 = (float)this.params._l1;
        float l2 = (float)this.params._l2;
        float max_w2 = this.params._max_w2;
        boolean have_momenta = this._minfo.has_momenta();
        boolean have_ada = this._minfo.adaDelta();
        boolean nesterov = this.params._nesterov_accelerated_gradient;
        boolean fast_mode = this.params._fast_mode;
        int cols = this._previous._a[0].size();
        assert (partial_grad.length == n2);
        double avg_grad2 = 0.0;
        int idx = row * cols;
        for (mb = 0; mb < n2; ++mb) {
            if (this._shortcut && partial_grad[mb] == 0.0) {
                return;
            }
            boolean update_prev = this._previous._e != null && this._previous._e[mb] != null;
            for (int col = 0; col < cols; ++col) {
                int w2 = idx + col;
                if (this._k != 0) {
                    w2 = this._k * w2 + this._maxIncoming[mb][row];
                }
                double weight = this._w.raw()[w2];
                if (update_prev) {
                    this._previous._e[mb].add(col, partial_grad[mb] * weight);
                }
                double previous_a = this._previous._a[mb].get(col);
                if (fast_mode && previous_a == 0.0) continue;
                double grad = partial_grad[mb] * previous_a + Math.signum(weight) * (double)l1 + weight * (double)l2;
                if (this._wEA != null) {
                    grad += this.params._elastic_averaging_regularization * (double)(this._w.raw()[w2] - this._wEA.raw()[w2]);
                }
                if (DeepLearningModelInfo.gradientCheck != null) {
                    DeepLearningModelInfo.gradientCheck.apply(this._index, row, col, grad);
                }
                if (have_ada) {
                    double grad2 = grad * grad;
                    avg_grad2 += grad2;
                    float brate = Neurons.computeAdaDeltaRateForWeight(grad, w2, this._ada_dx_g, rho, eps);
                    float[] fArray = this._w.raw();
                    int n3 = w2;
                    fArray[n3] = (float)((double)fArray[n3] - (double)brate * grad);
                    continue;
                }
                if (!nesterov) {
                    double delta = (double)(-rate) * grad;
                    float[] fArray = this._w.raw();
                    int n4 = w2;
                    fArray[n4] = (float)((double)fArray[n4] + delta);
                    if (!have_momenta) continue;
                    float[] fArray2 = this._w.raw();
                    int n5 = w2;
                    fArray2[n5] = fArray2[n5] + momentum * this._wm.raw()[w2];
                    this._wm.raw()[w2] = (float)delta;
                    continue;
                }
                double tmp = -grad;
                if (have_momenta) {
                    float[] fArray = this._wm.raw();
                    int n6 = w2;
                    fArray[n6] = fArray[n6] * momentum;
                    float[] fArray3 = this._wm.raw();
                    int n7 = w2;
                    fArray3[n7] = (float)((double)fArray3[n7] + tmp);
                    tmp = this._wm.raw()[w2];
                }
                float[] fArray = this._w.raw();
                int n8 = w2;
                fArray[n8] = (float)((double)fArray[n8] + (double)rate * tmp);
            }
        }
        if (max_w2 != Float.POSITIVE_INFINITY) {
            for (mb = 0; mb < n2; ++mb) {
                this.rescale_weights(this._w, row, max_w2, mb);
            }
        }
        if (have_ada) {
            avg_grad2 /= (double)(cols * n2);
        }
        for (mb = 0; mb < n2; ++mb) {
            this.update_bias(this._b, this._bEA, this._bm, row, partial_grad, avg_grad2, rate, momentum, mb);
        }
    }

    private void rescale_weights(Storage.DenseRowMatrix w2, int row, float max_w2, int mb) {
        int end;
        int start;
        int cols = this._previous._a[0].size();
        if (this._k != 0) {
            start = this._k * (row * cols) + this._maxIncoming[mb][row];
            end = this._k * (row * cols + (cols - 1)) + this._maxIncoming[mb][row];
        } else {
            if (mb > 0) {
                return;
            }
            start = row * cols;
            end = row * cols + cols;
        }
        float r2 = MathUtils.sumSquares(w2.raw(), start, end);
        if (r2 > max_w2) {
            float scale = MathUtils.approxSqrt(max_w2 / r2);
            int c2 = start;
            while (c2 < end) {
                float[] fArray = w2.raw();
                int n2 = c2++;
                fArray[n2] = fArray[n2] * scale;
            }
        }
    }

    protected double autoEncoderGradient(int row, int mb) {
        assert (this._minfo.get_params()._autoencoder && this._index == this._minfo.get_params()._hidden.length);
        double t2 = this._input._origa != null ? this._input._origa[mb].get(row) : this._input._a[mb].get(row);
        double y2 = this._a[mb].get(row);
        return -2.0 * this._dist.negHalfGradient(t2, y2);
    }

    private static float computeAdaDeltaRateForWeight(double grad, int w2, Storage.DenseRowMatrix ada_dx_g, float rho, float eps) {
        double grad2 = grad * grad;
        ada_dx_g.raw()[2 * w2 + 1] = (float)((double)(rho * ada_dx_g.raw()[2 * w2 + 1]) + (double)(1.0f - rho) * grad2);
        float rate = MathUtils.approxSqrt((ada_dx_g.raw()[2 * w2] + eps) / (ada_dx_g.raw()[2 * w2 + 1] + eps));
        ada_dx_g.raw()[2 * w2] = (float)((double)(rho * ada_dx_g.raw()[2 * w2]) + (double)((1.0f - rho) * rate * rate) * grad2);
        return rate;
    }

    private static double computeAdaDeltaRateForBias(double grad2, int row, Storage.DenseVector bias_ada_dx_g, float rho, float eps) {
        bias_ada_dx_g.raw()[2 * row + 1] = (double)rho * bias_ada_dx_g.raw()[2 * row + 1] + (double)(1.0f - rho) * grad2;
        double rate = MathUtils.approxSqrt((bias_ada_dx_g.raw()[2 * row] + (double)eps) / (bias_ada_dx_g.raw()[2 * row + 1] + (double)eps));
        bias_ada_dx_g.raw()[2 * row] = (double)rho * bias_ada_dx_g.raw()[2 * row] + (double)(1.0f - rho) * rate * rate * grad2;
        return rate;
    }

    void compute_sparsity() {
        if (this._avg_a != null) {
            if (this.params._mini_batch_size > 1) {
                throw H2O.unimpl("Sparsity constraint is not yet implemented for mini-batch size > 1.");
            }
            for (int mb = 0; mb < this._minfo.get_params()._mini_batch_size; ++mb) {
                for (int row = 0; row < this._avg_a.size(); ++row) {
                    this._avg_a.set(row, 0.999 * this._avg_a.get(row) + 0.001 * this._a[mb].get(row));
                }
            }
        }
    }

    private void update_bias(Storage.DenseVector _b, Storage.DenseVector _bEA, Storage.DenseVector _bm, int row, double[] partial_grad, double avg_grad2, double rate, double momentum, int mb) {
        boolean have_momenta = this._minfo.has_momenta();
        boolean have_ada = this._minfo.adaDelta();
        float l1 = (float)this.params._l1;
        float l2 = (float)this.params._l2;
        int b2 = this._k != 0 ? this._k * row + this._maxIncoming[mb][row] : row;
        double bias = _b.get(b2);
        int n2 = mb;
        partial_grad[n2] = partial_grad[n2] + (Math.signum(bias) * (double)l1 + bias * (double)l2);
        if (_bEA != null) {
            int n3 = mb;
            partial_grad[n3] = partial_grad[n3] + (bias - _bEA.get(b2)) * this.params._elastic_averaging_regularization;
        }
        if (DeepLearningModelInfo.gradientCheck != null) {
            DeepLearningModelInfo.gradientCheck.apply(this._index, row, -1, partial_grad[mb]);
        }
        if (have_ada) {
            float rho = (float)this.params._rho;
            float eps = (float)this.params._epsilon;
            rate = Neurons.computeAdaDeltaRateForBias(avg_grad2, b2, this._bias_ada_dx_g, rho, eps);
        }
        if (!this.params._nesterov_accelerated_gradient) {
            double delta = -rate * partial_grad[mb];
            _b.add(b2, delta);
            if (have_momenta) {
                _b.add(b2, momentum * _bm.get(b2));
                _bm.set(b2, delta);
            }
        } else {
            double d2 = -partial_grad[mb];
            if (have_momenta) {
                _bm.set(b2, _bm.get(b2) * momentum);
                _bm.add(b2, d2);
                d2 = _bm.get(b2);
            }
            _b.add(b2, rate * d2);
        }
        if (this.params._autoencoder && this.params._sparsity_beta > 0.0 && !(this instanceof Output) && !(this instanceof Input) && this._index != this.params._hidden.length) {
            _b.add(b2, -(rate * this.params._sparsity_beta * (this._avg_a.raw()[b2] - this.params._average_activation)));
        }
        if (Double.isInfinite(_b.get(b2))) {
            this._minfo.setUnstable();
        }
    }

    public float rate(double n2) {
        return (float)(this.params._rate / (1.0 + this.params._rate_annealing * n2));
    }

    protected float momentum() {
        return this.momentum(-1.0);
    }

    public final float momentum(double n2) {
        double m4 = this.params._momentum_start;
        if (this.params._momentum_ramp > 0.0) {
            double num;
            double d2 = num = n2 != -1.0 ? (double)this._minfo.get_processed_total() : n2;
            m4 = num >= this.params._momentum_ramp ? this.params._momentum_stable : (m4 += (this.params._momentum_stable - this.params._momentum_start) * num / this.params._momentum_ramp);
        }
        return (float)m4;
    }

    static void gemv_naive(double[] res, float[] a2, double[] x2, double[] y2, byte[] row_bits) {
        int cols = x2.length;
        int rows = y2.length;
        assert (res.length == rows);
        for (int row = 0; row < rows; ++row) {
            res[row] = 0.0;
            if (row_bits != null && (row_bits[row / 8] & 1 << row % 8) == 0) continue;
            for (int col = 0; col < cols; ++col) {
                int n2 = row;
                res[n2] = res[n2] + (double)a2[row * cols + col] * x2[col];
            }
            int n3 = row;
            res[n3] = res[n3] + y2[row];
        }
    }

    static void gemv_row_optimized(double[] res, float[] a2, double[] x2, double[] y2, byte[] row_bits) {
        int cols = x2.length;
        int rows = y2.length;
        assert (res.length == rows);
        int extra = cols - cols % 8;
        int multiple = cols / 8 * 8 - 1;
        int idx = 0;
        for (int row = 0; row < rows; ++row) {
            res[row] = 0.0;
            if (row_bits == null || (row_bits[row / 8] & 1 << row % 8) != 0) {
                int col;
                double psum0 = 0.0;
                double psum1 = 0.0;
                double psum2 = 0.0;
                double psum3 = 0.0;
                double psum4 = 0.0;
                double psum5 = 0.0;
                double psum6 = 0.0;
                double psum7 = 0.0;
                for (col = 0; col < multiple; col += 8) {
                    int off = idx + col;
                    psum0 += (double)a2[off] * x2[col];
                    psum1 += (double)a2[off + 1] * x2[col + 1];
                    psum2 += (double)a2[off + 2] * x2[col + 2];
                    psum3 += (double)a2[off + 3] * x2[col + 3];
                    psum4 += (double)a2[off + 4] * x2[col + 4];
                    psum5 += (double)a2[off + 5] * x2[col + 5];
                    psum6 += (double)a2[off + 6] * x2[col + 6];
                    psum7 += (double)a2[off + 7] * x2[col + 7];
                }
                int n2 = row;
                res[n2] = res[n2] + (psum0 + psum1 + psum2 + psum3);
                int n3 = row;
                res[n3] = res[n3] + (psum4 + psum5 + psum6 + psum7);
                for (col = extra; col < cols; ++col) {
                    int n4 = row;
                    res[n4] = res[n4] + (double)a2[idx + col] * x2[col];
                }
                int n5 = row;
                res[n5] = res[n5] + y2[row];
            }
            idx += cols;
        }
    }

    static void gemv(Storage.DenseVector res, Storage.DenseRowMatrix a2, Storage.DenseVector x2, Storage.DenseVector y2, byte[] row_bits) {
        Neurons.gemv_row_optimized(res.raw(), a2.raw(), x2.raw(), y2.raw(), row_bits);
    }

    static void gemv_naive(Storage.DenseVector res, Storage.DenseRowMatrix a2, Storage.DenseVector x2, Storage.DenseVector y2, byte[] row_bits) {
        Neurons.gemv_naive(res.raw(), a2.raw(), x2.raw(), y2.raw(), row_bits);
    }

    public static class Linear
    extends Output {
        public Linear() {
            super(1);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            for (int mb = 0; mb < n2; ++mb) {
                Linear.gemv(this._a[mb], this._w, this._previous._a[mb], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
        }

        @Override
        protected void setOutputLayerGradient(double target, int mb, int n2) {
            boolean row = false;
            double y2 = this._a[mb].get(0);
            double g2 = -2.0 * this._dist.negHalfGradient(target, y2);
            this._e[mb].set(0, g2 / (double)n2);
        }
    }

    public static class Softmax
    extends Output {
        public Softmax(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            int mb;
            for (mb = 0; mb < n2; ++mb) {
                Softmax.gemv(this._a[mb], this._w, this._previous._a[mb], this._b, null);
            }
            for (mb = 0; mb < n2; ++mb) {
                int row;
                double max = ArrayUtils.maxValue(this._a[mb].raw());
                double scaling = 0.0;
                int rows = this._a[mb].size();
                for (row = 0; row < rows; ++row) {
                    this._a[mb].set(row, Math.exp(this._a[mb].get(row) - max));
                    scaling += this._a[mb].get(row);
                }
                row = 0;
                while (row < rows) {
                    double[] dArray = this._a[mb].raw();
                    int n3 = row++;
                    dArray[n3] = dArray[n3] / scaling;
                }
            }
        }

        @Override
        protected void setOutputLayerGradient(double target, int mb, int n2) {
            assert (target == (double)((int)target));
            int rows = this._a[mb].size();
            for (int row = 0; row < rows; ++row) {
                double g2;
                double t2 = row == (int)target ? 1 : 0;
                double y2 = this._a[mb].get(row);
                switch (this.params._loss) {
                    case CrossEntropy: {
                        g2 = y2 - t2;
                        break;
                    }
                    case ModifiedHuber: {
                        g2 = -2.0 * this._dist.negHalfGradient(t2, y2) * (1.0 - y2) * y2;
                        break;
                    }
                    case Quadratic: {
                        g2 = (y2 - t2) * (1.0 - y2) * y2;
                        break;
                    }
                    default: {
                        throw H2O.unimpl();
                    }
                }
                this._e[mb].set(row, g2 / (double)n2);
            }
        }
    }

    public static abstract class Output
    extends Neurons {
        Output(int units) {
            super(units);
        }

        @Override
        protected void bprop(int n2) {
            throw new UnsupportedOperationException();
        }
    }

    public static class ExpRectifierDropout
    extends ExpRectifier {
        public ExpRectifierDropout(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            if (training) {
                this._dropout.fillBytes(seed += this.params._seed + -629514240L);
                super.fprop(seed, true, n2);
            } else {
                super.fprop(seed, false, n2);
                for (int mb = 0; mb < n2; ++mb) {
                    ArrayUtils.mult(this._a[mb].raw(), 1.0 - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    public static class ExpRectifier
    extends Neurons {
        public ExpRectifier(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            for (int mb = 0; mb < n2; ++mb) {
                ExpRectifier.gemv(this._a[mb], this._w, this._previous._a[mb], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
            int rows = this._a[0].size();
            for (int row = 0; row < rows; ++row) {
                for (int mb = 0; mb < n2; ++mb) {
                    double x2 = this._a[mb].get(row);
                    double val = x2 >= 0.0 ? x2 : Math.exp(x2) - 1.0;
                    this._a[mb].set(row, val);
                }
            }
            this.compute_sparsity();
        }

        @Override
        protected void bprop(int n2) {
            assert (this._index < this._minfo.get_params()._hidden.length);
            float m4 = this._minfo.adaDelta() ? 0.0f : this.momentum();
            float r2 = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m4);
            int rows = this._a[0].size();
            for (int row = 0; row < rows; ++row) {
                double[] g2 = new double[n2];
                for (int mb = 0; mb < n2; ++mb) {
                    double x2 = this._a[mb].get(row);
                    double val = x2 >= 0.0 ? 1.0 : Math.exp(x2);
                    g2[mb] = this._e[mb].get(row) * val;
                }
                this.bprop(row, g2, r2, m4, n2);
            }
        }
    }

    public static class RectifierDropout
    extends Rectifier {
        public RectifierDropout(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            if (training) {
                this._dropout.fillBytes(seed += this.params._seed + 1014100461L);
                super.fprop(seed, true, n2);
            } else {
                super.fprop(seed, false, n2);
                for (int mb = 0; mb < n2; ++mb) {
                    ArrayUtils.mult(this._a[mb].raw(), 1.0 - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    public static class Rectifier
    extends Neurons {
        public Rectifier(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            for (int mb = 0; mb < n2; ++mb) {
                Rectifier.gemv(this._a[mb], this._w, this._previous._a[mb], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
            int rows = this._a[0].size();
            for (int mb = 0; mb < n2; ++mb) {
                for (int row = 0; row < rows; ++row) {
                    this._a[mb].set(row, 0.5 * (this._a[mb].get(row) + Math.abs(this._a[mb].get(row))));
                }
            }
            this.compute_sparsity();
        }

        @Override
        protected void bprop(int n2) {
            assert (this._index < this._minfo.get_params()._hidden.length);
            float m4 = this._minfo.adaDelta() ? 0.0f : this.momentum();
            float r2 = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m4);
            int rows = this._a[0].size();
            double[] g2 = new double[n2];
            for (int row = 0; row < rows; ++row) {
                for (int mb = 0; mb < n2; ++mb) {
                    g2[mb] = this._a[mb].get(row) > 0.0 ? this._e[mb].get(row) : 0.0;
                }
                this.bprop(row, g2, r2, m4, n2);
            }
        }
    }

    public static class MaxoutDropout
    extends Maxout {
        public MaxoutDropout(DeepLearningModel.DeepLearningParameters params, short k2, int units) {
            super(params, k2, units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            if (training) {
                this._dropout.fillBytes(seed += this.params._seed + 1372114957L);
                super.fprop(seed, true, n2);
            } else {
                super.fprop(seed, false, n2);
                for (int mb = 0; mb < n2; ++mb) {
                    ArrayUtils.mult(this._a[mb].raw(), 1.0 - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    public static class Maxout
    extends Neurons {
        public Maxout(DeepLearningModel.DeepLearningParameters params, short k2, int units) {
            super(units);
            this._k = k2;
            this._maxIncoming = new int[params._mini_batch_size][];
            for (int i2 = 0; i2 < this._maxIncoming.length; ++i2) {
                this._maxIncoming[i2] = new int[units];
            }
            if (this._k != 2) {
                throw H2O.unimpl("Maxout is currently hardcoded for 2 channels. Trivial to enable k > 2 though.");
            }
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            assert (this._b.size() == this._a[0].size() * this._k);
            assert (this._w.size() == (long)(this._a[0].size() * this._previous._a[0].size() * this._k));
            int rows = this._a[0].size();
            double[] channel = new double[this._k];
            for (int row = 0; row < rows; ++row) {
                for (int mb = 0; mb < n2; ++mb) {
                    this._a[mb].set(row, 0.0);
                    if (training && this._dropout != null && !this._dropout.unit_active(row)) continue;
                    int cols = this._previous._a[mb].size();
                    int maxK = 0;
                    for (short k2 = 0; k2 < this._k; k2 = (short)(k2 + 1)) {
                        channel[k2] = 0.0;
                        for (int col = 0; col < cols; ++col) {
                            short s2 = k2;
                            channel[s2] = channel[s2] + (double)this._w.raw()[this._k * (row * cols + col) + k2] * this._previous._a[mb].get(col);
                        }
                        short s3 = k2;
                        channel[s3] = channel[s3] + this._b.raw()[this._k * row + k2];
                        if (!(channel[k2] > channel[maxK])) continue;
                        maxK = k2;
                    }
                    this._maxIncoming[mb][row] = maxK;
                    this._a[mb].set(row, channel[maxK]);
                }
                this.compute_sparsity();
            }
        }

        @Override
        protected void bprop(int n2) {
            assert (this._index != this.params._hidden.length);
            float m4 = this._minfo.adaDelta() ? 0.0f : this.momentum();
            float r2 = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m4);
            double[] g2 = new double[n2];
            int rows = this._a[0].size();
            for (int row = 0; row < rows; ++row) {
                for (int mb = 0; mb < n2; ++mb) {
                    g2[mb] = this._e[mb].get(row);
                }
                this.bprop(row, g2, r2, m4, n2);
            }
        }
    }

    public static class TanhDropout
    extends Tanh {
        public TanhDropout(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            if (training) {
                this._dropout.fillBytes(seed += this.params._seed + -629514240L);
                super.fprop(seed, true, n2);
            } else {
                super.fprop(seed, false, n2);
                for (int mb = 0; mb < n2; ++mb) {
                    ArrayUtils.mult(this._a[mb].raw(), 1.0 - this.params._hidden_dropout_ratios[this._index]);
                }
            }
        }
    }

    public static class Tanh
    extends Neurons {
        public Tanh(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            for (int mb = 0; mb < n2; ++mb) {
                Tanh.gemv(this._a[mb], this._w, this._previous._a[mb], this._b, this._dropout != null ? this._dropout.bits() : null);
            }
            int rows = this._a[0].size();
            for (int mb = 0; mb < n2; ++mb) {
                for (int row = 0; row < rows; ++row) {
                    this._a[mb].set(row, 1.0 - 2.0 / (1.0 + Math.exp(2.0 * this._a[mb].get(row))));
                }
            }
            this.compute_sparsity();
        }

        @Override
        protected void bprop(int n2) {
            assert (this._index < this._minfo.get_params()._hidden.length);
            float m4 = this._minfo.adaDelta() ? 0.0f : this.momentum();
            float r2 = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m4);
            int rows = this._a[0].size();
            double[] g2 = new double[n2];
            for (int row = 0; row < rows; ++row) {
                for (int mb = 0; mb < n2; ++mb) {
                    g2[mb] = this._e[mb].get(row) * (1.0 - this._a[mb].get(row) * this._a[mb].get(row));
                }
                this.bprop(row, g2, r2, m4, n2);
            }
        }
    }

    public static class Input
    extends Neurons {
        private DataInfo _dinfo;

        Input(DeepLearningModel.DeepLearningParameters params, int units, DataInfo d2) {
            super(units);
            this._dinfo = d2;
            this._a = new Storage.DenseVector[params._mini_batch_size];
            for (int i2 = 0; i2 < this._a.length; ++i2) {
                this._a[i2] = new Storage.DenseVector(units);
            }
        }

        @Override
        protected void bprop(int n2) {
            throw new UnsupportedOperationException();
        }

        @Override
        protected void fprop(long seed, boolean training, int n2) {
            throw new UnsupportedOperationException();
        }

        public void setInput(long seed, double[] data, int mb) {
            int i2;
            assert (this._dinfo != null);
            double[] nums = MemoryManager.malloc8d(this._dinfo._nums);
            int[] cats = MemoryManager.malloc4(this._dinfo._cats);
            int ncats = 0;
            for (i2 = 0; i2 < this._dinfo._cats; ++i2) {
                assert (this._dinfo._catMissing[i2]);
                if (Double.isNaN(data[i2])) {
                    cats[ncats] = this._dinfo._catOffsets[i2 + 1] - 1;
                } else {
                    int c2 = (int)data[i2];
                    cats[ncats] = this._dinfo._useAllFactorLevels ? c2 + this._dinfo._catOffsets[i2] : (c2 != 0 ? c2 + this._dinfo._catOffsets[i2] - 1 : -1);
                    if (cats[ncats] >= this._dinfo._catOffsets[i2 + 1]) {
                        cats[ncats] = this._dinfo._catOffsets[i2 + 1] - 1;
                    }
                }
                ++ncats;
            }
            while (i2 < data.length) {
                double d2 = data[i2];
                if (this._dinfo._normMul != null) {
                    d2 = (d2 - this._dinfo._normSub[i2 - this._dinfo._cats]) * this._dinfo._normMul[i2 - this._dinfo._cats];
                }
                nums[i2 - this._dinfo._cats] = d2;
                ++i2;
            }
            this.setInput(seed, null, nums, ncats, cats, mb);
        }

        public void setInput(long seed, int[] numIds, double[] nums, int numcat, int[] cats, int mb) {
            Arrays.fill(this._a[mb].raw(), 0.0);
            if (this.params._max_categorical_features < this._dinfo.fullN() - this._dinfo._nums) {
                int i2;
                assert (nums.length == this._dinfo._nums);
                int M2 = nums.length + this.params._max_categorical_features;
                assert (this._a[mb].size() == M2);
                int cM = this.params._max_categorical_features;
                assert (this._a[mb].size() == M2);
                MurmurHash murmur = MurmurHash.getInstance();
                for (i2 = 0; i2 < numcat; ++i2) {
                    ByteBuffer buf = ByteBuffer.allocate(4);
                    int hashval = murmur.hash(buf.putInt(cats[i2]).array(), 4, (int)this.params._seed);
                    this._a[mb].add(Math.abs(hashval % cM), 1.0);
                }
                for (i2 = 0; i2 < nums.length; ++i2) {
                    this._a[mb].set(cM + i2, Double.isNaN(nums[i2]) ? 0.0 : nums[i2]);
                }
            } else {
                int i3;
                assert (this._a[mb].size() == this._dinfo.fullN());
                for (i3 = 0; i3 < numcat; ++i3) {
                    if (cats[i3] < 0) continue;
                    this._a[mb].set(cats[i3], 1.0);
                }
                if (numIds != null) {
                    for (i3 = 0; i3 < numIds.length; ++i3) {
                        this._a[mb].set(numIds[i3], Double.isNaN(nums[i3]) ? 0.0 : nums[i3]);
                    }
                } else {
                    for (i3 = 0; i3 < nums.length; ++i3) {
                        this._a[mb].set(this._dinfo.numStart() + i3, Double.isNaN(nums[i3]) ? 0.0 : nums[i3]);
                    }
                }
            }
            if (this._dropout == null) {
                return;
            }
            if (this.params._autoencoder && this.params._input_dropout_ratio > 0.0) {
                System.arraycopy(this._a[mb].raw(), 0, this._origa[mb].raw(), 0, this._a[mb].raw().length);
            }
            this._dropout.randomlySparsifyActivation((Storage.Vector)this._a[mb], seed += this.params._seed + 322417854L);
        }
    }
}

