/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.DataInfo;
import hex.Distribution;
import hex.deeplearning.DeepLearningModelInfo;
import hex.deeplearning.DeepLearningParameters;
import hex.deeplearning.Dropout;
import hex.deeplearning.MurmurHash;
import hex.deeplearning.Storage;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.MathUtils;

public abstract class Neurons {
    protected int units;
    protected transient DeepLearningParameters params;
    protected transient int _index;
    public transient Storage.Vector _a;
    public transient Storage.DenseVector _e;
    public Neurons _previous;
    public Neurons _input;
    DeepLearningModelInfo _minfo;
    public Storage.Matrix _w;
    public Storage.Matrix _wEA;
    public Storage.DenseVector _b;
    public Storage.DenseVector _bEA;
    Storage.Matrix _wm;
    Storage.DenseVector _bm;
    Storage.Matrix _ada_dx_g;
    Storage.DenseVector _bias_ada_dx_g;
    protected Dropout _dropout;
    private boolean _shortcut = false;
    public Storage.DenseVector _avg_a;
    public static final int missing_int_value = Integer.MAX_VALUE;
    public static final Float missing_real_value = Float.valueOf(Float.NaN);

    Neurons(int units) {
        this.units = units;
    }

    public String toString() {
        String s = this.getClass().getSimpleName();
        s = s + "\nNumber of Neurons: " + this.units;
        s = s + "\nParameters:\n" + ((Object)((Object)this.params)).toString();
        if (this._dropout != null) {
            s = s + "\nDropout:\n" + this._dropout.toString();
        }
        return s;
    }

    static Distribution getDistribution(DeepLearningParameters.Loss loss) {
        switch (loss) {
            case Automatic: 
            case MeanSquare: {
                return new Distribution(Distribution.Family.gaussian);
            }
            case Huber: {
                return new Distribution(Distribution.Family.huber);
            }
            case Absolute: {
                return new Distribution(Distribution.Family.laplace);
            }
        }
        throw H2O.unimpl((String)loss.toString());
    }

    void sanityCheck(boolean training) {
        if (this instanceof Input) {
            assert (this._previous == null);
            assert (!training || this._dropout != null);
        } else {
            assert (this._previous != null);
            if (this._minfo.has_momenta()) {
                assert (this._wm != null);
                assert (this._bm != null);
                assert (this._ada_dx_g == null);
            }
            if (this._minfo.adaDelta()) {
                if (this.params._rho == 0.0) {
                    throw new IllegalArgumentException("rho must be > 0 if epsilon is >0.");
                }
                if (this.params._epsilon == 0.0) {
                    throw new IllegalArgumentException("epsilon must be > 0 if rho is >0.");
                }
                assert (this._minfo.adaDelta());
                assert (this._bias_ada_dx_g != null);
                assert (this._wm == null);
                assert (this._bm == null);
            }
            if (this instanceof MaxoutDropout || this instanceof TanhDropout || this instanceof RectifierDropout) assert (!training || this._dropout != null);
        }
    }

    public final void init(Neurons[] neurons, int index, DeepLearningParameters p, DeepLearningModelInfo minfo, boolean training) {
        this._index = index - 1;
        this.params = (DeepLearningParameters)p.clone();
        this.params._hidden_dropout_ratios = minfo.get_params()._hidden_dropout_ratios;
        this.params._rate *= Math.pow(this.params._rate_decay, index - 1);
        this._a = new Storage.DenseVector(this.units);
        if (!(this instanceof Output) && !(this instanceof Input)) {
            this._e = new Storage.DenseVector(this.units);
        }
        if (training && (this instanceof MaxoutDropout || this instanceof TanhDropout || this instanceof RectifierDropout || this instanceof Input)) {
            Dropout dropout = this._dropout = this instanceof Input ? new Dropout(this.units, this.params._input_dropout_ratio) : new Dropout(this.units, this.params._hidden_dropout_ratios[this._index]);
        }
        if (!(this instanceof Input)) {
            this._previous = neurons[this._index];
            this._minfo = minfo;
            this._w = minfo.get_weights(this._index);
            this._b = minfo.get_biases(this._index);
            if (this.params._autoencoder && this.params._sparsity_beta > 0.0 && this._index < this.params._hidden.length) {
                this._avg_a = minfo.get_avg_activations(this._index);
            }
            if (minfo.has_momenta()) {
                this._wm = minfo.get_weights_momenta(this._index);
                this._bm = minfo.get_biases_momenta(this._index);
            }
            if (minfo.adaDelta()) {
                this._ada_dx_g = minfo.get_ada_dx_g(this._index);
                this._bias_ada_dx_g = minfo.get_biases_ada_dx_g(this._index);
            }
            this._shortcut = this.params._fast_mode || !this.params._adaptive_rate && !this._minfo.has_momenta() && this.params._l1 == 0.0 && this.params._l2 == 0.0;
        }
        this.sanityCheck(training);
    }

    protected abstract void fprop(long var1, boolean var3);

    protected abstract void bprop();

    void bprop_sparse(float r, float m) {
        Storage.SparseVector prev_a = (Storage.SparseVector)this._previous._a;
        int start = prev_a.begin()._idx;
        int end = prev_a.end()._idx;
        for (int it = start; it < end; ++it) {
            int col = prev_a._indices[it];
            float previous_a = prev_a._values[it];
            this.bprop_col(col, previous_a, r, m);
        }
        int rows = this._a.size();
        float max_w2 = this.params._max_w2;
        for (int row = 0; row < rows; ++row) {
            if (max_w2 == Float.POSITIVE_INFINITY) continue;
            Neurons.rescale_weights(this._w, row, max_w2);
        }
    }

