/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.kmeans;

import java.util.List;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.kmeans.ClusterManager;
import org.neo4j.gds.kmeans.ImmutableKmeansContext;
import org.neo4j.gds.kmeans.Kmeans;
import org.neo4j.gds.kmeans.KmeansBaseConfig;
import org.neo4j.gds.kmeans.KmeansTask;
import org.neo4j.gds.mem.MemoryUsage;

public final class KmeansAlgorithmFactory<CONFIG extends KmeansBaseConfig>
extends GraphAlgorithmFactory<Kmeans, CONFIG> {
    public String taskName() {
        return "Kmeans";
    }

    public Kmeans build(Graph graph, CONFIG configuration, ProgressTracker progressTracker) {
        List<List<Double>> seedCentroids = configuration.seedCentroids();
        if (configuration.numberOfRestarts() > 1 && seedCentroids.size() > 0) {
            throw new IllegalArgumentException("K-Means cannot be run multiple time when seeded");
        }
        if (seedCentroids.size() > 0 && seedCentroids.size() != configuration.k()) {
            throw new IllegalArgumentException("Incorrect number of seeded centroids given for running K-Means");
        }
        return Kmeans.createKmeans(graph, configuration, ImmutableKmeansContext.builder().progressTracker(progressTracker).executor(Pools.DEFAULT).build());
    }

    public Task progressTask(Graph graph, CONFIG config) {
        int iterations = config.numberOfRestarts();
        if (iterations == 1) {
            return this.kMeansTask(graph, this.taskName(), config);
        }
        return Tasks.iterativeFixed((String)this.taskName(), () -> List.of(this.kMeansTask(graph, "KMeans Iteration", config)), (int)iterations);
    }

    @NotNull
    private Task kMeansTask(Graph graph, String description, CONFIG config) {
        if (config.computeSilhouette()) {
            return Tasks.task((String)description, List.of(Tasks.leaf((String)"Initialization", (long)config.k()), Tasks.iterativeDynamic((String)"Main", () -> List.of(Tasks.leaf((String)"Iteration")), (int)config.maxIterations()), Tasks.leaf((String)"Silhouette", (long)graph.nodeCount())));
        }
        return Tasks.task((String)description, List.of(Tasks.leaf((String)"Initialization", (long)config.k()), Tasks.iterativeDynamic((String)"Main", () -> List.of(Tasks.leaf((String)"Iteration")), (int)config.maxIterations())));
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        int fakeLength = 128;
        MemoryEstimations.Builder builder = MemoryEstimations.builder(Kmeans.class).perNode("bestCommunities", HugeIntArray::memoryEstimation).fixed("bestCentroids", MemoryUsage.sizeOfArray((long)configuration.k(), (long)MemoryUsage.sizeOfDoubleArray((long)fakeLength))).perNode("nodesInCluster", MemoryUsage::sizeOfLongArray).perNode("distanceFromCentroid", HugeDoubleArray::memoryEstimation).add(ClusterManager.memoryEstimation(configuration.k(), fakeLength)).perThread("KMeansTask", KmeansTask.memoryEstimation(configuration.k(), fakeLength));
        if (configuration.computeSilhouette()) {
            builder.perNode("silhouette", HugeDoubleArray::memoryEstimation);
        }
        if (configuration.isSeeded()) {
            List<List<Double>> centroids = configuration.seedCentroids();
            builder.fixed("seededCentroids", MemoryUsage.sizeOf(centroids));
        }
        return builder.build();
    }
}

