/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.dyad.learner.optimizing;

import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.IDyadRankingFeatureTransformPLGradientDescendableFunction;
import java.util.Map;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;

public class DyadRankingFeatureTransformNegativeLogLikelihood
implements IDyadRankingFeatureTransformPLGradientDescendableFunction {
    private IDyadRankingDataset dataset;
    private Map<IDyadRankingInstance, Map<IDyad, IVector>> featureTransforms;

    @Override
    public void initialize(IDyadRankingDataset dataset, Map<IDyadRankingInstance, Map<IDyad, IVector>> featureTransforms) {
        this.dataset = dataset;
        this.featureTransforms = featureTransforms;
    }

    public double apply(IVector w) {
        double firstSum = 0.0;
        double secondSum = 0.0;
        int largeN = this.dataset.size();
        for (int smallN = 0; smallN < largeN; ++smallN) {
            IDyadRankingInstance instance = (IDyadRankingInstance)this.dataset.get(smallN);
            int mN = instance.getNumberOfRankedElements();
            for (int m = 0; m < mN; ++m) {
                IDyad dyad = (IDyad)instance.getLabel().get(m);
                firstSum += w.dotProduct(this.featureTransforms.get(instance).get(dyad));
                double innerSum = 0.0;
                for (int l = m; l < mN - 1; ++l) {
                    IDyad innerDyad = (IDyad)instance.getLabel().get(l);
                    innerSum += Math.exp(w.dotProduct(this.featureTransforms.get(instance).get(innerDyad)));
                }
                secondSum += Math.log(innerSum);
            }
        }
        return -firstSum + secondSum;
    }
}

