/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.loss;

import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.loss.DyadRankingLossFunction;
import java.util.HashMap;
import java.util.Map;

public class NDCGLoss
implements DyadRankingLossFunction {
    private int l;

    public NDCGLoss(int l) {
        this.setL(l);
    }

    @Override
    public double loss(IDyadRankingInstance actual, IDyadRankingInstance predicted) {
        if (actual.length() <= 1) {
            throw new IllegalArgumentException("Dyad rankings must have length greater than 1.");
        }
        if (actual.length() != predicted.length()) {
            throw new IllegalArgumentException("Dyad rankings must have equal length.");
        }
        HashMap<Dyad, Integer> relevance = new HashMap<Dyad, Integer>();
        for (int i = 0; i < this.l; ++i) {
            relevance.put(actual.getDyadAtPosition(i), -(i + 1));
        }
        double dcg = this.computeDCG(predicted, relevance);
        double idcg = this.computeDCG(actual, relevance);
        if (dcg != 0.0) {
            return idcg / dcg;
        }
        return 0.0;
    }

    private double computeDCG(IDyadRankingInstance ranking, Map<Dyad, Integer> relevance) {
        int length = ranking.length();
        double dcg = 0.0;
        for (int i = 0; i < length; ++i) {
            dcg += (Math.pow(2.0, relevance.get(ranking.getDyadAtPosition(i)).intValue()) - 1.0) / this.log2((double)i + 2.0);
        }
        return dcg;
    }

    private double log2(double x) {
        return Math.log(x) / Math.log(2.0);
    }

    public int getL() {
        return this.l;
    }

    public void setL(int l) {
        this.l = l;
    }
}