    final void bprop(int row, float partial_grad, float rate, float momentum) {
        if (this._shortcut && partial_grad == 0.0f) {
            return;
        }
        if (this._w instanceof Storage.DenseRowMatrix && this._previous._a instanceof Storage.DenseVector) {
            this.bprop_dense_row_dense((Storage.DenseRowMatrix)this._w, (Storage.DenseRowMatrix)this._wEA, (Storage.DenseRowMatrix)this._wm, (Storage.DenseRowMatrix)this._ada_dx_g, (Storage.DenseVector)this._previous._a, this._previous._e, this._b, this._bEA, this._bm, row, partial_grad, rate, momentum);
        } else if (this._w instanceof Storage.DenseRowMatrix && this._previous._a instanceof Storage.SparseVector) {
            this.bprop_dense_row_sparse((Storage.DenseRowMatrix)this._w, (Storage.DenseRowMatrix)this._wm, (Storage.DenseRowMatrix)this._ada_dx_g, (Storage.SparseVector)this._previous._a, this._previous._e, this._b, this._bm, row, partial_grad, rate, momentum);
        } else {
            throw new UnsupportedOperationException("bprop for types not yet implemented.");
        }
    }

    final void bprop_col(int col, float previous_a, float rate, float momentum) {
        if (!(this._w instanceof Storage.DenseColMatrix) || !(this._previous._a instanceof Storage.SparseVector)) {
            throw new UnsupportedOperationException("bprop_col for types not yet implemented.");
        }
        this.bprop_dense_col_sparse((Storage.DenseColMatrix)this._w, (Storage.DenseColMatrix)this._wm, (Storage.DenseColMatrix)this._ada_dx_g, (Storage.SparseVector)this._previous._a, this._previous._e, this._b, this._bm, col, previous_a, rate, momentum);
    }

    private void bprop_dense_row_dense(Storage.DenseRowMatrix _w, Storage.DenseRowMatrix _wEA, Storage.DenseRowMatrix _wm, Storage.DenseRowMatrix adaxg, Storage.DenseVector prev_a, Storage.DenseVector prev_e, Storage.DenseVector _b, Storage.DenseVector _bEA, Storage.DenseVector _bm, int row, float partial_grad, float rate, float momentum) {
        float rho = (float)this.params._rho;
        float eps = (float)this.params._epsilon;
        float l1 = (float)this.params._l1;
        float l2 = (float)this.params._l2;
        float max_w2 = this.params._max_w2;
        boolean have_momenta = this._minfo.has_momenta();
        boolean have_ada = this._minfo.adaDelta();
        boolean nesterov = this.params._nesterov_accelerated_gradient;
        boolean update_prev = prev_e != null;
        boolean fast_mode = this.params._fast_mode;
        int cols = prev_a.size();
        int idx = row * cols;
        float avg_grad2 = 0.0f;
        for (int col = 0; col < cols; ++col) {
            float weight = _w.get(row, col);
            if (update_prev) {
                prev_e.add(col, partial_grad * weight);
            }
            float previous_a = prev_a.get(col);
            if (fast_mode && previous_a == 0.0f) continue;
            int w = idx + col;
            float grad = partial_grad * previous_a - Math.signum(weight) * l1 - weight * l2;
            if (_wEA != null) {
                grad = (float)((double)grad - this.params._elastic_averaging_regularization * (double)(_w.raw()[w] - _wEA.raw()[w]));
            }
            if (have_ada) {
                float grad2 = grad * grad;
                avg_grad2 += grad2;
                float brate = Neurons.computeAdaDeltaRateForWeight(grad, w, adaxg, rho, eps);
                float[] fArray = _w.raw();
                int n = w;
                fArray[n] = fArray[n] + brate * grad;
                continue;
            }
            if (!nesterov) {
                float delta = rate * grad;
                float[] fArray = _w.raw();
                int n = w;
                fArray[n] = fArray[n] + delta;
                if (!have_momenta) continue;
                float[] fArray2 = _w.raw();
                int n2 = w;
                fArray2[n2] = fArray2[n2] + momentum * _wm.raw()[w];
                _wm.raw()[w] = delta;
                continue;
            }
            float tmp = grad;
            if (have_momenta) {
                float[] fArray = _wm.raw();
                int n = w;
                fArray[n] = fArray[n] * momentum;
                float[] fArray3 = _wm.raw();
                int n3 = w;
                fArray3[n3] = fArray3[n3] + tmp;
                tmp = _wm.raw()[w];
            }
            float[] fArray = _w.raw();
            int n = w;
            fArray[n] = fArray[n] + rate * tmp;
        }
        if (max_w2 != Float.POSITIVE_INFINITY) {
            Neurons.rescale_weights(_w, row, max_w2);
        }
        if (have_ada) {
            avg_grad2 /= (float)cols;
        }
        this.update_bias(_b, _bEA, _bm, row, partial_grad, avg_grad2, rate, momentum);
    }

