/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.dyadranking.loss;

import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.loss.DyadRankingLossFunction;

public class KendallsTauOfTopK
implements DyadRankingLossFunction {
    private int k;
    private double p;

    public KendallsTauOfTopK(int k, double p) {
        this.k = k;
        this.p = p;
    }

    @Override
    public double loss(IDyadRankingInstance actual, IDyadRankingInstance predicted) {
        if (this.k <= 1) {
            throw new IllegalArgumentException("Dyad rankings must have length greater than 1.");
        }
        double kendallsDistance = 0.0;
        for (int actualI = 0; actualI < actual.length() - 1; ++actualI) {
            Dyad actualDyad = actual.getDyadAtPosition(actualI);
            int predictedI = -1;
            for (int i = 0; i < predicted.length(); ++i) {
                if (!predicted.getDyadAtPosition(i).equals(actualDyad)) continue;
                predictedI = i;
                break;
            }
            for (int actualJ = actualI + 1; actualJ < actual.length(); ++actualJ) {
                Dyad actPairedDyad = actual.getDyadAtPosition(actualJ);
                int predictedJ = -1;
                for (int j = 0; j < predicted.length(); ++j) {
                    if (!predicted.getDyadAtPosition(j).equals(actPairedDyad)) continue;
                    predictedJ = j;
                    break;
                }
                double penalty = 0.0;
                boolean iAndJAreBothInPredictedTopK = predictedI < this.k && predictedJ < this.k;
                boolean iAndJAreBothInActualTopK = actualI < this.k && actualJ < this.k;
                penalty = this.checkCase1(actualI, predictedI, actualJ, predictedJ, penalty, iAndJAreBothInPredictedTopK, iAndJAreBothInActualTopK);
                boolean justIIsInPredictedTopK = predictedI < this.k && predictedJ >= this.k;
                boolean justJIsInPredictedTopK = predictedJ < this.k && predictedI >= this.k;
                boolean justIIsInActualTopK = actualI < this.k && actualJ >= this.k;
                boolean justJIsInActualTopK = actualJ < this.k && actualI >= this.k;
                penalty = this.checkCase2(actualI, predictedI, actualJ, predictedJ, penalty, iAndJAreBothInPredictedTopK, iAndJAreBothInActualTopK, justIIsInPredictedTopK, justJIsInPredictedTopK, justIIsInActualTopK, justJIsInActualTopK);
                penalty = this.checkCase3(penalty, justIIsInPredictedTopK, justJIsInPredictedTopK, justIIsInActualTopK, justJIsInActualTopK);
                penalty = this.checkCase4(actualI, predictedI, actualJ, predictedJ, penalty, iAndJAreBothInPredictedTopK, iAndJAreBothInActualTopK);
                kendallsDistance += penalty;
            }
        }
        return kendallsDistance;
    }

    private double checkCase1(int actualI, int predictedI, int actualJ, int predictedJ, double penalty, boolean iAndJAreBothInPredictedTopK, boolean iAndJAreBothInActualTopK) {
        if (iAndJAreBothInActualTopK && iAndJAreBothInPredictedTopK) {
            boolean jIsBetterThanIInPredictedButNotInActualRanking;
            boolean jIsBetterThanIInPredictedAndActualRanking;
            boolean iIsBetterThanJInPredictedAndActualRanking = predictedI < predictedJ && actualI < actualJ;
            boolean bl = jIsBetterThanIInPredictedAndActualRanking = predictedI > predictedJ && actualI > actualJ;
            if (iIsBetterThanJInPredictedAndActualRanking || jIsBetterThanIInPredictedAndActualRanking) {
                penalty = 0.0;
            }
            boolean iIsBetterThanJInPredictedButNotInActualRanking = predictedI < predictedJ && actualI > actualJ;
            boolean bl2 = jIsBetterThanIInPredictedButNotInActualRanking = predictedI > predictedJ && actualI < actualJ;
            if (iIsBetterThanJInPredictedButNotInActualRanking || jIsBetterThanIInPredictedButNotInActualRanking) {
                penalty = 1.0;
            }
        }
        return penalty;
    }

    private double checkCase2(int actualI, int predictedI, int actualJ, int predictedJ, double penalty, boolean iAndJAreBothInPredictedTopK, boolean iAndJAreBothInActualTopK, boolean justIIsInPredictedTopK, boolean justJIsInPredictedTopK, boolean justIIsInActualTopK, boolean justJIsInActualTopK) {
        boolean bothActualAreInTopKButJustOnePredicted;
        boolean bothPredictedAreInTopKButJustOneActual = iAndJAreBothInPredictedTopK && justIIsInActualTopK || iAndJAreBothInPredictedTopK && justJIsInPredictedTopK;
        boolean bl = bothActualAreInTopKButJustOnePredicted = iAndJAreBothInActualTopK && justIIsInPredictedTopK || iAndJAreBothInActualTopK && justJIsInPredictedTopK;
        if (bothActualAreInTopKButJustOnePredicted) {
            penalty = actualI < actualJ ? (justIIsInPredictedTopK ? 0.0 : 1.0) : (justJIsInPredictedTopK ? 0.0 : 1.0);
        }
        if (bothPredictedAreInTopKButJustOneActual) {
            penalty = predictedI < predictedJ ? (justIIsInActualTopK ? 0.0 : 1.0) : (justJIsInActualTopK ? 0.0 : 1.0);
        }
        return penalty;
    }

    private double checkCase3(double penalty, boolean justIIsInPredictedTopK, boolean justJIsInPredictedTopK, boolean justIIsInActualTopK, boolean justJIsInActualTopK) {
        if (justIIsInActualTopK && justJIsInPredictedTopK) {
            penalty = 1.0;
        }
        if (justJIsInActualTopK && justIIsInPredictedTopK) {
            penalty = 1.0;
        }
        return penalty;
    }

    private double checkCase4(int actualI, int predictedI, int actualJ, int predictedJ, double penalty, boolean iAndJAreBothInPredictedTopK, boolean iAndJAreBothInActualTopK) {
        boolean neitherIOrJAreInActualTopK;
        boolean neitherIOrJAreInPredictedTopK = predictedI >= this.k && predictedJ >= this.k;
        boolean bl = neitherIOrJAreInActualTopK = actualI >= this.k && actualJ >= this.k;
        if (iAndJAreBothInActualTopK && neitherIOrJAreInPredictedTopK) {
            penalty = this.p;
        }
        if (iAndJAreBothInPredictedTopK && neitherIOrJAreInActualTopK) {
            penalty = this.p;
        }
        return penalty;
    }
}

