/*
 * Decompiled with CFR 0.152.
 */
package com.tencent.angel.ml.core.utils;

import com.tencent.angel.exception.AngelException;
import com.tencent.angel.ml.core.graphsubmit.GraphPredictResult;
import com.tencent.angel.ml.core.graphsubmit.SoftmaxPredictResult;
import com.tencent.angel.ml.core.optimizer.loss.LossFunc;
import com.tencent.angel.ml.feature.LabeledData;
import com.tencent.angel.ml.model.MLModel;
import com.tencent.angel.ml.predict.PredictResult;
import com.tencent.angel.utils.Sort;
import com.tencent.angel.worker.storage.DataBlock;
import it.unimi.dsi.fastutil.doubles.DoubleComparator;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple5;

public class ValidationUtils {
    private static final Log LOG = LogFactory.getLog(ValidationUtils.class);
    private static DoubleComparator cmp = new DoubleComparator(){

        public int compare(double i, double i1) {
            if (Math.abs(i - i1) < 1.0E-11) {
                return 0;
            }
            return i - i1 > 1.0E-11 ? 1 : -1;
        }

        public int compare(Double o1, Double o2) {
            if (Math.abs(o1 - o2) < 1.0E-11) {
                return 0;
            }
            return o1 - o2 > 1.0E-11 ? 1 : -1;
        }
    };
    private DataBlock<?> predicted;
    private double[] labels;
    private int totalNum;

