/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.spark.transform;

import java.util.Comparator;
import java.util.List;
import org.apache.commons.math3.util.Pair;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.datavec.api.transform.DataAction;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.join.Join;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.reduce.IReducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.ConvertToSequence;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.analysis.SequenceFlatMapFunction;
import org.datavec.spark.transform.filter.SparkFilterFunction;
import org.datavec.spark.transform.join.ExecuteJoinFunction;
import org.datavec.spark.transform.join.FilterAndFlattenJoinedValues;
import org.datavec.spark.transform.join.MapToJoinValuesFunction;
import org.datavec.spark.transform.misc.ColumnAsKeyPairFunction;
import org.datavec.spark.transform.rank.UnzipForCalculateSortedRankFunction;
import org.datavec.spark.transform.reduce.MapToPairForReducerFunction;
import org.datavec.spark.transform.reduce.ReducerFunction;
import org.datavec.spark.transform.sequence.SparkGroupToSequenceFunction;
import org.datavec.spark.transform.sequence.SparkMapToPairByColumnFunction;
import org.datavec.spark.transform.sequence.SparkSequenceFilterFunction;
import org.datavec.spark.transform.sequence.SparkSequenceTransformFunction;
import org.datavec.spark.transform.transform.SequenceSplitFunction;
import org.datavec.spark.transform.transform.SparkTransformFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SparkTransformExecutor {
    private static final Logger log = LoggerFactory.getLogger(SparkTransformExecutor.class);

    public JavaRDD<List<Writable>> execute(JavaRDD<List<Writable>> inputWritables, TransformProcess sequence) {
        if (sequence.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        return (JavaRDD)this.execute(inputWritables, null, sequence).getFirst();
    }

    public JavaRDD<List<List<Writable>>> executeToSequence(JavaRDD<List<Writable>> inputWritables, TransformProcess sequence) {
        if (!(sequence.getFinalSchema() instanceof SequenceSchema)) {
            throw new IllegalStateException("Cannot return non-sequence data with this method");
        }
        return (JavaRDD)this.execute(inputWritables, null, sequence).getSecond();
    }

    public JavaRDD<List<Writable>> executeSequenceToSeparate(JavaRDD<List<List<Writable>>> inputSequence, TransformProcess sequence) {
        if (sequence.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        return (JavaRDD)this.execute(null, inputSequence, sequence).getFirst();
    }

    public JavaRDD<List<List<Writable>>> executeSequenceToSequence(JavaRDD<List<List<Writable>>> inputSequence, TransformProcess sequence) {
        if (!(sequence.getFinalSchema() instanceof SequenceSchema)) {
            throw new IllegalStateException("Cannot return non-sequence data with this method");
        }
        return (JavaRDD)this.execute(null, inputSequence, sequence).getSecond();
    }

    private Pair<JavaRDD<List<Writable>>, JavaRDD<List<List<Writable>>>> execute(JavaRDD<List<Writable>> inputWritables, JavaRDD<List<List<Writable>>> inputSequence, TransformProcess sequence) {
        JavaRDD currentWritables = inputWritables;
        JavaRDD currentSequence = inputSequence;
        List list = sequence.getActionList();
        int count = 1;
        for (DataAction d : list) {
            log.info("Starting execution of stage {} of {}", (Object)count, (Object)list.size());
            if (d.getTransform() != null) {
                Object function;
                Transform t = d.getTransform();
                if (currentWritables != null) {
                    function = new SparkTransformFunction(t);
                    currentWritables = currentWritables.map((Function)function);
                } else {
                    function = new SparkSequenceTransformFunction(t);
                    currentSequence = currentSequence.map((Function)function);
                }
            } else if (d.getFilter() != null) {
                Filter f = d.getFilter();
                if (currentWritables != null) {
                    currentWritables = currentWritables.filter((Function)new SparkFilterFunction(f));
                } else {
                    currentSequence = currentSequence.filter((Function)new SparkSequenceFilterFunction(f));
                }
            } else if (d.getConvertToSequence() != null) {
                ConvertToSequence cts = d.getConvertToSequence();
                Schema schema = cts.getInputSchema();
                int colIdx = schema.getIndexOfColumn(cts.getKeyColumn());
                JavaPairRDD withKey = currentWritables.mapToPair((PairFunction)new SparkMapToPairByColumnFunction(colIdx));
                JavaPairRDD grouped = withKey.groupByKey();
                currentSequence = grouped.map((Function)new SparkGroupToSequenceFunction(cts.getComparator()));
                currentWritables = null;
            } else if (d.getConvertFromSequence() != null) {
                if (currentSequence == null) {
                    throw new IllegalStateException("Cannot execute ConvertFromSequence operation: current sequence is null");
                }
                currentWritables = currentSequence.flatMap((FlatMapFunction)new SequenceFlatMapFunction());
                currentSequence = null;
            } else if (d.getSequenceSplit() != null) {
                SequenceSplit sequenceSplit = d.getSequenceSplit();
                if (currentSequence == null) {
                    throw new IllegalStateException("Error during execution of SequenceSplit: currentSequence is null");
                }
                currentSequence = currentSequence.flatMap((FlatMapFunction)new SequenceSplitFunction(sequenceSplit));
            } else if (d.getReducer() != null) {
                IReducer reducer = d.getReducer();
                if (currentWritables == null) {
                    throw new IllegalStateException("Error during execution of reduction: current writables are null. Trying to execute a reduce operation on a sequence?");
                }
                JavaPairRDD pair = currentWritables.mapToPair((PairFunction)new MapToPairForReducerFunction(reducer));
                currentWritables = pair.groupByKey().map((Function)new ReducerFunction(reducer));
            } else if (d.getCalculateSortedRank() != null) {
                CalculateSortedRank csr = d.getCalculateSortedRank();
                if (currentWritables == null) {
                    throw new IllegalStateException("Error during execution of CalculateSortedRank: current writables are null. Trying to execute a CalculateSortedRank operation on a sequenc? (not currently supported)");
                }
                Comparator comparator = csr.getComparator();
                String sortColumn = csr.getSortOnColumn();
                int sortColumnIdx = csr.getInputSchema().getIndexOfColumn(sortColumn);
                boolean ascending = csr.isAscending();
                JavaPairRDD pairRDD = currentWritables.mapToPair((PairFunction)new ColumnAsKeyPairFunction(sortColumnIdx));
                pairRDD = pairRDD.sortByKey(comparator, ascending);
                JavaPairRDD zipped = pairRDD.zipWithIndex();
                currentWritables = zipped.map((Function)new UnzipForCalculateSortedRankFunction());
            } else {
                throw new RuntimeException("Unknown/not implemented action: " + d);
            }
            ++count;
        }
        log.info("Completed {} of {} execution steps", (Object)(count - 1), (Object)list.size());
        return new Pair(currentWritables, currentSequence);
    }

    public JavaRDD<List<Writable>> executeJoin(Join join, JavaRDD<List<Writable>> left, JavaRDD<List<Writable>> right) {
        JavaPairRDD leftJV = left.mapToPair((PairFunction)new MapToJoinValuesFunction(true, join));
        JavaPairRDD rightJV = right.mapToPair((PairFunction)new MapToJoinValuesFunction(false, join));
        JavaPairRDD both = leftJV.union(rightJV);
        JavaPairRDD grouped = both.groupByKey();
        JavaRDD joined = grouped.map((Function)new ExecuteJoinFunction(join));
        return joined.flatMap((FlatMapFunction)new FilterAndFlattenJoinedValues(join.getJoinType()));
    }
}

