/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.hashgnn;

import java.util.List;
import java.util.SplittableRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.primes.Primes;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.hashgnn.HashGNNCompanion;
import org.neo4j.gds.embeddings.hashgnn.HashGNNConfig;
import org.neo4j.gds.embeddings.hashgnn.ImmutableHashes;
import org.neo4j.gds.mem.MemoryUsage;

class HashTask
implements Runnable {
    private static final double MAX_FINAL_INFLUENCE = 10000.0;
    private static final int PRIME_LOWER_BOUND = 50000;
    private final int embeddingDimension;
    private final double scaledNeighborInfluence;
    private final int numberOfRelationshipTypes;
    private final SplittableRandom rng;
    private int[] neighborsAggregationHashes;
    private int[] selfAggregationHashes;
    private List<int[]> preAggregationHashes;
    private final ProgressTracker progressTracker;

    HashTask(int embeddingDimension, double scaledNeighborInfluence, int numberOfRelationshipTypes, SplittableRandom rng, ProgressTracker progressTracker) {
        this.embeddingDimension = embeddingDimension;
        this.scaledNeighborInfluence = scaledNeighborInfluence;
        this.numberOfRelationshipTypes = numberOfRelationshipTypes;
        this.rng = rng;
        this.progressTracker = progressTracker;
    }

    public static List<Hashes> compute(int embeddingDimension, double scaledNeighborInfluence, int numberOfRelationshipTypes, HashGNNConfig config, long randomSeed, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        progressTracker.beginSubTask("Precompute hashes");
        progressTracker.setSteps((long)config.embeddingDensity());
        List hashTasks = IntStream.range(0, config.embeddingDensity()).mapToObj(seedOffset -> new HashTask(embeddingDimension, scaledNeighborInfluence, numberOfRelationshipTypes, new SplittableRandom(randomSeed + (long)seedOffset), progressTracker)).collect(Collectors.toList());
        RunWithConcurrency.builder().concurrency(config.concurrency()).tasks(hashTasks).terminationFlag(terminationFlag).run();
        progressTracker.endSubTask("Precompute hashes");
        return hashTasks.stream().map(HashTask::hashes).collect(Collectors.toList());
    }

    @Override
    public void run() {
        double finalInfluence = Math.max(1.0E-4, Math.min(10000.0, this.scaledNeighborInfluence));
        int primeSeed = this.rng.nextInt(50000, (int)Math.round(2.147483647E9 / (Math.max(1.0, finalInfluence) * 1.001)));
        int neighborPrime = Primes.nextPrime((int)primeSeed);
        int selfPrime = Double.compare(this.scaledNeighborInfluence, 1.0) == 0 ? neighborPrime : Primes.nextPrime((int)((int)Math.round((double)neighborPrime * finalInfluence)));
        this.neighborsAggregationHashes = HashGNNCompanion.HashTriple.computeHashesFromTriple(this.embeddingDimension, HashGNNCompanion.HashTriple.generate(this.rng, neighborPrime));
        this.selfAggregationHashes = HashGNNCompanion.HashTriple.computeHashesFromTriple(this.embeddingDimension, HashGNNCompanion.HashTriple.generate(this.rng, selfPrime));
        this.preAggregationHashes = IntStream.range(0, this.numberOfRelationshipTypes).mapToObj(unused -> HashGNNCompanion.HashTriple.computeHashesFromTriple(this.embeddingDimension, HashGNNCompanion.HashTriple.generate(this.rng))).collect(Collectors.toList());
        this.progressTracker.logSteps(1L);
    }

    Hashes hashes() {
        return ImmutableHashes.of(this.neighborsAggregationHashes, this.selfAggregationHashes, this.preAggregationHashes);
    }

    @ValueClass
    static interface Hashes {
        public int[] neighborsAggregationHashes();

        public int[] selfAggregationHashes();

        public List<int[]> preAggregationHashes();

        public static long memoryEstimation(int ambientDimension, int numRelTypes) {
            long neighborAggregation = MemoryUsage.sizeOfIntArrayList((long)ambientDimension);
            long selfAggregation = MemoryUsage.sizeOfIntArray((long)ambientDimension);
            long preAggregation = MemoryUsage.sizeOfIntArrayList((long)numRelTypes) + MemoryUsage.sizeOfIntArray((long)ambientDimension) * (long)numRelTypes;
            return neighborAggregation + selfAggregation + preAggregation + MemoryUsage.sizeOfInstance(Hashes.class);
        }
    }
}