    public ValidationUtils(DataBlock<LabeledData> dataBlock, MLModel model) {
        long startTime = System.currentTimeMillis();
        this.totalNum = dataBlock.size();
        this.predicted = model.predict(dataBlock);
        this.labels = new double[this.totalNum];
        try {
            for (int i = 0; i < this.totalNum; ++i) {
                this.labels[i] = ((LabeledData)dataBlock.loopingRead()).getY();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        long cost = System.currentTimeMillis() - startTime;
        LOG.debug((Object)String.format("validate samples is %s, and the cost is %d ms", this.totalNum, cost));
    }

    public double calLossPrecision(LossFunc lossFunc) throws IOException, InterruptedException {
        long startTime = System.currentTimeMillis();
        double loss = 0.0;
        int truePos = 0;
        int falsePos = 0;
        int trueNeg = 0;
        int falseNeg = 0;
        for (int i = 0; i < this.totalNum; ++i) {
            PredictResult predRes = (PredictResult)this.predicted.get(i);
            if (predRes.pred() * this.labels[i] >= 0.0) {
                if (predRes.pred() > 0.0) {
                    ++truePos;
                } else {
                    ++trueNeg;
                }
            } else if (predRes.pred() > 0.0) {
                ++falsePos;
            } else {
                ++falseNeg;
            }
            loss += lossFunc.loss(predRes.pred(), this.labels[i]);
        }
        long cost = System.currentTimeMillis() - startTime;
        double precision = (double)(truePos + trueNeg) / (double)this.totalNum;
        double trueRecall = (double)truePos / (double)(truePos + falseNeg);
        double falseRecall = (double)trueNeg / (double)(trueNeg + falsePos);
        LOG.debug((Object)String.format("validate cost %d ms, loss= %.5f, precision=%.5f, trueRecall=%.5f, falseRecall=%.5f", cost, loss, precision, trueRecall, falseRecall));
        LOG.debug((Object)String.format("Validation TP=%d, TN=%d, FP=%d, FN=%d", truePos, trueNeg, falsePos, falseNeg));
        return loss;
    }

    public Tuple2<Double, Double> calMulMetrics(LossFunc lossFunc) throws IOException, InterruptedException {
        long startTime = System.currentTimeMillis();
        double loss = 0.0;
        HashMap<Double, Double> truePos = new HashMap<Double, Double>();
        for (int i = 0; i < this.totalNum; ++i) {
            PredictResult predRes = (PredictResult)this.predicted.get(i);
            if (predRes.label() == this.labels[i]) {
                double count = truePos.getOrDefault(this.labels[i], 0.0);
                truePos.put(this.labels[i], count + 1.0);
            }
            if (predRes instanceof GraphPredictResult) {
                loss += lossFunc.loss(predRes.proba(), this.labels[i]);
                continue;
            }
            if (predRes instanceof SoftmaxPredictResult) {
                loss += lossFunc.loss(((SoftmaxPredictResult)predRes).trueProba(), this.labels[i]);
                continue;
            }
            throw new AngelException("PredictResult Error!");
        }
        long cost = System.currentTimeMillis() - startTime;
        double sum = 0.0;
        Iterator iterator = truePos.values().iterator();
        while (iterator.hasNext()) {
            double count = (Double)iterator.next();
            sum += count;
        }
        double accuracy = sum / (double)this.totalNum;
        LOG.debug((Object)String.format("validate cost %d ms, loss= %.5f, accuracy=%.5f", cost, loss, accuracy));
        return new Tuple2((Object)loss, (Object)accuracy);
    }

    public Tuple5<Double, Double, Double, Double, Double> calMetrics(LossFunc lossFunc) throws IOException, InterruptedException {
        LOG.debug((Object)("Start calculate loss and auc, sample number: " + this.totalNum));
        long startTime = System.currentTimeMillis();
        double loss = 0.0;
        double[] scoresArray = new double[this.totalNum];
        double[] labelsArray = new double[this.totalNum];
        double truePos = 0.0;
        double falsePos = 0.0;
        double trueNeg = 0.0;
        double falseNeg = 0.0;
        for (int i = 0; i < this.totalNum; ++i) {
            PredictResult predRes = (PredictResult)this.predicted.get(i);
            if (predRes.pred() * this.labels[i] >= 0.0) {
                if (predRes.pred() > 0.0) {
                    truePos += 1.0;
                } else {
                    trueNeg += 1.0;
                }
            } else if (predRes.pred() > 0.0) {
                falsePos += 1.0;
            } else {
                falseNeg += 1.0;
            }
            scoresArray[i] = predRes.proba();
            labelsArray[i] = this.labels[i];
            loss += lossFunc.loss(predRes.pred(), this.labels[i]);
        }
        double precision = (truePos + trueNeg) / (double)this.totalNum;
        Tuple3<Double, Double, Double> tuple3 = this.calAUC(scoresArray, labelsArray, truePos, trueNeg, falsePos, falseNeg);
        double aucResult = (Double)tuple3._1();
        double trueRecall = (Double)tuple3._2();
        double falseRecall = (Double)tuple3._3();
        long cost = System.currentTimeMillis() - startTime;
        LOG.debug((Object)String.format("validate cost %d ms, loss= %.5f, precision=%.5f, trueRecall=%.5f, falseRecall=%.5f", cost, loss, precision, trueRecall, falseRecall));
        return new Tuple5((Object)loss, (Object)precision, (Object)aucResult, (Object)trueRecall, (Object)falseRecall);
    }

    private Tuple3<Double, Double, Double> calAUC(double[] scoresArray, double[] labelsArray, double truePos, double trueNeg, double falsePos, double falseNeg) {
        long startTime = System.currentTimeMillis();
        Sort.quickSort((double[])scoresArray, (double[])labelsArray, (int)0, (int)this.totalNum, (DoubleComparator)cmp);
        LOG.debug((Object)("Sort cost " + (System.currentTimeMillis() - startTime) + "ms, Scores list size: " + scoresArray.length + ", sorted values:" + scoresArray[0] + "," + scoresArray[scoresArray.length / 5] + "," + scoresArray[scoresArray.length / 3] + "," + scoresArray[scoresArray.length / 2] + "," + scoresArray[scoresArray.length - 1]));
        long M = 1L;
        long N = 1L;
        for (int i = 0; i < this.totalNum; ++i) {
            if (labelsArray[i] == 1.0) {
                ++M;
                continue;
            }
            ++N;
        }
        double sigma = 0.0;
        for (long i = (long)(this.totalNum - 1); i >= 0L; --i) {
            if (labelsArray[(int)i] != 1.0) continue;
            sigma += (double)(i + 1L);
        }
        double aucResult = (sigma - (double)((M + 1L) * M / 2L)) / (double)M / (double)N;
        LOG.debug((Object)("M = " + M + ", N = " + N + ", sigma = " + sigma + ", AUC = " + aucResult));
        double trueRecall = truePos / (truePos + falseNeg);
        double falseRecall = trueNeg / (trueNeg + falsePos);
        LOG.debug((Object)String.format("validate cost %d ms, auc=%.3f, trueRecall=%.3f, falseRecall=%.3f", System.currentTimeMillis() - startTime, aucResult, trueRecall, falseRecall));
        LOG.debug((Object)String.format("Validation TP=%.0f, TN=%.0f, FP=%.0f, FN=%.0f", truePos, trueNeg, falsePos, falseNeg));
        return new Tuple3((Object)aucResult, (Object)trueRecall, (Object)falseRecall);
    }

    public Tuple4<Double, Double, Double, Double> calMSER2() throws IOException, InterruptedException {
        long startTime = System.currentTimeMillis();
        double uLoss = 0.0;
        double maeLossSum = 0.0;
        double trueSum2 = 0.0;
        double trueSum = 0.0;
        for (int i = 0; i < this.totalNum; ++i) {
            PredictResult predRes = (PredictResult)this.predicted.get(i);
            uLoss += Math.pow(predRes.pred() - this.labels[i], 2.0);
            maeLossSum += Math.abs(predRes.pred() - this.labels[i]);
            trueSum2 += Math.pow(this.labels[i], 2.0);
            trueSum += this.labels[i];
        }
        double MSE = uLoss / (double)this.totalNum;
        double RMSE = Math.sqrt(MSE);
        double MAE = maeLossSum / (double)this.totalNum;
        double trueAvg = trueSum / (double)this.totalNum;
        double R2 = 1.0 - uLoss / (trueSum2 - (double)this.totalNum * trueAvg * trueAvg);
        LOG.info((Object)String.format("validate %d samples cost %d ms, MSE= %.5f ,RMSE= %.5f ,MAE=%.5f ,R2= %.5f", this.totalNum, System.currentTimeMillis() - startTime, MSE, RMSE, MAE, R2));
        return new Tuple4((Object)MSE, (Object)RMSE, (Object)MAE, (Object)R2);
    }
}

