/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.ranking.loss;

import ai.libs.jaicore.ml.ranking.loss.ARankingPredictionPerformanceMeasure;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.loss.IRankingPredictionPerformanceMeasure;

public class NDCGLoss
extends ARankingPredictionPerformanceMeasure
implements IRankingPredictionPerformanceMeasure {
    private int l;

    public NDCGLoss(int l) {
        this.setL(l);
    }

    @Override
    public double loss(List<? extends IRanking<?>> expected, List<? extends IRanking<?>> actual) {
        OptionalDouble res = IntStream.range(0, expected.size()).mapToDouble(x -> this.loss((IRanking)expected.get(0), (IRanking)actual.get(0))).average();
        if (res.isPresent()) {
            return res.getAsDouble();
        }
        throw new IllegalStateException("Could not aggregate kendalls tau of top k");
    }

    @Override
    public double loss(IRanking<?> expected, IRanking<?> actual) {
        if (expected.size() <= 1) {
            throw new IllegalArgumentException("Dyad rankings must have length greater than 1.");
        }
        if (expected.size() != actual.size()) {
            throw new IllegalArgumentException("Dyad rankings must have equal length.");
        }
        HashMap<Object, Integer> relevance = new HashMap<Object, Integer>();
        for (int i = 0; i < this.l; ++i) {
            relevance.put(expected.get(i), -(i + 1));
        }
        double dcg = this.computeDCG(actual, relevance);
        double idcg = this.computeDCG(expected, relevance);
        if (dcg != 0.0) {
            return idcg / dcg;
        }
        return 0.0;
    }

    private double computeDCG(IRanking<?> ranking, Map<Object, Integer> relevance) {
        int length = ranking.size();
        double dcg = 0.0;
        for (int i = 0; i < length; ++i) {
            dcg += (Math.pow(2.0, relevance.get(ranking.get(i)).intValue()) - 1.0) / this.log2((double)i + 2.0);
        }
        return dcg;
    }

    private double log2(double x) {
        return Math.log(x) / Math.log(2.0);
    }

    public int getL() {
        return this.l;
    }

    public void setL(int l) {
        this.l = l;
    }
}

