/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.node2vec;

import java.lang.invoke.LambdaMetafactory;
import java.util.ArrayList;
import java.util.List;
import java.util.PrimitiveIterator;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.function.BiConsumer;
import java.util.function.LongUnaryOperator;
import java.util.function.ObjDoubleConsumer;
import java.util.function.Supplier;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.collection.primitive.PrimitiveLongCollections;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.embeddings.node2vec.CompressedRandomWalks;
import org.neo4j.gds.embeddings.node2vec.ImmutableResult;
import org.neo4j.gds.embeddings.node2vec.NegativeSampleProducer;
import org.neo4j.gds.embeddings.node2vec.Node2Vec;
import org.neo4j.gds.embeddings.node2vec.Node2VecBaseConfig;
import org.neo4j.gds.embeddings.node2vec.PositiveSampleProducer;
import org.neo4j.gds.embeddings.node2vec.RandomWalkProbabilities;
import org.neo4j.gds.mem.BitUtil;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.tensor.FloatVector;
import org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations;
import org.neo4j.gds.utils.StringFormatting;

public class Node2VecModel {
    private final NegativeSampleProducer negativeSamples;
    private final HugeObjectArray<FloatVector> centerEmbeddings;
    private final HugeObjectArray<FloatVector> contextEmbeddings;
    private final Node2VecBaseConfig config;
    private final CompressedRandomWalks walks;
    private final RandomWalkProbabilities randomWalkProbabilities;
    private final ProgressTracker progressTracker;
    private final long randomSeed;

    public static MemoryEstimation memoryEstimation(Node2VecBaseConfig config) {
        long vectorMemoryEstimation = MemoryUsage.sizeOfFloatArray((long)config.embeddingDimension());
        return MemoryEstimations.builder((String)Node2Vec.class.getSimpleName()).perNode("center embeddings", nodeCount -> HugeObjectArray.memoryEstimation((long)nodeCount, (long)vectorMemoryEstimation)).perNode("context embeddings", nodeCount -> HugeObjectArray.memoryEstimation((long)nodeCount, (long)vectorMemoryEstimation)).build();
    }

    Node2VecModel(LongUnaryOperator toOriginalId, long nodeCount, Node2VecBaseConfig config, CompressedRandomWalks walks, RandomWalkProbabilities randomWalkProbabilities, ProgressTracker progressTracker) {
        this.config = config;
        this.walks = walks;
        this.randomWalkProbabilities = randomWalkProbabilities;
        this.progressTracker = progressTracker;
        this.negativeSamples = new NegativeSampleProducer(randomWalkProbabilities.negativeSamplingDistribution());
        this.randomSeed = config.randomSeed().orElseGet(() -> new SplittableRandom().nextLong());
        Random random = new Random();
        this.centerEmbeddings = this.initializeEmbeddings(toOriginalId, nodeCount, config.embeddingDimension(), random);
        this.contextEmbeddings = this.initializeEmbeddings(toOriginalId, nodeCount, config.embeddingDimension(), random);
    }

    Result train() {
        this.progressTracker.beginSubTask();
        double learningRateAlpha = (this.config.initialLearningRate() - this.config.minLearningRate()) / (double)this.config.iterations();
        ArrayList<Double> lossPerIteration = new ArrayList<Double>();
        for (int iteration = 0; iteration < this.config.iterations(); ++iteration) {
            this.progressTracker.beginSubTask();
            this.progressTracker.setVolume(this.walks.size());
            float learningRate = (float)Math.max(this.config.minLearningRate(), this.config.initialLearningRate() - (double)iteration * learningRateAlpha);
            List tasks = PartitionUtils.degreePartitionWithBatchSize((PrimitiveIterator.OfLong)PrimitiveLongCollections.range((long)0L, (long)(this.walks.size() - 1L)), this.walks::walkLength, (long)BitUtil.ceilDiv((long)this.randomWalkProbabilities.sampleCount(), (long)this.config.concurrency()), partition -> {
                PositiveSampleProducer positiveSampleProducer = new PositiveSampleProducer(this.walks.iterator(partition.startNode(), partition.nodeCount()), this.randomWalkProbabilities.positiveSamplingProbabilities(), this.config.windowSize(), this.progressTracker);
                return new TrainingTask(this.centerEmbeddings, this.contextEmbeddings, positiveSampleProducer, this.negativeSamples, learningRate, this.config.negativeSamplingRate(), this.config.embeddingDimension());
            });
            RunWithConcurrency.builder().concurrency(this.config.concurrency()).tasks((Iterable)tasks).run();
            double loss = tasks.stream().mapToDouble(TrainingTask::lossSum).sum();
            this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Loss %.4f", (Object[])new Object[]{loss}));
            lossPerIteration.add(loss);
            this.progressTracker.endSubTask();
        }
        this.progressTracker.endSubTask();
        return ImmutableResult.of(this.centerEmbeddings, lossPerIteration);
    }

