/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.multilabel.evaluation.loss;

import ai.libs.jaicore.ml.classification.loss.dataset.APredictionPerformanceMeasure;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.multilabel.evaluation.IMultiLabelClassification;
import org.api4.java.ai.ml.classification.multilabel.evaluation.loss.IMultiLabelClassificationPredictionPerformanceMeasure;

public abstract class AMultiLabelClassificationMeasure
extends APredictionPerformanceMeasure<int[], IMultiLabelClassification>
implements IMultiLabelClassificationPredictionPerformanceMeasure {
    private static final double DEFAULT_THRESHOLD = 0.5;
    private final double threshold;

    protected AMultiLabelClassificationMeasure(double threshold) {
        this.threshold = threshold;
    }

    protected AMultiLabelClassificationMeasure() {
        this.threshold = 0.5;
    }

    public double getThreshold() {
        return this.threshold;
    }

    protected double[][] listToRelevanceMatrix(List<IMultiLabelClassification> classificationList) {
        double[][] matrix = new double[classificationList.size()][];
        IntStream.range(0, classificationList.size()).forEach(x -> {
            matrix[x] = ((IMultiLabelClassification)classificationList.get(x)).getPrediction();
        });
        return matrix;
    }

    protected int[][] listToThresholdedRelevanceMatrix(List<IMultiLabelClassification> classificationList) {
        int[][] matrix = new int[classificationList.size()][];
        IntStream.range(0, classificationList.size()).forEach(x -> {
            matrix[x] = ((IMultiLabelClassification)classificationList.get(x)).getPrediction(this.threshold);
        });
        return matrix;
    }

    protected int[][] transposeMatrix(int[][] matrix) {
        int[][] out = new int[matrix[0].length][];
        for (int i = 0; i < matrix[0].length; ++i) {
            out[i] = new int[matrix.length];
            for (int j = 0; j < matrix.length; ++j) {
                out[i][j] = matrix[j][i];
            }
        }
        return out;
    }

    protected double[][] transposeMatrix(double[][] matrix) {
        double[][] out = new double[matrix[0].length][];
        for (int i = 0; i < matrix[0].length; ++i) {
            out[i] = new double[matrix.length];
            for (int j = 0; j < matrix.length; ++j) {
                out[i][j] = matrix[j][i];
            }
        }
        return out;
    }

    protected Set<Integer> getThresholdedPredictionAsSet(IMultiLabelClassification prediction) {
        return Arrays.stream(prediction.getThresholdedPrediction()).mapToObj(Integer::valueOf).collect(Collectors.toSet());
    }
}

