/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.similarity.filteredknn;

import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.LongPredicate;
import java.util.function.UnaryOperator;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.paged.HugeCursor;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.filteredknn.TargetNodeFilter;
import org.neo4j.gds.similarity.knn.NeighbourConsumers;
import org.neo4j.gds.similarity.knn.SimilarityFunction;

public final class TargetNodeFiltering
implements NeighbourConsumers {
    private final HugeObjectArray<TargetNodeFilter> targetNodeFilters;

    static TargetNodeFiltering create(long nodeCount, int k, LongPredicate targetNodePredicate, Graph graph, Optional<SimilarityFunction> optionalSimilarityFunction, double similarityCutoff) {
        HugeObjectArray neighbourConsumers = HugeObjectArray.newArray(TargetNodeFilter.class, (long)nodeCount);
        int i = 0;
        while ((long)i < nodeCount) {
            Optional<Set<Pair<Double, Long>>> optionalSeeds = TargetNodeFiltering.prepareSeeds(graph, targetNodePredicate, k, i, optionalSimilarityFunction);
            TargetNodeFilter targetNodeFilter = TargetNodeFilter.create(targetNodePredicate, k, optionalSeeds, similarityCutoff);
            neighbourConsumers.set((long)i, (Object)targetNodeFilter);
            ++i;
        }
        return new TargetNodeFiltering((HugeObjectArray<TargetNodeFilter>)neighbourConsumers);
    }

    private static Optional<Set<Pair<Double, Long>>> prepareSeeds(Graph graph, LongPredicate targetNodePredicate, int k, int n, Optional<SimilarityFunction> similarityFunction) {
        if (similarityFunction.isEmpty()) {
            return Optional.empty();
        }
        Set<Pair<Double, Long>> seeds = TargetNodeFiltering.prepareSeedSet(k);
        graph.forEachNode(m -> {
            if ((long)n == m) {
                return true;
            }
            if (!targetNodePredicate.test(m)) {
                return true;
            }
            double similarityScore = ((SimilarityFunction)similarityFunction.get()).computeSimilarity(n, m);
            seeds.add(Pair.of((Object)similarityScore, (Object)m));
            return seeds.size() < k;
        });
        return Optional.of(seeds);
    }

    @NotNull
    private static Set<Pair<Double, Long>> prepareSeedSet(int k) {
        float defaultLoadFactor = 0.75f;
        int initialCapacity = (int)((float)k / defaultLoadFactor);
        return new HashSet<Pair<Double, Long>>(initialCapacity, defaultLoadFactor);
    }

    private TargetNodeFiltering(HugeObjectArray<TargetNodeFilter> targetNodeFilters) {
        this.targetNodeFilters = targetNodeFilters;
    }

    @Override
    public TargetNodeFilter get(long nodeId) {
        return (TargetNodeFilter)this.targetNodeFilters.get(nodeId);
    }

    Stream<SimilarityResult> asSimilarityResultStream(LongPredicate sourceNodePredicate) {
        return Stream.iterate(this.targetNodeFilters.initCursor(this.targetNodeFilters.newCursor()), HugeCursor::next, UnaryOperator.identity()).flatMap(cursor -> IntStream.range(cursor.offset, cursor.limit).filter(index -> sourceNodePredicate.test((long)index + cursor.base)).mapToObj(index -> ((TargetNodeFilter[])cursor.array)[index].asSimilarityStream((long)index + cursor.base)).flatMap(Function.identity()));
    }

    long numberOfSimilarityPairs(LongPredicate sourceNodePredicate) {
        return Stream.iterate(this.targetNodeFilters.initCursor(this.targetNodeFilters.newCursor()), HugeCursor::next, UnaryOperator.identity()).flatMapToLong(cursor -> IntStream.range(cursor.offset, cursor.limit).filter(index -> sourceNodePredicate.test((long)index + cursor.base)).mapToLong(index -> ((TargetNodeFilter[])cursor.array)[index].size())).sum();
    }
}

