/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.rowpattern;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.analyzer.ExpressionTreeUtils;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.rowpattern.AggregateArgumentsRewriter;
import io.trino.sql.planner.rowpattern.AggregatedSetDescriptor;
import io.trino.sql.planner.rowpattern.AggregationValuePointer;
import io.trino.sql.planner.rowpattern.LogicalIndexPointer;
import io.trino.sql.planner.rowpattern.ScalarValuePointer;
import io.trino.sql.planner.rowpattern.ValuePointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.LabelDereference;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.ProcessingMode;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SymbolReference;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;

public class LogicalIndexExtractor {
    public static ExpressionAndValuePointers rewrite(Expression expression, Map<IrLabel, Set<IrLabel>> subsets, SymbolAllocator symbolAllocator, Session session, Metadata metadata) {
        ImmutableList.Builder layout = ImmutableList.builder();
        ImmutableList.Builder valuePointers = ImmutableList.builder();
        ImmutableSet.Builder classifierSymbols = ImmutableSet.builder();
        ImmutableSet.Builder matchNumberSymbols = ImmutableSet.builder();
        Visitor visitor = new Visitor(subsets, (ImmutableList.Builder<Symbol>)layout, (ImmutableList.Builder<ValuePointer>)valuePointers, (ImmutableSet.Builder<Symbol>)classifierSymbols, (ImmutableSet.Builder<Symbol>)matchNumberSymbols, symbolAllocator, session, metadata);
        Expression rewritten = ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)visitor, (Expression)expression, (Object)LogicalIndexContext.DEFAULT);
        return new ExpressionAndValuePointers(rewritten, (List<Symbol>)layout.build(), (List<ValuePointer>)valuePointers.build(), (Set<Symbol>)classifierSymbols.build(), (Set<Symbol>)matchNumberSymbols.build());
    }

    private LogicalIndexExtractor() {
    }

    private static class Visitor
    extends ExpressionRewriter<LogicalIndexContext> {
        private final Map<IrLabel, Set<IrLabel>> subsets;
        private final ImmutableList.Builder<Symbol> layout;
        private final ImmutableList.Builder<ValuePointer> valuePointers;
        private final ImmutableSet.Builder<Symbol> classifierSymbols;
        private final ImmutableSet.Builder<Symbol> matchNumberSymbols;
        private final SymbolAllocator symbolAllocator;
        private final Session session;
        private final Metadata metadata;

        public Visitor(Map<IrLabel, Set<IrLabel>> subsets, ImmutableList.Builder<Symbol> layout, ImmutableList.Builder<ValuePointer> valuePointers, ImmutableSet.Builder<Symbol> classifierSymbols, ImmutableSet.Builder<Symbol> matchNumberSymbols, SymbolAllocator symbolAllocator, Session session, Metadata metadata) {
            this.subsets = Objects.requireNonNull(subsets, "subsets is null");
            this.layout = Objects.requireNonNull(layout, "layout is null");
            this.valuePointers = Objects.requireNonNull(valuePointers, "valuePointers is null");
            this.classifierSymbols = Objects.requireNonNull(classifierSymbols, "classifierSymbols is null");
            this.matchNumberSymbols = Objects.requireNonNull(matchNumberSymbols, "matchNumberSymbols is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        protected Expression rewriteExpression(Expression node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            return treeRewriter.defaultRewrite(node, (Object)context);
        }

        public Expression rewriteLabelDereference(LabelDereference node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            Symbol referenced = Symbol.from((Expression)node.getReference().orElseThrow());
            Symbol reallocated = this.symbolAllocator.newSymbol(referenced);
            this.layout.add((Object)reallocated);
            ImmutableSet labels = this.subsets.get(this.irLabel(node.getLabel()));
            if (labels == null) {
                labels = ImmutableSet.of((Object)this.irLabel(node.getLabel()));
            }
            this.valuePointers.add((Object)new ScalarValuePointer(context.withLabels((Set<IrLabel>)labels).toLogicalIndexPointer(), referenced));
            return reallocated.toSymbolReference();
        }

        public Expression rewriteSymbolReference(SymbolReference node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            Symbol reallocated = this.symbolAllocator.newSymbol(Symbol.from((Expression)node));
            this.layout.add((Object)reallocated);
            this.valuePointers.add((Object)new ScalarValuePointer(context.withLabels((Set<IrLabel>)ImmutableSet.of()).toLogicalIndexPointer(), Symbol.from((Expression)node)));
            return reallocated.toSymbolReference();
        }

        public Expression rewriteFunctionCall(FunctionCall node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            if (ExpressionAnalyzer.isPatternRecognitionFunction(node)) {
                String functionName;
                QualifiedName name = node.getName();
                return switch (functionName = name.getSuffix().toUpperCase(Locale.ENGLISH)) {
                    case "FIRST", "LAST", "PREV", "NEXT" -> this.rewritePatternNavigationFunction(node, context, treeRewriter);
                    case "CLASSIFIER" -> this.rewriteClassifierFunction(node, context);
                    case "MATCH_NUMBER" -> this.rewriteMatchNumberFunction();
                    default -> throw new UnsupportedOperationException("unsupported pattern recognition function type: " + node.getName());
                };
            }
            ResolvedFunction resolvedFunction = this.metadata.decodeFunction(node.getName());
            if (resolvedFunction.getFunctionKind() == FunctionKind.AGGREGATE) {
                Type type = resolvedFunction.getSignature().getReturnType();
                Symbol aggregationSymbol = this.symbolAllocator.newSymbol((Expression)node, type);
                this.layout.add((Object)aggregationSymbol);
                Symbol classifierSymbol = this.symbolAllocator.newSymbol("classifier", (Type)VarcharType.VARCHAR);
                Symbol matchNumberSymbol = this.symbolAllocator.newSymbol("match_number", (Type)BigintType.BIGINT);
                List<Expression> rewrittenArguments = AggregateArgumentsRewriter.rewrite(node.getArguments(), classifierSymbol, matchNumberSymbol);
                AggregationValuePointer descriptor = new AggregationValuePointer(resolvedFunction, new AggregatedSetDescriptor(this.extractLabels(node), node.getProcessingMode().isEmpty() || ((ProcessingMode)node.getProcessingMode().get()).getMode() != ProcessingMode.Mode.FINAL), rewrittenArguments, classifierSymbol, matchNumberSymbol);
                this.valuePointers.add((Object)descriptor);
                return aggregationSymbol.toSymbolReference();
            }
            return super.rewriteFunctionCall(node, (Object)context, treeRewriter);
        }

        private Set<IrLabel> extractLabels(FunctionCall node) {
            FunctionCall classifier;
            if (node.getArguments().isEmpty()) {
                return ImmutableSet.of();
            }
            List<LabelDereference> labeledDereferences = ExpressionTreeUtils.extractExpressions(node.getArguments(), LabelDereference.class);
            if (!labeledDereferences.isEmpty()) {
                IrLabel label = this.irLabel(labeledDereferences.get(0).getLabel());
                ImmutableSet labels = this.subsets.get(label);
                if (labels == null) {
                    labels = ImmutableSet.of((Object)label);
                }
                return labels;
            }
            Optional<FunctionCall> classifierCall = ExpressionTreeUtils.extractExpressions(node.getArguments(), FunctionCall.class).stream().filter(ExpressionAnalyzer::isPatternRecognitionFunction).filter(function -> function.getName().getSuffix().toUpperCase(Locale.ENGLISH).equals("CLASSIFIER")).findFirst();
            if (classifierCall.isPresent() && !(classifier = classifierCall.get()).getArguments().isEmpty()) {
                IrLabel label = this.irLabel(((Identifier)Iterables.getOnlyElement((Iterable)classifier.getArguments())).getCanonicalValue());
                ImmutableSet labels = this.subsets.get(label);
                if (labels == null) {
                    labels = ImmutableSet.of((Object)label);
                }
                return labels;
            }
            return ImmutableSet.of();
        }

        private Expression rewritePatternNavigationFunction(FunctionCall node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            String functionName = node.getName().getSuffix().toUpperCase(Locale.ENGLISH);
            Expression argument = (Expression)node.getArguments().get(0);
            Optional processingMode = node.getProcessingMode();
            OptionalInt offset = OptionalInt.empty();
            if (node.getArguments().size() > 1) {
                offset = OptionalInt.of(Math.toIntExact(((LongLiteral)node.getArguments().get(1)).getParsedValue()));
            }
            return switch (functionName) {
                case "PREV" -> treeRewriter.rewrite(argument, (Object)context.withPhysicalOffset(-offset.orElse(1)));
                case "NEXT" -> treeRewriter.rewrite(argument, (Object)context.withPhysicalOffset(offset.orElse(1)));
                case "FIRST" -> treeRewriter.rewrite(argument, (Object)context.withLogicalOffset(processingMode.isEmpty() || ((ProcessingMode)processingMode.get()).getMode() != ProcessingMode.Mode.FINAL, false, offset.orElse(0)));
                case "LAST" -> treeRewriter.rewrite(argument, (Object)context.withLogicalOffset(processingMode.isEmpty() || ((ProcessingMode)processingMode.get()).getMode() != ProcessingMode.Mode.FINAL, true, offset.orElse(0)));
                default -> throw new UnsupportedOperationException("unsupported pattern navigation function type: " + node.getName());
            };
        }

        private Expression rewriteClassifierFunction(FunctionCall node, LogicalIndexContext context) {
            IrLabel label;
            Symbol classifierSymbol = this.symbolAllocator.newSymbol("classifier", (Type)VarcharType.VARCHAR);
            this.layout.add((Object)classifierSymbol);
            Object labels = ImmutableSet.of();
            if (!node.getArguments().isEmpty() && (labels = this.subsets.get(label = this.irLabel(((Identifier)Iterables.getOnlyElement((Iterable)node.getArguments())).getCanonicalValue()))) == null) {
                labels = ImmutableSet.of((Object)label);
            }
            this.valuePointers.add((Object)new ScalarValuePointer(context.withLabels((Set<IrLabel>)labels).toLogicalIndexPointer(), classifierSymbol));
            this.classifierSymbols.add((Object)classifierSymbol);
            return classifierSymbol.toSymbolReference();
        }

        private Expression rewriteMatchNumberFunction() {
            Symbol matchNumberSymbol = this.symbolAllocator.newSymbol("match_number", (Type)BigintType.BIGINT);
            this.layout.add((Object)matchNumberSymbol);
            this.valuePointers.add((Object)new ScalarValuePointer(LogicalIndexContext.DEFAULT.toLogicalIndexPointer(), matchNumberSymbol));
            this.matchNumberSymbols.add((Object)matchNumberSymbol);
            return matchNumberSymbol.toSymbolReference();
        }

        private IrLabel irLabel(String label) {
            return new IrLabel(label);
        }
    }

    private static class LogicalIndexContext {
        public static final LogicalIndexContext DEFAULT = new LogicalIndexContext((Set<IrLabel>)ImmutableSet.of(), true, true, 0, 0);
        private final Set<IrLabel> label;
        private final boolean running;
        private final boolean last;
        private final int logicalOffset;
        private final int physicalOffset;

        private LogicalIndexContext(Set<IrLabel> label, boolean running, boolean last, int logicalOffset, int physicalOffset) {
            this.label = Objects.requireNonNull(label, "label is null");
            this.running = running;
            this.last = last;
            this.logicalOffset = logicalOffset;
            this.physicalOffset = physicalOffset;
        }

        public LogicalIndexContext withPhysicalOffset(int physicalOffset) {
            return new LogicalIndexContext(this.label, this.running, this.last, this.logicalOffset, physicalOffset);
        }

        public LogicalIndexContext withLogicalOffset(boolean running, boolean last, int logicalOffset) {
            return new LogicalIndexContext(this.label, running, last, logicalOffset, this.physicalOffset);
        }

        public LogicalIndexContext withLabels(Set<IrLabel> labels) {
            return new LogicalIndexContext(labels, this.running, this.last, this.logicalOffset, this.physicalOffset);
        }

        public LogicalIndexPointer toLogicalIndexPointer() {
            return new LogicalIndexPointer(this.label, this.last, this.running, this.logicalOffset, this.physicalOffset);
        }
    }

    public static class ExpressionAndValuePointers {
        public static final ExpressionAndValuePointers TRUE = new ExpressionAndValuePointers((Expression)BooleanLiteral.TRUE_LITERAL, (List<Symbol>)ImmutableList.of(), (List<ValuePointer>)ImmutableList.of(), (Set<Symbol>)ImmutableSet.of(), (Set<Symbol>)ImmutableSet.of());
        private final Expression expression;
        private final List<Symbol> layout;
        private final List<ValuePointer> valuePointers;
        private final Set<Symbol> classifierSymbols;
        private final Set<Symbol> matchNumberSymbols;

        @JsonCreator
        public ExpressionAndValuePointers(Expression expression, List<Symbol> layout, List<ValuePointer> valuePointers, Set<Symbol> classifierSymbols, Set<Symbol> matchNumberSymbols) {
            this.expression = Objects.requireNonNull(expression, "expression is null");
            this.layout = Objects.requireNonNull(layout, "layout is null");
            this.valuePointers = Objects.requireNonNull(valuePointers, "valuePointers is null");
            Preconditions.checkArgument((layout.size() == valuePointers.size() ? 1 : 0) != 0, (Object)"layout and valuePointers sizes don't match");
            this.classifierSymbols = Objects.requireNonNull(classifierSymbols, "classifierSymbols is null");
            this.matchNumberSymbols = Objects.requireNonNull(matchNumberSymbols, "matchNumberSymbols is null");
        }

        @JsonProperty
        public Expression getExpression() {
            return this.expression;
        }

        @JsonProperty
        public List<Symbol> getLayout() {
            return this.layout;
        }

        @JsonProperty
        public List<ValuePointer> getValuePointers() {
            return this.valuePointers;
        }

        @JsonProperty
        public Set<Symbol> getClassifierSymbols() {
            return this.classifierSymbols;
        }

        @JsonProperty
        public Set<Symbol> getMatchNumberSymbols() {
            return this.matchNumberSymbols;
        }

        public List<Symbol> getInputSymbols() {
            ImmutableList.Builder inputSymbols = ImmutableList.builder();
            for (ValuePointer valuePointer : this.valuePointers) {
                if (valuePointer instanceof ScalarValuePointer) {
                    ScalarValuePointer pointer = (ScalarValuePointer)valuePointer;
                    Symbol symbol = pointer.getInputSymbol();
                    if (this.classifierSymbols.contains(symbol) || this.matchNumberSymbols.contains(symbol)) continue;
                    inputSymbols.add((Object)symbol);
                    continue;
                }
                if (valuePointer instanceof AggregationValuePointer) {
                    inputSymbols.addAll(((AggregationValuePointer)valuePointer).getInputSymbols());
                    continue;
                }
                throw new UnsupportedOperationException("unexpected ValuePointer type: " + valuePointer.getClass().getSimpleName());
            }
            return inputSymbols.build();
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || this.getClass() != obj.getClass()) {
                return false;
            }
            ExpressionAndValuePointers o = (ExpressionAndValuePointers)obj;
            return Objects.equals(this.expression, o.expression) && Objects.equals(this.layout, o.layout) && Objects.equals(this.valuePointers, o.valuePointers) && Objects.equals(this.classifierSymbols, o.classifierSymbols) && Objects.equals(this.matchNumberSymbols, o.matchNumberSymbols);
        }

        public int hashCode() {
            return Objects.hash(this.expression, this.layout, this.valuePointers, this.classifierSymbols, this.matchNumberSymbols);
        }
    }
}

