/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.RowExpressionRewriteRuleSet;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class RewriteCaseExpressionPredicate
extends RowExpressionRewriteRuleSet {
    public RewriteCaseExpressionPredicate(FunctionAndTypeManager functionAndTypeManager) {
        super(new Rewriter(functionAndTypeManager));
    }

    @Override
    public boolean isRewriterEnabled(Session session) {
        return SystemSessionProperties.isOptimizeCaseExpressionPredicate(session);
    }

    @Override
    public Set<Rule<?>> rules() {
        return ImmutableSet.of(this.filterRowExpressionRewriteRule(), this.joinRowExpressionRewriteRule());
    }

    private static class CaseExpressionPredicateRewriter
    extends RowExpressionRewriter<Void> {
        private final FunctionResolution functionResolution;
        private final LogicalRowExpressions logicalRowExpressions;

        private CaseExpressionPredicateRewriter(FunctionAndTypeManager functionAndTypeManager) {
            this.functionResolution = new FunctionResolution(functionAndTypeManager);
            this.logicalRowExpressions = new LogicalRowExpressions((DeterminismEvaluator)new RowExpressionDeterminismEvaluator(functionAndTypeManager), (StandardFunctionResolution)this.functionResolution, (FunctionMetadataManager)functionAndTypeManager);
        }

        public RowExpression rewriteCall(CallExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter) {
            if (this.functionResolution.isComparisonFunction(node.getFunctionHandle()) && node.getArguments().size() == 2) {
                RowExpression left = (RowExpression)node.getArguments().get(0);
                RowExpression right = (RowExpression)node.getArguments().get(1);
                if (this.isCaseExpression(left) && this.isSimpleExpression(right)) {
                    return this.processCaseExpression(left, expression -> LogicalRowExpressions.replaceArguments((CallExpression)node, (RowExpression[])new RowExpression[]{expression, right}), right);
                }
                if (this.isCaseExpression(right) && this.isSimpleExpression(left)) {
                    return this.processCaseExpression(right, expression -> LogicalRowExpressions.replaceArguments((CallExpression)node, (RowExpression[])new RowExpression[]{left, expression}), left);
                }
            }
            return null;
        }

        private boolean isCaseExpression(RowExpression expression) {
            if (this.logicalRowExpressions.isCastExpression(expression)) {
                expression = (RowExpression)((CallExpression)expression).getArguments().get(0);
            }
            return expression instanceof SpecialFormExpression && ((SpecialFormExpression)expression).getForm().equals((Object)SpecialFormExpression.Form.SWITCH);
        }

        private boolean isSimpleExpression(RowExpression expression) {
            if (this.logicalRowExpressions.isCastExpression(expression)) {
                return this.isSimpleExpression((RowExpression)((CallExpression)expression).getArguments().get(0));
            }
            return expression instanceof ConstantExpression || expression instanceof VariableReferenceExpression || expression instanceof InputReferenceExpression;
        }

        private RowExpression processCaseExpression(RowExpression expression, Function<RowExpression, RowExpression> comparisonExpressionGenerator, RowExpression value) {
            if (expression instanceof SpecialFormExpression) {
                Preconditions.checkArgument((boolean)this.logicalRowExpressions.isCaseExpression(expression), (Object)"expression must be a CASE expression");
                return this.processCaseExpression((SpecialFormExpression)expression, Optional.empty(), comparisonExpressionGenerator, value);
            }
            Preconditions.checkArgument((boolean)this.logicalRowExpressions.isCastExpression(expression), (Object)"expression must be a CAST expression");
            Preconditions.checkArgument((boolean)this.logicalRowExpressions.isCaseExpression((RowExpression)((CallExpression)expression).getArguments().get(0)), (Object)"expression argument must be a CASE expression");
            return this.processCaseExpression((SpecialFormExpression)((CallExpression)expression).getArguments().get(0), Optional.of((CallExpression)expression), comparisonExpressionGenerator, value);
        }

        private RowExpression processCaseExpression(SpecialFormExpression caseExpression, Optional<CallExpression> castExpression, Function<RowExpression, RowExpression> comparisonExpressionGenerator, RowExpression value) {
            List<RowExpression> whenClauses;
            Optional<RowExpression> caseOperand = this.getCaseOperand((RowExpression)caseExpression.getArguments().get(0));
            Optional<RowExpression> elseResult = Optional.empty();
            int argumentsSize = caseExpression.getArguments().size();
            RowExpression last = (RowExpression)caseExpression.getArguments().get(argumentsSize - 1);
            if (last instanceof SpecialFormExpression && ((SpecialFormExpression)last).getForm().equals((Object)SpecialFormExpression.Form.WHEN)) {
                whenClauses = caseExpression.getArguments().subList(1, argumentsSize);
            } else {
                whenClauses = caseExpression.getArguments().subList(1, argumentsSize - 1);
                elseResult = Optional.of(last);
            }
            if (caseOperand.isPresent() ? !this.canRewriteSimpleCaseExpression(whenClauses) : !this.canRewriteSearchedCaseExpression(whenClauses)) {
                return null;
            }
            ImmutableList.Builder andExpressions = new ImmutableList.Builder();
            ImmutableList.Builder invertedOperands = new ImmutableList.Builder();
            for (RowExpression whenClause : whenClauses) {
                RowExpression whenOperand = (RowExpression)((SpecialFormExpression)whenClause).getArguments().get(0);
                if (caseOperand.isPresent()) {
                    whenOperand = this.logicalRowExpressions.equalsCallExpression(caseOperand.get(), whenOperand);
                }
                RowExpression whenResult = (RowExpression)((SpecialFormExpression)whenClause).getArguments().get(1);
                if (castExpression.isPresent()) {
                    whenResult = LogicalRowExpressions.replaceArguments((CallExpression)castExpression.get(), (RowExpression[])new RowExpression[]{whenResult});
                }
                RowExpression comparisonExpression = comparisonExpressionGenerator.apply(whenResult);
                andExpressions.add((Object)LogicalRowExpressions.and((RowExpression[])new RowExpression[]{comparisonExpression, whenOperand}));
                invertedOperands.add((Object)this.logicalRowExpressions.notCallExpression(whenOperand));
            }
            RowExpression elseCondition = LogicalRowExpressions.and((RowExpression[])new RowExpression[]{this.getElseExpression(castExpression, value, elseResult, comparisonExpressionGenerator), LogicalRowExpressions.and((Collection)invertedOperands.build())});
            andExpressions.add((Object)elseCondition);
            return LogicalRowExpressions.or((Collection)andExpressions.build());
        }

        private RowExpression getElseExpression(Optional<CallExpression> castExpression, RowExpression value, Optional<RowExpression> elseValue, Function<RowExpression, RowExpression> comparisonExpressionGenerator) {
            return elseValue.map(elseVal -> (RowExpression)comparisonExpressionGenerator.apply(castExpression.map(castExp -> LogicalRowExpressions.replaceArguments((CallExpression)castExp, (RowExpression[])new RowExpression[]{elseVal})).orElse((RowExpression)elseVal))).orElse((RowExpression)new SpecialFormExpression(SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, new RowExpression[]{value}));
        }

        private Optional<RowExpression> getCaseOperand(RowExpression expression) {
            boolean searchedCase = expression instanceof ConstantExpression && expression.getType() == BooleanType.BOOLEAN && ((ConstantExpression)expression).getValue() == Boolean.TRUE;
            return searchedCase ? Optional.empty() : Optional.of(expression);
        }

        private boolean canRewriteSimpleCaseExpression(List<RowExpression> whenClauses) {
            List<RowExpression> whenOperands = whenClauses.stream().map(x -> (RowExpression)((SpecialFormExpression)x).getArguments().get(0)).collect(Collectors.toList());
            return this.allExpressionsAreConstantAndUnique(whenOperands);
        }

        private boolean canRewriteSearchedCaseExpression(List<RowExpression> whenClauses) {
            if (!this.allAreEqualsExpression(whenClauses) || !this.allLHSOperandsAreUnique(whenClauses)) {
                return false;
            }
            List<RowExpression> rhsExpressions = whenClauses.stream().map(whenClause -> (RowExpression)((SpecialFormExpression)whenClause).getArguments().get(0)).map(whenOperand -> (RowExpression)((CallExpression)whenOperand).getArguments().get(1)).collect(Collectors.toList());
            return this.allExpressionsAreConstantAndUnique(rhsExpressions);
        }

        private boolean allLHSOperandsAreUnique(List<RowExpression> whenClauses) {
            return whenClauses.stream().map(whenClause -> (RowExpression)((SpecialFormExpression)whenClause).getArguments().get(0)).map(whenOperand -> (RowExpression)((CallExpression)whenOperand).getArguments().get(0)).distinct().count() == 1L;
        }

        private boolean allAreEqualsExpression(List<RowExpression> whenClauses) {
            return whenClauses.stream().map(whenClause -> (RowExpression)((SpecialFormExpression)whenClause).getArguments().get(0)).allMatch(arg_0 -> ((LogicalRowExpressions)this.logicalRowExpressions).isEqualsExpression(arg_0));
        }

        private boolean allExpressionsAreConstantAndUnique(List<RowExpression> expressions) {
            HashSet<RowExpression> expressionSet = new HashSet<RowExpression>();
            for (RowExpression expression : expressions) {
                if (!this.isConstantExpression(expression) || expressionSet.contains(expression)) {
                    return false;
                }
                expressionSet.add(expression);
            }
            return true;
        }

        private boolean isConstantExpression(RowExpression expression) {
            if (this.logicalRowExpressions.isCastExpression(expression)) {
                return this.isConstantExpression((RowExpression)((CallExpression)expression).getArguments().get(0));
            }
            return expression instanceof ConstantExpression;
        }
    }

    private static class Rewriter
    implements RowExpressionRewriteRuleSet.PlanRowExpressionRewriter {
        private final CaseExpressionPredicateRewriter caseExpressionPredicateRewriter;

        public Rewriter(FunctionAndTypeManager functionAndTypeManager) {
            Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.caseExpressionPredicateRewriter = new CaseExpressionPredicateRewriter(functionAndTypeManager);
        }

        @Override
        public RowExpression rewrite(RowExpression expression, Rule.Context context) {
            return RowExpressionTreeRewriter.rewriteWith((RowExpressionRewriter)this.caseExpressionPredicateRewriter, (RowExpression)expression);
        }
    }
}

