/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.Util;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.rowpattern.ExpressionAndValuePointers;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;

public class MergePatternRecognitionNodes {
    private MergePatternRecognitionNodes() {
    }

    public static Set<Rule<?>> rules() {
        return ImmutableSet.of((Object)new MergePatternRecognitionNodesWithoutProject(), (Object)new MergePatternRecognitionNodesWithProject());
    }

    private static boolean patternRecognitionSpecificationsMatch(PatternRecognitionNode parent, PatternRecognitionNode child) {
        return parent.getSpecification().equals(child.getSpecification()) && parent.getCommonBaseFrame().equals(child.getCommonBaseFrame()) && parent.getRowsPerMatch() == child.getRowsPerMatch() && parent.getSkipToLabels().equals(child.getSkipToLabels()) && parent.getSkipToPosition() == child.getSkipToPosition() && parent.isInitial() == child.isInitial() && parent.getPattern().equals(child.getPattern()) && MergePatternRecognitionNodes.equivalent(parent.getVariableDefinitions(), child.getVariableDefinitions());
    }

    private static boolean equivalent(Map<IrLabel, ExpressionAndValuePointers> parentVariableDefinitions, Map<IrLabel, ExpressionAndValuePointers> childVariableDefinitions) {
        if (!parentVariableDefinitions.keySet().equals(childVariableDefinitions.keySet())) {
            return false;
        }
        for (Map.Entry<IrLabel, ExpressionAndValuePointers> parentDefinition : parentVariableDefinitions.entrySet()) {
            ExpressionAndValuePointers childExpression;
            IrLabel label = parentDefinition.getKey();
            ExpressionAndValuePointers parentExpression = parentDefinition.getValue();
            if (parentExpression.equals(childExpression = childVariableDefinitions.get(label))) continue;
            return false;
        }
        return true;
    }

