/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.ReadOnlyHugeLongIdentityArray;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.ImmutableEvaluationScores;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.metrics.ModelStatsBuilder;
import org.neo4j.gds.ml.metrics.SignedProbabilities;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.ClassifierTrainerFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.models.automl.RandomSearch;
import org.neo4j.gds.ml.models.automl.TunableTrainerConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.ExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.FeaturesAndLabels;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.ImmutableLinkPredictionTrainResult;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkFeaturesAndLabelsExtractor;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainResult;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.training.CrossValidation;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.utils.StringFormatting;

public final class LinkPredictionTrain {
    private final Graph trainGraph;
    private final Graph validationGraph;
    private final LinkPredictionTrainingPipeline pipeline;
    private final LinkPredictionTrainConfig config;
    private final LocalIdMap classIdMap;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;

    private static LocalIdMap makeClassIdMap() {
        return LocalIdMap.of((long[])new long[]{0L, 1L});
    }

    public LinkPredictionTrain(Graph trainGraph, Graph validationGraph, LinkPredictionTrainingPipeline pipeline, LinkPredictionTrainConfig config, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.trainGraph = trainGraph;
        this.validationGraph = validationGraph;
        this.pipeline = pipeline;
        this.config = config;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
        this.classIdMap = LinkPredictionTrain.makeClassIdMap();
    }

    public static List<Task> progressTasks(long relationshipCount, LinkPredictionSplitConfig splitConfig, int numberOfModelSelectionTrials) {
        ExpectedSetSizes sizes = splitConfig.expectedSetSizes(relationshipCount);
        ArrayList<Task> tasks = new ArrayList<Task>();
        tasks.add((Task)Tasks.leaf((String)"Extract train features", (long)(sizes.trainSize() * 3L)));
        tasks.addAll(CrossValidation.progressTasks((int)splitConfig.validationFolds(), (int)numberOfModelSelectionTrials, (long)sizes.trainSize()));
        tasks.add(ClassifierTrainer.progressTask((String)"Train best model", (long)(sizes.trainSize() * 5L)));
        tasks.add((Task)Tasks.leaf((String)"Compute train metrics", (long)sizes.trainSize()));
        tasks.add(Tasks.task((String)"Evaluate on test data", (Task)Tasks.leaf((String)"Extract test features", (long)(sizes.testSize() * 3L)), (Task[])new Task[]{Tasks.leaf((String)"Compute test metrics", (long)sizes.testSize())}));
        return tasks;
    }

    public LinkPredictionTrainResult compute() {
        this.progressTracker.beginSubTask("Extract train features");
        FeaturesAndLabels trainData = LinkFeaturesAndLabelsExtractor.extractFeaturesAndLabels(this.trainGraph, this.pipeline.featureSteps(), this.config.concurrency(), this.progressTracker, this.terminationFlag);
        ReadOnlyHugeLongIdentityArray trainRelationshipIds = new ReadOnlyHugeLongIdentityArray(trainData.size());
        this.progressTracker.endSubTask("Extract train features");
        TrainingStatistics trainingStatistics = new TrainingStatistics(this.config.metrics());
        this.findBestModelCandidate(trainData, (ReadOnlyHugeLongArray)trainRelationshipIds, trainingStatistics);
        this.progressTracker.beginSubTask("Train best model");
        Classifier classifier = this.trainModel(trainData, (ReadOnlyHugeLongArray)trainRelationshipIds, trainingStatistics.bestParameters(), LogLevel.INFO, ModelSpecificMetricsHandler.of(this.config.metrics(), (arg_0, arg_1) -> ((TrainingStatistics)trainingStatistics).addTestScore(arg_0, arg_1)));
        this.progressTracker.endSubTask("Train best model");
        this.progressTracker.beginSubTask("Compute train metrics");
        this.computeTrainMetric(trainData, classifier, (ReadOnlyHugeLongArray)trainRelationshipIds, (arg_0, arg_1) -> ((TrainingStatistics)trainingStatistics).addOuterTrainScore(arg_0, arg_1), this.progressTracker);
        this.progressTracker.endSubTask("Compute train metrics");
        Map outerTrainMetrics = trainingStatistics.winningModelOuterTrainMetrics();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Final model metrics on full train set: %s", (Object[])new Object[]{outerTrainMetrics}));
        this.progressTracker.beginSubTask("Evaluate on test data");
        this.computeTestMetric(classifier, trainingStatistics);
        this.progressTracker.endSubTask("Evaluate on test data");
        Map testMetrics = trainingStatistics.winningModelTestMetrics();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Final model metrics on test set: %s", (Object[])new Object[]{testMetrics}));
        return ImmutableLinkPredictionTrainResult.of(classifier, trainingStatistics);
    }

