/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.collections.LongMultiSet;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.ClassifierTrainerFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.TrainingMethod;
import org.neo4j.gds.ml.models.automl.RandomSearch;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.nodeClassification.ClassificationMetricComputer;
import org.neo4j.gds.ml.nodePropertyPrediction.NodeSplitter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.PipelineTrainer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.ImmutableNodeClassificationTrainResult;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.LabelsAndClassCountsExtractor;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainResult;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.training.CrossValidation;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.utils.StringFormatting;

public final class NodeClassificationTrain
implements PipelineTrainer<NodeClassificationTrainResult> {
    private final NodeClassificationTrainingPipeline pipeline;
    private final NodeClassificationPipelineTrainConfig trainConfig;
    private final HugeIntArray targets;
    private final LocalIdMap classIdMap;
    private final IdMap nodeIdMap;
    private final List<Metric> metrics;
    private final List<ClassificationMetric> classificationMetrics;
    private final LongMultiSet classCounts;
    private final NodeFeatureProducer<NodeClassificationPipelineTrainConfig> nodeFeatureProducer;
    private final ProgressTracker progressTracker;
    private TerminationFlag terminationFlag = TerminationFlag.RUNNING_TRUE;

    public static MemoryEstimation estimate(NodeClassificationTrainingPipeline pipeline, NodeClassificationPipelineTrainConfig configuration, ModelCatalog modelCatalog) {
        pipeline.validateTrainingParameterSpace();
        MemoryEstimation nodePropertyStepsEstimation = NodePropertyStepExecutor.estimateNodePropertySteps(modelCatalog, configuration.username(), pipeline.nodePropertySteps(), configuration.nodeLabels(), configuration.relationshipTypes());
        MemoryEstimation trainingEstimation = MemoryEstimations.builder().add("Training", NodeClassificationTrain.estimateExcludingNodePropertySteps(pipeline, configuration)).build();
        return MemoryEstimations.maxEstimation((String)"Node Classification Train Pipeline", List.of(nodePropertyStepsEstimation, trainingEstimation));
    }

    private static MemoryEstimation estimateExcludingNodePropertySteps(NodeClassificationTrainingPipeline pipeline, NodeClassificationPipelineTrainConfig config) {
        int fudgedClassCount = 1000;
        int fudgedFeatureCount = 500;
        NodePropertyPredictionSplitConfig splitConfig = pipeline.splitConfig();
        double testFraction = splitConfig.testFraction();
        MemoryEstimation modelSelection = NodeClassificationTrain.modelTrainAndEvaluateMemoryUsage(pipeline, fudgedClassCount, fudgedFeatureCount, splitConfig::foldTrainSetSize, splitConfig::foldTestSetSize);
        MemoryEstimation bestModelEvaluation = MemoryEstimations.delegateEstimation((MemoryEstimation)NodeClassificationTrain.modelTrainAndEvaluateMemoryUsage(pipeline, fudgedClassCount, fudgedFeatureCount, splitConfig::trainSetSize, splitConfig::testSetSize), (String)"best model evaluation");
        MemoryEstimation modelTrainingEstimation = MemoryEstimations.maxEstimation(List.of(modelSelection, bestModelEvaluation));
        MemoryEstimations.Builder builder = MemoryEstimations.builder().perNode("global targets", HugeIntArray::memoryEstimation).rangePerNode("global class counts", __ -> MemoryRange.of((long)16L, (long)((long)fudgedClassCount * 8L))).add("metrics", ClassificationMetricSpecification.memoryEstimation((int)fudgedClassCount)).perNode("node IDs", HugeLongArray::memoryEstimation).add("outer split", FractionSplitter.estimate((double)(1.0 - testFraction))).add("inner split", StratifiedKFoldSplitter.memoryEstimationForNodeSet((int)splitConfig.validationFolds(), (double)(1.0 - testFraction))).add("stats map train", TrainingStatistics.memoryEstimationStatsMap((int)config.metrics().size(), (int)pipeline.numberOfModelSelectionTrials())).add("stats map validation", TrainingStatistics.memoryEstimationStatsMap((int)config.metrics().size(), (int)pipeline.numberOfModelSelectionTrials())).add("max of model selection and best model evaluation", modelTrainingEstimation);
        if (!pipeline.trainingParameterSpace().get(TrainingMethod.RandomForestClassification).isEmpty()) {
            builder.perGraphDimension("cached feature vectors", (dim, threads) -> MemoryRange.of((long)HugeObjectArray.memoryEstimation((long)dim.nodeCount(), (long)MemoryUsage.sizeOfDoubleArray((long)10L)), (long)HugeObjectArray.memoryEstimation((long)dim.nodeCount(), (long)MemoryUsage.sizeOfDoubleArray((long)fudgedFeatureCount))));
        }
        return builder.build();
    }

    public static Task progressTask(NodeClassificationTrainingPipeline pipeline, long nodeCount) {
        NodePropertyPredictionSplitConfig splitConfig = pipeline.splitConfig();
        long trainSetSize = splitConfig.trainSetSize(nodeCount);
        long testSetSize = splitConfig.testSetSize(nodeCount);
        int validationFolds = splitConfig.validationFolds();
        ArrayList<Object> tasks = new ArrayList<Object>();
        tasks.add(NodePropertyStepExecutor.tasks(pipeline.nodePropertySteps(), nodeCount));
        tasks.addAll(CrossValidation.progressTasks((int)validationFolds, (int)pipeline.numberOfModelSelectionTrials(), (long)trainSetSize));
        tasks.add(ClassifierTrainer.progressTask((String)"Train best model", (long)(5L * trainSetSize)));
        tasks.add(Tasks.leaf((String)"Evaluate on train data", (long)trainSetSize));
        tasks.add(Tasks.leaf((String)"Evaluate on test data", (long)testSetSize));
        tasks.add(ClassifierTrainer.progressTask((String)"Retrain best model", (long)(5L * nodeCount)));
        return Tasks.task((String)"Node Classification Train Pipeline", tasks);
    }

    @NotNull
    private static MemoryEstimation modelTrainAndEvaluateMemoryUsage(NodeClassificationTrainingPipeline pipeline, int fudgedClassCount, int fudgedFeatureCount, LongUnaryOperator trainSetSize, LongUnaryOperator testSetSize) {
        List foldEstimations = pipeline.trainingParameterSpace().values().stream().flatMap(Collection::stream).flatMap(TunableTrainerConfig::streamCornerCaseConfigs).map(config -> MemoryEstimations.setup((String)"max of training and evaluation", dim -> {
            MemoryEstimation training = ClassifierTrainerFactory.memoryEstimation((TrainerConfig)config, (LongUnaryOperator)trainSetSize, (int)((int)Math.min((long)fudgedClassCount, dim.nodeCount())), (MemoryRange)MemoryRange.of((long)fudgedFeatureCount), (boolean)false);
            int batchSize = config instanceof LogisticRegressionTrainConfig ? ((LogisticRegressionTrainConfig)config).batchSize() : 0;
            MemoryEstimation evaluation = ClassificationMetricComputer.estimateEvaluation((TrainerConfig)config, (int)((int)Math.min((long)batchSize, dim.nodeCount())), (LongUnaryOperator)trainSetSize, (LongUnaryOperator)testSetSize, (int)((int)Math.min((long)fudgedClassCount, dim.nodeCount())), (int)fudgedFeatureCount, (boolean)false);
            return MemoryEstimations.maxEstimation(List.of(training, evaluation));
        })).collect(Collectors.toList());
        return MemoryEstimations.builder((String)"model selection").max(foldEstimations).build();
    }

    public static NodeClassificationTrain create(GraphStore graphStore, NodeClassificationTrainingPipeline pipeline, NodeClassificationPipelineTrainConfig config, NodeFeatureProducer<NodeClassificationPipelineTrainConfig> nodeFeatureProducer, ProgressTracker progressTracker) {
        Graph nodesGraph = graphStore.getGraph(config.targetNodeLabelIdentifiers(graphStore));
        pipeline.splitConfig().validateMinNumNodesInSplitSets(nodesGraph);
        NodePropertyValues targetNodeProperty = nodesGraph.nodeProperties(config.targetProperty());
        LabelsAndClassCountsExtractor.LabelsAndClassCounts labelsAndClassCounts = LabelsAndClassCountsExtractor.extractLabelsAndClassCounts(targetNodeProperty, nodesGraph.nodeCount());
        LongMultiSet classCounts = labelsAndClassCounts.classCounts();
        LocalIdMap classIdMap = LocalIdMap.ofSorted((long[])classCounts.keys());
        List<Metric> metrics = config.metrics(classIdMap, classCounts);
        return new NodeClassificationTrain(pipeline, config, labelsAndClassCounts.labels(), classIdMap, (IdMap)nodesGraph, metrics, NodeClassificationPipelineTrainConfig.classificationMetrics(metrics), classCounts, nodeFeatureProducer, progressTracker);
    }

    private NodeClassificationTrain(NodeClassificationTrainingPipeline pipeline, NodeClassificationPipelineTrainConfig config, HugeIntArray labels, LocalIdMap classIdMap, IdMap nodeIdMap, List<Metric> metrics, List<ClassificationMetric> classificationMetrics, LongMultiSet classCounts, NodeFeatureProducer<NodeClassificationPipelineTrainConfig> nodeFeatureProducer, ProgressTracker progressTracker) {
        this.pipeline = pipeline;
        this.nodeIdMap = nodeIdMap;
        this.classificationMetrics = classificationMetrics;
        this.nodeFeatureProducer = nodeFeatureProducer;
        this.trainConfig = config;
        this.targets = labels;
        this.classIdMap = classIdMap;
        this.metrics = metrics;
        this.classCounts = classCounts;
        this.progressTracker = progressTracker;
    }

    @Override
    public void setTerminationFlag(TerminationFlag terminationFlag) {
        this.terminationFlag = terminationFlag;
    }

    @Override
    public NodeClassificationTrainResult run() {
        this.progressTracker.beginSubTask();
        NodePropertyPredictionSplitConfig splitConfig = this.pipeline.splitConfig();
        NodeSplitter.NodeSplits nodeSplits = new NodeSplitter(this.trainConfig.concurrency(), this.nodeIdMap.nodeCount(), this.progressTracker, arg_0 -> ((IdMap)this.nodeIdMap).toOriginalNodeId(arg_0), arg_0 -> ((IdMap)this.nodeIdMap).toMappedNodeId(arg_0)).split(splitConfig.testFraction(), splitConfig.validationFolds(), this.trainConfig.randomSeed());
        TrainingStatistics trainingStatistics = new TrainingStatistics(this.metrics);
        Features features = this.nodeFeatureProducer.procedureFeatures(this.pipeline);
        this.findBestModelCandidate(nodeSplits.outerSplit().trainSet(), features, trainingStatistics);
        this.evaluateBestModel(nodeSplits.outerSplit(), features, trainingStatistics);
        Classifier retrainedModelData = this.retrainBestModel(nodeSplits.allTrainingExamples(), features, trainingStatistics.bestParameters());
        this.progressTracker.endSubTask();
        return ImmutableNodeClassificationTrainResult.of(retrainedModelData, trainingStatistics, this.classIdMap, this.classCounts);
    }

    private void findBestModelCandidate(ReadOnlyHugeLongArray trainNodeIds, Features features, TrainingStatistics trainingStatistics) {
        CrossValidation crossValidation = new CrossValidation(this.progressTracker, this.terminationFlag, this.metrics, this.pipeline.splitConfig().validationFolds(), this.trainConfig.randomSeed(), (trainSet, config, metricsHandler, messageLogLevel) -> this.trainModel(trainSet, config, features, messageLogLevel, metricsHandler), (evaluationSet, classifier, scoreConsumer) -> this.registerMetricScores(evaluationSet, (Classifier)classifier, features, scoreConsumer, ProgressTracker.NULL_TRACKER));
        RandomSearch modelCandidates = new RandomSearch(this.pipeline.trainingParameterSpace(), this.pipeline.autoTuningConfig().maxTrials(), this.trainConfig.randomSeed());
        TreeSet sortedClassIds = LongStream.range(0L, this.classCounts.size()).boxed().collect(Collectors.toCollection(TreeSet::new));
        crossValidation.selectModel(trainNodeIds, arg_0 -> ((HugeIntArray)this.targets).get(arg_0), (SortedSet)sortedClassIds, trainingStatistics, (Iterator)modelCandidates);
    }

    private void registerMetricScores(ReadOnlyHugeLongArray evaluationSet, Classifier classifier, Features features, MetricConsumer scoreConsumer, ProgressTracker customProgressTracker) {
        ClassificationMetricComputer trainMetricComputer = ClassificationMetricComputer.forEvaluationSet((Features)features, (HugeIntArray)this.targets, (ReadOnlyHugeLongArray)evaluationSet, (Classifier)classifier, (int)this.trainConfig.concurrency(), (TerminationFlag)this.terminationFlag, (ProgressTracker)customProgressTracker);
        this.classificationMetrics.forEach(metric -> scoreConsumer.consume((Metric)metric, trainMetricComputer.score(metric)));
    }

    private void evaluateBestModel(TrainingExamplesSplit outerSplit, Features features, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Train best model");
        ModelCandidateStats bestCandidate = trainingStatistics.bestCandidate();
        Classifier bestClassifier = this.trainModel(outerSplit.trainSet(), bestCandidate.trainerConfig(), features, LogLevel.INFO, ModelSpecificMetricsHandler.of(this.metrics, (arg_0, arg_1) -> ((TrainingStatistics)trainingStatistics).addTestScore(arg_0, arg_1)));
        this.progressTracker.endSubTask("Train best model");
        this.progressTracker.beginSubTask("Evaluate on train data");
        this.progressTracker.setSteps(outerSplit.trainSet().size());
        this.registerMetricScores(outerSplit.trainSet(), bestClassifier, features, (arg_0, arg_1) -> ((TrainingStatistics)trainingStatistics).addOuterTrainScore(arg_0, arg_1), this.progressTracker);
        Map outerTrainMetrics = trainingStatistics.winningModelOuterTrainMetrics();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Final model metrics on full train set: %s", (Object[])new Object[]{outerTrainMetrics}));
        this.progressTracker.endSubTask("Evaluate on train data");
        this.progressTracker.beginSubTask("Evaluate on test data");
        this.progressTracker.setSteps(outerSplit.testSet().size());
        this.registerMetricScores(outerSplit.testSet(), bestClassifier, features, (arg_0, arg_1) -> ((TrainingStatistics)trainingStatistics).addTestScore(arg_0, arg_1), this.progressTracker);
        Map testMetrics = trainingStatistics.winningModelTestMetrics();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Final model metrics on test set: %s", (Object[])new Object[]{testMetrics}));
        this.progressTracker.endSubTask("Evaluate on test data");
    }

    private Classifier retrainBestModel(ReadOnlyHugeLongArray trainSet, Features features, TrainerConfig bestParameters) {
        this.progressTracker.beginSubTask("Retrain best model");
        Classifier retrainedClassifier = this.trainModel(trainSet, bestParameters, features, LogLevel.INFO, ModelSpecificMetricsHandler.NOOP);
        this.progressTracker.endSubTask("Retrain best model");
        return retrainedClassifier;
    }

    private Classifier trainModel(ReadOnlyHugeLongArray trainSet, TrainerConfig trainerConfig, Features features, LogLevel messageLogLevel, ModelSpecificMetricsHandler metricsHandler) {
        ClassifierTrainer trainer = ClassifierTrainerFactory.create((TrainerConfig)trainerConfig, (int)this.classIdMap.size(), (TerminationFlag)this.terminationFlag, (ProgressTracker)this.progressTracker, (LogLevel)messageLogLevel, (int)this.trainConfig.concurrency(), (Optional)this.trainConfig.randomSeed(), (boolean)false, (ModelSpecificMetricsHandler)metricsHandler);
        return trainer.train(features, this.targets, trainSet);
    }
}

