/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.sparkling.examples;

import ai.h2o.sparkling.H2OContext$;
import ai.h2o.sparkling.ml.algos.H2OAlgorithm;
import ai.h2o.sparkling.ml.algos.H2OAutoML;
import ai.h2o.sparkling.ml.algos.H2ODeepLearning;
import ai.h2o.sparkling.ml.algos.H2OGBM;
import ai.h2o.sparkling.ml.algos.H2OGridSearch;
import ai.h2o.sparkling.ml.algos.H2OXGBoost;
import ai.h2o.sparkling.ml.features.ColumnPruner;
import ai.h2o.sparkling.ml.models.H2OMOJOModel;
import ai.h2o.sparkling.ml.params.H2OAutoMLStoppingCriteriaParams;
import ai.h2o.sparkling.ml.params.H2OCommonParams;
import ai.h2o.sparkling.ml.params.H2ODeepLearningParams;
import ai.h2o.sparkling.ml.params.H2OGBMParams;
import ai.h2o.sparkling.ml.params.H2OXGBoostParams;
import java.io.File;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.Pipeline$;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineModel$;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IDF;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function1;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.Map;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

public final class HamOrSpamDemo$ {
    public static final HamOrSpamDemo$ MODULE$;

    static {
        new HamOrSpamDemo$();
    }

