/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongArrayList;
import com.carrotsearch.hppc.cursors.LongCursor;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.BiLongConsumer;
import org.neo4j.gds.core.utils.ProgressTimer;
import org.neo4j.gds.core.utils.paged.HugeCursor;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.knn.GenerateRandomNeighbors;
import org.neo4j.gds.similarity.knn.ImmutableResult;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnNeighborFilterFactory;
import org.neo4j.gds.similarity.knn.KnnNodePropertySpec;
import org.neo4j.gds.similarity.knn.KnnSampler;
import org.neo4j.gds.similarity.knn.NeighborFilter;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.NeighborList;
import org.neo4j.gds.similarity.knn.NeighbourConsumers;
import org.neo4j.gds.similarity.knn.RandomWalkKnnSampler;
import org.neo4j.gds.similarity.knn.SimilarityFunction;
import org.neo4j.gds.similarity.knn.SplitOldAndNewNeighbors;
import org.neo4j.gds.similarity.knn.UniformKnnSampler;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
import org.neo4j.gds.utils.StringFormatting;

public class Knn
extends Algorithm<Result> {
    private final Graph graph;
    private final KnnBaseConfig config;
    private final NeighborFilterFactory neighborFilterFactory;
    private final ExecutorService executorService;
    private final SplittableRandom splittableRandom;
    private final SimilarityFunction similarityFunction;
    private final NeighbourConsumers neighborConsumers;
    private long nodePairsConsidered;

    public static Knn createWithDefaults(Graph graph, KnnBaseConfig config, KnnContext context) {
        return Knn.createWithDefaultsAndInstrumentation(graph, config, context, NeighbourConsumers.no_op, Knn.defaultSimilarityFunction(graph, config.nodeProperties()));
    }

    public static SimilarityFunction defaultSimilarityFunction(Graph graph, List<KnnNodePropertySpec> nodeProperties) {
        return Knn.defaultSimilarityFunction(SimilarityComputer.ofProperties(graph, nodeProperties));
    }

    private static SimilarityFunction defaultSimilarityFunction(SimilarityComputer similarityComputer) {
        return new SimilarityFunction(similarityComputer);
    }

    @NotNull
    public static Knn createWithDefaultsAndInstrumentation(Graph graph, KnnBaseConfig config, KnnContext context, NeighbourConsumers neighborConsumers, SimilarityFunction similarityFunction) {
        return new Knn(context.progressTracker(), graph, config, similarityFunction, new KnnNeighborFilterFactory(graph.nodeCount()), context.executor(), Knn.getSplittableRandom(config.randomSeed()), neighborConsumers);
    }

    public static Knn create(Graph graph, KnnBaseConfig config, SimilarityComputer similarityComputer, NeighborFilterFactory neighborFilterFactory, KnnContext context) {
        SplittableRandom splittableRandom = Knn.getSplittableRandom(config.randomSeed());
        SimilarityFunction similarityFunction = Knn.defaultSimilarityFunction(similarityComputer);
        return new Knn(context.progressTracker(), graph, config, similarityFunction, neighborFilterFactory, context.executor(), splittableRandom, NeighbourConsumers.no_op);
    }

    @NotNull
    private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
        return randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
    }

    Knn(ProgressTracker progressTracker, Graph graph, KnnBaseConfig config, SimilarityFunction similarityFunction, NeighborFilterFactory neighborFilterFactory, ExecutorService executorService, SplittableRandom splittableRandom, NeighbourConsumers neighborConsumers) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
        this.similarityFunction = similarityFunction;
        this.neighborFilterFactory = neighborFilterFactory;
        this.executorService = executorService;
        this.splittableRandom = splittableRandom;
        this.neighborConsumers = neighborConsumers;
    }

    public long nodeCount() {
        return this.graph.nodeCount();
    }

    public ExecutorService executorService() {
        return this.executorService;
    }

    public Result compute() {
        this.progressTracker.beginSubTask();
        try (ProgressTimer ignored1 = ProgressTimer.start(this::logOverallTime);){
            int iteration;
            HugeObjectArray<NeighborList> neighbors;
            try (Object ignored2 = ProgressTimer.start(this::logInitTime);){
                this.progressTracker.beginSubTask();
                neighbors = this.initializeRandomNeighbors();
                this.progressTracker.endSubTask();
            }
            if (neighbors == null) {
                ignored2 = new EmptyResult();
                return ignored2;
            }
            int maxIterations = this.config.maxIterations();
            long maxUpdates = (long)Math.ceil(this.config.sampleRate() * (double)this.config.topK() * (double)this.graph.nodeCount());
            long updateThreshold = (long)Math.floor(this.config.deltaThreshold() * (double)maxUpdates);
            boolean didConverge = false;
            this.progressTracker.beginSubTask();
            for (iteration = 0; iteration < maxIterations; ++iteration) {
                long updateCount;
                int currentIteration = iteration++;
                try (ProgressTimer ignored3 = ProgressTimer.start(took -> this.logIterationTime(currentIteration + 1, took));){
                    updateCount = this.iteration(neighbors);
                }
                if (updateCount > updateThreshold) continue;
                didConverge = true;
                break;
            }
            if (this.config.similarityCutoff() > 0.0) {
                double similarityCutoff = this.config.similarityCutoff();
                List neighborFilterTasks = PartitionUtils.rangePartition((int)this.config.concurrency(), (long)neighbors.size(), partition -> () -> partition.consume(nodeId -> ((NeighborList)neighbors.get(nodeId)).filterHighSimilarityResults(similarityCutoff)), Optional.of(this.config.minBatchSize()));
                RunWithConcurrency.builder().concurrency(this.config.concurrency()).tasks((Iterable)neighborFilterTasks).terminationFlag(this.terminationFlag).executor(this.executorService).run();
            }
            this.progressTracker.endSubTask();
            this.progressTracker.endSubTask();
            Result result = ImmutableResult.of(neighbors, iteration, didConverge, this.nodePairsConsidered);
            return result;
        }
    }

    public void release() {
    }

    @Nullable
    private HugeObjectArray<NeighborList> initializeRandomNeighbors() {
        int k = this.config.topK();
        int boundedK = (int)Math.min(this.graph.nodeCount() - 1L, (long)k);
        assert (boundedK <= k && (long)boundedK <= this.graph.nodeCount() - 1L);
        if (this.graph.nodeCount() < 2L || k == 0) {
            return null;
        }
        HugeObjectArray neighbors = HugeObjectArray.newArray(NeighborList.class, (long)this.graph.nodeCount());
        List randomNeighborGenerators = PartitionUtils.rangePartition((int)this.config.concurrency(), (long)this.graph.nodeCount(), partition -> {
            SplittableRandom localRandom = this.splittableRandom.split();
            return new GenerateRandomNeighbors(this.initializeSampler(localRandom), localRandom, this.similarityFunction, this.neighborFilterFactory.create(), (HugeObjectArray<NeighborList>)neighbors, boundedK, (Partition)partition, this.progressTracker, this.neighborConsumers);
        }, Optional.of(this.config.minBatchSize()));
        RunWithConcurrency.builder().concurrency(this.config.concurrency()).tasks((Iterable)randomNeighborGenerators).terminationFlag(this.terminationFlag).executor(this.executorService).run();
        this.nodePairsConsidered += randomNeighborGenerators.stream().mapToLong(GenerateRandomNeighbors::neighborsFound).sum();
        return neighbors;
    }

    private KnnSampler initializeSampler(SplittableRandom random) {
        switch (this.config.initialSampler()) {
            case UNIFORM: {
                return new UniformKnnSampler(random, this.graph.nodeCount());
            }
            case RANDOMWALK: {
                return new RandomWalkKnnSampler(this.graph.concurrentCopy(), random, this.config.randomSeed(), this.config.boundedK(this.graph.nodeCount()));
            }
        }
        throw new IllegalStateException("Invalid KnnSampler");
    }

    private long iteration(HugeObjectArray<NeighborList> neighbors) {
        long nodeCount = this.graph.nodeCount();
        if (nodeCount < 2L || this.config.topK() == 0) {
            return 0L;
        }
        int concurrency = this.config.concurrency();
        int sampledK = this.config.sampledK(nodeCount);
        HugeObjectArray allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        HugeObjectArray allNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        this.progressTracker.beginSubTask();
        ParallelUtil.readParallel((int)concurrency, (long)nodeCount, (ExecutorService)this.executorService, (BiLongConsumer)new SplitOldAndNewNeighbors(this.splittableRandom, neighbors, (HugeObjectArray<LongArrayList>)allOldNeighbors, (HugeObjectArray<LongArrayList>)allNewNeighbors, sampledK, this.progressTracker));
        this.progressTracker.endSubTask();
        HugeObjectArray reverseOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        HugeObjectArray reverseNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        this.progressTracker.beginSubTask();
        Knn.reverseOldAndNewNeighbors((HugeObjectArray<LongArrayList>)allOldNeighbors, (HugeObjectArray<LongArrayList>)allNewNeighbors, (HugeObjectArray<LongArrayList>)reverseOldNeighbors, (HugeObjectArray<LongArrayList>)reverseNewNeighbors, this.config, this.progressTracker);
        this.progressTracker.endSubTask();
        List neighborsJoiners = PartitionUtils.rangePartition((int)concurrency, (long)nodeCount, partition -> new JoinNeighbors(this.splittableRandom.split(), this.similarityFunction, this.neighborFilterFactory.create(), neighbors, (HugeObjectArray<LongArrayList>)allOldNeighbors, (HugeObjectArray<LongArrayList>)allNewNeighbors, (HugeObjectArray<LongArrayList>)reverseOldNeighbors, (HugeObjectArray<LongArrayList>)reverseNewNeighbors, sampledK, this.config.perturbationRate(), this.config.randomJoins(), (Partition)partition, this.progressTracker), Optional.of(this.config.minBatchSize()));
        this.progressTracker.beginSubTask();
        RunWithConcurrency.builder().concurrency(concurrency).tasks((Iterable)neighborsJoiners).terminationFlag(this.terminationFlag).executor(this.executorService).run();
        this.progressTracker.endSubTask();
        this.nodePairsConsidered += neighborsJoiners.stream().mapToLong(JoinNeighbors::nodePairsConsidered).sum();
        return neighborsJoiners.stream().mapToLong(joiner -> joiner.updateCount).sum();
    }

    private static void reverseOldAndNewNeighbors(HugeObjectArray<LongArrayList> allOldNeighbors, HugeObjectArray<LongArrayList> allNewNeighbors, HugeObjectArray<LongArrayList> reverseOldNeighbors, HugeObjectArray<LongArrayList> reverseNewNeighbors, KnnBaseConfig config, ProgressTracker progressTracker) {
        long nodeCount = allNewNeighbors.size();
        long logBatchSize = ParallelUtil.adjustedBatchSize((long)nodeCount, (int)config.concurrency(), (long)config.minBatchSize());
        for (long nodeId = 0L; nodeId < nodeCount; ++nodeId) {
            Knn.reverseNeighbors(nodeId, allOldNeighbors, reverseOldNeighbors);
            Knn.reverseNeighbors(nodeId, allNewNeighbors, reverseNewNeighbors);
            if ((nodeId + 1L) % logBatchSize != 0L) continue;
            progressTracker.logProgress(logBatchSize);
        }
    }

    static void reverseNeighbors(long nodeId, HugeObjectArray<LongArrayList> allNeighbors, HugeObjectArray<LongArrayList> reverseNeighbors) {
        LongArrayList neighbors = (LongArrayList)allNeighbors.get(nodeId);
        if (neighbors != null) {
            for (LongCursor neighbor : neighbors) {
                assert (neighbor.value != nodeId);
                LongArrayList oldReverse = (LongArrayList)reverseNeighbors.get(neighbor.value);
                if (oldReverse == null) {
                    oldReverse = new LongArrayList();
                    reverseNeighbors.set(neighbor.value, (Object)oldReverse);
                }
                oldReverse.add(nodeId);
            }
        }
    }

    private void logInitTime(long ms) {
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Graph init took %d ms", (Object[])new Object[]{ms}));
    }

    private void logIterationTime(int iteration, long ms) {
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Graph iteration %d took %d ms", (Object[])new Object[]{iteration, ms}));
    }

    private void logOverallTime(long ms) {
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Graph execution took %d ms", (Object[])new Object[]{ms}));
    }

    private static final class EmptyResult
    extends Result {
        private EmptyResult() {
        }

        @Override
        HugeObjectArray<NeighborList> neighborList() {
            return HugeObjectArray.of((Object[])new NeighborList[0]);
        }

        @Override
        public int ranIterations() {
            return 0;
        }

        @Override
        public boolean didConverge() {
            return false;
        }

        @Override
        public long nodePairsConsidered() {
            return 0L;
        }

        @Override
        public LongStream neighborsOf(long nodeId) {
            return LongStream.empty();
        }

        @Override
        public long size() {
            return 0L;
        }
    }

    @ValueClass
    public static abstract class Result {
        abstract HugeObjectArray<NeighborList> neighborList();

        public abstract int ranIterations();

        public abstract boolean didConverge();

        public abstract long nodePairsConsidered();

        public LongStream neighborsOf(long nodeId) {
            return ((NeighborList)this.neighborList().get(nodeId)).elements().map(NeighborList::clearCheckedFlag);
        }

        public Stream<SimilarityResult> streamSimilarityResult() {
            HugeObjectArray<NeighborList> neighborList = this.neighborList();
            return Stream.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity()).flatMap(cursor -> IntStream.range(cursor.offset, cursor.limit).mapToObj(index -> ((NeighborList[])cursor.array)[index].similarityStream((long)index + cursor.base)).flatMap(Function.identity()));
        }

        public long totalSimilarityPairs() {
            HugeObjectArray<NeighborList> neighborList = this.neighborList();
            return Stream.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity()).flatMapToLong(cursor -> IntStream.range(cursor.offset, cursor.limit).mapToLong(index -> ((NeighborList[])cursor.array)[index].size())).sum();
        }

        public long size() {
            return this.neighborList().size();
        }
    }

    static final class JoinNeighbors
    implements Runnable {
        private final SplittableRandom random;
        private final SimilarityFunction similarityFunction;
        private final NeighborFilter neighborFilter;
        private final HugeObjectArray<NeighborList> allNeighbors;
        private final HugeObjectArray<LongArrayList> allOldNeighbors;
        private final HugeObjectArray<LongArrayList> allNewNeighbors;
        private final HugeObjectArray<LongArrayList> allReverseOldNeighbors;
        private final HugeObjectArray<LongArrayList> allReverseNewNeighbors;
        private final int sampledK;
        private final int randomJoins;
        private final ProgressTracker progressTracker;
        private final long nodeCount;
        private long updateCount;
        private final Partition partition;
        private long nodePairsConsidered;
        private final double perturbationRate;

        JoinNeighbors(SplittableRandom random, SimilarityFunction similarityFunction, NeighborFilter neighborFilter, HugeObjectArray<NeighborList> allNeighbors, HugeObjectArray<LongArrayList> allOldNeighbors, HugeObjectArray<LongArrayList> allNewNeighbors, HugeObjectArray<LongArrayList> allReverseOldNeighbors, HugeObjectArray<LongArrayList> allReverseNewNeighbors, int sampledK, double perturbationRate, int randomJoins, Partition partition, ProgressTracker progressTracker) {
            this.random = random;
            this.similarityFunction = similarityFunction;
            this.neighborFilter = neighborFilter;
            this.allNeighbors = allNeighbors;
            this.nodeCount = allNewNeighbors.size();
            this.allOldNeighbors = allOldNeighbors;
            this.allNewNeighbors = allNewNeighbors;
            this.allReverseOldNeighbors = allReverseOldNeighbors;
            this.allReverseNewNeighbors = allReverseNewNeighbors;
            this.sampledK = sampledK;
            this.randomJoins = randomJoins;
            this.partition = partition;
            this.progressTracker = progressTracker;
            this.perturbationRate = perturbationRate;
            this.updateCount = 0L;
            this.nodePairsConsidered = 0L;
        }

        @Override
        public void run() {
            long startNode = this.partition.startNode();
            long endNode = startNode + this.partition.nodeCount();
            for (long nodeId = startNode; nodeId < endNode; ++nodeId) {
                LongArrayList newNeighbors;
                LongArrayList oldNeighbors = (LongArrayList)this.allOldNeighbors.get(nodeId);
                if (oldNeighbors != null) {
                    this.combineNeighbors((LongArrayList)this.allReverseOldNeighbors.get(nodeId), oldNeighbors);
                }
                if ((newNeighbors = (LongArrayList)this.allNewNeighbors.get(nodeId)) != null) {
                    this.combineNeighbors((LongArrayList)this.allReverseNewNeighbors.get(nodeId), newNeighbors);
                    this.updateCount += this.joinNewNeighbors(nodeId, oldNeighbors, newNeighbors);
                }
                this.randomJoins(this.nodeCount, nodeId);
            }
            this.progressTracker.logProgress(this.partition.nodeCount());
        }

        private long joinNewNeighbors(long nodeId, LongArrayList oldNeighbors, LongArrayList newNeighbors) {
            long updateCount = 0L;
            long[] newNeighborElements = newNeighbors.buffer;
            int newNeighborsCount = newNeighbors.elementsCount;
            boolean similarityIsSymmetric = this.similarityFunction.isSymmetric();
            for (int i = 0; i < newNeighborsCount; ++i) {
                long elem1 = newNeighborElements[i];
                assert (elem1 != nodeId);
                updateCount += this.join(elem1, nodeId);
                for (int j = i + 1; j < newNeighborsCount; ++j) {
                    long elem2 = newNeighborElements[j];
                    if (elem1 == elem2) continue;
                    if (similarityIsSymmetric) {
                        updateCount += this.joinSymmetric(elem1, elem2);
                        continue;
                    }
                    updateCount += this.join(elem1, elem2);
                    updateCount += this.join(elem2, elem1);
                }
                if (oldNeighbors == null) continue;
                for (LongCursor oldElemCursor : oldNeighbors) {
                    long elem2 = oldElemCursor.value;
                    if (elem1 == elem2) continue;
                    if (similarityIsSymmetric) {
                        updateCount += this.joinSymmetric(elem1, elem2);
                        continue;
                    }
                    updateCount += this.join(elem1, elem2);
                    updateCount += this.join(elem2, elem1);
                }
            }
            return updateCount;
        }

        private void combineNeighbors(@Nullable LongArrayList reversedNeighbors, LongArrayList neighbors) {
            if (reversedNeighbors != null) {
                int numberOfReverseNeighbors = reversedNeighbors.size();
                for (LongCursor elem : reversedNeighbors) {
                    if (this.random.nextInt(numberOfReverseNeighbors) >= this.sampledK) continue;
                    neighbors.add(elem.value);
                }
            }
        }

        private void randomJoins(long nodeCount, long nodeId) {
            for (int i = 0; i < this.randomJoins; ++i) {
                long randomNodeId = this.random.nextLong(nodeCount - 1L);
                if (randomNodeId >= nodeId) {
                    ++randomNodeId;
                }
                this.join(nodeId, randomNodeId);
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private long joinSymmetric(long node1, long node2) {
            NeighborList neighbors2;
            assert (node1 != node2);
            if (this.neighborFilter.excludeNodePair(node1, node2)) {
                return 0L;
            }
            ++this.nodePairsConsidered;
            double similarity = this.similarityFunction.computeSimilarity(node1, node2);
            NeighborList neighbors1 = (NeighborList)this.allNeighbors.get(node1);
            long updates = 0L;
            NeighborList neighborList = neighbors1;
            synchronized (neighborList) {
                updates += neighbors1.add(node2, similarity, this.random, this.perturbationRate);
            }
            NeighborList neighborList2 = neighbors2 = (NeighborList)this.allNeighbors.get(node2);
            synchronized (neighborList2) {
            }
            return updates += neighbors2.add(node1, similarity, this.random, this.perturbationRate);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private long join(long node1, long node2) {
            NeighborList neighbors;
            assert (node1 != node2);
            if (this.neighborFilter.excludeNodePair(node1, node2)) {
                return 0L;
            }
            double similarity = this.similarityFunction.computeSimilarity(node1, node2);
            ++this.nodePairsConsidered;
            NeighborList neighborList = neighbors = (NeighborList)this.allNeighbors.get(node1);
            synchronized (neighborList) {
                return neighbors.add(node2, similarity, this.random, this.perturbationRate);
            }
        }

        long nodePairsConsidered() {
            return this.nodePairsConsidered;
        }
    }
}

