/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.dt.binning;

import hex.tree.dt.CategoricalSplittingRule;
import hex.tree.dt.NumericSplittingRule;
import hex.tree.dt.binning.AbstractBin;
import hex.tree.dt.binning.CategoricalBin;
import hex.tree.dt.binning.NumericBin;
import hex.tree.dt.binning.SplitStatistics;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class FeatureBins {
    private List<AbstractBin> _bins;
    private final boolean _isConstant;
    private int _numOfCategories;

    public FeatureBins(List<AbstractBin> bins) {
        this(bins, -1);
    }

    public FeatureBins(List<AbstractBin> bins, int numOfCategories) {
        if (bins == null) {
            this._isConstant = true;
        } else {
            this._isConstant = false;
            this._bins = bins;
            this._numOfCategories = numOfCategories;
        }
    }

    public List<SplitStatistics> calculateSplitStatisticsForNumericFeature() {
        List<SplitStatistics> statistics = this._bins.stream().map(b -> new SplitStatistics()).collect(Collectors.toList());
        SplitStatistics tmpAccumulatorLeft = new SplitStatistics();
        SplitStatistics tmpAccumulatorRight = new SplitStatistics();
        for (int leftIndex = 0; leftIndex < statistics.size(); ++leftIndex) {
            tmpAccumulatorLeft.accumulateLeftStatistics(this._bins.get((int)leftIndex)._count, this._bins.get((int)leftIndex)._count0);
            statistics.get(leftIndex).copyLeftValues(tmpAccumulatorLeft);
            statistics.get((int)leftIndex)._splittingRule = new NumericSplittingRule(((NumericBin)this._bins.get((int)leftIndex))._max);
            int rightIndex = this._bins.size() - leftIndex - 1;
            statistics.get(rightIndex).copyRightValues(tmpAccumulatorRight);
            tmpAccumulatorRight.accumulateRightStatistics(this._bins.get((int)rightIndex)._count, this._bins.get((int)rightIndex)._count0);
        }
        return statistics;
    }

    public boolean isConstant() {
        return this._isConstant;
    }

    List<AbstractBin> getFeatureBins() {
        return this._bins.stream().map(AbstractBin::clone).collect(Collectors.toList());
    }

    public List<SplitStatistics> calculateSplitStatisticsForCategoricalFeature() {
        return this.calculateStatisticsForCategoricalFeatureBinomialClassification();
    }

    private List<SplitStatistics> calculateStatisticsForCategoricalFeatureFullApproach() {
        assert (this._numOfCategories <= 10);
        String categories = this._bins.stream().map(b -> String.valueOf(((CategoricalBin)b)._category)).collect(Collectors.joining(""));
        Set<boolean[]> splits = this.findAllCategoricalSplits(categories);
        ArrayList<SplitStatistics> statistics = new ArrayList<SplitStatistics>();
        for (boolean[] splitMask : splits) {
            SplitStatistics splitStatistics = new SplitStatistics();
            for (AbstractBin bin : this._bins) {
                if (splitMask[((CategoricalBin)bin)._category]) {
                    splitStatistics.accumulateLeftStatistics(bin._count, bin._count0);
                    continue;
                }
                splitStatistics.accumulateRightStatistics(bin._count, bin._count0);
            }
            splitStatistics._splittingRule = new CategoricalSplittingRule(splitMask);
            statistics.add(splitStatistics);
        }
        return statistics;
    }

    private Set<boolean[]> findAllCategoricalSplits(String categories) {
        int recMaxDepth = categories.length() / 2;
        HashSet<boolean[]> masks = new HashSet<boolean[]>();
        for (int depth = 1; depth < recMaxDepth; ++depth) {
            String[] stringArray = categories.split("");
            int n = stringArray.length;
            for (int i = 0; i < n; ++i) {
                String s = stringArray[i];
                this.rec(masks, s, categories.substring(0).replaceAll(s, ""), depth - 1);
            }
        }
        if (categories.length() == recMaxDepth * 2) {
            this.rec(masks, categories.substring(0, 1), categories.substring(1), recMaxDepth - 1);
        } else {
            for (String s : categories.split("")) {
                this.rec(masks, s, categories.substring(0).replaceAll(s, ""), recMaxDepth - 1);
            }
        }
        return masks;
    }

    private void rec(Set<boolean[]> masks, String current, String categories, int stepsToGo) {
        if (stepsToGo == 0) {
            masks.add(this.createMaskFromString(current));
            return;
        }
        for (String s : categories.split("")) {
            if (s.charAt(0) <= current.charAt(current.length() - 1)) continue;
            this.rec(masks, current + s, categories.substring(0).replaceAll(s, ""), stepsToGo - 1);
        }
    }

    private boolean[] createMaskFromString(String categories) {
        boolean[] mask = new boolean[this._numOfCategories];
        for (String c : categories.split("")) {
            mask[Integer.parseInt((String)c)] = true;
        }
        return mask;
    }

    private boolean[] createMaskFromBins(List<CategoricalBin> bins) {
        boolean[] mask = new boolean[this._numOfCategories];
        bins.stream().map(CategoricalBin::getCategory).forEach(c -> {
            mask[c.intValue()] = true;
        });
        return mask;
    }

    public List<SplitStatistics> calculateStatisticsForCategoricalFeatureBinomialClassification() {
        List sortedBins = this._bins.stream().map(b -> (CategoricalBin)b).sorted(Comparator.comparingInt(AbstractBin::getCount0)).collect(Collectors.toList());
        List<SplitStatistics> statistics = sortedBins.stream().map(b -> new SplitStatistics()).collect(Collectors.toList());
        SplitStatistics tmpAccumulatorLeft = new SplitStatistics();
        SplitStatistics tmpAccumulatorRight = new SplitStatistics();
        for (int leftIndex = 0; leftIndex < statistics.size(); ++leftIndex) {
            tmpAccumulatorLeft.accumulateLeftStatistics(((CategoricalBin)sortedBins.get((int)leftIndex))._count, ((CategoricalBin)sortedBins.get((int)leftIndex))._count0);
            statistics.get(leftIndex).copyLeftValues(tmpAccumulatorLeft);
            statistics.get((int)leftIndex)._splittingRule = new CategoricalSplittingRule(this.createMaskFromBins(sortedBins.subList(0, leftIndex + 1)));
            int rightIndex = sortedBins.size() - leftIndex - 1;
            statistics.get(rightIndex).copyRightValues(tmpAccumulatorRight);
            tmpAccumulatorRight.accumulateRightStatistics(((CategoricalBin)sortedBins.get((int)rightIndex))._count, ((CategoricalBin)sortedBins.get((int)rightIndex))._count0);
        }
        return statistics;
    }
}

