/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.ExceptNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.IntersectNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SetOperationNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.GenericLiteral;
import io.prestosql.sql.tree.NullLiteral;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.SymbolReference;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class ImplementIntersectAndExceptAsUnion
implements PlanOptimizer {
    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(idAllocator, symbolAllocator), plan);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private static final String MARKER = "marker";
        private static final Signature COUNT_AGGREGATION = new Signature("count", FunctionKind.AGGREGATE, TypeSignature.parseTypeSignature((String)"bigint"), TypeSignature.parseTypeSignature((String)"boolean"));
        private final PlanNodeIdAllocator idAllocator;
        private final SymbolAllocator symbolAllocator;

        private Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override
        public PlanNode visitIntersect(IntersectNode node, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            List<PlanNode> sources = node.getSources().stream().map(rewriteContext::rewrite).collect(Collectors.toList());
            List<Symbol> markers = this.allocateSymbols(sources.size(), MARKER, (Type)BooleanType.BOOLEAN);
            List<PlanNode> withMarkers = this.appendMarkers(markers, sources, node);
            List<Symbol> outputs = node.getOutputSymbols();
            UnionNode union = this.union(withMarkers, (List<Symbol>)ImmutableList.copyOf((Iterable)Iterables.concat(outputs, markers)));
            List<Symbol> aggregationOutputs = this.allocateSymbols(markers.size(), "count", (Type)BigintType.BIGINT);
            AggregationNode aggregation = this.computeCounts(union, outputs, markers, aggregationOutputs);
            FilterNode filterNode = this.addFilterForIntersect(aggregation);
            return this.project(filterNode, outputs);
        }

        @Override
        public PlanNode visitExcept(ExceptNode node, SimplePlanRewriter.RewriteContext<Void> rewriteContext) {
            List<PlanNode> sources = node.getSources().stream().map(rewriteContext::rewrite).collect(Collectors.toList());
            List<Symbol> markers = this.allocateSymbols(sources.size(), MARKER, (Type)BooleanType.BOOLEAN);
            List<PlanNode> withMarkers = this.appendMarkers(markers, sources, node);
            List<Symbol> outputs = node.getOutputSymbols();
            UnionNode union = this.union(withMarkers, (List<Symbol>)ImmutableList.copyOf((Iterable)Iterables.concat(outputs, markers)));
            List<Symbol> aggregationOutputs = this.allocateSymbols(markers.size(), "count", (Type)BigintType.BIGINT);
            AggregationNode aggregation = this.computeCounts(union, outputs, markers, aggregationOutputs);
            FilterNode filterNode = this.addFilterForExcept(aggregation, aggregationOutputs.get(0), aggregationOutputs.subList(1, aggregationOutputs.size()));
            return this.project(filterNode, outputs);
        }

        private List<Symbol> allocateSymbols(int count, String nameHint, Type type) {
            ImmutableList.Builder symbolsBuilder = ImmutableList.builder();
            for (int i = 0; i < count; ++i) {
                symbolsBuilder.add((Object)this.symbolAllocator.newSymbol(nameHint, type));
            }
            return symbolsBuilder.build();
        }

        private List<PlanNode> appendMarkers(List<Symbol> markers, List<PlanNode> nodes, SetOperationNode node) {
            ImmutableList.Builder result = ImmutableList.builder();
            for (int i = 0; i < nodes.size(); ++i) {
                result.add((Object)this.appendMarkers(nodes.get(i), i, markers, node.sourceSymbolMap(i)));
            }
            return result.build();
        }

        private PlanNode appendMarkers(PlanNode source, int markerIndex, List<Symbol> markers, Map<Symbol, SymbolReference> projections) {
            Assignments.Builder assignments = Assignments.builder();
            for (Map.Entry<Symbol, SymbolReference> entry : projections.entrySet()) {
                Symbol symbol = this.symbolAllocator.newSymbol(entry.getKey().getName(), this.symbolAllocator.getTypes().get(entry.getKey()));
                assignments.put(symbol, (Expression)entry.getValue());
            }
            for (int i = 0; i < markers.size(); ++i) {
                BooleanLiteral expression = i == markerIndex ? BooleanLiteral.TRUE_LITERAL : new Cast((Expression)new NullLiteral(), "boolean");
                assignments.put(this.symbolAllocator.newSymbol(markers.get(i).getName(), (Type)BooleanType.BOOLEAN), (Expression)expression);
            }
            return new ProjectNode(this.idAllocator.getNextId(), source, assignments.build());
        }

        private UnionNode union(List<PlanNode> nodes, List<Symbol> outputs) {
            ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder();
            for (PlanNode source : nodes) {
                for (int i = 0; i < source.getOutputSymbols().size(); ++i) {
                    outputsToInputs.put((Object)outputs.get(i), (Object)source.getOutputSymbols().get(i));
                }
            }
            return new UnionNode(this.idAllocator.getNextId(), nodes, (ListMultimap<Symbol, Symbol>)outputsToInputs.build(), outputs);
        }

        private AggregationNode computeCounts(UnionNode sourceNode, List<Symbol> originalColumns, List<Symbol> markers, List<Symbol> aggregationOutputs) {
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (int i = 0; i < markers.size(); ++i) {
                Symbol output = aggregationOutputs.get(i);
                aggregations.put((Object)output, (Object)new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of((String)"count"), (List)ImmutableList.of((Object)markers.get(i).toSymbolReference())), COUNT_AGGREGATION, Optional.empty()));
            }
            return new AggregationNode(this.idAllocator.getNextId(), sourceNode, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), AggregationNode.singleGroupingSet(originalColumns), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
        }

        private FilterNode addFilterForIntersect(AggregationNode aggregation) {
            ImmutableList predicates = (ImmutableList)aggregation.getAggregations().keySet().stream().map(column -> new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, (Expression)column.toSymbolReference(), (Expression)new GenericLiteral("BIGINT", "1"))).collect(ImmutableList.toImmutableList());
            return new FilterNode(this.idAllocator.getNextId(), aggregation, ExpressionUtils.and((Collection<Expression>)predicates));
        }

        private FilterNode addFilterForExcept(AggregationNode aggregation, Symbol firstSource, List<Symbol> remainingSources) {
            ImmutableList.Builder predicatesBuilder = ImmutableList.builder();
            predicatesBuilder.add((Object)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, (Expression)firstSource.toSymbolReference(), (Expression)new GenericLiteral("BIGINT", "1")));
            for (Symbol symbol : remainingSources) {
                predicatesBuilder.add((Object)new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)symbol.toSymbolReference(), (Expression)new GenericLiteral("BIGINT", "0")));
            }
            return new FilterNode(this.idAllocator.getNextId(), aggregation, ExpressionUtils.and((Collection<Expression>)predicatesBuilder.build()));
        }

        private ProjectNode project(PlanNode node, List<Symbol> columns) {
            return new ProjectNode(this.idAllocator.getNextId(), node, Assignments.identity(columns));
        }
    }
}

