/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Logical;
import io.trino.sql.planner.DeterminismEvaluator;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public final class ExtractCommonPredicatesExpressionRewriter {
    public static Expression extractCommonPredicates(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(), expression, NodeContext.ROOT_NODE);
    }

    private ExtractCommonPredicatesExpressionRewriter() {
    }

    private static class Visitor
    extends ExpressionRewriter<NodeContext> {
        private Visitor() {
        }

        @Override
        protected Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            if (context.isRootNode()) {
                return treeRewriter.rewrite(node, NodeContext.NOT_ROOT_NODE);
            }
            return null;
        }

        @Override
        public Expression rewriteLogical(Logical node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            Logical value;
            Expression expression = IrUtils.combinePredicates(node.operator(), (Collection)IrUtils.extractPredicates(node.operator(), node).stream().map(subExpression -> treeRewriter.rewrite(subExpression, NodeContext.NOT_ROOT_NODE)).collect(ImmutableList.toImmutableList()));
            if (!(expression instanceof Logical)) {
                return expression;
            }
            Logical logical = (Logical)expression;
            Expression simplified = this.extractCommonPredicates(logical);
            if (context.isRootNode() && simplified instanceof Logical && (value = (Logical)simplified).operator() == Logical.Operator.OR) {
                return this.distributeIfPossible(value);
            }
            return simplified;
        }

        private Expression extractCommonPredicates(Logical node) {
            List<List<Expression>> subPredicates = Visitor.getSubPredicates(node);
            ImmutableSet commonPredicates = ImmutableSet.copyOf((Collection)subPredicates.stream().map(this::filterDeterministicPredicates).reduce(Sets::intersection).orElse(Collections.emptySet()));
            List uncorrelatedSubPredicates = (List)subPredicates.stream().map(arg_0 -> Visitor.lambda$extractCommonPredicates$0((Set)commonPredicates, arg_0)).collect(ImmutableList.toImmutableList());
            Logical.Operator flippedOperator = node.operator().flip();
            List uncorrelatedPredicates = (List)uncorrelatedSubPredicates.stream().map(predicate -> IrUtils.combinePredicates(flippedOperator, predicate)).collect(ImmutableList.toImmutableList());
            Expression combinedUncorrelatedPredicates = IrUtils.combinePredicates(node.operator(), uncorrelatedPredicates);
            return IrUtils.combinePredicates(flippedOperator, (Collection<Expression>)ImmutableList.builder().addAll((Iterable)commonPredicates).add((Object)combinedUncorrelatedPredicates).build());
        }

        private static List<List<Expression>> getSubPredicates(Logical expression) {
            return (List)IrUtils.extractPredicates(expression.operator(), expression).stream().map(predicate -> {
                Object object;
                if (predicate instanceof Logical) {
                    Logical logical = (Logical)predicate;
                    object = IrUtils.extractPredicates(logical);
                } else {
                    object = ImmutableList.of((Object)predicate);
                }
                return object;
            }).collect(ImmutableList.toImmutableList());
        }

        private Expression distributeIfPossible(Logical expression) {
            int newBaseExpressions;
            if (!DeterminismEvaluator.isDeterministic(expression)) {
                return expression;
            }
            List subPredicates = Visitor.getSubPredicates(expression).stream().map(ImmutableSet::copyOf).collect(Collectors.toList());
            int originalBaseExpressions = subPredicates.stream().mapToInt(Set::size).sum();
            try {
                newBaseExpressions = Math.multiplyExact(subPredicates.stream().mapToInt(Set::size).reduce(Math::multiplyExact).getAsInt(), subPredicates.size());
            }
            catch (ArithmeticException e) {
                return expression;
            }
            if (newBaseExpressions > originalBaseExpressions * 2) {
                return expression;
            }
            Set crossProduct = Sets.cartesianProduct(subPredicates);
            return IrUtils.combinePredicates(expression.operator().flip(), (Collection)crossProduct.stream().map(expressions -> IrUtils.combinePredicates(expression.operator(), expressions)).collect(ImmutableList.toImmutableList()));
        }

        private Set<Expression> filterDeterministicPredicates(List<Expression> predicates) {
            return predicates.stream().filter(DeterminismEvaluator::isDeterministic).collect(Collectors.toSet());
        }

        private static <T> List<T> removeAll(Collection<T> collection, Collection<T> elementsToRemove) {
            return (List)collection.stream().filter(element -> !elementsToRemove.contains(element)).collect(ImmutableList.toImmutableList());
        }

        private static /* synthetic */ List lambda$extractCommonPredicates$0(Set commonPredicates, List predicateList) {
            return Visitor.removeAll(predicateList, commonPredicates);
        }
    }

    private static enum NodeContext {
        ROOT_NODE,
        NOT_ROOT_NODE;


        boolean isRootNode() {
            return this == ROOT_NODE;
        }
    }
}

