/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.spark.planner;

import com.facebook.airlift.json.JsonCodec;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.execution.TaskSource;
import com.facebook.presto.execution.scheduler.TableWriteInfo;
import com.facebook.presto.spark.PrestoSparkSessionProperties;
import com.facebook.presto.spark.PrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.IntegerIdentityPartitioner;
import com.facebook.presto.spark.classloader_interface.PrestoSparkRow;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.SerializedPrestoSparkTaskDescriptor;
import com.facebook.presto.spark.classloader_interface.SerializedTaskStats;
import com.facebook.presto.spark.classloader_interface.TaskProcessors;
import com.facebook.presto.spark.planner.PrestoSparkPlan;
import com.facebook.presto.spark.planner.PrestoSparkSubPlan;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.SystemPartitioningHandle;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.inject.Inject;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.util.CollectionAccumulator;

public class PrestoSparkRddFactory {
    private final JsonCodec<PrestoSparkTaskDescriptor> sparkTaskRequestJsonCodec;

    @Inject
    public PrestoSparkRddFactory(JsonCodec<PrestoSparkTaskDescriptor> sparkTaskRequestJsonCodec) {
        this.sparkTaskRequestJsonCodec = Objects.requireNonNull(sparkTaskRequestJsonCodec, "sparkTaskRequestJsonCodec is null");
    }

    public JavaPairRDD<Integer, PrestoSparkRow> createSparkRdd(JavaSparkContext sparkContext, Session session, PrestoSparkPlan prestoSparkPlan, PrestoSparkTaskExecutorFactoryProvider taskExecutorFactoryProvider, CollectionAccumulator<SerializedTaskStats> taskStatsCollector) {
        RddFactory rddFactory = new RddFactory(session, this.sparkTaskRequestJsonCodec, sparkContext, taskExecutorFactoryProvider, PrestoSparkSessionProperties.getSparkInitialPartitionCount(session), SystemSessionProperties.getHashPartitionCount((Session)session), taskStatsCollector, prestoSparkPlan.getTableWriteInfo());
        return rddFactory.createRdd(prestoSparkPlan.getPlan());
    }

    private static class RddFactory {
        private final Session session;
        private final JsonCodec<PrestoSparkTaskDescriptor> sparkTaskDescriptorJsonCodec;
        private final JavaSparkContext sparkContext;
        private final PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider;
        private final int initialSparkPartitionCount;
        private final int hashPartitionCount;
        private final CollectionAccumulator<SerializedTaskStats> taskStatsCollector;
        private final TableWriteInfo tableWriteInfo;

        private RddFactory(Session session, JsonCodec<PrestoSparkTaskDescriptor> sparkTaskDescriptorJsonCodec, JavaSparkContext sparkContext, PrestoSparkTaskExecutorFactoryProvider executorFactoryProvider, int initialSparkPartitionCount, int hashPartitionCount, CollectionAccumulator<SerializedTaskStats> taskStatsCollector, TableWriteInfo tableWriteInfo) {
            this.session = Objects.requireNonNull(session, "session is null");
            this.sparkTaskDescriptorJsonCodec = Objects.requireNonNull(sparkTaskDescriptorJsonCodec, "sparkTaskDescriptorJsonCodec is null");
            this.sparkContext = Objects.requireNonNull(sparkContext, "sparkContext is null");
            this.executorFactoryProvider = Objects.requireNonNull(executorFactoryProvider, "executorFactoryProvider is null");
            this.initialSparkPartitionCount = initialSparkPartitionCount;
            this.hashPartitionCount = hashPartitionCount;
            this.taskStatsCollector = Objects.requireNonNull(taskStatsCollector, "taskStatsCollector is null");
            this.tableWriteInfo = Objects.requireNonNull(tableWriteInfo, "tableWriteInfo is null");
        }

