/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.models.randomforest;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.paged.HugeAtomicLongArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.decisiontree.ClassificationDecisionTreeTrain;
import org.neo4j.gds.decisiontree.DecisionTreeLoss;
import org.neo4j.gds.decisiontree.DecisionTreePredict;
import org.neo4j.gds.decisiontree.DecisionTreeTrainConfig;
import org.neo4j.gds.decisiontree.DecisionTreeTrainConfigImpl;
import org.neo4j.gds.decisiontree.FeatureBagger;
import org.neo4j.gds.decisiontree.GiniIndex;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.models.Features;
import org.neo4j.gds.models.Trainer;
import org.neo4j.gds.models.randomforest.ClassificationRandomForestPredictor;
import org.neo4j.gds.models.randomforest.DatasetBootstrapper;
import org.neo4j.gds.models.randomforest.ImmutableBootstrappedDataset;
import org.neo4j.gds.models.randomforest.OutOfBagErrorMetric;
import org.neo4j.gds.models.randomforest.RandomForestTrainConfig;
import org.neo4j.gds.utils.StringFormatting;

public class ClassificationRandomForestTrainer
implements Trainer {
    private final LocalIdMap classIdMap;
    private final RandomForestTrainConfig config;
    private final int concurrency;
    private final boolean computeOutOfBagError;
    private final SplittableRandom random;
    private final ProgressTracker progressTracker;
    private Optional<Double> outOfBagError = Optional.empty();

    public ClassificationRandomForestTrainer(int concurrency, LocalIdMap classIdMap, RandomForestTrainConfig config, boolean computeOutOfBagError, Optional<Long> randomSeed, ProgressTracker progressTracker) {
        this.classIdMap = classIdMap;
        this.config = config;
        this.concurrency = concurrency;
        this.computeOutOfBagError = computeOutOfBagError;
        this.random = new SplittableRandom(randomSeed.orElseGet(() -> new SplittableRandom().nextLong()));
        this.progressTracker = progressTracker;
    }

    @Override
    public ClassificationRandomForestPredictor train(Features allFeatureVectors, HugeLongArray allLabels, ReadOnlyHugeLongArray trainSet) {
        Optional<HugeAtomicLongArray> maybePredictions = this.computeOutOfBagError ? Optional.of(HugeAtomicLongArray.newArray((long)((long)this.classIdMap.size() * trainSet.size()))) : Optional.empty();
        DecisionTreeTrainConfig decisionTreeTrainConfig = DecisionTreeTrainConfigImpl.builder().maxDepth(this.config.maxDepth()).minSplitSize(this.config.minSplitSize()).build();
        int numberOfDecisionTrees = this.config.numberOfDecisionTrees();
        GiniIndex lossFunction = GiniIndex.fromOriginalLabels(allLabels, this.classIdMap);
        this.progressTracker.setVolume((long)numberOfDecisionTrees);
        AtomicInteger numberOfTreesTrained = new AtomicInteger(0);
        List tasks = IntStream.range(0, numberOfDecisionTrees).mapToObj(unused -> new TrainDecisionTreeTask<GiniIndex>(maybePredictions, decisionTreeTrainConfig, this.config, this.random.split(), allFeatureVectors, allLabels, this.classIdMap, lossFunction, trainSet, this.progressTracker, numberOfTreesTrained)).collect(Collectors.toList());
        ParallelUtil.runWithConcurrency((int)this.concurrency, tasks, (ExecutorService)Pools.DEFAULT);
        this.outOfBagError = maybePredictions.map(predictions -> OutOfBagErrorMetric.evaluate(trainSet, this.classIdMap, allLabels, this.concurrency, predictions));
        List<DecisionTreePredict<Integer>> decisionTrees = tasks.stream().map(TrainDecisionTreeTask::trainedTree).collect(Collectors.toList());
        return new ClassificationRandomForestPredictor(decisionTrees, this.classIdMap, allFeatureVectors.featureDimension());
    }

    double outOfBagError() {
        return this.outOfBagError.orElseThrow(() -> new IllegalAccessError("Out of bag error has not been computed."));
    }

    static class TrainDecisionTreeTask<LOSS extends DecisionTreeLoss>
    implements Runnable {
        private DecisionTreePredict<Integer> trainedTree;
        private final Optional<HugeAtomicLongArray> maybePredictions;
        private final DecisionTreeTrainConfig decisionTreeTrainConfig;
        private final RandomForestTrainConfig randomForestTrainConfig;
        private final SplittableRandom random;
        private final Features allFeatureVectors;
        private final HugeLongArray allLabels;
        private final LocalIdMap classIdMap;
        private final LOSS lossFunction;
        private final ReadOnlyHugeLongArray trainSet;
        private final ProgressTracker progressTracker;
        private final AtomicInteger numberOfTreesTrained;

        TrainDecisionTreeTask(Optional<HugeAtomicLongArray> maybePredictions, DecisionTreeTrainConfig decisionTreeTrainConfig, RandomForestTrainConfig randomForestTrainConfig, SplittableRandom random, Features allFeatureVectors, HugeLongArray allLabels, LocalIdMap classIdMap, LOSS lossFunction, ReadOnlyHugeLongArray trainSet, ProgressTracker progressTracker, AtomicInteger numberOfTreesTrained) {
            this.maybePredictions = maybePredictions;
            this.decisionTreeTrainConfig = decisionTreeTrainConfig;
            this.randomForestTrainConfig = randomForestTrainConfig;
            this.random = random;
            this.allFeatureVectors = allFeatureVectors;
            this.allLabels = allLabels;
            this.classIdMap = classIdMap;
            this.lossFunction = lossFunction;
            this.trainSet = trainSet;
            this.progressTracker = progressTracker;
            this.numberOfTreesTrained = numberOfTreesTrained;
        }

        public DecisionTreePredict<Integer> trainedTree() {
            return this.trainedTree;
        }

        @Override
        public void run() {
            FeatureBagger featureBagger = new FeatureBagger(this.random, this.allFeatureVectors.featureDimension(), this.randomForestTrainConfig.maxFeaturesRatio(this.allFeatureVectors.featureDimension()));
            ClassificationDecisionTreeTrain<LOSS> decisionTree = new ClassificationDecisionTreeTrain<LOSS>(this.lossFunction, this.allFeatureVectors, this.allLabels, this.classIdMap, this.decisionTreeTrainConfig, featureBagger);
            BootstrappedDataset bootstrappedDataset = this.bootstrappedDataset();
            this.trainedTree = decisionTree.train(bootstrappedDataset.allVectorsIndices());
            this.maybePredictions.ifPresent(predictionsCache -> OutOfBagErrorMetric.addPredictionsForTree(this.trainedTree, this.classIdMap, this.allFeatureVectors, this.trainSet, bootstrappedDataset.trainSetIndices(), predictionsCache));
            this.progressTracker.logProgress(1L, StringFormatting.formatWithLocale((String)":: trained decision tree %d out of %d", (Object[])new Object[]{this.numberOfTreesTrained.incrementAndGet(), this.randomForestTrainConfig.numberOfDecisionTrees()}));
        }

        private BootstrappedDataset bootstrappedDataset() {
            ReadOnlyHugeLongArray allVectorsIndices;
            BitSet trainSetIndices = new BitSet(this.trainSet.size());
            if (Double.compare(this.randomForestTrainConfig.numberOfSamplesRatio(), 0.0) == 0) {
                allVectorsIndices = this.trainSet;
                trainSetIndices.set(1L, this.trainSet.size());
            } else {
                allVectorsIndices = DatasetBootstrapper.bootstrap(this.random, this.randomForestTrainConfig.numberOfSamplesRatio(), this.trainSet, trainSetIndices);
            }
            return ImmutableBootstrappedDataset.of(trainSetIndices, allVectorsIndices);
        }

        @ValueClass
        static interface BootstrappedDataset {
            public BitSet trainSetIndices();

            public ReadOnlyHugeLongArray allVectorsIndices();
        }
    }
}

