/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.leiden;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.leiden.RefinementBetweenRelationshipCounter;
import org.neo4j.gds.mem.MemoryUsage;

final class RefinementPhase {
    private final Graph workingGraph;
    private final HugeLongArray originalCommunities;
    private final HugeDoubleArray nodeVolumes;
    private final HugeDoubleArray communityVolumes;
    private final HugeDoubleArray communityVolumesAfterMerge;
    private final double gamma;
    private final double theta;
    private final HugeDoubleArray relationshipsBetweenCommunities;
    private final HugeLongArray encounteredCommunities;
    private final HugeDoubleArray encounteredCommunitiesWeights;
    private final long seed;
    private long communityCounter = 0L;
    private final int concurrency;
    private final ExecutorService executorService;
    private final HugeDoubleArray nextCommunityProbabilities;
    private final ProgressTracker progressTracker;

    static RefinementPhase create(Graph workingGraph, HugeLongArray originalCommunities, HugeDoubleArray nodeVolumes, HugeDoubleArray communityVolumes, double gamma, double theta, long seed, int concurrency, ExecutorService executorService, ProgressTracker progressTracker) {
        HugeLongArray encounteredCommunities = HugeLongArray.newArray((long)workingGraph.nodeCount());
        HugeDoubleArray encounteredCommunitiesWeights = HugeDoubleArray.newArray((long)workingGraph.nodeCount());
        encounteredCommunitiesWeights.setAll(c -> -1.0);
        HugeDoubleArray nextCommunityProbabilities = HugeDoubleArray.newArray((long)workingGraph.nodeCount());
        return new RefinementPhase(workingGraph, originalCommunities, nodeVolumes, communityVolumes, encounteredCommunities, encounteredCommunitiesWeights, nextCommunityProbabilities, gamma, theta, seed, concurrency, executorService, progressTracker);
    }

    private RefinementPhase(Graph workingGraph, HugeLongArray originalCommunities, HugeDoubleArray nodeVolumes, HugeDoubleArray communityVolumes, HugeLongArray encounteredCommunities, HugeDoubleArray encounteredCommunitiesWeights, HugeDoubleArray nextCommunityProbabilities, double gamma, double theta, long seed, int concurrency, ExecutorService executorService, ProgressTracker progressTracker) {
        this.workingGraph = workingGraph;
        this.originalCommunities = originalCommunities;
        this.nodeVolumes = nodeVolumes;
        this.communityVolumesAfterMerge = nodeVolumes.copyOf(nodeVolumes.size());
        this.communityVolumes = communityVolumes;
        this.encounteredCommunities = encounteredCommunities;
        this.encounteredCommunitiesWeights = encounteredCommunitiesWeights;
        this.nextCommunityProbabilities = nextCommunityProbabilities;
        this.gamma = gamma;
        this.theta = theta;
        this.seed = seed;
        encounteredCommunitiesWeights.setAll(c -> -1.0);
        this.relationshipsBetweenCommunities = HugeDoubleArray.newArray((long)workingGraph.nodeCount());
        this.concurrency = concurrency;
        this.executorService = executorService;
        this.progressTracker = progressTracker;
    }

    static MemoryEstimation memoryEstimation() {
        return MemoryEstimations.builder(RefinementPhase.class).perNode("encountered communities", HugeLongArray::memoryEstimation).perNode("encountered community weights", HugeDoubleArray::memoryEstimation).perNode("next community probabilities", HugeDoubleArray::memoryEstimation).perNode("merged community volumes", HugeDoubleArray::memoryEstimation).perNode("relationships between communities", HugeDoubleArray::memoryEstimation).perNode("refined communities", HugeLongArray::memoryEstimation).perNode("merge tracking bitset", MemoryUsage::sizeOfBitset).build();
    }

    RefinementPhaseResult run() {
        HugeLongArray refinedCommunities = HugeLongArray.newArray((long)this.workingGraph.nodeCount());
        refinedCommunities.setAll(nodeId -> nodeId);
        this.computeRelationshipsBetweenCommunities();
        BitSet singleton = new BitSet(this.workingGraph.nodeCount());
        singleton.set(0L, this.workingGraph.nodeCount());
        Random random = new Random(this.seed);
        MutableLong maximumCommunityId = new MutableLong(-1L);
        this.workingGraph.forEachNode(nodeId -> {
            boolean isSingleton = singleton.get(nodeId);
            if (isSingleton && this.isWellConnected(nodeId)) {
                this.mergeNodeSubset(nodeId, refinedCommunities, singleton, random);
            }
            long refinedId = refinedCommunities.get(nodeId);
            if (maximumCommunityId.longValue() < refinedId) {
                maximumCommunityId.setValue(refinedId);
            }
            this.progressTracker.logProgress();
            return true;
        });
        return new RefinementPhaseResult(refinedCommunities, this.communityVolumesAfterMerge, maximumCommunityId.longValue());
    }

    private void computeRelationshipsBetweenCommunities() {
        List tasks = PartitionUtils.degreePartition((Graph)this.workingGraph, (int)this.concurrency, degreePartition -> new RefinementBetweenRelationshipCounter(this.workingGraph.concurrentCopy(), this.relationshipsBetweenCommunities, this.originalCommunities, (Partition)degreePartition), Optional.empty());
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks((Iterable)tasks).executor(this.executorService).run();
    }

