/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.rtree.impl;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Output;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.impl.IntArrayContainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.impl.InvertedFeature;
import org.tribuo.regression.rtree.impl.TreeFeature;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;
import org.tribuo.util.Util;

public class JointRegressorTrainingNode
extends AbstractTrainingNode<Regressor> {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = Logger.getLogger(JointRegressorTrainingNode.class.getName());
    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> new IntArrayContainer(16));
    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> new IntArrayContainer(16));
    private transient ArrayList<TreeFeature> data;
    private final boolean normalize;
    private final ImmutableOutputInfo<Regressor> labelIDMap;
    private final ImmutableFeatureMap featureIDMap;
    private final RegressorImpurity impurity;
    private final int[] indices;
    private final float[][] targets;
    private final float[] weights;
    private final float weightSum;

    public JointRegressorTrainingNode(RegressorImpurity impurity, Dataset<Regressor> examples, boolean normalize, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        this(impurity, JointRegressorTrainingNode.invertData(examples), examples.size(), examples.getFeatureIDMap(), (ImmutableOutputInfo<Regressor>)examples.getOutputIDInfo(), normalize, leafDeterminer);
    }

    private JointRegressorTrainingNode(RegressorImpurity impurity, InvertedData tuple, int numExamples, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> outputInfo, boolean normalize, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        this(impurity, tuple.data, tuple.indices, tuple.targets, tuple.weights, numExamples, 0, featureIDMap, outputInfo, normalize, leafDeterminer);
    }

    private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, boolean normalize, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        super(depth, numExamples, leafDeterminer);
        this.data = data;
        this.normalize = normalize;
        this.featureIDMap = featureIDMap;
        this.labelIDMap = labelIDMap;
        this.impurity = impurity;
        this.indices = indices;
        this.targets = targets;
        this.weights = weights;
        this.weightSum = Util.sum((int[])indices, (int)indices.length, (float[])weights);
        this.impurityScore = this.calcImpurity(indices);
    }

    private JointRegressorTrainingNode(RegressorImpurity impurity, ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights, int numExamples, int depth, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Regressor> labelIDMap, boolean normalize, AbstractTrainingNode.LeafDeterminer leafDeterminer, float weightSum, double impurityScore) {
        super(depth, numExamples, leafDeterminer);
        this.data = data;
        this.normalize = normalize;
        this.featureIDMap = featureIDMap;
        this.labelIDMap = labelIDMap;
        this.impurity = impurity;
        this.indices = indices;
        this.targets = targets;
        this.weights = weights;
        this.weightSum = weightSum;
        this.impurityScore = impurityScore;
    }

    public double getImpurity() {
        return this.impurityScore;
    }

    public float getWeightSum() {
        return this.weightSum;
    }

    private double calcImpurity(int[] curIndices) {
        double tmp = 0.0;
        for (int i = 0; i < this.targets.length; ++i) {
            tmp += this.impurity.impurity(curIndices, this.targets[i], this.weights);
        }
        return tmp / (double)this.targets.length;
    }

    public List<AbstractTrainingNode<Regressor>> buildTree(int[] featureIDs, SplittableRandom rng, boolean useRandomSplitPoints) {
        if (useRandomSplitPoints) {
            return this.buildRandomTree(featureIDs, rng);
        }
        return this.buildGreedyTree(featureIDs);
    }

    private List<AbstractTrainingNode<Regressor>> buildGreedyTree(int[] featureIDs) {
        int bestID = -1;
        double bestSplitValue = 0.0;
        double bestScore = this.getImpurity();
        ArrayList<int[]> curIndices = new ArrayList<int[]>();
        ArrayList<int[]> bestLeftIndices = new ArrayList<int[]>();
        ArrayList<int[]> bestRightIndices = new ArrayList<int[]>();
        for (int i = 0; i < featureIDs.length; ++i) {
            int j;
            List<InvertedFeature> feature = this.data.get(featureIDs[i]).getFeature();
            curIndices.clear();
            for (j = 0; j < feature.size(); ++j) {
                InvertedFeature f = feature.get(j);
                int[] curFeatureIndices = f.indices();
                curIndices.add(curFeatureIndices);
            }
            for (j = 0; j < feature.size() - 1; ++j) {
                List<int[]> curLeftIndices = curIndices.subList(0, j + 1);
                List<int[]> curRightIndices = curIndices.subList(j + 1, feature.size());
                double lessThanScore = 0.0;
                double greaterThanScore = 0.0;
                for (int k = 0; k < this.targets.length; ++k) {
                    RegressorImpurity.ImpurityTuple left = this.impurity.impurityTuple(curLeftIndices, this.targets[k], this.weights);
                    lessThanScore += (double)(left.impurity * left.weight);
                    RegressorImpurity.ImpurityTuple right = this.impurity.impurityTuple(curRightIndices, this.targets[k], this.weights);
                    greaterThanScore += (double)(right.impurity * right.weight);
                }
                double score = (lessThanScore + greaterThanScore) / (double)((float)this.targets.length * this.weightSum);
                if (!(score < bestScore)) continue;
                bestID = i;
                bestScore = score;
                bestSplitValue = (feature.get((int)j).value + feature.get((int)(j + 1)).value) / 2.0;
                bestLeftIndices.clear();
                bestLeftIndices.addAll(curLeftIndices);
                bestRightIndices.clear();
                bestRightIndices.addAll(curRightIndices);
            }
        }
        double impurityDecrease = (double)this.weightSum * (this.getImpurity() - bestScore);
        List<AbstractTrainingNode<Regressor>> output = bestID != -1 && impurityDecrease >= (double)this.leafDeterminer.getScaledMinImpurityDecrease() ? this.splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices) : Collections.emptyList();
        this.data = null;
        return output;
    }

    private List<AbstractTrainingNode<Regressor>> buildRandomTree(int[] featureIDs, SplittableRandom rng) {
        int bestID = -1;
        double bestSplitValue = 0.0;
        double bestScore = this.getImpurity();
        ArrayList<int[]> curLeftIndices = new ArrayList<int[]>();
        ArrayList<int[]> curRightIndices = new ArrayList<int[]>();
        ArrayList<int[]> bestLeftIndices = new ArrayList<int[]>();
        ArrayList<int[]> bestRightIndices = new ArrayList<int[]>();
        for (int i = 0; i < featureIDs.length; ++i) {
            InvertedFeature vf;
            int j;
            List<InvertedFeature> feature = this.data.get(featureIDs[i]).getFeature();
            if (feature.size() == 1) continue;
            double lessThanScore = 0.0;
            double greaterThanScore = 0.0;
            int splitIdx = rng.nextInt(feature.size() - 1);
            for (j = 0; j < splitIdx + 1; ++j) {
                vf = feature.get(j);
                curLeftIndices.add(vf.indices());
            }
            for (j = splitIdx + 1; j < feature.size(); ++j) {
                vf = feature.get(j);
                curRightIndices.add(vf.indices());
            }
            for (int k = 0; k < this.targets.length; ++k) {
                RegressorImpurity.ImpurityTuple left = this.impurity.impurityTuple(curLeftIndices, this.targets[k], this.weights);
                lessThanScore += (double)(left.impurity * left.weight);
                RegressorImpurity.ImpurityTuple right = this.impurity.impurityTuple(curRightIndices, this.targets[k], this.weights);
                greaterThanScore += (double)(right.impurity * right.weight);
            }
            double score = (lessThanScore + greaterThanScore) / (double)((float)this.targets.length * this.weightSum);
            if (!(score < bestScore)) continue;
            bestID = i;
            bestScore = score;
            bestSplitValue = (feature.get((int)splitIdx).value + feature.get((int)(splitIdx + 1)).value) / 2.0;
            bestLeftIndices.clear();
            bestLeftIndices.addAll(curLeftIndices);
            bestRightIndices.clear();
            bestRightIndices.addAll(curRightIndices);
        }
        double impurityDecrease = (double)this.weightSum * (this.getImpurity() - bestScore);
        List<AbstractTrainingNode<Regressor>> output = bestID != -1 && impurityDecrease >= (double)this.leafDeterminer.getScaledMinImpurityDecrease() ? this.splitAtBest(featureIDs, bestID, bestSplitValue, bestLeftIndices, bestRightIndices) : Collections.emptyList();
        this.data = null;
        return output;
    }

    private List<AbstractTrainingNode<Regressor>> splitAtBest(int[] featureIDs, int bestID, double bestSplitValue, List<int[]> bestLeftIndices, List<int[]> bestRightIndices) {
        JointRegressorTrainingNode tmpNode;
        this.splitID = featureIDs[bestID];
        this.split = true;
        this.splitValue = bestSplitValue;
        IntArrayContainer firstBuffer = mergeBufferOne.get();
        firstBuffer.size = 0;
        firstBuffer.grow(this.indices.length);
        IntArrayContainer secondBuffer = mergeBufferTwo.get();
        secondBuffer.size = 0;
        secondBuffer.grow(this.indices.length);
        int[] leftIndices = IntArrayContainer.merge(bestLeftIndices, (IntArrayContainer)firstBuffer, (IntArrayContainer)secondBuffer);
        int[] rightIndices = IntArrayContainer.merge(bestRightIndices, (IntArrayContainer)firstBuffer, (IntArrayContainer)secondBuffer);
        float leftWeightSum = Util.sum((int[])leftIndices, (int)leftIndices.length, (float[])this.weights);
        double leftImpurityScore = this.calcImpurity(leftIndices);
        float rightWeightSum = Util.sum((int[])rightIndices, (int)rightIndices.length, (float[])this.weights);
        double rightImpurityScore = this.calcImpurity(rightIndices);
        boolean shouldMakeLeftLeaf = this.shouldMakeLeaf(leftImpurityScore, leftWeightSum);
        boolean shouldMakeRightLeaf = this.shouldMakeLeaf(rightImpurityScore, rightWeightSum);
        if (shouldMakeLeftLeaf && shouldMakeRightLeaf) {
            this.lessThanOrEqual = this.createLeaf(leftImpurityScore, leftIndices);
            this.greaterThan = this.createLeaf(rightImpurityScore, rightIndices);
            return Collections.emptyList();
        }
        ArrayList<TreeFeature> lessThanData = new ArrayList<TreeFeature>(this.data.size());
        ArrayList<TreeFeature> greaterThanData = new ArrayList<TreeFeature>(this.data.size());
        for (TreeFeature feature : this.data) {
            Pair<TreeFeature, TreeFeature> split = feature.split(leftIndices, rightIndices, firstBuffer, secondBuffer);
            lessThanData.add((TreeFeature)split.getA());
            greaterThanData.add((TreeFeature)split.getB());
        }
        ArrayList<AbstractTrainingNode<Regressor>> output = new ArrayList<AbstractTrainingNode<Regressor>>(2);
        if (shouldMakeLeftLeaf) {
            this.lessThanOrEqual = this.createLeaf(leftImpurityScore, leftIndices);
        } else {
            tmpNode = new JointRegressorTrainingNode(this.impurity, lessThanData, leftIndices, this.targets, this.weights, leftIndices.length, this.depth + 1, this.featureIDMap, this.labelIDMap, this.normalize, this.leafDeterminer, leftWeightSum, leftImpurityScore);
            this.lessThanOrEqual = tmpNode;
            output.add(tmpNode);
        }
        if (shouldMakeRightLeaf) {
            this.greaterThan = this.createLeaf(rightImpurityScore, rightIndices);
        } else {
            tmpNode = new JointRegressorTrainingNode(this.impurity, greaterThanData, rightIndices, this.targets, this.weights, rightIndices.length, this.depth + 1, this.featureIDMap, this.labelIDMap, this.normalize, this.leafDeterminer, rightWeightSum, rightImpurityScore);
            this.greaterThan = tmpNode;
            output.add(tmpNode);
        }
        return output;
    }

    public Node<Regressor> convertTree() {
        if (this.split) {
            return this.createSplitNode();
        }
        return this.createLeaf(this.getImpurity(), this.indices);
    }

    private LeafNode<Regressor> createLeaf(double impurityScore, int[] leafIndices) {
        Regressor leafPred;
        double leafWeightSum = 0.0;
        double[] mean = new double[this.targets.length];
        if (this.normalize) {
            int i;
            for (int i2 = 0; i2 < leafIndices.length; ++i2) {
                int idx = leafIndices[i2];
                float weight = this.weights[idx];
                leafWeightSum += (double)weight;
                int j = 0;
                while (j < this.targets.length) {
                    float value = this.targets[j][idx];
                    double oldMean = mean[j];
                    int n = j++;
                    mean[n] = mean[n] + (double)weight / leafWeightSum * ((double)value - oldMean);
                }
            }
            String[] names = new String[this.targets.length];
            double sum = 0.0;
            for (i = 0; i < this.targets.length; ++i) {
                names[i] = ((Regressor)this.labelIDMap.getOutput(i)).getNames()[0];
                sum += mean[i];
            }
            i = 0;
            while (i < this.targets.length) {
                int n = i++;
                mean[n] = mean[n] / sum;
            }
            leafPred = new Regressor(names, mean);
        } else {
            double[] variance = new double[this.targets.length];
            for (int i = 0; i < leafIndices.length; ++i) {
                int idx = leafIndices[i];
                float weight = this.weights[idx];
                leafWeightSum += (double)weight;
                for (int j = 0; j < this.targets.length; ++j) {
                    float value = this.targets[j][idx];
                    double oldMean = mean[j];
                    int n = j;
                    mean[n] = mean[n] + (double)weight / leafWeightSum * ((double)value - oldMean);
                    int n2 = j;
                    variance[n2] = variance[n2] + (double)weight * ((double)value - oldMean) * ((double)value - mean[j]);
                }
            }
            String[] names = new String[this.targets.length];
            for (int i = 0; i < this.targets.length; ++i) {
                names[i] = ((Regressor)this.labelIDMap.getOutput(i)).getNames()[0];
                variance[i] = leafIndices.length > 1 ? variance[i] / (leafWeightSum - 1.0) : 0.0;
            }
            leafPred = new Regressor(names, mean, variance);
        }
        return new LeafNode(impurityScore, (Output)leafPred, Collections.emptyMap(), false);
    }

    private static InvertedData invertData(Dataset<Regressor> examples) {
        ImmutableFeatureMap featureInfos = examples.getFeatureIDMap();
        ImmutableOutputInfo labelInfo = examples.getOutputIDInfo();
        int numLabels = labelInfo.size();
        int numFeatures = featureInfos.size();
        int[] indices = new int[examples.size()];
        float[][] targets = new float[labelInfo.size()][examples.size()];
        float[] weights = new float[examples.size()];
        logger.fine("Building initial List<TreeFeature> for " + numFeatures + " features and " + numLabels + " outputs");
        ArrayList<TreeFeature> data = new ArrayList<TreeFeature>(featureInfos.size());
        for (int i = 0; i < featureInfos.size(); ++i) {
            data.add(new TreeFeature(i));
        }
        int[] ids = ((ImmutableRegressionInfo)labelInfo).getNaturalOrderToIDMapping();
        for (int i = 0; i < examples.size(); ++i) {
            Example e = examples.getExample(i);
            indices[i] = i;
            weights[i] = e.getWeight();
            double[] output = ((Regressor)e.getOutput()).getValues();
            for (int j = 0; j < targets.length; ++j) {
                targets[ids[j]][i] = (float)output[j];
            }
            SparseVector vec = SparseVector.createSparseVector((Example)e, (ImmutableFeatureMap)featureInfos, (boolean)false);
            int lastID = 0;
            for (VectorTuple f : vec) {
                int curID = f.index;
                for (int j = lastID; j < curID; ++j) {
                    data.get(j).observeValue(0.0, i);
                }
                data.get(curID).observeValue(f.value, i);
                if (lastID > curID) {
                    logger.severe("Example = " + e.toString());
                    throw new IllegalStateException("Features aren't ordered. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
                }
                if (lastID - 1 == curID) {
                    logger.severe("Example = " + e.toString());
                    throw new IllegalStateException("Features are repeated. At id " + i + ", lastID = " + lastID + ", curID = " + curID);
                }
                lastID = curID + 1;
            }
            for (int j = lastID; j < numFeatures; ++j) {
                data.get(j).observeValue(0.0, i);
            }
            if (i % 1000 != 0) continue;
            logger.fine("Processed example " + i);
        }
        logger.fine("Sorting features");
        data.forEach(TreeFeature::sort);
        logger.fine("Fixing InvertedFeature sizes");
        data.forEach(TreeFeature::fixSize);
        logger.fine("Built initial List<TreeFeature>");
        return new InvertedData(data, indices, targets, weights);
    }

    private void writeObject(ObjectOutputStream stream) throws IOException {
        throw new NotSerializableException("JointRegressorTrainingNode is a runtime class only, and should not be serialized.");
    }

    private static class InvertedData {
        final ArrayList<TreeFeature> data;
        final int[] indices;
        final float[][] targets;
        final float[] weights;

        InvertedData(ArrayList<TreeFeature> data, int[] indices, float[][] targets, float[] weights) {
            this.data = data;
            this.indices = indices;
            this.targets = targets;
            this.weights = weights;
        }
    }
}

