/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongArrayList;
import java.util.List;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.similarity.knn.ImmutableKnnContext;
import org.neo4j.gds.similarity.knn.Knn;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnSampler;
import org.neo4j.gds.similarity.knn.NeighborList;
import org.neo4j.gds.similarity.knn.RandomWalkKnnSampler;
import org.neo4j.gds.similarity.knn.UniformKnnSampler;

public class KnnFactory<CONFIG extends KnnBaseConfig>
extends GraphAlgorithmFactory<Knn, CONFIG> {
    private static final String KNN_BASE_TASK_NAME = "Knn";

    public String taskName() {
        return KNN_BASE_TASK_NAME;
    }

    public Knn build(Graph graph, CONFIG configuration, ProgressTracker progressTracker) {
        return Knn.createWithDefaults(graph, configuration, ImmutableKnnContext.builder().progressTracker(progressTracker).executor(Pools.DEFAULT).build());
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        return MemoryEstimations.setup((String)this.taskName(), (dim, concurrency) -> {
            int boundedK = configuration.boundedK(dim.nodeCount());
            int sampledK = configuration.sampledK(dim.nodeCount());
            MemoryEstimation tempListEstimation = HugeObjectArray.memoryEstimation((MemoryEstimation)MemoryEstimations.of((String)"elements", (MemoryRange)MemoryRange.of((long)0L, (long)(MemoryUsage.sizeOfInstance(LongArrayList.class) + MemoryUsage.sizeOfLongArray((long)sampledK)))));
            return MemoryEstimations.builder(Knn.class).add("top-k-neighbors-list", HugeObjectArray.memoryEstimation((MemoryEstimation)NeighborList.memoryEstimation(boundedK))).add("old-neighbors", tempListEstimation).add("new-neighbors", tempListEstimation).add("old-reverse-neighbors", tempListEstimation).add("new-reverse-neighbors", tempListEstimation).fixed("initial-random-neighbors (per thread)", KnnFactory.initialSamplerMemoryEstimation(configuration.initialSampler(), boundedK).times((long)concurrency)).fixed("sampled-random-neighbors (per thread)", MemoryRange.of((long)(MemoryUsage.sizeOfIntArray((long)MemoryUsage.sizeOfOpenHashContainer((long)sampledK)) * (long)concurrency))).build();
        });
    }

    static MemoryRange initialSamplerMemoryEstimation(KnnSampler.SamplerType samplerType, long boundedK) {
        switch (samplerType) {
            case UNIFORM: {
                return UniformKnnSampler.memoryEstimation(boundedK);
            }
            case RANDOMWALK: {
                return RandomWalkKnnSampler.memoryEstimation(boundedK);
            }
        }
        throw new IllegalStateException("Invalid KnnSampler");
    }

    public Task progressTask(Graph graph, CONFIG config) {
        return KnnFactory.knnTaskTree(graph, config);
    }

    public static Task knnTaskTree(Graph graph, KnnBaseConfig config) {
        return Tasks.task((String)KNN_BASE_TASK_NAME, (Task)Tasks.leaf((String)"Initialize random neighbors", (long)graph.nodeCount()), (Task[])new Task[]{Tasks.iterativeDynamic((String)"Iteration", () -> List.of(Tasks.leaf((String)"Split old and new neighbors", (long)graph.nodeCount()), Tasks.leaf((String)"Reverse old and new neighbors", (long)graph.nodeCount()), Tasks.leaf((String)"Join neighbors", (long)graph.nodeCount())), (int)config.maxIterations())});
    }
}

