/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.zeroshot.inputoptimization;

import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.zeroshot.inputoptimization.InputOptimizerLoss;
import ai.libs.jaicore.ml.ranking.dyad.learner.zeroshot.util.InputOptListener;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;

public class PLNetInputOptimizer {
    private InputOptListener listener;

    public INDArray optimizeInput(PLNetDyadRanker plNet, INDArray input, InputOptimizerLoss loss, double learningRate, int numSteps, Pair<Integer, Integer> indexRange) {
        INDArray mask;
        if (indexRange != null) {
            mask = Nd4j.zeros((long[])new long[]{input.length()});
            mask.get(new INDArrayIndex[]{NDArrayIndex.interval((int)((Integer)indexRange.getFirst()), (int)((Integer)indexRange.getSecond()))}).assign((Number)1.0);
        } else {
            mask = Nd4j.ones((long[])new long[]{input.length()});
        }
        return this.optimizeInput(plNet, input, loss, learningRate, numSteps, mask);
    }

    public INDArray optimizeInput(PLNetDyadRanker plNet, INDArray input, InputOptimizerLoss loss, double initialLearningRate, double finalLearningRate, int numSteps, Pair<Integer, Integer> indexRange) {
        INDArray mask;
        if (indexRange != null) {
            mask = Nd4j.zeros((long[])new long[]{input.length()});
            mask.get(new INDArrayIndex[]{NDArrayIndex.interval((int)((Integer)indexRange.getFirst()), (int)((Integer)indexRange.getSecond()))}).assign((Number)1.0);
        } else {
            mask = Nd4j.ones((long[])new long[]{input.length()});
        }
        return this.optimizeInput(plNet, input, loss, initialLearningRate, finalLearningRate, numSteps, mask);
    }

    public INDArray optimizeInput(PLNetDyadRanker plNet, INDArray input, InputOptimizerLoss loss, double learningRate, int numSteps, INDArray inputMask) {
        return this.optimizeInput(plNet, input, loss, learningRate, learningRate, numSteps, inputMask);
    }

    public INDArray optimizeInput(PLNetDyadRanker plNet, INDArray input, InputOptimizerLoss loss, double initialLearningRate, double finalLearningRate, int numSteps, INDArray inputMask) {
        double output;
        INDArray inp = input.dup();
        INDArray alphas = Nd4j.zeros((long[])inp.shape());
        INDArray betas = Nd4j.zeros((long[])inp.shape());
        INDArray ones = Nd4j.ones((long[])inp.shape());
        double incumbentOutput = output = plNet.getPlNet().output(inp).getDouble(0L);
        INDArray incumbent = inp.dup();
        for (int i = 0; i < numSteps; ++i) {
            double lrDecayTerm = (double)i / (double)numSteps;
            double learningRate = (1.0 - lrDecayTerm) * initialLearningRate + lrDecayTerm * finalLearningRate;
            INDArray grad = PLNetInputOptimizer.computeInputDerivative(plNet, inp, loss);
            grad.subi(alphas);
            grad.addi(betas);
            alphas.subi(inp);
            betas.addi(inp.sub(ones));
            BooleanIndexing.replaceWhere((INDArray)alphas, (Number)0.0, (Condition)Conditions.lessThan((Number)0.0));
            BooleanIndexing.replaceWhere((INDArray)betas, (Number)0.0, (Condition)Conditions.lessThan((Number)0.0));
            grad.muli(inputMask);
            grad.muli((Number)learningRate);
            inp.subi(grad);
            output = plNet.getPlNet().output(inp).getDouble(0L);
            if (this.listener != null) {
                this.listener.reportOptimizationStep(inp, output);
            }
            INDArray incCheck = inp.dup().muli(inputMask);
            if (!(output > incumbentOutput) || !BooleanIndexing.and((INDArray)incCheck, (Condition)Conditions.greaterThanOrEqual((Number)0.0)) || !BooleanIndexing.and((INDArray)incCheck, (Condition)Conditions.lessThanOrEqual((Number)1.0))) continue;
            incumbent = inp.dup();
            incumbentOutput = output;
        }
        return incumbent;
    }

    private static INDArray computeInputDerivative(PLNetDyadRanker plNet, INDArray input, InputOptimizerLoss loss) {
        MultiLayerNetwork net = plNet.getPlNet();
        INDArray output = net.output(input);
        INDArray lossGradient = Nd4j.create((double[])new double[]{loss.lossGradient(output)});
        net.setInput(input);
        net.feedForward(false, false);
        Pair p = net.backpropGradient(lossGradient, null);
        return (INDArray)p.getSecond();
    }

    public void setListener(InputOptListener listener) {
        this.listener = listener;
    }
}

