/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner;

import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.prestosql.Session;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.spi.predicate.TupleDomain;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.DeterminismEvaluator;
import io.prestosql.sql.planner.DomainTranslator;
import io.prestosql.sql.planner.EqualityInference;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.DistinctLimitNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

public class EffectivePredicateExtractor {
    private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION = entry -> ((Expression)entry.getValue()).equals((Object)((Symbol)entry.getKey()).toSymbolReference());
    private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY = entry -> {
        SymbolReference reference = ((Symbol)entry.getKey()).toSymbolReference();
        Expression expression = (Expression)entry.getValue();
        return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)reference, expression);
    };
    private final DomainTranslator domainTranslator;
    private final Metadata metadata;

    public EffectivePredicateExtractor(DomainTranslator domainTranslator, Metadata metadata) {
        this.domainTranslator = Objects.requireNonNull(domainTranslator, "domainTranslator is null");
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    public Expression extract(Session session, PlanNode node) {
        return node.accept(new Visitor(this.domainTranslator, this.metadata, session), null);
    }

    private static class Visitor
    extends PlanVisitor<Expression, Void> {
        private final DomainTranslator domainTranslator;
        private final Metadata metadata;
        private final Session session;

        public Visitor(DomainTranslator domainTranslator, Metadata metadata, Session session) {
            this.domainTranslator = Objects.requireNonNull(domainTranslator, "domainTranslator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        @Override
        protected Expression visitPlan(PlanNode node, Void context) {
            return BooleanLiteral.TRUE_LITERAL;
        }

        @Override
        public Expression visitAggregation(AggregationNode node, Void context) {
            if (node.getGroupingKeys().isEmpty()) {
                return BooleanLiteral.TRUE_LITERAL;
            }
            Expression underlyingPredicate = node.getSource().accept(this, context);
            return Visitor.pullExpressionThroughSymbols(underlyingPredicate, node.getGroupingKeys());
        }

        @Override
        public Expression visitFilter(FilterNode node, Void context) {
            Expression underlyingPredicate = node.getSource().accept(this, context);
            Expression predicate = node.getPredicate();
            predicate = ExpressionUtils.filterDeterministicConjuncts(predicate);
            return ExpressionUtils.combineConjuncts(predicate, underlyingPredicate);
        }

        @Override
        public Expression visitExchange(ExchangeNode node, Void context) {
            return this.deriveCommonPredicates(node, source -> {
                HashMap<Symbol, SymbolReference> mappings = new HashMap<Symbol, SymbolReference>();
                for (int i = 0; i < node.getInputs().get((int)source).size(); ++i) {
                    mappings.put(node.getOutputSymbols().get(i), node.getInputs().get((int)source).get(i).toSymbolReference());
                }
                return mappings.entrySet();
            });
        }

        @Override
        public Expression visitProject(ProjectNode node, Void context) {
            Expression underlyingPredicate = node.getSource().accept(this, context);
            List projectionEqualities = (List)node.getAssignments().entrySet().stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList());
            return Visitor.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll((Iterable)projectionEqualities).add((Object)underlyingPredicate).build()), node.getOutputSymbols());
        }

        @Override
        public Expression visitTopN(TopNNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitLimit(LimitNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitAssignUniqueId(AssignUniqueId node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitDistinctLimit(DistinctLimitNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitTableScan(TableScanNode node, Void context) {
            ImmutableBiMap assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
            TupleDomain<ColumnHandle> predicate = node.getEnforcedConstraint();
            return this.domainTranslator.toPredicate((TupleDomain<Symbol>)predicate.simplify().transform(((Map)assignments)::get));
        }

        @Override
        public Expression visitSort(SortNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitWindow(WindowNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitUnion(UnionNode node, Void context) {
            return this.deriveCommonPredicates(node, source -> node.outputSymbolMap((int)source).entries());
        }

        @Override
        public Expression visitJoin(JoinNode node, Void context) {
            Expression leftPredicate = node.getLeft().accept(this, context);
            Expression rightPredicate = node.getRight().accept(this, context);
            List joinConjuncts = (List)node.getCriteria().stream().map(JoinNode.EquiJoinClause::toExpression).collect(ImmutableList.toImmutableList());
            switch (node.getType()) {
                case INNER: {
                    return Visitor.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)leftPredicate).add((Object)rightPredicate).add((Object)ExpressionUtils.combineConjuncts(joinConjuncts)).add((Object)node.getFilter().orElse((Expression)BooleanLiteral.TRUE_LITERAL)).build()), node.getOutputSymbols());
                }
                case LEFT: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getRight().getOutputSymbols()::contains;
                    Predicate[] predicateArray2 = new Predicate[1];
                    predicateArray2[0] = node.getRight().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)Visitor.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).addAll(Visitor.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), predicateArray)).addAll(Visitor.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), predicateArray2)).build());
                }
                case RIGHT: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getLeft().getOutputSymbols()::contains;
                    Predicate[] predicateArray3 = new Predicate[1];
                    predicateArray3[0] = node.getLeft().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)Visitor.pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())).addAll(Visitor.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(leftPredicate), node.getOutputSymbols(), predicateArray)).addAll(Visitor.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), predicateArray3)).build());
                }
                case FULL: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getLeft().getOutputSymbols()::contains;
                    Predicate[] predicateArray4 = new Predicate[1];
                    predicateArray4[0] = node.getRight().getOutputSymbols()::contains;
                    Predicate[] predicateArray5 = new Predicate[2];
                    predicateArray5[0] = node.getLeft().getOutputSymbols()::contains;
                    predicateArray5[1] = node.getRight().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll(Visitor.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(leftPredicate), node.getOutputSymbols(), predicateArray)).addAll(Visitor.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), predicateArray4)).addAll(Visitor.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), predicateArray5)).build());
                }
            }
            throw new UnsupportedOperationException("Unknown join type: " + (Object)((Object)node.getType()));
        }

        private static Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> conjuncts, Collection<Symbol> outputSymbols, Predicate<Symbol> ... nullSymbolScopes) {
            return (Iterable)conjuncts.stream().map(expression -> Visitor.pullExpressionThroughSymbols(expression, outputSymbols)).map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? BooleanLiteral.TRUE_LITERAL : expression).map(ExpressionUtils.expressionOrNullSymbols(nullSymbolScopes)).collect(ImmutableList.toImmutableList());
        }

        @Override
        public Expression visitSemiJoin(SemiJoinNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitSpatialJoin(SpatialJoinNode node, Void context) {
            Expression leftPredicate = node.getLeft().accept(this, context);
            Expression rightPredicate = node.getRight().accept(this, context);
            switch (node.getType()) {
                case INNER: {
                    return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)Visitor.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).add((Object)Visitor.pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())).build());
                }
                case LEFT: {
                    Predicate[] predicateArray = new Predicate[1];
                    predicateArray[0] = node.getRight().getOutputSymbols()::contains;
                    return ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)Visitor.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).addAll(Visitor.pullNullableConjunctsThroughOuterJoin(ExpressionUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), predicateArray)).build());
                }
            }
            throw new IllegalArgumentException("Unsupported spatial join type: " + (Object)((Object)node.getType()));
        }

        private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> mapping) {
            ArrayList<ImmutableSet> sourceOutputConjuncts = new ArrayList<ImmutableSet>();
            for (int i = 0; i < node.getSources().size(); ++i) {
                Expression underlyingPredicate = node.getSources().get(i).accept(this, null);
                List equalities = (List)mapping.apply(i).stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList());
                sourceOutputConjuncts.add(ImmutableSet.copyOf(ExpressionUtils.extractConjuncts(Visitor.pullExpressionThroughSymbols(ExpressionUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll((Iterable)equalities).add((Object)underlyingPredicate).build()), node.getOutputSymbols()))));
            }
            Iterator iterator = sourceOutputConjuncts.iterator();
            Set potentialOutputConjuncts = (Set)iterator.next();
            while (iterator.hasNext()) {
                potentialOutputConjuncts = Sets.intersection((Set)potentialOutputConjuncts, (Set)((Set)iterator.next()));
            }
            return ExpressionUtils.combineConjuncts(potentialOutputConjuncts);
        }

        private static List<Expression> pullExpressionsThroughSymbols(List<Expression> expressions, Collection<Symbol> symbols) {
            return (List)expressions.stream().map(expression -> Visitor.pullExpressionThroughSymbols(expression, symbols)).collect(ImmutableList.toImmutableList());
        }

        private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) {
            EqualityInference equalityInference = EqualityInference.createEqualityInference(expression);
            ImmutableList.Builder effectiveConjuncts = ImmutableList.builder();
            for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) {
                Expression rewritten;
                if (!DeterminismEvaluator.isDeterministic(conjunct) || (rewritten = equalityInference.rewriteExpression(conjunct, (com.google.common.base.Predicate<Symbol>)Predicates.in(symbols))) == null) continue;
                effectiveConjuncts.add((Object)rewritten);
            }
            effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy((com.google.common.base.Predicate<Symbol>)Predicates.in(symbols)).getScopeEqualities());
            return ExpressionUtils.combineConjuncts((Collection<Expression>)effectiveConjuncts.build());
        }
    }
}

