/*
 * Decompiled with CFR 0.152.
 */
package water.util;

import hex.Interaction;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.AtomicUtils;
import water.util.Log;
import water.util.TwoDimTable;

public class Tabulate
extends Keyed<Tabulate> {
    public final Job<Tabulate> _job;
    public Frame _dataset;
    public Key[] _vecs = new Key[2];
    public String _predictor;
    public String _response;
    public String _weight;
    int _nbins_predictor = 20;
    int _nbins_response = 10;
    double[][] _count_data;
    double[][] _response_data;
    public TwoDimTable _count_table;
    public TwoDimTable _response_table;
    private final Stats[] _stats = new Stats[2];

    public Tabulate() {
        this._job = new Job(Key.make(), Tabulate.class.getName(), "Tabulate job");
    }

    private int bins(int v) {
        return v == 1 ? this._nbins_response : this._nbins_predictor;
    }

    private int res(int v) {
        int missing = this._stats[v]._missing;
        if (this._stats[v]._isCategorical) {
            return this._stats[v]._cardinality + missing;
        }
        return this.bins(v) + missing;
    }

    private int bin(int v, double val) {
        int b;
        if (Double.isNaN(val)) {
            return 0;
        }
        int bins = this.bins(v);
        if (this._stats[v]._isCategorical) {
            assert ((double)((int)val) == val);
            b = (int)val;
        } else {
            double d = (this._stats[v]._max - this._stats[v]._min) / (double)bins;
            b = (int)((val - this._stats[v]._min) / d);
            assert (b >= 0 && b <= bins);
            b = Math.min(b, bins - 1);
        }
        return b + this._stats[v]._missing;
    }

    private String labelForBin(int v, int b) {
        int missing = this._stats[v]._missing;
        if (missing == 1 && b == 0) {
            return "missing(NA)";
        }
        if (missing == 1) {
            --b;
        }
        if (this._stats[v]._isCategorical) {
            return this._stats[v]._domain[b];
        }
        int bins = this.bins(v);
        if (this._stats[v]._isInt && this._stats[v]._max - this._stats[v]._min + 1.0 <= (double)bins) {
            return Integer.toString((int)(this._stats[v]._min + (double)b));
        }
        double d = (this._stats[v]._max - this._stats[v]._min) / (double)bins;
        return String.format("%5f", this._stats[v]._min + ((double)b + 0.5) * d);
    }

    public Tabulate execImpl() {
        Vec w;
        if (this._dataset == null) {
            throw new H2OIllegalArgumentException("Dataset not found");
        }
        if (this._nbins_predictor < 1) {
            throw new H2OIllegalArgumentException("Number of bins for predictor must be >= 1");
        }
        if (this._nbins_response < 1) {
            throw new H2OIllegalArgumentException("Number of bins for response must be >= 1");
        }
        Vec x = this._dataset.vec(this._predictor);
        if (x == null) {
            throw new H2OIllegalArgumentException("Predictor column " + this._predictor + " not found");
        }
        if (x.cardinality() > this._nbins_predictor) {
            Interaction in = new Interaction();
            in._source_frame = this._dataset._key;
            in._factor_columns = new String[]{this._predictor};
            in._max_factors = this._nbins_predictor - 1;
            in.execImpl(null);
            x = ((Frame)in._job._result.get()).anyVec();
        } else if (x.isInt() && x.max() - x.min() + 1.0 <= (double)this._nbins_predictor) {
            x = x.toCategoricalVec();
        }
        Vec y = this._dataset.vec(this._response);
        if (y == null) {
            throw new H2OIllegalArgumentException("Response column " + this._response + " not found");
        }
        if (y.cardinality() > this._nbins_response) {
            Interaction in = new Interaction();
            in._source_frame = this._dataset._key;
            in._factor_columns = new String[]{this._response};
            in._max_factors = this._nbins_response - 1;
            in.execImpl(null);
            y = ((Frame)in._job._result.get()).anyVec();
        } else if (y.isInt() && y.max() - y.min() + 1.0 <= (double)this._nbins_response) {
            y = y.toCategoricalVec();
        }
        if (y != null && y.cardinality() > 2) {
            Log.warn("Response column has more than two factor levels - mean response depends on lexicographic order of factors!");
        }
        if ((w = this._dataset.vec(this._weight)) != null && !w.isNumeric() && w.min() < 0.0) {
            throw new H2OIllegalArgumentException("Observation weights must be numeric with values >= 0");
        }
        if (x != null) {
            this._vecs[0] = x._key;
            this._stats[0] = new Stats(x);
        }
        if (y != null) {
            this._vecs[1] = y._key;
            this._stats[1] = new Stats(y);
        }
        Tabulate sp = w != null ? ((CoOccurrence)new CoOccurrence((Tabulate)this).doAll((Vec[])new Vec[]{x, y, w}))._sp : ((CoOccurrence)new CoOccurrence((Tabulate)this).doAll((Vec[])new Vec[]{x, y}))._sp;
        this._count_table = sp.tabulationTwoDimTable();
        this._response_table = sp.responseCharTwoDimTable();
        Log.info(this._count_table.toString(2, false));
        Log.info(this._response_table.toString(2, false));
        return sp;
    }

    public TwoDimTable tabulationTwoDimTable() {
        if (this._response_data == null) {
            return null;
        }
        int predN = this._count_data.length;
        int respN = this._count_data[0].length;
        String tableHeader = "(Weighted) co-occurrence counts of '" + this._predictor + "' and '" + this._response + "'";
        String[] rowHeaders = new String[predN * respN];
        String[] colHeaders = new String[3];
        String[] colTypes = new String[colHeaders.length];
        String[] colFormats = new String[colHeaders.length];
        colHeaders[0] = this._predictor;
        colHeaders[1] = this._response;
        colTypes[0] = "string";
        colFormats[0] = "%s";
        colTypes[1] = "string";
        colFormats[1] = "%s";
        colHeaders[2] = "counts";
        colTypes[2] = "double";
        colFormats[2] = "%f";
        TwoDimTable table = new TwoDimTable(tableHeader, null, rowHeaders, colHeaders, colTypes, colFormats, null);
        for (int p = 0; p < predN; ++p) {
            String plabel = this.labelForBin(0, p);
            for (int r = 0; r < respN; ++r) {
                String rlabel = this.labelForBin(1, r);
                for (int c = 0; c < 3; ++c) {
                    table.set(r * predN + p, 0, plabel);
                    table.set(r * predN + p, 1, rlabel);
                    table.set(r * predN + p, 2, this._count_data[p][r]);
                }
            }
        }
        return table;
    }

    public TwoDimTable responseCharTwoDimTable() {
        if (this._response_data == null) {
            return null;
        }
        String tableHeader = "Mean value of '" + this._response + "' and (weighted) counts for '" + this._predictor + "' values";
        int predN = this._count_data.length;
        String[] rowHeaders = new String[predN];
        String[] colHeaders = new String[3];
        String[] colTypes = new String[colHeaders.length];
        String[] colFormats = new String[colHeaders.length];
        colHeaders[0] = this._predictor;
        colTypes[0] = "string";
        colFormats[0] = "%s";
        colHeaders[1] = "mean " + this._response;
        colTypes[2] = "double";
        colFormats[2] = "%f";
        colHeaders[2] = "counts";
        colTypes[1] = "double";
        colFormats[1] = "%f";
        TwoDimTable table = new TwoDimTable(tableHeader, null, rowHeaders, colHeaders, colTypes, colFormats, null);
        for (int p = 0; p < predN; ++p) {
            String plabel = this.labelForBin(0, p);
            table.set(p, 0, plabel);
            table.set(p, 1, this._response_data[p][0]);
            table.set(p, 2, this._response_data[p][1]);
        }
        return table;
    }

    private static class CoOccurrence
    extends MRTask<CoOccurrence> {
        final Tabulate _sp;

        CoOccurrence(Tabulate sp) {
            this._sp = sp;
        }

        @Override
        protected void setupLocal() {
            this._sp._count_data = new double[this._sp.res(0)][this._sp.res(1)];
            this._sp._response_data = new double[this._sp.res(0)][2];
        }

        @Override
        public void map(Chunk x, Chunk y) {
            this.map(x, y, null);
        }

        @Override
        public void map(Chunk x, Chunk y, Chunk w) {
            for (int r = 0; r < x.len(); ++r) {
                double weight;
                int xbin = this._sp.bin(0, x.atd(r));
                int ybin = this._sp.bin(1, y.atd(r));
                double d = weight = w != null ? w.atd(r) : 1.0;
                if (Double.isNaN(weight)) continue;
                AtomicUtils.DoubleArray.add(this._sp._count_data[xbin], ybin, weight);
                if (y.isNA(r)) continue;
                AtomicUtils.DoubleArray.add(this._sp._response_data[xbin], 0, weight * y.atd(r));
                AtomicUtils.DoubleArray.add(this._sp._response_data[xbin], 1, weight);
            }
        }

        @Override
        public void reduce(CoOccurrence mrt) {
            if (this._sp._response_data == mrt._sp._response_data) {
                return;
            }
            ArrayUtils.add(this._sp._response_data, mrt._sp._response_data);
        }

        @Override
        protected void postGlobal() {
            for (int i = 0; i < this._sp._response_data.length; ++i) {
                double[] dArray = this._sp._response_data[i];
                dArray[0] = dArray[0] / this._sp._response_data[i][1];
            }
        }
    }

    private static class Stats
    extends Iced {
        final double _min;
        final double _max;
        final boolean _isCategorical;
        final boolean _isInt;
        final int _cardinality;
        final int _missing;
        final String[] _domain;

        Stats(Vec v) {
            this._min = v.min();
            this._max = v.max();
            this._isCategorical = v.isCategorical();
            this._isInt = v.isInt();
            this._cardinality = v.cardinality();
            this._missing = v.naCnt() > 0L ? 1 : 0;
            this._domain = v.domain();
        }
    }
}

