/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.traversal;

import java.util.List;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.config.SourceNodesConfig;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.queue.QueueBasedSpliterator;
import org.neo4j.gds.degree.DegreeCentrality;
import org.neo4j.gds.degree.DegreeCentralityConfig;
import org.neo4j.gds.degree.ImmutableDegreeCentralityConfig;
import org.neo4j.gds.ml.core.EmbeddingUtils;
import org.neo4j.gds.ml.core.samplers.RandomWalkSampler;
import org.neo4j.gds.traversal.RandomWalkBaseConfig;

public final class RandomWalk
extends Algorithm<Stream<long[]>> {
    private final Graph graph;
    private final RandomWalkBaseConfig config;
    private final ExecutorService executorService;

    private RandomWalk(Graph graph, RandomWalkBaseConfig config, ProgressTracker progressTracker, ExecutorService executorService) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
        this.executorService = executorService;
    }

    public static RandomWalk create(Graph graph, RandomWalkBaseConfig config, ProgressTracker progressTracker, ExecutorService executorService) {
        if (graph.hasRelationshipProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue((Graph)graph, (int)config.concurrency(), weight -> weight >= 0.0, (String)"Node2Vec only supports non-negative weights.", (ExecutorService)executorService);
        }
        return new RandomWalk(graph, config, progressTracker, executorService);
    }

    public Stream<long[]> compute() {
        this.progressTracker.beginSubTask("RandomWalk");
        RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier = this.graph.hasRelationshipProperty() ? this.cumulativeWeights()::get : arg_0 -> ((Graph)this.graph).degree(arg_0);
        Long randomSeed = this.config.randomSeed().orElseGet(() -> new Random().nextLong());
        NextNodeSupplier.GraphNodeSupplier nextNodeSupplier = this.config.sourceNodes() == null || this.config.sourceNodes().isEmpty() ? new NextNodeSupplier.GraphNodeSupplier(this.graph.nodeCount()) : NextNodeSupplier.ListNodeSupplier.of(this.config, this.graph);
        ExternalTerminationFlag terminationFlag = new ExternalTerminationFlag(this);
        ArrayBlockingQueue<long[]> walks = new ArrayBlockingQueue<long[]>(this.config.walkBufferSize());
        long[] TOMB = new long[]{};
        this.startWalkers(terminationFlag, cumulativeWeightSupplier, randomSeed, nextNodeSupplier, walks, TOMB);
        return this.walksQueueConsumer(terminationFlag, TOMB, walks);
    }

    private DegreeCentrality.DegreeFunction cumulativeWeights() {
        DegreeCentralityConfig degreeCentralityConfig = ImmutableDegreeCentralityConfig.builder().concurrency(this.config.concurrency()).relationshipWeightProperty("DUMMY").build();
        return new DegreeCentrality(this.graph, this.executorService, degreeCentralityConfig, this.progressTracker).compute();
    }

    public void release() {
    }

    private void startWalkers(TerminationFlag terminationFlag, RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier, long randomSeed, NextNodeSupplier nextNodeSupplier, BlockingQueue<long[]> walks, long[] TOMB) {
        List tasks = IntStream.range(0, this.config.concurrency()).mapToObj(i -> RandomWalkTask.of(nextNodeSupplier, cumulativeWeightSupplier, this.graph.concurrentCopy(), this.config, walks, randomSeed, this.progressTracker, terminationFlag)).collect(Collectors.toList());
        CompletableFuture.runAsync(() -> this.tasksRunner(tasks, walks, TOMB, terminationFlag), Pools.DEFAULT_SINGLE_THREAD_POOL).whenComplete((__, ___) -> {
            this.progressTracker.endSubTask("RandomWalk");
            this.release();
        });
    }

    private void tasksRunner(Iterable<? extends Runnable> tasks, BlockingQueue<long[]> walks, long[] tombstone, TerminationFlag terminationFlag) {
        this.progressTracker.beginSubTask("create walks");
        RunWithConcurrency.builder().executor(this.executorService).concurrency(this.config.concurrency()).tasks(tasks).terminationFlag(terminationFlag).mayInterruptIfRunning(true).run();
        this.progressTracker.endSubTask("create walks");
        try {
            boolean finished = false;
            while (!finished && terminationFlag.running()) {
                finished = walks.offer(tombstone, 100L, TimeUnit.MILLISECONDS);
            }
        }
        catch (InterruptedException exception) {
            Thread.currentThread().interrupt();
        }
    }

    private Stream<long[]> walksQueueConsumer(ExternalTerminationFlag terminationFlag, long[] tombstone, BlockingQueue<long[]> walks) {
        int timeoutInSeconds = 100;
        QueueBasedSpliterator queueConsumer = new QueueBasedSpliterator(walks, (Object)tombstone, (TerminationFlag)terminationFlag, timeoutInSeconds);
        return (Stream)StreamSupport.stream(queueConsumer, false).onClose(terminationFlag::stop);
    }

    @FunctionalInterface
    static interface NextNodeSupplier {
        public static final long NO_MORE_NODES = -1L;

        public long nextNode();

        public static final class ListNodeSupplier
        implements NextNodeSupplier {
            private final List<Long> nodes;
            private final AtomicInteger nextIndex;

            static ListNodeSupplier of(SourceNodesConfig config, Graph graph) {
                List<Long> mappedIds = config.sourceNodes().stream().map(arg_0 -> ((Graph)graph).toMappedNodeId(arg_0)).collect(Collectors.toList());
                return new ListNodeSupplier(mappedIds);
            }

            private ListNodeSupplier(List<Long> nodes) {
                this.nodes = nodes;
                this.nextIndex = new AtomicInteger(0);
            }

            @Override
            public long nextNode() {
                int index = this.nextIndex.getAndIncrement();
                return index < this.nodes.size() ? this.nodes.get(index) : -1L;
            }
        }

        public static class GraphNodeSupplier
        implements NextNodeSupplier {
            private final long numberOfNodes;
            private final AtomicLong nextNodeId;

            GraphNodeSupplier(long numberOfNodes) {
                this.numberOfNodes = numberOfNodes;
                this.nextNodeId = new AtomicLong(0L);
            }

            @Override
            public long nextNode() {
                long nextNode = this.nextNodeId.getAndIncrement();
                return nextNode < this.numberOfNodes ? nextNode : -1L;
            }
        }
    }

    private static final class RandomWalkTask
    implements Runnable {
        private final Graph graph;
        private final BlockingQueue<long[]> walks;
        private final NextNodeSupplier nextNodeSupplier;
        private final long[][] buffer;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;
        private final RandomWalkBaseConfig config;
        private final RandomWalkSampler sampler;

        static RandomWalkTask of(NextNodeSupplier nextNodeSupplier, RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier, Graph graph, RandomWalkBaseConfig config, BlockingQueue<long[]> walks, long randomSeed, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            double maxProbability = Math.max(Math.max(1.0 / config.returnFactor(), 1.0), 1.0 / config.inOutFactor());
            double normalizedReturnProbability = 1.0 / config.returnFactor() / maxProbability;
            double normalizedSameDistanceProbability = 1.0 / maxProbability;
            double normalizedInOutProbability = 1.0 / config.inOutFactor() / maxProbability;
            return new RandomWalkTask(nextNodeSupplier, cumulativeWeightSupplier, config, walks, normalizedReturnProbability, normalizedSameDistanceProbability, normalizedInOutProbability, graph, randomSeed, progressTracker, terminationFlag);
        }

        private RandomWalkTask(NextNodeSupplier nextNodeSupplier, RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier, RandomWalkBaseConfig config, BlockingQueue<long[]> walks, double normalizedReturnProbability, double normalizedSameDistanceProbability, double normalizedInOutProbability, Graph graph, long randomSeed, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.nextNodeSupplier = nextNodeSupplier;
            this.graph = graph;
            this.config = config;
            this.walks = walks;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
            this.sampler = new RandomWalkSampler(cumulativeWeightSupplier, config.walkLength(), normalizedReturnProbability, normalizedSameDistanceProbability, normalizedInOutProbability, graph, randomSeed);
            this.buffer = new long[1000][];
        }

        @Override
        public void run() {
            long nodeId;
            int bufferLength = 0;
            while ((nodeId = this.nextNodeSupplier.nextNode()) != -1L) {
                if (this.graph.degree(nodeId) == 0) {
                    this.progressTracker.logProgress();
                    continue;
                }
                int walksPerNode = this.config.walksPerNode();
                this.sampler.prepareForNewNode(nodeId);
                for (int walkIndex = 0; walkIndex < walksPerNode; ++walkIndex) {
                    this.buffer[bufferLength++] = this.sampler.walk(nodeId);
                    if (bufferLength != this.buffer.length) continue;
                    boolean shouldStop = this.flushBuffer(bufferLength);
                    bufferLength = 0;
                    if (!shouldStop) break;
                }
                this.progressTracker.logProgress();
            }
            this.flushBuffer(bufferLength);
        }

        private boolean flushBuffer(int bufferLength) {
            bufferLength = Math.min(bufferLength, this.buffer.length);
            int i = 0;
            while (i < bufferLength && this.terminationFlag.running()) {
                try {
                    if (!this.walks.offer(this.buffer[i], 100L, TimeUnit.MILLISECONDS)) continue;
                    ++i;
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    return false;
                }
            }
            return this.terminationFlag.running();
        }
    }

    private static final class ExternalTerminationFlag
    implements TerminationFlag {
        private volatile boolean running = true;
        private final Algorithm<?> algo;

        ExternalTerminationFlag(Algorithm<?> algo) {
            this.algo = algo;
        }

        public boolean running() {
            return this.running && this.algo.getTerminationFlag().running();
        }

        void stop() {
            this.running = false;
        }
    }
}

