/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.spark.transform;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.DoubleFunction;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.datavec.api.transform.analysis.DataAnalysis;
import org.datavec.api.transform.analysis.DataVecAnalysisUtils;
import org.datavec.api.transform.analysis.SequenceDataAnalysis;
import org.datavec.api.transform.analysis.quality.QualityAnalysisAddFunction;
import org.datavec.api.transform.analysis.quality.QualityAnalysisCombineFunction;
import org.datavec.api.transform.analysis.quality.QualityAnalysisState;
import org.datavec.api.transform.analysis.sequence.SequenceLengthAnalysis;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.quality.DataQualityAnalysis;
import org.datavec.api.transform.quality.columns.ColumnQuality;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.WritableType;
import org.datavec.api.writable.comparator.Comparators;
import org.datavec.spark.transform.analysis.SelectColumnFunction;
import org.datavec.spark.transform.analysis.SequenceFlatMapFunction;
import org.datavec.spark.transform.analysis.SequenceLengthFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisAddFunction;
import org.datavec.spark.transform.analysis.aggregate.AnalysisCombineFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramAddFunction;
import org.datavec.spark.transform.analysis.histogram.HistogramCombineFunction;
import org.datavec.spark.transform.analysis.seqlength.IntToDoubleFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisAddFunction;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisCounter;
import org.datavec.spark.transform.analysis.seqlength.SequenceLengthAnalysisMergeFunction;
import org.datavec.spark.transform.analysis.unique.UniqueAddFunction;
import org.datavec.spark.transform.analysis.unique.UniqueMergeFunction;
import org.datavec.spark.transform.filter.FilterWritablesBySchemaFunction;
import org.datavec.spark.transform.misc.ColumnToKeyPairTransform;
import org.datavec.spark.transform.misc.SumLongsFunction2;
import org.datavec.spark.transform.misc.comparator.Tuple2Comparator;
import org.datavec.spark.transform.utils.adapter.BiFunctionAdapter;
import scala.Tuple2;

public class AnalyzeSpark {
    public static final int DEFAULT_HISTOGRAM_BUCKETS = 30;

