/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.hashgnn;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.DegreePartition;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.hashgnn.HashGNN;
import org.neo4j.gds.embeddings.hashgnn.HashGNNCompanion;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.HashTask;

class MinHashTask
implements Runnable {
    private final List<HashTask.Hashes> hashes;
    private final int k;
    private final int embeddingDimension;
    private final DegreePartition partition;
    private final List<Graph> concurrentGraphs;
    private final HugeObjectArray<HugeAtomicBitSet> currentEmbeddings;
    private final HugeObjectArray<HugeAtomicBitSet> previousEmbeddings;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;
    private long totalFeatureCount = 0L;

    MinHashTask(int k, DegreePartition partition, List<Graph> graphs, int embeddingDimension, HugeObjectArray<HugeAtomicBitSet> currentEmbeddings, HugeObjectArray<HugeAtomicBitSet> previousEmbeddings, List<HashTask.Hashes> hashes, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        this.k = k;
        this.partition = partition;
        this.concurrentGraphs = graphs.stream().map(Graph::concurrentCopy).collect(Collectors.toList());
        this.embeddingDimension = embeddingDimension;
        this.currentEmbeddings = currentEmbeddings;
        this.previousEmbeddings = previousEmbeddings;
        this.hashes = hashes;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
    }

    static void compute(List<DegreePartition> degreePartition, List<Graph> graphs, HashGNNConfig config, int embeddingDimension, HugeObjectArray<HugeAtomicBitSet> currentEmbeddings, HugeObjectArray<HugeAtomicBitSet> previousEmbeddings, List<HashTask.Hashes> hashes, ProgressTracker progressTracker, TerminationFlag terminationFlag, MutableLong totalFeatureCountOutput) {
        progressTracker.beginSubTask("Perform min-hashing");
        progressTracker.setSteps((long)config.embeddingDensity() * graphs.get(0).nodeCount());
        List tasks = IntStream.range(0, config.embeddingDensity()).mapToObj(k -> degreePartition.stream().map(p -> new MinHashTask(k, (DegreePartition)p, graphs, embeddingDimension, currentEmbeddings, previousEmbeddings, hashes, terminationFlag, progressTracker))).flatMap(Function.identity()).collect(Collectors.toList());
        RunWithConcurrency.builder().concurrency(config.concurrency()).tasks(tasks).terminationFlag(terminationFlag).run();
        totalFeatureCountOutput.add(tasks.stream().mapToLong(MinHashTask::totalFeatureCount).sum());
        progressTracker.endSubTask("Perform min-hashing");
    }

    @Override
    public void run() {
        BitSet neighborsVector = new BitSet((long)this.embeddingDimension);
        HashGNN.MinAndArgmin selfMinAndArgMin = new HashGNN.MinAndArgmin();
        HashGNN.MinAndArgmin neighborsMinAndArgMin = new HashGNN.MinAndArgmin();
        HashGNN.MinAndArgmin tempMinAndArgMin = new HashGNN.MinAndArgmin();
        this.terminationFlag.assertRunning();
        HashTask.Hashes hashesForK = this.hashes.get(this.k);
        int[] neighborsAggregationHashes = hashesForK.neighborsAggregationHashes();
        int[] selfAggregationHashes = hashesForK.selfAggregationHashes();
        List<int[]> preAggregationHashes = hashesForK.preAggregationHashes();
        this.partition.consume(nodeId -> {
            int argMin;
            HugeAtomicBitSet currentEmbedding = (HugeAtomicBitSet)this.currentEmbeddings.get(nodeId);
            HashGNNCompanion.hashArgMin((HugeAtomicBitSet)this.previousEmbeddings.get(nodeId), selfAggregationHashes, selfMinAndArgMin, tempMinAndArgMin);
            neighborsVector.clear();
            for (int i = 0; i < this.concurrentGraphs.size(); ++i) {
                int[] preAggregationHashesForRel = (int[])preAggregationHashes.get(i);
                Graph currentGraph = this.concurrentGraphs.get(i);
                currentGraph.forEachRelationship(nodeId, (src, trg) -> {
                    HugeAtomicBitSet prevTargetEmbedding = (HugeAtomicBitSet)this.previousEmbeddings.get(trg);
                    HashGNNCompanion.hashArgMin(prevTargetEmbedding, preAggregationHashesForRel, neighborsMinAndArgMin, tempMinAndArgMin);
                    int argMin = neighborsMinAndArgMin.argMin;
                    if (argMin != -1) {
                        neighborsVector.set((long)argMin);
                    }
                    return true;
                });
            }
            HashGNNCompanion.hashArgMin(neighborsVector, neighborsAggregationHashes, neighborsMinAndArgMin);
            int n = argMin = neighborsMinAndArgMin.min < selfMinAndArgMin.min ? neighborsMinAndArgMin.argMin : selfMinAndArgMin.argMin;
            if (argMin != -1 && !currentEmbedding.getAndSet((long)argMin)) {
                ++this.totalFeatureCount;
            }
        });
        this.progressTracker.logSteps(this.partition.nodeCount());
    }

    public long totalFeatureCount() {
        return this.totalFeatureCount;
    }
}