    private void findBestModelCandidate(FeaturesAndLabels trainData, ReadOnlyHugeLongArray trainRelationshipIds, TrainingStatistics trainingStatistics) {
        RandomSearch modelCandidates = new RandomSearch(this.pipeline.trainingParameterSpace(), this.pipeline.autoTuningConfig().maxTrials(), this.config.randomSeed());
        CrossValidation crossValidation = new CrossValidation(this.progressTracker, this.terminationFlag, this.config.metrics(), this.pipeline.splitConfig().validationFolds(), this.config.randomSeed(), (trainSet, modelParameters, metricsHandler, messageLogLevel) -> this.trainModel(trainData, trainSet, modelParameters, messageLogLevel, metricsHandler), (evaluationSet, classifier, scoreConsumer) -> this.computeTrainMetric(trainData, (Classifier)classifier, evaluationSet, scoreConsumer, ProgressTracker.NULL_TRACKER));
        crossValidation.selectModel(trainRelationshipIds, arg_0 -> ((HugeIntArray)trainData.labels()).get(arg_0), new TreeSet(this.classIdMap.originalIdsList()), trainingStatistics, (Iterator)modelCandidates);
    }

    @NotNull
    private Classifier trainModel(FeaturesAndLabels featureAndLabels, ReadOnlyHugeLongArray trainSet, TrainerConfig trainerConfig, LogLevel messageLogLevel, ModelSpecificMetricsHandler metricsHandler) {
        return ClassifierTrainerFactory.create((TrainerConfig)trainerConfig, (int)this.classIdMap.size(), (TerminationFlag)this.terminationFlag, (ProgressTracker)this.progressTracker, (LogLevel)messageLogLevel, (int)this.config.concurrency(), (Optional)this.config.randomSeed(), (boolean)true, (ModelSpecificMetricsHandler)metricsHandler).train(featureAndLabels.features(), featureAndLabels.labels(), trainSet);
    }

    private void computeTestMetric(Classifier classifier, TrainingStatistics trainingStatistics) {
        this.progressTracker.beginSubTask("Extract test features");
        FeaturesAndLabels testData = LinkFeaturesAndLabelsExtractor.extractFeaturesAndLabels(this.validationGraph, this.pipeline.featureSteps(), this.config.concurrency(), this.progressTracker, this.terminationFlag);
        this.progressTracker.endSubTask("Extract test features");
        this.progressTracker.beginSubTask("Compute test metrics");
        SignedProbabilities signedProbabilities = SignedProbabilities.computeFromLabeledData((Features)testData.features(), (HugeIntArray)testData.labels(), (Classifier)classifier, (BatchQueue)BatchQueue.consecutive((long)testData.size()), (int)this.config.concurrency(), (TerminationFlag)this.terminationFlag, (ProgressTracker)this.progressTracker);
        this.config.linkMetrics().forEach(metric -> {
            double score = metric.compute(signedProbabilities, this.config.negativeClassWeight());
            trainingStatistics.addTestScore((Metric)metric, score);
        });
        this.progressTracker.endSubTask("Compute test metrics");
    }

    private void computeTrainMetric(FeaturesAndLabels trainData, Classifier classifier, ReadOnlyHugeLongArray evaluationSet, MetricConsumer metricConsumer, ProgressTracker progressTracker) {
        SignedProbabilities signedProbabilities = SignedProbabilities.computeFromLabeledData((Features)trainData.features(), (HugeIntArray)trainData.labels(), (Classifier)classifier, (BatchQueue)BatchQueue.fromArray((ReadOnlyHugeLongArray)evaluationSet), (int)this.config.concurrency(), (TerminationFlag)this.terminationFlag, (ProgressTracker)progressTracker);
        this.config.linkMetrics().forEach(metric -> metricConsumer.consume((Metric)metric, metric.compute(signedProbabilities, this.config.negativeClassWeight())));
    }