    private void mergeNodeSubset(long nodeId, HugeLongArray refinedCommunities, BitSet singleton, Random random) {
        this.communityCounter = 0L;
        this.computeCommunityInformation(nodeId, refinedCommunities);
        long currentNodeCommunityId = refinedCommunities.get(nodeId);
        double currentNodeVolume = this.nodeVolumes.get(nodeId);
        long i = 0L;
        double probabilitiesSum = 0.0;
        if (this.communityCounter == 0L) {
            return;
        }
        double bestGain = 0.0;
        long bestCommunityId = 0L;
        double totalSumOfRelationships = 0.0;
        for (long c = 0L; c < this.communityCounter; ++c) {
            long candidateCommunityId = this.encounteredCommunities.get(c);
            double communityRelationshipsCount = this.encounteredCommunitiesWeights.get(candidateCommunityId);
            totalSumOfRelationships += communityRelationshipsCount;
            this.encounteredCommunitiesWeights.set(candidateCommunityId, -communityRelationshipsCount);
            double modularityGain = communityRelationshipsCount - currentNodeVolume * this.communityVolumesAfterMerge.get(candidateCommunityId) * this.gamma;
            if (modularityGain > bestGain) {
                bestGain = modularityGain;
                bestCommunityId = candidateCommunityId;
            }
            double nextCommunityProbability = 0.0;
            if (modularityGain >= 0.0) {
                nextCommunityProbability = Math.exp(modularityGain / this.theta);
            }
            this.nextCommunityProbabilities.set(i++, nextCommunityProbability);
            probabilitiesSum += nextCommunityProbability;
        }
        long nextCommunityId = currentNodeCommunityId;
        if (Double.isInfinite(probabilitiesSum) || probabilitiesSum <= 0.0) {
            if (bestGain > 0.0) {
                nextCommunityId = bestCommunityId;
            }
        } else {
            nextCommunityId = this.selectRandomCommunity(this.nextCommunityProbabilities, probabilitiesSum, random, nextCommunityId);
        }
        if (nextCommunityId != currentNodeCommunityId) {
            this.addToCommunity(nodeId, refinedCommunities, singleton, currentNodeCommunityId, totalSumOfRelationships, nextCommunityId);
        }
    }

    private long selectRandomCommunity(HugeDoubleArray nextCommunityProbabilities, double probabilitiesSum, Random random, long defaultCommunity) {
        double x = probabilitiesSum * random.nextDouble();
        assert (x >= 0.0);
        long nextCommunityId = defaultCommunity;
        long j = 0L;
        double curr = 0.0;
        for (long c = 0L; c < this.communityCounter; ++c) {
            long candidateCommunityId = this.encounteredCommunities.get(c);
            double candidateCommunityProbability = nextCommunityProbabilities.get(j);
            if (x <= (curr += candidateCommunityProbability)) {
                nextCommunityId = candidateCommunityId;
                break;
            }
            ++j;
        }
        return nextCommunityId;
    }

    private void addToCommunity(long nodeId, HugeLongArray refinedCommunities, BitSet singleton, long currentNodeCommunityId, double totalSumOfRelationships, long nextCommunityId) {
        refinedCommunities.set(nodeId, nextCommunityId);
        if (singleton.get(nextCommunityId)) {
            singleton.flip(nextCommunityId);
        }
        double nodeVolume = this.nodeVolumes.get(nodeId);
        this.communityVolumesAfterMerge.addTo(nextCommunityId, nodeVolume);
        this.communityVolumesAfterMerge.addTo(currentNodeCommunityId, -nodeVolume);
        long updatedCommunityId = nextCommunityId;
        double externalEdgesWithNewCommunity = Math.abs(this.encounteredCommunitiesWeights.get(updatedCommunityId));
        this.relationshipsBetweenCommunities.addTo(updatedCommunityId, totalSumOfRelationships - externalEdgesWithNewCommunity);
    }

    private void computeCommunityInformation(long nodeId, HugeLongArray refinedCommunities) {
        long originalCommunityId = this.originalCommunities.get(nodeId);
        this.workingGraph.forEachRelationship(nodeId, 1.0, (s, t, relationshipWeight) -> {
            long tCommunity;
            boolean candidateCommunityIsWellConnected;
            long tOriginalCommunity = this.originalCommunities.get(t);
            if (tOriginalCommunity == originalCommunityId && (candidateCommunityIsWellConnected = this.isWellConnected(tCommunity = refinedCommunities.get(t)))) {
                if (this.encounteredCommunitiesWeights.get(tCommunity) < 0.0) {
                    this.encounteredCommunities.set(this.communityCounter, tCommunity);
                    ++this.communityCounter;
                    this.encounteredCommunitiesWeights.set(tCommunity, relationshipWeight);
                } else {
                    this.encounteredCommunitiesWeights.addTo(tCommunity, relationshipWeight);
                }
            }
            return true;
        });
    }

    private boolean isWellConnected(long nodeOrCommunityId) {
        long originalCommunityId = this.originalCommunities.get(nodeOrCommunityId);
        double originalCommunityVolume = this.communityVolumes.get(originalCommunityId);
        double updatedCommunityVolume = this.communityVolumesAfterMerge.get(nodeOrCommunityId);
        double rightSide = this.gamma * updatedCommunityVolume * (originalCommunityVolume - updatedCommunityVolume);
        return this.relationshipsBetweenCommunities.get(nodeOrCommunityId) >= rightSide;
    }

    static class RefinementPhaseResult {
        private final HugeLongArray communities;
        private final HugeDoubleArray communityVolumes;
        private final long maximumRefinementCommunityId;

        RefinementPhaseResult(HugeLongArray communities, HugeDoubleArray communityVolumes, long maximumRefinedCommunityId) {
            this.communities = communities;
            this.communityVolumes = communityVolumes;
            this.maximumRefinementCommunityId = maximumRefinedCommunityId;
        }

        HugeLongArray communities() {
            return this.communities;
        }

        HugeDoubleArray communityVolumes() {
            return this.communityVolumes;
        }

        long maximumRefinedCommunityId() {
            return this.maximumRefinementCommunityId;
        }
    }
}