    private void bprop_dense_col_sparse(Storage.DenseColMatrix w, Storage.DenseColMatrix wm, Storage.DenseColMatrix adaxg, Storage.SparseVector prev_a, Storage.DenseVector prev_e, Storage.DenseVector b, Storage.DenseVector bm, int col, float previous_a, float rate, float momentum) {
        float rho = (float)this.params._rho;
        float eps = (float)this.params._epsilon;
        float l1 = (float)this.params._l1;
        float l2 = (float)this.params._l2;
        boolean have_momenta = this._minfo.has_momenta();
        boolean have_ada = this._minfo.adaDelta();
        boolean nesterov = this.params._nesterov_accelerated_gradient;
        boolean update_prev = prev_e != null;
        int cols = prev_a.size();
        int rows = this._a.size();
        for (int row = 0; row < rows; ++row) {
            float partial_grad = this._e.get(row) * (1.0f - this._a.get(row) * this._a.get(row));
            float weight = w.get(row, col);
            if (update_prev) {
                prev_e.add(col, partial_grad * weight);
            }
            assert (previous_a != 0.0f);
            if (this._shortcut && partial_grad == 0.0f) continue;
            float grad = partial_grad * previous_a - Math.signum(weight) * l1 - weight * l2;
            if (this._wEA != null) {
                throw H2O.unimpl((String)"elastic averaging is not implemented for sparse input handling with column-major matrix format.");
            }
            if (have_ada) {
                assert (!have_momenta);
                float brate = Neurons.computeAdaDeltaRateForWeight(grad, row, col, adaxg, rho, eps);
                w.add(row, col, brate * grad);
            } else if (!nesterov) {
                float delta = rate * grad;
                w.add(row, col, delta);
                if (have_momenta) {
                    w.add(row, col, momentum * wm.get(row, col));
                    wm.set(row, col, delta);
                }
            } else {
                float tmp = grad;
                if (have_momenta) {
                    float val = wm.get(row, col);
                    val *= momentum;
                    val += tmp;
                    tmp = val;
                    wm.set(row, col, val);
                }
                w.add(row, col, rate * tmp);
            }
            assert (this._bEA == null);
            this.update_bias(b, this._bEA, bm, row, partial_grad / (float)cols, grad * grad / (float)cols, rate, momentum);
        }
    }

    private void bprop_dense_row_sparse(Storage.DenseRowMatrix _w, Storage.DenseRowMatrix _wm, Storage.DenseRowMatrix adaxg, Storage.SparseVector prev_a, Storage.DenseVector prev_e, Storage.DenseVector _b, Storage.DenseVector _bm, int row, float partial_grad, float rate, float momentum) {
        float rho = (float)this.params._rho;
        float eps = (float)this.params._epsilon;
        float l1 = (float)this.params._l1;
        float l2 = (float)this.params._l2;
        float max_w2 = this.params._max_w2;
        boolean have_momenta = this._minfo.has_momenta();
        boolean have_ada = this._minfo.adaDelta();
        boolean nesterov = this.params._nesterov_accelerated_gradient;
        boolean update_prev = prev_e != null;
        int cols = prev_a.size();
        int idx = row * cols;
        float avg_grad2 = 0.0f;
        int start = prev_a.begin()._idx;
        int end = prev_a.end()._idx;
        for (int it = start; it < end; ++it) {
            int col = prev_a._indices[it];
            float weight = _w.get(row, col);
            if (update_prev) {
                prev_e.add(col, partial_grad * weight);
            }
            float previous_a = prev_a._values[it];
            assert (previous_a != 0.0f);
            float grad = partial_grad * previous_a - Math.signum(weight) * l1 - weight * l2;
            if (this._wEA != null) {
                throw H2O.unimpl((String)"elastic averaging is not implemented for sparse input handling.");
            }
            int w = idx + col;
            if (have_ada) {
                assert (!have_momenta);
                float grad2 = grad * grad;
                avg_grad2 += grad2;
                float brate = Neurons.computeAdaDeltaRateForWeight(grad, w, adaxg, rho, eps);
                float[] fArray = _w.raw();
                int n = w;
                fArray[n] = fArray[n] + brate * grad;
                continue;
            }
            if (!nesterov) {
                float delta = rate * grad;
                float[] fArray = _w.raw();
                int n = w;
                fArray[n] = fArray[n] + delta;
                if (!have_momenta) continue;
                float[] fArray2 = _w.raw();
                int n2 = w;
                fArray2[n2] = fArray2[n2] + momentum * _wm.raw()[w];
                _wm.raw()[w] = delta;
                continue;
            }
            float tmp = grad;
            if (have_momenta) {
                float[] fArray = _wm.raw();
                int n = w;
                fArray[n] = fArray[n] * momentum;
                float[] fArray3 = _wm.raw();
                int n3 = w;
                fArray3[n3] = fArray3[n3] + tmp;
                tmp = _wm.raw()[w];
            }
            float[] fArray = _w.raw();
            int n = w;
            fArray[n] = fArray[n] + rate * tmp;
        }
        if (max_w2 != Float.POSITIVE_INFINITY) {
            Neurons.rescale_weights(_w, row, max_w2);
        }
        if (have_ada) {
            avg_grad2 /= (float)prev_a.nnz();
        }
        assert (this._bEA == null);
        this.update_bias(_b, this._bEA, _bm, row, partial_grad, avg_grad2, rate, momentum);
    }

    private static void rescale_weights(Storage.Matrix w, int row, float max_w2) {
        int cols = w.cols();
        if (w instanceof Storage.DenseRowMatrix) {
            Neurons.rescale_weights((Storage.DenseRowMatrix)w, row, max_w2);
        } else if (w instanceof Storage.DenseColMatrix) {
            float r2 = 0.0f;
            for (int col = 0; col < cols; ++col) {
                r2 += w.get(row, col) * w.get(row, col);
            }
            if (r2 > max_w2) {
                float scale = MathUtils.approxSqrt((float)(max_w2 / r2));
                for (int col = 0; col < cols; ++col) {
                    w.set(row, col, w.get(row, col) * scale);
                }
            }
        } else {
            throw new UnsupportedOperationException("rescale weights for " + w.getClass().getSimpleName() + " not yet implemented.");
        }
    }

