/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.execution.warnings.WarningCollector;
import com.facebook.presto.metadata.FunctionManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SetOperationNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.type.BigintType;
import com.facebook.presto.spi.type.BooleanType;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.SetOperationNodeUtils;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.BooleanLiteral;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.GenericLiteral;
import com.facebook.presto.sql.tree.NullLiteral;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Maps;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class ImplementIntersectAndExceptAsUnion
implements PlanOptimizer {
    private final FunctionManager functionManager;

    public ImplementIntersectAndExceptAsUnion(FunctionManager functionManager) {
        this.functionManager = Objects.requireNonNull(functionManager, "functionManager is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(variableAllocator, "variableAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(session, this.functionManager, idAllocator, variableAllocator), plan);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private static final String MARKER = "marker";
        private final Session session;
        private final StandardFunctionResolution functionResolution;
        private final PlanNodeIdAllocator idAllocator;
        private final PlanVariableAllocator variableAllocator;

        private Rewriter(Session session, FunctionManager functionManager, PlanNodeIdAllocator idAllocator, PlanVariableAllocator variableAllocator) {
            Objects.requireNonNull(functionManager, "functionManager is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.functionResolution = new FunctionResolution(functionManager);
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
        }

        public PlanNode visitIntersect(IntersectNode node, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            List<PlanNode> sources = node.getSources().stream().map(rewriteContext::rewrite).collect(Collectors.toList());
            List<VariableReferenceExpression> markers = this.allocateVariables(sources.size(), MARKER, (Type)BooleanType.BOOLEAN);
            List<PlanNode> withMarkers = this.appendMarkers(markers, sources, (SetOperationNode)node);
            List outputs = node.getOutputVariables();
            UnionNode union = this.union(withMarkers, (List<VariableReferenceExpression>)ImmutableList.copyOf((Iterable)Iterables.concat((Iterable)outputs, markers)));
            List<VariableReferenceExpression> aggregationOutputs = this.allocateVariables(markers.size(), "count", (Type)BigintType.BIGINT);
            AggregationNode aggregation = this.computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs);
            FilterNode filterNode = this.addFilterForIntersect(aggregation);
            return this.project((PlanNode)filterNode, outputs);
        }

        public PlanNode visitExcept(ExceptNode node, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            List<PlanNode> sources = node.getSources().stream().map(rewriteContext::rewrite).collect(Collectors.toList());
            List<VariableReferenceExpression> markers = this.allocateVariables(sources.size(), MARKER, (Type)BooleanType.BOOLEAN);
            List<PlanNode> withMarkers = this.appendMarkers(markers, sources, (SetOperationNode)node);
            List outputs = node.getOutputVariables();
            UnionNode union = this.union(withMarkers, (List<VariableReferenceExpression>)ImmutableList.copyOf((Iterable)Iterables.concat((Iterable)outputs, markers)));
            List<VariableReferenceExpression> aggregationOutputs = this.allocateVariables(markers.size(), "count", (Type)BigintType.BIGINT);
            AggregationNode aggregation = this.computeCounts(union, node.getOutputVariables(), markers, aggregationOutputs);
            FilterNode filterNode = this.addFilterForExcept(aggregation, aggregationOutputs.get(0), aggregationOutputs.subList(1, aggregationOutputs.size()));
            return this.project((PlanNode)filterNode, outputs);
        }

        private List<VariableReferenceExpression> allocateVariables(int count, String nameHint, Type type) {
            ImmutableList.Builder variablesBuilder = ImmutableList.builder();
            for (int i = 0; i < count; ++i) {
                variablesBuilder.add((Object)this.variableAllocator.newVariable(nameHint, type));
            }
            return variablesBuilder.build();
        }

        private List<PlanNode> appendMarkers(List<VariableReferenceExpression> markers, List<PlanNode> nodes, SetOperationNode node) {
            ImmutableList.Builder result = ImmutableList.builder();
            for (int i = 0; i < nodes.size(); ++i) {
                result.add((Object)this.appendMarkers(nodes.get(i), i, markers, Maps.transformValues((Map)node.sourceVariableMap(i), variable -> new SymbolReference(variable.getName()))));
            }
            return result.build();
        }

        private PlanNode appendMarkers(PlanNode source, int markerIndex, List<VariableReferenceExpression> markers, Map<VariableReferenceExpression, SymbolReference> projections) {
            Assignments.Builder assignments = Assignments.builder();
            for (Map.Entry<VariableReferenceExpression, SymbolReference> entry : projections.entrySet()) {
                VariableReferenceExpression variable = this.variableAllocator.newVariable(entry.getKey().getName(), entry.getKey().getType());
                assignments.put(variable, OriginalExpressionUtils.castToRowExpression((Expression)entry.getValue()));
            }
            for (int i = 0; i < markers.size(); ++i) {
                BooleanLiteral expression = i == markerIndex ? BooleanLiteral.TRUE_LITERAL : new Cast((Expression)new NullLiteral(), "boolean");
                assignments.put(this.variableAllocator.newVariable(markers.get(i).getName(), (Type)BooleanType.BOOLEAN), OriginalExpressionUtils.castToRowExpression((Expression)expression));
            }
            return new ProjectNode(this.idAllocator.getNextId(), source, assignments.build());
        }

        private UnionNode union(List<PlanNode> nodes, List<VariableReferenceExpression> outputs) {
            ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder();
            for (PlanNode source : nodes) {
                for (int i = 0; i < source.getOutputVariables().size(); ++i) {
                    outputsToInputs.put((Object)outputs.get(i), source.getOutputVariables().get(i));
                }
            }
            ImmutableListMultimap mapping = outputsToInputs.build();
            return new UnionNode(this.idAllocator.getNextId(), nodes, (List)ImmutableList.copyOf((Collection)mapping.keySet()), SetOperationNodeUtils.fromListMultimap((ListMultimap<VariableReferenceExpression, VariableReferenceExpression>)mapping));
        }

        private AggregationNode computeCounts(UnionNode sourceNode, List<VariableReferenceExpression> originalColumns, List<VariableReferenceExpression> markers, List<VariableReferenceExpression> aggregationOutputs) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (int i = 0; i < markers.size(); ++i) {
                VariableReferenceExpression output = aggregationOutputs.get(i);
                aggregations.put((Object)output, (Object)new AggregationNode.Aggregation(new CallExpression("count", this.functionResolution.countFunction(markers.get(i).getType()), (Type)BigintType.BIGINT, (List)ImmutableList.of((Object)OriginalExpressionUtils.castToRowExpression((Expression)OriginalExpressionUtils.asSymbolReference(markers.get(i))))), Optional.empty(), Optional.empty(), false, Optional.empty()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), (PlanNode)sourceNode, (Map)aggregations.build(), AggregationNode.singleGroupingSet(originalColumns), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        }

        private FilterNode addFilterForIntersect(AggregationNode aggregation) {
            ImmutableList predicates = (ImmutableList)aggregation.getAggregations().keySet().stream().map(column -> new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, (Expression)new SymbolReference(column.getName()), (Expression)new GenericLiteral("BIGINT", "1"))).collect(ImmutableList.toImmutableList());
            return new FilterNode(this.idAllocator.getNextId(), (PlanNode)aggregation, OriginalExpressionUtils.castToRowExpression(ExpressionUtils.and((Collection<Expression>)predicates)));
        }

        private FilterNode addFilterForExcept(AggregationNode aggregation, VariableReferenceExpression firstSource, List<VariableReferenceExpression> remainingSources) {
            ImmutableList.Builder predicatesBuilder = ImmutableList.builder();
            predicatesBuilder.add((Object)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, (Expression)new SymbolReference(firstSource.getName()), (Expression)new GenericLiteral("BIGINT", "1")));
            for (VariableReferenceExpression variable : remainingSources) {
                predicatesBuilder.add((Object)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)new SymbolReference(variable.getName()), (Expression)new GenericLiteral("BIGINT", "0")));
            }
            return new FilterNode(this.idAllocator.getNextId(), (PlanNode)aggregation, OriginalExpressionUtils.castToRowExpression(ExpressionUtils.and((Collection<Expression>)predicatesBuilder.build())));
        }

        private ProjectNode project(PlanNode node, List<VariableReferenceExpression> columns) {
            return new ProjectNode(this.idAllocator.getNextId(), node, AssignmentUtils.identityAssignmentsAsSymbolReferences(columns));
        }
    }
}

