/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.models.randomforest;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.LongAdder;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.paged.HugeAtomicLongArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.decisiontree.DecisionTreePredict;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.models.Features;

public final class OutOfBagErrorMetric {
    private OutOfBagErrorMetric() {
    }

    static void addPredictionsForTree(DecisionTreePredict<Integer> decisionTree, LocalIdMap classMapping, Features allFeatureVectors, ReadOnlyHugeLongArray trainSet, BitSet sampledTrainSet, HugeAtomicLongArray predictions) {
        int numClasses = classMapping.size();
        for (long trainSetIdx = 0L; trainSetIdx < trainSet.size(); ++trainSetIdx) {
            if (sampledTrainSet.get(trainSetIdx)) continue;
            double[] featureVector = allFeatureVectors.get(trainSet.get(trainSetIdx));
            Integer prediction = decisionTree.predict(featureVector);
            predictions.getAndAdd(trainSetIdx * (long)numClasses + (long)prediction.intValue(), 1L);
        }
    }

    public static double evaluate(ReadOnlyHugeLongArray trainSet, LocalIdMap classMapping, HugeLongArray expectedLabels, int concurrency, HugeAtomicLongArray predictions) {
        LongAdder totalMistakes = new LongAdder();
        LongAdder totalOutOfAnyBagVectors = new LongAdder();
        List tasks = PartitionUtils.rangePartition((int)concurrency, (long)trainSet.size(), partition -> OutOfBagErrorMetric.accumulationTask(partition, classMapping, trainSet, predictions, expectedLabels, totalMistakes, totalOutOfAnyBagVectors), Optional.empty());
        ParallelUtil.runWithConcurrency((int)concurrency, (Iterable)tasks, (ExecutorService)Pools.DEFAULT);
        return totalMistakes.doubleValue() / totalOutOfAnyBagVectors.doubleValue();
    }

    private static Runnable accumulationTask(Partition partition, LocalIdMap classMapping, ReadOnlyHugeLongArray trainSet, HugeAtomicLongArray predictions, HugeLongArray expectedLabels, LongAdder totalMistakes, LongAdder totalOutOfAnyBagVectors) {
        return () -> {
            int numClasses = classMapping.size();
            long numMistakes = 0L;
            long numOutOfAnyBagVectors = 0L;
            long startOffset = partition.startNode();
            long endOffset = startOffset + partition.nodeCount();
            for (long i = startOffset; i < endOffset; ++i) {
                long innerOffset = i * (long)numClasses;
                long max = 0L;
                int maxClassIdx = 0;
                for (int j = 0; j < numClasses; ++j) {
                    long numPredictions = predictions.get(innerOffset + (long)j);
                    if (numPredictions <= max) continue;
                    max = numPredictions;
                    maxClassIdx = j;
                }
                if (max == 0L) continue;
                ++numOutOfAnyBagVectors;
                if (classMapping.toOriginal(maxClassIdx) == expectedLabels.get(trainSet.get(i))) continue;
                ++numMistakes;
            }
            totalMistakes.add(numMistakes);
            totalOutOfAnyBagVectors.add(numOutOfAnyBagVectors);
        };
    }
}