    private static void rescale_weights(Storage.DenseRowMatrix w, int row, float max_w2) {
        int cols = w.cols();
        int idx = row * cols;
        float r2 = MathUtils.sumSquares((float[])w.raw(), (int)idx, (int)(idx + cols));
        if (r2 > max_w2) {
            float scale = MathUtils.approxSqrt((float)(max_w2 / r2));
            for (int c = 0; c < cols; ++c) {
                float[] fArray = w.raw();
                int n = idx + c;
                fArray[n] = fArray[n] * scale;
            }
        }
    }

    protected float autoEncoderGradient(Distribution dist, int row) {
        assert (this._minfo.get_params()._autoencoder && this._index == this._minfo.get_params()._hidden.length);
        float t = this._input._a.get(row);
        float y = this._a.get(row);
        return (float)dist.gradient((double)t, (double)y);
    }

    private static float computeAdaDeltaRateForWeight(float grad, int row, int col, Storage.DenseColMatrix ada_dx_g, float rho, float eps) {
        ada_dx_g.set(2 * row + 1, col, rho * ada_dx_g.get(2 * row + 1, col) + (1.0f - rho) * grad * grad);
        float rate = MathUtils.approxSqrt((float)((ada_dx_g.get(2 * row, col) + eps) / (ada_dx_g.get(2 * row + 1, col) + eps)));
        ada_dx_g.set(2 * row, col, rho * ada_dx_g.get(2 * row, col) + (1.0f - rho) * rate * rate * grad * grad);
        return rate;
    }

    private static float computeAdaDeltaRateForWeight(float grad, int w, Storage.DenseRowMatrix ada_dx_g, float rho, float eps) {
        ada_dx_g.raw()[2 * w + 1] = rho * ada_dx_g.raw()[2 * w + 1] + (1.0f - rho) * grad * grad;
        float rate = MathUtils.approxSqrt((float)((ada_dx_g.raw()[2 * w] + eps) / (ada_dx_g.raw()[2 * w + 1] + eps)));
        ada_dx_g.raw()[2 * w] = rho * ada_dx_g.raw()[2 * w] + (1.0f - rho) * rate * rate * grad * grad;
        return rate;
    }

    private static float computeAdaDeltaRateForBias(float grad2, int row, Storage.DenseVector bias_ada_dx_g, float rho, float eps) {
        bias_ada_dx_g.raw()[2 * row + 1] = rho * bias_ada_dx_g.raw()[2 * row + 1] + (1.0f - rho) * grad2;
        float rate = MathUtils.approxSqrt((float)((bias_ada_dx_g.raw()[2 * row] + eps) / (bias_ada_dx_g.raw()[2 * row + 1] + eps)));
        bias_ada_dx_g.raw()[2 * row] = rho * bias_ada_dx_g.raw()[2 * row] + (1.0f - rho) * rate * rate * grad2;
        return rate;
    }

    void compute_sparsity() {
        if (this._avg_a != null) {
            for (int row = 0; row < this._avg_a.size(); ++row) {
                this._avg_a.set(row, 0.999f * this._avg_a.get(row) + 0.001f * this._a.get(row));
            }
        }
    }

    void update_bias(Storage.DenseVector _b, Storage.DenseVector _bEA, Storage.DenseVector _bm, int row, float partial_grad, float avg_grad2, float rate, float momentum) {
        boolean have_momenta = this._minfo.has_momenta();
        boolean have_ada = this._minfo.adaDelta();
        float l1 = (float)this.params._l1;
        float l2 = (float)this.params._l2;
        float bias = _b.get(row);
        partial_grad -= Math.signum(bias) * l1 + bias * l2;
        if (_bEA != null) {
            partial_grad = (float)((double)partial_grad - (double)(bias - _bEA.get(row)) * this.params._elastic_averaging_regularization);
        }
        if (have_ada) {
            float rho = (float)this.params._rho;
            float eps = (float)this.params._epsilon;
            rate = Neurons.computeAdaDeltaRateForBias(avg_grad2, row, this._bias_ada_dx_g, rho, eps);
        }
        if (!this.params._nesterov_accelerated_gradient) {
            float delta = rate * partial_grad;
            _b.add(row, delta);
            if (have_momenta) {
                _b.add(row, momentum * _bm.get(row));
                _bm.set(row, delta);
            }
        } else {
            float d = partial_grad;
            if (have_momenta) {
                _bm.set(row, _bm.get(row) * momentum);
                _bm.add(row, d);
                d = _bm.get(row);
            }
            _b.add(row, rate * d);
        }
        if (this.params._autoencoder && this.params._sparsity_beta > 0.0 && !(this instanceof Output) && !(this instanceof Input) && this._index != this.params._hidden.length) {
            _b.add(row, -((float)((double)rate * this.params._sparsity_beta * ((double)this._avg_a.raw()[row] - this.params._average_activation))));
        }
        if (Float.isInfinite(_b.get(row))) {
            this._minfo.set_unstable();
        }
    }

    public float rate(double n) {
        return (float)(this.params._rate / (1.0 + this.params._rate_annealing * n));
    }

