/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.embeddings.graphsage;

import java.util.stream.IntStream;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.RelationshipWeights;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.functions.SingleParentVariable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.utils.StringFormatting;

public class GraphSageLoss
extends SingleParentVariable<Matrix, Scalar> {
    private static final int SAMPLING_BUCKETS = 3;
    private static final int NEGATIVE_NODES_OFFSET = 2;
    private final RelationshipWeights relationshipWeights;
    private final Variable<Matrix> combinedEmbeddings;
    private final long[] batch;
    private final int negativeSamplingFactor;
    private static final double ALPHA = 1.0;

    GraphSageLoss(RelationshipWeights relationshipWeights, Variable<Matrix> combinedEmbeddings, long[] batch, int negativeSamplingFactor) {
        super(combinedEmbeddings, Dimensions.scalar());
        this.relationshipWeights = relationshipWeights;
        this.combinedEmbeddings = combinedEmbeddings;
        this.batch = batch;
        this.negativeSamplingFactor = negativeSamplingFactor;
    }

    public Scalar apply(ComputationContext ctx) {
        Matrix embeddingData = (Matrix)ctx.data(this.combinedEmbeddings);
        int bucketSize = embeddingData.rows() / 3;
        int negativeNodesOffset = 2 * bucketSize;
        double loss = IntStream.range(0, bucketSize).mapToDouble(bucketIndex -> {
            int positiveNodeIdx = bucketIndex + bucketSize;
            int negativeNodeIdx = bucketIndex + negativeNodesOffset;
            double positiveAffinity = GraphSageLoss.affinity(embeddingData, bucketIndex, positiveNodeIdx);
            double negativeAffinity = GraphSageLoss.affinity(embeddingData, bucketIndex, negativeNodeIdx);
            return -this.relationshipWeightFactor(this.batch[bucketIndex], this.batch[positiveNodeIdx]) * Math.log(Sigmoid.sigmoid((double)positiveAffinity)) - (double)this.negativeSamplingFactor * Math.log(Sigmoid.sigmoid((double)(-negativeAffinity)));
        }).sum();
        return new Scalar(loss / (double)bucketSize);
    }

    private double relationshipWeightFactor(long nodeId, long positiveNodeId) {
        double relationshipWeight = this.relationshipWeights.weight(nodeId, positiveNodeId);
        if (Double.isNaN(relationshipWeight)) {
            relationshipWeight = 1.0;
        }
        return Math.pow(relationshipWeight, 1.0);
    }

    private static double affinity(Matrix embeddingData, int batchIdx, int otherBatchIdx) {
        int embeddingDimension = embeddingData.cols();
        double sum = 0.0;
        for (int col = 0; col < embeddingDimension; ++col) {
            sum += embeddingData.dataAt(batchIdx, col) * embeddingData.dataAt(otherBatchIdx, col);
        }
        return sum;
    }

    public Matrix gradientForParent(ComputationContext ctx) {
        if (this.parent != this.combinedEmbeddings) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)"This variable only has a single parent. Expected %s but got %s", (Object[])new Object[]{this.combinedEmbeddings, this.parent}));
        }
        Matrix embeddings = (Matrix)ctx.data(this.combinedEmbeddings);
        Matrix gradientResult = embeddings.createWithSameDimensions();
        int bucketSize = embeddings.rows() / 3;
        int negativeNodesBucketOffset = 2 * bucketSize;
        int embeddingDimension = embeddings.cols();
        for (int bucketIdx = 0; bucketIdx < bucketSize; ++bucketIdx) {
            int positiveNodeIdx = bucketIdx + bucketSize;
            int negativeNodeIdx = bucketIdx + negativeNodesBucketOffset;
            double positiveAffinity = GraphSageLoss.affinity(embeddings, bucketIdx, positiveNodeIdx);
            double negativeAffinity = GraphSageLoss.affinity(embeddings, bucketIdx, negativeNodeIdx);
            double relationshipWeightFactor = this.relationshipWeightFactor(this.batch[bucketIdx], this.batch[positiveNodeIdx]);
            double weightedPositiveLogistic = relationshipWeightFactor * GraphSageLoss.logisticFunction(positiveAffinity);
            double weightedNegativeLogistic = (double)this.negativeSamplingFactor * GraphSageLoss.logisticFunction(-negativeAffinity);
            for (int embeddingIdx = 0; embeddingIdx < embeddingDimension; ++embeddingIdx) {
                GraphSageLoss.computeGradientForEmbeddingIdx(embeddings, gradientResult, bucketIdx, positiveNodeIdx, negativeNodeIdx, weightedPositiveLogistic, weightedNegativeLogistic, embeddingIdx);
            }
        }
        gradientResult.mapInPlace(i -> i / (double)bucketSize);
        return gradientResult;
    }

    private static void computeGradientForEmbeddingIdx(Matrix embeddings, Matrix gradientResult, int batchIdx, int positiveNodeIdx, int negativeNodeIdx, double weightedPositiveLogistic, double weightedNegativeLogistic, int embeddingIdx) {
        double scaledPositiveExampleGradient = -embeddings.dataAt(positiveNodeIdx, embeddingIdx) * weightedPositiveLogistic;
        double scaledNegativeExampleGradient = weightedNegativeLogistic * embeddings.dataAt(negativeNodeIdx, embeddingIdx);
        gradientResult.setDataAt(batchIdx, embeddingIdx, scaledPositiveExampleGradient + scaledNegativeExampleGradient);
        double currentEmbeddingValue = embeddings.dataAt(batchIdx, embeddingIdx);
        gradientResult.setDataAt(positiveNodeIdx, embeddingIdx, -currentEmbeddingValue * weightedPositiveLogistic);
        gradientResult.setDataAt(negativeNodeIdx, embeddingIdx, weightedNegativeLogistic * currentEmbeddingValue);
    }

    private static double logisticFunction(double affinity) {
        return 1.0 / (1.0 + Math.pow(Math.E, affinity));
    }
}

