/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.optimizing;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.optimizing.IDyadRankingFeatureTransformPLGradientDescendableFunction;
import java.util.Map;

public class DyadRankingFeatureTransformNegativeLogLikelihood
implements IDyadRankingFeatureTransformPLGradientDescendableFunction {
    private DyadRankingDataset dataset;
    private Map<IDyadRankingInstance, Map<Dyad, Vector>> featureTransforms;

    @Override
    public void initialize(DyadRankingDataset dataset, Map<IDyadRankingInstance, Map<Dyad, Vector>> featureTransforms) {
        this.dataset = dataset;
        this.featureTransforms = featureTransforms;
    }

    @Override
    public double apply(Vector w) {
        double firstSum = 0.0;
        double secondSum = 0.0;
        int largeN = this.dataset.size();
        for (int smallN = 0; smallN < largeN; ++smallN) {
            IDyadRankingInstance instance = (IDyadRankingInstance)this.dataset.get(smallN);
            int mN = instance.length();
            for (int m = 0; m < mN; ++m) {
                Dyad dyad = instance.getDyadAtPosition(m);
                firstSum += w.dotProduct(this.featureTransforms.get(instance).get(dyad));
                double innerSum = 0.0;
                for (int l = m; l < mN - 1; ++l) {
                    Dyad innerDyad = instance.getDyadAtPosition(l);
                    innerSum += Math.exp(w.dotProduct(this.featureTransforms.get(instance).get(innerDyad)));
                }
                secondSum += Math.log(innerSum);
            }
        }
        return -firstSum + secondSum;
    }
}