    protected float momentum() {
        return this.momentum(-1.0);
    }

    public float momentum(double n) {
        double m = this.params._momentum_start;
        if (this.params._momentum_ramp > 0.0) {
            double num;
            double d = num = n != -1.0 ? (double)this._minfo.get_processed_total() : n;
            m = num >= this.params._momentum_ramp ? this.params._momentum_stable : (m += (this.params._momentum_stable - this.params._momentum_start) * num / this.params._momentum_ramp);
        }
        return (float)m;
    }

    static void gemv_naive(float[] res, float[] a, float[] x, float[] y, byte[] row_bits) {
        int cols = x.length;
        int rows = y.length;
        assert (res.length == rows);
        for (int row = 0; row < rows; ++row) {
            res[row] = 0.0f;
            if (row_bits != null && (row_bits[row / 8] & 1 << row % 8) == 0) continue;
            for (int col = 0; col < cols; ++col) {
                int n = row;
                res[n] = res[n] + a[row * cols + col] * x[col];
            }
            int n = row;
            res[n] = res[n] + y[row];
        }
    }

    static void gemv_row_optimized(float[] res, float[] a, float[] x, float[] y, byte[] row_bits) {
        int cols = x.length;
        int rows = y.length;
        assert (res.length == rows);
        int extra = cols - cols % 8;
        int multiple = cols / 8 * 8 - 1;
        int idx = 0;
        for (int row = 0; row < rows; ++row) {
            res[row] = 0.0f;
            if (row_bits == null || (row_bits[row / 8] & 1 << row % 8) != 0) {
                int col;
                float psum0 = 0.0f;
                float psum1 = 0.0f;
                float psum2 = 0.0f;
                float psum3 = 0.0f;
                float psum4 = 0.0f;
                float psum5 = 0.0f;
                float psum6 = 0.0f;
                float psum7 = 0.0f;
                for (col = 0; col < multiple; col += 8) {
                    int off = idx + col;
                    psum0 += a[off] * x[col];
                    psum1 += a[off + 1] * x[col + 1];
                    psum2 += a[off + 2] * x[col + 2];
                    psum3 += a[off + 3] * x[col + 3];
                    psum4 += a[off + 4] * x[col + 4];
                    psum5 += a[off + 5] * x[col + 5];
                    psum6 += a[off + 6] * x[col + 6];
                    psum7 += a[off + 7] * x[col + 7];
                }
                int n = row;
                res[n] = res[n] + (psum0 + psum1 + psum2 + psum3);
                int n2 = row;
                res[n2] = res[n2] + (psum4 + psum5 + psum6 + psum7);
                for (col = extra; col < cols; ++col) {
                    int n3 = row;
                    res[n3] = res[n3] + a[idx + col] * x[col];
                }
                int n4 = row;
                res[n4] = res[n4] + y[row];
            }
            idx += cols;
        }
    }

    static void gemv(Storage.DenseVector res, Storage.Matrix a, Storage.Vector x, Storage.DenseVector y, byte[] row_bits) {
        if (a instanceof Storage.DenseRowMatrix && x instanceof Storage.DenseVector) {
            Neurons.gemv(res, (Storage.DenseRowMatrix)a, (Storage.DenseVector)x, y, row_bits);
        } else if (a instanceof Storage.DenseColMatrix && x instanceof Storage.SparseVector) {
            Neurons.gemv(res, (Storage.DenseColMatrix)a, (Storage.SparseVector)x, y, row_bits);
        } else if (a instanceof Storage.DenseRowMatrix && x instanceof Storage.SparseVector) {
            Neurons.gemv(res, (Storage.DenseRowMatrix)a, (Storage.SparseVector)x, y, row_bits);
        } else if (a instanceof Storage.DenseColMatrix && x instanceof Storage.DenseVector) {
            Neurons.gemv(res, (Storage.DenseColMatrix)a, (Storage.DenseVector)x, y, row_bits);
        } else {
            throw new UnsupportedOperationException("gemv for matrix " + a.getClass().getSimpleName() + " and vector + " + x.getClass().getSimpleName() + " not yet implemented.");
        }
    }

    static void gemv(Storage.DenseVector res, Storage.DenseRowMatrix a, Storage.DenseVector x, Storage.DenseVector y, byte[] row_bits) {
        Neurons.gemv_row_optimized(res.raw(), a.raw(), x.raw(), y.raw(), row_bits);
    }

    static void gemv_naive(Storage.DenseVector res, Storage.DenseRowMatrix a, Storage.DenseVector x, Storage.DenseVector y, byte[] row_bits) {
        Neurons.gemv_naive(res.raw(), a.raw(), x.raw(), y.raw(), row_bits);
    }

    static void gemv(Storage.DenseVector res, Storage.DenseColMatrix a, Storage.DenseVector x, Storage.DenseVector y, byte[] row_bits) {
        int r;
        int cols = x.size();
        int rows = y.size();
        assert (res.size() == rows);
        for (r = 0; r < rows; ++r) {
            res.set(r, 0.0f);
        }
        for (int c = 0; c < cols; ++c) {
            float val = x.get(c);
            for (int r2 = 0; r2 < rows; ++r2) {
                if (row_bits != null && (row_bits[r2 / 8] & 1 << r2 % 8) == 0) continue;
                res.add(r2, a.get(r2, c) * val);
            }
        }
        for (r = 0; r < rows; ++r) {
            if (row_bits != null && (row_bits[r / 8] & 1 << r % 8) == 0) continue;
            res.add(r, y.get(r));
        }
    }

