/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.prestosql.sql.ExpressionUtils;
import io.prestosql.sql.planner.DeterminismEvaluator;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.LogicalBinaryExpression;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class ExtractCommonPredicatesExpressionRewriter {
    public static Expression extractCommonPredicates(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new Visitor(), (Expression)expression, (Object)((Object)NodeContext.ROOT_NODE));
    }

    private ExtractCommonPredicatesExpressionRewriter() {
    }

    private static enum NodeContext {
        ROOT_NODE,
        NOT_ROOT_NODE;


        boolean isRootNode() {
            return this == ROOT_NODE;
        }
    }

    private static class Visitor
    extends ExpressionRewriter<NodeContext> {
        private Visitor() {
        }

        public Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            if (context.isRootNode()) {
                return treeRewriter.rewrite(node, (Object)NodeContext.NOT_ROOT_NODE);
            }
            return null;
        }

        public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            Expression expression = ExpressionUtils.combinePredicates(node.getOperator(), (Collection)ExpressionUtils.extractPredicates(node.getOperator(), (Expression)node).stream().map(subExpression -> treeRewriter.rewrite(subExpression, (Object)NodeContext.NOT_ROOT_NODE)).collect(ImmutableList.toImmutableList()));
            if (!(expression instanceof LogicalBinaryExpression)) {
                return expression;
            }
            Expression simplified = Visitor.extractCommonPredicates((LogicalBinaryExpression)expression);
            if (context.isRootNode() && simplified instanceof LogicalBinaryExpression && ((LogicalBinaryExpression)simplified).getOperator() == LogicalBinaryExpression.Operator.OR) {
                return Visitor.distributeIfPossible((LogicalBinaryExpression)simplified);
            }
            return simplified;
        }

        private static Expression extractCommonPredicates(LogicalBinaryExpression node) {
            List<List<Expression>> subPredicates = Visitor.getSubPredicates(node);
            ImmutableSet commonPredicates = ImmutableSet.copyOf((Collection)subPredicates.stream().map(Visitor::filterDeterministicPredicates).reduce(Sets::intersection).orElse(Collections.emptySet()));
            List uncorrelatedSubPredicates = (List)subPredicates.stream().map(arg_0 -> Visitor.lambda$extractCommonPredicates$1((Set)commonPredicates, arg_0)).collect(ImmutableList.toImmutableList());
            LogicalBinaryExpression.Operator flippedOperator = node.getOperator().flip();
            List uncorrelatedPredicates = (List)uncorrelatedSubPredicates.stream().map(predicate -> ExpressionUtils.combinePredicates(flippedOperator, predicate)).collect(ImmutableList.toImmutableList());
            Expression combinedUncorrelatedPredicates = ExpressionUtils.combinePredicates(node.getOperator(), uncorrelatedPredicates);
            return ExpressionUtils.combinePredicates(flippedOperator, (Collection<Expression>)ImmutableList.builder().addAll((Iterable)commonPredicates).add((Object)combinedUncorrelatedPredicates).build());
        }

        private static List<List<Expression>> getSubPredicates(LogicalBinaryExpression expression) {
            return (List)ExpressionUtils.extractPredicates(expression.getOperator(), (Expression)expression).stream().map(predicate -> predicate instanceof LogicalBinaryExpression ? ExpressionUtils.extractPredicates((LogicalBinaryExpression)predicate) : ImmutableList.of((Object)predicate)).collect(ImmutableList.toImmutableList());
        }

        private static Expression distributeIfPossible(LogicalBinaryExpression expression) {
            int newBaseExpressions;
            if (!DeterminismEvaluator.isDeterministic((Expression)expression)) {
                return expression;
            }
            List subPredicates = Visitor.getSubPredicates(expression).stream().map(ImmutableSet::copyOf).collect(Collectors.toList());
            int originalBaseExpressions = subPredicates.stream().mapToInt(Set::size).sum();
            try {
                newBaseExpressions = Math.multiplyExact(subPredicates.stream().mapToInt(Set::size).reduce(Math::multiplyExact).getAsInt(), subPredicates.size());
            }
            catch (ArithmeticException e) {
                return expression;
            }
            if (newBaseExpressions > originalBaseExpressions * 2) {
                return expression;
            }
            Set crossProduct = Sets.cartesianProduct(subPredicates);
            return ExpressionUtils.combinePredicates(expression.getOperator().flip(), (Collection)crossProduct.stream().map(expressions -> ExpressionUtils.combinePredicates(expression.getOperator(), expressions)).collect(ImmutableList.toImmutableList()));
        }

        private static Set<Expression> filterDeterministicPredicates(List<Expression> predicates) {
            return predicates.stream().filter(DeterminismEvaluator::isDeterministic).collect(Collectors.toSet());
        }

        private static <T> List<T> removeAll(Collection<T> collection, Collection<T> elementsToRemove) {
            return (List)collection.stream().filter(element -> !elementsToRemove.contains(element)).collect(ImmutableList.toImmutableList());
        }

        private static /* synthetic */ List lambda$extractCommonPredicates$1(Set commonPredicates, List predicateList) {
            return Visitor.removeAll(predicateList, commonPredicates);
        }
    }
}

