/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.decisiontree;

import java.util.ArrayDeque;
import java.util.Deque;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.decisiontree.DecisionTreeLoss;
import org.neo4j.gds.decisiontree.DecisionTreePredict;
import org.neo4j.gds.decisiontree.DecisionTreeTrainConfig;
import org.neo4j.gds.decisiontree.FeatureBagger;
import org.neo4j.gds.decisiontree.GroupSizes;
import org.neo4j.gds.decisiontree.Groups;
import org.neo4j.gds.decisiontree.ImmutableGroupSizes;
import org.neo4j.gds.decisiontree.ImmutableGroups;
import org.neo4j.gds.decisiontree.ImmutableReadOnlyGroups;
import org.neo4j.gds.decisiontree.ImmutableSplit;
import org.neo4j.gds.decisiontree.ImmutableStackRecord;
import org.neo4j.gds.decisiontree.ReadOnlyGroups;
import org.neo4j.gds.decisiontree.TreeNode;
import org.neo4j.gds.models.Features;

public abstract class DecisionTreeTrain<LOSS extends DecisionTreeLoss, PREDICTION> {
    private final LOSS lossFunction;
    private final Features features;
    private final DecisionTreeTrainConfig config;
    private final FeatureBagger featureBagger;

    DecisionTreeTrain(Features features, DecisionTreeTrainConfig config, LOSS lossFunction, FeatureBagger featureBagger) {
        this.lossFunction = lossFunction;
        this.features = features;
        this.config = config;
        this.featureBagger = featureBagger;
    }

    public DecisionTreePredict<PREDICTION> train(ReadOnlyHugeLongArray trainSetIndices) {
        ArrayDeque<StackRecord<PREDICTION>> stack = new ArrayDeque<StackRecord<PREDICTION>>();
        TreeNode<PREDICTION> root = this.splitAndPush(stack, trainSetIndices, trainSetIndices.size(), 1);
        int maxDepth = this.config.maxDepth();
        int minSplitSize = this.config.minSplitSize();
        while (!stack.isEmpty()) {
            StackRecord<PREDICTION> record = stack.pop();
            Split split = record.split();
            if (record.depth() >= maxDepth || split.sizes().left() < (long)minSplitSize) {
                record.node().setLeftChild(new TreeNode<PREDICTION>(this.toTerminal(split.groups().left(), split.sizes().left())));
            } else {
                record.node().setLeftChild(this.splitAndPush(stack, split.groups().left(), split.sizes().left(), record.depth() + 1));
            }
            if (record.depth() >= maxDepth || split.sizes().right() < (long)minSplitSize) {
                record.node().setRightChild(new TreeNode<PREDICTION>(this.toTerminal(split.groups().right(), split.sizes().right())));
                continue;
            }
            record.node().setRightChild(this.splitAndPush(stack, split.groups().right(), split.sizes().right(), record.depth() + 1));
        }
        return new DecisionTreePredict<PREDICTION>(root);
    }

    protected abstract PREDICTION toTerminal(ReadOnlyHugeLongArray var1, long var2);

    private TreeNode<PREDICTION> splitAndPush(Deque<StackRecord<PREDICTION>> stack, ReadOnlyHugeLongArray group, long groupSize, int depth) {
        assert (groupSize > 0L);
        assert (group.size() >= groupSize);
        assert (depth >= 1);
        Split split = this.findBestSplit(group, groupSize);
        if (split.sizes().right() == 0L) {
            return new TreeNode<PREDICTION>(this.toTerminal(split.groups().left(), split.sizes().left()));
        }
        if (split.sizes().left() == 0L) {
            return new TreeNode<PREDICTION>(this.toTerminal(split.groups().right(), split.sizes().right()));
        }
        TreeNode node = new TreeNode(split.index(), split.value());
        stack.push(ImmutableStackRecord.of(node, split, depth));
        return node;
    }

    private GroupSizes createSplit(int index, double value, ReadOnlyHugeLongArray group, long groupSize, Groups groups) {
        assert (groupSize > 0L);
        assert (group.size() >= groupSize);
        assert (index >= 0 && index < this.features.featureDimension());
        long leftGroupSize = 0L;
        long rightGroupSize = 0L;
        HugeLongArray leftGroup = groups.left();
        HugeLongArray rightGroup = groups.right();
        int i = 0;
        while ((long)i < groupSize) {
            long featuresIdx = group.get((long)i);
            double[] featureVector = this.features.get(featuresIdx);
            if (featureVector[index] < value) {
                leftGroup.set(leftGroupSize++, featuresIdx);
            } else {
                rightGroup.set(rightGroupSize++, featuresIdx);
            }
            ++i;
        }
        return ImmutableGroupSizes.of(leftGroupSize, rightGroupSize);
    }

    private Split findBestSplit(ReadOnlyHugeLongArray group, long groupSize) {
        assert (groupSize > 0L);
        assert (group.size() >= groupSize);
        int bestIdx = -1;
        double bestValue = Double.MAX_VALUE;
        double bestLoss = Double.MAX_VALUE;
        Groups childGroups = ImmutableGroups.of(HugeLongArray.newArray((long)groupSize), HugeLongArray.newArray((long)groupSize));
        Groups bestChildGroups = ImmutableGroups.of(HugeLongArray.newArray((long)groupSize), HugeLongArray.newArray((long)groupSize));
        GroupSizes bestGroupSizes = ImmutableGroupSizes.of(-1L, -1L);
        int[] featureBag = this.featureBagger.sample();
        for (long j = 0L; j < groupSize; ++j) {
            for (int i : featureBag) {
                double[] featureVector = this.features.get(group.get(j));
                GroupSizes groupSizes = this.createSplit(i, featureVector[i], group, groupSize, childGroups);
                double loss = this.lossFunction.splitLoss(childGroups, groupSizes);
                if (!(loss < bestLoss)) continue;
                bestIdx = i;
                bestValue = featureVector[i];
                bestLoss = loss;
                Groups tmpGroups = bestChildGroups;
                bestChildGroups = childGroups;
                childGroups = tmpGroups;
                bestGroupSizes = groupSizes;
            }
        }
        return ImmutableSplit.of(bestIdx, bestValue, ImmutableReadOnlyGroups.of(ReadOnlyHugeLongArray.of((HugeLongArray)bestChildGroups.left()), ReadOnlyHugeLongArray.of((HugeLongArray)bestChildGroups.right())), bestGroupSizes);
    }

    @ValueClass
    static interface StackRecord<PREDICTION> {
        public TreeNode<PREDICTION> node();

        public Split split();

        public int depth();
    }

    @ValueClass
    static interface Split {
        public int index();

        public double value();

        public ReadOnlyGroups groups();

        public GroupSizes sizes();
    }
}