    static void gemv(Storage.DenseVector res, Storage.DenseRowMatrix a, Storage.SparseVector x, Storage.DenseVector y, byte[] row_bits) {
        int rows = y.size();
        assert (res.size() == rows);
        for (int r = 0; r < rows; ++r) {
            res.set(r, 0.0f);
            if (row_bits != null && (row_bits[r / 8] & 1 << r % 8) == 0) continue;
            int start = x.begin()._idx;
            int end = x.end()._idx;
            for (int it = start; it < end; ++it) {
                res.add(r, a.get(r, x._indices[it]) * x._values[it]);
            }
            res.add(r, y.get(r));
        }
    }

    static void gemv(Storage.DenseVector res, Storage.DenseColMatrix a, Storage.SparseVector x, Storage.DenseVector y, byte[] row_bits) {
        int rows = y.size();
        assert (res.size() == rows);
        for (int r = 0; r < rows; ++r) {
            res.set(r, 0.0f);
        }
        int start = x.begin()._idx;
        int end = x.end()._idx;
        for (int it = start; it < end; ++it) {
            float val = x._values[it];
            if (val == 0.0f) continue;
            for (int r = 0; r < rows; ++r) {
                if (row_bits != null && (row_bits[r / 8] & 1 << r % 8) == 0) continue;
                res.add(r, a.get(r, x._indices[it]) * val);
            }
        }
        for (int r = 0; r < rows; ++r) {
            if (row_bits != null && (row_bits[r / 8] & 1 << r % 8) == 0) continue;
            res.add(r, y.get(r));
        }
    }

    static void gemv(Storage.DenseVector res, Storage.SparseRowMatrix a, Storage.SparseVector x, Storage.DenseVector y, byte[] row_bits) {
        int rows = y.size();
        assert (res.size() == rows);
        for (int r = 0; r < rows; ++r) {
            res.set(r, 0.0f);
            if (row_bits != null && (row_bits[r / 8] & 1 << r % 8) == 0) continue;
            TreeMap<Integer, Float> row = a.row(r);
            Set<Map.Entry<Integer, Float>> set = row.entrySet();
            for (Map.Entry<Integer, Float> e : set) {
                float val = x.get(e.getKey());
                if (val == 0.0f) continue;
                res.add(r, e.getValue().floatValue() * val);
            }
            res.add(r, y.get(r));
        }
    }

    static void gemv(Storage.DenseVector res, Storage.SparseColMatrix a, Storage.SparseVector x, Storage.DenseVector y, byte[] row_bits) {
        int r;
        int rows = y.size();
        assert (res.size() == rows);
        for (r = 0; r < rows; ++r) {
            res.set(r, 0.0f);
        }
        for (int c = 0; c < a.cols(); ++c) {
            TreeMap<Integer, Float> col = a.col(c);
            float val = x.get(c);
            if (val == 0.0f) continue;
            for (Map.Entry<Integer, Float> e : col.entrySet()) {
                int r2 = e.getKey();
                if (row_bits != null && (row_bits[r2 / 8] & 1 << r2 % 8) == 0) continue;
                res.add(r2, e.getValue().floatValue() * val);
            }
        }
        for (r = 0; r < rows; ++r) {
            if (row_bits != null && (row_bits[r / 8] & 1 << r % 8) == 0) continue;
            res.add(r, y.get(r));
        }
    }

