/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.ClassifierValuePointer;
import io.trino.sql.planner.rowpattern.ExpressionAndValuePointers;
import io.trino.sql.planner.rowpattern.MatchNumberValuePointer;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import java.lang.runtime.SwitchBootstraps;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class PushDownProjectionsFromPatternRecognition
implements Rule<PatternRecognitionNode> {
    private static final Pattern<PatternRecognitionNode> PATTERN = Patterns.patternRecognition();

    @Override
    public Pattern<PatternRecognitionNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(PatternRecognitionNode node, Captures captures, Rule.Context context) {
        Assignments.Builder assignments = Assignments.builder();
        Map<IrLabel, ExpressionAndValuePointers> rewrittenVariableDefinitions = PushDownProjectionsFromPatternRecognition.rewriteVariableDefinitions(node.getVariableDefinitions(), assignments, context);
        Map<Symbol, PatternRecognitionNode.Measure> rewrittenMeasureDefinitions = PushDownProjectionsFromPatternRecognition.rewriteMeasureDefinitions(node.getMeasures(), assignments, context);
        if (assignments.build().isEmpty()) {
            return Rule.Result.empty();
        }
        assignments.putIdentities(node.getSource().getOutputSymbols());
        ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), assignments.build());
        PatternRecognitionNode patternRecognitionNode = new PatternRecognitionNode(node.getId(), projectNode, node.getSpecification(), node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix(), node.getWindowFunctions(), rewrittenMeasureDefinitions, node.getCommonBaseFrame(), node.getRowsPerMatch(), node.getSkipToLabels(), node.getSkipToPosition(), node.isInitial(), node.getPattern(), rewrittenVariableDefinitions);
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), patternRecognitionNode, (Set<Symbol>)ImmutableSet.copyOf(node.getOutputSymbols())).orElse(patternRecognitionNode));
    }

    private static Map<IrLabel, ExpressionAndValuePointers> rewriteVariableDefinitions(Map<IrLabel, ExpressionAndValuePointers> variableDefinitions, Assignments.Builder assignments, Rule.Context context) {
        return (Map)variableDefinitions.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> PushDownProjectionsFromPatternRecognition.rewrite((ExpressionAndValuePointers)entry.getValue(), assignments, context)));
    }

    private static Map<Symbol, PatternRecognitionNode.Measure> rewriteMeasureDefinitions(Map<Symbol, PatternRecognitionNode.Measure> measureDefinitions, Assignments.Builder assignments, Rule.Context context) {
        return (Map)measureDefinitions.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> new PatternRecognitionNode.Measure(PushDownProjectionsFromPatternRecognition.rewrite(((PatternRecognitionNode.Measure)entry.getValue()).getExpressionAndValuePointers(), assignments, context), ((PatternRecognitionNode.Measure)entry.getValue()).getType())));
    }

    /*
     * Enabled aggressive block sorting
     */
    private static ExpressionAndValuePointers rewrite(ExpressionAndValuePointers expression, Assignments.Builder assignments, Rule.Context context) {
        ImmutableList.Builder rewrittenAssignments = ImmutableList.builder();
        Iterator<ExpressionAndValuePointers.Assignment> iterator = expression.getAssignments().iterator();
        while (true) {
            ValuePointer valuePointer;
            Symbol symbol;
            block12: {
                ImmutableList.Builder rewrittenArguments;
                Set runtimeEvaluatedSymbols;
                AggregationValuePointer pointer;
                ValuePointer valuePointer2;
                if (!iterator.hasNext()) {
                    return new ExpressionAndValuePointers(expression.getExpression(), (List<ExpressionAndValuePointers.Assignment>)rewrittenAssignments.build());
                }
                ExpressionAndValuePointers.Assignment assignment = iterator.next();
                ValuePointer valuePointer3 = assignment.valuePointer();
                symbol = assignment.symbol();
                Objects.requireNonNull(valuePointer3);
                int n = 0;
                switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{ClassifierValuePointer.class, MatchNumberValuePointer.class, ScalarValuePointer.class, AggregationValuePointer.class}, (Object)valuePointer2, n)) {
                    case 0: {
                        ClassifierValuePointer pointer2;
                        valuePointer = pointer2 = (ClassifierValuePointer)valuePointer2;
                        break block12;
                    }
                    case 1: {
                        MatchNumberValuePointer pointer3 = (MatchNumberValuePointer)valuePointer2;
                        valuePointer = pointer3;
                        break block12;
                    }
                    case 2: {
                        ScalarValuePointer pointer4 = (ScalarValuePointer)valuePointer2;
                        valuePointer = pointer4;
                        break block12;
                    }
                    case 3: {
                        pointer = (AggregationValuePointer)valuePointer2;
                        runtimeEvaluatedSymbols = (Set)ImmutableSet.of(pointer.getClassifierSymbol(), pointer.getMatchNumberSymbol()).stream().filter(Optional::isPresent).map(Optional::get).collect(ImmutableSet.toImmutableSet());
                        rewrittenArguments = ImmutableList.builder();
                        break;
                    }
                    default: {
                        throw new MatchException(null, null);
                    }
                }
                for (int i = 0; i < pointer.getArguments().size(); ++i) {
                    Expression argument;
                    block14: {
                        block13: {
                            argument = pointer.getArguments().get(i);
                            if (argument instanceof Reference) break block13;
                            if (!SymbolsExtractor.extractUnique(argument).stream().anyMatch(runtimeEvaluatedSymbols::contains)) break block14;
                        }
                        rewrittenArguments.add((Object)argument);
                        continue;
                    }
                    Symbol symbol2 = context.getSymbolAllocator().newSymbol(argument);
                    assignments.put(symbol2, argument);
                    rewrittenArguments.add((Object)symbol2.toSymbolReference());
                }
                valuePointer = new AggregationValuePointer(pointer.getFunction(), pointer.getSetDescriptor(), (List<Expression>)rewrittenArguments.build(), pointer.getClassifierSymbol(), pointer.getMatchNumberSymbol());
            }
            rewrittenAssignments.add((Object)new ExpressionAndValuePointers.Assignment(symbol, valuePointer));
        }
    }
}

