/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.similarity.filteredknn;

import com.carrotsearch.hppc.LongArrayList;
import org.apache.commons.lang3.function.TriFunction;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.similarity.filteredknn.FilteredKnn;
import org.neo4j.gds.similarity.filteredknn.FilteredKnnBaseConfig;
import org.neo4j.gds.similarity.knn.ImmutableKnnContext;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnFactory;
import org.neo4j.gds.similarity.knn.NeighborList;

public class FilteredKnnFactory<CONFIG extends FilteredKnnBaseConfig>
extends GraphAlgorithmFactory<FilteredKnn, CONFIG> {
    private static final String FILTERED_KNN_TASK_NAME = "Filtered KNN";
    private final TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> unseededFilteredKnnSupplier;
    private final TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> seededFilteredKnnSupplier;

    public FilteredKnnFactory() {
        this(FilteredKnn::createWithoutSeeding, FilteredKnn::createWithDefaultSeeding);
    }

    FilteredKnnFactory(TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> unseededFilteredKnnSupplier, TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> seededFilteredKnnSupplier) {
        this.unseededFilteredKnnSupplier = unseededFilteredKnnSupplier;
        this.seededFilteredKnnSupplier = seededFilteredKnnSupplier;
    }

    public String taskName() {
        return FILTERED_KNN_TASK_NAME;
    }

    public FilteredKnn build(Graph graph, CONFIG configuration, ProgressTracker progressTracker) {
        KnnContext knnContext = ImmutableKnnContext.builder().progressTracker(progressTracker).executor(Pools.DEFAULT).build();
        if (configuration.seedTargetNodes()) {
            return (FilteredKnn)((Object)this.seededFilteredKnnSupplier.apply((Object)graph, configuration, (Object)knnContext));
        }
        return (FilteredKnn)((Object)this.unseededFilteredKnnSupplier.apply((Object)graph, configuration, (Object)knnContext));
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        return MemoryEstimations.setup((String)this.taskName(), (dim, concurrency) -> {
            int boundedK = configuration.boundedK(dim.nodeCount());
            int sampledK = configuration.sampledK(dim.nodeCount());
            MemoryEstimation tempListEstimation = HugeObjectArray.memoryEstimation((MemoryEstimation)MemoryEstimations.of((String)"elements", (MemoryRange)MemoryRange.of((long)0L, (long)(MemoryUsage.sizeOfInstance(LongArrayList.class) + MemoryUsage.sizeOfLongArray((long)sampledK)))));
            return MemoryEstimations.builder(FilteredKnn.class).add("top-k-neighbors-list", HugeObjectArray.memoryEstimation((MemoryEstimation)NeighborList.memoryEstimation(boundedK))).add("old-neighbors", tempListEstimation).add("new-neighbors", tempListEstimation).add("old-reverse-neighbors", tempListEstimation).add("new-reverse-neighbors", tempListEstimation).fixed("initial-random-neighbors (per thread)", KnnFactory.initialSamplerMemoryEstimation(configuration.initialSampler(), boundedK).times((long)concurrency)).fixed("sampled-random-neighbors (per thread)", MemoryRange.of((long)(MemoryUsage.sizeOfIntArray((long)MemoryUsage.sizeOfOpenHashContainer((long)sampledK)) * (long)concurrency))).build();
        });
    }

    public Task progressTask(Graph graph, CONFIG config) {
        return KnnFactory.knnTaskTree(graph, config);
    }
}

