/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import com.carrotsearch.hppc.LongHashSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.ImmutableRelationshipCursor;
import org.neo4j.gds.api.RelationshipCursor;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.FeatureFunction;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.GraphSageLoss;
import org.neo4j.gds.embeddings.graphsage.ImmutableGraphSageTrainMetrics;
import org.neo4j.gds.embeddings.graphsage.ImmutableModelTrainResult;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.gds.embeddings.graphsage.LayerFactory;
import org.neo4j.gds.embeddings.graphsage.SingleLabelFeatureFunction;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.RelationshipWeights;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.optimizer.AdamOptimizer;
import org.neo4j.gds.ml.core.samplers.WeightedUniformSampler;
import org.neo4j.gds.ml.core.subgraph.NeighborhoodSampler;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.TensorFunctions;
import org.neo4j.gds.utils.StringFormatting;

public class GraphSageModelTrainer {
    private final long randomSeed;
    private final boolean useWeights;
    private final double learningRate;
    private final double tolerance;
    private final int negativeSampleWeight;
    private final int concurrency;
    private final int epochs;
    private final int maxIterations;
    private final int maxSearchDepth;
    private final Function<Graph, List<LayerConfig>> layerConfigsFunction = graph -> config.layerConfigs(GraphSageModelTrainer.firstLayerColumns(config, graph));
    private final FeatureFunction featureFunction;
    private final Collection<Weights<Matrix>> labelProjectionWeights;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;
    private final int batchSize;

    public GraphSageModelTrainer(GraphSageTrainConfig config, ExecutorService executor, ProgressTracker progressTracker) {
        this(config, executor, progressTracker, new SingleLabelFeatureFunction(), Collections.emptyList());
    }

    public GraphSageModelTrainer(GraphSageTrainConfig config, ExecutorService executor, ProgressTracker progressTracker, FeatureFunction featureFunction, Collection<Weights<Matrix>> labelProjectionWeights) {
        this.batchSize = config.batchSize();
        this.learningRate = config.learningRate();
        this.tolerance = config.tolerance();
        this.negativeSampleWeight = config.negativeSampleWeight();
        this.concurrency = config.concurrency();
        this.epochs = config.epochs();
        this.maxIterations = config.maxIterations();
        this.maxSearchDepth = config.searchDepth();
        this.featureFunction = featureFunction;
        this.labelProjectionWeights = labelProjectionWeights;
        this.executor = executor;
        this.progressTracker = progressTracker;
        this.useWeights = config.hasRelationshipWeightProperty();
        this.randomSeed = config.randomSeed().orElseGet(() -> ThreadLocalRandom.current().nextLong());
    }

    public ModelTrainResult train(Graph graph, HugeObjectArray<double[]> features) {
        Layer[] layers = (Layer[])this.layerConfigsFunction.apply(graph).stream().map(LayerFactory::createLayer).toArray(Layer[]::new);
        ArrayList weights = new ArrayList(this.labelProjectionWeights);
        for (Layer layer : layers) {
            weights.addAll(layer.weights());
        }
        List batchTasks = PartitionUtils.rangePartitionWithBatchSize((long)graph.nodeCount(), (long)this.batchSize, batch -> new BatchTask(this.lossFunction((Partition)batch, graph, features, layers), weights, this.tolerance));
        double previousLoss = Double.MAX_VALUE;
        boolean converged = false;
        ArrayList<Double> epochLosses = new ArrayList<Double>();
        for (int epoch = 1; epoch <= this.epochs; ++epoch) {
            this.progressTracker.beginSubTask("train epoch");
            double newLoss = this.trainEpoch(batchTasks, weights);
            epochLosses.add(newLoss);
            this.progressTracker.endSubTask("train epoch");
            if (Math.abs((newLoss - previousLoss) / previousLoss) < this.tolerance) {
                converged = true;
                break;
            }
            previousLoss = newLoss;
        }
        return ModelTrainResult.of(epochLosses, converged, layers);
    }

    private double trainEpoch(List<BatchTask> batchTasks, List<Weights<? extends Tensor<?>>> weights) {
        AdamOptimizer updater = new AdamOptimizer(weights, this.learningRate);
        double totalLoss = Double.NaN;
        for (int iteration = 1; iteration <= this.maxIterations; ++iteration) {
            this.progressTracker.beginSubTask("iteration");
            ParallelUtil.runWithConcurrency((int)this.concurrency, batchTasks, (ExecutorService)this.executor);
            totalLoss = batchTasks.stream().mapToDouble(BatchTask::loss).average().orElseThrow();
            boolean converged = batchTasks.stream().allMatch(task -> task.converged);
            if (converged) {
                this.progressTracker.endSubTask();
                break;
            }
            List batchedGradients = batchTasks.stream().map(BatchTask::weightGradients).collect(Collectors.toList());
            List meanGradients = TensorFunctions.averageTensors(batchedGradients);
            updater.update(meanGradients);
            this.progressTracker.logMessage(StringFormatting.formatWithLocale((String)"LOSS: %.10f", (Object[])new Object[]{totalLoss}));
            this.progressTracker.endSubTask("iteration");
        }
        return totalLoss;
    }

    private Variable<Scalar> lossFunction(Partition batch, Graph graph, HugeObjectArray<double[]> features, Layer[] layers) {
        Graph localGraph = graph.concurrentCopy();
        long[] totalBatch = this.addSamplesPerBatchNode(batch, localGraph);
        Variable<Matrix> embeddingVariable = GraphSageHelper.embeddingsComputationGraph(localGraph, this.useWeights, totalBatch, features, layers, this.featureFunction);
        return new GraphSageLoss(this.useWeights ? (arg_0, arg_1, arg_2) -> ((Graph)localGraph).relationshipProperty(arg_0, arg_1, arg_2) : RelationshipWeights.UNWEIGHTED, embeddingVariable, totalBatch, this.negativeSampleWeight);
    }

