/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.sparkling.examples;

import ai.h2o.sparkling.ml.algos.H2ODeepLearning;
import ai.h2o.sparkling.ml.algos.H2OGBM;
import ai.h2o.sparkling.ml.models.H2OSupervisedMOJOModel;
import ai.h2o.sparkling.ml.models.H2OTreeBasedSupervisedMOJOModel;
import ai.h2o.sparkling.ml.params.H2OAlgoSharedTreeParams;
import ai.h2o.sparkling.ml.params.H2ODeepLearningParams;
import java.io.File;
import org.apache.spark.h2o.H2OContext$;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession$;
import org.apache.spark.sql.functions$;
import scala.Predef$;
import scala.StringContext;
import scala.Symbol;
import scala.Symbol$;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.runtime.BoxesRunTime;

public final class AirlinesWithWeatherDemo$ {
    public static final AirlinesWithWeatherDemo$ MODULE$;
    private static Symbol symbol$1;
    private static Symbol symbol$2;

    static {
        symbol$1 = Symbol$.MODULE$.apply("Date");
        symbol$2 = Symbol$.MODULE$.apply("Dest");
        new AirlinesWithWeatherDemo$();
    }

    public void main(String[] args) {
        SparkSession spark = SparkSession$.MODULE$.builder().appName("Join of Airlines with Weather Data").getOrCreate();
        String weatherDataPath = "./examples/smalldata/chicago/Chicago_Ohare_International_Airport.csv";
        String weatherDataFile = new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"file://", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{new File(weatherDataPath).getAbsolutePath()}));
        Dataset weatherTable = spark.read().option("header", "true").option("inferSchema", "true").csv(weatherDataFile).withColumn("Date", functions$.MODULE$.to_date(functions$.MODULE$.regexp_replace((Column)spark.implicits().symbolToColumn(symbol$1), "(\\d+)/(\\d+)/(\\d+)", "$3-$2-$1"))).withColumn("Year", functions$.MODULE$.year((Column)spark.implicits().symbolToColumn(symbol$1))).withColumn("Month", functions$.MODULE$.month((Column)spark.implicits().symbolToColumn(symbol$1))).withColumn("DayofMonth", functions$.MODULE$.dayofmonth((Column)spark.implicits().symbolToColumn(symbol$1)));
        String airlinesDataPath = "./examples/smalldata/airlines/allyears2k_headers.csv";
        String airlinesDataFile = new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"file://", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{new File(airlinesDataPath).getAbsolutePath()}));
        Dataset airlinesTable = spark.read().option("header", "true").option("inferSchema", "true").option("nullValue", "NA").csv(airlinesDataFile);
        Dataset flightsToORD = airlinesTable.filter(spark.implicits().symbolToColumn(symbol$2).$eq$eq$eq((Object)"ORD"));
        Predef$.MODULE$.println((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"\\nFlights to ORD: ", "\\n"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToLong((long)flightsToORD.count())})));
        Dataset joined = flightsToORD.join(weatherTable, (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Year", "Month", "DayofMonth"})));
        H2OContext$.MODULE$.getOrCreate();
        H2ODeepLearning dl = (H2ODeepLearning)((H2ODeepLearningParams)new H2ODeepLearning().setLabelCol("ArrDelay").setSplitRatio(0.8)).setEpochs(5.0).setHidden(new int[]{100, 100}).setActivation("RectifierWithDropout");
        H2OSupervisedMOJOModel deepLearningModel = dl.fit(joined);
        Row[] predictionsFromDL = (Row[])deepLearningModel.transform(joined).select("prediction", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])).collect();
        Predef$.MODULE$.println((Object)Predef$.MODULE$.refArrayOps((Object[])predictionsFromDL).mkString("\n===> Model predictions from DL: ", ", ", ", ...\n"));
        H2OGBM gbm = (H2OGBM)((H2OAlgoSharedTreeParams)new H2OGBM().setLabelCol("ArrDelay").setSplitRatio(0.8)).setNtrees(100);
        H2OTreeBasedSupervisedMOJOModel gbmModel = gbm.fit(joined);
        Row[] predictionsFromGBM = (Row[])gbmModel.transform(joined).select("prediction", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[0])).collect();
        Predef$.MODULE$.println((Object)Predef$.MODULE$.refArrayOps((Object[])predictionsFromGBM).mkString("\n===> Model predictions from GBM: ", ", ", ", ...\n"));
    }

    private AirlinesWithWeatherDemo$() {
        MODULE$ = this;
    }
}

