/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.h2o;

import hex.Model;
import hex.ModelMetrics;
import hex.ModelMetricsSupervised;
import hex.splitframe.ShuffleSplitFrame;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.io.File;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.examples.h2o.DemoUtils$;
import org.apache.spark.examples.h2o.TimeSplit;
import org.apache.spark.examples.h2o.TimeTransform;
import org.apache.spark.h2o.H2OContext;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SchemaRDD;
import org.apache.spark.sql.catalyst.expressions.Row;
import scala.Array$;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Symbol;
import scala.Symbol$;
import scala.collection.Seq;
import scala.collection.immutable.StringOps;
import scala.math.Ordering;
import scala.package$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import water.Key;
import water.Keyed;
import water.fvec.DataFrame;
import water.fvec.Frame;

public final class CitiBikeSharingDemo$ {
    public static final CitiBikeSharingDemo$ MODULE$;

    static {
        new CitiBikeSharingDemo$();
    }

    public void main(String[] args) {
        SparkConf conf = DemoUtils$.MODULE$.configure("Sparkling Water Meetup: Predict occupation of citi bike station in NYC");
        SparkContext sc = new SparkContext(conf);
        H2OContext h2oContext = new H2OContext(sc).start();
        SQLContext sqlContext = new SQLContext(sc);
        DataFrame dataf = new DataFrame(new File("/Users/michal/Devel/projects/h2o/repos/h2o2/bigdata/laptop/citibike-nyc/2013-09.csv"));
        DataFrame startTimeF = dataf.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Symbol[]{Symbol$.MODULE$.apply("starttime")}));
        dataf.add((Frame)new TimeSplit().doIt(startTimeF));
        Predef$.MODULE$.println((Object)dataf);
        SchemaRDD brdd = h2oContext.asSchemaRDD(dataf, sqlContext);
        sqlContext.registerRDDAsTable(brdd, "brdd");
        SchemaRDD bph = sqlContext.sql(new StringOps(Predef$.MODULE$.augmentString("SELECT Days, start_station_id, count(*) bikes\n        |FROM brdd\n        |GROUP BY Days, start_station_id ")).stripMargin());
        Predef$.MODULE$.println((Object)Predef$.MODULE$.refArrayOps((Object[])bph.take(10)).mkString("\n"));
        DataFrame bphf = h2oContext.createDataFrame(bph);
        DataFrame daysVec = bphf.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Symbol[]{Symbol$.MODULE$.apply("Days")}));
        Frame finalTable = bphf.add((Frame)new TimeTransform().doIt(daysVec));
        Predef$.MODULE$.println((Object)finalTable);
        Key[] keys = (Key[])Predef$.MODULE$.refArrayOps((Object[])new String[]{"train.hex", "test.hex", "hold.hex"}).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Key<? extends Keyed<? extends Keyed<?>>> apply(String x$1) {
                return Key.make((String)x$1);
            }
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Key.class)));
        double[] ratios = (double[])Array$.MODULE$.apply((Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{0.6, 0.3, 0.1}), ClassTag$.MODULE$.Double());
        Frame[] frs = ShuffleSplitFrame.shuffleSplitFrame((Frame)finalTable, (Key[])keys, (double[])ratios, (long)1234567689L);
        Frame train = frs[0];
        Frame test = frs[1];
        Frame hold = frs[2];
        dataf.delete();
        GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters();
        gbmParams._train = h2oContext.dataFrameToKey(train);
        gbmParams._valid = h2oContext.dataFrameToKey(test);
        gbmParams._response_column = h2oContext.symbolToString(Symbol$.MODULE$.apply("bikes"));
        gbmParams._ntrees = 500;
        gbmParams._max_depth = 6;
        GBM gbm = new GBM(gbmParams);
        GBMModel gbmModel = (GBMModel)gbm.trainModel().get();
        gbmModel.score(train).remove();
        gbmModel.score(test).remove();
        gbmModel.score(hold).remove();
        Predef$.MODULE$.println((Object)new StringOps(Predef$.MODULE$.augmentString(new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"\n         |r2 on train: ", "\n         |r2 on test:  ", "\n         |r2 on hold:  ", "\""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToDouble((double)this.r2(gbmModel, train)), BoxesRunTime.boxToDouble((double)this.r2(gbmModel, test)), BoxesRunTime.boxToDouble((double)this.r2(gbmModel, hold))})))).stripMargin());
        sc.stop();
    }

    public double r2(GBMModel model, Frame fr) {
        return ((ModelMetricsSupervised)ModelMetrics.getFromDKV((Model)model, (Frame)fr)).r2();
    }

    public void basicStats(SchemaRDD brdd, SQLContext sqlContext) {
        brdd.first();
        brdd.count();
        sqlContext.registerRDDAsTable(brdd, "brdd");
        SchemaRDD tGBduration = sqlContext.sql("select bikeid, sum(tripduration) from brdd group by bikeid");
        Row[] bottom10 = (Row[])tGBduration.sortBy((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final long apply(Row r) {
                return r.getLong(1);
            }
        }, tGBduration.sortBy$default$2(), tGBduration.sortBy$default$3(), (Ordering)Ordering.Long$.MODULE$, ClassTag$.MODULE$.Long()).take(10);
        Row minDurationBikeId = (Row)tGBduration.min(package$.MODULE$.Ordering().by((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final long apply(Row r) {
                return r.getLong(1);
            }
        }, (Ordering)Ordering.Long$.MODULE$));
        Row row = bottom10[0];
        Row row2 = minDurationBikeId;
        Predef$.MODULE$.assert(!(row != null ? !row.equals(row2) : row2 != null));
        Row maxDurationBikeId = (Row)tGBduration.min(package$.MODULE$.Ordering().by((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final long apply(Row r) {
                return -r.getLong(1);
            }
        }, (Ordering)Ordering.Long$.MODULE$));
    }

    private CitiBikeSharingDemo$() {
        MODULE$ = this;
    }
}

