/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.Model;
import java.util.ArrayList;
import java.util.Arrays;
import jsr166y.CountedCompleter;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.util.FrameUtils;
import water.util.Log;
import water.util.TwoDimTable;

public class PartialDependence
extends Lockable<PartialDependence> {
    public final transient Job _job;
    public Key<Model> _model_id;
    public Key<Frame> _frame_id;
    public long _row_index = -1L;
    public String[] _cols;
    public ArrayList<String> _cols_1d_2d;
    public int _weight_column_index = -1;
    public boolean _add_missing_na = false;
    public int _nbins = 20;
    public String[] _targets;
    public TwoDimTable[] _partial_dependence_data;
    public double[] _user_splits = null;
    public double[][] _user_split_per_col = null;
    public int[] _num_user_splits = null;
    public String[] _user_cols = null;
    public boolean _user_splits_present = false;
    public String[][] _col_pairs_2dpdp = null;
    public int _num_2D_pairs = 0;
    public int _num_1D = 0;
    public int _predictor_column;
    public int[] _predictor_columns;

    public PartialDependence(Key<PartialDependence> dest, Job j) {
        super(dest);
        this._job = j;
    }

    public PartialDependence(Key<PartialDependence> dest) {
        this(dest, new Job<PartialDependence>(dest, PartialDependence.class.getName(), "PartialDependence"));
    }

    public PartialDependence execNested() {
        this.checkSanityAndFillParams();
        this.delete_and_lock(this._job);
        this._frame_id.get().write_lock(this._job._key);
        new PartialDependenceDriver().compute2();
        return this;
    }

    public Job<PartialDependence> execImpl() {
        this.checkSanityAndFillParams();
        this.delete_and_lock(this._job);
        this._frame_id.get().write_lock(this._job._key);
        this._job.start(new PartialDependenceDriver(), this._num_1D + this._num_2D_pairs);
        return this._job;
    }

    private int findTargetClassPredictorIndex(Model m, String target) {
        int index = Arrays.asList(((Model.Output)m._output).classNames()).indexOf(target);
        if (index == -1) {
            throw new IllegalArgumentException("Incorrect target class: " + target + ".");
        }
        return index + 1;
    }

    private int[] findTargetClassPredictorIndices(Model m, String[] targets) {
        int[] result = new int[targets.length];
        for (int i = 0; i < targets.length; ++i) {
            result[i] = this.findTargetClassPredictorIndex(m, targets[i]);
        }
        return result;
    }

    private void checkSanityAndFillParams() {
        Model m = this._model_id.get();
        if (m == null) {
            throw new IllegalArgumentException("Model not found.");
        }
        if (!((Model.Output)m._output).isSupervised()) {
            throw new IllegalArgumentException("Partial dependence plots are only implemented for supervised models");
        }
        int nclasses = ((Model.Output)m._output).nclasses();
        if (nclasses <= 2 && this._targets != null) {
            throw new IllegalArgumentException("Targets parameter is available only for multinomial classification.");
        }
        if (nclasses == 1) {
            this._predictor_column = 0;
            this._predictor_columns = new int[]{this._predictor_column};
        } else if (nclasses == 2) {
            this._predictor_column = 2;
            this._predictor_columns = new int[]{this._predictor_column};
        } else {
            if (this._targets == null) {
                throw new IllegalArgumentException("Targets parameter has to be set for multinomial classification.");
            }
            this._predictor_columns = this.findTargetClassPredictorIndices(m, this._targets);
        }
        if (this._cols != null || this._col_pairs_2dpdp != null) {
            this._cols_1d_2d = new ArrayList();
            if (this._cols != null) {
                this._cols_1d_2d.addAll(Arrays.asList(this._cols));
            }
            if (this._col_pairs_2dpdp != null) {
                this._num_2D_pairs = this._col_pairs_2dpdp.length * this._predictor_columns.length;
                for (int index = 0; index < this._num_2D_pairs; ++index) {
                    if (!this._cols_1d_2d.contains(this._col_pairs_2dpdp[index][0])) {
                        this._cols_1d_2d.add(this._col_pairs_2dpdp[index][0]);
                    }
                    if (this._cols_1d_2d.contains(this._col_pairs_2dpdp[index][1])) continue;
                    this._cols_1d_2d.add(this._col_pairs_2dpdp[index][1]);
                }
            }
        } else {
            this._cols_1d_2d = null;
        }
        if (this._cols_1d_2d == null) {
            Frame f = this._frame_id.get();
            if (f == null) {
                throw new IllegalArgumentException("Frame not found.");
            }
            if (Model.GetMostImportantFeatures.class.isAssignableFrom(m.getClass())) {
                this._cols = ((Model.GetMostImportantFeatures)((Object)m)).getMostImportantFeatures(10);
                if (this._cols != null) {
                    Log.info("Selecting the top " + this._cols.length + " features from the model's variable importances.");
                }
            } else {
                this._cols = ((Model.Output)m._output)._names;
                if (this._cols != null) {
                    Log.info("Selecting all features from the training data.");
                }
            }
            this._cols_1d_2d = new ArrayList();
            this._cols_1d_2d.addAll(Arrays.asList(this._cols));
        }
        int n = this._num_1D = this._cols == null ? 0 : this._cols.length * this._predictor_columns.length;
        if (this._nbins < 2) {
            throw new IllegalArgumentException("_nbins must be >=2.");
        }
        if (this._user_splits != null && this._user_splits.length > 0) {
            int cindex;
            this._user_splits_present = true;
            int numUserSplits = this._user_cols.length;
            this._user_split_per_col = new double[numUserSplits][];
            int[] user_splits_start = new int[numUserSplits];
            for (cindex = 1; cindex < numUserSplits; ++cindex) {
                user_splits_start[cindex] = this._num_user_splits[cindex - 1] + user_splits_start[cindex - 1];
            }
            for (cindex = 0; cindex < numUserSplits; ++cindex) {
                int splitNum = this._num_user_splits[cindex];
                this._user_split_per_col[cindex] = new double[splitNum];
                System.arraycopy(this._user_splits, user_splits_start[cindex], this._user_split_per_col[cindex], 0, splitNum);
            }
        }
        Frame fr = this._frame_id.get();
        if (this._weight_column_index >= 0 && (!fr.vec(this._weight_column_index).isNumeric() || fr.vec(this._weight_column_index).isCategorical())) {
            throw new IllegalArgumentException("Weight column " + this._weight_column_index + " must be a numerical column.");
        }
        for (int i = 0; i < this._cols_1d_2d.size(); ++i) {
            String col = this._cols_1d_2d.get(i);
            Vec v = fr.vec(col);
            if (!v.isCategorical() || v.cardinality() <= this._nbins) continue;
            throw new IllegalArgumentException("Column " + col + "'s cardinality of " + v.cardinality() + " > nbins of " + this._nbins);
        }
    }

    double[] extractColValues(String col, int actualbins, Vec v) {
        double[] colVals;
        if (this._user_splits_present && Arrays.asList(this._user_cols).contains(col)) {
            int user_col_index = Arrays.asList(this._user_cols).indexOf(col);
            actualbins = this._num_user_splits[user_col_index];
            colVals = this._add_missing_na ? new double[this._num_user_splits[user_col_index] + 1] : new double[this._num_user_splits[user_col_index]];
            for (int rindex = 0; rindex < this._num_user_splits[user_col_index]; ++rindex) {
                colVals[rindex] = this._user_split_per_col[user_col_index][rindex];
            }
        } else {
            if (v.isInt() && v.max() - v.min() + 1.0 < (double)this._nbins) {
                actualbins = (int)(v.max() - v.min() + 1.0);
            }
            colVals = this._add_missing_na ? new double[actualbins + 1] : new double[actualbins];
            double delta = (v.max() - v.min()) / (double)(actualbins - 1);
            if (actualbins == 1) {
                delta = 0.0;
            }
            for (int j = 0; j < colVals.length; ++j) {
                colVals[j] = v.min() + (double)j * delta;
            }
        }
        if (this._add_missing_na) {
            colVals[actualbins] = Double.NaN;
        }
        Log.debug("Computing PartialDependence for column " + col + " at the following values: ");
        Log.debug(Arrays.toString(colVals));
        return colVals;
    }

    @Override
    public Class<KeyV3.PartialDependenceKeyV3> makeSchema() {
        return KeyV3.PartialDependenceKeyV3.class;
    }

    private class PartialDependenceDriver
    extends H2O.H2OCountedCompleter<PartialDependenceDriver> {
        private PartialDependenceDriver() {
        }

        @Override
        public void compute2() {
            assert (PartialDependence.this._job != null);
            Frame fr = PartialDependence.this._frame_id.get();
            int num_cols_1d_2d = PartialDependence.this._num_1D + PartialDependence.this._num_2D_pairs;
            PartialDependence.this._partial_dependence_data = new TwoDimTable[num_cols_1d_2d];
            int column = 0;
            for (int i = 0; i < num_cols_1d_2d; ++i) {
                boolean cat2;
                boolean workingOn1D = i < PartialDependence.this._num_1D;
                String col = workingOn1D ? PartialDependence.this._cols[column] : PartialDependence.this._col_pairs_2dpdp[column - PartialDependence.this._num_1D][0];
                String col2 = workingOn1D ? null : PartialDependence.this._col_pairs_2dpdp[column - PartialDependence.this._num_1D][1];
                int whichPredictorColumn = i % PartialDependence.this._predictor_columns.length;
                Log.debug("Computing partial dependence of model on '" + col + "'" + (PartialDependence.this._targets == null ? "." : " and class " + PartialDependence.this._targets[whichPredictorColumn] + "."));
                double[] colVals = PartialDependence.this.extractColValues(col, PartialDependence.this._nbins, fr.vec(col));
                double[] col2Vals = workingOn1D ? null : PartialDependence.this.extractColValues(col2, PartialDependence.this._nbins, fr.vec(col2));
                Futures fs = new Futures();
                int responseLength = workingOn1D ? colVals.length : colVals.length * col2Vals.length;
                double[] meanResponse = new double[responseLength];
                double[] stddevResponse = new double[responseLength];
                double[] stdErrorOfTheMeanResponse = new double[responseLength];
                boolean cat = fr.vec(col).isCategorical();
                boolean bl = cat2 = workingOn1D ? false : fr.vec(col2).isCategorical();
                if (workingOn1D) {
                    for (int k = 0; k < colVals.length; ++k) {
                        double value = colVals[k];
                        CalculatePdpPerBin pdp = new CalculatePdpPerBin(col, col2, value, -1.0, cat, cat2, k, false, meanResponse, stddevResponse, stdErrorOfTheMeanResponse, PartialDependence.this._predictor_columns[whichPredictorColumn]);
                        fs.add(H2O.submitTask(pdp));
                    }
                } else {
                    int colLen1 = colVals.length;
                    int colLen2 = col2Vals.length;
                    int totLen = colLen1 * colLen2;
                    for (int k = 0; k < totLen; ++k) {
                        int index1 = k / colLen2;
                        int index2 = k % colLen2;
                        double value = colVals[index1];
                        double value2 = col2Vals[index2];
                        CalculatePdpPerBin pdp = new CalculatePdpPerBin(col, col2, value, value2, cat, cat2, k, true, meanResponse, stddevResponse, stdErrorOfTheMeanResponse, PartialDependence.this._predictor_columns[whichPredictorColumn]);
                        fs.add(H2O.submitTask(pdp));
                    }
                }
                fs.blockForPending();
                PartialDependence.this._partial_dependence_data[i] = workingOn1D ? new TwoDimTable("PartialDependence", PartialDependence.this._row_index < 0L ? "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on column '" + col + "'" + (PartialDependence.this._targets == null ? "." : " and class " + PartialDependence.this._targets[whichPredictorColumn]) : "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on column '" + col + "'" + (PartialDependence.this._targets == null ? "'" : " and class " + PartialDependence.this._targets[whichPredictorColumn]) + " for row index" + PartialDependence.this._row_index, new String[colVals.length], new String[]{col, "mean_response", "stddev_response", "std_error_mean_response"}, new String[]{cat ? "string" : "double", "double", "double", "double"}, new String[]{cat ? "%s" : "%5f", "%5f", "%5f", "%5f"}, null) : new TwoDimTable("2D-PartialDependence", PartialDependence.this._row_index < 0L ? "2D Partial Dependence Plot of model " + PartialDependence.this._model_id + " on 1st column '" + col + "' and 2nd column '" + col2 + "'" : "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on columns '" + col + "', '" + col2 + "' for row " + PartialDependence.this._row_index, new String[colVals.length * col2Vals.length], new String[]{col, col2, "mean_response", "stddev_response", "std_error_mean_response"}, new String[]{cat ? "string" : "double", cat2 ? "string" : "double", "double", "double", "double"}, new String[]{cat ? "%s" : "%5f", cat2 ? "%s" : "%5f", "%5f", "%5f", "%5f"}, null);
                for (int j = 0; j < meanResponse.length; ++j) {
                    int countval1;
                    int colIndex = 0;
                    int n = countval1 = workingOn1D ? j : j / col2Vals.length;
                    if (fr.vec(col).isCategorical()) {
                        if (PartialDependence.this._add_missing_na && Double.isNaN(colVals[countval1])) {
                            PartialDependence.this._partial_dependence_data[i].set(j, colIndex, ".missing(NA)");
                        } else {
                            PartialDependence.this._partial_dependence_data[i].set(j, colIndex, fr.vec(col).domain()[(int)colVals[countval1]]);
                        }
                    } else {
                        PartialDependence.this._partial_dependence_data[i].set(j, colIndex, colVals[countval1]);
                    }
                    ++colIndex;
                    if (!workingOn1D) {
                        int countval2 = j % col2Vals.length;
                        if (fr.vec(col2).isCategorical()) {
                            if (PartialDependence.this._add_missing_na && Double.isNaN(col2Vals[countval2])) {
                                PartialDependence.this._partial_dependence_data[i].set(j, colIndex, ".missing(NA)");
                            } else {
                                PartialDependence.this._partial_dependence_data[i].set(j, colIndex, fr.vec(col2).domain()[(int)col2Vals[countval2]]);
                            }
                        } else {
                            PartialDependence.this._partial_dependence_data[i].set(j, colIndex, col2Vals[countval2]);
                        }
                    }
                    int n2 = ++colIndex;
                    PartialDependence.this._partial_dependence_data[i].set(j, n2, meanResponse[j]);
                    int n3 = ++colIndex;
                    PartialDependence.this._partial_dependence_data[i].set(j, n3, stddevResponse[j]);
                    int n4 = ++colIndex;
                    ++colIndex;
                    PartialDependence.this._partial_dependence_data[i].set(j, n4, stdErrorOfTheMeanResponse[j]);
                }
                if (PartialDependence.this._targets == null) {
                    ++column;
                } else if ((i + 1) % PartialDependence.this._targets.length == 0) {
                    ++column;
                }
                PartialDependence.this._job.update(1L);
                PartialDependence.this.update(PartialDependence.this._job);
                if (PartialDependence.this._job.stop_requested()) break;
            }
            this.tryComplete();
        }

        public FrameUtils.CalculateWeightMeanSTD getWeightedStat(Frame dataFrame, Frame pred, int targetIndex) {
            FrameUtils.CalculateWeightMeanSTD calMeansSTD = new FrameUtils.CalculateWeightMeanSTD();
            calMeansSTD.doAll(pred.vec(targetIndex), dataFrame.vec(PartialDependence.this._weight_column_index));
            return calMeansSTD;
        }

        @Override
        public void onCompletion(CountedCompleter caller) {
            PartialDependence.this._frame_id.get().unlock(PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
        }

        @Override
        public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
            PartialDependence.this._frame_id.get().unlock(PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
            return true;
        }

        private class CalculatePdpPerBin
        extends H2O.H2OCountedCompleter<CalculatePdpPerBin> {
            final String _col;
            final String _col2;
            final double _value;
            final double _value2;
            final boolean _workOn2D;
            final int _pdp_row_index;
            final boolean _col1_cat;
            final boolean _col2_cat;
            final double[] _meanResponse;
            final double[] _stddevResponse;
            final double[] _stdErrorOfTheMeanResponse;
            final int _predictorColumn;

            CalculatePdpPerBin(String col, String col2, double value, double value2, boolean cat, boolean cat2, int which, boolean workon2D, double[] meanResp, double[] stddevResp, double[] stdErrMeanResp, int predictorColumn) {
                this._col = col;
                this._col2 = col2;
                this._value = value;
                this._value2 = value2;
                this._workOn2D = workon2D;
                this._pdp_row_index = which;
                this._col1_cat = cat;
                this._col2_cat = cat2;
                this._meanResponse = meanResp;
                this._stddevResponse = stddevResp;
                this._stdErrorOfTheMeanResponse = stdErrMeanResp;
                this._predictorColumn = predictorColumn;
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            public void compute2() {
                Frame fr = PartialDependence.this._row_index >= 0L ? Rapids.exec("(rows " + PartialDependence.this._frame_id + "  " + PartialDependence.this._row_index + ")").getFrame() : PartialDependence.this._frame_id.get();
                Frame test = new Frame(fr.names(), fr.vecs());
                Vec orig = test.remove(this._col);
                Vec cons = orig.makeCon(this._value);
                if (this._col1_cat) {
                    cons.setDomain(fr.vec(this._col).domain());
                }
                test.add(this._col, cons);
                Vec cons2 = null;
                if (this._workOn2D) {
                    Vec orig2 = test.remove(this._col2);
                    cons2 = orig2.makeCon(this._value2);
                    if (this._col2_cat) {
                        cons2.setDomain(fr.vec(this._col2).domain());
                    }
                    test.add(this._col2, cons2);
                }
                Keyed preds = null;
                try {
                    preds = PartialDependence.this._model_id.get().score(test, Key.make().toString(), PartialDependence.this._job, false);
                    if (preds == null || ((Frame)preds).numRows() == 0L) {
                        this._meanResponse[this._pdp_row_index] = Double.NaN;
                        this._stddevResponse[this._pdp_row_index] = Double.NaN;
                        this._stdErrorOfTheMeanResponse[this._pdp_row_index] = Double.NaN;
                    } else {
                        FrameUtils.CalculateWeightMeanSTD calMeansSTD = PartialDependence.this._weight_column_index >= 0 ? PartialDependenceDriver.this.getWeightedStat(fr, (Frame)preds, this._predictorColumn) : null;
                        this._meanResponse[this._pdp_row_index] = PartialDependence.this._weight_column_index >= 0 ? calMeansSTD.getWeightedMean() : ((Frame)preds).vec(this._predictorColumn).mean();
                        this._stddevResponse[this._pdp_row_index] = PartialDependence.this._weight_column_index >= 0 ? calMeansSTD.getWeightedSigma() : ((Frame)preds).vec(this._predictorColumn).sigma();
                        this._stdErrorOfTheMeanResponse[this._pdp_row_index] = this._stddevResponse[this._pdp_row_index] / Math.sqrt(((Frame)preds).numRows());
                    }
                }
                finally {
                    if (preds != null) {
                        preds.remove();
                    }
                }
                cons.remove();
                if (cons2 != null) {
                    cons2.remove();
                }
                if (PartialDependence.this._row_index >= 0L) {
                    fr.remove();
                }
                this.tryComplete();
            }
        }
    }
}

