/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.loss.dataset;

import ai.libs.jaicore.basic.aggregate.reals.Mean;
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;
import org.api4.java.ai.ml.core.evaluation.IPredictionAndGroundTruthTable;
import org.api4.java.ai.ml.core.evaluation.execution.IAggregatedPredictionPerformanceMeasure;
import org.api4.java.ai.ml.core.evaluation.supervised.loss.IDeterministicPredictionPerformanceMeasure;
import org.api4.java.common.aggregate.IAggregateFunction;

public enum EAggregatedClassifierMetric implements IAggregatedPredictionPerformanceMeasure<Integer, ISingleLabelClassification>
{
    MEAN_ERRORRATE(EClassificationPerformanceMeasure.ERRORRATE, (IAggregateFunction<Double>)new Mean());

    private final IDeterministicPredictionPerformanceMeasure<Integer, ISingleLabelClassification> lossFunction;
    private final IAggregateFunction<Double> aggregation;

    private EAggregatedClassifierMetric(IDeterministicPredictionPerformanceMeasure<Integer, ISingleLabelClassification> lossFunction, IAggregateFunction<Double> aggregation) {
        this.lossFunction = lossFunction;
        this.aggregation = aggregation;
    }

    public double loss(List<List<? extends Integer>> expected, List<List<? extends ISingleLabelClassification>> predicted) {
        int n = expected.size();
        ArrayList<Double> losses = new ArrayList<Double>();
        for (int i = 0; i < n; ++i) {
            losses.add(this.lossFunction.loss(expected.get(i), predicted.get(i)));
        }
        return (Double)this.aggregation.aggregate(losses);
    }

    public double loss(List<IPredictionAndGroundTruthTable<? extends Integer, ? extends ISingleLabelClassification>> pairTables) {
        return (Double)this.aggregation.aggregate(pairTables.stream().map(arg_0 -> this.lossFunction.loss(arg_0)).collect(Collectors.toList()));
    }

    public double score(List<List<? extends Integer>> expected, List<List<? extends ISingleLabelClassification>> predicted) {
        return 1.0 - this.loss(expected, predicted);
    }

    public double score(List<IPredictionAndGroundTruthTable<? extends Integer, ? extends ISingleLabelClassification>> pairTables) {
        return 1.0 - this.loss(pairTables);
    }

    public IDeterministicPredictionPerformanceMeasure<Integer, ISingleLabelClassification> getBaseMeasure() {
        return this.lossFunction;
    }
}

