/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class PushDownProjectionsFromPatternRecognition
implements Rule<PatternRecognitionNode> {
    private static final Pattern<PatternRecognitionNode> PATTERN = Patterns.patternRecognition();

    @Override
    public Pattern<PatternRecognitionNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(PatternRecognitionNode node, Captures captures, Rule.Context context) {
        Assignments.Builder assignments = Assignments.builder();
        Map<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers> rewrittenVariableDefinitions = PushDownProjectionsFromPatternRecognition.rewriteVariableDefinitions(node.getVariableDefinitions(), assignments, context);
        Map<Symbol, PatternRecognitionNode.Measure> rewrittenMeasureDefinitions = PushDownProjectionsFromPatternRecognition.rewriteMeasureDefinitions(node.getMeasures(), assignments, context);
        if (assignments.build().isEmpty()) {
            return Rule.Result.empty();
        }
        assignments.putIdentities(node.getSource().getOutputSymbols());
        ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), assignments.build());
        PatternRecognitionNode patternRecognitionNode = new PatternRecognitionNode(node.getId(), projectNode, node.getSpecification(), node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix(), node.getWindowFunctions(), rewrittenMeasureDefinitions, node.getCommonBaseFrame(), node.getRowsPerMatch(), node.getSkipToLabel(), node.getSkipToPosition(), node.isInitial(), node.getPattern(), node.getSubsets(), rewrittenVariableDefinitions);
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), patternRecognitionNode, (Set<Symbol>)ImmutableSet.copyOf(node.getOutputSymbols())).orElse(patternRecognitionNode));
    }

    private static Map<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers> rewriteVariableDefinitions(Map<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers> variableDefinitions, Assignments.Builder assignments, Rule.Context context) {
        return (Map)variableDefinitions.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> PushDownProjectionsFromPatternRecognition.rewrite((LogicalIndexExtractor.ExpressionAndValuePointers)entry.getValue(), assignments, context)));
    }

    private static Map<Symbol, PatternRecognitionNode.Measure> rewriteMeasureDefinitions(Map<Symbol, PatternRecognitionNode.Measure> measureDefinitions, Assignments.Builder assignments, Rule.Context context) {
        return (Map)measureDefinitions.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> new PatternRecognitionNode.Measure(PushDownProjectionsFromPatternRecognition.rewrite(((PatternRecognitionNode.Measure)entry.getValue()).getExpressionAndValuePointers(), assignments, context), ((PatternRecognitionNode.Measure)entry.getValue()).getType())));
    }

    private static LogicalIndexExtractor.ExpressionAndValuePointers rewrite(LogicalIndexExtractor.ExpressionAndValuePointers expression, Assignments.Builder assignments, Rule.Context context) {
        ImmutableList.Builder rewrittenPointers = ImmutableList.builder();
        for (ValuePointer valuePointer : expression.getValuePointers()) {
            if (valuePointer instanceof ScalarValuePointer) {
                rewrittenPointers.add((Object)valuePointer);
                continue;
            }
            AggregationValuePointer aggregationPointer = (AggregationValuePointer)valuePointer;
            ImmutableSet runtimeEvaluatedSymbols = ImmutableSet.of((Object)aggregationPointer.getClassifierSymbol(), (Object)aggregationPointer.getMatchNumberSymbol());
            List argumentTypes = aggregationPointer.getFunction().getSignature().getArgumentTypes();
            ImmutableList.Builder rewrittenArguments = ImmutableList.builder();
            for (int i = 0; i < aggregationPointer.getArguments().size(); ++i) {
                Expression argument;
                block7: {
                    block6: {
                        argument = aggregationPointer.getArguments().get(i);
                        if (argument instanceof SymbolReference) break block6;
                        if (!SymbolsExtractor.extractUnique(argument).stream().anyMatch(((Set)runtimeEvaluatedSymbols)::contains)) break block7;
                    }
                    rewrittenArguments.add((Object)argument);
                    continue;
                }
                Symbol symbol = context.getSymbolAllocator().newSymbol(argument, (Type)argumentTypes.get(i));
                assignments.put(symbol, argument);
                rewrittenArguments.add((Object)symbol.toSymbolReference());
            }
            rewrittenPointers.add((Object)new AggregationValuePointer(aggregationPointer.getFunction(), aggregationPointer.getSetDescriptor(), (List<Expression>)rewrittenArguments.build(), aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol()));
        }
        return new LogicalIndexExtractor.ExpressionAndValuePointers(expression.getExpression(), expression.getLayout(), (List<ValuePointer>)rewrittenPointers.build(), expression.getClassifierSymbols(), expression.getMatchNumberSymbols());
    }
}

