/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.optimizing;

import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.optimizing.IDyadRankingFeatureTransformPLGradientFunction;
import java.util.Map;

public class DyadRankingFeatureTransformNegativeLogLikelihoodDerivative
implements IDyadRankingFeatureTransformPLGradientFunction {
    private DyadRankingDataset dataset;
    private Map<IDyadRankingInstance, Map<Dyad, Vector>> featureTransforms;

    @Override
    public void initialize(DyadRankingDataset dataset, Map<IDyadRankingInstance, Map<Dyad, Vector>> featureTransforms) {
        this.dataset = dataset;
        this.featureTransforms = featureTransforms;
    }

    @Override
    public Vector apply(Vector vector) {
        DenseDoubleVector result = new DenseDoubleVector(vector.length());
        for (int i = 0; i < vector.length(); ++i) {
            result.setValue(i, this.computeDerivativeForIndex(i, vector));
        }
        return result;
    }

    private double computeDerivativeForIndex(int i, Vector vector) {
        double secondSum = 0.0;
        int largeN = this.dataset.size();
        double firstSum = 0.0;
        for (int smallN = 0; smallN < largeN; ++smallN) {
            IDyadRankingInstance instance = (IDyadRankingInstance)this.dataset.get(smallN);
            int mN = instance.length();
            for (int m = 0; m < mN - 1; ++m) {
                double innerDenumerator = 0.0;
                double innerNumerator = 0.0;
                Dyad dyad = instance.getDyadAtPosition(m);
                firstSum += this.featureTransforms.get(instance).get(dyad).getValue(i);
                for (int l = m; l < mN; ++l) {
                    Vector zNL = this.featureTransforms.get(instance).get(instance.getDyadAtPosition(l));
                    double dotProd = Math.exp(vector.dotProduct(zNL));
                    innerNumerator += zNL.getValue(i) * dotProd;
                    innerDenumerator += dotProd;
                }
                if (innerDenumerator == 0.0) continue;
                secondSum += innerNumerator / innerDenumerator;
            }
        }
        return -firstSum + secondSum;
    }
}

