/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.beta.modularity;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.cursors.LongLongCursor;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.stream.BaseStream;
import java.util.stream.LongStream;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.NodeProperties;
import org.neo4j.gds.api.RelationshipIterator;
import org.neo4j.gds.api.nodeproperties.LongNodeProperties;
import org.neo4j.gds.beta.k1coloring.ImmutableK1ColoringStreamConfig;
import org.neo4j.gds.beta.k1coloring.K1Coloring;
import org.neo4j.gds.beta.k1coloring.K1ColoringFactory;
import org.neo4j.gds.beta.k1coloring.K1ColoringStreamConfig;
import org.neo4j.gds.beta.modularity.ModularityOptimizationTask;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.paged.HugeAtomicDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeLongLongMap;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.utils.StringFormatting;

public final class ModularityOptimization
extends Algorithm<ModularityOptimization> {
    public static final int K1COLORING_MAX_ITERATIONS = 5;
    private final int concurrency;
    private final int maxIterations;
    private final long nodeCount;
    private final long minBatchSize;
    private final double tolerance;
    private final Graph graph;
    private final NodeProperties seedProperty;
    private final ExecutorService executor;
    private int iterationCounter;
    private boolean didConverge = false;
    private double totalNodeWeight = 0.0;
    private double modularity = -1.0;
    private BitSet colorsUsed;
    private HugeLongArray colors;
    private HugeLongArray currentCommunities;
    private HugeLongArray nextCommunities;
    private HugeLongArray reverseSeedCommunityMapping;
    private HugeDoubleArray cumulativeNodeWeights;
    private HugeDoubleArray nodeCommunityInfluences;
    private HugeAtomicDoubleArray communityWeights;
    private HugeAtomicDoubleArray communityWeightUpdates;

    public ModularityOptimization(Graph graph, int maxIterations, double tolerance, @Nullable NodeProperties seedProperty, int concurrency, int minBatchSize, ExecutorService executor, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.nodeCount = graph.nodeCount();
        this.maxIterations = maxIterations;
        this.tolerance = tolerance;
        this.seedProperty = seedProperty;
        this.executor = executor;
        this.concurrency = concurrency;
        this.minBatchSize = minBatchSize;
        if (maxIterations < 1) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Need to run at least one iteration, but got %d", (Object[])new Object[]{maxIterations}));
        }
    }

    public ModularityOptimization compute() {
        this.progressTracker.beginSubTask();
        this.progressTracker.beginSubTask();
        this.computeColoring();
        this.initSeeding();
        this.init();
        this.progressTracker.endSubTask();
        this.progressTracker.beginSubTask();
        this.iterationCounter = 0;
        while (this.iterationCounter < this.maxIterations) {
            this.progressTracker.beginSubTask();
            this.nodeCommunityInfluences.fill(0.0);
            long currentColor = this.colorsUsed.nextSetBit(0);
            while (currentColor != -1L) {
                this.assertRunning();
                this.optimizeForColor(currentColor);
                currentColor = this.colorsUsed.nextSetBit(currentColor + 1L);
            }
            boolean hasConverged = !this.updateModularity();
            this.progressTracker.endSubTask();
            if (hasConverged) {
                this.didConverge = true;
                ++this.iterationCounter;
                break;
            }
            ++this.iterationCounter;
        }
        this.progressTracker.endSubTask();
        this.progressTracker.endSubTask();
        return this;
    }

    private void computeColoring() {
        K1ColoringStreamConfig k1Config = ImmutableK1ColoringStreamConfig.builder().concurrency(this.concurrency).maxIterations(5).batchSize((int)this.minBatchSize).build();
        K1Coloring coloring = new K1ColoringFactory<K1ColoringStreamConfig>().build(this.graph, k1Config, this.progressTracker);
        coloring.setTerminationFlag(this.terminationFlag);
        this.colors = coloring.compute();
        this.colorsUsed = coloring.usedColors();
    }

    private void initSeeding() {
        this.currentCommunities = HugeLongArray.newArray((long)this.nodeCount);
        if (this.seedProperty == null) {
            return;
        }
        long maxSeedCommunity = this.seedProperty.getMaxLongPropertyValue().orElse(0L);
        HugeLongLongMap communityMapping = new HugeLongLongMap(this.nodeCount);
        long nextAvailableInternalCommunityId = -1L;
        for (long nodeId = 0L; nodeId < this.nodeCount; ++nodeId) {
            long seedCommunity = this.seedProperty.longValue(nodeId);
            if (seedCommunity < 0L) {
                seedCommunity = -1L;
            }
            long l = seedCommunity = seedCommunity >= 0L ? seedCommunity : this.graph.toOriginalNodeId(nodeId) + maxSeedCommunity;
            if (communityMapping.getOrDefault(seedCommunity, -1L) < 0L) {
                communityMapping.addTo(seedCommunity, ++nextAvailableInternalCommunityId);
            }
            this.currentCommunities.set(nodeId, communityMapping.getOrDefault(seedCommunity, -1L));
        }
        this.reverseSeedCommunityMapping = HugeLongArray.newArray((long)communityMapping.size());
        for (LongLongCursor entry : communityMapping) {
            this.reverseSeedCommunityMapping.set(entry.value, entry.key);
        }
    }

    private void init() {
        this.nextCommunities = HugeLongArray.newArray((long)this.nodeCount);
        this.cumulativeNodeWeights = HugeDoubleArray.newArray((long)this.nodeCount);
        this.nodeCommunityInfluences = HugeDoubleArray.newArray((long)this.nodeCount);
        this.communityWeights = HugeAtomicDoubleArray.newArray((long)this.nodeCount);
        this.communityWeightUpdates = HugeAtomicDoubleArray.newArray((long)this.nodeCount);
        List initTasks = PartitionUtils.rangePartition((int)this.concurrency, (long)this.nodeCount, partition -> new InitTask((RelationshipIterator)this.graph.concurrentCopy(), this.currentCommunities, this.communityWeights, this.cumulativeNodeWeights, this.seedProperty != null, (Partition)partition), Optional.of((int)this.minBatchSize));
        ParallelUtil.run((Collection)initTasks, (ExecutorService)this.executor);
        double doubleTotalNodeWeight = initTasks.stream().mapToDouble(InitTask::localSum).sum();
        this.totalNodeWeight = doubleTotalNodeWeight / 2.0;
        this.currentCommunities.copyTo(this.nextCommunities, this.nodeCount);
    }

    private void optimizeForColor(long currentColor) {
        ParallelUtil.runWithConcurrency((int)this.concurrency, this.createModularityOptimizationTasks(currentColor), (ExecutorService)this.executor);
        this.nextCommunities.copyTo(this.currentCommunities, this.nodeCount);
        ParallelUtil.parallelStreamConsume((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, stream -> stream.forEach(nodeId -> {
            double update = this.communityWeightUpdates.get(nodeId);
            this.communityWeights.update(nodeId, w -> w + update);
        }));
        this.communityWeightUpdates = HugeAtomicDoubleArray.newArray((long)this.nodeCount);
    }

    private Collection<ModularityOptimizationTask> createModularityOptimizationTasks(long currentColor) {
        return PartitionUtils.rangePartition((int)this.concurrency, (long)this.nodeCount, partition -> new ModularityOptimizationTask(this.graph, (Partition)partition, currentColor, this.totalNodeWeight, this.colors, this.currentCommunities, this.nextCommunities, this.cumulativeNodeWeights, this.nodeCommunityInfluences, this.communityWeights, this.communityWeightUpdates, this.progressTracker), Optional.of((int)this.minBatchSize));
    }

    private boolean updateModularity() {
        double oldModularity = this.modularity;
        this.modularity = this.calculateModularity();
        return this.modularity > oldModularity && Math.abs(this.modularity - oldModularity) > this.tolerance;
    }

    private double calculateModularity() {
        double ex = (Double)ParallelUtil.parallelStream((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, nodeStream -> nodeStream.mapToDouble(arg_0 -> ((HugeDoubleArray)this.nodeCommunityInfluences).get(arg_0)).reduce(Double::sum).orElseThrow(() -> new RuntimeException("Error while computing modularity")));
        double ax = (Double)ParallelUtil.parallelStream((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, nodeStream -> nodeStream.mapToDouble(nodeId -> Math.pow(this.communityWeights.get(nodeId), 2.0)).reduce(Double::sum).orElseThrow(() -> new RuntimeException("Error while computing modularity")));
        return ex / (2.0 * this.totalNodeWeight) - ax / Math.pow(2.0 * this.totalNodeWeight, 2.0);
    }

    public void release() {
        this.nextCommunities.release();
        this.communityWeights.release();
        this.communityWeightUpdates.release();
        this.cumulativeNodeWeights.release();
        this.nodeCommunityInfluences.release();
        this.colors.release();
        this.colorsUsed = null;
    }

    public long getCommunityId(long nodeId) {
        if (this.seedProperty == null || this.reverseSeedCommunityMapping == null) {
            return this.currentCommunities.get(nodeId);
        }
        return this.reverseSeedCommunityMapping.get(this.currentCommunities.get(nodeId));
    }

    public int getIterations() {
        return this.iterationCounter;
    }

    public double getModularity() {
        return this.modularity;
    }

    public boolean didConverge() {
        return this.didConverge;
    }

    public LongNodeProperties asNodeProperties() {
        return new LongNodeProperties(){

            public long longValue(long nodeId) {
                return ModularityOptimization.this.getCommunityId(nodeId);
            }

            public long size() {
                return ModularityOptimization.this.currentCommunities.size();
            }
        };
    }

    private static final class InitTask
    implements Runnable {
        private final RelationshipIterator relationshipIterator;
        private final HugeLongArray currentCommunities;
        private final HugeAtomicDoubleArray communityWeights;
        private final HugeDoubleArray cumulativeNodeWeights;
        private final boolean isSeeded;
        private final Partition partition;
        private double localSum;

        private InitTask(RelationshipIterator relationshipIterator, HugeLongArray currentCommunities, HugeAtomicDoubleArray communityWeights, HugeDoubleArray cumulativeNodeWeights, boolean isSeeded, Partition partition) {
            this.relationshipIterator = relationshipIterator;
            this.currentCommunities = currentCommunities;
            this.communityWeights = communityWeights;
            this.cumulativeNodeWeights = cumulativeNodeWeights;
            this.isSeeded = isSeeded;
            this.partition = partition;
            this.localSum = 0.0;
        }

        @Override
        public void run() {
            MutableDouble cumulativeWeight = new MutableDouble();
            this.partition.consume(nodeId -> {
                if (!this.isSeeded) {
                    this.currentCommunities.set(nodeId, nodeId);
                }
                cumulativeWeight.setValue(0.0);
                this.relationshipIterator.forEachRelationship(nodeId, 1.0, (s, t, w) -> {
                    cumulativeWeight.add(w);
                    return true;
                });
                this.communityWeights.update(this.currentCommunities.get(nodeId), acc -> acc + cumulativeWeight.doubleValue());
                this.cumulativeNodeWeights.set(nodeId, cumulativeWeight.doubleValue());
                this.localSum += cumulativeWeight.doubleValue();
            });
        }

        double localSum() {
            return this.localSum;
        }
    }
}

