/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.wcc;

import com.carrotsearch.hppc.LongIntHashMap;
import com.carrotsearch.hppc.cursors.LongIntCursor;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.function.LongConsumer;
import java.util.stream.Collectors;
import org.immutables.builder.Builder;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.RelationshipConsumer;
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.wcc.Wcc;

final class SampledStrategy {
    private static final int NEIGHBOR_ROUNDS = 2;
    private static final int SAMPLING_SIZE = 1024;
    private final Graph graph;
    private final DisjointSetStruct disjointSetStruct;
    private final int concurrency;
    private final Optional<Double> threshold;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;
    private final ExecutorService executorService;

    @Builder.Constructor
    SampledStrategy(Graph graph, DisjointSetStruct disjointSetStruct, int concurrency, Optional<Double> threshold, TerminationFlag terminationFlag, ProgressTracker progressTracker, ExecutorService executorService) {
        this.graph = graph;
        this.disjointSetStruct = disjointSetStruct;
        this.concurrency = concurrency;
        this.threshold = threshold;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
        this.executorService = executorService;
    }

    void compute() {
        List partitions = PartitionUtils.rangePartition((int)this.concurrency, (long)this.graph.nodeCount(), Function.identity(), Optional.empty());
        this.sampleSubgraph(this.disjointSetStruct, partitions);
        long largestComponent = this.findLargestComponent(this.disjointSetStruct);
        this.linkRemaining(this.disjointSetStruct, partitions, largestComponent);
    }

    private void sampleSubgraph(DisjointSetStruct components, List<Partition> partitions) {
        List tasks = partitions.stream().map(partition -> this.threshold.isPresent() ? new SamplingWithThresholdTask(this.graph, this.threshold.get(), (Partition)partition, this.disjointSetStruct, this.progressTracker, this.terminationFlag) : new SamplingTask(this.graph, (Partition)partition, components, this.progressTracker, this.terminationFlag)).collect(Collectors.toList());
        ParallelUtil.run(tasks, (ExecutorService)this.executorService);
    }

    private long findLargestComponent(DisjointSetStruct components) {
        SplittableRandom random = new SplittableRandom();
        LongIntHashMap sampleCounts = new LongIntHashMap();
        for (int i = 0; i < 1024; ++i) {
            long node = random.nextLong(this.graph.nodeCount());
            sampleCounts.addTo(components.setIdOf(node), 1);
        }
        int max = -1;
        long mostFrequent = -1L;
        for (LongIntCursor entry : sampleCounts) {
            long component = entry.key;
            int count = entry.value;
            if (count <= max) continue;
            max = count;
            mostFrequent = component;
        }
        return mostFrequent;
    }

    private void linkRemaining(DisjointSetStruct components, List<Partition> partitions, long largestComponent) {
        List tasks = partitions.stream().map(partition -> this.threshold.isPresent() ? new LinkWithThresholdTask(this.graph, this.threshold.get(), (Partition)partition, largestComponent, components, this.progressTracker, this.terminationFlag) : new LinkTask(this.graph, (Partition)partition, largestComponent, components, this.progressTracker, this.terminationFlag)).collect(Collectors.toList());
        ParallelUtil.run(tasks, (ExecutorService)this.executorService);
    }

    static final class LinkWithThresholdTask
    extends LinkTask
    implements RelationshipWithPropertyConsumer {
        private final double threshold;
        private final RelationshipWithPropertyConsumer inverseConsumer;

        LinkWithThresholdTask(Graph graph, double threshold, Partition partition, long skipComponent, DisjointSetStruct components, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            super(graph, partition, skipComponent, components, progressTracker, terminationFlag);
            this.threshold = threshold;
            this.inverseConsumer = graph.characteristics().isInverseIndexed() ? (sourceNodeId, targetNodeId, property) -> {
                if (property > threshold) {
                    components.union(sourceNodeId, targetNodeId);
                }
                return true;
            } : null;
        }

        @Override
        void link(long node) {
            this.graph.forEachRelationship(node, Wcc.defaultWeight(this.threshold), (RelationshipWithPropertyConsumer)this);
        }

        @Override
        void linkInverse(long node) {
            this.graph.forEachInverseRelationship(node, Wcc.defaultWeight(this.threshold), this.inverseConsumer);
        }

        public boolean accept(long sourceNodeId, long targetNodeId, double property) {
            if (property > this.threshold) {
                ++this.skip;
                if (this.skip > 2L) {
                    this.components.union(sourceNodeId, targetNodeId);
                }
            }
            return true;
        }
    }

