/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.loss;

import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.dyadranking.algorithm.IDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.loss.DyadRankingLossFunction;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

public class DyadRankingLossUtil {
    private DyadRankingLossUtil() {
    }

    public static double computeAverageLoss(DyadRankingLossFunction lossFunction, DyadRankingDataset trueOrderings, DyadRankingDataset predictedOrderings) {
        if (trueOrderings.size() != predictedOrderings.size()) {
            throw new IllegalArgumentException("The list of predictions and the list of ground truth dyad rankings need to have the same length!");
        }
        double avgLoss = 0.0;
        for (int i = 0; i < trueOrderings.size(); ++i) {
            IDyadRankingInstance actual = (IDyadRankingInstance)trueOrderings.get(i);
            IDyadRankingInstance predicted = (IDyadRankingInstance)predictedOrderings.get(i);
            avgLoss += lossFunction.loss(actual, predicted);
        }
        return avgLoss /= (double)trueOrderings.size();
    }

    public static double computeAverageLoss(DyadRankingLossFunction lossFunction, DyadRankingDataset trueOrderings, IDyadRanker ranker, Random random) throws PredictionException {
        double avgLoss = 0.0;
        for (int i = 0; i < trueOrderings.size(); ++i) {
            IDyadRankingInstance actual = (IDyadRankingInstance)trueOrderings.get(i);
            ArrayList shuffleContainer = Lists.newArrayList(actual.iterator());
            Collections.shuffle(shuffleContainer, random);
            DyadRankingInstance shuffledActual = new DyadRankingInstance(shuffleContainer);
            IDyadRankingInstance predicted = (IDyadRankingInstance)ranker.predict(shuffledActual);
            avgLoss += lossFunction.loss(actual, predicted);
        }
        return avgLoss /= (double)trueOrderings.size();
    }

    public static double computeAverageLoss(DyadRankingLossFunction lossFunction, DyadRankingDataset trueOrderings, IDyadRanker ranker) throws PredictionException {
        return DyadRankingLossUtil.computeAverageLoss(lossFunction, trueOrderings, ranker, new Random(0L));
    }
}

