/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.steiner;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang3.mutable.MutableBoolean;
import org.jetbrains.annotations.TestOnly;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeLongArrayQueue;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.paths.PathResult;
import org.neo4j.gds.paths.dijkstra.DijkstraResult;
import org.neo4j.gds.steiner.LinkCutTree;
import org.neo4j.gds.steiner.SteinerBasedDeltaStepping;
import org.neo4j.gds.steiner.SteinerTreeResult;

public class ShortestPathsSteinerAlgorithm
extends Algorithm<SteinerTreeResult> {
    public static final long ROOT_NODE = -1L;
    public static final long PRUNED = -2L;
    private final Graph graph;
    private final long sourceId;
    private final List<Long> terminals;
    private final int concurrency;
    private final BitSet isTerminal;
    private final boolean applyRerouting;
    private final double delta;
    private final ExecutorService executorService;
    private final int binSizeThreshold;

    public ShortestPathsSteinerAlgorithm(Graph graph, long sourceId, List<Long> terminals, double delta, int concurrency, boolean applyRerouting, ExecutorService executorService, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.sourceId = sourceId;
        this.terminals = terminals;
        this.concurrency = concurrency;
        this.delta = delta;
        this.isTerminal = this.createTerminals();
        this.applyRerouting = applyRerouting;
        this.executorService = executorService;
        this.binSizeThreshold = 1000;
    }

    @TestOnly
    ShortestPathsSteinerAlgorithm(Graph graph, long sourceId, List<Long> terminals, double delta, int concurrency, boolean applyRerouting, int binSizeThreshold, ExecutorService executorService, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.sourceId = sourceId;
        this.terminals = terminals;
        this.concurrency = concurrency;
        this.delta = delta;
        this.isTerminal = this.createTerminals();
        this.applyRerouting = applyRerouting;
        this.executorService = executorService;
        this.binSizeThreshold = binSizeThreshold;
    }

    private BitSet createTerminals() {
        long maxTerminalId = -1L;
        for (long terminalId : this.terminals) {
            if (terminalId <= maxTerminalId) continue;
            maxTerminalId = terminalId;
        }
        BitSet terminalBitSet = new BitSet(maxTerminalId + 1L);
        for (long terminalId : this.terminals) {
            terminalBitSet.set(terminalId);
        }
        return terminalBitSet;
    }

    public SteinerTreeResult compute() {
        this.progressTracker.beginSubTask("SteinerTree");
        this.progressTracker.beginSubTask("Traverse");
        HugeLongArray parent = HugeLongArray.newArray((long)this.graph.nodeCount());
        HugeDoubleArray parentCost = HugeDoubleArray.newArray((long)this.graph.nodeCount());
        ParallelUtil.parallelForEachNode((long)this.graph.nodeCount(), (int)this.concurrency, v -> {
            parentCost.set(v, -2.0);
            parent.set(v, -2L);
        });
        DoubleAdder totalCost = new DoubleAdder();
        LongAdder effectiveNodeCount = new LongAdder();
        LongAdder terminalsReached = new LongAdder();
        effectiveNodeCount.increment();
        DijkstraResult shortestPaths = this.runShortestPaths();
        this.initForSource(parent, parentCost);
        shortestPaths.forEachPath(path -> {
            this.processPath((PathResult)path, parent, parentCost, totalCost, effectiveNodeCount);
            terminalsReached.increment();
        });
        this.progressTracker.endSubTask("Traverse");
        if (this.applyRerouting) {
            this.reroute(parent, parentCost, totalCost, effectiveNodeCount);
        }
        this.progressTracker.endSubTask("SteinerTree");
        return SteinerTreeResult.of(parent, parentCost, totalCost.doubleValue(), effectiveNodeCount.longValue(), terminalsReached.longValue());
    }

    public void release() {
    }

    private void initForSource(HugeLongArray parent, HugeDoubleArray parentCost) {
        parent.set(this.sourceId, -1L);
        parentCost.set(this.sourceId, 0.0);
    }

    private void processPath(PathResult path, HugeLongArray parent, HugeDoubleArray parentCost, DoubleAdder totalCost, LongAdder effectiveNodeCount) {
        long targetId = path.targetNode();
        if (this.isTerminal.get(targetId)) {
            long[] ids = path.nodeIds();
            double[] costs = path.costs();
            int pastLength = costs.length;
            totalCost.add(path.totalCost());
            for (int j = pastLength - 1; j >= 0; --j) {
                long nodeId = ids[j + 1];
                long parentId = ids[j];
                double cost = costs[j];
                if (j > 0) {
                    cost -= costs[j - 1];
                }
                parent.set(nodeId, parentId);
                parentCost.set(nodeId, cost);
                effectiveNodeCount.increment();
            }
        }
    }

    private DijkstraResult runShortestPaths() {
        SteinerBasedDeltaStepping steinerBasedDelta = new SteinerBasedDeltaStepping(this.graph, this.sourceId, this.delta, this.isTerminal, this.concurrency, this.binSizeThreshold, this.executorService, this.progressTracker);
        return steinerBasedDelta.compute();
    }

    private void reconnect(LinkCutTree tree, HugeLongArray parent, HugeDoubleArray parentCost, DoubleAdder totalCost, long source, long target, double weight) {
        double edgeCostOft = parentCost.get(target);
        parent.set(target, source);
        parentCost.set(target, weight);
        totalCost.add(-edgeCostOft + weight);
        tree.link(source, target);
    }

    private boolean checkIfRerouteIsValid(LinkCutTree tree, long source, long target, long parentTarget) {
        tree.delete(parentTarget, target);
        return !tree.connected(source, target);
    }

    private LinkCutTree createLinkCutTree(HugeLongArray parent) {
        LinkCutTree tree = new LinkCutTree(this.graph.nodeCount());
        for (long nodeId = 0L; nodeId < this.graph.nodeCount(); ++nodeId) {
            long parentId = parent.get(nodeId);
            if (parentId == -2L || parentId == -1L) continue;
            tree.link(parentId, nodeId);
        }
        return tree;
    }

    private void cutNodesAfterRerouting(HugeLongArray parent, HugeDoubleArray parentCost, DoubleAdder totalCost, LongAdder effectiveNodeCount) {
        BitSet endsAtTerminal = new BitSet(this.graph.nodeCount());
        HugeLongArrayQueue queue = HugeLongArrayQueue.newQueue((long)this.graph.nodeCount());
        for (Long terminal : this.terminals) {
            if (parent.get(terminal.longValue()) == -2L) continue;
            queue.add(terminal.longValue());
            endsAtTerminal.set(terminal.longValue());
        }
        while (!queue.isEmpty()) {
            long nodeId2 = queue.remove();
            long parentId = parent.get(nodeId2);
            if (parentId == this.sourceId || endsAtTerminal.getAndSet(parentId)) continue;
            queue.add(parentId);
        }
        ParallelUtil.parallelForEachNode((long)this.graph.nodeCount(), (int)this.concurrency, nodeId -> {
            if (parent.get(nodeId) != -2L && parent.get(nodeId) != -1L && !endsAtTerminal.get(nodeId)) {
                parent.set(nodeId, -2L);
                totalCost.add(-parentCost.get(nodeId));
                parentCost.set(nodeId, -2.0);
                effectiveNodeCount.decrement();
            }
        });
    }

    private void reroute(HugeLongArray parent, HugeDoubleArray parentCost, DoubleAdder totalCost, LongAdder effectiveNodeCount) {
        this.progressTracker.beginSubTask("Reroute");
        LinkCutTree tree = this.createLinkCutTree(parent);
        MutableBoolean didReroutes = new MutableBoolean();
        this.graph.forEachNode(nodeId -> {
            if (parent.get(nodeId) != -2L) {
                this.graph.forEachRelationship(nodeId, 1.0, (s, t, w) -> {
                    long parentId = parent.get(t);
                    double targetParentCost = parentCost.get(t);
                    if (parentId != -2L && parentId != -1L && w < targetParentCost) {
                        boolean shouldReconnect = this.checkIfRerouteIsValid(tree, s, t, parentId);
                        if (shouldReconnect) {
                            didReroutes.setTrue();
                            this.reconnect(tree, parent, parentCost, totalCost, s, t, w);
                        } else {
                            tree.link(parentId, t);
                        }
                    }
                    return true;
                });
            }
            this.progressTracker.logProgress();
            return true;
        });
        if (didReroutes.isTrue()) {
            this.cutNodesAfterRerouting(parent, parentCost, totalCost, effectiveNodeCount);
        }
        this.progressTracker.endSubTask("Reroute");
    }
}