    private long[] addSamplesPerBatchNode(Partition batch, Graph localGraph) {
        long batchLocalRandomSeed = (long)GraphSageModelTrainer.getBatchIndex(batch, localGraph.nodeCount()) + this.randomSeed;
        long[] neighbours = this.neighborBatch(localGraph, batch, batchLocalRandomSeed).toArray();
        LongHashSet neighborsSet = new LongHashSet(neighbours.length);
        neighborsSet.addAll(neighbours);
        return LongStream.concat(batch.stream(), LongStream.concat(Arrays.stream(neighbours), this.negativeBatch(localGraph, Math.toIntExact(batch.nodeCount()), neighborsSet, batchLocalRandomSeed))).toArray();
    }

    LongStream neighborBatch(Graph graph, Partition batch, long batchLocalSeed) {
        LongStream.Builder neighborBatchBuilder = LongStream.builder();
        Random localRandom = new Random(batchLocalSeed);
        batch.consume(nodeId -> {
            AtomicLong currentNode = new AtomicLong(nodeId);
            for (int searchDepth = localRandom.nextInt(this.maxSearchDepth) + 1; searchDepth > 0; --searchDepth) {
                NeighborhoodSampler neighborhoodSampler = new NeighborhoodSampler(currentNode.get() + (long)searchDepth);
                OptionalLong maybeSample = neighborhoodSampler.sampleOne(graph, nodeId);
                if (maybeSample.isPresent()) {
                    currentNode.set(maybeSample.getAsLong());
                    continue;
                }
                searchDepth = 0;
            }
            neighborBatchBuilder.add(currentNode.get());
        });
        return neighborBatchBuilder.build();
    }

    LongStream negativeBatch(Graph graph, int batchSize, LongHashSet neighbours, long batchLocalRandomSeed) {
        long nodeCount = graph.nodeCount();
        WeightedUniformSampler sampler = new WeightedUniformSampler(batchLocalRandomSeed);
        Stream<RelationshipCursor> degreeWeightedNodes = LongStream.range(0L, nodeCount).mapToObj(nodeId -> ImmutableRelationshipCursor.of((long)0L, (long)nodeId, (double)Math.pow(graph.degree(nodeId), 0.75)));
        return sampler.sample(degreeWeightedNodes, nodeCount, batchSize, sample -> !neighbours.contains(sample));
    }

    private static int getBatchIndex(Partition partition, long nodeCount) {
        return Math.toIntExact(Math.floorDiv(partition.startNode(), nodeCount));
    }

    private static int firstLayerColumns(GraphSageTrainConfig config, Graph graph) {
        return config.projectedFeatureDimension().orElseGet(() -> {
            List<FeatureExtractor> featureExtractors = GraphSageHelper.featureExtractors(graph, config);
            return FeatureExtraction.featureCount(featureExtractors);
        });
    }

    @ValueClass
    public static interface ModelTrainResult {
        public GraphSageTrainMetrics metrics();

        public Layer[] layers();

        public static ModelTrainResult of(List<Double> epochLosses, boolean converged, Layer[] layers) {
            return ImmutableModelTrainResult.builder().layers(layers).metrics(ImmutableGraphSageTrainMetrics.of(epochLosses, converged)).build();
        }
    }

    @ValueClass
    public static interface GraphSageTrainMetrics
    extends ToMapConvertible {
        public static GraphSageTrainMetrics empty() {
            return ImmutableGraphSageTrainMetrics.of(List.of(), false);
        }

        public List<Double> epochLosses();

        public boolean didConverge();

        @Value.Derived
        default public int ranEpochs() {
            return this.epochLosses().isEmpty() ? 0 : this.epochLosses().size();
        }

        @Value.Auxiliary
        @Value.Derived
        default public Map<String, Object> toMap() {
            return Map.of("metrics", Map.of("epochLosses", this.epochLosses(), "didConverge", this.didConverge(), "ranEpochs", this.ranEpochs()));
        }
    }

    static class BatchTask
    implements Runnable {
        private final Variable<Scalar> lossFunction;
        private final List<Weights<? extends Tensor<?>>> weightVariables;
        private List<? extends Tensor<?>> weightGradients;
        private final double tolerance;
        private boolean converged;
        private double prevLoss;

        BatchTask(Variable<Scalar> lossFunction, List<Weights<? extends Tensor<?>>> weightVariables, double tolerance) {
            this.lossFunction = lossFunction;
            this.weightVariables = weightVariables;
            this.tolerance = tolerance;
        }

        @Override
        public void run() {
            if (this.converged) {
                return;
            }
            ComputationContext localCtx = new ComputationContext();
            double loss = ((Scalar)localCtx.forward(this.lossFunction)).value();
            this.converged = Math.abs(this.prevLoss - loss) < this.tolerance;
            this.prevLoss = loss;
            localCtx.backward(this.lossFunction);
            this.weightGradients = this.weightVariables.stream().map(arg_0 -> ((ComputationContext)localCtx).gradient(arg_0)).collect(Collectors.toList());
        }

        public double loss() {
            return this.prevLoss;
        }

        List<? extends Tensor<?>> weightGradients() {
            return this.weightGradients;
        }
    }
}

