/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.concurrent.ExecutorService;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.config.MutateConfig;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSage;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageModelResolver;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.EmbeddingUtils;

public class GraphSageAlgorithmFactory<CONFIG extends GraphSageBaseConfig>
extends GraphAlgorithmFactory<GraphSage, CONFIG> {
    private final ModelCatalog modelCatalog;

    public GraphSageAlgorithmFactory(ModelCatalog modelCatalog) {
        this.modelCatalog = modelCatalog;
    }

    public String taskName() {
        return GraphSage.class.getSimpleName();
    }

    public GraphSage build(Graph graph, CONFIG configuration, ProgressTracker progressTracker) {
        ExecutorService executorService = Pools.DEFAULT;
        Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model = GraphSageModelResolver.resolveModel(this.modelCatalog, configuration.username(), configuration.modelName());
        if (((GraphSageTrainConfig)model.trainConfig()).hasRelationshipWeightProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue((Graph)graph, (int)configuration.concurrency(), (ExecutorService)executorService);
        }
        return new GraphSage(graph, model, (GraphSageBaseConfig)configuration, executorService, progressTracker);
    }

    public MemoryEstimation memoryEstimation(CONFIG config) {
        Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model = GraphSageModelResolver.resolveModel(this.modelCatalog, config.username(), config.modelName());
        return MemoryEstimations.setup((String)"", graphDimensions -> this.withNodeCount((GraphSageTrainConfig)model.trainConfig(), graphDimensions.nodeCount(), config instanceof MutateConfig));
    }

    public Task progressTask(Graph graph, CONFIG config) {
        return Tasks.leaf((String)this.taskName(), (long)graph.nodeCount());
    }

    private MemoryEstimation withNodeCount(GraphSageTrainConfig config, long nodeCount, boolean mutate) {
        MemoryEstimations.Builder gsBuilder = MemoryEstimations.builder((String)"GraphSage");
        if (mutate) {
            gsBuilder = gsBuilder.startField("residentMemory").add("resultFeatures", HugeObjectArray.memoryEstimation((long)MemoryUsage.sizeOfDoubleArray((long)config.embeddingDimension()))).endField();
        }
        MemoryEstimations.Builder builder = gsBuilder.startField("temporaryMemory").field("this.instance", GraphSage.class).add("initialFeatures", HugeObjectArray.memoryEstimation((long)MemoryUsage.sizeOfDoubleArray((long)config.estimationFeatureDimension()))).perThread("concurrentBatches", MemoryEstimations.builder().add(GraphSageHelper.embeddingsEstimation(config, config.batchSize(), nodeCount, 0, false)).build());
        if (!mutate) {
            builder = builder.add("resultFeatures", HugeObjectArray.memoryEstimation((long)MemoryUsage.sizeOfDoubleArray((long)config.embeddingDimension())));
        }
        return builder.endField().build();
    }
}

