/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.loss.dataset;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.ml.classification.loss.dataset.ASingleLabelClassificationPerformanceMeasure;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;

public abstract class AAreaUnderCurvePerformanceMeasure
extends ASingleLabelClassificationPerformanceMeasure {
    private final int positiveClass;

    public AAreaUnderCurvePerformanceMeasure(int positiveClass) {
        this.positiveClass = positiveClass;
    }

    public AAreaUnderCurvePerformanceMeasure() {
        this(0);
    }

    public Object getPositiveClass() {
        return this.positiveClass;
    }

    public List<Pair<Double, Integer>> getPredictionList(List<? extends Integer> expected, List<? extends ISingleLabelClassification> predicted) {
        ArrayList<Pair<Double, Integer>> predictionsList = new ArrayList<Pair<Double, Integer>>(expected.size());
        IntStream.range(0, expected.size()).mapToObj(x -> new Pair((Object)((ISingleLabelClassification)predicted.get(x)).getProbabilityOfLabel(this.positiveClass), (Object)((Integer)expected.get(x)))).forEach(predictionsList::add);
        Collections.sort(predictionsList, (o1, o2) -> ((Double)o2.getX()).compareTo((Double)o1.getX()));
        return predictionsList;
    }

    protected double getAreaUnderCurve(List<Pair<Double, Double>> curveCoordinates) {
        double area = 0.0;
        for (int i = 1; i < curveCoordinates.size(); ++i) {
            Pair<Double, Double> prev = curveCoordinates.get(i - 1);
            Pair<Double, Double> cur = curveCoordinates.get(i);
            double deltaX = (Double)cur.getX() - (Double)prev.getX();
            double deltaY = (Double)cur.getY() - (Double)prev.getY();
            area += (Double)prev.getY() * deltaX + deltaX * deltaY / 2.0;
        }
        return area;
    }

    @Override
    public double score(List<? extends Integer> expected, List<? extends ISingleLabelClassification> predicted) {
        this.checkConsistency(expected, predicted);
        List<Pair<Double, Integer>> predictionsList = this.getPredictionList(expected, predicted);
        ArrayList<Pair<Double, Double>> curveCoordinates = new ArrayList<Pair<Double, Double>>(predictionsList.size());
        int tp = 0;
        int fp = 0;
        int fn = (int)predictionsList.stream().filter(x -> x.getY() == this.getPositiveClass()).count();
        int tn = predictionsList.size() - fn;
        curveCoordinates.add(new Pair((Object)this.getXValue(tp, fp, tn, fn), (Object)this.getYValue(tp, fp, tn, fn)));
        double currentThreshold = 1.0;
        int currentIndex = 0;
        while (currentIndex < predictionsList.size()) {
            while (currentIndex < predictionsList.size() && currentThreshold <= (Double)predictionsList.get(currentIndex).getX()) {
                Pair<Double, Integer> pred = predictionsList.get(currentIndex);
                if (pred.getY() == this.getPositiveClass()) {
                    ++tp;
                    --fn;
                } else {
                    ++fp;
                    --tn;
                }
                ++currentIndex;
            }
            curveCoordinates.add((Pair<Double, Double>)new Pair((Object)this.getXValue(tp, fp, tn, fn), (Object)this.getYValue(tp, fp, tn, fn)));
            if (currentIndex >= predictionsList.size()) break;
            currentThreshold = (Double)predictionsList.get(currentIndex).getX();
        }
        return this.getAreaUnderCurve(curveCoordinates);
    }

    public abstract double getXValue(int var1, int var2, int var3, int var4);

    public abstract double getYValue(int var1, int var2, int var3, int var4);
}

