/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.spark;

import io.prestosql.hive.$internal.org.slf4j.Logger;
import io.prestosql.hive.$internal.org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.hive.common.ObjectPair;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.FileSinkOperator;
import org.apache.hadoop.hive.ql.exec.LimitOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorUtils;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.TerminalOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkUtilities;
import org.apache.hadoop.hive.ql.exec.spark.session.SparkSession;
import org.apache.hadoop.hive.ql.exec.spark.session.SparkSessionManagerImpl;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.parse.spark.OptimizeSparkProcContext;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.FileSinkDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.stats.StatsUtils;

public class SetSparkReducerParallelism
implements NodeProcessor {
    private static final Logger LOG = LoggerFactory.getLogger(SetSparkReducerParallelism.class.getName());
    private static final String SPARK_DYNAMIC_ALLOCATION_ENABLED = "spark.dynamicAllocation.enabled";
    private ObjectPair<Long, Integer> sparkMemoryAndCores = null;
    private final boolean useOpStats;

    public SetSparkReducerParallelism(HiveConf conf) {
        this.useOpStats = conf.getBoolVar(HiveConf.ConfVars.SPARK_USE_OP_STATS);
    }

    @Override
    public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procContext, Object ... nodeOutputs) throws SemanticException {
        OptimizeSparkProcContext context = (OptimizeSparkProcContext)procContext;
        ReduceSinkOperator sink = (ReduceSinkOperator)nd;
        ReduceSinkDesc desc = (ReduceSinkDesc)sink.getConf();
        Set<ReduceSinkOperator> parentSinks = null;
        int maxReducers = context.getConf().getIntVar(HiveConf.ConfVars.MAXREDUCERS);
        int constantReducers = context.getConf().getIntVar(HiveConf.ConfVars.HADOOPNUMREDUCERS);
        if (!this.useOpStats) {
            parentSinks = OperatorUtils.findOperatorsUpstream(sink, ReduceSinkOperator.class);
            parentSinks.remove(sink);
            if (!context.getVisitedReduceSinks().containsAll(parentSinks)) {
                LOG.debug("Skipping sink " + sink + " for now as we haven't seen all its parents.");
                return false;
            }
        }
        if (context.getVisitedReduceSinks().contains(sink)) {
            LOG.debug("Already processed reduce sink: " + sink.getName());
            return true;
        }
        context.getVisitedReduceSinks().add(sink);
        if (this.needSetParallelism(sink, context.getConf())) {
            if (constantReducers > 0) {
                LOG.info("Parallelism for reduce sink " + sink + " set by user to " + constantReducers);
                desc.setNumReducers(constantReducers);
            } else {
                FileSinkOperator fso = GenSparkUtils.getChildOperator(sink, FileSinkOperator.class);
                if (fso != null) {
                    int numBuckets;
                    String bucketCount = ((FileSinkDesc)fso.getConf()).getTableInfo().getProperties().getProperty("bucket_count");
                    int n = numBuckets = bucketCount == null ? 0 : Integer.parseInt(bucketCount);
                    if (numBuckets > 0) {
                        LOG.info("Set parallelism for reduce sink " + sink + " to: " + numBuckets + " (buckets)");
                        desc.setNumReducers(numBuckets);
                        return false;
                    }
                }
                if (this.useOpStats || parentSinks.isEmpty()) {
                    long numberOfBytes = 0L;
                    if (this.useOpStats) {
                        for (Operator<OperatorDesc> sibling : sink.getChildOperators().get(0).getParentOperators()) {
                            if (sibling.getStatistics() != null) {
                                numberOfBytes = StatsUtils.safeAdd(numberOfBytes, sibling.getStatistics().getDataSize());
                                if (!LOG.isDebugEnabled()) continue;
                                LOG.debug("Sibling " + sibling + " has stats: " + sibling.getStatistics());
                                continue;
                            }
                            LOG.warn("No stats available from: " + sibling);
                        }
                    } else {
                        for (Operator<OperatorDesc> sibling : sink.getChildOperators().get(0).getParentOperators()) {
                            Set<TableScanOperator> sources = OperatorUtils.findOperatorsUpstream(sibling, TableScanOperator.class);
                            for (TableScanOperator source : sources) {
                                if (source.getStatistics() != null) {
                                    numberOfBytes = StatsUtils.safeAdd(numberOfBytes, source.getStatistics().getDataSize());
                                    if (!LOG.isDebugEnabled()) continue;
                                    LOG.debug("Table source " + source + " has stats: " + source.getStatistics());
                                    continue;
                                }
                                LOG.warn("No stats available from table source: " + source);
                            }
                        }
                        LOG.debug("Gathered stats for sink " + sink + ". Total size is " + numberOfBytes + " bytes.");
                    }
                    long bytesPerReducer = context.getConf().getLongVar(HiveConf.ConfVars.BYTESPERREDUCER) / 2L;
                    int numReducers = Utilities.estimateReducers(numberOfBytes, bytesPerReducer, maxReducers, false);
                    this.getSparkMemoryAndCores(context);
                    if (this.sparkMemoryAndCores != null && this.sparkMemoryAndCores.getFirst() > 0L && this.sparkMemoryAndCores.getSecond() > 0) {
                        if ((double)this.sparkMemoryAndCores.getFirst().longValue() / (double)bytesPerReducer < 0.5) {
                            LOG.warn("Average load of a reducer is much larger than its available memory. Consider decreasing hive.exec.reducers.bytes.per.reducer");
                        }
                        numReducers = Math.max(numReducers, this.sparkMemoryAndCores.getSecond());
                    }
                    numReducers = Math.min(numReducers, maxReducers);
                    LOG.info("Set parallelism for reduce sink " + sink + " to: " + numReducers + " (calculated)");
                    desc.setNumReducers(numReducers);
                } else {
                    int numberOfReducers = 0;
                    for (ReduceSinkOperator parent : parentSinks) {
                        numberOfReducers = Math.max(numberOfReducers, ((ReduceSinkDesc)parent.getConf()).getNumReducers());
                    }
                    desc.setNumReducers(numberOfReducers);
                    LOG.debug("Set parallelism for sink " + sink + " to " + numberOfReducers + " based on its parents");
                }
                Collection<ExprNodeDesc.ExprNodeDescEqualityWrapper> keyCols = ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(desc.getKeyCols());
                Collection<ExprNodeDesc.ExprNodeDescEqualityWrapper> partCols = ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(desc.getPartitionCols());
                if (keyCols != null && keyCols.equals(partCols)) {
                    desc.setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.UNIFORM));
                }
            }
        } else {
            LOG.info("Number of reducers for sink " + sink + " was already determined to be: " + desc.getNumReducers());
        }
        return false;
    }

    private boolean needSetParallelism(ReduceSinkOperator reduceSink, HiveConf hiveConf) {
        ReduceSinkDesc desc = (ReduceSinkDesc)reduceSink.getConf();
        if (desc.getNumReducers() <= 0) {
            return true;
        }
        if (desc.getNumReducers() == 1 && desc.hasOrderBy() && hiveConf.getBoolVar(HiveConf.ConfVars.HIVESAMPLINGFORORDERBY) && !desc.isDeduplicated()) {
            Stack<Operator<OperatorDesc>> descendants = new Stack<Operator<OperatorDesc>>();
            List<Operator<OperatorDesc>> children = reduceSink.getChildOperators();
            if (children != null) {
                for (Operator<OperatorDesc> child : children) {
                    descendants.push(child);
                }
            }
            while (descendants.size() != 0) {
                List<Operator<OperatorDesc>> childrenOfDescendant;
                Operator descendant = (Operator)descendants.pop();
                if (descendant instanceof LimitOperator) {
                    return false;
                }
                boolean reachTerminalOperator = descendant instanceof TerminalOperator;
                if (reachTerminalOperator || (childrenOfDescendant = descendant.getChildOperators()) == null) continue;
                for (Operator<OperatorDesc> childOfDescendant : childrenOfDescendant) {
                    descendants.push(childOfDescendant);
                }
            }
            return true;
        }
        return false;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void getSparkMemoryAndCores(OptimizeSparkProcContext context) throws SemanticException {
        if (this.sparkMemoryAndCores != null) {
            return;
        }
        if (context.getConf().getBoolean(SPARK_DYNAMIC_ALLOCATION_ENABLED, false)) {
            this.sparkMemoryAndCores = null;
            return;
        }
        SparkSessionManagerImpl sparkSessionManager = null;
        SparkSession sparkSession = null;
        try {
            sparkSessionManager = SparkSessionManagerImpl.getInstance();
            sparkSession = SparkUtilities.getSparkSession(context.getConf(), sparkSessionManager);
            this.sparkMemoryAndCores = sparkSession.getMemoryAndCores();
        }
        catch (HiveException e) {
            throw new SemanticException("Failed to get a spark session: " + e);
        }
        catch (Exception e) {
            LOG.warn("Failed to get spark memory/core info", e);
        }
        finally {
            if (sparkSession != null && sparkSessionManager != null) {
                try {
                    sparkSessionManager.returnSession(sparkSession);
                }
                catch (HiveException ex) {
                    LOG.error("Failed to return the session to SessionManager: " + ex, ex);
                }
            }
        }
    }
}

