/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.node2vec;

import java.util.stream.BaseStream;
import java.util.stream.LongStream;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.embeddings.node2vec.ImmutableRandomWalkProbabilities;

@ValueClass
interface RandomWalkProbabilities {
    public HugeLongArray nodeFrequencies();

    public HugeDoubleArray positiveSamplingProbabilities();

    public HugeLongArray negativeSamplingDistribution();

    public long sampleCount();

    public static MemoryEstimation memoryEstimation() {
        return MemoryEstimations.builder(RandomWalkProbabilities.class).perNode("node frequencies", HugeLongArray::memoryEstimation).perNode("positive sampling probabilities", HugeDoubleArray::memoryEstimation).perNode("negative sampling distribution", HugeLongArray::memoryEstimation).build();
    }

    public static class Builder {
        private final long nodeCount;
        private final int concurrency;
        private final double positiveSamplingFactor;
        private final double negativeSamplingExponent;
        private final HugeLongArray nodeFrequencies;
        private final MutableLong sampleCount;

        Builder(long nodeCount, double positiveSamplingFactor, double negativeSamplingExponent, int concurrency) {
            this.nodeCount = nodeCount;
            this.concurrency = concurrency;
            this.positiveSamplingFactor = positiveSamplingFactor;
            this.negativeSamplingExponent = negativeSamplingExponent;
            this.nodeFrequencies = HugeLongArray.newArray((long)nodeCount);
            this.sampleCount = new MutableLong(0L);
        }

        Builder registerWalk(long[] walk) {
            for (long node : walk) {
                this.nodeFrequencies.addTo(node, 1L);
            }
            this.sampleCount.add((long)walk.length);
            return this;
        }

        RandomWalkProbabilities build() {
            HugeDoubleArray centerProbabilities = this.computePositiveSamplingProbabilities();
            HugeLongArray contextDistribution = this.computeNegativeSamplingDistribution();
            return ImmutableRandomWalkProbabilities.builder().nodeFrequencies(this.nodeFrequencies).positiveSamplingProbabilities(centerProbabilities).negativeSamplingDistribution(contextDistribution).sampleCount(this.sampleCount.getValue()).build();
        }

        private HugeDoubleArray computePositiveSamplingProbabilities() {
            HugeDoubleArray centerProbabilities = HugeDoubleArray.newArray((long)this.nodeCount);
            Long sum = this.sampleCount.getValue();
            ParallelUtil.parallelStreamConsume((BaseStream)LongStream.range(0L, this.nodeCount), (int)this.concurrency, nodeStream -> nodeStream.forEach(nodeId -> {
                double frequency = (double)this.nodeFrequencies.get(nodeId) / (double)sum.longValue();
                centerProbabilities.set(nodeId, (Math.sqrt(frequency / this.positiveSamplingFactor) + 1.0) * (this.positiveSamplingFactor / frequency));
            }));
            return centerProbabilities;
        }

        private HugeLongArray computeNegativeSamplingDistribution() {
            HugeLongArray contextDistribution = HugeLongArray.newArray((long)this.nodeCount);
            long sum = 0L;
            for (long i = 0L; i < this.nodeCount; ++i) {
                sum = (long)((double)sum + Math.pow(this.nodeFrequencies.get(i), this.negativeSamplingExponent));
                sum = Math.addExact(sum, (long)Math.pow(this.nodeFrequencies.get(i), this.negativeSamplingExponent));
                contextDistribution.set(i, sum);
            }
            return contextDistribution;
        }
    }
}

