/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.tree.DHistogram;
import hex.tree.DTree;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.AtomicUtils;
import water.util.IcedBitSet;

public class DBinomHistogram
extends DHistogram<DBinomHistogram> {
    public int[] _sums;

    public DBinomHistogram(String name, int nbins, byte isInt, float min, float maxEx, long nelems) {
        super(name, nbins, isInt, min, maxEx, nelems);
    }

    @Override
    boolean isBinom() {
        return true;
    }

    @Override
    public double mean(int b) {
        int n = this._bins[b];
        return n > 0 ? (double)this._sums[b] / (double)n : 0.0;
    }

    @Override
    public double var(int b) {
        int n = this._bins[b];
        if (n <= 1) {
            return 0.0;
        }
        return ((double)this._sums[b] - (double)this._sums[b] * (double)this._sums[b] / (double)n) / (double)(n - 1);
    }

    @Override
    void init0() {
        this._sums = MemoryManager.malloc4((int)this._nbin);
    }

    @Override
    void incr0(int b, double y) {
        AtomicUtils.IntArray.incr((int[])this._sums, (int)b);
    }

    @Override
    void add0(DBinomHistogram dsh) {
        ArrayUtils.add((int[])this._sums, (int[])dsh._sums);
    }

    @Override
    public DTree.Split scoreMSE(int col, int min_rows) {
        int b;
        int nbins = this.nbins();
        assert (nbins > 1);
        int[] sums = this._sums;
        int[] bins = this._bins;
        int[] idxs = null;
        if (this._isInt == 2 && this._step == 1.0f && nbins >= 4) {
            int i;
            idxs = MemoryManager.malloc4((int)(nbins + 1));
            for (int i2 = 0; i2 < nbins + 1; ++i2) {
                idxs[i2] = i2;
            }
            double[] avgs = MemoryManager.malloc8d((int)(nbins + 1));
            for (i = 0; i < nbins; ++i) {
                avgs[i] = this._bins[i] == 0 ? 0.0 : (double)this._sums[i] / (double)this._bins[i];
            }
            avgs[nbins] = Double.MAX_VALUE;
            ArrayUtils.sort((int[])idxs, (double[])avgs);
            sums = MemoryManager.malloc4((int)nbins);
            bins = MemoryManager.malloc4((int)nbins);
            for (i = 0; i < nbins; ++i) {
                sums[i] = this._sums[idxs[i]];
                bins[i] = this._bins[idxs[i]];
            }
        }
        double[] sums0 = MemoryManager.malloc8d((int)(nbins + 1));
        long[] ns0 = MemoryManager.malloc8((int)(nbins + 1));
        for (int b2 = 1; b2 <= nbins; ++b2) {
            double m0 = sums0[b2 - 1];
            double m1 = sums[b2 - 1];
            long k0 = ns0[b2 - 1];
            long k1 = bins[b2 - 1];
            if (k0 == 0L && k1 == 0L) continue;
            sums0[b2] = m0 + m1;
            ns0[b2] = k0 + k1;
        }
        long tot = ns0[nbins];
        if (tot < (long)(2 * min_rows)) {
            return null;
        }
        double var = sums0[nbins] * ((double)tot - sums0[nbins]);
        if (var == 0.0) {
            assert (this.isConstantResponse());
            return null;
        }
        if ((float)var == 0.0f) {
            return null;
        }
        double[] sums1 = MemoryManager.malloc8d((int)(nbins + 1));
        long[] ns1 = MemoryManager.malloc8((int)(nbins + 1));
        for (int b3 = nbins - 1; b3 >= 0; --b3) {
            double m0 = sums1[b3 + 1];
            double m1 = sums[b3];
            long k0 = ns1[b3 + 1];
            long k1 = bins[b3];
            if (k0 == 0L && k1 == 0L) continue;
            sums1[b3] = m0 + m1;
            ns1[b3] = k0 + k1;
            assert (ns0[b3] + ns1[b3] == tot);
        }
        int best = 0;
        double best_se0 = Double.MAX_VALUE;
        double best_se1 = Double.MAX_VALUE;
        byte equal = 0;
        for (b = 1; b <= nbins - 1; ++b) {
            if (bins[b] == 0 || ns0[b] < (long)min_rows) continue;
            if (ns1[b] < (long)min_rows) break;
            double se0 = sums0[b] * (1.0 - sums0[b] / (double)ns0[b]);
            double se1 = sums1[b] * (1.0 - sums1[b] / (double)ns1[b]);
            if (!(se0 + se1 < best_se0 + best_se1) && (se0 + se1 != best_se0 + best_se1 || Math.abs(b - (nbins >> 1)) >= Math.abs(best - (nbins >> 1)))) continue;
            best_se0 = se0;
            best_se1 = se1;
            best = b;
        }
        if (this._isInt > 0 && this._step == 1.0f && this._maxEx - this._min > 2.0f && idxs == null) {
            for (b = 1; b <= nbins - 1; ++b) {
                double sx;
                double sums2;
                double si;
                long N;
                if (bins[b] < min_rows || (N = ns0[b] + ns1[b + 1]) < (long)min_rows || !((si = (sums2 = sums0[b] + sums1[b + 1]) * (1.0 - sums2 / (double)N)) + (sx = (double)(sums[b] - sums[b] * sums[b] / bins[b])) < best_se0 + best_se1)) continue;
                best_se0 = si;
                best_se1 = sx;
                best = b;
                equal = 1;
            }
        }
        IcedBitSet bs = null;
        if (idxs != null) {
            int i;
            int min = Integer.MAX_VALUE;
            int max = Integer.MIN_VALUE;
            for (i = best; i < nbins; ++i) {
                min = Math.min(min, idxs[i]);
                max = Math.max(max, idxs[i]);
            }
            bs = new IcedBitSet(max - min + 1, min);
            for (i = best; i < nbins; ++i) {
                bs.set(idxs[i]);
            }
            equal = (byte)(bs.max() <= 32 ? 2 : 3);
        }
        if (best == 0) {
            return null;
        }
        double se = sums1[0] * (1.0 - sums1[0] / (double)ns1[0]);
        if (se <= best_se0 + best_se1) {
            return null;
        }
        long n0 = equal != 1 ? ns0[best] : ns0[best] + ns1[best + 1];
        long n1 = equal != 1 ? ns1[best] : (long)bins[best];
        double p0 = equal != 1 ? sums0[best] : sums0[best] + sums1[best + 1];
        double p1 = equal != 1 ? sums1[best] : (double)sums[best];
        return new DTree.Split(col, best, bs, equal, se, best_se0, best_se1, n0, n1, p0 / (double)n0, p1 / (double)n1);
    }

    @Override
    public long byteSize0() {
        return 32 + this._sums.length << 3;
    }
}

