/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.multilabel.evaluation.loss;

import ai.libs.jaicore.basic.ArrayUtil;
import ai.libs.jaicore.ml.classification.multilabel.evaluation.loss.AMultiLabelClassificationMeasure;
import java.util.Iterator;
import java.util.List;
import java.util.OptionalDouble;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.multilabel.evaluation.IMultiLabelClassification;

public class RankLoss
extends AMultiLabelClassificationMeasure {
    private static final double DEFAULT_TIE_LOSS = 0.0;
    private final double tieLoss;

    public RankLoss() {
        this(0.0);
    }

    public RankLoss(double tieLoss) {
        this.tieLoss = tieLoss;
    }

    private double rankingLoss(int[] expected, IMultiLabelClassification predicted) {
        List expectedRelevantLabels = ArrayUtil.argMax((int[])expected);
        List expectedIrrelevantLabels = ArrayUtil.argMin((int[])expected);
        double[] labelRelevance = predicted.getPrediction();
        double wrongRankingCounter = 0.0;
        Iterator iterator = expectedRelevantLabels.iterator();
        while (iterator.hasNext()) {
            int expectedRel = (Integer)iterator.next();
            Iterator iterator2 = expectedIrrelevantLabels.iterator();
            while (iterator2.hasNext()) {
                double scoreRelLabel = labelRelevance[expectedRel];
                int expectedIrr = (Integer)iterator2.next();
                double scoreIrrLabel = labelRelevance[expectedIrr];
                if (scoreRelLabel == scoreIrrLabel) {
                    wrongRankingCounter += this.tieLoss;
                    continue;
                }
                if (!(scoreRelLabel < scoreIrrLabel)) continue;
                wrongRankingCounter += 1.0;
            }
        }
        return wrongRankingCounter / (double)(expectedRelevantLabels.size() + expectedIrrelevantLabels.size());
    }

    @Override
    public double loss(List<? extends int[]> expected, List<? extends IMultiLabelClassification> predicted) {
        this.checkConsistency(expected, predicted);
        OptionalDouble res = IntStream.range(0, expected.size()).mapToDouble(x -> this.rankingLoss((int[])expected.get(x), (IMultiLabelClassification)predicted.get(x))).average();
        if (res.isPresent()) {
            return res.getAsDouble();
        }
        throw new IllegalStateException("The ranking loss could not be averaged across all the instances.");
    }
}

