/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.hashgnn;

import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.schema.GraphSchema;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.hashgnn.BinarizeTask;
import org.neo4j.gds.embeddings.hashgnn.DensifyTask;
import org.neo4j.gds.embeddings.hashgnn.GenerateFeaturesTask;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashTask;
import org.neo4j.gds.embeddings.hashgnn.MinHashTask;
import org.neo4j.gds.embeddings.hashgnn.RawFeaturesTask;
import org.neo4j.gds.utils.StringFormatting;

public class HashGNN
extends Algorithm<HashGNNResult> {
    private static final long DEGREE_PARTITIONS_PER_THREAD = 4L;
    private final long randomSeed;
    private final Graph graph;
    private final SplittableRandom rng;
    private final HashGNNConfig config;
    private final MutableLong currentTotalFeatureCount = new MutableLong();

    public HashGNN(Graph graph, HashGNNConfig config, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
        long tempRandomSeed = config.randomSeed().orElse(new SplittableRandom().nextLong());
        this.randomSeed = new SplittableRandom(tempRandomSeed).nextLong();
        this.rng = new SplittableRandom(this.randomSeed);
    }

    public HashGNNResult compute() {
        HugeObjectArray<double[]> outputVectors;
        HugeObjectArray<HugeAtomicBitSet> binaryOutputVectors;
        this.progressTracker.beginSubTask("HashGNN");
        List degreePartition = PartitionUtils.degreePartition((Graph)this.graph, (int)Math.toIntExact(Math.min((long)this.config.concurrency() * 4L, this.graph.nodeCount())), Function.identity(), Optional.of(1));
        List rangePartition = PartitionUtils.rangePartition((int)this.config.concurrency(), (long)this.graph.nodeCount(), Function.identity(), Optional.of(1));
        Graph graphCopy = this.graph.concurrentCopy();
        GraphSchema schema = this.graph.schema();
        List<Graph> graphs = this.config.heterogeneous() ? schema.relationshipSchema().availableTypes().stream().map(rt -> this.graph.relationshipTypeFilteredGraph(Set.of(rt))).collect(Collectors.toList()) : List.of(graphCopy);
        HugeObjectArray<HugeAtomicBitSet> embeddingsB = this.constructInputEmbeddings(rangePartition);
        int embeddingDimension = (int)((HugeAtomicBitSet)embeddingsB.get(0L)).size();
        double avgInputActiveFeatures = this.currentTotalFeatureCount.doubleValue() / (double)this.graph.nodeCount();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Density (number of active features) of binary input features is %.4f.", (Object[])new Object[]{avgInputActiveFeatures}));
        HugeObjectArray<HugeAtomicBitSet> embeddingsA = HugeObjectArray.newArray(HugeAtomicBitSet.class, (long)this.graph.nodeCount());
        embeddingsA.setAll(unused -> HugeAtomicBitSet.create((long)embeddingDimension));
        double avgDegree = (double)this.graph.relationshipCount() / (double)this.graph.nodeCount();
        double upperBoundNeighborExpectedBits = embeddingDimension == 0 ? 1.0 : (double)embeddingDimension * (1.0 - Math.pow(1.0 - 1.0 / (double)embeddingDimension, avgDegree));
        this.progressTracker.beginSubTask("Propagate embeddings");
        for (int iteration = 0; iteration < this.config.iterations(); ++iteration) {
            this.terminationFlag.assertRunning();
            HugeObjectArray<HugeAtomicBitSet> currentEmbeddings = iteration % 2 == 0 ? embeddingsA : embeddingsB;
            HugeObjectArray<HugeAtomicBitSet> previousEmbeddings = iteration % 2 == 0 ? embeddingsB : embeddingsA;
            for (long i = 0L; i < currentEmbeddings.size(); ++i) {
                ((HugeAtomicBitSet)currentEmbeddings.get(i)).clear();
            }
            double scaledNeighborInfluence = this.graph.relationshipCount() == 0L ? 1.0 : this.currentTotalFeatureCount.doubleValue() / (double)this.graph.nodeCount() * this.config.neighborInfluence() / upperBoundNeighborExpectedBits;
            this.currentTotalFeatureCount.setValue(0L);
            List<HashTask.Hashes> hashes = HashTask.compute(embeddingDimension, scaledNeighborInfluence, graphs.size(), this.config, this.randomSeed + (long)(this.config.embeddingDensity() * iteration), this.terminationFlag, this.progressTracker);
            MinHashTask.compute(degreePartition, graphs, this.config, embeddingDimension, currentEmbeddings, previousEmbeddings, hashes, this.progressTracker, this.terminationFlag, this.currentTotalFeatureCount);
            double avgActiveFeatures = this.currentTotalFeatureCount.doubleValue() / (double)this.graph.nodeCount();
            this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"After iteration %d average node embedding density (number of active features) is %.4f.", (Object[])new Object[]{iteration, avgActiveFeatures}));
        }
        this.progressTracker.endSubTask("Propagate embeddings");
        HugeObjectArray<HugeAtomicBitSet> hugeObjectArray = binaryOutputVectors = (this.config.iterations() - 1) % 2 == 0 ? embeddingsA : embeddingsB;
        if (this.config.outputDimension().isPresent()) {
            outputVectors = DensifyTask.compute(this.graph, rangePartition, this.config, this.rng, binaryOutputVectors, this.progressTracker, this.terminationFlag);
        } else {
            outputVectors = HugeObjectArray.newArray(double[].class, (long)this.graph.nodeCount());
            outputVectors.setAll(nodeId -> this.bitSetToArray((HugeAtomicBitSet)binaryOutputVectors.get(nodeId), embeddingDimension));
        }
        this.progressTracker.endSubTask("HashGNN");
        return new HashGNNResult(outputVectors);
    }

    private double[] bitSetToArray(HugeAtomicBitSet bitSet, int dimension) {
        double[] array = new double[dimension];
        bitSet.forEachSetBit(bit -> {
            array[(int)bit] = 1.0;
        });
        return array;
    }

    public void release() {
    }

    private HugeObjectArray<HugeAtomicBitSet> constructInputEmbeddings(List<Partition> partition) {
        if (!this.config.featureProperties().isEmpty()) {
            if (this.config.binarizeFeatures().isPresent()) {
                return BinarizeTask.compute(this.graph, partition, this.config, this.rng, this.progressTracker, this.terminationFlag, this.currentTotalFeatureCount);
            }
            return RawFeaturesTask.compute(this.config, this.progressTracker, this.graph, partition, this.terminationFlag, this.currentTotalFeatureCount);
        }
        return GenerateFeaturesTask.compute(this.graph, partition, this.config, this.randomSeed, this.progressTracker, this.terminationFlag, this.currentTotalFeatureCount);
    }

    public static class HashGNNResult {
        private final HugeObjectArray<double[]> embeddings;

        public HashGNNResult(HugeObjectArray<double[]> embeddings) {
            this.embeddings = embeddings;
        }

        public HugeObjectArray<double[]> embeddings() {
            return this.embeddings;
        }
    }

    static final class MinAndArgmin {
        public int min = -1;
        public int argMin = Integer.MAX_VALUE;

        MinAndArgmin() {
        }
    }
}

