/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.concurrent.ExecutorService;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.graphsage.GraphSageEmbeddingsGenerator;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.algo.ImmutableGraphSageResult;

public class GraphSage
extends Algorithm<GraphSageResult> {
    public static final String MODEL_TYPE = "graphSage";
    private final Graph graph;
    private final GraphSageBaseConfig config;
    private final Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model;
    private final ExecutorService executor;

    public GraphSage(Graph graph, Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model, GraphSageBaseConfig config, ExecutorService executor, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
        this.model = model;
        this.executor = executor;
    }

    public GraphSageResult compute() {
        Layer[] layers = ((ModelData)this.model.data()).layers();
        GraphSageEmbeddingsGenerator embeddingsGenerator = new GraphSageEmbeddingsGenerator(layers, this.config.batchSize(), this.config.concurrency(), ((ModelData)this.model.data()).featureFunction(), ((GraphSageTrainConfig)this.model.trainConfig()).randomSeed(), this.executor, this.progressTracker);
        GraphSageTrainConfig trainConfig = (GraphSageTrainConfig)this.model.trainConfig();
        HugeObjectArray<double[]> features = trainConfig.isMultiLabel() ? GraphSageHelper.initializeMultiLabelFeatures(this.graph, GraphSageHelper.multiLabelFeatureExtractors(this.graph, trainConfig)) : GraphSageHelper.initializeSingleLabelFeatures(this.graph, trainConfig);
        HugeObjectArray<double[]> embeddings = embeddingsGenerator.makeEmbeddings(this.graph, features);
        return GraphSageResult.of(embeddings);
    }

    public void release() {
    }

    @ValueClass
    public static interface GraphSageResult {
        public HugeObjectArray<double[]> embeddings();

        public static GraphSageResult of(HugeObjectArray<double[]> embeddings) {
            return ImmutableGraphSageResult.of(embeddings);
        }
    }
}

