/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.sparkling.ml.features;

import ai.h2o.automl.targetencoding.BlendingParams;
import ai.h2o.automl.targetencoding.TargetEncoderBuilder;
import ai.h2o.automl.targetencoding.TargetEncoderModel;
import ai.h2o.sparkling.ml.features.H2OTargetEncoder$;
import ai.h2o.sparkling.ml.features.H2OTargetEncoderBase;
import ai.h2o.sparkling.ml.features.H2OTargetEncoderBase$class;
import ai.h2o.sparkling.ml.features.H2OTargetEncoderHoldoutStrategy;
import ai.h2o.sparkling.ml.models.H2OTargetEncoderModel;
import ai.h2o.sparkling.ml.params.H2OAlgoParamsHelper$;
import ai.h2o.sparkling.ml.params.H2OTargetEncoderParams$class;
import java.io.IOException;
import org.apache.spark.h2o.H2OContext;
import org.apache.spark.h2o.H2OContext$;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.h2o.param.NullableStringParam;
import org.apache.spark.ml.param.BooleanParam;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.util.DefaultParamsWritable;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.MLReader;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.ml.util.MLWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.types.StructType;
import scala.Function0;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import water.fvec.Frame;
import water.fvec.H2OFrame;

@ScalaSignature(bytes="\u0006\u0001\u0005}e\u0001B\u0001\u0003\u00015\u0011\u0001\u0003\u0013\u001aP)\u0006\u0014x-\u001a;F]\u000e|G-\u001a:\u000b\u0005\r!\u0011\u0001\u00034fCR,(/Z:\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0005ta\u0006\u00148\u000e\\5oO*\u0011\u0011BC\u0001\u0004QJz'\"A\u0006\u0002\u0005\u0005L7\u0001A\n\u0005\u00019y2\u0005E\u0002\u0010/ei\u0011\u0001\u0005\u0006\u0003\u000bEQ!AE\n\u0002\u000bM\u0004\u0018M]6\u000b\u0005Q)\u0012AB1qC\u000eDWMC\u0001\u0017\u0003\ry'oZ\u0005\u00031A\u0011\u0011\"R:uS6\fGo\u001c:\u0011\u0005iiR\"A\u000e\u000b\u0005q!\u0011AB7pI\u0016d7/\u0003\u0002\u001f7\t)\u0002JM(UCJ<W\r^#oG>$WM]'pI\u0016d\u0007C\u0001\u0011\"\u001b\u0005\u0011\u0011B\u0001\u0012\u0003\u0005QA%g\u0014+be\u001e,G/\u00128d_\u0012,'OQ1tKB\u0011AeJ\u0007\u0002K)\u0011a\u0005E\u0001\u0005kRLG.\u0003\u0002)K\t)B)\u001a4bk2$\b+\u0019:b[N<&/\u001b;bE2,\u0007\u0002\u0003\u0016\u0001\u0005\u000b\u0007I\u0011I\u0016\u0002\u0007ULG-F\u0001-!\ti3G\u0004\u0002/c5\tqFC\u00011\u0003\u0015\u00198-\u00197b\u0013\t\u0011t&\u0001\u0004Qe\u0016$WMZ\u0005\u0003iU\u0012aa\u0015;sS:<'B\u0001\u001a0\u0011!9\u0004A!A!\u0002\u0013a\u0013\u0001B;jI\u0002BQ!\u000f\u0001\u0005\u0002i\na\u0001P5oSRtDCA\u001e=!\t\u0001\u0003\u0001C\u0003+q\u0001\u0007A\u0006C\u0003:\u0001\u0011\u0005a\bF\u0001<\u0011\u0015\u0001\u0005\u0001\"\u0011B\u0003\r1\u0017\u000e\u001e\u000b\u00033\tCQaQ A\u0002\u0011\u000bq\u0001Z1uCN,G\u000f\r\u0002F\u001bB\u0019a)S&\u000e\u0003\u001dS!\u0001S\t\u0002\u0007M\fH.\u0003\u0002K\u000f\n9A)\u0019;bg\u0016$\bC\u0001'N\u0019\u0001!\u0011B\u0014\"\u0002\u0002\u0003\u0005)\u0011A(\u0003\u0007}#\u0013'\u0005\u0002Q'B\u0011a&U\u0005\u0003%>\u0012qAT8uQ&tw\r\u0005\u0002/)&\u0011Qk\f\u0002\u0004\u0003:L\b\"B,\u0001\t\u0013A\u0016\u0001\u0007;sC&tG+\u0019:hKR,enY8eS:<Wj\u001c3fYR\u0011\u0011,\u0019\t\u00035~k\u0011a\u0017\u0006\u00039v\u000ba\u0002^1sO\u0016$XM\\2pI&twM\u0003\u0002_\u0011\u00051\u0011-\u001e;p[2L!\u0001Y.\u0003%Q\u000b'oZ3u\u000b:\u001cw\u000eZ3s\u001b>$W\r\u001c\u0005\u0006EZ\u0003\raY\u0001\u000eiJ\f\u0017N\\5oO\u001a\u0013\u0018-\\3\u0011\u0005\u0011\u001chBA3q\u001d\t1wN\u0004\u0002h]:\u0011\u0001.\u001c\b\u0003S2l\u0011A\u001b\u0006\u0003W2\ta\u0001\u0010:p_Rt\u0014\"\u0001\f\n\u0005Q)\u0012B\u0001\n\u0014\u0013\tI\u0011#\u0003\u0002re\u00069\u0001/Y2lC\u001e,'BA\u0005\u0012\u0013\t!XOA\u0003Ge\u0006lWM\u0003\u0002re\")q\u000f\u0001C!q\u0006!1m\u001c9z)\tY\u0014\u0010C\u0003{m\u0002\u000710A\u0003fqR\u0014\u0018\r\u0005\u0002}\u007f6\tQP\u0003\u0002\u007f!\u0005)\u0001/\u0019:b[&\u0019\u0011\u0011A?\u0003\u0011A\u000b'/Y7NCBDq!!\u0002\u0001\t\u0003\t9!\u0001\u0006tKR4u\u000e\u001c3D_2$B!!\u0003\u0002\f5\t\u0001\u0001C\u0004\u0002\u000e\u0005\r\u0001\u0019\u0001\u0017\u0002\u000bY\fG.^3\t\u000f\u0005E\u0001\u0001\"\u0001\u0002\u0014\u0005Y1/\u001a;MC\n,GnQ8m)\u0011\tI!!\u0006\t\u000f\u00055\u0011q\u0002a\u0001Y!9\u0011\u0011\u0004\u0001\u0005\u0002\u0005m\u0011\u0001D:fi&s\u0007/\u001e;D_2\u001cH\u0003BA\u0005\u0003;A\u0001\"a\b\u0002\u0018\u0001\u0007\u0011\u0011E\u0001\u0007m\u0006dW/Z:\u0011\t9\n\u0019\u0003L\u0005\u0004\u0003Ky#!B!se\u0006L\bbBA\u0015\u0001\u0011\u0005\u00111F\u0001\u0013g\u0016$\bj\u001c7e_V$8\u000b\u001e:bi\u0016<\u0017\u0010\u0006\u0003\u0002\n\u00055\u0002bBA\u0007\u0003O\u0001\r\u0001\f\u0005\b\u0003c\u0001A\u0011AA\u001a\u0003Q\u0019X\r\u001e\"mK:$W\rZ!wO\u0016s\u0017M\u00197fIR!\u0011\u0011BA\u001b\u0011!\ti!a\fA\u0002\u0005]\u0002c\u0001\u0018\u0002:%\u0019\u00111H\u0018\u0003\u000f\t{w\u000e\\3b]\"9\u0011q\b\u0001\u0005\u0002\u0005\u0005\u0013\u0001H:fi\ncWM\u001c3fI\u00063x-\u00138gY\u0016\u001cG/[8o!>Lg\u000e\u001e\u000b\u0005\u0003\u0013\t\u0019\u0005\u0003\u0005\u0002\u000e\u0005u\u0002\u0019AA#!\rq\u0013qI\u0005\u0004\u0003\u0013z#A\u0002#pk\ndW\rC\u0004\u0002N\u0001!\t!a\u0014\u0002-M,GO\u00117f]\u0012,G-\u0011<h'6|w\u000e\u001e5j]\u001e$B!!\u0003\u0002R!A\u0011QBA&\u0001\u0004\t)\u0005C\u0004\u0002V\u0001!\t!a\u0016\u0002\u0011M,GOT8jg\u0016$B!!\u0003\u0002Z!A\u0011QBA*\u0001\u0004\t)\u0005C\u0004\u0002^\u0001!\t!a\u0018\u0002\u0019M,GOT8jg\u0016\u001cV-\u001a3\u0015\t\u0005%\u0011\u0011\r\u0005\t\u0003\u001b\tY\u00061\u0001\u0002dA\u0019a&!\u001a\n\u0007\u0005\u001dtF\u0001\u0003M_:<waBA6\u0005!\u0005\u0011QN\u0001\u0011\u0011JzE+\u0019:hKR,enY8eKJ\u00042\u0001IA8\r\u0019\t!\u0001#\u0001\u0002rMA\u0011qNA:\u0003s\ny\bE\u0002/\u0003kJ1!a\u001e0\u0005\u0019\te.\u001f*fMB!A%a\u001f<\u0013\r\ti(\n\u0002\u0016\t\u00164\u0017-\u001e7u!\u0006\u0014\u0018-\\:SK\u0006$\u0017M\u00197f!\rq\u0013\u0011Q\u0005\u0004\u0003\u0007{#\u0001D*fe&\fG.\u001b>bE2,\u0007bB\u001d\u0002p\u0011\u0005\u0011q\u0011\u000b\u0003\u0003[B!\"a#\u0002p\u0005\u0005I\u0011BAG\u0003-\u0011X-\u00193SKN|GN^3\u0015\u0005\u0005=\u0005\u0003BAI\u00037k!!a%\u000b\t\u0005U\u0015qS\u0001\u0005Y\u0006twM\u0003\u0002\u0002\u001a\u0006!!.\u0019<b\u0013\u0011\ti*a%\u0003\r=\u0013'.Z2u\u0001")
public class H2OTargetEncoder
extends Estimator<H2OTargetEncoderModel>
implements H2OTargetEncoderBase,
DefaultParamsWritable {
    private final String uid;
    private final NullableStringParam foldCol;
    private final Param<String> labelCol;
    private final StringArrayParam inputCols;
    private final Param<String> holdoutStrategy;
    private final BooleanParam blendedAvgEnabled;
    private final DoubleParam blendedAvgInflectionPoint;
    private final DoubleParam blendedAvgSmoothing;
    private final DoubleParam noise;
    private final LongParam noiseSeed;

    public static Object load(String string) {
        return H2OTargetEncoder$.MODULE$.load(string);
    }

    public static MLReader<H2OTargetEncoder> read() {
        return H2OTargetEncoder$.MODULE$.read();
    }

    public MLWriter write() {
        return DefaultParamsWritable.class.write((DefaultParamsWritable)this);
    }

    public void save(String path) throws IOException {
        MLWritable.class.save((MLWritable)this, (String)path);
    }

    @Override
    public StructType transformSchema(StructType schema) {
        return H2OTargetEncoderBase$class.transformSchema(this, schema);
    }

    @Override
    public void convertRelevantColumnsToCategorical(Frame frame) {
        H2OTargetEncoderBase$class.convertRelevantColumnsToCategorical(this, frame);
    }

    @Override
    public final NullableStringParam foldCol() {
        return this.foldCol;
    }

    @Override
    public final Param<String> labelCol() {
        return this.labelCol;
    }

    @Override
    public final StringArrayParam inputCols() {
        return this.inputCols;
    }

    @Override
    public final Param<String> holdoutStrategy() {
        return this.holdoutStrategy;
    }

    @Override
    public final BooleanParam blendedAvgEnabled() {
        return this.blendedAvgEnabled;
    }

    @Override
    public final DoubleParam blendedAvgInflectionPoint() {
        return this.blendedAvgInflectionPoint;
    }

    @Override
    public final DoubleParam blendedAvgSmoothing() {
        return this.blendedAvgSmoothing;
    }

    @Override
    public final DoubleParam noise() {
        return this.noise;
    }

    @Override
    public final LongParam noiseSeed() {
        return this.noiseSeed;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$foldCol_$eq(NullableStringParam x$1) {
        this.foldCol = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$labelCol_$eq(Param x$1) {
        this.labelCol = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$inputCols_$eq(StringArrayParam x$1) {
        this.inputCols = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$holdoutStrategy_$eq(Param x$1) {
        this.holdoutStrategy = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$blendedAvgEnabled_$eq(BooleanParam x$1) {
        this.blendedAvgEnabled = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$blendedAvgInflectionPoint_$eq(DoubleParam x$1) {
        this.blendedAvgInflectionPoint = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$blendedAvgSmoothing_$eq(DoubleParam x$1) {
        this.blendedAvgSmoothing = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$noise_$eq(DoubleParam x$1) {
        this.noise = x$1;
    }

    @Override
    public final void ai$h2o$sparkling$ml$params$H2OTargetEncoderParams$_setter_$noiseSeed_$eq(LongParam x$1) {
        this.noiseSeed = x$1;
    }

    @Override
    public final String[] possibleHoldoutStrategyValues() {
        return H2OTargetEncoderParams$class.possibleHoldoutStrategyValues(this);
    }

    @Override
    public String getFoldCol() {
        return H2OTargetEncoderParams$class.getFoldCol(this);
    }

    @Override
    public String getLabelCol() {
        return H2OTargetEncoderParams$class.getLabelCol(this);
    }

    @Override
    public String[] getInputCols() {
        return H2OTargetEncoderParams$class.getInputCols(this);
    }

    @Override
    public String[] getOutputCols() {
        return H2OTargetEncoderParams$class.getOutputCols(this);
    }

    @Override
    public String getHoldoutStrategy() {
        return H2OTargetEncoderParams$class.getHoldoutStrategy(this);
    }

    @Override
    public boolean getBlendedAvgEnabled() {
        return H2OTargetEncoderParams$class.getBlendedAvgEnabled(this);
    }

    @Override
    public double getBlendedAvgInflectionPoint() {
        return H2OTargetEncoderParams$class.getBlendedAvgInflectionPoint(this);
    }

    @Override
    public double getBlendedAvgSmoothing() {
        return H2OTargetEncoderParams$class.getBlendedAvgSmoothing(this);
    }

    @Override
    public double getNoise() {
        return H2OTargetEncoderParams$class.getNoise(this);
    }

    @Override
    public long getNoiseSeed() {
        return H2OTargetEncoderParams$class.getNoiseSeed(this);
    }

    public String uid() {
        return this.uid;
    }

    public H2OTargetEncoderModel fit(Dataset<?> dataset) {
        H2OContext h2oContext = H2OContext$.MODULE$.getOrCreate(SparkSession$.MODULE$.builder().getOrCreate());
        H2OFrame input = h2oContext.asH2OFrame(dataset.toDF());
        this.convertRelevantColumnsToCategorical((Frame)input);
        TargetEncoderModel targetEncoderModel = this.trainTargetEncodingModel((Frame)input);
        H2OTargetEncoderModel model = (H2OTargetEncoderModel)new H2OTargetEncoderModel(this.uid(), targetEncoderModel).setParent(this);
        return (H2OTargetEncoderModel)this.copyValues(model, this.copyValues$default$2());
    }

    private TargetEncoderModel trainTargetEncodingModel(Frame trainingFrame) {
        try {
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._withBlending = Predef$.MODULE$.boolean2Boolean(this.getBlendedAvgEnabled());
            targetEncoderParameters._blendingParams = new BlendingParams(this.getBlendedAvgInflectionPoint(), this.getBlendedAvgSmoothing());
            targetEncoderParameters._response_column = this.getLabelCol();
            targetEncoderParameters._teFoldColumnName = this.getFoldCol();
            targetEncoderParameters._columnNamesToEncode = this.getInputCols();
            targetEncoderParameters.setTrain(trainingFrame._key);
            TargetEncoderBuilder builder = new TargetEncoderBuilder(targetEncoderParameters);
            builder.trainModel().get();
            return builder.getTargetEncoderModel();
        }
        catch (Throwable throwable) {
            IllegalStateException illegalStateException;
            Throwable throwable2 = throwable;
            if (throwable2 instanceof IllegalStateException && (illegalStateException = (IllegalStateException)throwable2).getMessage().contains("We do not support multi-class target case")) {
                throw new RuntimeException("The label column can not contain more than two unique values.");
            }
            throw throwable;
        }
    }

    public H2OTargetEncoder copy(ParamMap extra) {
        return (H2OTargetEncoder)this.defaultCopy(extra);
    }

    public H2OTargetEncoder setFoldCol(String value) {
        return (H2OTargetEncoder)this.set(this.foldCol(), value);
    }

    public H2OTargetEncoder setLabelCol(String value) {
        return (H2OTargetEncoder)this.set(this.labelCol(), value);
    }

    public H2OTargetEncoder setInputCols(String[] values) {
        return (H2OTargetEncoder)this.set((Param)this.inputCols(), values);
    }

    public H2OTargetEncoder setHoldoutStrategy(String value) {
        return (H2OTargetEncoder)this.set(this.holdoutStrategy(), H2OAlgoParamsHelper$.MODULE$.getValidatedEnumValue(value, ClassTag$.MODULE$.apply(H2OTargetEncoderHoldoutStrategy.class)));
    }

    public H2OTargetEncoder setBlendedAvgEnabled(boolean value) {
        return (H2OTargetEncoder)this.set((Param)this.blendedAvgEnabled(), BoxesRunTime.boxToBoolean((boolean)value));
    }

    public H2OTargetEncoder setBlendedAvgInflectionPoint(double value) {
        return (H2OTargetEncoder)this.set((Param)this.blendedAvgInflectionPoint(), BoxesRunTime.boxToDouble((double)value));
    }

    public H2OTargetEncoder setBlendedAvgSmoothing(double value) {
        Predef$.MODULE$.require(value > 0.0, (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "The smoothing value has to be a positive number.";
            }
        });
        return (H2OTargetEncoder)this.set((Param)this.blendedAvgSmoothing(), BoxesRunTime.boxToDouble((double)value));
    }

    public H2OTargetEncoder setNoise(double value) {
        Predef$.MODULE$.require(value >= 0.0, (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final String apply() {
                return "Noise can't be a negative value.";
            }
        });
        return (H2OTargetEncoder)this.set((Param)this.noise(), BoxesRunTime.boxToDouble((double)value));
    }

    public H2OTargetEncoder setNoiseSeed(long value) {
        return (H2OTargetEncoder)this.set((Param)this.noiseSeed(), BoxesRunTime.boxToLong((long)value));
    }

    public H2OTargetEncoder(String uid) {
        this.uid = uid;
        H2OTargetEncoderParams$class.$init$(this);
        H2OTargetEncoderBase$class.$init$(this);
        MLWritable.class.$init$((MLWritable)this);
        DefaultParamsWritable.class.$init$((DefaultParamsWritable)this);
    }

    public H2OTargetEncoder() {
        this(Identifiable$.MODULE$.randomUID("H2OTargetEncoder"));
    }
}

