/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.rowpattern;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.rowpattern.LogicalIndexPointer;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.LabelDereference;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.ProcessingMode;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SymbolReference;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;

public class LogicalIndexExtractor {
    public static ExpressionAndValuePointers rewrite(Expression expression, Map<IrLabel, Set<IrLabel>> subsets, SymbolAllocator symbolAllocator) {
        ImmutableList.Builder layout = ImmutableList.builder();
        ImmutableList.Builder valuePointers = ImmutableList.builder();
        ImmutableSet.Builder classifierSymbols = ImmutableSet.builder();
        ImmutableSet.Builder matchNumberSymbols = ImmutableSet.builder();
        Visitor visitor = new Visitor(subsets, (ImmutableList.Builder<Symbol>)layout, (ImmutableList.Builder<ValuePointer>)valuePointers, (ImmutableSet.Builder<Symbol>)classifierSymbols, (ImmutableSet.Builder<Symbol>)matchNumberSymbols, symbolAllocator);
        Expression rewritten = ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)visitor, (Expression)expression, (Object)LogicalIndexContext.DEFAULT);
        return new ExpressionAndValuePointers(rewritten, (List<Symbol>)layout.build(), (List<ValuePointer>)valuePointers.build(), (Set<Symbol>)classifierSymbols.build(), (Set<Symbol>)matchNumberSymbols.build());
    }

    private LogicalIndexExtractor() {
    }

    public static class ValuePointer {
        private final LogicalIndexPointer logicalIndexPointer;
        private final Symbol inputSymbol;

        @JsonCreator
        public ValuePointer(LogicalIndexPointer logicalIndexPointer, Symbol inputSymbol) {
            this.logicalIndexPointer = Objects.requireNonNull(logicalIndexPointer, "logicalIndexPointer is null");
            this.inputSymbol = Objects.requireNonNull(inputSymbol, "inputSymbol is null");
        }

        @JsonProperty
        public LogicalIndexPointer getLogicalIndexPointer() {
            return this.logicalIndexPointer;
        }

        @JsonProperty
        public Symbol getInputSymbol() {
            return this.inputSymbol;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || this.getClass() != obj.getClass()) {
                return false;
            }
            ValuePointer o = (ValuePointer)obj;
            return Objects.equals(this.logicalIndexPointer, o.logicalIndexPointer) && Objects.equals(this.inputSymbol, o.inputSymbol);
        }

        public int hashCode() {
            return Objects.hash(this.logicalIndexPointer, this.inputSymbol);
        }
    }

    public static class ExpressionAndValuePointers {
        public static final ExpressionAndValuePointers TRUE = new ExpressionAndValuePointers((Expression)BooleanLiteral.TRUE_LITERAL, (List<Symbol>)ImmutableList.of(), (List<ValuePointer>)ImmutableList.of(), (Set<Symbol>)ImmutableSet.of(), (Set<Symbol>)ImmutableSet.of());
        private final Expression expression;
        private final List<Symbol> layout;
        private final List<ValuePointer> valuePointers;
        private final Set<Symbol> classifierSymbols;
        private final Set<Symbol> matchNumberSymbols;

        @JsonCreator
        public ExpressionAndValuePointers(Expression expression, List<Symbol> layout, List<ValuePointer> valuePointers, Set<Symbol> classifierSymbols, Set<Symbol> matchNumberSymbols) {
            this.expression = Objects.requireNonNull(expression, "expression is null");
            this.layout = Objects.requireNonNull(layout, "layout is null");
            this.valuePointers = Objects.requireNonNull(valuePointers, "valuePointers is null");
            this.classifierSymbols = Objects.requireNonNull(classifierSymbols, "classifierSymbols is null");
            this.matchNumberSymbols = Objects.requireNonNull(matchNumberSymbols, "matchNumberSymbols is null");
        }

        @JsonProperty
        public Expression getExpression() {
            return this.expression;
        }

        @JsonProperty
        public List<Symbol> getLayout() {
            return this.layout;
        }

        @JsonProperty
        public List<ValuePointer> getValuePointers() {
            return this.valuePointers;
        }

        @JsonProperty
        public Set<Symbol> getClassifierSymbols() {
            return this.classifierSymbols;
        }

        @JsonProperty
        public Set<Symbol> getMatchNumberSymbols() {
            return this.matchNumberSymbols;
        }

        public List<Symbol> getInputSymbols() {
            return (List)this.valuePointers.stream().map(ValuePointer::getInputSymbol).filter(symbol -> !this.classifierSymbols.contains(symbol) && !this.matchNumberSymbols.contains(symbol)).collect(ImmutableList.toImmutableList());
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || this.getClass() != obj.getClass()) {
                return false;
            }
            ExpressionAndValuePointers o = (ExpressionAndValuePointers)obj;
            return Objects.equals(this.expression, o.expression) && Objects.equals(this.layout, o.layout) && Objects.equals(this.valuePointers, o.valuePointers) && Objects.equals(this.classifierSymbols, o.classifierSymbols) && Objects.equals(this.matchNumberSymbols, o.matchNumberSymbols);
        }

        public int hashCode() {
            return Objects.hash(this.expression, this.layout, this.valuePointers, this.classifierSymbols, this.matchNumberSymbols);
        }
    }

    private static class LogicalIndexContext {
        public static final LogicalIndexContext DEFAULT = new LogicalIndexContext((Set<IrLabel>)ImmutableSet.of(), true, true, 0, 0);
        private final Set<IrLabel> label;
        private final boolean running;
        private final boolean last;
        private final int logicalOffset;
        private final int physicalOffset;

        private LogicalIndexContext(Set<IrLabel> label, boolean running, boolean last, int logicalOffset, int physicalOffset) {
            this.label = Objects.requireNonNull(label, "label is null");
            this.running = running;
            this.last = last;
            this.logicalOffset = logicalOffset;
            this.physicalOffset = physicalOffset;
        }

        public LogicalIndexContext withPhysicalOffset(int physicalOffset) {
            return new LogicalIndexContext(this.label, this.running, this.last, this.logicalOffset, physicalOffset);
        }

        public LogicalIndexContext withLogicalOffset(boolean running, boolean last, int logicalOffset) {
            return new LogicalIndexContext(this.label, running, last, logicalOffset, this.physicalOffset);
        }

        public LogicalIndexContext withLabels(Set<IrLabel> labels) {
            return new LogicalIndexContext(labels, this.running, this.last, this.logicalOffset, this.physicalOffset);
        }

        public LogicalIndexPointer toLogicalIndexPointer() {
            return new LogicalIndexPointer(this.label, this.last, this.running, this.logicalOffset, this.physicalOffset);
        }
    }

    private static class Visitor
    extends ExpressionRewriter<LogicalIndexContext> {
        private final Map<IrLabel, Set<IrLabel>> subsets;
        private final ImmutableList.Builder<Symbol> layout;
        private final ImmutableList.Builder<ValuePointer> valuePointers;
        private final ImmutableSet.Builder<Symbol> classifierSymbols;
        private final ImmutableSet.Builder<Symbol> matchNumberSymbols;
        private final SymbolAllocator symbolAllocator;

        public Visitor(Map<IrLabel, Set<IrLabel>> subsets, ImmutableList.Builder<Symbol> layout, ImmutableList.Builder<ValuePointer> valuePointers, ImmutableSet.Builder<Symbol> classifierSymbols, ImmutableSet.Builder<Symbol> matchNumberSymbols, SymbolAllocator symbolAllocator) {
            this.subsets = Objects.requireNonNull(subsets, "subsets is null");
            this.layout = Objects.requireNonNull(layout, "layout is null");
            this.valuePointers = Objects.requireNonNull(valuePointers, "valuePointers is null");
            this.classifierSymbols = Objects.requireNonNull(classifierSymbols, "classifierSymbols is null");
            this.matchNumberSymbols = Objects.requireNonNull(matchNumberSymbols, "matchNumberSymbols is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        protected Expression rewriteExpression(Expression node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            return treeRewriter.defaultRewrite(node, (Object)context);
        }

        public Expression rewriteLabelDereference(LabelDereference node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            Symbol reallocated = this.symbolAllocator.newSymbol(Symbol.from((Expression)node.getReference()));
            this.layout.add((Object)reallocated);
            ImmutableSet labels = this.subsets.get(this.irLabel(node.getLabel()));
            if (labels == null) {
                labels = ImmutableSet.of((Object)this.irLabel(node.getLabel()));
            }
            this.valuePointers.add((Object)new ValuePointer(context.withLabels((Set<IrLabel>)labels).toLogicalIndexPointer(), Symbol.from((Expression)node.getReference())));
            return reallocated.toSymbolReference();
        }

        public Expression rewriteSymbolReference(SymbolReference node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            Symbol reallocated = this.symbolAllocator.newSymbol(Symbol.from((Expression)node));
            this.layout.add((Object)reallocated);
            this.valuePointers.add((Object)new ValuePointer(context.withLabels((Set<IrLabel>)ImmutableSet.of()).toLogicalIndexPointer(), Symbol.from((Expression)node)));
            return reallocated.toSymbolReference();
        }

        public Expression rewriteFunctionCall(FunctionCall node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            if (ExpressionAnalyzer.isPatternRecognitionFunction(node)) {
                String functionName;
                QualifiedName name = node.getName();
                switch (functionName = name.getSuffix().toUpperCase(Locale.ENGLISH)) {
                    case "FIRST": 
                    case "LAST": 
                    case "PREV": 
                    case "NEXT": {
                        return this.rewritePatternNavigationFunction(node, context, treeRewriter);
                    }
                    case "CLASSIFIER": {
                        return this.rewriteClassifierFunction(node, context);
                    }
                    case "MATCH_NUMBER": {
                        return this.rewriteMatchNumberFunction();
                    }
                }
                throw new UnsupportedOperationException("unsupported pattern recognition function type: " + node.getName());
            }
            return super.rewriteFunctionCall(node, (Object)context, treeRewriter);
        }

        private Expression rewritePatternNavigationFunction(FunctionCall node, LogicalIndexContext context, ExpressionTreeRewriter<LogicalIndexContext> treeRewriter) {
            String functionName = node.getName().getSuffix().toUpperCase(Locale.ENGLISH);
            Expression argument = (Expression)node.getArguments().get(0);
            Optional processingMode = node.getProcessingMode();
            OptionalInt offset = OptionalInt.empty();
            if (node.getArguments().size() > 1) {
                offset = OptionalInt.of(Math.toIntExact(((LongLiteral)node.getArguments().get(1)).getValue()));
            }
            switch (functionName) {
                case "PREV": {
                    return treeRewriter.rewrite(argument, (Object)context.withPhysicalOffset(-offset.orElse(1)));
                }
                case "NEXT": {
                    return treeRewriter.rewrite(argument, (Object)context.withPhysicalOffset(offset.orElse(1)));
                }
                case "FIRST": {
                    boolean running = processingMode.isEmpty() || ((ProcessingMode)processingMode.get()).getMode() != ProcessingMode.Mode.FINAL;
                    return treeRewriter.rewrite(argument, (Object)context.withLogicalOffset(running, false, offset.orElse(0)));
                }
                case "LAST": {
                    boolean running = processingMode.isEmpty() || ((ProcessingMode)processingMode.get()).getMode() != ProcessingMode.Mode.FINAL;
                    return treeRewriter.rewrite(argument, (Object)context.withLogicalOffset(running, true, offset.orElse(0)));
                }
            }
            throw new UnsupportedOperationException("unsupported pattern navigation function type: " + node.getName());
        }

        private Expression rewriteClassifierFunction(FunctionCall node, LogicalIndexContext context) {
            IrLabel label;
            Symbol classifierSymbol = this.symbolAllocator.newSymbol("classifier", (Type)VarcharType.VARCHAR);
            this.layout.add((Object)classifierSymbol);
            Object labels = ImmutableSet.of();
            if (!node.getArguments().isEmpty() && (labels = this.subsets.get(label = this.irLabel(((Identifier)Iterables.getOnlyElement((Iterable)node.getArguments())).getCanonicalValue()))) == null) {
                labels = ImmutableSet.of((Object)label);
            }
            this.valuePointers.add((Object)new ValuePointer(context.withLabels((Set<IrLabel>)labels).toLogicalIndexPointer(), classifierSymbol));
            this.classifierSymbols.add((Object)classifierSymbol);
            return classifierSymbol.toSymbolReference();
        }

        private Expression rewriteMatchNumberFunction() {
            Symbol matchNumberSymbol = this.symbolAllocator.newSymbol("match_number", (Type)BigintType.BIGINT);
            this.layout.add((Object)matchNumberSymbol);
            this.valuePointers.add((Object)new ValuePointer(LogicalIndexContext.DEFAULT.toLogicalIndexPointer(), matchNumberSymbol));
            this.matchNumberSymbols.add((Object)matchNumberSymbol);
            return matchNumberSymbol.toSymbolReference();
        }

        private IrLabel irLabel(String label) {
            return new IrLabel(label);
        }
    }
}