    public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
        return AnalyzeSpark.analyzeSequence(schema, data, 30);
    }

    public static SequenceDataAnalysis analyzeSequence(Schema schema, JavaRDD<List<List<Writable>>> data, int maxHistogramBuckets) {
        JavaDoubleRDD drdd;
        Tuple2 hist;
        data.cache();
        JavaRDD fmSeq = data.flatMap((FlatMapFunction)new SequenceFlatMapFunction());
        DataAnalysis da = AnalyzeSpark.analyze(schema, (JavaRDD<List<Writable>>)fmSeq);
        JavaRDD seqLengths = data.map((Function)new SequenceLengthFunction());
        seqLengths.cache();
        SequenceLengthAnalysisCounter counter = new SequenceLengthAnalysisCounter();
        counter = (SequenceLengthAnalysisCounter)seqLengths.aggregate((Object)counter, (Function2)new SequenceLengthAnalysisAddFunction(), (Function2)new SequenceLengthAnalysisMergeFunction());
        int max = counter.getMaxLengthSeen();
        int min = counter.getMinLengthSeen();
        int nBuckets = counter.getMaxLengthSeen() - counter.getMinLengthSeen();
        if (max == min) {
            hist = new Tuple2((Object)new double[]{min}, (Object)new long[]{counter.getCountTotal()});
        } else if (nBuckets < maxHistogramBuckets) {
            drdd = seqLengths.mapToDouble((DoubleFunction)new IntToDoubleFunction());
            hist = drdd.histogram(nBuckets);
        } else {
            drdd = seqLengths.mapToDouble((DoubleFunction)new IntToDoubleFunction());
            hist = drdd.histogram(maxHistogramBuckets);
        }
        seqLengths.unpersist();
        SequenceLengthAnalysis lengthAnalysis = SequenceLengthAnalysis.builder().totalNumSequences(counter.getCountTotal()).minSeqLength(counter.getMinLengthSeen()).maxSeqLength(counter.getMaxLengthSeen()).countZeroLength(counter.getCountZeroLength()).countOneLength(counter.getCountOneLength()).meanLength(counter.getMean()).histogramBuckets((double[])hist._1()).histogramBucketCounts((long[])hist._2()).build();
        return new SequenceDataAnalysis(schema, da.getColumnAnalysis(), lengthAnalysis);
    }

    public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> data) {
        return AnalyzeSpark.analyze(schema, data, 30);
    }

    public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> data, int maxHistogramBuckets) {
        data.cache();
        List columnTypes = schema.getColumnTypes();
        List counters = (List)data.aggregate(null, (Function2)new AnalysisAddFunction(schema), (Function2)new AnalysisCombineFunction());
        double[][] minsMaxes = new double[counters.size()][2];
        List list = DataVecAnalysisUtils.convertCounters((List)counters, (double[][])minsMaxes, (List)columnTypes);
        List histogramCounters = (List)data.aggregate(null, (Function2)new HistogramAddFunction(maxHistogramBuckets, schema, minsMaxes), (Function2)new HistogramCombineFunction());
        DataVecAnalysisUtils.mergeCounters((List)list, (List)histogramCounters);
        return new DataAnalysis(schema, list);
    }

    public static List<Writable> sampleFromColumn(int count, String columnName, Schema schema, JavaRDD<List<Writable>> data) {
        int colIdx = schema.getIndexOfColumn(columnName);
        JavaRDD ithColumn = data.map((Function)new SelectColumnFunction(colIdx));
        return ithColumn.takeSample(false, count);
    }

    public static List<Writable> sampleFromColumnSequence(int count, String columnName, Schema schema, JavaRDD<List<List<Writable>>> sequenceData) {
        JavaRDD flattenedSequence = sequenceData.flatMap((FlatMapFunction)new SequenceFlatMapFunction());
        return AnalyzeSpark.sampleFromColumn(count, columnName, schema, (JavaRDD<List<Writable>>)flattenedSequence);
    }

    public static List<Writable> getUnique(String columnName, Schema schema, JavaRDD<List<Writable>> data) {
        int colIdx = schema.getIndexOfColumn(columnName);
        JavaRDD ithColumn = data.map((Function)new SelectColumnFunction(colIdx));
        return ithColumn.distinct().collect();
    }

    public static Map<String, List<Writable>> getUnique(List<String> columnNames, Schema schema, JavaRDD<List<Writable>> data) {
        Map m = (Map)data.aggregate(null, (Function2)new UniqueAddFunction(columnNames, schema), (Function2)new UniqueMergeFunction());
        HashMap<String, List<Writable>> out = new HashMap<String, List<Writable>>();
        for (String s : m.keySet()) {
            out.put(s, new ArrayList((Collection)m.get(s)));
        }
        return out;
    }

    public static List<Writable> getUniqueSequence(String columnName, Schema schema, JavaRDD<List<List<Writable>>> sequenceData) {
        JavaRDD flattenedSequence = sequenceData.flatMap((FlatMapFunction)new SequenceFlatMapFunction());
        return AnalyzeSpark.getUnique(columnName, schema, (JavaRDD<List<Writable>>)flattenedSequence);
    }

    public static Map<String, List<Writable>> getUniqueSequence(List<String> columnNames, Schema schema, JavaRDD<List<List<Writable>>> sequenceData) {
        JavaRDD flattenedSequence = sequenceData.flatMap((FlatMapFunction)new SequenceFlatMapFunction());
        return AnalyzeSpark.getUnique(columnNames, schema, (JavaRDD<List<Writable>>)flattenedSequence);
    }

    public static List<List<Writable>> sample(int count, JavaRDD<List<Writable>> data) {
        return data.takeSample(false, count);
    }

    public static List<List<List<Writable>>> sampleSequence(int count, JavaRDD<List<List<Writable>>> data) {
        return data.takeSample(false, count);
    }

    public static DataQualityAnalysis analyzeQualitySequence(Schema schema, JavaRDD<List<List<Writable>>> data) {
        JavaRDD fmSeq = data.flatMap((FlatMapFunction)new SequenceFlatMapFunction());
        return AnalyzeSpark.analyzeQuality(schema, (JavaRDD<List<Writable>>)fmSeq);
    }

    public static DataQualityAnalysis analyzeQuality(Schema schema, JavaRDD<List<Writable>> data) {
        int nColumns = schema.numColumns();
        List states = (List)data.aggregate(null, new BiFunctionAdapter(new QualityAnalysisAddFunction(schema)), new BiFunctionAdapter(new QualityAnalysisCombineFunction()));
        ArrayList<ColumnQuality> list = new ArrayList<ColumnQuality>(nColumns);
        for (QualityAnalysisState qualityState : states) {
            list.add(qualityState.getColumnQuality());
        }
        return new DataQualityAnalysis(schema, list);
    }

    public static List<Writable> sampleInvalidFromColumn(int numToSample, String columnName, Schema schema, JavaRDD<List<Writable>> data) {
        return AnalyzeSpark.sampleInvalidFromColumn(numToSample, columnName, schema, data, false);
    }

    public static List<Writable> sampleInvalidFromColumn(int numToSample, String columnName, Schema schema, JavaRDD<List<Writable>> data, boolean ignoreMissing) {
        int colIdx = schema.getIndexOfColumn(columnName);
        JavaRDD ithColumn = data.map((Function)new SelectColumnFunction(colIdx));
        ColumnMetaData meta = schema.getMetaData(columnName);
        JavaRDD invalid = ithColumn.filter((Function)new FilterWritablesBySchemaFunction(meta, false, ignoreMissing));
        return invalid.takeSample(false, numToSample);
    }

    public static List<Writable> sampleInvalidFromColumnSequence(int numToSample, String columnName, Schema schema, JavaRDD<List<List<Writable>>> data) {
        JavaRDD flattened = data.flatMap((FlatMapFunction)new SequenceFlatMapFunction());
        return AnalyzeSpark.sampleInvalidFromColumn(numToSample, columnName, schema, (JavaRDD<List<Writable>>)flattened);
    }

    public static Map<Writable, Long> sampleMostFrequentFromColumn(int nMostFrequent, String columnName, Schema schema, JavaRDD<List<Writable>> data) {
        int columnIdx = schema.getIndexOfColumn(columnName);
        JavaPairRDD keyedByWritable = data.mapToPair((PairFunction)new ColumnToKeyPairTransform(columnIdx));
        JavaPairRDD reducedByWritable = keyedByWritable.reduceByKey((Function2)new SumLongsFunction2());
        List list = reducedByWritable.takeOrdered(nMostFrequent, new Tuple2Comparator(false));
        ArrayList sorted = new ArrayList(list);
        Collections.sort(sorted, new Tuple2Comparator(false));
        LinkedHashMap<Writable, Long> map = new LinkedHashMap<Writable, Long>();
        for (Tuple2 t2 : sorted) {
            map.put((Writable)t2._1(), (Long)t2._2());
        }
        return map;
    }

    public static Writable min(JavaRDD<List<Writable>> allData, String columnName, Schema schema) {
        int columnIdx = schema.getIndexOfColumn(columnName);
        JavaRDD col = allData.map((Function)new SelectColumnFunction(columnIdx));
        return (Writable)col.min(Comparators.forType((WritableType)schema.getType(columnName).getWritableType()));
    }

    public static Writable max(JavaRDD<List<Writable>> allData, String columnName, Schema schema) {
        int columnIdx = schema.getIndexOfColumn(columnName);
        JavaRDD col = allData.map((Function)new SelectColumnFunction(columnIdx));
        return (Writable)col.max(Comparators.forType((WritableType)schema.getType(columnName).getWritableType()));
    }
}

