/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.evaluation.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Triple;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.serde.ROCArraySerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

public class ROCMultiClass
extends BaseEvaluation<ROCMultiClass> {
    public static final int DEFAULT_STATS_PRECISION = 4;
    private int thresholdSteps;
    private boolean rocRemoveRedundantPts;
    @JsonSerialize(using=ROCArraySerializer.class)
    private ROC[] underlying;
    private List<String> labels;
    protected int axis = 1;

    protected ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels) {
        this.thresholdSteps = thresholdSteps;
        this.rocRemoveRedundantPts = rocRemoveRedundantPts;
        this.axis = axis;
        this.labels = labels;
    }

    public ROCMultiClass() {
        this(0);
    }

    public ROCMultiClass(int thresholdSteps) {
        this(thresholdSteps, true);
    }

    public ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts) {
        this.thresholdSteps = thresholdSteps;
        this.rocRemoveRedundantPts = rocRemoveRedundantPts;
    }

    public void setAxis(int axis) {
        this.axis = axis;
    }

    public int getAxis() {
        return this.axis;
    }

    @Override
    public void reset() {
        this.underlying = null;
    }

    @Override
    public String stats() {
        return this.stats(4);
    }

    public String stats(int printPrecision) {
        StringBuilder sb = new StringBuilder();
        int maxLabelsLength = 15;
        if (this.labels != null) {
            for (String s : this.labels) {
                maxLabelsLength = Math.max(s.length(), maxLabelsLength);
            }
        }
        String patternHeader = "%-" + (maxLabelsLength + 5) + "s%-12s%-10s%-10s";
        String header = String.format(patternHeader, "Label", "AUC", "# Pos", "# Neg");
        String pattern = "%-" + (maxLabelsLength + 5) + "s%-12." + printPrecision + "f%-10d%-10d";
        sb.append(header);
        if (this.underlying != null) {
            for (int i = 0; i < this.underlying.length; ++i) {
                double auc = this.calculateAUC(i);
                String label = this.labels == null ? String.valueOf(i) : this.labels.get(i);
                sb.append("\n").append(String.format(pattern, label, auc, this.getCountActualPositive(i), this.getCountActualNegative(i)));
            }
            sb.append("Average AUC: ").append(String.format("%-12." + printPrecision + "f", this.calculateAverageAUC()));
            if (this.thresholdSteps > 0) {
                sb.append("\n");
                sb.append("[Note: Thresholded AUC/AUPRC calculation used with ").append(this.thresholdSteps).append(" steps); accuracy may reduced compared to exact mode]");
            }
        } else {
            sb.append("\n-- No Data --\n");
        }
        return sb.toString();
    }

    @Override
    public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData) {
        int i;
        Triple<INDArray, INDArray, INDArray> p = BaseEvaluation.reshapeAndExtractNotMasked(labels, predictions, mask, this.axis);
        if (p == null) {
            return;
        }
        INDArray labels2d = p.getFirst();
        INDArray predictions2d = p.getSecond();
        INDArray maskArray = p.getThird();
        Preconditions.checkState(maskArray == null, "Per-output masking for ROCMultiClass is not supported");
        if (labels2d.dataType() != predictions2d.dataType()) {
            labels2d = labels2d.castTo(predictions2d.dataType());
        }
        int n = (int)labels2d.size(1);
        if (this.underlying == null) {
            this.underlying = new ROC[n];
            for (i = 0; i < n; ++i) {
                this.underlying[i] = new ROC(this.thresholdSteps, this.rocRemoveRedundantPts);
            }
        }
        if ((long)this.underlying.length != labels2d.size(1)) {
            throw new IllegalArgumentException("Cannot evaluate data: number of label classes does not match previous call. Got " + labels2d.size(1) + " labels (from array shape " + Arrays.toString(labels2d.shape()) + ") vs. expected number of label classes = " + this.underlying.length);
        }
        for (i = 0; i < n; ++i) {
            INDArray prob = predictions2d.getColumn(i, true);
            INDArray label = labels2d.getColumn(i, true);
            if (prob.rank() == 0) {
                prob = prob.reshape(1L, 1L);
            }
            if (label.rank() == 0) {
                label = label.reshape(1L, 1L);
            }
            this.underlying[i].eval(label, prob);
        }
    }

    public RocCurve getRocCurve(int classIdx) {
        this.assertIndex(classIdx);
        return this.underlying[classIdx].getRocCurve();
    }

    public PrecisionRecallCurve getPrecisionRecallCurve(int classIdx) {
        this.assertIndex(classIdx);
        return this.underlying[classIdx].getPrecisionRecallCurve();
    }

    public double calculateAUC(int classIdx) {
        this.assertIndex(classIdx);
        return this.underlying[classIdx].calculateAUC();
    }

    public double calculateAUCPR(int classIdx) {
        this.assertIndex(classIdx);
        return this.underlying[classIdx].calculateAUCPR();
    }

    public double calculateAverageAUC() {
        this.assertIndex(0);
        double sum = 0.0;
        for (int i = 0; i < this.underlying.length; ++i) {
            sum += this.calculateAUC(i);
        }
        return sum / (double)this.underlying.length;
    }

    public double calculateAverageAUCPR() {
        double sum = 0.0;
        for (int i = 0; i < this.underlying.length; ++i) {
            sum += this.calculateAUCPR(i);
        }
        return sum / (double)this.underlying.length;
    }

    public long getCountActualPositive(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].getCountActualPositive();
    }

    public long getCountActualNegative(int outputNum) {
        this.assertIndex(outputNum);
        return this.underlying[outputNum].getCountActualNegative();
    }

    @Override
    public void merge(ROCMultiClass other) {
        if (this.underlying == null) {
            this.underlying = other.underlying;
            return;
        }
        if (other.underlying == null) {
            return;
        }
        if (this.underlying.length != other.underlying.length) {
            throw new UnsupportedOperationException("Cannot merge ROCBinary: this expects " + this.underlying.length + "outputs, other expects " + other.underlying.length + " outputs");
        }
        for (int i = 0; i < this.underlying.length; ++i) {
            this.underlying[i].merge(other.underlying[i]);
        }
    }

    public int getNumClasses() {
        if (this.underlying == null) {
            return -1;
        }
        return this.underlying.length;
    }

    private void assertIndex(int classIdx) {
        if (this.underlying == null) {
            throw new IllegalStateException("Cannot get results: no data has been collected");
        }
        if (classIdx < 0 || classIdx >= this.underlying.length) {
            throw new IllegalArgumentException("Invalid class index (" + classIdx + "): must be in range 0 to numClasses = " + this.underlying.length);
        }
    }

    public static ROCMultiClass fromJson(String json) {
        return ROCMultiClass.fromJson(json, ROCMultiClass.class);
    }

    public double scoreForMetric(Metric metric, int idx) {
        this.assertIndex(idx);
        switch (metric) {
            case AUROC: {
                return this.calculateAUC(idx);
            }
            case AUPRC: {
                return this.calculateAUCPR(idx);
            }
        }
        throw new IllegalStateException("Unknown metric: " + metric);
    }

    @Override
    public double getValue(IMetric metric) {
        if (metric instanceof Metric) {
            if (metric == Metric.AUPRC) {
                return this.calculateAverageAUCPR();
            }
            if (metric == Metric.AUROC) {
                return this.calculateAverageAUC();
            }
            throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
        }
        throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
    }

    @Override
    public ROCMultiClass newInstance() {
        return new ROCMultiClass(this.axis, this.thresholdSteps, this.rocRemoveRedundantPts, this.labels);
    }

    public int getThresholdSteps() {
        return this.thresholdSteps;
    }

    public boolean isRocRemoveRedundantPts() {
        return this.rocRemoveRedundantPts;
    }

    public ROC[] getUnderlying() {
        return this.underlying;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public void setThresholdSteps(int thresholdSteps) {
        this.thresholdSteps = thresholdSteps;
    }

    public void setRocRemoveRedundantPts(boolean rocRemoveRedundantPts) {
        this.rocRemoveRedundantPts = rocRemoveRedundantPts;
    }

    public void setUnderlying(ROC[] underlying) {
        this.underlying = underlying;
    }

    public void setLabels(List<String> labels) {
        this.labels = labels;
    }

    @Override
    public String toString() {
        return "ROCMultiClass(thresholdSteps=" + this.getThresholdSteps() + ", rocRemoveRedundantPts=" + this.isRocRemoveRedundantPts() + ", underlying=" + Arrays.deepToString(this.getUnderlying()) + ", labels=" + this.getLabels() + ", axis=" + this.getAxis() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ROCMultiClass)) {
            return false;
        }
        ROCMultiClass other = (ROCMultiClass)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getThresholdSteps() != other.getThresholdSteps()) {
            return false;
        }
        if (this.isRocRemoveRedundantPts() != other.isRocRemoveRedundantPts()) {
            return false;
        }
        if (!Arrays.deepEquals(this.getUnderlying(), other.getUnderlying())) {
            return false;
        }
        List<String> this$labels = this.getLabels();
        List<String> other$labels = other.getLabels();
        return !(this$labels == null ? other$labels != null : !((Object)this$labels).equals(other$labels));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof ROCMultiClass;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + this.getThresholdSteps();
        result = result * 59 + (this.isRocRemoveRedundantPts() ? 79 : 97);
        result = result * 59 + Arrays.deepHashCode(this.getUnderlying());
        List<String> $labels = this.getLabels();
        result = result * 59 + ($labels == null ? 43 : ((Object)$labels).hashCode());
        return result;
    }

    public static enum Metric implements IMetric
    {
        AUROC,
        AUPRC;


        @Override
        public Class<? extends IEvaluation> getEvaluationClass() {
            return ROCMultiClass.class;
        }

        @Override
        public boolean minimize() {
            return false;
        }
    }
}

