/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.kmeans;

import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.kmeans.ClusterManager;
import org.neo4j.gds.kmeans.DoubleClusterManager;
import org.neo4j.gds.kmeans.DoubleKmeansTask;
import org.neo4j.gds.kmeans.FloatKmeansTask;
import org.neo4j.gds.kmeans.KmeansSampler;
import org.neo4j.gds.kmeans.TaskPhase;
import org.neo4j.gds.mem.MemoryUsage;

public abstract class KmeansTask
implements Runnable {
    private final ClusterManager clusterManager;
    private final Partition partition;
    final NodePropertyValues nodePropertyValues;
    private final HugeDoubleArray distanceFromCentroid;
    final HugeIntArray communities;
    final long[] communitySizes;
    final int k;
    final int dimensions;
    private long swaps;
    private double distance;
    private double squaredDistance = 0.0;
    private TaskPhase phase;

    long getNumAssignedAtCluster(int ith) {
        return this.communitySizes[ith];
    }

    long getSwaps() {
        return this.swaps;
    }

    static MemoryEstimation memoryEstimation(int k, int fakeDimensions) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder(KmeansTask.class);
        builder.fixed("communitySizes", MemoryUsage.sizeOfLongArray((long)k)).add("communityCoordinateSums", MemoryEstimations.of((String)"communityCoordinateSums", (MemoryRange)MemoryRange.of((long)((long)k * MemoryUsage.sizeOfFloatArray((long)fakeDimensions)), (long)((long)k * MemoryUsage.sizeOfDoubleArray((long)fakeDimensions)))));
        return builder.build();
    }

    abstract void reset();

    abstract void updateAfterAssignmentToCentroid(long var1, int var3);

    KmeansTask(KmeansSampler.SamplerType samplerType, ClusterManager clusterManager, NodePropertyValues nodePropertyValues, HugeIntArray communities, HugeDoubleArray distanceFromCentroid, int k, int dimensions, Partition partition) {
        this.clusterManager = clusterManager;
        this.nodePropertyValues = nodePropertyValues;
        this.communities = communities;
        this.distanceFromCentroid = distanceFromCentroid;
        this.k = k;
        this.dimensions = dimensions;
        this.partition = partition;
        this.communitySizes = new long[k];
        this.phase = samplerType == KmeansSampler.SamplerType.UNIFORM ? TaskPhase.ITERATION : TaskPhase.INITIAL;
        this.distance = 0.0;
    }

    static KmeansTask createTask(KmeansSampler.SamplerType samplerType, ClusterManager clusterManager, NodePropertyValues nodePropertyValues, HugeIntArray communities, HugeDoubleArray distanceFromCentroid, int k, int dimensions, Partition partition) {
        if (clusterManager instanceof DoubleClusterManager) {
            return new DoubleKmeansTask(samplerType, clusterManager, nodePropertyValues, communities, distanceFromCentroid, k, dimensions, partition);
        }
        return new FloatKmeansTask(samplerType, clusterManager, nodePropertyValues, communities, distanceFromCentroid, k, dimensions, partition);
    }

    void switchToPhase(TaskPhase newPhase) {
        this.phase = newPhase;
    }

    private void assignNodeToCentroid(long startNode, long endNode) {
        this.swaps = 0L;
        this.reset();
        for (long nodeId = startNode; nodeId < endNode; ++nodeId) {
            int closestCommunity;
            int n = closestCommunity = this.clusterManager.findClosestCentroid(nodeId);
            this.communitySizes[n] = this.communitySizes[n] + 1L;
            int previousCommunity = this.communities.get(nodeId);
            if (closestCommunity != previousCommunity) {
                ++this.swaps;
            }
            this.communities.set(nodeId, closestCommunity);
            this.updateAfterAssignmentToCentroid(nodeId, closestCommunity);
        }
    }

    public double getDistanceFromCentroidNormalized() {
        return this.distance / (double)this.communities.size();
    }

    public double getSquaredDistance() {
        return this.squaredDistance;
    }

    private void calculateFinalDistance(long startNode, long endNode) {
        for (long nodeId = startNode; nodeId < endNode; ++nodeId) {
            double nodeCentroidDistance = this.clusterManager.euclidean(nodeId, this.communities.get(nodeId));
            this.distance += nodeCentroidDistance;
            this.distanceFromCentroid.set(nodeId, nodeCentroidDistance);
        }
    }

    private void distanceFromLastSampledCentroid(long startNode, long endNode, int numAssigned) {
        this.squaredDistance = 0.0;
        for (long nodeId = startNode; nodeId < endNode; ++nodeId) {
            int communityId;
            if (this.distanceFromCentroid.get(nodeId) > -1.0) {
                double nodeCentroidDistance = this.clusterManager.euclidean(nodeId, numAssigned - 1);
                if (numAssigned == 1) {
                    this.distanceFromCentroid.set(nodeId, nodeCentroidDistance);
                    this.squaredDistance += nodeCentroidDistance * nodeCentroidDistance;
                    this.communities.set(nodeId, 0);
                } else if (this.distanceFromCentroid.get(nodeId) > nodeCentroidDistance) {
                    this.distanceFromCentroid.set(nodeId, nodeCentroidDistance);
                    this.squaredDistance += nodeCentroidDistance * nodeCentroidDistance;
                    this.communities.set(nodeId, numAssigned - 1);
                } else {
                    this.squaredDistance += this.distanceFromCentroid.get(nodeId) * this.distanceFromCentroid.get(nodeId);
                }
            }
            if (numAssigned != this.k) continue;
            if (this.distanceFromCentroid.get(nodeId) <= -1.0) {
                this.communities.set(nodeId, (int)(-this.distanceFromCentroid.get(nodeId)) - 1);
                this.distanceFromCentroid.set(nodeId, 0.0);
            }
            int n = communityId = this.communities.get(nodeId);
            this.communitySizes[n] = this.communitySizes[n] + 1L;
            this.updateAfterAssignmentToCentroid(nodeId, communityId);
        }
    }

    @Override
    public void run() {
        long startNode = this.partition.startNode();
        long endNode = startNode + this.partition.nodeCount();
        if (this.phase == TaskPhase.ITERATION) {
            this.assignNodeToCentroid(startNode, endNode);
        } else if (this.phase == TaskPhase.DISTANCE) {
            this.calculateFinalDistance(startNode, endNode);
        } else {
            this.distanceFromLastSampledCentroid(startNode, endNode, this.clusterManager.getCurrentlyAssigned());
        }
    }
}

