/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.FeatureFunction;
import org.neo4j.gds.embeddings.graphsage.Layer;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.embeddings.graphsage.algo.MultiLabelFeatureExtractors;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.features.BiasFeature;
import org.neo4j.gds.ml.core.features.FeatureConsumer;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.features.HugeObjectArrayFeatureConsumer;
import org.neo4j.gds.ml.core.functions.NormalizeRows;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.utils.StringFormatting;

public final class GraphSageHelper {
    private GraphSageHelper() {
    }

    public static Variable<Matrix> embeddingsComputationGraph(Graph graph, boolean useWeights, long[] nodeIds, HugeObjectArray<double[]> features, Layer[] layers, FeatureFunction featureFunction) {
        List<SubGraph> subGraphs = GraphSageHelper.subGraphsPerLayer(graph, useWeights, nodeIds, layers);
        Variable<Matrix> batchedFeaturesExtractor = featureFunction.apply(graph, subGraphs.get(subGraphs.size() - 1).originalNodeIds(), features);
        return GraphSageHelper.embeddingsComputationGraph(subGraphs, layers, batchedFeaturesExtractor);
    }

    public static Variable<Matrix> embeddingsComputationGraph(List<SubGraph> subGraphs, Layer[] layers, Variable<Matrix> batchedFeaturesExtractor) {
        Variable<Matrix> previousLayerRepresentations = batchedFeaturesExtractor;
        for (int layerNr = layers.length - 1; layerNr >= 0; --layerNr) {
            Layer layer = layers[layers.length - layerNr - 1];
            previousLayerRepresentations = layer.aggregator().aggregate(previousLayerRepresentations, subGraphs.get(layerNr));
        }
        return new NormalizeRows(previousLayerRepresentations);
    }

    public static List<SubGraph> subGraphsPerLayer(Graph graph, boolean useWeights, long[] nodeIds, Layer[] layers) {
        List neighborhoodFunctions = Arrays.stream(layers).map(layer -> layer::neighborhoodFunction).collect(Collectors.toList());
        Collections.reverse(neighborhoodFunctions);
        return SubGraph.buildSubGraphs((long[])nodeIds, neighborhoodFunctions, (Graph)graph, (boolean)useWeights);
    }

