/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.common.type.Varchars;
import com.facebook.presto.metadata.CastType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.PrestoWarning;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.StandardWarningCode;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCodeSupplier;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizerResult;
import com.facebook.presto.sql.planner.plan.ChildReplacer;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.planner.plan.WindowNode;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.type.TypeUtils;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class KeyBasedSampler
implements PlanOptimizer {
    private final Metadata metadata;
    private boolean isEnabledForTesting;

    public KeyBasedSampler(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public void setEnabledForTesting(boolean isSet) {
        this.isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session) {
        return this.isEnabledForTesting || SystemSessionProperties.isKeyBasedSamplingEnabled(session);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        if (this.isEnabled(session)) {
            ArrayList sampledFields = new ArrayList(2);
            PlanNode rewritten = SimplePlanRewriter.rewriteWith(new Rewriter(session, this.metadata.getFunctionAndTypeManager(), idAllocator, sampledFields), plan, null);
            if (!this.isEnabledForTesting) {
                if (!sampledFields.isEmpty()) {
                    warningCollector.add(new PrestoWarning((WarningCodeSupplier)StandardWarningCode.SAMPLED_FIELDS, String.format("Sampled the following columns/derived columns at %s percent:%n\t%s", SystemSessionProperties.getKeyBasedSamplingPercentage(session) * 100.0, String.join((CharSequence)"\n\t", sampledFields))));
                } else {
                    warningCollector.add(new PrestoWarning((WarningCodeSupplier)StandardWarningCode.SEMANTIC_WARNING, "Sampling could not be performed due to the query structure"));
                }
            }
            return PlanOptimizerResult.optimizerResult(rewritten, true);
        }
        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final Session session;
        private final FunctionAndTypeManager functionAndTypeManager;
        private final PlanNodeIdAllocator idAllocator;
        private final List<String> sampledFields;

        private Rewriter(Session session, FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator idAllocator, List<String> sampledFields) {
            this.session = Objects.requireNonNull(session, "session is null");
            this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.sampledFields = Objects.requireNonNull(sampledFields, "sampledFields is null");
        }

        private PlanNode addSamplingFilter(PlanNode tableScanNode, Optional<VariableReferenceExpression> rowExpressionOptional, FunctionAndTypeManager functionAndTypeManager) {
            CallExpression sampledArg;
            if (!rowExpressionOptional.isPresent()) {
                return tableScanNode;
            }
            RowExpression rowExpression = (RowExpression)rowExpressionOptional.get();
            Type type = rowExpression.getType();
            Object arg = !Varchars.isVarcharType((Type)type) ? Expressions.call("CAST", functionAndTypeManager.lookupCast(CastType.CAST, rowExpression.getType(), (Type)VarcharType.VARCHAR), (Type)VarcharType.VARCHAR, rowExpression) : rowExpression;
            try {
                sampledArg = Expressions.call(functionAndTypeManager, SystemSessionProperties.getKeyBasedSamplingFunction(this.session), (Type)DoubleType.DOUBLE, new RowExpression[]{arg});
            }
            catch (PrestoException prestoException) {
                throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.FUNCTION_NOT_FOUND, String.format("Sampling function: %s not cannot be resolved", SystemSessionProperties.getKeyBasedSamplingFunction(this.session)), (Throwable)prestoException);
            }
            CallExpression predicate = Expressions.call(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL.name(), functionAndTypeManager.resolveOperator(OperatorType.LESS_THAN_OR_EQUAL, TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE, DoubleType.DOUBLE})), (Type)BooleanType.BOOLEAN, new RowExpression[]{sampledArg, new ConstantExpression(arg.getSourceLocation(), (Object)SystemSessionProperties.getKeyBasedSamplingPercentage(this.session), (Type)DoubleType.DOUBLE)});
            FilterNode filterNode = new FilterNode(tableScanNode.getSourceLocation(), this.idAllocator.getNextId(), tableScanNode, (RowExpression)predicate);
            while (tableScanNode instanceof FilterNode || tableScanNode instanceof ProjectNode) {
                tableScanNode = (PlanNode)tableScanNode.getSources().get(0);
            }
            String tableName = tableScanNode instanceof TableScanNode ? ((TableScanNode)tableScanNode).getTable().getConnectorHandle().toString() : "plan node: " + tableScanNode.getId();
            this.sampledFields.add(String.format("%s from %s", rowExpression, tableName));
            return filterNode;
        }

        private Optional<VariableReferenceExpression> findSuitableKey(List<VariableReferenceExpression> keys) {
            Optional<VariableReferenceExpression> variableReferenceExpression = keys.stream().filter(x -> TypeUtils.isIntegralType(x.getType().getTypeSignature(), this.functionAndTypeManager)).findFirst();
            if (!variableReferenceExpression.isPresent()) {
                variableReferenceExpression = keys.stream().filter(x -> Varchars.isVarcharType((Type)x.getType())).findFirst();
            }
            return variableReferenceExpression;
        }

        private PlanNode sampleSourceNodeWithKey(PlanNode planNode, PlanNode source, List<VariableReferenceExpression> keys) {
            PlanNode rewrittenSource = Rewriter.rewriteWith(this, source);
            if (rewrittenSource == source) {
                rewrittenSource = this.addSamplingFilter(source, this.findSuitableKey(keys), this.functionAndTypeManager);
            }
            return ChildReplacer.replaceChildren(planNode, (List<PlanNode>)ImmutableList.of((Object)rewrittenSource));
        }

        @Override
        public PlanNode visitJoin(JoinNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode left = node.getLeft();
            PlanNode right = node.getRight();
            PlanNode rewrittenLeft = Rewriter.rewriteWith(this, left);
            PlanNode rewrittenRight = Rewriter.rewriteWith(this, right);
            if (left == rewrittenLeft || right == rewrittenRight) {
                Optional<EquiJoinClause> equiJoinClause = node.getCriteria().stream().filter(x -> TypeUtils.isIntegralType(x.getLeft().getType().getTypeSignature(), this.functionAndTypeManager)).findFirst();
                if (!equiJoinClause.isPresent()) {
                    equiJoinClause = node.getCriteria().stream().filter(x -> Varchars.isVarcharType((Type)x.getLeft().getType())).findFirst();
                }
                if (equiJoinClause.isPresent()) {
                    rewrittenLeft = this.addSamplingFilter(rewrittenLeft, Optional.of(equiJoinClause.get().getLeft()), this.functionAndTypeManager);
                    rewrittenRight = this.addSamplingFilter(rewrittenRight, Optional.of(equiJoinClause.get().getRight()), this.functionAndTypeManager);
                }
            }
            return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)rewrittenLeft, (Object)rewrittenRight));
        }

        @Override
        public PlanNode visitSemiJoin(SemiJoinNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            PlanNode source = node.getSource();
            PlanNode filteringSource = node.getFilteringSource();
            PlanNode rewrittenSource = Rewriter.rewriteWith(this, source);
            PlanNode rewrittenFilteringSource = Rewriter.rewriteWith(this, filteringSource);
            if (rewrittenSource == source || rewrittenFilteringSource == filteringSource) {
                rewrittenSource = this.addSamplingFilter(rewrittenSource, this.findSuitableKey((List<VariableReferenceExpression>)ImmutableList.of((Object)node.getSourceJoinVariable())), this.functionAndTypeManager);
                rewrittenFilteringSource = this.addSamplingFilter(rewrittenFilteringSource, this.findSuitableKey((List<VariableReferenceExpression>)ImmutableList.of((Object)node.getFilteringSourceJoinVariable())), this.functionAndTypeManager);
            }
            return ChildReplacer.replaceChildren(node, (List<PlanNode>)ImmutableList.of((Object)rewrittenSource, (Object)rewrittenFilteringSource));
        }

        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            return this.sampleSourceNodeWithKey((PlanNode)node, node.getSource(), node.getGroupingKeys());
        }

        @Override
        public PlanNode visitWindow(WindowNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            return this.sampleSourceNodeWithKey(node, node.getSource(), node.getPartitionBy());
        }

        @Override
        public PlanNode visitRowNumber(RowNumberNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            return this.sampleSourceNodeWithKey(node, node.getSource(), node.getPartitionBy());
        }

        @Override
        public PlanNode visitTopNRowNumber(TopNRowNumberNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            return this.sampleSourceNodeWithKey(node, node.getSource(), node.getPartitionBy());
        }

        public PlanNode visitDistinctLimit(DistinctLimitNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            return this.sampleSourceNodeWithKey((PlanNode)node, node.getSource(), node.getDistinctVariables());
        }
    }
}