    public static class Linear
    extends Output {
        public Linear(int units) {
            super(units);
        }

        protected void fprop() {
            Linear.gemv((Storage.DenseVector)this._a, this._w, this._previous._a, this._b, this._dropout != null ? this._dropout.bits() : null);
        }

        protected void bprop(float target) {
            assert (target != missing_real_value.floatValue());
            boolean row = false;
            float t = target;
            float y = this._a.get(0);
            float g = (float)new Distribution(this.params._distribution, this.params._tweedie_power).gradient((double)t, (double)y);
            float m = this.momentum();
            float r = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m);
            this.bprop(0, g, r, m);
        }
    }

    public static class Softmax
    extends Output {
        public Softmax(int units) {
            super(units);
        }

        protected void fprop() {
            Softmax.gemv((Storage.DenseVector)this._a, (Storage.DenseRowMatrix)this._w, (Storage.DenseVector)this._previous._a, this._b, null);
            float max = ArrayUtils.maxValue((float[])this._a.raw());
            float scale = 0.0f;
            float rows = this._a.size();
            int row = 0;
            while ((float)row < rows) {
                this._a.set(row, (float)Math.exp(this._a.get(row) - max));
                scale += this._a.get(row);
                ++row;
            }
            row = 0;
            while ((float)row < rows) {
                float[] fArray = this._a.raw();
                int n = row;
                fArray[n] = fArray[n] / scale;
                if (Float.isNaN(this._a.get(row))) {
                    this._minfo.set_unstable();
                    throw new RuntimeException("Numerical instability, predicted NaN.");
                }
                ++row;
            }
        }

        protected void bprop(int target) {
            assert (target != Integer.MAX_VALUE);
            float m = this.momentum();
            float r = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m);
            float rows = this._a.size();
            int row = 0;
            while ((float)row < rows) {
                float g;
                float t = row == target ? 1.0f : 0.0f;
                float y = this._a.get(row);
                switch (this.params._loss) {
                    case Automatic: 
                    case CrossEntropy: {
                        g = t - y;
                        break;
                    }
                    case Absolute: {
                        g = (2.0f * t - 1.0f) * (1.0f - y) * y;
                        break;
                    }
                    case MeanSquare: {
                        g = (t - y) * (1.0f - y) * y;
                        break;
                    }
                    case Huber: {
                        g = t == 0.0f ? ((double)y < 0.5 ? -4.0f * y : -2.0f) : ((double)y > 0.5 ? 4.0f * (1.0f - y) : 2.0f);
                        g *= (1.0f - y) * y;
                        break;
                    }
                    default: {
                        throw H2O.unimpl();
                    }
                }
                this.bprop(row, g, r, m);
                ++row;
            }
        }
    }

    public static abstract class Output
    extends Neurons {
        Output(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training) {
            throw new UnsupportedOperationException();
        }

        @Override
        protected void bprop() {
            throw new UnsupportedOperationException();
        }
    }

    public static class RectifierDropout
    extends Rectifier {
        public RectifierDropout(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training) {
            if (training) {
                this._dropout.fillBytes(seed += this.params._seed + 1014100461L);
                super.fprop(seed, true);
            } else {
                super.fprop(seed, false);
                ArrayUtils.mult((float[])this._a.raw(), (float)((float)(1.0 - this.params._hidden_dropout_ratios[this._index])));
            }
        }
    }

    public static class Rectifier
    extends Neurons {
        public Rectifier(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training) {
            Rectifier.gemv((Storage.DenseVector)this._a, this._w, this._previous._a, this._b, this._dropout != null ? this._dropout.bits() : null);
            int rows = this._a.size();
            for (int row = 0; row < rows; ++row) {
                this._a.set(row, Math.max(this._a.get(row), 0.0f));
                this.compute_sparsity();
            }
        }

        @Override
        protected void bprop() {
            float m = this.momentum();
            float r = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m);
            int rows = this._a.size();
            if (this._w instanceof Storage.DenseRowMatrix) {
                Distribution dist = this._minfo.get_params()._autoencoder && this._index == this._minfo.get_params()._hidden.length ? new Distribution(this.params._distribution) : null;
                for (int row = 0; row < rows; ++row) {
                    if (dist != null) {
                        this._e.set(row, this.autoEncoderGradient(dist, row));
                    }
                    float g = this._a.get(row) > 0.0f ? this._e.get(row) : 0.0f;
                    this.bprop(row, g, r, m);
                }
            } else {
                this.bprop_sparse(r, m);
            }
        }
    }

    public static class MaxoutDropout
    extends Maxout {
        public MaxoutDropout(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training) {
            if (training) {
                this._dropout.fillBytes(seed += this.params._seed + 1372114957L);
                super.fprop(seed, true);
            } else {
                super.fprop(seed, false);
                ArrayUtils.mult((float[])this._a.raw(), (float)((float)(1.0 - this.params._hidden_dropout_ratios[this._index])));
            }
        }
    }

    public static class Maxout
    extends Neurons {
        public Maxout(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training) {
            float max = 0.0f;
            int rows = this._a.size();
            if (this._previous._a instanceof Storage.DenseVector) {
                for (int row = 0; row < rows; ++row) {
                    this._a.set(row, 0.0f);
                    if (training && this._dropout != null && !this._dropout.unit_active(row)) continue;
                    this._a.set(row, Float.NEGATIVE_INFINITY);
                    for (int i = 0; i < this._previous._a.size(); ++i) {
                        this._a.set(row, Math.max(this._a.get(row), this._w.get(row, i) * this._previous._a.get(i)));
                    }
                    if (Float.isInfinite(-this._a.get(row))) {
                        this._a.set(row, 0.0f);
                    }
                    this._a.add(row, this._b.get(row));
                    max = Math.max(this._a.get(row), max);
                }
                if (max > 1.0f) {
                    ArrayUtils.div((float[])this._a.raw(), (float)max);
                }
            } else {
                Storage.SparseVector x = (Storage.SparseVector)this._previous._a;
                for (int row = 0; row < this._a.size(); ++row) {
                    this._a.set(row, 0.0f);
                    if (training && this._dropout != null && !this._dropout.unit_active(row)) continue;
                    float mymax = Float.NEGATIVE_INFINITY;
                    int start = x.begin()._idx;
                    int end = x.end()._idx;
                    for (int it = start; it < end; ++it) {
                        mymax = Math.max(mymax, this._w.get(row, x._indices[it]) * x._values[it]);
                    }
                    this._a.set(row, mymax);
                    if (Float.isInfinite(-this._a.get(row))) {
                        this._a.set(row, 0.0f);
                    }
                    this._a.add(row, this._b.get(row));
                    max = Math.max(this._a.get(row), max);
                }
                if (max > 1.0f) {
                    ArrayUtils.div((float[])this._a.raw(), (float)max);
                }
            }
            this.compute_sparsity();
        }

        @Override
        protected void bprop() {
            float r;
            float m = this.momentum();
            float f = r = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m);
            if (this._w instanceof Storage.DenseRowMatrix) {
                int rows = this._a.size();
                for (int row = 0; row < rows; ++row) {
                    assert (!this._minfo.get_params()._autoencoder);
                    float g = this._e.get(row);
                    this.bprop(row, g, r, m);
                }
            } else {
                this.bprop_sparse(r, m);
            }
        }
    }

    public static class TanhDropout
    extends Tanh {
        public TanhDropout(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training) {
            if (training) {
                this._dropout.fillBytes(seed += this.params._seed + -629514240L);
                super.fprop(seed, true);
            } else {
                super.fprop(seed, false);
                ArrayUtils.mult((float[])this._a.raw(), (float)((float)(1.0 - this.params._hidden_dropout_ratios[this._index])));
            }
        }
    }

    public static class Tanh
    extends Neurons {
        public Tanh(int units) {
            super(units);
        }

        @Override
        protected void fprop(long seed, boolean training) {
            Tanh.gemv((Storage.DenseVector)this._a, this._w, this._previous._a, this._b, this._dropout != null ? this._dropout.bits() : null);
            int rows = this._a.size();
            for (int row = 0; row < rows; ++row) {
                this._a.set(row, 1.0f - 2.0f / (1.0f + (float)Math.exp(2.0f * this._a.get(row))));
            }
            this.compute_sparsity();
        }

        @Override
        protected void bprop() {
            float r;
            float m = this.momentum();
            float f = r = this._minfo.adaDelta() ? 0.0f : this.rate(this._minfo.get_processed_total()) * (1.0f - m);
            if (this._w instanceof Storage.DenseRowMatrix) {
                int rows = this._a.size();
                Distribution dist = this._minfo.get_params()._autoencoder && this._index == this._minfo.get_params()._hidden.length ? new Distribution(this.params._distribution) : null;
                for (int row = 0; row < rows; ++row) {
                    if (dist != null) {
                        this._e.set(row, this.autoEncoderGradient(dist, row));
                    }
                    float g = this._e.get(row) * (1.0f - this._a.get(row) * this._a.get(row));
                    this.bprop(row, g, r, m);
                }
            } else {
                this.bprop_sparse(r, m);
            }
        }
    }

    public static class Input
    extends Neurons {
        private DataInfo _dinfo;
        Storage.SparseVector _svec;
        Storage.DenseVector _dvec;

        Input(int units, DataInfo d) {
            super(units);
            this._dinfo = d;
            this._a = new Storage.DenseVector(units);
            this._dvec = (Storage.DenseVector)this._a;
        }

        @Override
        protected void bprop() {
            throw new UnsupportedOperationException();
        }

        @Override
        protected void fprop(long seed, boolean training) {
            throw new UnsupportedOperationException();
        }

        public void setInput(long seed, double[] data) {
            int i;
            assert (this._dinfo != null);
            double[] nums = MemoryManager.malloc8d((int)this._dinfo._nums);
            int[] cats = MemoryManager.malloc4((int)this._dinfo._cats);
            int ncats = 0;
            for (i = 0; i < this._dinfo._cats; ++i) {
                assert (this._dinfo._catMissing[i] != 0);
                if (Double.isNaN(data[i])) {
                    cats[ncats] = this._dinfo._catOffsets[i + 1] - 1;
                } else {
                    int c = (int)data[i];
                    if (this._dinfo._useAllFactorLevels) {
                        cats[ncats] = c + this._dinfo._catOffsets[i];
                    } else if (c != 0) {
                        cats[ncats] = c + this._dinfo._catOffsets[i] - 1;
                    }
                    if (cats[ncats] >= this._dinfo._catOffsets[i + 1]) {
                        cats[ncats] = this._dinfo._catOffsets[i + 1] - 1;
                    }
                }
                ++ncats;
            }
            int n = data.length - (this._dinfo._weights ? 1 : 0) - (this._dinfo._offset ? 1 : 0);
            while (i < n) {
                double d = data[i];
                if (this._dinfo._normMul != null) {
                    d = (d - this._dinfo._normSub[i - this._dinfo._cats]) * this._dinfo._normMul[i - this._dinfo._cats];
                }
                nums[i - this._dinfo._cats] = d;
                ++i;
            }
            this.setInput(seed, nums, ncats, cats);
        }

        public void setInput(long seed, double[] nums, int numcat, int[] cats) {
            this._a = this._dvec;
            Arrays.fill(this._a.raw(), 0.0f);
            if (this.params._max_categorical_features < this._dinfo.fullN() - this._dinfo._nums) {
                int i;
                assert (nums.length == this._dinfo._nums);
                int M = nums.length + this.params._max_categorical_features;
                boolean random_projection = false;
                boolean hash_trick = true;
                assert (this._a.size() == M);
                int cM = this.params._max_categorical_features;
                assert (this._a.size() == M);
                MurmurHash murmur = MurmurHash.getInstance();
                for (i = 0; i < numcat; ++i) {
                    ByteBuffer buf = ByteBuffer.allocate(4);
                    int hashval = murmur.hash(buf.putInt(cats[i]).array(), 4, (int)this.params._seed);
                    this._a.add(Math.abs(hashval % cM), 1.0f);
                }
                for (i = 0; i < nums.length; ++i) {
                    this._a.set(cM + i, Double.isNaN(nums[i]) ? 0.0f : (float)nums[i]);
                }
            } else {
                int i;
                assert (this._a.size() == this._dinfo.fullN());
                for (i = 0; i < numcat; ++i) {
                    this._a.set(cats[i], 1.0f);
                }
                for (i = 0; i < nums.length; ++i) {
                    this._a.set(this._dinfo.numStart() + i, Double.isNaN(nums[i]) ? 0.0f : (float)nums[i]);
                }
            }
            if (this._dropout == null) {
                return;
            }
            this._dropout.randomlySparsifyActivation(this._a, seed += this.params._seed + 322417854L);
            if (this.params._sparse) {
                this._svec = new Storage.SparseVector(this._dvec);
                this._a = this._svec;
            }
        }
    }
}

