/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.fastrp;

import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.mutable.MutableLong;
import org.jetbrains.annotations.TestOnly;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.DegreePartition;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.fastrp.FastRPBaseConfig;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.features.FeatureConsumer;
import org.neo4j.gds.ml.core.features.FeatureExtraction;
import org.neo4j.gds.ml.core.features.FeatureExtractor;
import org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations;
import org.neo4j.gds.utils.StringFormatting;

public class FastRP
extends Algorithm<FastRPResult> {
    private static final int SPARSITY = 3;
    private static final double ENTRY_PROBABILITY = 0.16666666666666666;
    private static final float EPSILON = 2.9387362E-38f;
    private final Graph graph;
    private final int concurrency;
    private final float normalizationStrength;
    private final List<FeatureExtractor> featureExtractors;
    private final Optional<String> relationshipWeightProperty;
    private final double relationshipWeightFallback;
    private final int inputDimension;
    private final float[][] propertyVectors;
    private final HugeObjectArray<float[]> embeddings;
    private final HugeObjectArray<float[]> embeddingA;
    private final HugeObjectArray<float[]> embeddingB;
    private final EmbeddingCombiner embeddingCombiner;
    private final long randomSeed;
    private final int embeddingDimension;
    private final int baseEmbeddingDimension;
    private final Number nodeSelfInfluence;
    private final List<Number> iterationWeights;
    private final int minBatchSize;
    private List<DegreePartition> partitions;

    public static MemoryEstimation memoryEstimation(FastRPBaseConfig config) {
        return MemoryEstimations.builder((String)FastRP.class.getSimpleName()).fixed("propertyVectors", MemoryUsage.sizeOfFloatArray((long)(config.featureProperties().size() * config.propertyDimension()))).add("embeddings", HugeObjectArray.memoryEstimation((long)MemoryUsage.sizeOfFloatArray((long)config.embeddingDimension()))).add("embeddingA", HugeObjectArray.memoryEstimation((long)MemoryUsage.sizeOfFloatArray((long)config.embeddingDimension()))).add("embeddingB", HugeObjectArray.memoryEstimation((long)MemoryUsage.sizeOfFloatArray((long)config.embeddingDimension()))).build();
    }

    public FastRP(Graph graph, FastRPBaseConfig config, List<FeatureExtractor> featureExtractors, ProgressTracker progressTracker) {
        this(graph, config, featureExtractors, progressTracker, config.randomSeed());
    }

    public FastRP(Graph graph, FastRPBaseConfig config, List<FeatureExtractor> featureExtractors, ProgressTracker progressTracker, Optional<Long> randomSeed) {
        super(progressTracker);
        this.graph = graph;
        this.featureExtractors = featureExtractors;
        this.relationshipWeightProperty = config.relationshipWeightProperty();
        this.relationshipWeightFallback = this.relationshipWeightProperty.map(s -> Double.NaN).orElse(1.0);
        this.inputDimension = FeatureExtraction.featureCount(featureExtractors);
        this.randomSeed = this.improveSeed(randomSeed.orElseGet(System::nanoTime));
        this.minBatchSize = config.minBatchSize();
        this.propertyVectors = new float[this.inputDimension][config.propertyDimension()];
        this.embeddings = HugeObjectArray.newArray(float[].class, (long)graph.nodeCount());
        this.embeddingA = HugeObjectArray.newArray(float[].class, (long)graph.nodeCount());
        this.embeddingB = HugeObjectArray.newArray(float[].class, (long)graph.nodeCount());
        this.embeddingDimension = config.embeddingDimension();
        this.baseEmbeddingDimension = config.embeddingDimension() - config.propertyDimension();
        this.iterationWeights = config.iterationWeights();
        this.nodeSelfInfluence = config.nodeSelfInfluence();
        this.normalizationStrength = config.normalizationStrength();
        this.concurrency = config.concurrency();
        this.embeddingCombiner = graph.hasRelationshipProperty() ? this::addArrayValuesWeighted : (lhs, rhs, ignoreWeight) -> FloatVectorOperations.addInPlace((float[])lhs, (float[])rhs);
        this.embeddings.setAll(i -> new float[this.embeddingDimension]);
    }

    public FastRPResult compute() {
        this.progressTracker.beginSubTask();
        this.initDegreePartition();
        this.initPropertyVectors();
        this.initRandomVectors();
        this.addInitialVectorsToEmbedding();
        this.propagateEmbeddings();
        this.progressTracker.endSubTask();
        return new FastRPResult(this.embeddings);
    }

    public void release() {
        this.embeddingA.release();
        this.embeddingB.release();
    }

    public void initDegreePartition() {
        this.partitions = PartitionUtils.degreePartition((Graph)this.graph, (int)this.concurrency, Function.identity(), Optional.of(this.minBatchSize));
    }

    void initPropertyVectors() {
        int propertyDimension = this.embeddingDimension - this.baseEmbeddingDimension;
        float entryValue = (float)Math.sqrt(3.0) / (float)Math.sqrt(propertyDimension);
        HighQualityRandom random = new HighQualityRandom(this.randomSeed);
        for (int i = 0; i < this.inputDimension; ++i) {
            this.propertyVectors[i] = new float[propertyDimension];
            for (int d = 0; d < propertyDimension; ++d) {
                this.propertyVectors[i][d] = FastRP.computeRandomEntry(random, entryValue);
            }
        }
    }

    void initRandomVectors() {
        this.progressTracker.beginSubTask();
        float sqrtEmbeddingDimension = (float)Math.sqrt(this.baseEmbeddingDimension);
        List tasks = PartitionUtils.rangePartition((int)this.concurrency, (long)this.graph.nodeCount(), partition -> new InitRandomVectorTask((Partition)partition, sqrtEmbeddingDimension), Optional.of(this.minBatchSize));
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks((Iterable)tasks).run();
        this.progressTracker.endSubTask();
    }

    void addInitialVectorsToEmbedding() {
        if (Float.compare(this.nodeSelfInfluence.floatValue(), 0.0f) == 0) {
            return;
        }
        this.progressTracker.beginSubTask();
        List tasks = this.partitions.stream().map(x$0 -> new AddInitialStateToEmbeddingTask((Partition)x$0)).collect(Collectors.toList());
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(tasks).run();
        this.progressTracker.endSubTask();
    }

    void propagateEmbeddings() {
        this.progressTracker.beginSubTask();
        for (int i = 0; i < this.iterationWeights.size(); ++i) {
            this.progressTracker.beginSubTask();
            HugeObjectArray<float[]> currentEmbeddings = i % 2 == 0 ? this.embeddingA : this.embeddingB;
            HugeObjectArray<float[]> previousEmbeddings = i % 2 == 0 ? this.embeddingB : this.embeddingA;
            float iterationWeight = this.iterationWeights.get(i).floatValue();
            boolean firstIteration = i == 0;
            List tasks = this.partitions.stream().map(partition -> new PropagateEmbeddingsTask((Partition)partition, currentEmbeddings, previousEmbeddings, iterationWeight, firstIteration)).collect(Collectors.toList());
            RunWithConcurrency.builder().concurrency(this.concurrency).tasks(tasks).run();
            this.progressTracker.endSubTask();
        }
        this.progressTracker.endSubTask();
    }

    @TestOnly
    HugeObjectArray<float[]> currentEmbedding(int iteration) {
        return iteration % 2 == 0 ? this.embeddingA : this.embeddingB;
    }

    @TestOnly
    float[][] propertyVectors() {
        return this.propertyVectors;
    }

    @TestOnly
    HugeObjectArray<float[]> embeddings() {
        return this.embeddings;
    }

    private void addArrayValuesWeighted(float[] lhs, float[] rhs, double weight) {
        for (int i = 0; i < lhs.length; ++i) {
            lhs[i] = (float)Math.fma((double)rhs[i], weight, (double)lhs[i]);
        }
    }

    private static float computeRandomEntry(Random random, float entryValue) {
        double randomValue = random.nextDouble();
        if (randomValue < 0.16666666666666666) {
            return entryValue;
        }
        if (randomValue < 0.3333333333333333) {
            return -entryValue;
        }
        return 0.0f;
    }

    private long improveSeed(long randomSeed) {
        return new HighQualityRandom(randomSeed).nextLong();
    }

    public static class FastRPResult {
        private final HugeObjectArray<float[]> embeddings;

        public FastRPResult(HugeObjectArray<float[]> embeddings) {
            this.embeddings = embeddings;
        }

        public HugeObjectArray<float[]> embeddings() {
            return this.embeddings;
        }
    }

    private final class PropagateEmbeddingsTask
    implements Runnable {
        private final Partition partition;
        private final HugeObjectArray<float[]> currentEmbeddings;
        private final HugeObjectArray<float[]> previousEmbeddings;
        private final float iterationWeight;
        private final Graph concurrentGraph;
        private final boolean firstIteration;

        private PropagateEmbeddingsTask(Partition partition, HugeObjectArray<float[]> currentEmbeddings, HugeObjectArray<float[]> previousEmbeddings, float iterationWeight, boolean firstIteration) {
            this.partition = partition;
            this.currentEmbeddings = currentEmbeddings;
            this.previousEmbeddings = previousEmbeddings;
            this.iterationWeight = iterationWeight;
            this.concurrentGraph = FastRP.this.graph.concurrentCopy();
            this.firstIteration = firstIteration;
        }

        @Override
        public void run() {
            MutableLong degrees = new MutableLong(0L);
            this.partition.consume(nodeId -> {
                float[] embedding = (float[])FastRP.this.embeddings.get(nodeId);
                float[] currentEmbedding = (float[])this.currentEmbeddings.get(nodeId);
                Arrays.fill(currentEmbedding, 0.0f);
                this.concurrentGraph.forEachRelationship(nodeId, FastRP.this.relationshipWeightFallback, (source, target, weight) -> {
                    if (this.firstIteration && Double.isNaN(weight)) {
                        throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Missing relationship property `%s` on relationship between nodes with ids `%d` and `%d`.", (Object[])new Object[]{FastRP.this.relationshipWeightProperty.orElse(""), FastRP.this.graph.toOriginalNodeId(source), FastRP.this.graph.toOriginalNodeId(target)}));
                    }
                    FastRP.this.embeddingCombiner.combine(currentEmbedding, (float[])this.previousEmbeddings.get(target), weight);
                    return true;
                });
                int degree = FastRP.this.graph.degree(nodeId);
                int adjustedDegree = degree == 0 ? 1 : degree;
                float degreeScale = 1.0f / (float)adjustedDegree;
                FloatVectorOperations.scale((float[])currentEmbedding, (float)degreeScale);
                FloatVectorOperations.l2Normalize((float[])currentEmbedding);
                FloatVectorOperations.addWeightedInPlace((float[])embedding, (float[])currentEmbedding, (float)this.iterationWeight);
                degrees.add((long)degree);
            });
            FastRP.this.progressTracker.logProgress(degrees.longValue());
        }
    }

    private final class AddInitialStateToEmbeddingTask
    implements Runnable {
        private final Partition partition;

        private AddInitialStateToEmbeddingTask(Partition partition) {
            this.partition = partition;
        }

        @Override
        public void run() {
            this.partition.consume(nodeId -> {
                float[] initialVector = (float[])FastRP.this.embeddingB.get(nodeId);
                float l2Norm = FloatVectorOperations.l2Norm((float[])initialVector);
                float adjustedL2Norm = l2Norm < 2.9387362E-38f ? 1.0f : l2Norm;
                FloatVectorOperations.addWeightedInPlace((float[])((float[])FastRP.this.embeddings.get(nodeId)), (float[])initialVector, (float)(FastRP.this.nodeSelfInfluence.floatValue() / adjustedL2Norm));
            });
            FastRP.this.progressTracker.logProgress(this.partition.nodeCount());
        }
    }

    private final class InitRandomVectorTask
    implements Runnable {
        final float sqrtSparsity = (float)Math.sqrt(3.0);
        private final Partition partition;
        private final float sqrtEmbeddingDimension;
        private final PropertyVectorAdder propertyVectorAdder;

        private InitRandomVectorTask(Partition partition, float sqrtEmbeddingDimension) {
            this.partition = partition;
            this.sqrtEmbeddingDimension = sqrtEmbeddingDimension;
            this.propertyVectorAdder = new PropertyVectorAdder();
        }

        @Override
        public void run() {
            HighQualityRandom random = new HighQualityRandom(FastRP.this.randomSeed);
            this.partition.consume(nodeId -> {
                int degree = FastRP.this.graph.degree(nodeId);
                float scaling = degree == 0 ? 1.0f : (float)Math.pow(degree, FastRP.this.normalizationStrength);
                float entryValue = scaling * this.sqrtSparsity / this.sqrtEmbeddingDimension;
                random.reseed(FastRP.this.randomSeed ^ FastRP.this.graph.toOriginalNodeId(nodeId));
                float[] randomVector = this.computeRandomVector(nodeId, random, entryValue);
                FastRP.this.embeddingB.set(nodeId, (Object)randomVector);
                FastRP.this.embeddingA.set(nodeId, (Object)new float[FastRP.this.embeddingDimension]);
            });
            FastRP.this.progressTracker.logProgress(this.partition.nodeCount());
        }

        private float[] computeRandomVector(long nodeId, Random random, float entryValue) {
            float[] randomVector = new float[FastRP.this.embeddingDimension];
            for (int i = 0; i < FastRP.this.baseEmbeddingDimension; ++i) {
                randomVector[i] = FastRP.computeRandomEntry(random, entryValue);
            }
            this.propertyVectorAdder.setRandomVector(randomVector);
            FeatureExtraction.extract((long)nodeId, (long)-1L, FastRP.this.featureExtractors, (FeatureConsumer)this.propertyVectorAdder);
            return randomVector;
        }

        private class PropertyVectorAdder
        implements FeatureConsumer {
            private float[] randomVector;

            private PropertyVectorAdder() {
            }

            void setRandomVector(float[] randomVector) {
                this.randomVector = randomVector;
            }

            public void acceptScalar(long ignored, int offset, double value) {
                float floatValue = (float)value;
                for (int i = FastRP.this.baseEmbeddingDimension; i < FastRP.this.embeddingDimension; ++i) {
                    int n = i;
                    this.randomVector[n] = this.randomVector[n] + floatValue * FastRP.this.propertyVectors[offset][i - FastRP.this.baseEmbeddingDimension];
                }
            }

            public void acceptArray(long ignored, int offset, double[] values) {
                for (int j = 0; j < values.length; ++j) {
                    float value = (float)values[j];
                    float[] propertyVector = FastRP.this.propertyVectors[offset + j];
                    for (int i = FastRP.this.baseEmbeddingDimension; i < FastRP.this.embeddingDimension; ++i) {
                        int n = i;
                        this.randomVector[n] = this.randomVector[n] + value * propertyVector[i - FastRP.this.baseEmbeddingDimension];
                    }
                }
            }
        }
    }

    private static interface EmbeddingCombiner {
        public void combine(float[] var1, float[] var2, double var3);
    }

    private static class HighQualityRandom
    extends Random {
        private long u;
        private long v;
        private long w;

        public HighQualityRandom(long seed) {
            this.reseed(seed);
        }

        public void reseed(long seed) {
            this.v = 4101842887655102017L;
            this.w = 1L;
            this.u = seed ^ this.v;
            this.nextLong();
            this.v = this.u;
            this.nextLong();
            this.w = this.v;
            this.nextLong();
        }

        @Override
        public long nextLong() {
            this.u = this.u * 2862933555777941757L + 7046029254386353087L;
            this.v ^= this.v >>> 17;
            this.v ^= this.v << 31;
            this.v ^= this.v >>> 8;
            this.w = 4294957665L * this.w + (this.w >>> 32);
            long x = this.u ^ this.u << 21;
            x ^= x >>> 35;
            x ^= x << 4;
            return x + this.v ^ this.w;
        }

        @Override
        protected int next(int bits) {
            return (int)(this.nextLong() >>> 64 - bits);
        }
    }
}

