/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.isotonic.IsotonicRegression;
import hex.isotonic.IsotonicRegressionModel;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;

public class CalibrationHelper {
    public static void initCalibration(ModelBuilderWithCalibration builder, ParamsWithCalibration parms, boolean expensive) {
        Frame cf = parms.getCalibrationFrame();
        if (cf != null) {
            if (!parms.calibrateModel()) {
                builder.getModelBuilder().warn("_calibration_frame", "Calibration frame was specified but calibration was not requested.");
            }
            Frame adaptedCf = builder.getModelBuilder().init_adaptFrameToTrain(cf, "Calibration Frame", "_calibration_frame", expensive);
            builder.setCalibrationFrame(adaptedCf);
        }
        if (parms.calibrateModel()) {
            if (builder.getModelBuilder().nclasses() != 2) {
                builder.getModelBuilder().error("_calibrate_model", "Model calibration is only currently supported for binomial models.");
            }
            if (cf == null) {
                builder.getModelBuilder().error("_calibrate_model", "Calibration frame was not specified.");
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static <M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> Model<?, ?, ?> buildCalibrationModel(ModelBuilderWithCalibration<M, P, O> builder, ParamsWithCalibration parms, Job job, M model) {
        CalibrationMethod calibrationMethod = parms.getCalibrationMethod() == CalibrationMethod.AUTO ? CalibrationMethod.PlattScaling : parms.getCalibrationMethod();
        Key calibInputKey = Key.make();
        try {
            ModelBuilder<?, ?, ?> calibrationModelBuilder;
            Scope.enter();
            job.update(0L, "Calibrating probabilities");
            Frame calib = builder.getCalibrationFrame();
            Vec calibWeights = parms.getParams()._weights_column != null ? calib.vec(parms.getParams()._weights_column) : null;
            Frame calibPredict = Scope.track((Frame[])new Frame[]{model.score(calib, null, job, false)});
            int calibVecIdx = calibrationMethod.getCalibratedVecIdx();
            Frame calibInput = new Frame(calibInputKey, new String[]{"p", "response"}, new Vec[]{calibPredict.vec(calibVecIdx), calib.vec(parms.getParams()._response_column)});
            if (calibWeights != null) {
                calibInput.add("weights", calibWeights);
            }
            DKV.put((Keyed)calibInput);
            switch (calibrationMethod) {
                case PlattScaling: {
                    calibrationModelBuilder = CalibrationHelper.makePlattScalingModelBuilder(calibInput, calibWeights != null);
                    break;
                }
                case IsotonicRegression: {
                    calibrationModelBuilder = CalibrationHelper.makeIsotonicRegressionModelBuilder(calibInput, calibWeights != null);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported calibration method: " + (Object)((Object)calibrationMethod));
                }
            }
            Model model2 = (Model)calibrationModelBuilder.trainModel().get();
            return model2;
        }
        finally {
            Scope.exit((Key[])new Key[0]);
            DKV.remove((Key)calibInputKey);
        }
    }

    static ModelBuilder<?, ?, ?> makePlattScalingModelBuilder(Frame calibInput, boolean hasWeights) {
        Key calibModelKey = Key.make();
        Job calibJob = new Job(calibModelKey, ModelBuilder.javaName((String)"glm"), "Platt Scaling (GLM)");
        GLM calibBuilder = (GLM)ModelBuilder.make((String)"GLM", (Job)calibJob, (Key)calibModelKey);
        ((GLMModel.GLMParameters)calibBuilder._parms)._intercept = true;
        ((GLMModel.GLMParameters)calibBuilder._parms)._response_column = "response";
        ((GLMModel.GLMParameters)calibBuilder._parms)._train = calibInput._key;
        ((GLMModel.GLMParameters)calibBuilder._parms)._family = GLMModel.GLMParameters.Family.binomial;
        ((GLMModel.GLMParameters)calibBuilder._parms)._lambda = new double[]{0.0};
        if (hasWeights) {
            ((GLMModel.GLMParameters)calibBuilder._parms)._weights_column = "weights";
        }
        return calibBuilder;
    }

    static ModelBuilder<?, ?, ?> makeIsotonicRegressionModelBuilder(Frame calibInput, boolean hasWeights) {
        Key calibModelKey = Key.make();
        Job calibJob = new Job(calibModelKey, ModelBuilder.javaName((String)"isotonicregression"), "Isotonic Regression Calibration");
        IsotonicRegression calibBuilder = (IsotonicRegression)ModelBuilder.make((String)"isotonicregression", (Job)calibJob, (Key)calibModelKey);
        ((IsotonicRegressionModel.IsotonicRegressionParameters)calibBuilder._parms)._response_column = "response";
        ((IsotonicRegressionModel.IsotonicRegressionParameters)calibBuilder._parms)._train = calibInput._key;
        ((IsotonicRegressionModel.IsotonicRegressionParameters)calibBuilder._parms)._out_of_bounds = IsotonicRegressionModel.OutOfBoundsHandling.Clip;
        if (hasWeights) {
            ((IsotonicRegressionModel.IsotonicRegressionParameters)calibBuilder._parms)._weights_column = "weights";
        }
        return calibBuilder;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Frame postProcessPredictions(Frame predictFr, Job j, OutputWithCalibration output) {
        if (output.calibrationModel() == null) {
            return predictFr;
        }
        if (output.getModelCategory() == ModelCategory.Binomial) {
            Key jobKey = j != null ? j._key : null;
            Key calibInputKey = Key.make();
            Frame calibOutput = null;
            Frame toUnlock = null;
            try {
                Vec[] calPredictions;
                Model<?, ?, ?> calibModel = output.calibrationModel();
                int calibVecIdx = output.getCalibrationMethod().getCalibratedVecIdx();
                String[] calibFeatureNames = calibModel._output.features();
                assert (calibFeatureNames.length == 1);
                Frame calibInput = new Frame(calibInputKey, calibFeatureNames, new Vec[]{predictFr.vec(calibVecIdx)});
                calibOutput = calibModel.score(calibInput);
                if (calibModel instanceof GLMModel) {
                    assert (calibOutput._names.length == 3);
                    calPredictions = calibOutput.remove(new int[]{1, 2});
                } else if (calibModel instanceof IsotonicRegressionModel) {
                    assert (calibOutput._names.length == 1);
                    Vec p1 = calibOutput.remove(0);
                    Vec p0 = ((P0Task)new P0Task().doAll((byte)3, new Vec[]{p1})).outputFrame().lastVec();
                    calPredictions = new Vec[]{p0, p1};
                } else {
                    throw new UnsupportedOperationException("Unsupported calibration model: " + calibModel);
                }
                predictFr.write_lock(jobKey);
                toUnlock = predictFr;
                for (int i = 0; i < calPredictions.length; ++i) {
                    predictFr.add("cal_" + predictFr.name(1 + i), calPredictions[i]);
                }
                Frame frame = (Frame)predictFr.update(jobKey);
                return frame;
            }
            finally {
                if (toUnlock != null) {
                    predictFr.unlock(jobKey);
                }
                DKV.remove((Key)calibInputKey);
                if (calibOutput != null) {
                    calibOutput.remove();
                }
            }
        }
        throw H2O.unimpl((String)"Calibration is only supported for binomial models");
    }

    private static class P0Task
    extends MRTask<P0Task> {
        private P0Task() {
        }

        public void map(Chunk c, NewChunk nc) {
            for (int i = 0; i < c._len; ++i) {
                if (c.isNA(i)) {
                    nc.addNA();
                    continue;
                }
                double p1 = c.atd(i);
                nc.addNum(1.0 - p1);
            }
        }
    }

    public static interface OutputWithCalibration {
        public ModelCategory getModelCategory();

        public Model<?, ?, ?> calibrationModel();

        public void setCalibrationModel(Model<?, ?, ?> var1);

        default public CalibrationMethod getCalibrationMethod() {
            if (!1.$assertionsDisabled && !this.isCalibrated()) {
                throw new AssertionError();
            }
            return this.calibrationModel() instanceof IsotonicRegressionModel ? CalibrationMethod.IsotonicRegression : CalibrationMethod.PlattScaling;
        }

        default public boolean isCalibrated() {
            return this.calibrationModel() != null;
        }

        static {
            if (1.$assertionsDisabled) {
                // empty if block
            }
        }
    }

    public static interface ParamsWithCalibration {
        public Model.Parameters getParams();

        public Frame getCalibrationFrame();

        public boolean calibrateModel();

        public CalibrationMethod getCalibrationMethod();

        public void setCalibrationMethod(CalibrationMethod var1);
    }

    public static interface ModelBuilderWithCalibration<M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> {
        public ModelBuilder<M, P, O> getModelBuilder();

        public Frame getCalibrationFrame();

        public void setCalibrationFrame(Frame var1);
    }

    public static enum CalibrationMethod {
        AUTO("auto", -1),
        PlattScaling("platt", 1),
        IsotonicRegression("isotonic", 2);

        private final int _calibVecIdx;
        private final String _id;

        private CalibrationMethod(String id, int calibVecIdx) {
            this._calibVecIdx = calibVecIdx;
            this._id = id;
        }

        private int getCalibratedVecIdx() {
            return this._calibVecIdx;
        }

        public String getId() {
            return this._id;
        }
    }
}