    public static MemoryEstimation estimate(LinkPredictionTrainingPipeline pipeline, LinkPredictionTrainConfig trainConfig) {
        LinkPredictionSplitConfig splitConfig = pipeline.splitConfig();
        MemoryEstimations.Builder builder = MemoryEstimations.builder((String)LinkPredictionTrain.class.getSimpleName());
        MemoryRange fudgedLinkFeatureDim = MemoryRange.of((long)10L, (long)500L);
        int numberOfMetrics = trainConfig.linkMetrics().size();
        return builder.max("Features and labels", List.of(LinkFeaturesAndLabelsExtractor.estimate(fudgedLinkFeatureDim, relCounts -> (Long)relCounts.get(splitConfig.trainRelationshipType()), "Train"), LinkFeaturesAndLabelsExtractor.estimate(fudgedLinkFeatureDim, relCounts -> (Long)relCounts.get(splitConfig.testRelationshipType()), "Test"))).add(LinkPredictionTrain.estimateTrainingAndEvaluation(pipeline, fudgedLinkFeatureDim, numberOfMetrics)).add("Outer train stats map", TrainingStatistics.memoryEstimationStatsMap((int)numberOfMetrics, (int)1, (int)1)).add("Test stats map", TrainingStatistics.memoryEstimationStatsMap((int)numberOfMetrics, (int)1, (int)1)).fixed("Best model stats", MemoryRange.of((long)MemoryUsage.sizeOfInstance(ImmutableEvaluationScores.class)).times(2L).add(16L).times((long)numberOfMetrics)).build();
    }

    private static MemoryEstimation estimateTrainingAndEvaluation(LinkPredictionTrainingPipeline pipeline, MemoryRange linkFeatureDimension, int numberOfMetrics) {
        LinkPredictionSplitConfig splitConfig = pipeline.splitConfig();
        MemoryEstimation maxEstimationOverModelCandidates = MemoryEstimations.maxEstimation((String)"Max over model candidates", pipeline.trainingParameterSpace().values().stream().flatMap(Collection::stream).flatMap(TunableTrainerConfig::streamCornerCaseConfigs).map(trainerConfig -> MemoryEstimations.builder((String)"Train and evaluate model").fixed("Stats map builder train", ModelStatsBuilder.sizeInBytes((long)numberOfMetrics)).fixed("Stats map builder validation", ModelStatsBuilder.sizeInBytes((long)numberOfMetrics)).max("Train model and compute train metrics", List.of(LinkPredictionTrain.estimateTraining(pipeline.splitConfig(), trainerConfig, linkFeatureDimension), LinkPredictionTrain.estimateComputeTrainMetrics(pipeline.splitConfig()))).build()).collect(Collectors.toList()));
        return MemoryEstimations.builder((String)"model selection").add("Cross-Validation splitting", StratifiedKFoldSplitter.memoryEstimation((int)splitConfig.validationFolds(), dim -> (Long)dim.relationshipCounts().get(splitConfig.trainRelationshipType()))).add(maxEstimationOverModelCandidates).add("Inner train stats map", TrainingStatistics.memoryEstimationStatsMap((int)numberOfMetrics, (int)pipeline.numberOfModelSelectionTrials(), (int)1)).add("Validation stats map", TrainingStatistics.memoryEstimationStatsMap((int)numberOfMetrics, (int)pipeline.numberOfModelSelectionTrials(), (int)1)).build();
    }

    private static MemoryEstimation estimateTraining(LinkPredictionSplitConfig splitConfig, TrainerConfig trainerConfig, MemoryRange linkFeatureDimension) {
        return MemoryEstimations.setup((String)"Training", dim -> ClassifierTrainerFactory.memoryEstimation((TrainerConfig)trainerConfig, unused -> (Long)dim.relationshipCounts().get(splitConfig.trainRelationshipType()), (int)2, (MemoryRange)linkFeatureDimension, (boolean)true));
    }

    private static MemoryEstimation estimateComputeTrainMetrics(LinkPredictionSplitConfig splitConfig) {
        return MemoryEstimations.builder((String)"Compute train metrics").perGraphDimension("Sorted probabilities", (dim, threads) -> {
            long trainSetSize = (Long)dim.relationshipCounts().get(splitConfig.trainRelationshipType());
            return MemoryRange.of((long)SignedProbabilities.estimateMemory((long)trainSetSize));
        }).build();
    }
}