    public static MemoryEstimation embeddingsEstimation(GraphSageTrainConfig config, long batchSize, long nodeCount, int labelCount, boolean withGradientDescent) {
        boolean isMultiLabel = config.isMultiLabel();
        List<LayerConfig> layerConfigs = config.layerConfigs(config.estimationFeatureDimension());
        int numberOfLayers = layerConfigs.size();
        MemoryEstimations.Builder computationGraphBuilder = MemoryEstimations.builder((String)"computationGraph").startField("subgraphs");
        ArrayList<Long> minBatchNodeCounts = new ArrayList<Long>(numberOfLayers + 1);
        ArrayList<Long> maxBatchNodeCounts = new ArrayList<Long>(numberOfLayers + 1);
        minBatchNodeCounts.add(batchSize);
        maxBatchNodeCounts.add(batchSize);
        for (int i = 0; i < numberOfLayers; ++i) {
            int sampleSize = layerConfigs.get(i).sampleSize();
            Long min = (Long)minBatchNodeCounts.get(i);
            Long max = (Long)maxBatchNodeCounts.get(i);
            long minNextNodeCount = Math.min(min, nodeCount);
            long maxNextNodeCount = Math.min(max * (long)(sampleSize + 1), nodeCount);
            minBatchNodeCounts.add(minNextNodeCount);
            maxBatchNodeCounts.add(maxNextNodeCount);
            MemoryRange subgraphRange = MemoryRange.of((long)(MemoryUsage.sizeOfIntArray((long)min) + MemoryUsage.sizeOfObjectArray((long)min) + min * MemoryUsage.sizeOfIntArray((long)0L) + MemoryUsage.sizeOfLongArray((long)minNextNodeCount)), (long)(MemoryUsage.sizeOfIntArray((long)max) + MemoryUsage.sizeOfObjectArray((long)max) + max * MemoryUsage.sizeOfIntArray((long)sampleSize) + MemoryUsage.sizeOfLongArray((long)maxNextNodeCount)));
            computationGraphBuilder.add(MemoryEstimations.of((String)("subgraph " + (i + 1)), (MemoryRange)subgraphRange));
        }
        Collections.reverse(minBatchNodeCounts);
        Collections.reverse(maxBatchNodeCounts);
        MemoryEstimations.Builder aggregatorsBuilder = MemoryEstimations.builder();
        for (int i = 0; i < numberOfLayers; ++i) {
            LayerConfig layerConfig = layerConfigs.get(i);
            Long minPreviousNodeCount = (Long)minBatchNodeCounts.get(i);
            Long maxPreviousNodeCount = (Long)maxBatchNodeCounts.get(i);
            Long minNodeCount = (Long)minBatchNodeCounts.get(i + 1);
            Long maxNodeCount = (Long)maxBatchNodeCounts.get(i + 1);
            if (i == 0) {
                int featureSize = config.estimationFeatureDimension();
                MemoryRange firstLayerMemory = MemoryRange.of((long)MemoryUsage.sizeOfDoubleArray((long)(minPreviousNodeCount * (long)featureSize)), (long)MemoryUsage.sizeOfDoubleArray((long)(maxPreviousNodeCount * (long)featureSize)));
                if (isMultiLabel) {
                    firstLayerMemory = firstLayerMemory.add(MemoryRange.of((long)MemoryUsage.sizeOfDoubleArray((long)featureSize)));
                }
                aggregatorsBuilder.fixed("firstLayer", firstLayerMemory);
            }
            Aggregator.AggregatorType aggregatorType = layerConfig.aggregatorType();
            int embeddingDimension = config.embeddingDimension();
            aggregatorsBuilder.fixed(StringFormatting.formatWithLocale((String)"%s %d", (Object[])new Object[]{aggregatorType.name(), i + 1}), aggregatorType.memoryEstimation(minNodeCount, maxNodeCount, minPreviousNodeCount, maxPreviousNodeCount, layerConfig.cols(), embeddingDimension));
            if (i != numberOfLayers - 1) continue;
            aggregatorsBuilder.fixed("normalizeRows", MemoryRange.of((long)MemoryUsage.sizeOfDoubleArray((long)(minNodeCount * (long)embeddingDimension)), (long)MemoryUsage.sizeOfDoubleArray((long)(maxNodeCount * (long)embeddingDimension))));
        }
        computationGraphBuilder = computationGraphBuilder.endField();
        if (isMultiLabel) {
            long minFeatureFunction = MemoryUsage.sizeOfObjectArray((long)((Long)minBatchNodeCounts.get(0)));
            long maxFeatureFunction = MemoryUsage.sizeOfObjectArray((long)((Long)maxBatchNodeCounts.get(0)));
            long copyOfLabels = MemoryUsage.sizeOfObjectArray((long)labelCount);
            computationGraphBuilder.fixed("multiLabelFeatureFunction", MemoryRange.of((long)minFeatureFunction, (long)maxFeatureFunction).add(MemoryRange.of((long)copyOfLabels)));
        }
        computationGraphBuilder = computationGraphBuilder.startField("forward").addComponentsOf(aggregatorsBuilder.build());
        if (withGradientDescent) {
            computationGraphBuilder = computationGraphBuilder.endField().startField("backward").addComponentsOf(aggregatorsBuilder.build());
        }
        return computationGraphBuilder.endField().build();
    }

    public static HugeObjectArray<double[]> initializeSingleLabelFeatures(Graph graph, GraphSageTrainConfig config) {
        HugeObjectArray features = HugeObjectArray.newArray(double[].class, (long)graph.nodeCount());
        List<FeatureExtractor> extractors = GraphSageHelper.featureExtractors(graph, config);
        return FeatureExtraction.extract((Graph)graph, extractors, (HugeObjectArray)features);
    }

