/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.concurrent.ExecutorService;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.FeatureFunction;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.tensor.Matrix;

public class GraphSageEmbeddingsGenerator {
    private final Layer[] layers;
    private final int batchSize;
    private final int concurrency;
    private final boolean isWeighted;
    private final FeatureFunction featureFunction;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;

    public GraphSageEmbeddingsGenerator(Layer[] layers, int batchSize, int concurrency, boolean isWeighted, FeatureFunction featureFunction, ExecutorService executor, ProgressTracker progressTracker) {
        this.layers = layers;
        this.batchSize = batchSize;
        this.concurrency = concurrency;
        this.isWeighted = isWeighted;
        this.featureFunction = featureFunction;
        this.executor = executor;
        this.progressTracker = progressTracker;
    }

    public HugeObjectArray<double[]> makeEmbeddings(Graph graph, HugeObjectArray<double[]> features) {
        HugeObjectArray result = HugeObjectArray.newArray(double[].class, (long)graph.nodeCount());
        this.progressTracker.beginSubTask();
        List tasks = PartitionUtils.rangePartitionWithBatchSize((long)graph.nodeCount(), (long)this.batchSize, partition -> this.createEmbeddings(graph, (Partition)partition, features, (HugeObjectArray<double[]>)result));
        ParallelUtil.runWithConcurrency((int)this.concurrency, (Iterable)tasks, (ExecutorService)this.executor);
        this.progressTracker.endSubTask();
        return result;
    }

    private Runnable createEmbeddings(Graph graph, Partition partition, HugeObjectArray<double[]> features, HugeObjectArray<double[]> result) {
        return () -> {
            Variable<Matrix> embeddingVariable = GraphSageHelper.embeddingsComputationGraph(graph, this.isWeighted, partition.stream().toArray(), features, this.layers, this.featureFunction);
            Matrix embeddings = (Matrix)new ComputationContext().forward(embeddingVariable);
            long partitionStartNodeId = partition.startNode();
            long partitionNodeCount = partition.nodeCount();
            int partitionIdx = 0;
            while ((long)partitionIdx < partitionNodeCount) {
                long nodeId = partitionStartNodeId + (long)partitionIdx;
                result.set(nodeId, (Object)embeddings.getRow(partitionIdx));
                ++partitionIdx;
            }
            this.progressTracker.logProgress(partitionNodeCount);
        };
    }
}

