/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.wcc;

import com.carrotsearch.hppc.LongIntHashMap;
import com.carrotsearch.hppc.cursors.LongIntCursor;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.NodeProperties;
import org.neo4j.gds.api.RelationshipConsumer;
import org.neo4j.gds.api.RelationshipIterator;
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct;
import org.neo4j.gds.core.utils.paged.dss.HugeAtomicDisjointSetStruct;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.wcc.WccBaseConfig;

public class Wcc
extends Algorithm<DisjointSetStruct> {
    private static final int NEIGHBOR_ROUNDS = 2;
    private static final int SAMPLING_SIZE = 1024;
    private final WccBaseConfig config;
    private final NodeProperties initialComponents;
    private final ExecutorService executor;
    private final long nodeCount;
    private final long batchSize;
    private final int threadSize;
    private Graph graph;

    public static MemoryEstimation memoryEstimation(boolean incremental) {
        return MemoryEstimations.builder(Wcc.class).add("dss", HugeAtomicDisjointSetStruct.memoryEstimation((boolean)incremental)).build();
    }

    public Wcc(Graph graph, ExecutorService executor, int minBatchSize, WccBaseConfig config, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
        this.initialComponents = config.isIncremental() ? graph.nodeProperties(config.seedProperty()) : null;
        this.executor = executor;
        this.nodeCount = graph.nodeCount();
        this.batchSize = ParallelUtil.adjustedBatchSize((long)this.nodeCount, (int)config.concurrency(), (long)minBatchSize, (long)Integer.MAX_VALUE);
        long threadSize = ParallelUtil.threadCount((long)this.batchSize, (long)this.nodeCount);
        if (threadSize > Integer.MAX_VALUE) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Too many nodes (%d) to run union find with the given concurrency (%d) and batchSize (%d)", (Object[])new Object[]{this.nodeCount, config.concurrency(), this.batchSize}));
        }
        this.threadSize = (int)threadSize;
    }

    public DisjointSetStruct compute() {
        HugeAtomicDisjointSetStruct dss;
        this.progressTracker.beginSubTask();
        long nodeCount = this.graph.nodeCount();
        HugeAtomicDisjointSetStruct hugeAtomicDisjointSetStruct = dss = this.config.isIncremental() ? new HugeAtomicDisjointSetStruct(nodeCount, this.initialComponents, this.config.concurrency()) : new HugeAtomicDisjointSetStruct(nodeCount, this.config.concurrency());
        if (this.graph.isUndirected() && !this.config.hasThreshold()) {
            this.computeUndirected((DisjointSetStruct)dss);
        } else {
            this.computeDirected((DisjointSetStruct)dss);
        }
        this.progressTracker.endSubTask();
        return dss;
    }

    public void release() {
        this.graph = null;
    }

    public double threshold() {
        return this.config.threshold();
    }

    private void computeDirected(DisjointSetStruct dss) {
        ArrayList<DirectedUnionTask> tasks = new ArrayList<DirectedUnionTask>(this.threadSize);
        for (long i = 0L; i < this.nodeCount; i += this.batchSize) {
            DirectedUnionTask wccTask = !this.config.hasThreshold() ? new DirectedUnionTask(dss, i) : new DirectedUnionWithThresholdTask(this.threshold(), dss, i);
            tasks.add(wccTask);
        }
        ParallelUtil.run(tasks, (ExecutorService)this.executor);
    }

    private void computeUndirected(DisjointSetStruct components) {
        List partitions = PartitionUtils.rangePartition((int)this.config.concurrency(), (long)this.graph.nodeCount(), Function.identity(), Optional.empty());
        this.sampleSubgraph(components, partitions);
        long largestComponent = this.findLargestComponent(components);
        this.linkRemaining(components, partitions, largestComponent);
    }

    private void sampleSubgraph(DisjointSetStruct components, List<Partition> partitions) {
        List tasks = partitions.stream().map(partition -> new UndirectedSamplingTask(this.graph, (Partition)partition, components, this.progressTracker, (TerminationFlag)this)).collect(Collectors.toList());
        ParallelUtil.run(tasks, (ExecutorService)this.executor);
    }

    private long findLargestComponent(DisjointSetStruct components) {
        SplittableRandom random = new SplittableRandom();
        LongIntHashMap sampleCounts = new LongIntHashMap();
        for (int i = 0; i < 1024; ++i) {
            long node = random.nextLong(this.nodeCount);
            sampleCounts.addTo(components.setIdOf(node), 1);
        }
        int max = -1;
        long mostFrequent = -1L;
        for (LongIntCursor entry : sampleCounts) {
            long component = entry.key;
            int count = entry.value;
            if (count <= max) continue;
            max = count;
            mostFrequent = component;
        }
        return mostFrequent;
    }

    private void linkRemaining(DisjointSetStruct components, List<Partition> partitions, long largestComponent) {
        List tasks = partitions.stream().map(partition -> new UndirectedUnionTask(this.graph, (Partition)partition, largestComponent, components, this.progressTracker, (TerminationFlag)this)).collect(Collectors.toList());
        ParallelUtil.run(tasks, (ExecutorService)this.executor);
    }

    private static double defaultWeight(double threshold) {
        return threshold + 1.0;
    }

    static final class UndirectedUnionTask
    implements Runnable,
    RelationshipConsumer {
        private final Graph graph;
        private final long skipComponent;
        private final Partition partition;
        private final DisjointSetStruct components;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;
        private long skip;

        UndirectedUnionTask(Graph graph, Partition partition, long skipComponent, DisjointSetStruct components, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graph = graph.concurrentCopy();
            this.skipComponent = skipComponent;
            this.partition = partition;
            this.components = components;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
        }

        @Override
        public void run() {
            long startNode = this.partition.startNode();
            long endNode = startNode + this.partition.nodeCount();
            for (long node = startNode; node < endNode; ++node) {
                int degree;
                if (this.components.setIdOf(node) == this.skipComponent || (degree = this.graph.degree(node)) <= 2) continue;
                this.reset();
                this.graph.forEachRelationship(node, (RelationshipConsumer)this);
                this.progressTracker.logProgress((long)(degree - 2));
                if (node % 10000L != 0L) continue;
                this.terminationFlag.assertRunning();
            }
        }

        public boolean accept(long source, long target) {
            ++this.skip;
            if (this.skip > 2L) {
                this.components.union(source, target);
            }
            return true;
        }

        public void reset() {
            this.skip = 0L;
        }
    }

    static final class UndirectedSamplingTask
    implements Runnable,
    RelationshipConsumer {
        private final Graph graph;
        private final Partition partition;
        private final DisjointSetStruct components;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;
        private long limit;

        UndirectedSamplingTask(Graph graph, Partition partition, DisjointSetStruct components, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graph = graph.concurrentCopy();
            this.partition = partition;
            this.components = components;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
        }

        @Override
        public void run() {
            long startNode = this.partition.startNode();
            long endNode = startNode + this.partition.nodeCount();
            for (long node = startNode; node < endNode; ++node) {
                this.reset();
                this.graph.forEachRelationship(node, (RelationshipConsumer)this);
                if (node % 10000L == 0L) {
                    this.terminationFlag.assertRunning();
                }
                this.progressTracker.logProgress((long)Math.min(2, this.graph.degree(node)));
            }
        }

        public boolean accept(long s, long t) {
            this.components.union(s, t);
            --this.limit;
            return this.limit != 0L;
        }

        public void reset() {
            this.limit = 2L;
        }
    }

    private class DirectedUnionWithThresholdTask
    extends DirectedUnionTask
    implements RelationshipWithPropertyConsumer {
        private final double threshold;

        DirectedUnionWithThresholdTask(double threshold, DisjointSetStruct struct, long offset) {
            super(struct, offset);
            this.threshold = threshold;
        }

        @Override
        void compute(long node) {
            this.rels.forEachRelationship(node, Wcc.defaultWeight(this.threshold), (RelationshipWithPropertyConsumer)this);
        }

        public boolean accept(long sourceNodeId, long targetNodeId, double property) {
            if (property > this.threshold) {
                this.struct.union(sourceNodeId, targetNodeId);
            }
            return true;
        }
    }

    private class DirectedUnionTask
    implements Runnable,
    RelationshipConsumer {
        final DisjointSetStruct struct;
        final RelationshipIterator rels;
        private final long offset;
        private final long end;

        DirectedUnionTask(DisjointSetStruct struct, long offset) {
            this.struct = struct;
            this.rels = Wcc.this.graph.concurrentCopy();
            this.offset = offset;
            this.end = Math.min(offset + Wcc.this.batchSize, Wcc.this.nodeCount);
        }

        @Override
        public void run() {
            for (long node = this.offset; node < this.end; ++node) {
                this.compute(node);
                if (node % 10000L == 0L) {
                    Wcc.this.assertRunning();
                }
                Wcc.this.progressTracker.logProgress((long)Wcc.this.graph.degree(node));
            }
        }

        void compute(long node) {
            this.rels.forEachRelationship(node, (RelationshipConsumer)this);
        }

        public boolean accept(long sourceNodeId, long targetNodeId) {
            this.struct.union(sourceNodeId, targetNodeId);
            return true;
        }
    }
}

