/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.glm.GLM;
import hex.glm.GLMModel;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;

public class PlattScalingHelper {
    public static void initCalibration(ModelBuilderWithCalibration builder, ParamsWithCalibration parms, boolean expensive) {
        Frame cf = parms.getCalibrationFrame();
        if (cf != null) {
            if (!parms.calibrateModel()) {
                builder.getModelBuilder().warn("_calibration_frame", "Calibration frame was specified but calibration was not requested.");
            }
            Frame adaptedCf = builder.getModelBuilder().init_adaptFrameToTrain(cf, "Calibration Frame", "_calibration_frame", expensive);
            builder.setCalibrationFrame(adaptedCf);
        }
        if (parms.calibrateModel()) {
            if (builder.getModelBuilder().nclasses() != 2) {
                builder.getModelBuilder().error("_calibrate_model", "Model calibration is only currently supported for binomial models.");
            }
            if (cf == null) {
                builder.getModelBuilder().error("_calibrate_model", "Calibration frame was not specified.");
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static <M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> GLMModel buildCalibrationModel(ModelBuilderWithCalibration<M, P, O> builder, ParamsWithCalibration parms, Job job, M model) {
        Key<Frame> calibInputKey = Key.make();
        try {
            Scope.enter();
            job.update(0L, "Calibrating probabilities");
            Frame calib = builder.getCalibrationFrame();
            Vec calibWeights = parms.getParams()._weights_column != null ? calib.vec(parms.getParams()._weights_column) : null;
            Frame calibPredict = Scope.track(model.score(calib, null, job, false));
            Frame calibInput = new Frame(calibInputKey, new String[]{"p", "response"}, new Vec[]{calibPredict.vec(1), calib.vec(parms.getParams()._response_column)});
            if (calibWeights != null) {
                calibInput.add("weights", calibWeights);
            }
            DKV.put(calibInput);
            Key<Model> calibModelKey = Key.make();
            Job calibJob = new Job(calibModelKey, ModelBuilder.javaName("glm"), "Platt Scaling (GLM)");
            GLM calibBuilder = (GLM)ModelBuilder.make("GLM", calibJob, calibModelKey);
            ((GLMModel.GLMParameters)calibBuilder._parms)._intercept = true;
            ((GLMModel.GLMParameters)calibBuilder._parms)._response_column = "response";
            ((GLMModel.GLMParameters)calibBuilder._parms)._train = calibInput._key;
            ((GLMModel.GLMParameters)calibBuilder._parms)._family = GLMModel.GLMParameters.Family.binomial;
            ((GLMModel.GLMParameters)calibBuilder._parms)._lambda = new double[]{0.0};
            if (calibWeights != null) {
                ((GLMModel.GLMParameters)calibBuilder._parms)._weights_column = "weights";
            }
            GLMModel gLMModel = (GLMModel)calibBuilder.trainModel().get();
            return gLMModel;
        }
        finally {
            Scope.exit(new Key[0]);
            DKV.remove(calibInputKey);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Frame postProcessPredictions(Frame predictFr, Job j2, OutputWithCalibration output) {
        if (output.calibrationModel() == null) {
            return predictFr;
        }
        if (output.getModelCategory() == ModelCategory.Binomial) {
            Key jobKey = j2 != null ? j2._key : null;
            Key<Frame> calibInputKey = Key.make();
            Keyed calibOutput = null;
            try {
                Frame calibInput = new Frame(calibInputKey, new String[]{"p"}, new Vec[]{predictFr.vec(1)});
                calibOutput = output.calibrationModel().score(calibInput);
                assert (((Frame)calibOutput)._names.length == 3);
                Vec[] calPredictions = ((Frame)calibOutput).remove(new int[]{1, 2});
                predictFr.write_lock(jobKey);
                for (int i2 = 0; i2 < calPredictions.length; ++i2) {
                    predictFr.add("cal_" + predictFr.name(1 + i2), calPredictions[i2]);
                }
                Frame frame = (Frame)predictFr.update(jobKey);
                return frame;
            }
            finally {
                predictFr.unlock(jobKey);
                DKV.remove(calibInputKey);
                if (calibOutput != null) {
                    calibOutput.remove();
                }
            }
        }
        throw H2O.unimpl("Calibration is only supported for binomial models");
    }

    public static interface OutputWithCalibration {
        public ModelCategory getModelCategory();

        public GLMModel calibrationModel();
    }

    public static interface ParamsWithCalibration {
        public Model.Parameters getParams();

        public Frame getCalibrationFrame();

        public boolean calibrateModel();
    }

    public static interface ModelBuilderWithCalibration<M extends Model<M, P, O>, P extends Model.Parameters, O extends Model.Output> {
        public ModelBuilder getModelBuilder();

        public Frame getCalibrationFrame();

        public void setCalibrationFrame(Frame var1);
    }
}