    private HugeObjectArray<FloatVector> initializeEmbeddings(LongUnaryOperator toOriginalNodeId, long nodeCount, int embeddingDimensions, Random random) {
        double bound;
        HugeObjectArray embeddings = HugeObjectArray.newArray(FloatVector.class, (long)nodeCount);
        switch (this.config.embeddingInitializer()) {
            case UNIFORM: {
                bound = 1.0;
                break;
            }
            case NORMALIZED: {
                bound = 0.5 / (double)embeddingDimensions;
                break;
            }
            default: {
                throw new IllegalStateException("Missing implementation for: " + this.config.embeddingInitializer());
            }
        }
        for (long i = 0L; i < nodeCount; ++i) {
            random.setSeed(toOriginalNodeId.applyAsLong(i) + this.randomSeed);
            float[] data = random.doubles((long)((long)embeddingDimensions), (double)(-bound), (double)bound).collect((Supplier<FloatConsumer>)LambdaMetafactory.metafactory(null, null, null, ()Ljava/lang/Object;, lambda$initializeEmbeddings$4(int ), ()Lorg/neo4j/gds/embeddings/node2vec/Node2VecModel$FloatConsumer;)((int)embeddingDimensions), (ObjDoubleConsumer<FloatConsumer>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;D)V, add(double ), (Lorg/neo4j/gds/embeddings/node2vec/Node2VecModel$FloatConsumer;D)V)(), (BiConsumer<FloatConsumer, FloatConsumer>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;Ljava/lang/Object;)V, addAll(org.neo4j.gds.embeddings.node2vec.Node2VecModel$FloatConsumer ), (Lorg/neo4j/gds/embeddings/node2vec/Node2VecModel$FloatConsumer;Lorg/neo4j/gds/embeddings/node2vec/Node2VecModel$FloatConsumer;)V)()).values;
            embeddings.set(i, (Object)new FloatVector(data));
        }
        return embeddings;
    }

    private static /* synthetic */ FloatConsumer lambda$initializeEmbeddings$4(int embeddingDimensions) {
        return new FloatConsumer(embeddingDimensions);
    }

    @ValueClass
    public static interface Result {
        public HugeObjectArray<FloatVector> embeddings();

        public List<Double> lossPerIteration();
    }

    static class FloatConsumer {
        float[] values;
        int index;

        FloatConsumer(int length) {
            this.values = new float[length];
        }

        void add(double value) {
            this.values[this.index++] = (float)value;
        }

        void addAll(FloatConsumer other) {
            System.arraycopy(other.values, 0, this.values, this.index, other.index);
            this.index += other.index;
        }
    }

    private static final class TrainingTask
    implements Runnable {
        private final HugeObjectArray<FloatVector> centerEmbeddings;
        private final HugeObjectArray<FloatVector> contextEmbeddings;
        private final PositiveSampleProducer positiveSampleProducer;
        private final NegativeSampleProducer negativeSampleProducer;
        private final FloatVector centerGradientBuffer;
        private final FloatVector contextGradientBuffer;
        private final int negativeSamplingRate;
        private final float learningRate;
        private double lossSum;

        private TrainingTask(HugeObjectArray<FloatVector> centerEmbeddings, HugeObjectArray<FloatVector> contextEmbeddings, PositiveSampleProducer positiveSampleProducer, NegativeSampleProducer negativeSampleProducer, float learningRate, int negativeSamplingRate, int embeddingDimensions) {
            this.centerEmbeddings = centerEmbeddings;
            this.contextEmbeddings = contextEmbeddings;
            this.positiveSampleProducer = positiveSampleProducer;
            this.negativeSampleProducer = negativeSampleProducer;
            this.learningRate = learningRate;
            this.negativeSamplingRate = negativeSamplingRate;
            this.centerGradientBuffer = new FloatVector(embeddingDimensions);
            this.contextGradientBuffer = new FloatVector(embeddingDimensions);
        }

        @Override
        public void run() {
            long[] buffer = new long[2];
            while (this.positiveSampleProducer.next(buffer)) {
                this.trainSample(buffer[0], buffer[1], true);
                for (int i = 0; i < this.negativeSamplingRate; ++i) {
                    this.trainSample(buffer[0], this.negativeSampleProducer.next(), false);
                }
            }
        }

        private void trainSample(long center, long context, boolean positive) {
            FloatVector centerEmbedding = (FloatVector)this.centerEmbeddings.get(center);
            FloatVector contextEmbedding = (FloatVector)this.contextEmbeddings.get(context);
            float affinity = centerEmbedding.innerProduct(contextEmbedding);
            float positiveSigmoid = (float)Sigmoid.sigmoid((double)affinity);
            float negativeSigmoid = 1.0f - positiveSigmoid;
            this.lossSum -= positive ? Math.log(positiveSigmoid) : Math.log(negativeSigmoid);
            float gradient = positive ? -negativeSigmoid : positiveSigmoid;
            float scaledGradient = -gradient * this.learningRate;
            FloatVectorOperations.scale((float[])contextEmbedding.data(), (float)scaledGradient, (float[])this.centerGradientBuffer.data());
            FloatVectorOperations.scale((float[])centerEmbedding.data(), (float)scaledGradient, (float[])this.contextGradientBuffer.data());
            FloatVectorOperations.addInPlace((float[])centerEmbedding.data(), (float[])this.centerGradientBuffer.data());
            FloatVectorOperations.addInPlace((float[])contextEmbedding.data(), (float[])this.contextGradientBuffer.data());
        }

        double lossSum() {
            return this.lossSum;
        }
    }
}

