/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.classification.singlelabel;

import ai.libs.jaicore.basic.ArrayUtil;
import ai.libs.jaicore.ml.core.evaluation.Prediction;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.classification.singlelabel.evaluation.ISingleLabelClassification;

public class SingleLabelClassification
extends Prediction
implements ISingleLabelClassification {
    private double[] labelProbabilities;

    public SingleLabelClassification(int numClasses, int predicted) {
        super(predicted);
        this.labelProbabilities = new double[numClasses];
        this.labelProbabilities[predicted] = 1.0;
    }

    public SingleLabelClassification(Map<Integer, Double> labelProbabilities) {
        super(SingleLabelClassification.labelWithHighestProbability(labelProbabilities));
        this.labelProbabilities = new double[labelProbabilities.size()];
        labelProbabilities.entrySet().stream().forEach(x -> {
            this.labelProbabilities[((Integer)x.getKey()).intValue()] = (Double)x.getValue();
        });
    }

    public SingleLabelClassification(double[] labelProbabilities) {
        super(ArrayUtil.argMax((double[])labelProbabilities).get(0));
        this.labelProbabilities = labelProbabilities;
    }

    public int getIntPrediction() {
        return (Integer)super.getPrediction();
    }

    @Override
    public Integer getPrediction() {
        return this.getIntPrediction();
    }

    @Override
    public Integer getLabelWithHighestProbability() {
        return this.getIntPrediction();
    }

    public Map<Integer, Double> getClassDistribution() {
        HashMap<Integer, Double> distributionMap = new HashMap<Integer, Double>();
        IntStream.range(0, this.labelProbabilities.length).forEach(x -> distributionMap.put(x, this.labelProbabilities[x]));
        return distributionMap;
    }

    public double getProbabilityOfLabel(int label) {
        return this.labelProbabilities[label];
    }

    public Map<Integer, Double> getClassConfidence() {
        HashMap<Integer, Double> confidenceMap = new HashMap<Integer, Double>();
        IntStream.range(0, this.labelProbabilities.length).forEach(x -> confidenceMap.put(x, this.labelProbabilities[x]));
        return confidenceMap;
    }

    private static int labelWithHighestProbability(Map<Integer, Double> labelProbabilities) {
        Map.Entry<Integer, Double> highestProbEntry = null;
        for (Map.Entry<Integer, Double> entry : labelProbabilities.entrySet()) {
            if (highestProbEntry != null && !((Double)highestProbEntry.getValue() < entry.getValue())) continue;
            highestProbEntry = entry;
        }
        if (highestProbEntry == null) {
            throw new IllegalArgumentException("No prediction contained");
        }
        return (Integer)highestProbEntry.getKey();
    }
}