    private static boolean dependsOnSourceCreatedOutputs(PatternRecognitionNode parent, PatternRecognitionNode child) {
        Set<Symbol> sourceCreatedOutputs = child.getCreatedSymbols();
        return Streams.concat((Stream[])new Stream[]{parent.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll).flatMap(Collection::stream), parent.getMeasures().values().stream().map(PatternRecognitionNode.Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).flatMap(Collection::stream)}).anyMatch(sourceCreatedOutputs::contains);
    }

    private static boolean dependsOnSourceCreatedOutputs(PatternRecognitionNode parent, ProjectNode project, PatternRecognitionNode child) {
        Set<Symbol> sourceCreatedOutputs = child.getCreatedSymbols();
        Assignments assignments = project.getAssignments();
        ImmutableSet.Builder parentInputs = ImmutableSet.builder();
        parent.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll).forEach(arg_0 -> ((ImmutableSet.Builder)parentInputs).addAll(arg_0));
        parent.getMeasures().values().stream().map(PatternRecognitionNode.Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).forEach(arg_0 -> ((ImmutableSet.Builder)parentInputs).addAll(arg_0));
        return parentInputs.build().stream().map(assignments::get).map(SymbolsExtractor::extractAll).flatMap(Collection::stream).anyMatch(sourceCreatedOutputs::contains);
    }

    private static Assignments extractPrerequisites(PatternRecognitionNode node, ProjectNode project) {
        Assignments assignments = project.getAssignments();
        ImmutableSet.Builder inputsBuilder = ImmutableSet.builder();
        node.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll).forEach(arg_0 -> ((ImmutableSet.Builder)inputsBuilder).addAll(arg_0));
        node.getMeasures().values().stream().map(PatternRecognitionNode.Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).forEach(arg_0 -> ((ImmutableSet.Builder)inputsBuilder).addAll(arg_0));
        ImmutableSet inputs = inputsBuilder.build();
        return assignments.filter(symbol -> !assignments.isIdentity((Symbol)symbol)).filter(((Set)inputs)::contains);
    }

    private static PatternRecognitionNode merge(PatternRecognitionNode parent, PatternRecognitionNode child) {
        ImmutableMap.Builder windowFunctions = ImmutableMap.builder().putAll(parent.getWindowFunctions()).putAll(child.getWindowFunctions());
        ImmutableMap.Builder measures = ImmutableMap.builder().putAll(parent.getMeasures()).putAll(child.getMeasures());
        return new PatternRecognitionNode(parent.getId(), child.getSource(), parent.getSpecification(), parent.getHashSymbol(), parent.getPrePartitionedInputs(), parent.getPreSortedOrderPrefix(), (Map<Symbol, WindowNode.Function>)windowFunctions.buildOrThrow(), (Map<Symbol, PatternRecognitionNode.Measure>)measures.buildOrThrow(), parent.getCommonBaseFrame(), parent.getRowsPerMatch(), parent.getSkipToLabels(), parent.getSkipToPosition(), parent.isInitial(), parent.getPattern(), parent.getVariableDefinitions());
    }

    public static final class MergePatternRecognitionNodesWithoutProject
    implements Rule<PatternRecognitionNode> {
        private static final Capture<PatternRecognitionNode> CHILD = Capture.newCapture();
        private static final Pattern<PatternRecognitionNode> PATTERN = Patterns.patternRecognition().with(Patterns.source().matching(Patterns.patternRecognition().capturedAs(CHILD)));

        @Override
        public Pattern<PatternRecognitionNode> getPattern() {
            return PATTERN;
        }

        @Override
        public Rule.Result apply(PatternRecognitionNode node, Captures captures, Rule.Context context) {
            PatternRecognitionNode child = (PatternRecognitionNode)captures.get(CHILD);
            if (!MergePatternRecognitionNodes.patternRecognitionSpecificationsMatch(node, child)) {
                return Rule.Result.empty();
            }
            if (MergePatternRecognitionNodes.dependsOnSourceCreatedOutputs(node, child)) {
                return Rule.Result.empty();
            }
            PatternRecognitionNode result = MergePatternRecognitionNodes.merge(node, child);
            return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), result, (Set<Symbol>)ImmutableSet.copyOf(node.getOutputSymbols())).orElse(result));
        }
    }

    public static final class MergePatternRecognitionNodesWithProject
    implements Rule<PatternRecognitionNode> {
        private static final Capture<ProjectNode> PROJECT = Capture.newCapture();
        private static final Capture<PatternRecognitionNode> CHILD = Capture.newCapture();
        private static final Pattern<PatternRecognitionNode> PATTERN = Patterns.patternRecognition().with(Patterns.source().matching(Patterns.project().capturedAs(PROJECT).with(Patterns.source().matching(Patterns.patternRecognition().capturedAs(CHILD)))));

        @Override
        public Pattern<PatternRecognitionNode> getPattern() {
            return PATTERN;
        }

        @Override
        public Rule.Result apply(PatternRecognitionNode node, Captures captures, Rule.Context context) {
            ProjectNode result;
            ProjectNode project = (ProjectNode)captures.get(PROJECT);
            PatternRecognitionNode child = (PatternRecognitionNode)captures.get(CHILD);
            if (!MergePatternRecognitionNodes.patternRecognitionSpecificationsMatch(node, child)) {
                return Rule.Result.empty();
            }
            if (MergePatternRecognitionNodes.dependsOnSourceCreatedOutputs(node, project, child)) {
                return Rule.Result.empty();
            }
            PatternRecognitionNode merged = MergePatternRecognitionNodes.merge(node, child);
            Assignments prerequisites = MergePatternRecognitionNodes.extractPrerequisites(node, project);
            if (prerequisites.isEmpty()) {
                result = new ProjectNode(context.getIdAllocator().getNextId(), merged, Assignments.builder().putIdentities(merged.getOutputSymbols()).putAll(project.getAssignments()).build());
            } else {
                Assignments remainingAssignments = project.getAssignments().filter(symbol -> !prerequisites.getSymbols().contains(symbol));
                merged = (PatternRecognitionNode)merged.replaceChildren((List<PlanNode>)ImmutableList.of((Object)new ProjectNode(context.getIdAllocator().getNextId(), merged.getSource(), Assignments.builder().putIdentities(merged.getSource().getOutputSymbols()).putAll(prerequisites).build())));
                result = new ProjectNode(context.getIdAllocator().getNextId(), merged, Assignments.builder().putIdentities(merged.getOutputSymbols()).putAll(remainingAssignments).build());
            }
            return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), result, (Set<Symbol>)ImmutableSet.copyOf(node.getOutputSymbols())).orElse(result));
        }
    }
}

