/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.models;

import hex.AUC2;
import hex.Model;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.HypergeometricDistribution;
import org.apache.commons.math3.stat.inference.GTest;
import water.DKV;
import water.Key;
import water.MRTask;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValMapFrame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class AstFairnessMetrics
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"model", "test_frame", "protected_columns", "reference", "favourable_class"};
    }

    @Override
    public int nargs() {
        return 6;
    }

    @Override
    public String str() {
        return "fairnessMetrics";
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public ValMapFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Model model = stk.track(asts[1].exec(env)).getModel();
        Frame fr = stk.track(asts[2].exec(env).getFrame());
        String[] protectedCols = stk.track(asts[3].exec(env)).getStrs();
        String[] reference = stk.track(asts[4].exec(env)).getStrs();
        String favourableClass = stk.track(asts[5].exec(env)).getStr();
        String frameName = asts[2].str();
        int responseIdx = fr.find(((Model.Parameters)model._parms)._response_column);
        if (!((Model.Output)model._output).isBinomialClassifier()) {
            throw new H2OIllegalArgumentException("Model has to be a binomial model!");
        }
        for (String pc : protectedCols) {
            if (fr.find(pc) == -1) {
                throw new RuntimeException(pc + " was not found in the frame!");
            }
            if (fr.vec(pc).isCategorical()) continue;
            throw new H2OIllegalArgumentException(pc + " has to be a categorical column!");
        }
        if (reference.length != protectedCols.length) {
            reference = null;
        } else {
            for (int i = 0; i < protectedCols.length; ++i) {
                if (ArrayUtils.contains(fr.vec(protectedCols[i]).domain(), reference[i])) continue;
                throw new RuntimeException("Reference group is not present in the protected column");
            }
        }
        if (!ArrayUtils.contains(fr.vec(responseIdx).domain(), favourableClass)) {
            throw new RuntimeException("Favourable class is not present in the response!");
        }
        int favorableClassId = ArrayUtils.find(fr.vec(responseIdx).domain(), favourableClass);
        int[] protectedColsIdx = fr.find(protectedCols);
        int[] cardinalities = IntStream.of(protectedColsIdx).map(colId -> fr.vec(colId).cardinality() + 1).toArray();
        if (Arrays.stream(cardinalities).asDoubleStream().reduce((a, b) -> a * b).orElse(Double.MAX_VALUE) > 1000000.0) {
            throw new RuntimeException("Too many combinations of categories! Maximum number of category combinations is 1e6.");
        }
        Frame predictions = new Frame(fr).add(model.score(fr));
        DKV.put(predictions);
        try {
            FairnessMRTask fairnessMRTask = (FairnessMRTask)new FairnessMRTask(protectedColsIdx, cardinalities, responseIdx, fr.numCols(), favorableClassId).doAll(predictions);
            Frame metrics = fairnessMRTask.getMetrics(protectedCols, fr, model, reference, frameName);
            Map<String, Frame> results = fairnessMRTask.getROCInfo(model, fr, frameName);
            DKV.put(metrics);
            results.put("overview", metrics);
            ValMapFrame valMapFrame = new ValMapFrame(results);
            return valMapFrame;
        }
        finally {
            DKV.remove(predictions.getKey());
        }
    }

    public static class FairnessMRTask
    extends MRTask {
        public static final int GTEST_THRESHOLD = 10000;
        public static final double FISHER_TEST_REL_ERROR = 1.0000001;
        int[] protectedColsIdx;
        int[] cardinalities;
        int responseIdx;
        int predictionIdx;
        final int tpIdx = 0;
        final int tnIdx = 1;
        final int fpIdx = 2;
        final int fnIdx = 3;
        final int llsIdx = 4;
        final int essentialMetrics = 5;
        final int maxIndex;
        final int favourableClass;
        int[] _results;
        AUC2.AUCBuilder[] _aucs;

        public FairnessMRTask(int[] protectedColsIdx, int[] cardinalities, int responseIdx, int predictionIdx, int favourableClass) {
            this.protectedColsIdx = protectedColsIdx;
            this.cardinalities = cardinalities;
            this.responseIdx = responseIdx;
            this.predictionIdx = predictionIdx;
            this.favourableClass = favourableClass;
            double maxIndexDbl = Arrays.stream(cardinalities).asDoubleStream().reduce((double a, double b) -> a * b).getAsDouble();
            if (maxIndexDbl > 2.147483647E9) {
                throw new RuntimeException("Too many combinations of categories! Maximum number of category combinations is 2147483647!");
            }
            this.maxIndex = (int)maxIndexDbl;
        }

        private int pColsToKey(Chunk[] cs, int row) {
            int[] indices = new int[this.protectedColsIdx.length];
            for (int i = 0; i < this.protectedColsIdx.length; ++i) {
                if (cs[this.protectedColsIdx[i]].isNA(row)) {
                    indices[i] = this.cardinalities[i] - 1;
                    continue;
                }
                int n = i;
                indices[n] = (int)((long)indices[n] + cs[this.protectedColsIdx[i]].at8(row));
            }
            return this.pColsToKey(indices);
        }

        public int pColsToKey(int[] indices) {
            int result = 0;
            int base = 1;
            for (int i = 0; i < this.protectedColsIdx.length; ++i) {
                result += indices[i] * base;
                base *= this.cardinalities[i];
            }
            return result;
        }

        private double[] keyToPCols(int value) {
            double[] result = new double[this.cardinalities.length];
            for (int i = 0; i < this.cardinalities.length; ++i) {
                int tmp = value % this.cardinalities[i];
                value /= this.cardinalities[i];
                result[i] = tmp == this.cardinalities[i] - 1 ? Double.NaN : (double)tmp;
            }
            return result;
        }

        protected String keyToString(int value, Frame fr) {
            double[] pcolIdx = this.keyToPCols(value);
            StringBuilder result = new StringBuilder();
            for (int i = 0; i < this.protectedColsIdx.length; ++i) {
                if (i > 0) {
                    result.append(",");
                }
                if (Double.isFinite(pcolIdx[i])) {
                    result.append(fr.vec(this.protectedColsIdx[i]).domain()[(int)pcolIdx[i]]);
                    continue;
                }
                result.append("NaN");
            }
            return result.toString().replaceAll("[^A-Za-z0-9,]", "_");
        }

        @Override
        public void map(Chunk[] cs) {
            assert (this._results == null);
            this._results = new int[this.maxIndex * 5];
            this._aucs = new AUC2.AUCBuilder[this.maxIndex];
            for (int i = 0; i < cs[0]._len; ++i) {
                double predictionProb;
                int key = this.pColsToKey(cs, i);
                long response = this.favourableClass == 1 ? cs[this.responseIdx].at8(i) : 1L - cs[this.responseIdx].at8(i);
                long prediction = this.favourableClass == 1 ? cs[this.predictionIdx].at8(i) : 1L - cs[this.predictionIdx].at8(i);
                double d = predictionProb = this.favourableClass == 1 ? cs[this.predictionIdx + 2].atd(i) : cs[this.predictionIdx + 1].atd(i);
                if (response == prediction) {
                    if (response == 1L) {
                        int n = 5 * key + 0;
                        this._results[n] = this._results[n] + 1;
                    } else {
                        int n = 5 * key + 1;
                        this._results[n] = this._results[n] + 1;
                    }
                } else if (prediction == 1L) {
                    int n = 5 * key + 2;
                    this._results[n] = this._results[n] + 1;
                } else {
                    int n = 5 * key + 3;
                    this._results[n] = this._results[n] + 1;
                }
                int n = 5 * key + 4;
                this._results[n] = (int)((double)this._results[n] + -((double)response * Math.log(predictionProb) + (double)(1L - response) * Math.log(1.0 - predictionProb)));
                if (this._aucs[key] == null) {
                    this._aucs[key] = new AUC2.AUCBuilder(400);
                }
                this._aucs[key].perRow(predictionProb, (int)response, 1.0);
            }
        }

        public void reduce(MRTask mrt) {
            int i;
            FairnessMRTask other = (FairnessMRTask)mrt;
            if (this._results == other._results) {
                return;
            }
            for (i = 0; i < this._results.length; ++i) {
                int n = i;
                this._results[n] = this._results[n] + other._results[i];
            }
            for (i = 0; i < this.maxIndex; ++i) {
                if (this._aucs[i] == null) {
                    this._aucs[i] = other._aucs[i];
                    continue;
                }
                if (other._aucs[i] == null) continue;
                this._aucs[i].reduce(other._aucs[i]);
            }
        }

        public Frame getMetrics(String[] protectedCols, Frame fr, Model model, String[] reference, String frName) {
            int i;
            int counter;
            FairnessMetrics[] results = new FairnessMetrics[this.maxIndex];
            long nrows = fr.numRows();
            for (int i2 = 0; i2 < this.maxIndex; ++i2) {
                results[i2] = new FairnessMetrics(this._results[i2 * 5 + 0], this._results[i2 * 5 + 1], this._results[i2 * 5 + 2], this._results[i2 * 5 + 3], this._results[i2 * 5 + 4], this._aucs[i2], nrows);
            }
            int referenceIdx = 0;
            if (reference != null) {
                int[] indices = new int[protectedCols.length];
                for (int i3 = 0; i3 < protectedCols.length; ++i3) {
                    indices[i3] = ArrayUtils.find(fr.vec(protectedCols[i3]).domain(), reference[i3]);
                }
                referenceIdx = this.pColsToKey(indices);
            } else {
                double max = 0.0;
                for (int key = 0; key < this.maxIndex; ++key) {
                    if (!(results[key].total > max)) continue;
                    max = results[key].total;
                    referenceIdx = key;
                }
            }
            int emptyResults = 0;
            for (FairnessMetrics fm : results) {
                emptyResults += fm.total == 0.0 ? 1 : 0;
            }
            String[] skipAIR = new String[]{"total", "relativeSize"};
            Field[] metrics = FairnessMetrics.class.getDeclaredFields();
            int protectedColsCnt = protectedCols.length;
            int metricsCount = metrics.length + (metrics.length - skipAIR.length) + 1;
            double[][] resultCols = new double[protectedColsCnt + metricsCount][results.length - emptyResults];
            FairnessMetrics ref = results[referenceIdx];
            int nonEmptyKey = 0;
            for (int key = 0; key < this.maxIndex; ++key) {
                int i4;
                if (results[key].total == 0.0) continue;
                counter = 0;
                double[] decodedKey = this.keyToPCols(key);
                for (i4 = 0; i4 < protectedCols.length; ++i4) {
                    resultCols[i4][nonEmptyKey] = decodedKey[i4];
                }
                for (i4 = 0; i4 < metrics.length; ++i4) {
                    try {
                        resultCols[protectedColsCnt + i4][nonEmptyKey] = metrics[i4].getDouble(results[key]);
                        if (!ArrayUtils.contains(skipAIR, metrics[i4].getName())) {
                            double air;
                            resultCols[protectedColsCnt + metrics.length + i4 - counter][nonEmptyKey] = air = metrics[i4].getDouble(results[key]) / metrics[i4].getDouble(ref);
                            continue;
                        }
                        ++counter;
                        continue;
                    }
                    catch (IllegalAccessException e) {
                        throw new RuntimeException(e);
                    }
                }
                try {
                    resultCols[resultCols.length - 1][nonEmptyKey] = FairnessMRTask.getPValue(ref, results[key]);
                }
                catch (Exception e) {
                    resultCols[resultCols.length - 1][nonEmptyKey] = Double.NaN;
                }
                ++nonEmptyKey;
            }
            String[] colNames = new String[protectedColsCnt + metricsCount];
            System.arraycopy(protectedCols, 0, colNames, 0, protectedCols.length);
            counter = 0;
            for (int i5 = 0; i5 < metrics.length; ++i5) {
                colNames[protectedColsCnt + i5] = metrics[i5].getName();
                if (!ArrayUtils.contains(skipAIR, metrics[i5].getName())) {
                    colNames[protectedColsCnt + metrics.length + i5 - counter] = "AIR_" + metrics[i5].getName();
                    continue;
                }
                ++counter;
            }
            colNames[colNames.length - 1] = "p.value";
            Vec[] vecs = new Vec[protectedColsCnt + metricsCount];
            for (i = 0; i < protectedColsCnt; ++i) {
                vecs[i] = Vec.makeVec(resultCols[i], fr.domains()[this.protectedColsIdx[i]], Vec.newKey());
            }
            for (i = 0; i < metricsCount; ++i) {
                vecs[protectedColsCnt + i] = Vec.makeVec(resultCols[protectedColsCnt + i], Vec.newKey());
            }
            return new Frame(Key.make("fairness_metrics_" + frName + "_for_model_" + model._key), colNames, vecs);
        }

        public Map<String, Frame> getROCInfo(Model model, Frame fr, String frName) {
            HashMap<String, Frame> result = new HashMap<String, Frame>();
            for (int id = 0; id < this.maxIndex; ++id) {
                int i;
                if (this._aucs[id] == null) continue;
                AUC2 auc = new AUC2(this._aucs[id]);
                String[] thresholds = new String[auc._nBins];
                for (int i2 = 0; i2 < auc._nBins; ++i2) {
                    thresholds[i2] = Double.toString(auc._ths[i2]);
                }
                AUC2.ThresholdCriterion[] crits = AUC2.ThresholdCriterion.VALUES;
                String[] colHeaders = new String[crits.length + 2];
                String[] types = new String[crits.length + 2];
                String[] formats = new String[crits.length + 2];
                colHeaders[0] = "Threshold";
                types[0] = "double";
                formats[0] = "%f";
                for (i = 0; i < crits.length; ++i) {
                    colHeaders[i + 1] = crits[i].toString();
                    types[i + 1] = crits[i]._isInt ? "long" : "double";
                    formats[i + 1] = crits[i]._isInt ? "%d" : "%f";
                }
                colHeaders[i + 1] = "idx";
                types[i + 1] = "int";
                formats[i + 1] = "%d";
                TwoDimTable thresholdsByMetrics = new TwoDimTable("Metrics for Thresholds", "Binomial metrics as a function of classification thresholds", new String[auc._nBins], colHeaders, types, formats, null);
                for (i = 0; i < auc._nBins; ++i) {
                    int j = 0;
                    thresholdsByMetrics.set(i, j, Double.valueOf(thresholds[i]));
                    for (j = 0; j < crits.length; ++j) {
                        double d = crits[j].exec(auc, i);
                        thresholdsByMetrics.set(i, 1 + j, crits[j]._isInt ? (Number)((long)d) : (Number)d);
                    }
                    thresholdsByMetrics.set(i, 1 + j, i);
                }
                String groupName = this.keyToString(id, fr);
                Frame f = thresholdsByMetrics.asFrame(Key.make("thresholds_and_metrics_" + groupName + "_for_model_" + model._key + "_for_frame_" + frName));
                DKV.put(f);
                result.put("thresholds_and_metrics_" + groupName, f);
            }
            return result;
        }

        private static double fishersTest(long a, long b, long c, long d) {
            long popSize = a + b + c + d;
            if (popSize > Integer.MAX_VALUE) {
                return Double.NaN;
            }
            HypergeometricDistribution hgd = new HypergeometricDistribution((int)popSize, (int)(a + b), (int)(a + c));
            double p = hgd.probability((int)a);
            double pValue = 0.0;
            int i = (int)Math.max(a - d, 0L);
            while ((long)i <= Math.min(a + b, a + c)) {
                double proposal = hgd.probability(i);
                if (proposal <= p * 1.0000001) {
                    pValue += proposal;
                }
                ++i;
            }
            return pValue;
        }

        private static double getPValue(FairnessMetrics ref, FairnessMetrics results) {
            long a = (long)results.selected;
            long b = (long)ref.selected;
            long c = (long)(results.total - results.selected);
            long d = (long)(ref.total - ref.selected);
            if (ref.total < 10000.0 && results.total < 10000.0 || a == 0L || b == 0L || c == 0L || d == 0L) {
                return FairnessMRTask.fishersTest(a, b, c, d);
            }
            return new GTest().gTestDataSetsComparison(new long[]{a, c}, new long[]{b, d});
        }
    }

    public static class FairnessMetrics {
        double tp;
        double fp;
        double tn;
        double fn;
        double total;
        double relativeSize;
        double accuracy;
        double precision;
        double f1;
        double tpr;
        double tnr;
        double fpr;
        double fnr;
        double auc;
        double aucpr;
        double gini;
        double selected;
        double selectedRatio;
        double logloss;

        public FairnessMetrics(double tp, double tn, double fp, double fn, double logLossSum, AUC2.AUCBuilder aucBuilder, double nrows) {
            this.tp = tp;
            this.tn = tn;
            this.fp = fp;
            this.fn = fn;
            this.total = tp + fp + tn + fn;
            this.logloss = logLossSum / this.total;
            this.relativeSize = this.total / nrows;
            this.accuracy = (tp + tn) / this.total;
            this.precision = tp / (fp + tp);
            this.f1 = 2.0 * tp / (2.0 * tp + fp + fn);
            this.tpr = tp / (tp + fn);
            this.tnr = tn / (tn + fp);
            this.fpr = fp / (fp + tn);
            this.fnr = fn / (fn + tp);
            if (aucBuilder != null) {
                AUC2 auc2 = new AUC2(aucBuilder);
                this.auc = auc2._auc;
                this.aucpr = auc2._pr_auc;
                this.gini = auc2._gini;
            } else {
                this.auc = Double.NaN;
                this.aucpr = Double.NaN;
                this.gini = Double.NaN;
            }
            this.selected = tp + fp;
            this.selectedRatio = (tp + fp) / this.total;
        }
    }
}

