/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.kmeans;

import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.nodeproperties.ValueType;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.kmeans.ClusterManager;
import org.neo4j.gds.kmeans.ImmutableKmeansResult;
import org.neo4j.gds.kmeans.KmeansBaseConfig;
import org.neo4j.gds.kmeans.KmeansContext;
import org.neo4j.gds.kmeans.KmeansIterationStopper;
import org.neo4j.gds.kmeans.KmeansResult;
import org.neo4j.gds.kmeans.KmeansSampler;
import org.neo4j.gds.kmeans.KmeansTask;
import org.neo4j.gds.kmeans.SilhouetteTask;
import org.neo4j.gds.kmeans.TaskPhase;

public class Kmeans
extends Algorithm<KmeansResult> {
    private static final int UNASSIGNED = -1;
    private final String nodeWeightProperty;
    private HugeIntArray bestCommunities;
    private final Graph graph;
    private final int k;
    private final int concurrency;
    private final ExecutorService executorService;
    private final SplittableRandom random;
    private final NodePropertyValues nodePropertyValues;
    private final int dimensions;
    private final boolean computeSilhouette;
    private double[][] bestCentroids;
    private HugeDoubleArray distanceFromCentroid;
    private final KmeansIterationStopper kmeansIterationStopper;
    private final int maximumNumberOfRestarts;
    private HugeDoubleArray silhouette;
    private final KmeansSampler.SamplerType samplerType;
    private double averageSilhouette;
    private double bestDistance;
    private long[] nodesInCluster;
    private final List<List<Double>> seededCentroids;

    public static Kmeans createKmeans(Graph graph, KmeansBaseConfig config, KmeansContext context) {
        String nodeWeightProperty = config.nodeProperty();
        NodePropertyValues nodeProperties = graph.nodeProperties(nodeWeightProperty);
        if (nodeProperties == null) {
            throw new IllegalArgumentException("Property '" + nodeWeightProperty + "' does not exist for all nodes");
        }
        return new Kmeans(context.progressTracker(), context.executor(), graph, config.k(), config.concurrency(), config.maxIterations(), config.numberOfRestarts(), config.deltaThreshold(), nodeProperties, config.computeSilhouette(), config.initialSampler(), config.seedCentroids(), nodeWeightProperty, Kmeans.getSplittableRandom(config.randomSeed()));
    }

    Kmeans(ProgressTracker progressTracker, ExecutorService executorService, Graph graph, int k, int concurrency, int maxIterations, int maximumNumberOfRestarts, double deltaThreshold, NodePropertyValues nodePropertyValues, boolean computeSilhouette, KmeansSampler.SamplerType initialSampler, List<List<Double>> seededCentroids, String nodeWeightProperty, SplittableRandom random) {
        super(progressTracker);
        this.nodeWeightProperty = nodeWeightProperty;
        this.executorService = executorService;
        this.graph = graph;
        this.k = k;
        this.concurrency = concurrency;
        this.random = random;
        this.bestCommunities = HugeIntArray.newArray((long)graph.nodeCount());
        this.nodePropertyValues = nodePropertyValues;
        this.dimensions = nodePropertyValues.doubleArrayValue(0L).length;
        this.kmeansIterationStopper = new KmeansIterationStopper(deltaThreshold, maxIterations, graph.nodeCount());
        this.maximumNumberOfRestarts = maximumNumberOfRestarts;
        this.distanceFromCentroid = HugeDoubleArray.newArray((long)graph.nodeCount());
        this.computeSilhouette = computeSilhouette;
        this.samplerType = initialSampler;
        this.seededCentroids = seededCentroids;
        this.nodesInCluster = new long[k];
    }

    public KmeansResult compute() {
        this.progressTracker.beginSubTask();
        this.checkInputValidity();
        if ((long)this.k > this.graph.nodeCount()) {
            this.progressTracker.logWarning("Number of requested clusters is larger than the number of nodes.");
            this.bestCommunities.setAll(v -> (int)v);
            this.distanceFromCentroid.setAll(v -> 0.0);
            this.progressTracker.endSubTask();
            this.bestCentroids = new double[(int)this.graph.nodeCount()][this.dimensions];
            for (int i = 0; i < (int)this.graph.nodeCount(); ++i) {
                this.bestCentroids[i] = this.nodePropertyValues.doubleArrayValue((long)i);
            }
            return ImmutableKmeansResult.of(this.bestCommunities, this.distanceFromCentroid, this.bestCentroids, 0.0, this.silhouette, 0.0);
        }
        long nodeCount = this.graph.nodeCount();
        HugeIntArray currentCommunities = HugeIntArray.newArray((long)nodeCount);
        HugeDoubleArray currentDistanceFromCentroid = HugeDoubleArray.newArray((long)nodeCount);
        this.bestDistance = Double.POSITIVE_INFINITY;
        this.bestCommunities.setAll(v -> -1);
        if (this.maximumNumberOfRestarts == 1) {
            this.kMeans(nodeCount, currentCommunities, currentDistanceFromCentroid, 0);
        } else {
            for (int restartIteration = 0; restartIteration < this.maximumNumberOfRestarts; ++restartIteration) {
                this.progressTracker.beginSubTask();
                this.kMeans(nodeCount, currentCommunities, currentDistanceFromCentroid, restartIteration);
                this.progressTracker.endSubTask();
            }
        }
        if (this.computeSilhouette) {
            this.calculateSilhouette();
        }
        this.progressTracker.endSubTask();
        return ImmutableKmeansResult.of(this.bestCommunities, this.distanceFromCentroid, this.bestCentroids, this.bestDistance, this.silhouette, this.averageSilhouette);
    }

    private void kMeans(long nodeCount, HugeIntArray currentCommunities, HugeDoubleArray currentDistanceFromCentroid, int restartIteration) {
        long numberOfSwaps;
        ClusterManager clusterManager = ClusterManager.createClusterManager(this.nodePropertyValues, this.dimensions, this.k);
        currentCommunities.setAll(v -> -1);
        List tasks = PartitionUtils.rangePartition((int)this.concurrency, (long)nodeCount, partition -> KmeansTask.createTask(this.samplerType, clusterManager, this.nodePropertyValues, currentCommunities, currentDistanceFromCentroid, this.k, this.dimensions, partition), Optional.of((int)nodeCount / this.concurrency));
        int numberOfTasks = tasks.size();
        KmeansSampler sampler = KmeansSampler.createSampler(this.samplerType, this.random, this.nodePropertyValues, clusterManager, nodeCount, this.k, this.concurrency, currentDistanceFromCentroid, this.executorService, tasks, this.progressTracker);
        assert (numberOfTasks <= this.concurrency);
        this.initializeCentroids(clusterManager, sampler);
        int iteration = 0;
        this.progressTracker.beginSubTask();
        do {
            boolean shouldComputeDistance;
            this.progressTracker.beginSubTask();
            numberOfSwaps = 0L;
            boolean bl = shouldComputeDistance = iteration > 0 || this.samplerType == KmeansSampler.SamplerType.UNIFORM;
            if (shouldComputeDistance) {
                RunWithConcurrency.builder().concurrency(this.concurrency).tasks((Iterable)tasks).executor(this.executorService).run();
                for (KmeansTask task : tasks) {
                    numberOfSwaps += task.getSwaps();
                }
            }
            this.recomputeCentroids(clusterManager, tasks);
            this.progressTracker.endSubTask();
        } while (!this.kmeansIterationStopper.shouldQuit(numberOfSwaps, ++iteration));
        this.progressTracker.endSubTask();
        double averageDistanceFromCentroid = this.calculatedistancePhase(tasks);
        this.updateBestSolution(restartIteration, clusterManager, averageDistanceFromCentroid, currentCommunities, currentDistanceFromCentroid);
    }

    private void initializeCentroids(ClusterManager clusterManager, KmeansSampler sampler) {
        this.progressTracker.beginSubTask();
        if (!this.seededCentroids.isEmpty()) {
            clusterManager.assignSeededCentroids(this.seededCentroids);
        } else {
            sampler.performInitialSampling();
        }
        this.progressTracker.endSubTask();
    }

    private void recomputeCentroids(ClusterManager clusterManager, List<KmeansTask> tasks) {
        clusterManager.reset();
        for (KmeansTask task : tasks) {
            clusterManager.updateFromTask(task);
        }
        clusterManager.normalizeClusters();
    }

    public void release() {
    }

    @NotNull
    private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
        return randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
    }

    private void checkInputValidity() {
        if (!this.seededCentroids.isEmpty()) {
            for (List<Double> centroid : this.seededCentroids) {
                if (centroid.size() != this.dimensions) {
                    throw new IllegalStateException("All property arrays for K-Means should have the same number of dimensions");
                }
                for (double value : centroid) {
                    if (!Double.isNaN(value)) continue;
                    throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                }
            }
        }
        ParallelUtil.parallelForEachNode((long)this.graph.nodeCount(), (int)this.concurrency, nodeId -> {
            if (this.nodePropertyValues.valueType() == ValueType.FLOAT_ARRAY) {
                float[] value = this.nodePropertyValues.floatArrayValue(nodeId);
                if (value == null) {
                    throw new IllegalArgumentException("Property '" + this.nodeWeightProperty + "' does not exist for all nodes");
                }
                if (value.length != this.dimensions) {
                    throw new IllegalStateException("All property arrays for K-Means should have the same number of dimensions");
                }
                for (int dimension = 0; dimension < this.dimensions; ++dimension) {
                    if (!Float.isNaN(value[dimension])) continue;
                    throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                }
            } else {
                double[] value = this.nodePropertyValues.doubleArrayValue(nodeId);
                if (value == null) {
                    throw new IllegalArgumentException("Property '" + this.nodeWeightProperty + "' does not exist for all nodes");
                }
                if (value.length != this.dimensions) {
                    throw new IllegalStateException("All property arrays for K-Means should have the same number of dimensions");
                }
                for (int dimension = 0; dimension < this.dimensions; ++dimension) {
                    if (!Double.isNaN(value[dimension])) continue;
                    throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                }
            }
        });
    }

    private void calculateSilhouette() {
        long nodeCount = this.graph.nodeCount();
        this.progressTracker.beginSubTask();
        this.silhouette = HugeDoubleArray.newArray((long)nodeCount);
        List tasks = PartitionUtils.rangePartition((int)this.concurrency, (long)nodeCount, partition -> SilhouetteTask.createTask(this.nodePropertyValues, this.bestCommunities, this.silhouette, this.k, this.dimensions, this.nodesInCluster, partition, this.progressTracker), Optional.of((int)nodeCount / this.concurrency));
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks((Iterable)tasks).executor(this.executorService).run();
        for (SilhouetteTask task : tasks) {
            this.averageSilhouette += task.getAverageSilhouette();
        }
        this.progressTracker.endSubTask();
    }

    private double calculatedistancePhase(List<KmeansTask> tasks) {
        for (KmeansTask task : tasks) {
            task.switchToPhase(TaskPhase.DISTANCE);
        }
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(tasks).executor(this.executorService).run();
        double averageDistanceFromCentroid = 0.0;
        for (KmeansTask task : tasks) {
            averageDistanceFromCentroid += task.getDistanceFromCentroidNormalized();
        }
        return averageDistanceFromCentroid;
    }

    private void updateBestSolution(int restartIteration, ClusterManager clusterManager, double averageDistanceFromCentroid, HugeIntArray currentCommunities, HugeDoubleArray currentDistanceFromCentroid) {
        if (restartIteration >= 1) {
            if (averageDistanceFromCentroid < this.bestDistance) {
                this.bestDistance = averageDistanceFromCentroid;
                ParallelUtil.parallelForEachNode((Graph)this.graph, (int)this.concurrency, v -> {
                    this.bestCommunities.set(v, currentCommunities.get(v));
                    this.distanceFromCentroid.set(v, currentDistanceFromCentroid.get(v));
                });
                this.bestCentroids = clusterManager.getCentroids();
                if (this.computeSilhouette) {
                    this.nodesInCluster = clusterManager.getNodesInCluster();
                }
            }
        } else {
            this.bestCommunities = currentCommunities;
            this.distanceFromCentroid = currentDistanceFromCentroid;
            this.bestCentroids = clusterManager.getCentroids();
            this.bestDistance = averageDistanceFromCentroid;
            if (this.computeSilhouette) {
                this.nodesInCluster = clusterManager.getNodesInCluster();
            }
        }
    }
}

