/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.tree.DHistogram;
import hex.tree.DTree;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.AtomicUtils;
import water.util.IcedBitSet;
import water.util.MathUtils;

public class DRealHistogram
extends DHistogram<DRealHistogram> {
    private double[] _sums;
    private double[] _ssqs;

    public DRealHistogram(String name, int nbins, int nbins_cats, byte isInt, float min, float maxEx) {
        super(name, nbins, nbins_cats, isInt, min, maxEx);
    }

    @Override
    public double mean(int b) {
        double n = this._bins[b];
        return n > 0.0 ? this._sums[b] / n : 0.0;
    }

    @Override
    public double var(int b) {
        double n = this._bins[b];
        if (n <= 1.0) {
            return 0.0;
        }
        return (this._ssqs[b] - this._sums[b] * this._sums[b] / n) / (n - 1.0);
    }

    @Override
    void init0() {
        this._sums = MemoryManager.malloc8d((int)this._nbin);
        this._ssqs = MemoryManager.malloc8d((int)this._nbin);
    }

    @Override
    void incr0(int b, double y, double w) {
        AtomicUtils.DoubleArray.add((double[])this._sums, (int)b, (double)(w * y));
        AtomicUtils.DoubleArray.add((double[])this._ssqs, (int)b, (double)(w * y * y));
    }

    void incr1(int b, double y, double yy) {
        AtomicUtils.DoubleArray.add((double[])this._sums, (int)b, (double)y);
        AtomicUtils.DoubleArray.add((double[])this._ssqs, (int)b, (double)yy);
    }

    @Override
    void add0(DRealHistogram dsh) {
        ArrayUtils.add((double[])this._sums, (double[])dsh._sums);
        ArrayUtils.add((double[])this._ssqs, (double[])dsh._ssqs);
    }

    @Override
    public DTree.Split scoreMSE(int col, double min_rows) {
        double p1;
        int b;
        int nbins = this.nbins();
        assert (nbins > 1);
        double[] sums = this._sums;
        double[] ssqs = this._ssqs;
        double[] bins = this._bins;
        int[] idxs = null;
        if (this._isInt == 2 && this._step == 1.0f && nbins >= 4) {
            int i;
            idxs = MemoryManager.malloc4((int)(nbins + 1));
            for (int i2 = 0; i2 < nbins + 1; ++i2) {
                idxs[i2] = i2;
            }
            double[] avgs = MemoryManager.malloc8d((int)(nbins + 1));
            for (i = 0; i < nbins; ++i) {
                avgs[i] = this._bins[i] == 0.0 ? 0.0 : this._sums[i] / this._bins[i];
            }
            avgs[nbins] = Double.MAX_VALUE;
            ArrayUtils.sort((int[])idxs, (double[])avgs);
            sums = MemoryManager.malloc8d((int)nbins);
            ssqs = MemoryManager.malloc8d((int)nbins);
            bins = MemoryManager.malloc8d((int)nbins);
            for (i = 0; i < nbins; ++i) {
                sums[i] = this._sums[idxs[i]];
                ssqs[i] = this._ssqs[idxs[i]];
                bins[i] = this._bins[idxs[i]];
            }
        }
        double[] sums0 = MemoryManager.malloc8d((int)(nbins + 1));
        double[] ssqs0 = MemoryManager.malloc8d((int)(nbins + 1));
        double[] ns0 = MemoryManager.malloc8d((int)(nbins + 1));
        for (int b2 = 1; b2 <= nbins; ++b2) {
            double m0 = sums0[b2 - 1];
            double m1 = sums[b2 - 1];
            double s0 = ssqs0[b2 - 1];
            double s1 = ssqs[b2 - 1];
            double k0 = ns0[b2 - 1];
            double k1 = bins[b2 - 1];
            if (k0 == 0.0 && k1 == 0.0) continue;
            sums0[b2] = m0 + m1;
            ssqs0[b2] = s0 + s1;
            ns0[b2] = k0 + k1;
        }
        double tot = ns0[nbins];
        if (tot < 2.0 * min_rows) {
            return null;
        }
        double var = ssqs0[nbins] * tot - sums0[nbins] * sums0[nbins];
        if (var == 0.0) {
            assert (this.isConstantResponse());
            return null;
        }
        if ((float)var == 0.0f) {
            return null;
        }
        double[] sums1 = MemoryManager.malloc8d((int)(nbins + 1));
        double[] ssqs1 = MemoryManager.malloc8d((int)(nbins + 1));
        double[] ns1 = MemoryManager.malloc8d((int)(nbins + 1));
        for (int b3 = nbins - 1; b3 >= 0; --b3) {
            double m0 = sums1[b3 + 1];
            double m1 = sums[b3];
            double s0 = ssqs1[b3 + 1];
            double s1 = ssqs[b3];
            double k0 = ns1[b3 + 1];
            double k1 = bins[b3];
            if (k0 == 0.0 && k1 == 0.0) continue;
            sums1[b3] = m0 + m1;
            ssqs1[b3] = s0 + s1;
            ns1[b3] = k0 + k1;
            assert (MathUtils.compare((double)(ns0[b3] + ns1[b3]), (double)tot, (double)1.0E-5, (double)1.0E-5));
        }
        int best = 0;
        double best_se0 = Double.MAX_VALUE;
        double best_se1 = Double.MAX_VALUE;
        byte equal = 0;
        for (b = 1; b <= nbins - 1; ++b) {
            if (bins[b] == 0.0 || ns0[b] < min_rows) continue;
            if (ns1[b] < min_rows) break;
            double se0 = ssqs0[b] - sums0[b] * sums0[b] / ns0[b];
            double se1 = ssqs1[b] - sums1[b] * sums1[b] / ns1[b];
            if (se0 < 0.0) {
                se0 = 0.0;
            }
            if (se1 < 0.0) {
                se1 = 0.0;
            }
            if (!(se0 + se1 < best_se0 + best_se1) && (se0 + se1 != best_se0 + best_se1 || Math.abs(b - (nbins >> 1)) >= Math.abs(best - (nbins >> 1)))) continue;
            best_se0 = se0;
            best_se1 = se1;
            best = b;
        }
        if (this._isInt > 0 && this._step == 1.0f && this._maxEx - this._min > 2.0f && idxs == null) {
            for (b = 1; b <= nbins - 1; ++b) {
                double N;
                if (bins[b] < min_rows || (N = ns0[b] + ns1[b + 1]) < min_rows) continue;
                double sums2 = sums0[b] + sums1[b + 1];
                double ssqs2 = ssqs0[b] + ssqs1[b + 1];
                double si = ssqs2 - sums2 * sums2 / N;
                double sx = ssqs[b] - sums[b] * sums[b] / bins[b];
                if (si < 0.0) {
                    si = 0.0;
                }
                if (sx < 0.0) {
                    sx = 0.0;
                }
                if (!(si + sx < best_se0 + best_se1)) continue;
                best_se0 = si;
                best_se1 = sx;
                best = b;
                equal = 1;
            }
        }
        IcedBitSet bs = null;
        if (idxs != null) {
            int i;
            int min = Integer.MAX_VALUE;
            int max = Integer.MIN_VALUE;
            for (i = best; i < nbins; ++i) {
                min = Math.min(min, idxs[i]);
                max = Math.max(max, idxs[i]);
            }
            bs = new IcedBitSet(max - min + 1, min);
            for (i = best; i < nbins; ++i) {
                bs.set(idxs[i]);
            }
            equal = (byte)(bs.max() <= 32 ? 2 : 3);
        }
        if (best == 0) {
            return null;
        }
        double se = ssqs1[0] - sums1[0] * sums1[0] / ns1[0];
        if (se <= best_se0 + best_se1) {
            return null;
        }
        double n0 = equal != 1 ? ns0[best] : ns0[best] + ns1[best + 1];
        double n1 = equal != 1 ? ns1[best] : bins[best];
        double p0 = equal != 1 ? sums0[best] : sums0[best] + sums1[best + 1];
        double d = p1 = equal != 1 ? sums1[best] : sums[best];
        if (MathUtils.equalsWithinOneSmallUlp((float)((float)(p0 / n0)), (float)((float)(p1 / n1)))) {
            return null;
        }
        return new DTree.Split(col, best, bs, equal, se, best_se0, best_se1, n0, n1, p0 / n0, p1 / n1);
    }

    @Override
    public long byteSize0() {
        return 40 + this._sums.length << 27 + this._ssqs.length << 3;
    }
}

