/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.graphalgo.impl.multistepscc;

import com.carrotsearch.hppc.IntContainer;
import com.carrotsearch.hppc.IntScatterSet;
import com.carrotsearch.hppc.IntSet;
import com.carrotsearch.hppc.IntStack;
import com.carrotsearch.hppc.cursors.IntCursor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.function.IntPredicate;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.container.AtomicBitSet;
import org.neo4j.graphalgo.core.utils.container.FlipStack;
import org.neo4j.graphdb.Direction;
import org.neo4j.helpers.Exceptions;
import org.neo4j.kernel.impl.util.collection.SimpleBitSet;

public class MultiStepColoring {
    public static final int MIN_BATCH_SIZE = 100000;
    private final Graph graph;
    private final ExecutorService executorService;
    private final AtomicIntegerArray colors;
    private final AtomicBitSet visited;
    private final List<Future<IntContainer>> futures = new ArrayList<Future<IntContainer>>();
    private final int concurrency;
    private final int nodeCount;

    public MultiStepColoring(Graph graph, ExecutorService executorService, int concurrency) {
        this.graph = graph;
        this.nodeCount = Math.toIntExact(graph.nodeCount());
        this.executorService = executorService;
        this.concurrency = concurrency;
        this.colors = new AtomicIntegerArray(this.nodeCount);
        this.visited = new AtomicBitSet(this.nodeCount);
    }

    public MultiStepColoring compute(IntSet nodes) {
        this.resetColors(nodes);
        this.msColorParallel(nodes);
        return this;
    }

    public AtomicIntegerArray getColors() {
        return this.colors;
    }

    public void forEachColor(IntPredicate consumer) {
        SimpleBitSet bitSet = new SimpleBitSet(this.nodeCount);
        for (int i = 0; i < this.nodeCount; ++i) {
            int color = this.colors.get(i);
            if (bitSet.contains(color)) continue;
            bitSet.put(color);
            if (consumer.test(color)) continue;
            return;
        }
    }

    private void msColorParallel(IntSet nodeSet) {
        FlipStack flipStack = new FlipStack(nodeSet);
        flipStack.flip();
        while (!flipStack.isEmpty()) {
            this.futures.clear();
            int size = flipStack.popStack().size();
            int batchSize = Math.floorDiv(size, this.concurrency);
            if (this.concurrency <= 1 || batchSize < 100000) {
                this.futures.add(this.executorService.submit(() -> this.msColorTask(flipStack.popStack())));
            } else {
                Iterator<IntCursor> it = flipStack.popStack().iterator();
                for (int i = 0; i < size; i += batchSize) {
                    IntScatterSet partition = this.partition(it, batchSize);
                    this.futures.add(this.executorService.submit(() -> this.msColorTask(partition)));
                }
            }
            flipStack.pushStack().clear();
            this.union(flipStack.pushStack(), this.futures);
            flipStack.flip();
        }
    }

    private IntContainer msColorTask(IntContainer nodes) {
        IntStack levelQueue = new IntStack(nodes.size());
        nodes.forEach(node -> {
            int nodeColor = this.colors.get(node);
            boolean[] change = new boolean[]{false};
            this.graph.forEachRelationship(node, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId) -> {
                if (this.cas(targetNodeId, nodeColor) && !this.visited.get(targetNodeId)) {
                    this.visited.set(targetNodeId);
                    levelQueue.push(targetNodeId);
                    change[0] = true;
                }
                return true;
            });
            if (change[0] && !this.visited.get(node)) {
                levelQueue.push(node);
                this.visited.set(node);
            }
        });
        return levelQueue;
    }

    private void msColorSequential(IntSet nodes) {
        FlipStack queue = new FlipStack(nodes.size());
        queue.addAll(nodes);
        queue.flip();
        while (!queue.isEmpty()) {
            queue.forEach(node -> {
                int nodeColor = this.colors.get(node);
                boolean[] change = new boolean[]{false};
                this.graph.forEachRelationship(node, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId) -> {
                    if (this.cas(targetNodeId, nodeColor) && !this.visited.get(targetNodeId)) {
                        this.visited.set(targetNodeId);
                        queue.push(targetNodeId);
                        change[0] = true;
                    }
                    return true;
                });
                if (change[0] && !this.visited.get(node)) {
                    queue.push(node);
                    this.visited.set(node);
                }
            });
            queue.popStack().clear();
            queue.flip();
        }
    }

    private void simpleColor(IntSet nodes) {
        boolean[] changed = new boolean[]{false};
        do {
            changed[0] = false;
            nodes.forEach(node -> {
                int nodeColor = this.colors.get(node);
                this.graph.forEachRelationship(node, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId) -> {
                    if (this.cas(targetNodeId, nodeColor)) {
                        changed[0] = true;
                    }
                    return true;
                });
            });
        } while (changed[0]);
    }

    private void resetColors(IntContainer nodes) {
        nodes.forEach(node -> this.colors.set(node, node));
    }

    private IntScatterSet partition(Iterator<IntCursor> it, int batchSize) {
        IntScatterSet partition = new IntScatterSet(batchSize);
        for (int j = 0; j < batchSize && it.hasNext(); ++j) {
            partition.add(it.next().value);
        }
        return partition;
    }

    private boolean cas(int nodeId, int color) {
        int oldC;
        boolean stored = false;
        while (!stored && color > (oldC = this.colors.get(nodeId))) {
            stored = this.colors.compareAndSet(nodeId, oldC, color);
        }
        return stored;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void union(IntStack ret, Collection<Future<IntContainer>> futures) {
        Throwable error;
        block11: {
            boolean done = false;
            error = null;
            try {
                for (Future<IntContainer> future : futures) {
                    try {
                        future.get().forEach(ret::add);
                    }
                    catch (ExecutionException ee) {
                        error = Exceptions.chain((Throwable)error, (Throwable)ee.getCause());
                    }
                    catch (CancellationException cancellationException) {}
                }
                done = true;
            }
            catch (InterruptedException e) {
                error = Exceptions.chain((Throwable)e, error);
            }
            finally {
                if (done) break block11;
                for (Future<IntContainer> future : futures) {
                    future.cancel(true);
                }
            }
        }
        if (error != null) {
            throw Exceptions.launderedException((Throwable)error);
        }
    }
}

