/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.core;

import java.util.List;
import org.assertj.core.api.AbstractDoubleAssert;
import org.assertj.core.api.AssertionsForInterfaceTypes;
import org.assertj.core.data.Offset;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Scalar;

public interface FiniteDifferenceTest {
    public static final String FAIL_MESSAGE = "AutoGrad of %f and FiniteDifference gradients of %f differs for coordinate %s more than the tolerance.";

    default public double tolerance() {
        return 1.0E-5;
    }

    default public double epsilon() {
        return 1.0E-4;
    }

    default public void finiteDifferenceShouldApproximateGradient(Weights<?> weightVariable, Variable<Scalar> loss) {
        this.finiteDifferenceShouldApproximateGradient(List.of(weightVariable), loss);
    }

    default public void finiteDifferenceShouldApproximateGradient(List<Weights<?>> weightVariables, Variable<Scalar> loss) {
        for (Weights<?> variable : weightVariables) {
            for (int tensorIndex = 0; tensorIndex < Dimensions.totalSize((int[])variable.dimensions()); ++tensorIndex) {
                ComputationContext ctx = new ComputationContext();
                double forwardLoss = ((Scalar)ctx.forward(loss)).value();
                ctx.backward(loss);
                double autoGradient = ctx.gradient(variable).dataAt(tensorIndex);
                variable.data().addDataAt(tensorIndex, this.epsilon());
                ComputationContext ctx2 = new ComputationContext();
                double forwardLossOnPerturbedData = ((Scalar)ctx2.forward(loss)).value();
                double finiteDifferenceGrad = (forwardLossOnPerturbedData - forwardLoss) / this.epsilon();
                ((AbstractDoubleAssert)AssertionsForInterfaceTypes.assertThat((double)finiteDifferenceGrad).isNotNaN().withFailMessage(FAIL_MESSAGE, new Object[]{finiteDifferenceGrad, autoGradient, tensorIndex})).isEqualTo(autoGradient, Offset.offset((Number)this.tolerance()));
            }
        }
    }
}

