/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.List;
import java.util.concurrent.ExecutorService;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSage;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrain;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.algo.MultiLabelGraphSageTrain;
import org.neo4j.gds.embeddings.graphsage.algo.SingleLabelGraphSageTrain;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.EmbeddingUtils;

public final class GraphSageTrainAlgorithmFactory
extends GraphAlgorithmFactory<GraphSageTrain, GraphSageTrainConfig> {
    public String taskName() {
        return GraphSageTrain.class.getSimpleName();
    }

    public GraphSageTrain build(Graph graph, GraphSageTrainConfig configuration, ProgressTracker progressTracker) {
        ExecutorService executorService = Pools.DEFAULT;
        if (configuration.hasRelationshipWeightProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue((Graph)graph, (int)configuration.concurrency(), (ExecutorService)executorService);
        }
        return configuration.isMultiLabel() ? new MultiLabelGraphSageTrain(graph, configuration, executorService, progressTracker) : new SingleLabelGraphSageTrain(graph, configuration, executorService, progressTracker);
    }

    public MemoryEstimation memoryEstimation(GraphSageTrainConfig configuration) {
        return MemoryEstimations.setup((String)"", graphDimensions -> this.estimate(configuration, graphDimensions.nodeCount(), graphDimensions.estimationNodeLabelCount()));
    }

    public Task progressTask(Graph graph, GraphSageTrainConfig config) {
        return Tasks.task((String)this.taskName(), GraphSageModelTrainer.progressTasks(config));
    }

    private MemoryEstimation estimate(GraphSageTrainConfig config, long nodeCount, int labelCount) {
        List<LayerConfig> layerConfigs = config.layerConfigs(config.estimationFeatureDimension());
        int numberOfLayers = layerConfigs.size();
        MemoryEstimations.Builder layerBuilder = MemoryEstimations.builder((String)"GraphSageTrain").startField("residentMemory").startField("weights");
        long initialAdamOptimizer = 0L;
        long updateAdamOptimizer = 0L;
        for (int i = 0; i < numberOfLayers; ++i) {
            LayerConfig layerConfig = layerConfigs.get(i);
            int weightDimensions = layerConfig.rows() * layerConfig.cols();
            long weightsMemory = MemoryUsage.sizeOfDoubleArray((long)weightDimensions);
            if (layerConfig.aggregatorType() == Aggregator.AggregatorType.POOL) {
                weightsMemory += MemoryUsage.sizeOfDoubleArray((long)((long)layerConfig.rows() * (long)layerConfig.rows()));
                weightsMemory += MemoryUsage.sizeOfDoubleArray((long)((long)layerConfig.rows() * (long)layerConfig.rows()));
                weightsMemory += MemoryUsage.sizeOfDoubleArray((long)layerConfig.rows());
            }
            layerBuilder.fixed("layer " + (i + 1), weightsMemory);
            initialAdamOptimizer += 2L * MemoryUsage.sizeOfDoubleArray((long)weightDimensions);
            updateAdamOptimizer += 5L * (long)weightDimensions;
        }
        boolean isMultiLabel = config.isMultiLabel();
        MemoryRange perNodeFeaturesMemory = MemoryRange.of((long)MemoryUsage.sizeOfDoubleArray((long)(isMultiLabel ? 1L : (long)config.estimationFeatureDimension())), (long)MemoryUsage.sizeOfDoubleArray((long)config.estimationFeatureDimension()));
        MemoryEstimation initialFeaturesMemory = HugeObjectArray.memoryEstimation((MemoryEstimation)MemoryEstimations.of((String)"", (MemoryRange)perNodeFeaturesMemory));
        MemoryEstimations.Builder estimationsBuilder = layerBuilder.endField().endField().startField("temporaryMemory").field("this.instance", GraphSage.class);
        if (isMultiLabel) {
            int minNumProperties = 1;
            int maxNumProperties = config.featureProperties().size();
            long minWeightsMemory = MemoryUsage.sizeOfDoubleArray((long)(config.estimationFeatureDimension() * minNumProperties));
            long maxWeightsMemory = MemoryUsage.sizeOfDoubleArray((long)((long)config.estimationFeatureDimension() * (long)(++maxNumProperties)));
            MemoryRange weightByLabelMemory = MemoryRange.of((long)minWeightsMemory, (long)maxWeightsMemory).times((long)labelCount);
            estimationsBuilder.fixed("weightsByLabel", weightByLabelMemory);
        }
        return estimationsBuilder.add("initialFeatures", initialFeaturesMemory).startField("trainOnEpoch").fixed("initialAdamOptimizer", initialAdamOptimizer).perThread("concurrentBatches", MemoryEstimations.builder().startField("trainOnBatch").add(GraphSageHelper.embeddingsEstimation(config, 3L * (long)config.batchSize(), nodeCount, labelCount, true)).fixed("updateAdamOptimizer", updateAdamOptimizer).endField().build()).endField().endField().build();
    }
}