    public static List<FeatureExtractor> featureExtractors(Graph graph, GraphSageTrainConfig config) {
        return FeatureExtraction.propertyExtractors((Graph)graph, (Collection)config.featureProperties());
    }

    public static MultiLabelFeatureExtractors multiLabelFeatureExtractors(Graph graph, GraphSageTrainConfig config) {
        Map<NodeLabel, Set<String>> filteredKeysPerLabel = GraphSageHelper.filteredPropertyKeysPerNodeLabel(graph, config);
        HashMap<NodeLabel, Integer> featureCountPerLabel = new HashMap<NodeLabel, Integer>();
        HashMap<NodeLabel, List<FeatureExtractor>> extractorsPerLabel = new HashMap<NodeLabel, List<FeatureExtractor>>();
        graph.forEachNode(nodeId -> {
            NodeLabel nodeLabel = GraphSageHelper.labelOf((IdMap)graph, nodeId);
            extractorsPerLabel.computeIfAbsent(nodeLabel, label -> {
                Set propertyKeys = (Set)filteredKeysPerLabel.get(label);
                ArrayList<BiasFeature> featureExtractors = new ArrayList<BiasFeature>(FeatureExtraction.propertyExtractors((Graph)graph, (Collection)propertyKeys, (long)nodeId));
                featureExtractors.add(new BiasFeature());
                return featureExtractors;
            });
            featureCountPerLabel.computeIfAbsent(nodeLabel, label -> FeatureExtraction.featureCount((Collection)((Collection)extractorsPerLabel.get(label))));
            return true;
        });
        return new MultiLabelFeatureExtractors(featureCountPerLabel, extractorsPerLabel);
    }

    public static HugeObjectArray<double[]> initializeMultiLabelFeatures(Graph graph, MultiLabelFeatureExtractors multiLabelFeatureExtractors) {
        HugeObjectArray features = HugeObjectArray.newArray(double[].class, (long)graph.nodeCount());
        HugeObjectArrayFeatureConsumer featureConsumer = new HugeObjectArrayFeatureConsumer(features);
        graph.forEachNode(nodeId -> {
            NodeLabel nodeLabel = GraphSageHelper.labelOf((IdMap)graph, nodeId);
            List<FeatureExtractor> extractors = multiLabelFeatureExtractors.extractorsPerLabel().get(nodeLabel);
            Integer featureCount = multiLabelFeatureExtractors.featureCountPerLabel().get(nodeLabel);
            features.set(nodeId, (Object)new double[featureCount.intValue()]);
            FeatureExtraction.extract((long)nodeId, (long)nodeId, extractors, (FeatureConsumer)featureConsumer);
            return true;
        });
        return features;
    }

    public static Map<NodeLabel, Set<String>> propertyKeysPerNodeLabel(GraphSchema graphSchema) {
        return graphSchema.nodeSchema().properties().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((Map)e.getValue()).keySet()));
    }

    private static Map<NodeLabel, Set<String>> filteredPropertyKeysPerNodeLabel(Graph graph, GraphSageTrainConfig config) {
        return GraphSageHelper.propertyKeysPerNodeLabel(graph.schema()).entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> config.featureProperties().stream().filter(((Set)e.getValue())::contains).collect(Collectors.toSet())));
    }

    private static NodeLabel labelOf(IdMap idMap, long nodeId) {
        AtomicReference labelRef = new AtomicReference();
        MutableInt labelCount = new MutableInt(0);
        idMap.forEachNodeLabel(nodeId, nodeLabel -> {
            labelRef.set(nodeLabel);
            return labelCount.getAndIncrement() == 0;
        });
        if (labelCount.intValue() != 1) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Each node must have exactly one label: nodeId=%d, labels=%s", (Object[])new Object[]{nodeId, idMap.nodeLabels(nodeId)}));
        }
        return (NodeLabel)labelRef.get();
    }
}

