/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import com.carrotsearch.hppc.LongHashSet;
import java.util.Arrays;
import java.util.List;
import java.util.SplittableRandom;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.ImmutableRelationshipCursor;
import org.neo4j.gds.api.RelationshipCursor;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.ml.core.samplers.WeightedUniformSampler;

public final class BatchSampler {
    public static final double DEGREE_SMOOTHING_FACTOR = 0.75;
    private final Graph graph;

    BatchSampler(Graph graph) {
        this.graph = graph;
    }

    List<long[]> extendedBatches(int batchSize, int searchDepth, long randomSeed) {
        return PartitionUtils.rangePartitionWithBatchSize((long)this.graph.nodeCount(), (long)batchSize, batch -> {
            long localSeed = (long)Math.toIntExact(Math.floorDiv(batch.startNode(), this.graph.nodeCount())) + randomSeed;
            return this.sampleNeighborAndNegativeNodePerBatchNode((Partition)batch, searchDepth, localSeed);
        });
    }

    long[] sampleNeighborAndNegativeNodePerBatchNode(Partition batch, int searchDepth, long randomSeed) {
        long[] neighbours = this.neighborBatch(batch, randomSeed, searchDepth);
        LongStream negativeSamples = this.negativeBatch(Math.toIntExact(batch.nodeCount()), neighbours, randomSeed);
        return LongStream.concat(batch.stream(), LongStream.concat(Arrays.stream(neighbours), negativeSamples)).toArray();
    }

    long[] neighborBatch(Partition batch, long batchLocalSeed, int searchDepth) {
        int iBatchSize = Math.toIntExact(batch.nodeCount());
        long[] neighbors = new long[iBatchSize];
        SplittableRandom localRandom = new SplittableRandom(batchLocalSeed);
        long batchOffset = batch.startNode();
        for (int idx = 0; idx < iBatchSize; ++idx) {
            long nodeId = batchOffset + (long)idx;
            MutableLong currentNode = new MutableLong(nodeId);
            for (int actualSearchDepth = localRandom.nextInt(searchDepth) + 1; actualSearchDepth > 0; --actualSearchDepth) {
                int degree = this.graph.degree(currentNode.longValue());
                if (degree != 0) {
                    int sampledIdx = localRandom.nextInt(degree);
                    long nextNode = this.graph.nthTarget(currentNode.longValue(), sampledIdx);
                    assert (nextNode != -1L) : "The offset '" + sampledIdx + "' is bound by the degree but no target could be found for nodeId " + currentNode.longValue();
                    currentNode.setValue(nextNode);
                    continue;
                }
                actualSearchDepth = 0;
            }
            neighbors[idx] = currentNode.longValue();
        }
        return neighbors;
    }

    LongStream negativeBatch(int batchSize, long[] batchNeighbors, long batchLocalRandomSeed) {
        long nodeCount = this.graph.nodeCount();
        WeightedUniformSampler sampler = new WeightedUniformSampler(batchLocalRandomSeed);
        LongHashSet neighborsSet = new LongHashSet(batchNeighbors.length);
        neighborsSet.addAll(batchNeighbors);
        Stream<RelationshipCursor> degreeWeightedNodes = LongStream.range(0L, nodeCount).mapToObj(nodeId -> ImmutableRelationshipCursor.of((long)0L, (long)nodeId, (double)Math.pow(this.graph.degree(nodeId), 0.75)));
        return sampler.sample(degreeWeightedNodes, nodeCount, batchSize, sample -> !neighborsSet.contains(sample));
    }
}

