/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline.train;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.model.CatalogModelContainer;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.pipeline.ImmutablePipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.NodePropertyStepExecutor;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.PipelineGraphFilter;
import org.neo4j.gds.ml.pipeline.linkPipeline.ExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionTrainingPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.ImmutableLinkPredictionTrainPipelineResult;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionRelationshipSampler;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrain;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainResult;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.ml.util.TrainingSetWarnings;
import org.neo4j.gds.model.ModelConfig;

public class LinkPredictionTrainPipelineExecutor
extends PipelineExecutor<LinkPredictionTrainConfig, LinkPredictionTrainingPipeline, LinkPredictionTrainPipelineResult> {
    private final LinkPredictionRelationshipSampler linkPredictionRelationshipSampler;
    private final Set<RelationshipType> availableRelationshipTypesForNodeProperty;

    public LinkPredictionTrainPipelineExecutor(LinkPredictionTrainingPipeline pipeline, LinkPredictionTrainConfig config, ExecutionContext executionContext, GraphStore graphStore, ProgressTracker progressTracker) {
        super(pipeline, config, executionContext, graphStore, progressTracker);
        this.availableRelationshipTypesForNodeProperty = graphStore.relationshipTypes().stream().filter(relType -> !relType.name.equals(config.targetRelationshipType())).collect(Collectors.toSet());
        this.linkPredictionRelationshipSampler = new LinkPredictionRelationshipSampler(graphStore, pipeline.splitConfig(), config, progressTracker, this.terminationFlag);
    }

    public static Task progressTask(String taskName, final LinkPredictionTrainingPipeline pipeline, final long relationshipCount) {
        final ExpectedSetSizes sizes = pipeline.splitConfig().expectedSetSizes(relationshipCount);
        return Tasks.task((String)taskName, (List)new ArrayList<Task>(){
            {
                this.add(LinkPredictionRelationshipSampler.progressTask(sizes));
                this.add(NodePropertyStepExecutor.tasks(pipeline.nodePropertySteps(), sizes.featureInputSize()));
                this.addAll(LinkPredictionTrain.progressTasks(relationshipCount, pipeline.splitConfig(), pipeline.numberOfModelSelectionTrials()));
            }
        });
    }

    public static MemoryEstimation estimate(ExecutionContext executionContext, LinkPredictionTrainingPipeline pipeline, LinkPredictionTrainConfig configuration) {
        pipeline.validateTrainingParameterSpace();
        MemoryEstimation splitEstimations = LinkPredictionRelationshipSampler.splitEstimation(pipeline.splitConfig(), configuration.targetRelationshipType(), pipeline.relationshipWeightProperty(executionContext));
        MemoryEstimation maxOverNodePropertySteps = NodePropertyStepExecutor.estimateNodePropertySteps(executionContext.modelCatalog(), configuration.username(), pipeline.nodePropertySteps(), configuration.nodeLabels(), List.of(pipeline.splitConfig().featureInputRelationshipType().name));
        MemoryEstimation trainingEstimation = MemoryEstimations.builder().add("Train pipeline", LinkPredictionTrain.estimate(pipeline, configuration)).build();
        return MemoryEstimations.builder((String)LinkPredictionTrainPipelineExecutor.class.getSimpleName()).max("Pipeline execution", List.of(splitEstimations, maxOverNodePropertySteps, trainingEstimation)).build();
    }

    @Override
    public Map<PipelineExecutor.DatasetSplits, PipelineGraphFilter> generateDatasetSplitGraphFilters() {
        LinkPredictionSplitConfig splitConfig = ((LinkPredictionTrainingPipeline)this.pipeline).splitConfig();
        return Map.of(PipelineExecutor.DatasetSplits.TRAIN, ImmutablePipelineGraphFilter.builder().nodeLabels(((LinkPredictionTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore)).relationshipTypes(List.of(splitConfig.trainRelationshipType())).build(), PipelineExecutor.DatasetSplits.TEST, ImmutablePipelineGraphFilter.builder().nodeLabels(((LinkPredictionTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore)).relationshipTypes(List.of(splitConfig.testRelationshipType())).build(), PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutablePipelineGraphFilter.builder().nodeLabels(((LinkPredictionTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore)).relationshipTypes(List.of(splitConfig.featureInputRelationshipType())).build());
    }

    @Override
    public void splitDatasets() {
        this.linkPredictionRelationshipSampler.splitAndSampleRelationships(((LinkPredictionTrainingPipeline)this.pipeline).relationshipWeightProperty(this.executionContext));
    }

    @Override
    protected LinkPredictionTrainPipelineResult execute(Map<PipelineExecutor.DatasetSplits, PipelineGraphFilter> dataSplits) {
        ((LinkPredictionTrainingPipeline)this.pipeline).validateTrainingParameterSpace();
        PipelineGraphFilter trainDataSplit = dataSplits.get((Object)PipelineExecutor.DatasetSplits.TRAIN);
        PipelineGraphFilter testDataSplit = dataSplits.get((Object)PipelineExecutor.DatasetSplits.TEST);
        Graph trainGraph = this.graphStore.getGraph(trainDataSplit.nodeLabels(), trainDataSplit.relationshipTypes(), Optional.of("label"));
        Graph testGraph = this.graphStore.getGraph(testDataSplit.nodeLabels(), testDataSplit.relationshipTypes(), Optional.of("label"));
        TrainingSetWarnings.warnForSmallRelationshipSets((long)trainGraph.relationshipCount(), (long)testGraph.relationshipCount(), (long)((LinkPredictionTrainingPipeline)this.pipeline).splitConfig().validationFolds(), (ProgressTracker)this.progressTracker);
        LinkPredictionTrainResult trainResult = new LinkPredictionTrain(trainGraph, testGraph, (LinkPredictionTrainingPipeline)this.pipeline, (LinkPredictionTrainConfig)this.config, this.progressTracker, this.terminationFlag).compute();
        Model model = Model.of((String)"LinkPrediction", (GraphSchema)this.schemaBeforeSteps, (Object)trainResult.classifier().data(), (ModelConfig)((LinkPredictionTrainConfig)this.config), (ToMapConvertible)LinkPredictionModelInfo.of(trainResult.trainingStatistics().winningModelTestMetrics(), trainResult.trainingStatistics().winningModelOuterTrainMetrics(), trainResult.trainingStatistics().bestCandidate(), LinkPredictionPredictPipeline.from(this.pipeline)));
        return ImmutableLinkPredictionTrainPipelineResult.of((Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo>)model, trainResult.trainingStatistics());
    }

    @Override
    protected Set<RelationshipType> getAvailableRelTypesForNodePropertySteps() {
        return this.availableRelationshipTypesForNodeProperty;
    }

    private void removeDataSplitRelationships(Map<PipelineExecutor.DatasetSplits, PipelineGraphFilter> datasets) {
        datasets.values().stream().flatMap(graphFilter -> graphFilter.relationshipTypes().stream()).distinct().collect(Collectors.toList()).forEach(arg_0 -> ((GraphStore)this.graphStore).deleteRelationships(arg_0));
    }

    @Override
    protected void additionalGraphStoreCleanup(Map<PipelineExecutor.DatasetSplits, PipelineGraphFilter> datasets) {
        this.removeDataSplitRelationships(datasets);
        super.additionalGraphStoreCleanup(datasets);
    }

    @ValueClass
    public static interface LinkPredictionTrainPipelineResult
    extends CatalogModelContainer<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPredictionModelInfo> {
        public TrainingStatistics trainingStatistics();
    }
}

