/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongHashSet;
import java.util.Optional;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.function.LongPredicate;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.samplers.LongUniformSamplerFromRange;
import org.neo4j.gds.ml.core.samplers.RandomWalkSampler;
import org.neo4j.gds.similarity.knn.KnnSampler;

class RandomWalkKnnSampler
implements KnnSampler {
    private static final int WALK_LENGTH_MULTIPLIER = 3;
    private final RandomWalkSampler randomWalkSampler;
    private final LongUniformSamplerFromRange uniformSamplerFromRange;
    private final long exclusiveMax;
    private final LongHashSet sampledValuesCache;

    RandomWalkKnnSampler(Graph graph, SplittableRandom random, Optional<Long> randomSeed, int k) {
        assert (k > 0);
        this.randomWalkSampler = new RandomWalkSampler(arg_0 -> ((Graph)graph).degree(arg_0), 3 * k, 0.4, 0.6, 1.0, graph, randomSeed.orElseGet(() -> new Random().nextLong()).longValue());
        this.uniformSamplerFromRange = new LongUniformSamplerFromRange(random);
        this.exclusiveMax = graph.nodeCount();
        this.sampledValuesCache = new LongHashSet();
    }

    public static MemoryRange memoryEstimation(long boundedK) {
        MemoryRange baseEstimation = RandomWalkSampler.memoryEstimation((long)(boundedK * 3L)).add(MemoryRange.of((long)(MemoryUsage.sizeOfInstance(RandomWalkKnnSampler.class) + MemoryUsage.sizeOfLongArray((long)boundedK) + MemoryUsage.sizeOfLongHashSet((long)boundedK))));
        return baseEstimation.add(LongUniformSamplerFromRange.memoryEstimation((long)0L)).union(baseEstimation.add(LongUniformSamplerFromRange.memoryEstimation((long)boundedK)));
    }

    @Override
    public long[] sample(long nodeId, long lowerBoundOnValidSamplesInRange, int numberOfSamples, LongPredicate isInvalidSample) {
        long[] walk = this.randomWalkSampler.walk(nodeId);
        this.sampledValuesCache.clear();
        long[] samples = new long[numberOfSamples];
        int addedSamples = 0;
        for (int i = 1; i < walk.length; ++i) {
            long node2 = walk[i];
            if (isInvalidSample.test(node2) || this.sampledValuesCache.contains(node2)) continue;
            this.sampledValuesCache.add(node2);
            samples[addedSamples++] = node2;
            if (addedSamples != numberOfSamples) continue;
            return samples;
        }
        long[] uniformSamples = this.uniformSamplerFromRange.sample(0L, this.exclusiveMax, lowerBoundOnValidSamplesInRange - (long)addedSamples, numberOfSamples - addedSamples, node -> isInvalidSample.test(node) || this.sampledValuesCache.contains(node));
        System.arraycopy(uniformSamples, 0, samples, addedSamples, uniformSamples.length);
        return samples;
    }
}

