/*
 * Decompiled with CFR 0.152.
 */
package io.cdap.mmds.manager;

import com.google.common.base.Joiner;
import io.cdap.cdap.api.ServiceDiscoverer;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.plugin.PluginConfigurer;
import io.cdap.cdap.api.plugin.PluginContext;
import io.cdap.cdap.api.plugin.PluginProperties;
import io.cdap.cdap.api.spark.service.SparkHttpServicePluginContext;
import io.cdap.cdap.api.spark.sql.DataFrames;
import io.cdap.cdap.etl.api.PipelineConfigurer;
import io.cdap.cdap.etl.api.Transform;
import io.cdap.mmds.NullableMath;
import io.cdap.mmds.data.ColumnSplitStats;
import io.cdap.mmds.data.DataSplitInfo;
import io.cdap.mmds.manager.WranglerFunction;
import io.cdap.mmds.manager.WranglerPipelineConfigurer;
import io.cdap.mmds.splitter.DataSplitResult;
import io.cdap.mmds.splitter.DatasetSplitter;
import io.cdap.mmds.splitter.ToCatHisto;
import io.cdap.mmds.splitter.ToDoubleValues;
import io.cdap.mmds.splitter.ToNumericHisto;
import io.cdap.mmds.stats.CategoricalHisto;
import io.cdap.mmds.stats.NumericHisto;
import io.cdap.mmds.stats.NumericStats;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.twill.filesystem.Location;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public class DataSplitStatsGenerator
implements AutoCloseable {
    private static final Logger LOG = LoggerFactory.getLogger(DataSplitStatsGenerator.class);
    private final SparkSession sparkSession;
    private final DatasetSplitter splitter;
    private final SparkHttpServicePluginContext pluginContext;
    private final ServiceDiscoverer serviceDiscoverer;
    private final PipelineConfigurer pipelineConfigurer;

    public DataSplitStatsGenerator(SparkSession sparkSession, DatasetSplitter splitter, SparkHttpServicePluginContext pluginContext, ServiceDiscoverer serviceDiscoverer) {
        this.sparkSession = sparkSession;
        this.splitter = splitter;
        this.pluginContext = pluginContext;
        this.serviceDiscoverer = serviceDiscoverer;
        this.pipelineConfigurer = new WranglerPipelineConfigurer((PluginConfigurer)pluginContext);
    }

    public DataSplitResult split(DataSplitInfo dataSplitInfo) throws IOException {
        PluginProperties wranglerProperties = PluginProperties.builder().add("schema", dataSplitInfo.getDataSplit().getSchema().toString()).add("field", "*").add("directives", Joiner.on((String)"\n").join((Iterable)dataSplitInfo.getDataSplit().getDirectives())).add("threshold", "-1").add("precondition", "false").build();
        Transform wrangler = (Transform)this.pluginContext.usePlugin("transform", "Wrangler", "wrangler", wranglerProperties);
        if (wrangler == null) {
            throw new IllegalStateException("Could not find wrangler plugin. Please make sure it has been deployed with MMDS as a parent.");
        }
        wrangler.configurePipeline(this.pipelineConfigurer);
        Schema schema = dataSplitInfo.getDataSplit().getSchema();
        JavaRDD rowRDD = this.sparkSession.read().format("text").load(dataSplitInfo.getExperiment().getSrcpath()).javaRDD().flatMap((FlatMapFunction)new WranglerFunction(schema, (PluginContext)this.pluginContext, this.serviceDiscoverer));
        StructType rowType = (StructType)DataFrames.toDataType((Schema)schema);
        Dataset rawData = this.sparkSession.createDataFrame(rowRDD, rowType).cache();
        long start = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        Dataset[] split = this.splitter.split(rawData, dataSplitInfo.getDataSplit().getParams());
        Dataset trainingSplit = split[0].cache();
        Dataset testSplit = split[1].cache();
        Location splitLocation = dataSplitInfo.getSplitLocation();
        Location trainingLocation = splitLocation.append("train");
        Location testLocation = splitLocation.append("test");
        String trainingPath = trainingLocation.toURI().getPath();
        String testPath = testLocation.toURI().getPath();
        trainingSplit.write().mode(SaveMode.Overwrite).format("parquet").save(trainingPath);
        testSplit.write().mode(SaveMode.Overwrite).format("parquet").save(testPath);
        long splitEnd = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to split = {} seconds", (Object)(splitEnd - start));
        List<ColumnSplitStats> stats = this.getStats((Dataset<Row>)trainingSplit, (Dataset<Row>)testSplit, dataSplitInfo.getDataSplit().getSchema());
        long statsEnd = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to get stats = {} seconds", (Object)(statsEnd - splitEnd));
        return new DataSplitResult(trainingPath, testPath, stats);
    }

    private List<ColumnSplitStats> getStats(Dataset<Row> train, Dataset<Row> test, Schema schema) {
        ArrayList<ColumnSplitStats> stats = new ArrayList<ColumnSplitStats>(schema.getFields().size());
        ArrayList<Column> categoricalColumns = new ArrayList<Column>();
        ArrayList<String> categoricalNames = new ArrayList<String>();
        ArrayList<Column> numericColumns = new ArrayList<Column>();
        ArrayList<String> numericNames = new ArrayList<String>();
        for (Schema.Field field : schema.getFields()) {
            String fieldName = field.getName();
            Schema fieldSchema = field.getSchema();
            fieldSchema = fieldSchema.isNullable() ? fieldSchema.getNonNullable() : fieldSchema;
            Schema.Type fieldType = fieldSchema.getType();
            Column col = new Column(fieldName);
            switch (fieldType) {
                case BOOLEAN: {
                    categoricalColumns.add(col.cast(DataTypes.StringType));
                    categoricalNames.add(fieldName);
                    break;
                }
                case STRING: {
                    categoricalColumns.add(col);
                    categoricalNames.add(fieldName);
                    break;
                }
                case INT: 
                case LONG: 
                case FLOAT: {
                    numericColumns.add(col.cast(DataTypes.DoubleType));
                    numericNames.add(fieldName);
                    break;
                }
                case DOUBLE: {
                    numericColumns.add(col);
                    numericNames.add(fieldName);
                }
            }
        }
        int numCategorical = categoricalColumns.size();
        int numNumeric = numericColumns.size();
        Dataset trainCategoricalSplit = train.select(categoricalColumns.toArray(new Column[numCategorical]));
        Dataset testCategoricalSplit = test.select(categoricalColumns.toArray(new Column[numCategorical]));
        Dataset trainNumericSplit = train.select(numericColumns.toArray(new Column[numNumeric]));
        Dataset testNumericSplit = test.select(numericColumns.toArray(new Column[numNumeric]));
        long start = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        Map trainCategoricalHistograms = trainCategoricalSplit.javaRDD().flatMapToPair((PairFlatMapFunction)new ToCatHisto(categoricalNames)).reduceByKey(CategoricalHisto::merge, categoricalColumns.size()).collectAsMap();
        Map testCategoricalHistograms = testCategoricalSplit.javaRDD().flatMapToPair((PairFlatMapFunction)new ToCatHisto(categoricalNames)).reduceByKey(CategoricalHisto::merge, categoricalColumns.size()).collectAsMap();
        long catEnd = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to get categorical stats = {} seconds", (Object)(catEnd - start));
        for (Map.Entry entry : trainCategoricalHistograms.entrySet()) {
            String columnName = (String)entry.getKey();
            CategoricalHisto trainHisto = (CategoricalHisto)entry.getValue();
            CategoricalHisto testHisto = (CategoricalHisto)testCategoricalHistograms.get(columnName);
            stats.add(new ColumnSplitStats(columnName, trainHisto, testHisto));
        }
        JavaPairRDD trainNumericValues = trainNumericSplit.javaRDD().flatMapToPair((PairFlatMapFunction)new ToDoubleValues(numericNames));
        JavaPairRDD testNumericValues = testNumericSplit.javaRDD().flatMapToPair((PairFlatMapFunction)new ToDoubleValues(numericNames));
        Map trainNumericStats = trainNumericValues.mapValues(NumericStats::new).reduceByKey(NumericStats::merge, numNumeric).collectAsMap();
        Map testNumericStats = testNumericValues.mapValues(NumericStats::new).reduceByKey(NumericStats::merge, numNumeric).collectAsMap();
        HashMap<String, Tuple2> columnMinMax = new HashMap<String, Tuple2>();
        for (Map.Entry entry : trainNumericStats.entrySet()) {
            String column = (String)entry.getKey();
            NumericStats trainStats = (NumericStats)entry.getValue();
            NumericStats testStats = (NumericStats)testNumericStats.get(column);
            Double min = NullableMath.min((Double)trainStats.getMin(), (Double)testStats.getMin());
            Double max = NullableMath.max((Double)trainStats.getMax(), (Double)testStats.getMax());
            columnMinMax.put(column, new Tuple2((Object)min, (Object)max));
        }
        long numericEnd = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to get numeric stats, 1st pass = {} seconds", (Object)(numericEnd - catEnd));
        Map trainNumericHistos = trainNumericValues.mapToPair((PairFunction)new ToNumericHisto(columnMinMax)).reduceByKey(NumericHisto::merge, numericColumns.size()).collectAsMap();
        Map testNumericHistos = testNumericValues.mapToPair((PairFunction)new ToNumericHisto(columnMinMax)).reduceByKey(NumericHisto::merge, numericColumns.size()).collectAsMap();
        long numericStatsEnd = TimeUnit.SECONDS.convert(System.currentTimeMillis(), TimeUnit.MILLISECONDS);
        LOG.info("Time to get numeric stats, 2nd pass = {} seconds", (Object)(numericStatsEnd - numericEnd));
        for (Map.Entry entry : trainNumericHistos.entrySet()) {
            String columnName = (String)entry.getKey();
            NumericHisto trainHisto = (NumericHisto)entry.getValue();
            NumericHisto testHisto = (NumericHisto)testNumericHistos.get(columnName);
            stats.add(new ColumnSplitStats(columnName, trainHisto, testHisto));
        }
        return stats;
    }

    @Override
    public void close() throws Exception {
        this.pluginContext.close();
    }
}

