/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.multilabel;

import ai.libs.jaicore.ml.core.evaluation.Prediction;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.multilabel.evaluation.IMultiLabelClassification;

public class MultiLabelClassification
extends Prediction
implements IMultiLabelClassification {
    private static final double DEFAULT_THRESHOLD = 0.5;
    private double[] threshold;

    public MultiLabelClassification(double[] predicted) {
        this(predicted, 0.5);
    }

    public MultiLabelClassification(double[] predicted, double threshold) {
        this(predicted, IntStream.range(0, predicted.length).mapToDouble(x -> threshold).toArray());
    }

    public MultiLabelClassification(double[] predicted, double[] threshold) {
        super(predicted);
        this.threshold = threshold;
    }

    public double[] getPrediction() {
        return (double[])super.getPrediction();
    }

    public int[] getThresholdedPrediction() {
        return IntStream.range(0, this.getPrediction().length).map(x -> this.getPrediction()[x] >= this.threshold[x] ? 1 : 0).toArray();
    }

    public int[] getPrediction(double threshold) {
        return IntStream.range(0, this.getPrediction().length).map(x -> this.getPrediction()[x] >= threshold ? 1 : 0).toArray();
    }

    public int[] getPrediction(double[] threshold) {
        return IntStream.range(0, this.getPrediction().length).map(x -> this.getPrediction()[x] >= threshold[x] ? 1 : 0).toArray();
    }

    public int[] getRelevantLabels(double threshold) {
        return IntStream.range(0, this.getPrediction().length).filter(x -> this.getPrediction()[x] >= threshold).toArray();
    }

    public int[] getIrrelevantLabels(double threshold) {
        return IntStream.range(0, this.getPrediction().length).filter(x -> this.getPrediction()[x] < threshold).toArray();
    }
}

