/*
 * Decompiled with CFR 0.152.
 */
package water.udf;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import java.util.Arrays;
import org.junit.Assert;
import water.DKV;
import water.Key;
import water.Lockable;
import water.TestUtil;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.util.FrameUtils;

public class CustomMetricUtils {
    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static void testNullModelRegression(final CFuncRef func) throws Exception {
        final Frame f = TestUtil.Datasets.iris();
        Frame pred = null;
        Model model = null;
        try {
            NullModelParameters params = new NullModelParameters(){
                {
                    this._train = f._key;
                    this._response_column = "sepal_len";
                    this._custom_metric_func = func.toRef();
                }
            };
            model = (Model)new NullModelBuilder(params).trainModel().get();
            pred = model.score(f, null, null, true, func);
            Assert.assertEquals((String)"Null model generates only a single model metrics", (long)1L, (long)model._output.getModelMetrics().length);
            ModelMetrics mm = (ModelMetrics)model._output.getModelMetrics()[0].get();
            Assert.assertEquals((String)"Custom model metrics should compute mean of response column", (double)f.vec("sepal_len").mean(), (double)mm._custom_metric.value, (double)1.0E-8);
        }
        catch (Throwable throwable) {
            FrameUtils.delete((Lockable[])new Lockable[]{f, pred, model});
            DKV.remove((Key)func.getKey());
            throw throwable;
        }
        FrameUtils.delete((Lockable[])new Lockable[]{f, pred, model});
        DKV.remove((Key)func.getKey());
    }

    static class NullModelBuilder
    extends ModelBuilder<NullModel, NullModelParameters, NullModelOutput> {
        public NullModelBuilder(NullModelParameters parms) {
            super((Model.Parameters)parms);
            this.init(false);
        }

        public void init(boolean expensive) {
            super.init(expensive);
        }

        protected ModelBuilder.Driver trainModelImpl() {
            return new ModelBuilder.Driver(){

                public void computeImpl() {
                    this.init(true);
                    NullModel model = new NullModel((Key<NullModel>)this.dest(), (NullModelParameters)_parms, new NullModelOutput(this));
                    try {
                        model.delete_and_lock(_job);
                    }
                    finally {
                        model.unlock(_job);
                    }
                }
            };
        }

        public ModelCategory[] can_build() {
            return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
        }

        public boolean isSupervised() {
            return true;
        }
    }

    static class NullModel
    extends Model<NullModel, NullModelParameters, NullModelOutput> {
        public NullModel(Key<NullModel> selfKey, NullModelParameters parms, NullModelOutput output) {
            super(selfKey, (Model.Parameters)parms, (Model.Output)output);
        }

        public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
            switch (((NullModelOutput)this._output).getModelCategory()) {
                case Binomial: {
                    return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
                }
                case Multinomial: {
                    return new ModelMetricsMultinomial.MetricBuilderMultinomial(((NullModelOutput)this._output).nclasses(), domain, ((NullModelParameters)this._parms)._auc_type);
                }
                case Regression: {
                    return new ModelMetricsRegression.MetricBuilderRegression();
                }
            }
            return null;
        }

        protected double[] score0(double[] data, double[] preds) {
            Arrays.fill(preds, 0.0);
            return preds;
        }
    }

    static class NullModelParameters
    extends Model.Parameters {
        NullModelParameters() {
        }

        public String fullName() {
            return "nullModel";
        }

        public String algoName() {
            return "nullModel";
        }

        public String javaName() {
            return NullModelBuilder.class.getName();
        }

        public long progressUnits() {
            return 1L;
        }
    }

    static class NullModelOutput
    extends Model.Output {
        public NullModelOutput(ModelBuilder b) {
            super(b);
        }
    }
}

