/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.graphsage.BatchSampler;
import org.neo4j.gds.embeddings.graphsage.FeatureFunction;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.GraphSageLoss;
import org.neo4j.gds.embeddings.graphsage.ImmutableEpochResult;
import org.neo4j.gds.embeddings.graphsage.ImmutableGraphSageTrainMetrics;
import org.neo4j.gds.embeddings.graphsage.ImmutableModelTrainResult;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.embeddings.graphsage.LayerFactory;
import org.neo4j.gds.embeddings.graphsage.SingleLabelFeatureFunction;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.TensorFunctions;
import org.neo4j.gds.utils.StringFormatting;

public class GraphSageModelTrainer {
    private final long randomSeed;
    private final FeatureFunction featureFunction;
    private final Collection<Weights<Matrix>> labelProjectionWeights;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;
    private final GraphSageTrainConfig config;

    public GraphSageModelTrainer(GraphSageTrainConfig config, ExecutorService executor, ProgressTracker progressTracker) {
        this(config, executor, progressTracker, new SingleLabelFeatureFunction(), Collections.emptyList());
    }

    public GraphSageModelTrainer(GraphSageTrainConfig config, ExecutorService executor, ProgressTracker progressTracker, FeatureFunction featureFunction, Collection<Weights<Matrix>> labelProjectionWeights) {
        this.config = config;
        this.featureFunction = featureFunction;
        this.labelProjectionWeights = labelProjectionWeights;
        this.executor = executor;
        this.progressTracker = progressTracker;
        this.randomSeed = config.randomSeed().orElseGet(() -> ThreadLocalRandom.current().nextLong());
    }

    public static List<Task> progressTasks(GraphSageTrainConfig config) {
        return List.of(Tasks.leaf((String)"Prepare batches"), Tasks.iterativeDynamic((String)"Train model", () -> List.of(Tasks.iterativeDynamic((String)"Epoch", () -> List.of(Tasks.leaf((String)"Iteration")), (int)config.maxIterations())), (int)config.epochs()));
    }

