/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.tree.DHistogram;
import hex.tree.DTree;
import java.util.Arrays;
import java.util.Comparator;
import water.H2O;
import water.MemoryManager;
import water.util.ArrayUtils;
import water.util.AtomicUtils;
import water.util.IcedBitSet;

public class DBinomHistogram
extends DHistogram<DBinomHistogram> {
    private long[] _sums;

    public DBinomHistogram(String name, int nbins, byte isInt, float min, float maxEx, long nelems, boolean doGrpSplit) {
        super(name, nbins, isInt, min, maxEx, nelems, doGrpSplit);
    }

    @Override
    boolean isBinom() {
        return true;
    }

    @Override
    public double mean(int b) {
        long n = this._bins[b];
        return n > 0L ? (double)this._sums[b] / (double)n : 0.0;
    }

    @Override
    public double var(int b) {
        long n = this._bins[b];
        if (n <= 1L) {
            return 0.0;
        }
        return ((double)this._sums[b] - (double)this._sums[b] * (double)this._sums[b] / (double)n) / (double)(n - 1L);
    }

    @Override
    void init0() {
        this._sums = MemoryManager.malloc8((int)this._nbin);
    }

    @Override
    void incr0(int b, double y) {
        AtomicUtils.LongArray.incr((long[])this._sums, (int)b);
    }

    @Override
    void add0(DBinomHistogram dsh) {
        ArrayUtils.add((long[])this._sums, (long[])dsh._sums);
    }

    @Override
    public DTree.Split scoreMSE(int col) {
        int b;
        int nbins = this.nbins();
        assert (nbins > 1);
        Integer[] idx = new Integer[nbins];
        for (int b2 = 0; b2 < nbins; ++b2) {
            idx[b2] = b2;
        }
        if (this._isInt == 2 && this._step == 1.0f && nbins >= 4 && this._doGrpSplit) {
            final Double[] means = new Double[nbins];
            for (int b3 = 0; b3 < nbins; ++b3) {
                means[b3] = this.mean(b3);
            }
            Arrays.sort(idx, new Comparator<Integer>(){

                @Override
                public int compare(Integer o1, Integer o2) {
                    return means[o1].compareTo(means[o2]);
                }
            });
        }
        long[] sums0 = MemoryManager.malloc8((int)(nbins + 1));
        long[] ns0 = MemoryManager.malloc8((int)(nbins + 1));
        for (int b4 = 1; b4 <= nbins; ++b4) {
            long m0 = sums0[b4 - 1];
            long m1 = this._sums[idx[b4 - 1]];
            long k0 = ns0[b4 - 1];
            long k1 = this._bins[idx[b4 - 1]];
            if (k0 == 0L && k1 == 0L) continue;
            sums0[b4] = m0 + m1;
            ns0[b4] = k0 + k1;
        }
        long tot = ns0[nbins];
        if (sums0[nbins] == 0L || sums0[nbins] == tot) {
            assert (this.isConstantResponse());
            return null;
        }
        long[] sums1 = MemoryManager.malloc8((int)(nbins + 1));
        long[] ns1 = MemoryManager.malloc8((int)(nbins + 1));
        for (int b5 = nbins - 1; b5 >= 0; --b5) {
            long m0 = sums1[b5 + 1];
            long m1 = this._sums[idx[b5]];
            long k0 = ns1[b5 + 1];
            long k1 = this._bins[idx[b5]];
            if (k0 == 0L && k1 == 0L) continue;
            sums1[b5] = m0 + m1;
            ns1[b5] = k0 + k1;
            assert (ns0[b5] + ns1[b5] == tot);
        }
        int best = 0;
        double best_se0 = Double.MAX_VALUE;
        double best_se1 = Double.MAX_VALUE;
        int equal = 0;
        for (b = 1; b <= nbins - 1; ++b) {
            if (this._bins[idx[b]] == 0) continue;
            double se0 = sums0[b];
            se0 -= se0 * se0 / (double)ns0[b];
            double se1 = sums1[b];
            if (!(se0 + (se1 -= se1 * se1 / (double)ns1[b]) < best_se0 + best_se1) && (se0 + se1 != best_se0 + best_se1 || Math.abs(b - (nbins >> 1)) >= Math.abs(best - (nbins >> 1)))) continue;
            best_se0 = se0;
            best_se1 = se1;
            best = b;
        }
        if (this._isInt > 0 && this._step == 1.0f && this._maxEx - this._min > 2.0f) {
            for (b = 1; b <= nbins - 1; ++b) {
                double sumb;
                double sx;
                double sums;
                double si;
                long N;
                if (this._bins[idx[b]] == 0 || (N = ns0[b + 0] + ns1[b + 1]) == 0L || !((si = (sums = (double)(sums0[b + 0] + sums1[b + 1])) - sums * sums / (double)N) + (sx = (sumb = (double)this._sums[idx[b + 0]]) - sumb * sumb / (double)this._bins[idx[b]]) < best_se0 + best_se1)) continue;
                best_se0 = si;
                best_se1 = sx;
                best = b;
                equal = 1;
            }
        }
        if (best == 0) {
            return null;
        }
        assert (best > 0) : "Must actually pick a split " + best;
        long n0 = equal == 0 ? ns0[best] : ns0[best] + ns1[best + 1];
        long n1 = equal == 0 ? ns1[best] : (long)this._bins[idx[best]];
        double p0 = equal == 0 ? (double)sums0[best] : (double)(sums0[best] + sums1[best + 1]);
        double p1 = equal == 0 ? (double)sums1[best] : (double)this._sums[idx[best]];
        IcedBitSet bs = null;
        if (this._isInt == 2 && this._step == 1.0f && nbins >= 4 && this._doGrpSplit) {
            int offset = (int)this._min;
            if (this._maxEx <= 32.0f) {
                equal = 2;
                bs = new IcedBitSet(32);
                for (int i = best; i < nbins; ++i) {
                    bs.set(idx[i] + offset);
                }
                throw H2O.unimpl();
            }
            equal = 3;
            bs = new IcedBitSet(nbins, offset);
            for (int i = best; i < nbins; ++i) {
                bs.set(idx[i].intValue());
            }
            throw H2O.unimpl();
        }
        return new DTree.Split(col, best, bs, (byte)equal, best_se0, best_se1, n0, n1, p0 / (double)n0, p1 / (double)n1);
    }

    @Override
    public long byteSize0() {
        return 32 + this._sums.length << 3;
    }
}