    public void main(String[] args) {
        SparkSession spark = SparkSession$.MODULE$.builder().appName("Ham or Spam Pipeline Demo").getOrCreate();
        String smsDataPath = "./examples/smalldata/smsData.txt";
        String smsDataFile = new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"file://", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{new File(smsDataPath).getAbsolutePath()}));
        Dataset<Row> data = this.load(spark, smsDataFile);
        H2OContext$.MODULE$.getOrCreate();
        RegexTokenizer tokenizer = this.createTokenizer();
        StopWordsRemover stopWordsRemover = this.createStopWordsRemover(tokenizer);
        HashingTF hashingTF = this.createHashingTF(stopWordsRemover);
        IDF idf = this.createIDF(hashingTF);
        ColumnPruner columnPruner = this.createColumnPruner(hashingTF, stopWordsRemover, tokenizer);
        ColumnPruner columnPruner2 = (ColumnPruner)new ColumnPruner().setColumns((String[])((Object[])new String[]{idf.getOutputCol()}));
        Estimator[] estimators = (Estimator[])((Object[])new Estimator[]{this.gbm(), this.deepLearning(), this.autoML(), this.gridSearch(), this.xgboost()});
        Predef$.MODULE$.refArrayOps((Object[])estimators).foreach((Function1)new Serializable(spark, data, tokenizer, stopWordsRemover, hashingTF, idf, columnPruner, columnPruner2){
            public static final long serialVersionUID = 0L;
            private final SparkSession spark$1;
            private final Dataset data$1;
            private final RegexTokenizer tokenizer$1;
            private final StopWordsRemover stopWordsRemover$1;
            private final HashingTF hashingTF$1;
            private final IDF idf$1;
            private final ColumnPruner columnPruner$1;
            private final ColumnPruner columnPruner2$1;

            public final void apply(Estimator<H2OMOJOModel> estimator) {
                PipelineStage[] stages = (PipelineStage[])((Object[])new PipelineStage[]{this.tokenizer$1, this.stopWordsRemover$1, this.hashingTF$1, this.idf$1, this.columnPruner$1, estimator, this.columnPruner2$1});
                Pipeline pipeline = HamOrSpamDemo$.MODULE$.createPipeline(stages);
                PipelineModel model = HamOrSpamDemo$.MODULE$.trainPipeline(pipeline, (Dataset<Row>)this.data$1);
                HamOrSpamDemo$.MODULE$.assertPredictions(this.spark$1, model);
            }
            {
                this.spark$1 = spark$1;
                this.data$1 = data$1;
                this.tokenizer$1 = tokenizer$1;
                this.stopWordsRemover$1 = stopWordsRemover$1;
                this.hashingTF$1 = hashingTF$1;
                this.idf$1 = idf$1;
                this.columnPruner$1 = columnPruner$1;
                this.columnPruner2$1 = columnPruner2$1;
            }
        });
    }

    public Pipeline createPipeline(PipelineStage[] stages) {
        Pipeline pipeline = new Pipeline().setStages(stages);
        pipeline.write().overwrite().save("examples/build/pipeline");
        return Pipeline$.MODULE$.load("examples/build/pipeline");
    }

    public void assertPredictions(SparkSession spark, PipelineModel model) {
        Predef$.MODULE$.assert(!this.isSpam(spark, "Michal, h2oworld party tonight in MV?", model));
        Predef$.MODULE$.assert(this.isSpam(spark, "We tried to contact you re your reply to our offer of a Video Handset? 750 anytime any networks mins? UNLIMITED TEXT?", model));
    }

    public boolean isSpam(SparkSession spark, String smsText, PipelineModel model) {
        StructType smsTextSchema = new StructType((StructField[])((Object[])new StructField[]{new StructField("text", (DataType)StringType$.MODULE$, false, StructField$.MODULE$.apply$default$4())}));
        RDD smsTextRowRDD = spark.sparkContext().parallelize((Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{smsText})), spark.sparkContext().parallelize$default$2(), ClassTag$.MODULE$.apply(String.class)).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Row apply(String x$1) {
                return Row$.MODULE$.apply((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{x$1}));
            }
        }, ClassTag$.MODULE$.apply(Row.class));
        Dataset smsTextDF = spark.createDataFrame(smsTextRowRDD, smsTextSchema);
        Dataset prediction = model.transform(smsTextDF);
        String string = ((Row)prediction.select("prediction", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])).first()).getString(0);
        String string2 = "spam";
        return !(string != null ? !string.equals(string2) : string2 != null);
    }

    public Dataset<Row> load(SparkSession spark, String dataFile) {
        StructType smsSchema = new StructType((StructField[])((Object[])new StructField[]{new StructField("label", (DataType)StringType$.MODULE$, false, StructField$.MODULE$.apply$default$4()), new StructField("text", (DataType)StringType$.MODULE$, false, StructField$.MODULE$.apply$default$4())}));
        RDD rowRDD = spark.sparkContext().textFile(dataFile, spark.sparkContext().textFile$default$2()).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String[] apply(String x$2) {
                return x$2.split("\t", 2);
            }
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(String.class))).filter((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final boolean apply(String[] r) {
                return !r[0].isEmpty();
            }
        }).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Row apply(String[] p) {
                return Row$.MODULE$.apply((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{p[0], p[1]}));
            }
        }, ClassTag$.MODULE$.apply(Row.class));
        return spark.createDataFrame(rowRDD, smsSchema);
    }

    public RegexTokenizer createTokenizer() {
        return ((RegexTokenizer)new RegexTokenizer().setInputCol("text").setOutputCol("words")).setMinTokenLength(3).setGaps(false).setPattern("[a-zA-Z]+");
    }

    public StopWordsRemover createStopWordsRemover(RegexTokenizer tokenizer) {
        return new StopWordsRemover().setInputCol(tokenizer.getOutputCol()).setOutputCol("filtered").setStopWords((String[])((Object[])new String[]{"the", "a", "", "in", "on", "at", "as", "not", "for"})).setCaseSensitive(false);
    }

    public HashingTF createHashingTF(StopWordsRemover stopWordsRemover) {
        return new HashingTF().setNumFeatures(1024).setInputCol(stopWordsRemover.getOutputCol()).setOutputCol("wordToIndex");
    }

    public IDF createIDF(HashingTF hashingTF) {
        return new IDF().setMinDocFreq(4).setInputCol(hashingTF.getOutputCol()).setOutputCol("tf_idf");
    }

    public ColumnPruner createColumnPruner(HashingTF hashingTF, StopWordsRemover stopWordsRemover, RegexTokenizer tokenizer) {
        return (ColumnPruner)new ColumnPruner().setColumns((String[])((Object[])new String[]{hashingTF.getOutputCol(), stopWordsRemover.getOutputCol(), tokenizer.getOutputCol()}));
    }

    public PipelineModel trainPipeline(Pipeline pipeline, Dataset<Row> data) {
        PipelineModel model = pipeline.fit(data);
        model.write().overwrite().save("build/examples/model");
        return PipelineModel$.MODULE$.load("build/examples/model");
    }

    public H2OGBM gbm() {
        return (H2OGBM)((H2OGBMParams)((H2OCommonParams)((H2OGBMParams)new H2OGBM().setSplitRatio(0.8)).setSeed(1L)).setFeaturesCols("tf_idf", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0]))).setLabelCol("label");
    }

    public H2ODeepLearning deepLearning() {
        return (H2ODeepLearning)((H2ODeepLearningParams)((H2OCommonParams)new H2ODeepLearning().setEpochs(10.0).setL1(0.001).setL2(0.0).setSeed(1L).setHidden((int[])Array$.MODULE$.apply((Seq)Predef$.MODULE$.wrapIntArray(new int[]{200, 200}), ClassTag$.MODULE$.Int()))).setFeaturesCols("tf_idf", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0]))).setLabelCol("label");
    }

    public H2OAutoML autoML() {
        return (H2OAutoML)((H2OCommonParams)((H2OAutoMLStoppingCriteriaParams)new H2OAutoML().setLabelCol("label")).setSeed(1L).setMaxRuntimeSecs(6000.0).setMaxModels(10)).setConvertUnknownCategoricalLevelsToNa(true);
    }

    public H2OGridSearch gridSearch() {
        Map hyperParams = (Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"ntrees"), Predef$.MODULE$.intArrayOps(new int[]{1, 30}).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Object apply(int x$3) {
                return BoxesRunTime.boxToInteger((int)x$3);
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.AnyRef())))}));
        H2OGBM algo = (H2OGBM)((H2OCommonParams)((H2OGBMParams)((H2OCommonParams)new H2OGBM().setMaxDepth(6).setSeed(1L)).setFeaturesCols("tf_idf", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0]))).setLabelCol("label")).setConvertUnknownCategoricalLevelsToNa(true);
        return (H2OGridSearch)new H2OGridSearch().setHyperParameters(hyperParams).setAlgo((H2OAlgorithm)algo);
    }

    public H2OXGBoost xgboost() {
        return (H2OXGBoost)((H2OCommonParams)((H2OXGBoostParams)new H2OXGBoost().setFeaturesCols("tf_idf", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0]))).setLabelCol("label")).setConvertUnknownCategoricalLevelsToNa(true);
    }

    private HamOrSpamDemo$() {
        MODULE$ = this;
    }
}