        public JavaPairRDD<Integer, PrestoSparkRow> createRdd(PrestoSparkSubPlan subPlan) {
            PlanFragment fragment = subPlan.getFragment().getPartitioningScheme().getPartitioning().getHandle().equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) ? subPlan.getFragment().withBucketToPartition(Optional.of(IntStream.range(0, this.hashPartitionCount).toArray())) : subPlan.getFragment();
            Preconditions.checkArgument((!fragment.getStageExecutionDescriptor().isStageGroupedExecution() ? 1 : 0) != 0, (String)"unexpected grouped execution fragment: %s", (Object)fragment.getId());
            List tableScans = fragment.getTableScanSchedulingOrder();
            List remoteSources = fragment.getRemoteSourceNodes();
            Preconditions.checkArgument((tableScans.isEmpty() || remoteSources.isEmpty() ? 1 : 0) != 0, (Object)"stages that have both, remote sources and table scans, are not supported");
            if (!tableScans.isEmpty()) {
                Preconditions.checkArgument((boolean)fragment.getPartitioning().equals((Object)SystemPartitioningHandle.SOURCE_DISTRIBUTION), (String)"unexpected table scan partitioning: %s", (Object)fragment.getPartitioning());
                List scheduledSplits = (List)subPlan.getTaskSources().stream().flatMap(taskSource -> taskSource.getSplits().stream()).collect(ImmutableList.toImmutableList());
                List<List<ScheduledSplit>> assignedSplits = RddFactory.assignSplitsToTasks(scheduledSplits, this.initialSparkPartitionCount);
                List serializedRequests = (List)assignedSplits.stream().map(splits -> this.createTaskDescriptor(fragment, (List<ScheduledSplit>)splits)).map(arg_0 -> this.sparkTaskDescriptorJsonCodec.toJsonBytes(arg_0)).map(SerializedPrestoSparkTaskDescriptor::new).collect(ImmutableList.toImmutableList());
                return this.sparkContext.parallelize(serializedRequests, this.initialSparkPartitionCount).mapPartitionsToPair(TaskProcessors.createTaskProcessor((PrestoSparkTaskExecutorFactoryProvider)this.executorFactoryProvider, this.taskStatsCollector));
            }
            List<PrestoSparkSubPlan> children = subPlan.getChildren();
            Preconditions.checkArgument((remoteSources.size() == children.size() ? 1 : 0) != 0, (String)"number of remote sources doesn't match the number of child stages: %s != %s", (int)remoteSources.size(), (int)children.size());
            if (children.size() == 1) {
                PrestoSparkSubPlan childSubPlan = (PrestoSparkSubPlan)Iterables.getOnlyElement(children);
                JavaPairRDD<Integer, PrestoSparkRow> childRdd = this.createRdd(childSubPlan);
                PartitioningHandle partitioning = fragment.getPartitioning();
                if (partitioning.equals((Object)SystemPartitioningHandle.COORDINATOR_DISTRIBUTION)) {
                    return childRdd;
                }
                PlanFragment childFragment = childSubPlan.getFragment();
                RemoteSourceNode remoteSource = (RemoteSourceNode)Iterables.getOnlyElement((Iterable)remoteSources);
                List sourceFragmentIds = remoteSource.getSourceFragmentIds();
                Preconditions.checkArgument((sourceFragmentIds.size() == 1 ? 1 : 0) != 0, (Object)"expected to have exactly only a single source fragment");
                Preconditions.checkArgument((boolean)childFragment.getId().equals(Iterables.getOnlyElement((Iterable)sourceFragmentIds)));
                PrestoSparkTaskDescriptor taskDescriptor = this.createTaskDescriptor(fragment, (List<ScheduledSplit>)ImmutableList.of());
                SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(this.sparkTaskDescriptorJsonCodec.toJsonBytes((Object)taskDescriptor));
                if (partitioning.equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) || partitioning.equals((Object)SystemPartitioningHandle.SINGLE_DISTRIBUTION)) {
                    String planNodeId = remoteSource.getId().toString();
                    return childRdd.partitionBy((Partitioner)(partitioning.equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) ? new IntegerIdentityPartitioner(this.hashPartitionCount) : new IntegerIdentityPartitioner(1))).mapPartitionsToPair(TaskProcessors.createTaskProcessor((PrestoSparkTaskExecutorFactoryProvider)this.executorFactoryProvider, (SerializedPrestoSparkTaskDescriptor)serializedTaskDescriptor, (String)planNodeId, this.taskStatsCollector));
                }
                throw new IllegalArgumentException("Unsupported fragment partitioning: " + partitioning);
            }
            if (children.size() == 2) {
                PrestoSparkSubPlan leftSubPlan = children.get(0);
                PrestoSparkSubPlan rightSubPlan = children.get(1);
                RemoteSourceNode leftRemoteSource = (RemoteSourceNode)remoteSources.get(0);
                RemoteSourceNode rightRemoteSource = (RemoteSourceNode)remoteSources.get(1);
                String leftRemoteSourcePlanId = leftRemoteSource.getId().toString();
                String rightRemoteSourcePlanId = rightRemoteSource.getId().toString();
                JavaPairRDD<Integer, PrestoSparkRow> leftChildRdd = this.createRdd(leftSubPlan);
                JavaPairRDD<Integer, PrestoSparkRow> rightChildRdd = this.createRdd(rightSubPlan);
                PlanFragment leftFragment = leftSubPlan.getFragment();
                PlanFragment rightFragment = rightSubPlan.getFragment();
                List leftFragmentIds = leftRemoteSource.getSourceFragmentIds();
                Preconditions.checkArgument((leftFragmentIds.size() == 1 ? 1 : 0) != 0, (Object)"expected to have exactly only a single source fragment");
                Preconditions.checkArgument((boolean)leftFragment.getId().equals(Iterables.getOnlyElement((Iterable)leftFragmentIds)));
                List rightFragmentIds = rightRemoteSource.getSourceFragmentIds();
                Preconditions.checkArgument((rightFragmentIds.size() == 1 ? 1 : 0) != 0, (Object)"expected to have exactly only a single source fragment");
                Preconditions.checkArgument((boolean)rightFragment.getId().equals(Iterables.getOnlyElement((Iterable)rightFragmentIds)));
                PrestoSparkTaskDescriptor taskDescriptor = this.createTaskDescriptor(fragment, (List<ScheduledSplit>)ImmutableList.of());
                SerializedPrestoSparkTaskDescriptor serializedTaskDescriptor = new SerializedPrestoSparkTaskDescriptor(this.sparkTaskDescriptorJsonCodec.toJsonBytes((Object)taskDescriptor));
                PartitioningHandle partitioning = fragment.getPartitioning();
                Preconditions.checkArgument((boolean)partitioning.equals((Object)SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION));
                JavaPairRDD shuffledLeftChildRdd = leftChildRdd.partitionBy((Partitioner)new IntegerIdentityPartitioner(this.hashPartitionCount));
                JavaPairRDD shuffledRightChildRdd = rightChildRdd.partitionBy((Partitioner)new IntegerIdentityPartitioner(this.hashPartitionCount));
                return JavaPairRDD.fromJavaRDD((JavaRDD)shuffledLeftChildRdd.zipPartitions((JavaRDDLike)shuffledRightChildRdd, TaskProcessors.createTaskProcessor((PrestoSparkTaskExecutorFactoryProvider)this.executorFactoryProvider, (SerializedPrestoSparkTaskDescriptor)serializedTaskDescriptor, (String)leftRemoteSourcePlanId, (String)rightRemoteSourcePlanId, this.taskStatsCollector)));
            }
            throw new UnsupportedOperationException();
        }

        private static List<List<ScheduledSplit>> assignSplitsToTasks(List<ScheduledSplit> scheduledSplits, int numTasks) {
            ArrayList<List<ScheduledSplit>> assignedSplits = new ArrayList<List<ScheduledSplit>>();
            for (int i = 0; i < numTasks; ++i) {
                assignedSplits.add(new ArrayList());
            }
            for (ScheduledSplit split : scheduledSplits) {
                int taskId = Objects.hash(split.getPlanNodeId(), split.getSequenceId()) % numTasks;
                if (taskId < 0) {
                    taskId += numTasks;
                }
                ((List)assignedSplits.get(taskId)).add(split);
            }
            return assignedSplits;
        }

        private PrestoSparkTaskDescriptor createTaskDescriptor(PlanFragment fragment, List<ScheduledSplit> splits) {
            Map splitsByPlanNode = splits.stream().collect(Collectors.groupingBy(ScheduledSplit::getPlanNodeId, Collectors.mapping(Function.identity(), Collectors.toSet())));
            List taskSourceByPlanNode = (List)splitsByPlanNode.entrySet().stream().map(entry -> new TaskSource((PlanNodeId)entry.getKey(), (Set)entry.getValue(), (Set)ImmutableSet.of(), true)).collect(ImmutableList.toImmutableList());
            return new PrestoSparkTaskDescriptor(this.session.toSessionRepresentation(), this.session.getIdentity().getExtraCredentials(), fragment, taskSourceByPlanNode, this.tableWriteInfo);
        }
    }
}

