/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.steiner;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.DoubleArrayDeque;
import com.carrotsearch.hppc.LongArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.LongAdder;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.paged.HugeAtomicDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeAtomicLongArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.queue.HugeLongPriorityQueue;
import org.neo4j.gds.paths.ImmutablePathResult;
import org.neo4j.gds.paths.PathResult;
import org.neo4j.gds.paths.delta.TentativeDistances;
import org.neo4j.gds.paths.dijkstra.DijkstraResult;
import org.neo4j.gds.steiner.SteinerBasedDeltaTask;

public final class SteinerBasedDeltaStepping
extends Algorithm<DijkstraResult> {
    static final int NO_BIN = Integer.MAX_VALUE;
    private static final long NO_TERMINAL = -1L;
    public static final int BIN_SIZE_THRESHOLD = 1000;
    private final Graph graph;
    private final long startNode;
    private final double delta;
    private final int concurrency;
    private final HugeLongArray frontier;
    private final TentativeDistances distances;
    private final ExecutorService executorService;
    private long pathIndex;
    private final long numOfTerminals;
    private final BitSet unvisitedTerminal;
    private final BitSet mergedWithSource;
    private final LongAdder metTerminals;
    private final int binSizeThreshold;
    private static final long[] EMPTY_ARRAY = new long[0];

    SteinerBasedDeltaStepping(Graph graph, long startNode, double delta, BitSet isTerminal, int concurrency, int binSizeThreshold, ExecutorService executorService, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.startNode = startNode;
        this.delta = delta;
        this.concurrency = concurrency;
        this.executorService = executorService;
        this.frontier = HugeLongArray.newArray((long)graph.relationshipCount());
        this.distances = TentativeDistances.distanceAndPredecessors(graph.nodeCount(), concurrency);
        this.mergedWithSource = new BitSet(graph.nodeCount());
        this.unvisitedTerminal = new BitSet(isTerminal.size());
        this.unvisitedTerminal.or(isTerminal);
        this.pathIndex = 0L;
        this.metTerminals = new LongAdder();
        this.numOfTerminals = isTerminal.cardinality();
        this.binSizeThreshold = binSizeThreshold;
    }

    private void mergeNodesOnPathToSource(long nodeId, AtomicLong frontierIndex) {
        long currentId = nodeId;
        while (!this.mergedWithSource.getAndSet(currentId)) {
            long predecessor = this.distances.predecessor(currentId);
            this.distances.set(currentId, predecessor, 0.0);
            this.frontier.set(frontierIndex.getAndIncrement(), currentId);
            currentId = predecessor;
        }
    }

    private void relaxPhase(List<SteinerBasedDeltaTask> tasks, int currentBin, AtomicLong frontierSize) {
        for (SteinerBasedDeltaTask task : tasks) {
            task.setPhase(Phase.RELAX);
            task.setBinIndex(currentBin);
            task.setFrontierLength(frontierSize.longValue());
        }
        ParallelUtil.run(tasks, (ExecutorService)this.executorService);
    }

    private void syncPhase(List<SteinerBasedDeltaTask> tasks, int currentBin, AtomicLong frontierIndex) {
        frontierIndex.set(0L);
        tasks.forEach(task -> task.setPhase(Phase.SYNC));
        for (SteinerBasedDeltaTask task2 : tasks) {
            task2.setPhase(Phase.SYNC);
            task2.setBinIndex(currentBin);
        }
        ParallelUtil.run(tasks, (ExecutorService)this.executorService);
    }

    private long nextTerminal(HugeLongPriorityQueue terminalQueue) {
        return terminalQueue.isEmpty() ? -1L : terminalQueue.top();
    }

    private boolean updateSteinerTree(long terminalId, AtomicLong frontierIndex, List<PathResult> paths, ImmutablePathResult.Builder pathResultBuilder) {
        paths.add(SteinerBasedDeltaStepping.pathResult(pathResultBuilder, this.pathIndex++, terminalId, this.distances.distances(), this.distances.predecessors().get(), this.mergedWithSource));
        frontierIndex.set(0L);
        this.metTerminals.increment();
        this.unvisitedTerminal.flip(terminalId);
        this.progressTracker.logProgress();
        if (this.metTerminals.longValue() == this.numOfTerminals) {
            return true;
        }
        this.mergeNodesOnPathToSource(terminalId, frontierIndex);
        return false;
    }

    private boolean ensureShortest(double distance, long oldBin, long currentBin, List<SteinerBasedDeltaTask> tasks) {
        if (currentBin == Integer.MAX_VALUE) {
            return true;
        }
        if (oldBin == currentBin) {
            if (distance >= (double)(currentBin + 1L) * this.delta) {
                return false;
            }
            double currentMinDistance = tasks.stream().mapToDouble(SteinerBasedDeltaTask::getSmallestConsideredDistance).min().orElseThrow();
            return distance <= currentMinDistance;
        }
        return distance < (double)currentBin * this.delta;
    }

    private long tryToUpdateSteinerTree(long oldBin, long currentBin, HugeLongPriorityQueue terminalQueue, List<SteinerBasedDeltaTask> tasks) {
        long terminalId = this.nextTerminal(terminalQueue);
        if (terminalId == -1L) {
            return -1L;
        }
        boolean shouldReturnTerminal = this.ensureShortest(this.distances.distance(terminalId), oldBin, currentBin, tasks);
        return shouldReturnTerminal ? terminalId : -1L;
    }

    public DijkstraResult compute() {
        int currentBin = 0;
        ImmutablePathResult.Builder pathResultBuilder = ImmutablePathResult.builder().sourceNode(this.startNode);
        AtomicLong frontierIndex = new AtomicLong(0L);
        AtomicLong frontierSize = new AtomicLong(1L);
        ArrayList<PathResult> paths = new ArrayList<PathResult>();
        this.frontier.set((long)currentBin, this.startNode);
        this.mergedWithSource.set(this.startNode);
        this.distances.set(this.startNode, -1L, 0.0);
        HugeLongPriorityQueue terminalQueue = HugeLongPriorityQueue.min((long)this.unvisitedTerminal.size());
        ReentrantLock terminalQueueLock = new ReentrantLock();
        List<SteinerBasedDeltaTask> tasks = IntStream.range(0, this.concurrency).mapToObj(i -> new SteinerBasedDeltaTask(this.graph.concurrentCopy(), this.frontier, this.distances, this.delta, frontierIndex, this.mergedWithSource, terminalQueue, terminalQueueLock, this.unvisitedTerminal, this.binSizeThreshold)).collect(Collectors.toList());
        boolean shouldBreak = false;
        while (currentBin != Integer.MAX_VALUE && !shouldBreak) {
            this.relaxPhase(tasks, currentBin, frontierSize);
            long oldCurrentBin = currentBin;
            currentBin = tasks.stream().mapToInt(SteinerBasedDeltaTask::minNonEmptyBin).min().orElseThrow();
            long terminalId = this.tryToUpdateSteinerTree(oldCurrentBin, currentBin, terminalQueue, tasks);
            if (terminalId != -1L) {
                terminalQueue.pop();
                shouldBreak = this.updateSteinerTree(terminalId, frontierIndex, paths, pathResultBuilder);
                currentBin = 0;
            } else {
                this.syncPhase(tasks, currentBin, frontierIndex);
            }
            frontierSize.set(frontierIndex.longValue());
            frontierIndex.set(0L);
        }
        return new DijkstraResult(paths.stream());
    }

    public void release() {
    }

    private static PathResult pathResult(ImmutablePathResult.Builder pathResultBuilder, long pathIndex, long targetNode, HugeAtomicDoubleArray distances, HugeAtomicLongArray predecessors, BitSet mergedWithSource) {
        LongArrayDeque pathNodeIds = new LongArrayDeque();
        DoubleArrayDeque costs = new DoubleArrayDeque();
        long lastNode = targetNode;
        while (true) {
            pathNodeIds.addFirst(lastNode);
            if (mergedWithSource.get(lastNode)) break;
            costs.addFirst(distances.get(lastNode));
            lastNode = predecessors.get(lastNode);
        }
        return pathResultBuilder.index(pathIndex).targetNode(targetNode).nodeIds(pathNodeIds.toArray()).relationshipIds(EMPTY_ARRAY).costs(costs.toArray()).build();
    }

    static enum Phase {
        RELAX,
        SYNC;

    }
}