    static class LinkTask
    implements Runnable,
    RelationshipConsumer {
        final Graph graph;
        final DisjointSetStruct components;
        long skip;
        private final RelationshipConsumer inverseConsumer;
        private final long skipComponent;
        private final Partition partition;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;

        LinkTask(Graph graph, Partition partition, long skipComponent, DisjointSetStruct components, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graph = graph.concurrentCopy();
            this.skipComponent = skipComponent;
            this.partition = partition;
            this.components = components;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
            this.inverseConsumer = graph.characteristics().isInverseIndexed() ? (sourceNodeId, targetNodeId) -> {
                components.union(sourceNodeId, targetNodeId);
                return true;
            } : null;
        }

        @Override
        public void run() {
            long startNode = this.partition.startNode();
            long endNode = startNode + this.partition.nodeCount();
            LongConsumer linkInverseFn = this.inverseConsumer != null ? this::linkInverse : ignored -> {};
            for (long node = startNode; node < endNode; ++node) {
                if (this.components.setIdOf(node) == this.skipComponent) continue;
                int degree = this.graph.degree(node);
                if (degree > 2) {
                    this.reset();
                    this.link(node);
                    this.progressTracker.logProgress((long)(degree - 2));
                    if (node % 10000L == 0L) {
                        this.terminationFlag.assertRunning();
                    }
                }
                linkInverseFn.accept(node);
            }
        }

        void link(long node) {
            this.graph.forEachRelationship(node, (RelationshipConsumer)this);
        }

        void linkInverse(long node) {
            this.graph.forEachInverseRelationship(node, this.inverseConsumer);
        }

        public boolean accept(long sourceNodeId, long targetNodeId) {
            ++this.skip;
            if (this.skip > 2L) {
                this.components.union(sourceNodeId, targetNodeId);
            }
            return true;
        }

        void reset() {
            this.skip = 0L;
        }
    }

    static final class SamplingWithThresholdTask
    extends SamplingTask
    implements RelationshipWithPropertyConsumer {
        private final double threshold;

        SamplingWithThresholdTask(Graph graph, double threshold, Partition partition, DisjointSetStruct components, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            super(graph, partition, components, progressTracker, terminationFlag);
            this.threshold = threshold;
        }

        @Override
        void sample(long node) {
            this.graph.forEachRelationship(node, Wcc.defaultWeight(this.threshold), (RelationshipWithPropertyConsumer)this);
        }

        public boolean accept(long sourceNodeId, long targetNodeId, double property) {
            if (property > this.threshold) {
                this.components.union(sourceNodeId, targetNodeId);
                --this.limit;
            }
            return this.limit != 0L;
        }
    }

    static class SamplingTask
    implements Runnable,
    RelationshipConsumer {
        final Graph graph;
        final DisjointSetStruct components;
        long limit;
        private final Partition partition;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;

        SamplingTask(Graph graph, Partition partition, DisjointSetStruct components, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graph = graph.concurrentCopy();
            this.partition = partition;
            this.components = components;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
        }

        @Override
        public void run() {
            long startNode = this.partition.startNode();
            long endNode = startNode + this.partition.nodeCount();
            for (long node = startNode; node < endNode; ++node) {
                this.reset();
                this.sample(node);
                if (node % 10000L == 0L) {
                    this.terminationFlag.assertRunning();
                }
                this.progressTracker.logProgress((long)Math.min(2, this.graph.degree(node)));
            }
        }

        void sample(long node) {
            this.graph.forEachRelationship(node, (RelationshipConsumer)this);
        }

        public boolean accept(long sourceNodeId, long targetNodeId) {
            this.components.union(sourceNodeId, targetNodeId);
            --this.limit;
            return this.limit != 0L;
        }

        void reset() {
            this.limit = 2L;
        }
    }
}

