/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.FeatureFunction;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;

public class GraphSageEmbeddingsGenerator {
    private final Layer[] layers;
    private final int batchSize;
    private final int concurrency;
    private final FeatureFunction featureFunction;
    private final long randomSeed;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;

    public GraphSageEmbeddingsGenerator(Layer[] layers, int batchSize, int concurrency, FeatureFunction featureFunction, Optional<Long> randomSeed, ExecutorService executor, ProgressTracker progressTracker) {
        this.layers = layers;
        this.batchSize = batchSize;
        this.concurrency = concurrency;
        this.featureFunction = featureFunction;
        this.randomSeed = randomSeed.orElseGet(() -> ThreadLocalRandom.current().nextLong());
        this.executor = executor;
        this.progressTracker = progressTracker;
    }

    public HugeObjectArray<double[]> makeEmbeddings(Graph graph, HugeObjectArray<double[]> features) {
        HugeObjectArray result = HugeObjectArray.newArray(double[].class, (long)graph.nodeCount());
        this.progressTracker.beginSubTask();
        List tasks = PartitionUtils.rangePartitionWithBatchSize((long)graph.nodeCount(), (long)this.batchSize, partition -> this.createEmbeddings(graph.concurrentCopy(), (Partition)partition, features, (HugeObjectArray<double[]>)result));
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks((Iterable)tasks).executor(this.executor).run();
        this.progressTracker.endSubTask();
        return result;
    }

    private Runnable createEmbeddings(Graph graph, Partition partition, HugeObjectArray<double[]> features, HugeObjectArray<double[]> result) {
        return () -> {
            List<SubGraph> subGraphs = GraphSageHelper.subGraphsPerLayer(graph, partition.stream().toArray(), this.layers, this.randomSeed);
            Variable<Matrix> batchedFeaturesExtractor = this.featureFunction.apply(graph, subGraphs.get(subGraphs.size() - 1).originalNodeIds(), features);
            Variable<Matrix> embeddingVariable = GraphSageHelper.embeddingsComputationGraph(subGraphs, this.layers, batchedFeaturesExtractor);
            Matrix embeddings = (Matrix)new ComputationContext().forward(embeddingVariable);
            long partitionStartNodeId = partition.startNode();
            long partitionNodeCount = partition.nodeCount();
            int partitionIdx = 0;
            while ((long)partitionIdx < partitionNodeCount) {
                long nodeId = partitionStartNodeId + (long)partitionIdx;
                result.set(nodeId, (Object)embeddings.getRow(partitionIdx));
                ++partitionIdx;
            }
            this.progressTracker.logProgress(partitionNodeCount);
        };
    }
}

