/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.OrderBy;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.SortItem;
import io.trino.sql.tree.SymbolReference;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class ExpressionRewriteRuleSet {
    private final ExpressionRewriter rewriter;

    public ExpressionRewriteRuleSet(ExpressionRewriter rewriter) {
        this.rewriter = Objects.requireNonNull(rewriter, "rewriter is null");
    }

    public Set<Rule<?>> rules() {
        return ImmutableSet.of(this.projectExpressionRewrite(), this.aggregationExpressionRewrite(), this.filterExpressionRewrite(), this.joinExpressionRewrite(), this.valuesExpressionRewrite(), this.patternRecognitionExpressionRewrite(), (Object[])new Rule[0]);
    }

    public Rule<?> projectExpressionRewrite() {
        return new ProjectExpressionRewrite(this.rewriter);
    }

    public Rule<?> aggregationExpressionRewrite() {
        return new AggregationExpressionRewrite(this.rewriter);
    }

    public Rule<?> filterExpressionRewrite() {
        return new FilterExpressionRewrite(this.rewriter);
    }

    public Rule<?> joinExpressionRewrite() {
        return new JoinExpressionRewrite(this.rewriter);
    }

    public Rule<?> valuesExpressionRewrite() {
        return new ValuesExpressionRewrite(this.rewriter);
    }

    public Rule<?> patternRecognitionExpressionRewrite() {
        return new PatternRecognitionExpressionRewrite(this.rewriter);
    }

    public static interface ExpressionRewriter {
        public Expression rewrite(Expression var1, Rule.Context var2);
    }

    private static final class ProjectExpressionRewrite
    implements Rule<ProjectNode> {
        private final ExpressionRewriter rewriter;

        ProjectExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<ProjectNode> getPattern() {
            return Patterns.project();
        }

        @Override
        public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
            Assignments assignments = projectNode.getAssignments().rewrite(x -> this.rewriter.rewrite((Expression)x, context));
            if (projectNode.getAssignments().equals(assignments)) {
                return Rule.Result.empty();
            }
            return Rule.Result.ofPlanNode(new ProjectNode(projectNode.getId(), projectNode.getSource(), assignments));
        }

        public String toString() {
            return String.format("%s(%s)", this.getClass().getSimpleName(), this.rewriter);
        }
    }

    private static final class AggregationExpressionRewrite
    implements Rule<AggregationNode> {
        private final ExpressionRewriter rewriter;

        AggregationExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<AggregationNode> getPattern() {
            return Patterns.aggregation();
        }

        @Override
        public Rule.Result apply(AggregationNode aggregationNode, Captures captures, Rule.Context context) {
            boolean anyRewritten = false;
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
                AggregationNode.Aggregation aggregation = entry.getValue();
                FunctionCall call = (FunctionCall)this.rewriter.rewrite((Expression)new FunctionCall(Optional.empty(), QualifiedName.of((String)aggregation.getResolvedFunction().getSignature().getName()), Optional.empty(), aggregation.getFilter().map(symbol -> new SymbolReference(symbol.getName())), aggregation.getOrderingScheme().map(orderBy -> new OrderBy((List)orderBy.getOrderBy().stream().map(symbol -> new SortItem((Expression)new SymbolReference(symbol.getName()), orderBy.getOrdering((Symbol)symbol).isAscending() ? SortItem.Ordering.ASCENDING : SortItem.Ordering.DESCENDING, orderBy.getOrdering((Symbol)symbol).isNullsFirst() ? SortItem.NullOrdering.FIRST : SortItem.NullOrdering.LAST)).collect(ImmutableList.toImmutableList()))), aggregation.isDistinct(), Optional.empty(), Optional.empty(), aggregation.getArguments()), context);
                Verify.verify((boolean)QualifiedName.of((String)ResolvedFunction.extractFunctionName(call.getName())).equals((Object)QualifiedName.of((String)aggregation.getResolvedFunction().getSignature().getName())), (String)"Aggregation function name changed", (Object[])new Object[0]);
                AggregationNode.Aggregation newAggregation = new AggregationNode.Aggregation(aggregation.getResolvedFunction(), call.getArguments(), call.isDistinct(), call.getFilter().map(Symbol::from), call.getOrderBy().map(OrderingScheme::fromOrderBy), aggregation.getMask());
                aggregations.put((Object)entry.getKey(), (Object)newAggregation);
                if (aggregation.equals(newAggregation)) continue;
                anyRewritten = true;
            }
            if (anyRewritten) {
                return Rule.Result.ofPlanNode(AggregationNode.builderFrom(aggregationNode).setAggregations((Map<Symbol, AggregationNode.Aggregation>)aggregations.buildOrThrow()).build());
            }
            return Rule.Result.empty();
        }

        public String toString() {
            return String.format("%s(%s)", this.getClass().getSimpleName(), this.rewriter);
        }
    }

    private static final class FilterExpressionRewrite
    implements Rule<FilterNode> {
        private final ExpressionRewriter rewriter;

        FilterExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<FilterNode> getPattern() {
            return Patterns.filter();
        }

        @Override
        public Rule.Result apply(FilterNode filterNode, Captures captures, Rule.Context context) {
            Expression rewritten = this.rewriter.rewrite(filterNode.getPredicate(), context);
            if (filterNode.getPredicate().equals((Object)rewritten)) {
                return Rule.Result.empty();
            }
            return Rule.Result.ofPlanNode(new FilterNode(filterNode.getId(), filterNode.getSource(), rewritten));
        }

        public String toString() {
            return String.format("%s(%s)", this.getClass().getSimpleName(), this.rewriter);
        }
    }

    private static final class JoinExpressionRewrite
    implements Rule<JoinNode> {
        private final ExpressionRewriter rewriter;

        JoinExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<JoinNode> getPattern() {
            return Patterns.join();
        }

        @Override
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            Optional<Expression> filter = joinNode.getFilter().map(x -> this.rewriter.rewrite((Expression)x, context));
            if (!joinNode.getFilter().equals(filter)) {
                return Rule.Result.ofPlanNode(new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), filter, joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()));
            }
            return Rule.Result.empty();
        }

        public String toString() {
            return String.format("%s(%s)", this.getClass().getSimpleName(), this.rewriter);
        }
    }

    private static final class ValuesExpressionRewrite
    implements Rule<ValuesNode> {
        private final ExpressionRewriter rewriter;

        ValuesExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<ValuesNode> getPattern() {
            return Patterns.values();
        }

        @Override
        public Rule.Result apply(ValuesNode valuesNode, Captures captures, Rule.Context context) {
            if (valuesNode.getRows().isEmpty()) {
                return Rule.Result.empty();
            }
            boolean anyRewritten = false;
            ImmutableList.Builder rows = ImmutableList.builder();
            Iterator<Expression> iterator = valuesNode.getRows().get().iterator();
            while (iterator.hasNext()) {
                Expression row;
                Object rewritten = (row = iterator.next()) instanceof Row ? new Row((List)((Row)row).getItems().stream().map(item -> this.rewriter.rewrite((Expression)item, context)).collect(ImmutableList.toImmutableList())) : this.rewriter.rewrite(row, context);
                if (!row.equals(rewritten)) {
                    anyRewritten = true;
                }
                rows.add(rewritten);
            }
            if (anyRewritten) {
                return Rule.Result.ofPlanNode(new ValuesNode(valuesNode.getId(), valuesNode.getOutputSymbols(), (List<Expression>)rows.build()));
            }
            return Rule.Result.empty();
        }

        public String toString() {
            return String.format("%s(%s)", this.getClass().getSimpleName(), this.rewriter);
        }
    }

    private static final class PatternRecognitionExpressionRewrite
    implements Rule<PatternRecognitionNode> {
        private final ExpressionRewriter rewriter;

        PatternRecognitionExpressionRewrite(ExpressionRewriter rewriter) {
            this.rewriter = rewriter;
        }

        @Override
        public Pattern<PatternRecognitionNode> getPattern() {
            return Patterns.patternRecognition();
        }

        @Override
        public Rule.Result apply(PatternRecognitionNode node, Captures captures, Rule.Context context) {
            boolean anyRewritten = false;
            ImmutableMap.Builder rewrittenMeasures = ImmutableMap.builder();
            for (Map.Entry<Symbol, PatternRecognitionNode.Measure> entry : node.getMeasures().entrySet()) {
                LogicalIndexExtractor.ExpressionAndValuePointers pointers = entry.getValue().getExpressionAndValuePointers();
                Optional<LogicalIndexExtractor.ExpressionAndValuePointers> newPointers = this.rewrite(pointers, context);
                if (newPointers.isPresent()) {
                    anyRewritten = true;
                    rewrittenMeasures.put((Object)entry.getKey(), (Object)new PatternRecognitionNode.Measure(newPointers.get(), entry.getValue().getType()));
                    continue;
                }
                rewrittenMeasures.put(entry);
            }
            ImmutableMap.Builder rewrittenDefinitions = ImmutableMap.builder();
            for (Map.Entry<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers> entry : node.getVariableDefinitions().entrySet()) {
                LogicalIndexExtractor.ExpressionAndValuePointers pointers = entry.getValue();
                Optional<LogicalIndexExtractor.ExpressionAndValuePointers> newPointers = this.rewrite(pointers, context);
                if (newPointers.isPresent()) {
                    anyRewritten = true;
                    rewrittenDefinitions.put((Object)entry.getKey(), (Object)newPointers.get());
                    continue;
                }
                rewrittenDefinitions.put(entry);
            }
            if (anyRewritten) {
                return Rule.Result.ofPlanNode(new PatternRecognitionNode(node.getId(), node.getSource(), node.getSpecification(), node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix(), node.getWindowFunctions(), (Map<Symbol, PatternRecognitionNode.Measure>)rewrittenMeasures.buildOrThrow(), node.getCommonBaseFrame(), node.getRowsPerMatch(), node.getSkipToLabel(), node.getSkipToPosition(), node.isInitial(), node.getPattern(), node.getSubsets(), (Map<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers>)rewrittenDefinitions.buildOrThrow()));
            }
            return Rule.Result.empty();
        }

        private Optional<LogicalIndexExtractor.ExpressionAndValuePointers> rewrite(LogicalIndexExtractor.ExpressionAndValuePointers pointers, Rule.Context context) {
            boolean rewritten = false;
            ImmutableList newLayout = pointers.getLayout();
            ImmutableList newPointers = pointers.getValuePointers();
            Set newClassifierSymbols = pointers.getClassifierSymbols();
            Set newMatchNumberSymbols = pointers.getMatchNumberSymbols();
            Expression newExpression = this.rewriter.rewrite(pointers.getExpression(), context);
            if (!pointers.getExpression().equals((Object)newExpression)) {
                rewritten = true;
                Set<Symbol> newSymbols = SymbolsExtractor.extractUnique(newExpression);
                List<Symbol> layout = pointers.getLayout();
                ImmutableList.Builder newLayoutBuilder = ImmutableList.builder();
                ImmutableList.Builder newPointersBuilder = ImmutableList.builder();
                for (int i = 0; i < layout.size(); ++i) {
                    if (!newSymbols.contains(layout.get(i))) continue;
                    newLayoutBuilder.add((Object)layout.get(i));
                    newPointersBuilder.add((Object)pointers.getValuePointers().get(i));
                }
                newLayout = newLayoutBuilder.build();
                newPointers = newPointersBuilder.build();
                newClassifierSymbols = (Set)pointers.getClassifierSymbols().stream().filter(newSymbols::contains).collect(ImmutableSet.toImmutableSet());
                newMatchNumberSymbols = (Set)pointers.getMatchNumberSymbols().stream().filter(newSymbols::contains).collect(ImmutableSet.toImmutableSet());
            }
            ImmutableList.Builder newPointersBuilder = ImmutableList.builder();
            for (ValuePointer pointer : newPointers) {
                if (pointer instanceof ScalarValuePointer) {
                    newPointersBuilder.add((Object)pointer);
                    continue;
                }
                AggregationValuePointer aggregationPointer = (AggregationValuePointer)pointer;
                ImmutableList.Builder newArguments = ImmutableList.builder();
                for (Expression argument : aggregationPointer.getArguments()) {
                    Expression newArgument = this.rewriter.rewrite(argument, context);
                    if (!newArgument.equals((Object)argument)) {
                        rewritten = true;
                    }
                    newArguments.add((Object)newArgument);
                }
                newPointersBuilder.add((Object)new AggregationValuePointer(aggregationPointer.getFunction(), aggregationPointer.getSetDescriptor(), (List<Expression>)newArguments.build(), aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol()));
            }
            newPointers = newPointersBuilder.build();
            if (rewritten) {
                return Optional.of(new LogicalIndexExtractor.ExpressionAndValuePointers(newExpression, (List<Symbol>)newLayout, (List<ValuePointer>)newPointers, newClassifierSymbols, newMatchNumberSymbols));
            }
            return Optional.empty();
        }

        public String toString() {
            return String.format("%s(%s)", this.getClass().getSimpleName(), this.rewriter);
        }
    }
}

