/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.ApplyNodeUtil;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.Expression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class ExpressionRewriteRuleSet {
    private final ExpressionRewriter rewriter;

    public ExpressionRewriteRuleSet(ExpressionRewriter rewriter) {
        this.rewriter = Objects.requireNonNull(rewriter, "rewriter is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of(this.projectExpressionRewrite(), this.aggregationExpressionRewrite(), this.filterExpressionRewrite(), this.joinExpressionRewrite(), this.valuesExpressionRewrite(), this.applyExpressionRewrite(), (Object[])new Rule[0]);
    }

    public Rule<?> projectExpressionRewrite() {
        return new ProjectExpressionRewrite(this.rewriter);
    }

    public Rule<?> aggregationExpressionRewrite() {
        return new AggregationExpressionRewrite(this.rewriter);
    }

    public Rule<?> filterExpressionRewrite() {
        return new FilterExpressionRewrite(this.rewriter);
    }

    public Rule<?> joinExpressionRewrite() {
        return new JoinExpressionRewrite(this.rewriter);
    }

    public Rule<?> valuesExpressionRewrite() {
        return new ValuesExpressionRewrite(this.rewriter);
    }

    public Rule<?> applyExpressionRewrite() {
        return new ApplyExpressionRewrite(this.rewriter);
    }

    private static final class ApplyExpressionRewrite
    implements Rule<ApplyNode> {
        private final ExpressionRewriter rewriter;

        ApplyExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<ApplyNode> getPattern() {
            return Patterns.applyNode();
        }

        @Override
        public Rule.Result apply(ApplyNode applyNode, Captures captures, Rule.Context context) {
            Assignments subqueryAssignments = AssignmentUtils.rewrite(applyNode.getSubqueryAssignments(), x -> this.rewriter.rewrite((Expression)x, context));
            if (applyNode.getSubqueryAssignments().equals((Object)subqueryAssignments)) {
                return Rule.Result.empty();
            }
            ApplyNodeUtil.verifySubquerySupported(subqueryAssignments);
            return Rule.Result.ofPlanNode(new ApplyNode(applyNode.getSourceLocation(), applyNode.getId(), applyNode.getInput(), applyNode.getSubquery(), subqueryAssignments, applyNode.getCorrelation(), applyNode.getOriginSubqueryError()));
        }
    }

    private static final class ValuesExpressionRewrite
    implements Rule<ValuesNode> {
        private final ExpressionRewriter rewriter;

        ValuesExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<ValuesNode> getPattern() {
            return Patterns.values();
        }

        @Override
        public Rule.Result apply(ValuesNode valuesNode, Captures captures, Rule.Context context) {
            boolean anyRewritten = false;
            ImmutableList.Builder rows = ImmutableList.builder();
            for (List row : valuesNode.getRows()) {
                ImmutableList.Builder newRow = ImmutableList.builder();
                for (RowExpression rowExpression : row) {
                    if (OriginalExpressionUtils.isExpression(rowExpression)) {
                        Expression expression = OriginalExpressionUtils.castToExpression(rowExpression);
                        Expression rewritten = this.rewriter.rewrite(expression, context);
                        newRow.add((Object)OriginalExpressionUtils.castToRowExpression(rewritten));
                        if (expression.equals((Object)rewritten)) continue;
                        anyRewritten = true;
                        continue;
                    }
                    newRow.add((Object)rowExpression);
                }
                rows.add((Object)newRow.build());
            }
            if (anyRewritten) {
                return Rule.Result.ofPlanNode((PlanNode)new ValuesNode(valuesNode.getSourceLocation(), valuesNode.getId(), valuesNode.getOutputVariables(), (List)rows.build()));
            }
            return Rule.Result.empty();
        }
    }

    private static final class JoinExpressionRewrite
    implements Rule<JoinNode> {
        private final ExpressionRewriter rewriter;

        JoinExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<JoinNode> getPattern() {
            return Patterns.join();
        }

        @Override
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            Optional<Expression> filter = joinNode.getFilter().map(x -> this.rewriter.rewrite(OriginalExpressionUtils.castToExpression(x), context));
            if (!joinNode.getFilter().map(OriginalExpressionUtils::castToExpression).equals(filter)) {
                return Rule.Result.ofPlanNode(new JoinNode(joinNode.getSourceLocation(), joinNode.getId(), joinNode.getType(), joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputVariables(), filter.map(OriginalExpressionUtils::castToRowExpression), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters()));
            }
            return Rule.Result.empty();
        }
    }

    private static final class FilterExpressionRewrite
    implements Rule<FilterNode> {
        private final ExpressionRewriter rewriter;

        FilterExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return Patterns.filter();
        }

        @Override
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            RowExpression rewritten = OriginalExpressionUtils.isExpression(filterNode.getPredicate()) ? OriginalExpressionUtils.castToRowExpression(this.rewriter.rewrite(OriginalExpressionUtils.castToExpression(filterNode.getPredicate()), context)) : filterNode.getPredicate();
            if (filterNode.getPredicate().equals((Object)rewritten)) {
                return Rule.Result.empty();
            }
            return Rule.Result.ofPlanNode((PlanNode)new FilterNode(filterNode.getSourceLocation(), filterNode.getId(), filterNode.getSource(), rewritten));
        }
    }

    private static final class AggregationExpressionRewrite
    implements Rule<AggregationNode> {
        private final ExpressionRewriter rewriter;

        AggregationExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<AggregationNode> getPattern() {
            return Patterns.aggregation();
        }

        @Override
        public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
            boolean anyRewritten = false;
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry entry : aggregationNode.getAggregations().entrySet()) {
                AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)entry.getValue();
                AggregationNode.Aggregation rewritten = new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), (List)aggregation.getCall().getArguments().stream().map(argument -> OriginalExpressionUtils.castToRowExpression(this.rewriter.rewrite(OriginalExpressionUtils.castToExpression(argument), context))).collect(ImmutableList.toImmutableList())), aggregation.getFilter().map(filter -> OriginalExpressionUtils.castToRowExpression(this.rewriter.rewrite(OriginalExpressionUtils.castToExpression(filter), context))), aggregation.getOrderBy(), aggregation.isDistinct(), aggregation.getMask());
                aggregations.put(entry.getKey(), (Object)rewritten);
                if (aggregation.equals((Object)rewritten)) continue;
                anyRewritten = true;
            }
            if (anyRewritten) {
                return Rule.Result.ofPlanNode((PlanNode)new AggregationNode(aggregationNode.getSourceLocation(), aggregationNode.getId(), aggregationNode.getSource(), (Map)aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
            }
            return Rule.Result.empty();
        }
    }

    private static final class ProjectExpressionRewrite
    implements Rule<ProjectNode> {
        private final ExpressionRewriter rewriter;

        ProjectExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<ProjectNode> getPattern() {
            return Patterns.project();
        }

        @Override
        public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
            Assignments assignments = AssignmentUtils.rewrite(projectNode.getAssignments(), x -> this.rewriter.rewrite((Expression)x, context));
            if (projectNode.getAssignments().equals((Object)assignments)) {
                return Rule.Result.empty();
            }
            return Rule.Result.ofPlanNode((PlanNode)new ProjectNode(projectNode.getSourceLocation(), projectNode.getId(), projectNode.getSource(), assignments, projectNode.getLocality()));
        }
    }

    public static interface ExpressionRewriter {
        public Expression rewrite(Expression var1, Rule.Context var2);
    }
}

