/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.optimizing;

import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.IDyadRankingFeatureTransformPLGradientFunction;
import java.util.Map;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;

public class DyadRankingFeatureTransformNegativeLogLikelihoodDerivative
implements IDyadRankingFeatureTransformPLGradientFunction {
    private IDyadRankingDataset dataset;
    private Map<IDyadRankingInstance, Map<IDyad, IVector>> featureTransforms;

    @Override
    public void initialize(IDyadRankingDataset dataset, Map<IDyadRankingInstance, Map<IDyad, IVector>> featureTransforms) {
        this.dataset = dataset;
        this.featureTransforms = featureTransforms;
    }

    public IVector apply(IVector vector) {
        DenseDoubleVector result = new DenseDoubleVector(vector.length());
        for (int i = 0; i < vector.length(); ++i) {
            result.setValue(i, this.computeDerivativeForIndex(i, vector));
        }
        return result;
    }

    private double computeDerivativeForIndex(int i, IVector vector) {
        double secondSum = 0.0;
        int largeN = this.dataset.size();
        double firstSum = 0.0;
        for (int smallN = 0; smallN < largeN; ++smallN) {
            IDyadRankingInstance instance = (IDyadRankingInstance)this.dataset.get(smallN);
            int mN = instance.getNumberOfRankedElements();
            for (int m = 0; m < mN - 1; ++m) {
                double innerDenumerator = 0.0;
                double innerNumerator = 0.0;
                IDyad dyad = (IDyad)instance.getLabel().get(m);
                firstSum += this.featureTransforms.get(instance).get(dyad).getValue(i);
                for (int l = m; l < mN; ++l) {
                    IVector zNL = this.featureTransforms.get(instance).get(instance.getAttributeValue(l));
                    double dotProd = Math.exp(vector.dotProduct(zNL));
                    innerNumerator += zNL.getValue(i) * dotProd;
                    innerDenumerator += dotProd;
                }
                if (innerDenumerator == 0.0) continue;
                secondSum += innerNumerator / innerDenumerator;
            }
        }
        return -firstSum + secondSum;
    }
}

