/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.Model;
import java.util.Arrays;
import jsr166y.CountedCompleter;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.FrameUtils;
import water.util.Log;
import water.util.TwoDimTable;

public class PartialDependence
extends Lockable<PartialDependence> {
    public final transient Job _job;
    public Key<Model> _model_id;
    public Key<Frame> _frame_id;
    public String[] _cols;
    public int _weight_column_index = -1;
    public boolean _add_missing_na = false;
    public int _nbins = 20;
    public TwoDimTable[] _partial_dependence_data;
    public double[] _user_splits = null;
    public double[][] _user_split_per_col = null;
    public int[] _num_user_splits = null;
    public String[] _user_cols = null;
    public boolean _user_splits_present = false;

    public PartialDependence(Key<PartialDependence> dest, Job j) {
        super(dest);
        this._job = j;
    }

    public PartialDependence(Key<PartialDependence> dest) {
        this(dest, new Job<PartialDependence>(dest, PartialDependence.class.getName(), "PartialDependence"));
    }

    public PartialDependence execNested() {
        this.checkSanityAndFillParams();
        this.delete_and_lock(this._job);
        this._frame_id.get().write_lock(this._job._key);
        new PartialDependenceDriver().compute2();
        return this;
    }

    public Job<PartialDependence> execImpl() {
        this.checkSanityAndFillParams();
        this.delete_and_lock(this._job);
        this._frame_id.get().write_lock(this._job._key);
        this._job.start(new PartialDependenceDriver(), this._cols.length);
        return this._job;
    }

    private void checkSanityAndFillParams() {
        if (this._cols == null) {
            Model m = this._model_id.get();
            if (m == null) {
                throw new IllegalArgumentException("Model not found.");
            }
            if (!((Model.Output)m._output).isSupervised() || ((Model.Output)m._output).nclasses() > 2) {
                throw new IllegalArgumentException("Partial dependence plots are only implemented for regression and binomial classification models");
            }
            Frame f = this._frame_id.get();
            if (f == null) {
                throw new IllegalArgumentException("Frame not found.");
            }
            if (Model.GetMostImportantFeatures.class.isAssignableFrom(m.getClass())) {
                this._cols = ((Model.GetMostImportantFeatures)((Object)m)).getMostImportantFeatures(10);
                if (this._cols != null) {
                    Log.info("Selecting the top " + this._cols.length + " features from the model's variable importances");
                }
            }
        }
        if (this._nbins < 2) {
            throw new IllegalArgumentException("_nbins must be >=2.");
        }
        if (this._user_splits != null && this._user_splits.length > 0) {
            this._user_splits_present = true;
            int numUserSplits = this._user_cols.length;
            this._user_split_per_col = new double[numUserSplits][];
            int[] user_splits_start = new int[numUserSplits];
            System.arraycopy(this._num_user_splits, 0, user_splits_start, 1, numUserSplits - 1);
            for (int cindex = 0; cindex < numUserSplits; ++cindex) {
                int splitNum = this._num_user_splits[cindex];
                this._user_split_per_col[cindex] = new double[splitNum];
                System.arraycopy(this._user_splits, user_splits_start[cindex], this._user_split_per_col[cindex], 0, splitNum);
            }
        }
        Frame fr = this._frame_id.get();
        if (this._weight_column_index >= 0 && (!fr.vec(this._weight_column_index).isNumeric() || fr.vec(this._weight_column_index).isCategorical())) {
            throw new IllegalArgumentException("Weight column " + this._weight_column_index + " must be a numerical column.");
        }
        for (int i = 0; i < this._cols.length; ++i) {
            String col = this._cols[i];
            Vec v = fr.vec(col);
            if (!v.isCategorical() || v.cardinality() <= this._nbins) continue;
            throw new IllegalArgumentException("Column " + col + "'s cardinality of " + v.cardinality() + " > nbins of " + this._nbins);
        }
    }

    @Override
    public Class<KeyV3.PartialDependenceKeyV3> makeSchema() {
        return KeyV3.PartialDependenceKeyV3.class;
    }

    private class PartialDependenceDriver
    extends H2O.H2OCountedCompleter<PartialDependenceDriver> {
        private PartialDependenceDriver() {
        }

        @Override
        public void compute2() {
            assert (PartialDependence.this._job != null);
            Frame fr = PartialDependence.this._frame_id.get();
            PartialDependence.this._partial_dependence_data = new TwoDimTable[PartialDependence.this._cols.length];
            for (int i = 0; i < PartialDependence.this._cols.length; ++i) {
                double[] colVals;
                final String col = PartialDependence.this._cols[i];
                Log.debug("Computing partial dependence of model on '" + col + "'.");
                Vec v = fr.vec(col);
                int actualbins = PartialDependence.this._nbins;
                if (PartialDependence.this._user_splits_present && Arrays.asList(PartialDependence.this._user_cols).contains(col)) {
                    int user_col_index = Arrays.asList(PartialDependence.this._user_cols).indexOf(col);
                    actualbins = PartialDependence.this._num_user_splits[user_col_index];
                    colVals = PartialDependence.this._add_missing_na ? new double[PartialDependence.this._num_user_splits[user_col_index] + 1] : new double[PartialDependence.this._num_user_splits[user_col_index]];
                    for (int rindex = 0; rindex < PartialDependence.this._num_user_splits[user_col_index]; ++rindex) {
                        colVals[rindex] = PartialDependence.this._user_split_per_col[user_col_index][rindex];
                    }
                } else {
                    if (v.isInt() && v.max() - v.min() + 1.0 < (double)PartialDependence.this._nbins) {
                        actualbins = (int)(v.max() - v.min() + 1.0);
                    }
                    colVals = PartialDependence.this._add_missing_na ? new double[actualbins + 1] : new double[actualbins];
                    double delta = (v.max() - v.min()) / (double)(actualbins - 1);
                    if (actualbins == 1) {
                        delta = 0.0;
                    }
                    for (int j = 0; j < colVals.length; ++j) {
                        colVals[j] = v.min() + (double)j * delta;
                    }
                }
                if (PartialDependence.this._add_missing_na) {
                    colVals[actualbins] = Double.NaN;
                }
                Log.debug("Computing PartialDependence for column " + col + " at the following values: ");
                Log.debug(Arrays.toString(colVals));
                Futures fs = new Futures();
                final double[] meanResponse = new double[colVals.length];
                final double[] stddevResponse = new double[colVals.length];
                final double[] stdErrorOfTheMeanResponse = new double[colVals.length];
                final boolean cat = fr.vec(col).isCategorical();
                int k = 0;
                while (k < colVals.length) {
                    final double value = colVals[k];
                    final int which = k++;
                    H2O.H2OCountedCompleter pdp = new H2O.H2OCountedCompleter(){

                        /*
                         * WARNING - Removed try catching itself - possible behaviour change.
                         */
                        @Override
                        public void compute2() {
                            Vec cons;
                            block12: {
                                Frame fr = PartialDependence.this._frame_id.get();
                                Frame test = new Frame(fr.names(), fr.vecs());
                                Vec orig = test.remove(col);
                                cons = orig.makeCon(value);
                                if (cat) {
                                    cons.setDomain(fr.vec(col).domain());
                                }
                                test.add(col, cons);
                                Keyed preds = null;
                                try {
                                    preds = PartialDependence.this._model_id.get().score(test, Key.make().toString(), PartialDependence.this._job, false);
                                    if (preds == null || ((Frame)preds).numRows() == 0L) {
                                        meanResponse[which] = Double.NaN;
                                        stddevResponse[which] = Double.NaN;
                                        stdErrorOfTheMeanResponse[which] = Double.NaN;
                                        break block12;
                                    }
                                    if (((Model.Output)PartialDependence.this._model_id.get()._output).nclasses() == 2) {
                                        if (PartialDependence.this._weight_column_index >= 0) {
                                            FrameUtils.CalculateWeightMeanSTD calMeansSTD = PartialDependenceDriver.this.getWeightedStat(fr, (Frame)preds, 2);
                                            meanResponse[which] = calMeansSTD.getWeightedMean();
                                            stddevResponse[which] = calMeansSTD.getWeightedSigma();
                                            stdErrorOfTheMeanResponse[which] = stddevResponse[which] / Math.sqrt(((Frame)preds).numRows());
                                        } else {
                                            meanResponse[which] = ((Frame)preds).vec(2).mean();
                                            stddevResponse[which] = ((Frame)preds).vec(2).sigma();
                                            stdErrorOfTheMeanResponse[which] = stddevResponse[which] / Math.sqrt(((Frame)preds).numRows());
                                        }
                                        break block12;
                                    }
                                    if (((Model.Output)PartialDependence.this._model_id.get()._output).nclasses() == 1) {
                                        if (PartialDependence.this._weight_column_index >= 0) {
                                            FrameUtils.CalculateWeightMeanSTD calMeansSTD = PartialDependenceDriver.this.getWeightedStat(fr, (Frame)preds, 0);
                                            meanResponse[which] = calMeansSTD.getWeightedMean();
                                            stddevResponse[which] = calMeansSTD.getWeightedSigma();
                                            stdErrorOfTheMeanResponse[which] = stddevResponse[which] / Math.sqrt(((Frame)preds).numRows());
                                        } else {
                                            meanResponse[which] = ((Frame)preds).vec(0).mean();
                                            stddevResponse[which] = ((Frame)preds).vec(0).sigma();
                                            stdErrorOfTheMeanResponse[which] = stddevResponse[which] / Math.sqrt(((Frame)preds).numRows());
                                        }
                                        break block12;
                                    }
                                    throw H2O.unimpl();
                                }
                                finally {
                                    if (preds != null) {
                                        preds.remove();
                                    }
                                }
                            }
                            cons.remove();
                            this.tryComplete();
                        }
                    };
                    fs.add(H2O.submitTask(pdp));
                }
                fs.blockForPending();
                PartialDependence.this._partial_dependence_data[i] = new TwoDimTable("PartialDependence", "Partial Dependence Plot of model " + PartialDependence.this._model_id + " on column '" + PartialDependence.this._cols[i] + "'", new String[colVals.length], new String[]{PartialDependence.this._cols[i], "mean_response", "stddev_response", "std_error_mean_response"}, new String[]{cat ? "string" : "double", "double", "double", "double"}, new String[]{cat ? "%s" : "%5f", "%5f", "%5f", "%5f"}, null);
                for (int j = 0; j < meanResponse.length; ++j) {
                    if (fr.vec(col).isCategorical()) {
                        if (PartialDependence.this._add_missing_na && Double.isNaN(colVals[j])) {
                            PartialDependence.this._partial_dependence_data[i].set(j, 0, ".missing(NA)");
                        } else {
                            PartialDependence.this._partial_dependence_data[i].set(j, 0, fr.vec(col).domain()[(int)colVals[j]]);
                        }
                    } else {
                        PartialDependence.this._partial_dependence_data[i].set(j, 0, colVals[j]);
                    }
                    PartialDependence.this._partial_dependence_data[i].set(j, 1, meanResponse[j]);
                    PartialDependence.this._partial_dependence_data[i].set(j, 2, stddevResponse[j]);
                    PartialDependence.this._partial_dependence_data[i].set(j, 3, stdErrorOfTheMeanResponse[j]);
                }
                PartialDependence.this._job.update(1L);
                PartialDependence.this.update(PartialDependence.this._job);
                if (PartialDependence.this._job.stop_requested()) break;
            }
            this.tryComplete();
        }

        public FrameUtils.CalculateWeightMeanSTD getWeightedStat(Frame dataFrame, Frame pred, int targetIndex) {
            FrameUtils.CalculateWeightMeanSTD calMeansSTD = new FrameUtils.CalculateWeightMeanSTD();
            calMeansSTD.doAll(pred.vec(targetIndex), dataFrame.vec(PartialDependence.this._weight_column_index));
            return calMeansSTD;
        }

        @Override
        public void onCompletion(CountedCompleter caller) {
            PartialDependence.this._frame_id.get().unlock(PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
        }

        @Override
        public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) {
            PartialDependence.this._frame_id.get().unlock(PartialDependence.this._job._key);
            PartialDependence.this.unlock(PartialDependence.this._job);
            return true;
        }
    }
}