    public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
        Layer[] layers = (Layer[])this.config.layerConfigs(GraphSageModelTrainer.firstLayerColumns(this.config, graph)).stream().map(LayerFactory::createLayer).toArray(Layer[]::new);
        assert (graph.hasRelationshipProperty() == this.config.hasRelationshipWeightProperty()) : "Weight property of graph and config needs to match.";
        ArrayList weights = new ArrayList(this.labelProjectionWeights);
        for (Layer layer : layers) {
            weights.addAll(layer.weights());
        }
        this.progressTracker.beginSubTask("Prepare batches");
        BatchSampler batchSampler = new BatchSampler(graph);
        List<long[]> extendedBatches = batchSampler.extendedBatches(this.config.batchSize(), this.config.searchDepth(), this.randomSeed);
        SplittableRandom random = new SplittableRandom(this.randomSeed);
        this.progressTracker.endSubTask("Prepare batches");
        this.progressTracker.beginSubTask("Train model");
        boolean converged = false;
        ArrayList<List<Double>> iterationLossesPerEpoch = new ArrayList<List<Double>>();
        double prevEpochLoss = Double.NaN;
        int epochs = this.config.epochs();
        boolean createBatchTasksEagerly = this.config.batchesPerIteration(graph.nodeCount()) * this.config.maxIterations() > extendedBatches.size();
        for (int epoch = 1; epoch <= epochs && !converged; ++epoch) {
            Supplier<List<BatchTask>> batchTaskSampler;
            this.progressTracker.beginSubTask("Epoch");
            long epochLocalSeed = (long)epoch + this.randomSeed;
            if (createBatchTasksEagerly) {
                List tasksForEpoch = extendedBatches.stream().map(extendedBatch -> this.createBatchTask((long[])extendedBatch, graph, features, layers, weights, epochLocalSeed)).collect(Collectors.toList());
                batchTaskSampler = () -> IntStream.range(0, this.config.batchesPerIteration(graph.nodeCount())).mapToObj(__ -> (BatchTask)tasksForEpoch.get(random.nextInt(tasksForEpoch.size()))).collect(Collectors.toList());
            } else {
                batchTaskSampler = () -> IntStream.range(0, this.config.batchesPerIteration(graph.nodeCount())).mapToObj(__ -> this.createBatchTask((long[])extendedBatches.get(random.nextInt(extendedBatches.size())), graph, features, layers, weights, epochLocalSeed)).collect(Collectors.toList());
            }
            EpochResult epochResult = this.trainEpoch(batchTaskSampler, weights, prevEpochLoss);
            List<Double> epochLosses = epochResult.losses();
            iterationLossesPerEpoch.add(epochLosses);
            prevEpochLoss = epochLosses.get(epochLosses.size() - 1);
            converged = epochResult.converged();
            this.progressTracker.endSubTask("Epoch");
        }
        this.progressTracker.endSubTask("Train model");
        return ModelTrainResult.of(iterationLossesPerEpoch, converged, layers);
    }

    private BatchTask createBatchTask(long[] extendedBatch, Graph graph, HugeObjectArray<double[]> features, Layer[] layers, ArrayList<Weights<? extends Tensor<?>>> weights, long localSeed) {
        GraphSageLoss loss;
        Graph localGraph = graph.concurrentCopy();
        List<SubGraph> subGraphs = GraphSageHelper.subGraphsPerLayer(localGraph, extendedBatch, layers, localSeed);
        Variable<Matrix> batchedFeaturesExtractor = this.featureFunction.apply(localGraph, subGraphs.get(subGraphs.size() - 1).originalNodeIds(), features);
        Variable<Matrix> embeddingVariable = GraphSageHelper.embeddingsComputationGraph(subGraphs, layers, batchedFeaturesExtractor);
        GraphSageLoss lossWithoutPenalty = new GraphSageLoss(SubGraph.relationshipWeightFunction((Graph)localGraph), embeddingVariable, extendedBatch, this.config.negativeSampleWeight());
        long originalBatchSize = extendedBatch.length / 3;
        if (this.config.penaltyL2() > 0.0) {
            List l2penalty = Arrays.stream(layers).map(layer -> layer.aggregator().weightsWithoutBias()).flatMap(layerWeights -> layerWeights.stream().map(L2NormSquared::new)).collect(Collectors.toList());
            loss = new ElementSum(List.of(lossWithoutPenalty, new ConstantScale((Variable)new ElementSum(l2penalty), this.config.penaltyL2() * (double)originalBatchSize / (double)graph.nodeCount())));
        } else {
            loss = lossWithoutPenalty;
        }
        return new BatchTask((Variable<Scalar>)loss, weights, this.progressTracker);
    }

    private EpochResult trainEpoch(Supplier<List<BatchTask>> sampledBatchTaskSupplier, List<Weights<? extends Tensor<?>>> weights, double prevEpochLoss) {
        AdamOptimizer updater = new AdamOptimizer(weights, this.config.learningRate());
        ArrayList<Double> iterationLosses = new ArrayList<Double>();
        double prevLoss = prevEpochLoss;
        boolean converged = false;
        int maxIterations = this.config.maxIterations();
        for (int iteration = 1; iteration <= maxIterations; ++iteration) {
            this.progressTracker.beginSubTask("Iteration");
            List<BatchTask> sampledBatchTasks = sampledBatchTaskSupplier.get();
            RunWithConcurrency.builder().concurrency(this.config.concurrency()).tasks(sampledBatchTasks).executor(this.executor).run();
            double avgLossPerNode = sampledBatchTasks.stream().mapToDouble(BatchTask::loss).sum() / (double)sampledBatchTasks.size();
            iterationLosses.add(avgLossPerNode);
            this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Average loss per node: %.10f", (Object[])new Object[]{avgLossPerNode}));
            if (Math.abs(prevLoss - avgLossPerNode) < this.config.tolerance()) {
                converged = true;
                this.progressTracker.endSubTask("Iteration");
                break;
            }
            prevLoss = avgLossPerNode;
            List batchedGradients = sampledBatchTasks.stream().map(BatchTask::weightGradients).collect(Collectors.toList());
            List meanGradients = TensorFunctions.averageTensors(batchedGradients);
            updater.update(meanGradients);
            this.progressTracker.endSubTask("Iteration");
        }
        return ImmutableEpochResult.of(converged, iterationLosses);
    }

    private static int firstLayerColumns(GraphSageTrainConfig config, Graph graph) {
        return config.projectedFeatureDimension().orElseGet(() -> {
            List<FeatureExtractor> featureExtractors = GraphSageHelper.featureExtractors(graph, config);
            return FeatureExtraction.featureCount(featureExtractors);
        });
    }

    @ValueClass
    public static interface ModelTrainResult {
        public GraphSageTrainMetrics metrics();

        public Layer[] layers();

        public static ModelTrainResult of(List<List<Double>> iterationLossesPerEpoch, boolean converged, Layer[] layers) {
            return ImmutableModelTrainResult.builder().layers(layers).metrics(ImmutableGraphSageTrainMetrics.of(iterationLossesPerEpoch, converged)).build();
        }
    }

    @ValueClass
    public static interface GraphSageTrainMetrics
    extends ToMapConvertible {
        public static GraphSageTrainMetrics empty() {
            return ImmutableGraphSageTrainMetrics.of(List.of(), false);
        }

        @Value.Derived
        default public List<Double> epochLosses() {
            return this.iterationLossPerEpoch().stream().map(iterationLosses -> (Double)iterationLosses.get(iterationLosses.size() - 1)).collect(Collectors.toList());
        }

        public List<List<Double>> iterationLossPerEpoch();

        public boolean didConverge();

        @Value.Derived
        default public int ranEpochs() {
            return this.iterationLossPerEpoch().isEmpty() ? 0 : this.iterationLossPerEpoch().size();
        }

        @Value.Derived
        default public List<Integer> ranIterationsPerEpoch() {
            return this.iterationLossPerEpoch().stream().map(List::size).collect(Collectors.toList());
        }

        @Value.Auxiliary
        @Value.Derived
        default public Map<String, Object> toMap() {
            return Map.of("metrics", Map.of("epochLosses", this.epochLosses(), "iterationLossesPerEpoch", this.iterationLossPerEpoch(), "didConverge", this.didConverge(), "ranEpochs", this.ranEpochs(), "ranIterationsPerEpoch", this.ranIterationsPerEpoch()));
        }
    }

    static class BatchTask
    implements Runnable {
        private final Variable<Scalar> lossFunction;
        private final List<Weights<? extends Tensor<?>>> weightVariables;
        private List<? extends Tensor<?>> weightGradients;
        private final ProgressTracker progressTracker;
        private double loss;

        BatchTask(Variable<Scalar> lossFunction, List<Weights<? extends Tensor<?>>> weightVariables, ProgressTracker progressTracker) {
            this.lossFunction = lossFunction;
            this.weightVariables = weightVariables;
            this.progressTracker = progressTracker;
        }

        @Override
        public void run() {
            ComputationContext localCtx = new ComputationContext();
            this.loss = ((Scalar)localCtx.forward(this.lossFunction)).value();
            localCtx.backward(this.lossFunction);
            this.weightGradients = this.weightVariables.stream().map(arg_0 -> ((ComputationContext)localCtx).gradient(arg_0)).collect(Collectors.toList());
            this.progressTracker.logProgress();
        }

        public double loss() {
            return this.loss;
        }

        List<? extends Tensor<?>> weightGradients() {
            return this.weightGradients;
        }
    }

    @ValueClass
    static interface EpochResult {
        public boolean converged();

        public List<Double> losses();
    }
}

