/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongArrayList;
import com.carrotsearch.hppc.cursors.LongCursor;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.BiLongConsumer;
import org.neo4j.gds.core.utils.ProgressTimer;
import org.neo4j.gds.core.utils.paged.HugeCursor;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.knn.GenerateRandomNeighbors;
import org.neo4j.gds.similarity.knn.ImmutableResult;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnNeighborFilterFactory;
import org.neo4j.gds.similarity.knn.KnnSampler;
import org.neo4j.gds.similarity.knn.NeighborFilter;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.NeighborList;
import org.neo4j.gds.similarity.knn.RandomWalkKnnSampler;
import org.neo4j.gds.similarity.knn.SplitOldAndNewNeighbors;
import org.neo4j.gds.similarity.knn.UniformKnnSampler;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
import org.neo4j.gds.utils.StringFormatting;

public class Knn
extends Algorithm<Result> {
    private final Graph graph;
    private final KnnBaseConfig config;
    private final NeighborFilterFactory neighborFilterFactory;
    private final KnnContext context;
    private final SplittableRandom splittableRandom;
    private final SimilarityComputer similarityComputer;
    private long nodePairsConsidered;

    public static Knn createWithDefaults(Graph graph, KnnBaseConfig config, KnnContext context) {
        return new Knn(context.progressTracker(), graph, config, SimilarityComputer.ofProperties(graph, config.nodeProperties()), new KnnNeighborFilterFactory(graph.nodeCount()), context, Knn.getSplittableRandom(config.randomSeed()));
    }

    public static Knn create(Graph graph, KnnBaseConfig config, SimilarityComputer similarityComputer, NeighborFilterFactory neighborFilterFactory, KnnContext context) {
        SplittableRandom splittableRandom = Knn.getSplittableRandom(config.randomSeed());
        return new Knn(context.progressTracker(), graph, config, similarityComputer, neighborFilterFactory, context, splittableRandom);
    }

    @NotNull
    private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
        return randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
    }

    Knn(ProgressTracker progressTracker, Graph graph, KnnBaseConfig config, SimilarityComputer similarityComputer, NeighborFilterFactory neighborFilterFactory, KnnContext context, SplittableRandom splittableRandom) {
        super(progressTracker);
        this.graph = graph;
        this.config = config;
        this.similarityComputer = similarityComputer;
        this.neighborFilterFactory = neighborFilterFactory;
        this.context = context;
        this.splittableRandom = splittableRandom;
    }

    public long nodeCount() {
        return this.graph.nodeCount();
    }

    public KnnContext context() {
        return this.context;
    }

    public Result compute() {
        this.progressTracker.beginSubTask();
        try (ProgressTimer ignored1 = ProgressTimer.start(this::logOverallTime);){
            int iteration;
            HugeObjectArray<NeighborList> neighbors;
            try (Object ignored2 = ProgressTimer.start(this::logInitTime);){
                this.progressTracker.beginSubTask();
                neighbors = this.initializeRandomNeighbors();
                this.progressTracker.endSubTask();
            }
            if (neighbors == null) {
                ignored2 = new EmptyResult();
                return ignored2;
            }
            int maxIterations = this.config.maxIterations();
            long maxUpdates = (long)Math.ceil(this.config.sampleRate() * (double)this.config.topK() * (double)this.graph.nodeCount());
            long updateThreshold = (long)Math.floor(this.config.deltaThreshold() * (double)maxUpdates);
            boolean didConverge = false;
            this.progressTracker.beginSubTask();
            for (iteration = 0; iteration < maxIterations; ++iteration) {
                long updateCount;
                int currentIteration = iteration++;
                try (ProgressTimer ignored3 = ProgressTimer.start(took -> this.logIterationTime(currentIteration + 1, took));){
                    updateCount = this.iteration(neighbors);
                }
                if (updateCount > updateThreshold) continue;
                didConverge = true;
                break;
            }
            if (this.config.similarityCutoff() > 0.0) {
                double similarityCutoff = this.config.similarityCutoff();
                List neighborFilterTasks = PartitionUtils.rangePartition((int)this.config.concurrency(), (long)neighbors.size(), partition -> () -> partition.consume(nodeId -> ((NeighborList)neighbors.get(nodeId)).filterHighSimilarityResults(similarityCutoff)), Optional.of(this.config.minBatchSize()));
                ParallelUtil.runWithConcurrency((int)this.config.concurrency(), (Iterable)neighborFilterTasks, (ExecutorService)this.context.executor());
            }
            this.progressTracker.endSubTask();
            this.progressTracker.endSubTask();
            Result result = ImmutableResult.of(neighbors, iteration, didConverge, this.nodePairsConsidered);
            return result;
        }
    }

    public void release() {
    }

    @Nullable
    private HugeObjectArray<NeighborList> initializeRandomNeighbors() {
        int k = this.config.topK();
        int boundedK = (int)Math.min(this.graph.nodeCount() - 1L, (long)k);
        assert (boundedK <= k && (long)boundedK <= this.graph.nodeCount() - 1L);
        if (this.graph.nodeCount() < 2L || k == 0) {
            return null;
        }
        HugeObjectArray neighbors = HugeObjectArray.newArray(NeighborList.class, (long)this.graph.nodeCount());
        List randomNeighborGenerators = PartitionUtils.rangePartition((int)this.config.concurrency(), (long)this.graph.nodeCount(), partition -> {
            SplittableRandom localRandom = this.splittableRandom.split();
            return new GenerateRandomNeighbors(this.initializeSampler(localRandom), localRandom, this.similarityComputer, this.neighborFilterFactory.create(), (HugeObjectArray<NeighborList>)neighbors, k, boundedK, (Partition)partition, this.progressTracker);
        }, Optional.of(this.config.minBatchSize()));
        ParallelUtil.runWithConcurrency((int)this.config.concurrency(), (Iterable)randomNeighborGenerators, (ExecutorService)this.context.executor());
        this.nodePairsConsidered += randomNeighborGenerators.stream().mapToLong(GenerateRandomNeighbors::neighborsFound).sum();
        return neighbors;
    }

    private KnnSampler initializeSampler(SplittableRandom random) {
        switch (this.config.initialSampler()) {
            case UNIFORM: {
                return new UniformKnnSampler(random, this.graph.nodeCount());
            }
            case RANDOMWALK: {
                return new RandomWalkKnnSampler(this.graph.concurrentCopy(), random, this.config.randomSeed(), this.config.boundedK(this.graph.nodeCount()));
            }
        }
        throw new IllegalStateException("Invalid KnnSampler");
    }

    private long iteration(HugeObjectArray<NeighborList> neighbors) {
        long nodeCount = this.graph.nodeCount();
        if (nodeCount < 2L || this.config.topK() == 0) {
            return 0L;
        }
        int concurrency = this.config.concurrency();
        ExecutorService executor = this.context.executor();
        int sampledK = this.config.sampledK(nodeCount);
        HugeObjectArray allOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        HugeObjectArray allNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        this.progressTracker.beginSubTask();
        ParallelUtil.readParallel((int)concurrency, (long)nodeCount, (ExecutorService)executor, (BiLongConsumer)new SplitOldAndNewNeighbors(this.splittableRandom, neighbors, (HugeObjectArray<LongArrayList>)allOldNeighbors, (HugeObjectArray<LongArrayList>)allNewNeighbors, sampledK, this.progressTracker));
        this.progressTracker.endSubTask();
        HugeObjectArray reverseOldNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        HugeObjectArray reverseNewNeighbors = HugeObjectArray.newArray(LongArrayList.class, (long)nodeCount);
        this.progressTracker.beginSubTask();
        Knn.reverseOldAndNewNeighbors(nodeCount, (HugeObjectArray<LongArrayList>)allOldNeighbors, (HugeObjectArray<LongArrayList>)allNewNeighbors, (HugeObjectArray<LongArrayList>)reverseOldNeighbors, (HugeObjectArray<LongArrayList>)reverseNewNeighbors, this.config, this.progressTracker);
        this.progressTracker.endSubTask();
        List neighborsJoiners = PartitionUtils.rangePartition((int)concurrency, (long)nodeCount, partition -> new JoinNeighbors(this.splittableRandom.split(), this.similarityComputer, this.neighborFilterFactory.create(), neighbors, (HugeObjectArray<LongArrayList>)allOldNeighbors, (HugeObjectArray<LongArrayList>)allNewNeighbors, (HugeObjectArray<LongArrayList>)reverseOldNeighbors, (HugeObjectArray<LongArrayList>)reverseNewNeighbors, nodeCount, this.config.topK(), sampledK, this.config.perturbationRate(), this.config.randomJoins(), (Partition)partition, this.progressTracker), Optional.of(this.config.minBatchSize()));
        this.progressTracker.beginSubTask();
        ParallelUtil.runWithConcurrency((int)concurrency, (Iterable)neighborsJoiners, (ExecutorService)executor);
        this.progressTracker.endSubTask();
        this.nodePairsConsidered += neighborsJoiners.stream().mapToLong(JoinNeighbors::nodePairsConsidered).sum();
        return neighborsJoiners.stream().mapToLong(joiner -> joiner.updateCount).sum();
    }

    private static void reverseOldAndNewNeighbors(long nodeCount, HugeObjectArray<LongArrayList> allOldNeighbors, HugeObjectArray<LongArrayList> allNewNeighbors, HugeObjectArray<LongArrayList> reverseOldNeighbors, HugeObjectArray<LongArrayList> reverseNewNeighbors, KnnBaseConfig config, ProgressTracker progressTracker) {
        long logBatchSize = ParallelUtil.adjustedBatchSize((long)nodeCount, (int)config.concurrency(), (long)config.minBatchSize());
        for (long nodeId = 0L; nodeId < nodeCount; ++nodeId) {
            Knn.reverseNeighbors(nodeId, allOldNeighbors, reverseOldNeighbors);
            Knn.reverseNeighbors(nodeId, allNewNeighbors, reverseNewNeighbors);
            if ((nodeId + 1L) % logBatchSize != 0L) continue;
            progressTracker.logProgress(logBatchSize);
        }
    }

    static void reverseNeighbors(long nodeId, HugeObjectArray<LongArrayList> allNeighbors, HugeObjectArray<LongArrayList> reverseNeighbors) {
        LongArrayList neighbors = (LongArrayList)allNeighbors.get(nodeId);
        if (neighbors != null) {
            for (LongCursor neighbor : neighbors) {
                assert (neighbor.value != nodeId);
                LongArrayList oldReverse = (LongArrayList)reverseNeighbors.get(neighbor.value);
                if (oldReverse == null) {
                    oldReverse = new LongArrayList();
                    reverseNeighbors.set(neighbor.value, (Object)oldReverse);
                }
                oldReverse.add(nodeId);
            }
        }
    }

    private void logInitTime(long ms) {
        this.progressTracker.logMessage(StringFormatting.formatWithLocale((String)"Graph init took %d ms", (Object[])new Object[]{ms}));
    }

    private void logIterationTime(int iteration, long ms) {
        this.progressTracker.logMessage(StringFormatting.formatWithLocale((String)"Graph iteration %d took %d ms", (Object[])new Object[]{iteration, ms}));
    }

    private void logOverallTime(long ms) {
        this.progressTracker.logMessage(StringFormatting.formatWithLocale((String)"Graph execution took %d ms", (Object[])new Object[]{ms}));
    }

    private static final class EmptyResult
    extends Result {
        private EmptyResult() {
        }

        @Override
        HugeObjectArray<NeighborList> neighborList() {
            return HugeObjectArray.of((Object[])new NeighborList[0]);
        }

        @Override
        public int ranIterations() {
            return 0;
        }

        @Override
        public boolean didConverge() {
            return false;
        }

        @Override
        public long nodePairsConsidered() {
            return 0L;
        }

        @Override
        public LongStream neighborsOf(long nodeId) {
            return LongStream.empty();
        }

        @Override
        public long size() {
            return 0L;
        }
    }

    @ValueClass
    public static abstract class Result {
        abstract HugeObjectArray<NeighborList> neighborList();

        public abstract int ranIterations();

        public abstract boolean didConverge();

        public abstract long nodePairsConsidered();

        public LongStream neighborsOf(long nodeId) {
            return ((NeighborList)this.neighborList().get(nodeId)).elements().map(NeighborList::clearCheckedFlag);
        }

        public Stream<SimilarityResult> streamSimilarityResult() {
            HugeObjectArray<NeighborList> neighborList = this.neighborList();
            return Stream.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity()).flatMap(cursor -> IntStream.range(cursor.offset, cursor.limit).mapToObj(index -> ((NeighborList[])cursor.array)[index].similarityStream((long)index + cursor.base)).flatMap(Function.identity()));
        }

        public long totalSimilarityPairs() {
            HugeObjectArray<NeighborList> neighborList = this.neighborList();
            return Stream.iterate(neighborList.initCursor(neighborList.newCursor()), HugeCursor::next, UnaryOperator.identity()).flatMapToLong(cursor -> IntStream.range(cursor.offset, cursor.limit).mapToLong(index -> ((NeighborList[])cursor.array)[index].size())).sum();
        }

        public long size() {
            return this.neighborList().size();
        }
    }

    private static final class JoinNeighbors
    implements Runnable {
        private final SplittableRandom random;
        private final SimilarityComputer computer;
        private final NeighborFilter neighborFilter;
        private final HugeObjectArray<NeighborList> neighbors;
        private final HugeObjectArray<LongArrayList> allOldNeighbors;
        private final HugeObjectArray<LongArrayList> allNewNeighbors;
        private final HugeObjectArray<LongArrayList> allReverseOldNeighbors;
        private final HugeObjectArray<LongArrayList> allReverseNewNeighbors;
        private final long n;
        private final int k;
        private final int sampledK;
        private final int randomJoins;
        private final ProgressTracker progressTracker;
        private long updateCount;
        private final Partition partition;
        private long nodePairsConsidered;
        private final double perturbationRate;

        private JoinNeighbors(SplittableRandom random, SimilarityComputer computer, NeighborFilter neighborFilter, HugeObjectArray<NeighborList> neighbors, HugeObjectArray<LongArrayList> allOldNeighbors, HugeObjectArray<LongArrayList> allNewNeighbors, HugeObjectArray<LongArrayList> allReverseOldNeighbors, HugeObjectArray<LongArrayList> allReverseNewNeighbors, long n, int k, int sampledK, double perturbationRate, int randomJoins, Partition partition, ProgressTracker progressTracker) {
            this.random = random;
            this.computer = computer;
            this.neighborFilter = neighborFilter;
            this.neighbors = neighbors;
            this.allOldNeighbors = allOldNeighbors;
            this.allNewNeighbors = allNewNeighbors;
            this.allReverseOldNeighbors = allReverseOldNeighbors;
            this.allReverseNewNeighbors = allReverseNewNeighbors;
            this.n = n;
            this.k = k;
            this.sampledK = sampledK;
            this.randomJoins = randomJoins;
            this.partition = partition;
            this.progressTracker = progressTracker;
            this.perturbationRate = perturbationRate;
            this.updateCount = 0L;
            this.nodePairsConsidered = 0L;
        }

        @Override
        public void run() {
            SplittableRandom rng = this.random;
            SimilarityComputer computer = this.computer;
            long n = this.n;
            int k = this.k;
            int sampledK = this.sampledK;
            HugeObjectArray<NeighborList> allNeighbors = this.neighbors;
            HugeObjectArray<LongArrayList> allNewNeighbors = this.allNewNeighbors;
            HugeObjectArray<LongArrayList> allOldNeighbors = this.allOldNeighbors;
            HugeObjectArray<LongArrayList> allReverseNewNeighbors = this.allReverseNewNeighbors;
            HugeObjectArray<LongArrayList> allReverseOldNeighbors = this.allReverseOldNeighbors;
            long startNode = this.partition.startNode();
            long endNode = startNode + this.partition.nodeCount();
            for (long nodeId = startNode; nodeId < endNode; ++nodeId) {
                LongArrayList newNeighbors;
                LongArrayList oldNeighbors = (LongArrayList)allOldNeighbors.get(nodeId);
                if (oldNeighbors != null) {
                    this.joinOldNeighbors(rng, sampledK, allReverseOldNeighbors, nodeId, oldNeighbors);
                }
                if ((newNeighbors = (LongArrayList)allNewNeighbors.get(nodeId)) != null) {
                    this.updateCount += this.joinNewNeighbors(rng, computer, n, k, sampledK, allNeighbors, allReverseNewNeighbors, nodeId, oldNeighbors, newNeighbors);
                }
                this.randomJoins(rng, computer, n, k, allNeighbors, nodeId, this.randomJoins);
            }
            this.progressTracker.logProgress(this.partition.nodeCount());
        }

        private void joinOldNeighbors(SplittableRandom rng, int sampledK, HugeObjectArray<LongArrayList> allReverseOldNeighbors, long nodeId, LongArrayList oldNeighbors) {
            LongArrayList reverseOldNeighbors = (LongArrayList)allReverseOldNeighbors.get(nodeId);
            if (reverseOldNeighbors != null) {
                int numberOfReverseOldNeighbors = reverseOldNeighbors.size();
                for (LongCursor elem : reverseOldNeighbors) {
                    if (rng.nextInt(numberOfReverseOldNeighbors) >= sampledK) continue;
                    oldNeighbors.add(elem.value);
                }
            }
        }

        private long joinNewNeighbors(SplittableRandom rng, SimilarityComputer computer, long n, int k, int sampledK, HugeObjectArray<NeighborList> allNeighbors, HugeObjectArray<LongArrayList> allReverseNewNeighbors, long nodeId, LongArrayList oldNeighbors, LongArrayList newNeighbors) {
            long updateCount = 0L;
            this.joinOldNeighbors(rng, sampledK, allReverseNewNeighbors, nodeId, newNeighbors);
            long[] newNeighborElements = newNeighbors.buffer;
            int newNeighborsCount = newNeighbors.elementsCount;
            for (int i = 0; i < newNeighborsCount; ++i) {
                long elem1 = newNeighborElements[i];
                assert (elem1 != nodeId);
                updateCount += this.join(rng, computer, allNeighbors, n, k, elem1, nodeId);
                for (int j = i + 1; j < newNeighborsCount; ++j) {
                    long elem2 = newNeighborElements[i];
                    if (elem1 == elem2) continue;
                    updateCount += this.join(rng, computer, allNeighbors, n, k, elem1, elem2);
                    updateCount += this.join(rng, computer, allNeighbors, n, k, elem2, elem1);
                }
                if (oldNeighbors == null) continue;
                for (LongCursor oldElemCursor : oldNeighbors) {
                    long elem2 = oldElemCursor.value;
                    if (elem1 == elem2) continue;
                    updateCount += this.join(rng, computer, allNeighbors, n, k, elem1, elem2);
                    updateCount += this.join(rng, computer, allNeighbors, n, k, elem2, elem1);
                }
            }
            return updateCount;
        }

        private void randomJoins(SplittableRandom rng, SimilarityComputer computer, long n, int k, HugeObjectArray<NeighborList> allNeighbors, long nodeId, int randomJoins) {
            for (int i = 0; i < randomJoins; ++i) {
                long randomNodeId = rng.nextLong(n - 1L);
                if (randomNodeId >= nodeId) {
                    ++randomNodeId;
                }
                this.join(rng, computer, allNeighbors, n, k, nodeId, randomNodeId);
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private long join(SplittableRandom splittableRandom, SimilarityComputer computer, HugeObjectArray<NeighborList> allNeighbors, long n, int k, long base, long joiner) {
            NeighborList neighbors;
            assert (base != joiner);
            assert (n > 1L && k > 0);
            if (this.neighborFilter.excludeNodePair(base, joiner)) {
                return 0L;
            }
            double similarity = computer.safeSimilarity(base, joiner);
            ++this.nodePairsConsidered;
            NeighborList neighborList = neighbors = (NeighborList)allNeighbors.get(base);
            synchronized (neighborList) {
                int k2 = neighbors.size();
                assert (k2 > 0);
                assert (k2 <= k);
                assert ((long)k2 <= n - 1L);
                return neighbors.add(joiner, similarity, splittableRandom, this.perturbationRate);
            }
        }

        long nodePairsConsidered() {
            return this.nodePairsConsidered;
        }
    }
}

