/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.metadata.Metadata;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.SortExpressionContext;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.SymbolReference;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public final class SortExpressionExtractor {
    private SortExpressionExtractor() {
    }

    public static Optional<SortExpressionContext> extractSortExpression(Metadata metadata, Set<Symbol> buildSymbols, Expression filter) {
        List<Expression> filterConjuncts = ExpressionUtils.extractConjuncts(filter);
        SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols);
        List sortExpressionCandidates = (List)filterConjuncts.stream().filter(expression -> DeterminismEvaluator.isDeterministic(expression, metadata)).map(arg_0 -> ((SortExpressionVisitor)visitor).process(arg_0)).filter(Optional::isPresent).map(Optional::get).collect(Collectors.toMap(SortExpressionContext::getSortExpression, Function.identity(), SortExpressionExtractor::merge)).values().stream().collect(ImmutableList.toImmutableList());
        return sortExpressionCandidates.stream().sorted(Comparator.comparing(context -> -1 * context.getSearchExpressions().size())).findFirst();
    }

    private static SortExpressionContext merge(SortExpressionContext left, SortExpressionContext right) {
        Preconditions.checkArgument((boolean)left.getSortExpression().equals((Object)right.getSortExpression()));
        ImmutableList.Builder searchExpressions = ImmutableList.builder();
        searchExpressions.addAll(left.getSearchExpressions());
        searchExpressions.addAll(right.getSearchExpressions());
        return new SortExpressionContext(left.getSortExpression(), (List<Expression>)searchExpressions.build());
    }

    private static Optional<SymbolReference> asBuildSymbolReference(Set<Symbol> buildLayout, Expression expression) {
        SymbolReference symbolReference;
        if (expression instanceof SymbolReference && buildLayout.contains(new Symbol((symbolReference = (SymbolReference)expression).getName()))) {
            return Optional.of(symbolReference);
        }
        return Optional.empty();
    }

    private static boolean hasBuildSymbolReference(Set<Symbol> buildSymbols, Expression expression) {
        return (Boolean)new BuildSymbolReferenceFinder(buildSymbols).process((Node)expression);
    }

    private static class SortExpressionVisitor
    extends AstVisitor<Optional<SortExpressionContext>, Void> {
        private final Set<Symbol> buildSymbols;

        public SortExpressionVisitor(Set<Symbol> buildSymbols) {
            this.buildSymbols = buildSymbols;
        }

        protected Optional<SortExpressionContext> visitExpression(Expression expression, Void context) {
            return Optional.empty();
        }

        protected Optional<SortExpressionContext> visitComparisonExpression(ComparisonExpression comparison, Void context) {
            return switch (comparison.getOperator()) {
                case ComparisonExpression.Operator.GREATER_THAN, ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, ComparisonExpression.Operator.LESS_THAN, ComparisonExpression.Operator.LESS_THAN_OR_EQUAL -> {
                    Optional<SymbolReference> sortChannel = SortExpressionExtractor.asBuildSymbolReference(this.buildSymbols, comparison.getRight());
                    boolean hasBuildReferencesOnOtherSide = SortExpressionExtractor.hasBuildSymbolReference(this.buildSymbols, comparison.getLeft());
                    if (sortChannel.isEmpty()) {
                        sortChannel = SortExpressionExtractor.asBuildSymbolReference(this.buildSymbols, comparison.getLeft());
                        hasBuildReferencesOnOtherSide = SortExpressionExtractor.hasBuildSymbolReference(this.buildSymbols, comparison.getRight());
                    }
                    if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) {
                        yield sortChannel.map(symbolReference -> new SortExpressionContext((Expression)symbolReference, Collections.singletonList(comparison)));
                    }
                    yield Optional.empty();
                }
                default -> Optional.empty();
            };
        }

        protected Optional<SortExpressionContext> visitBetweenPredicate(BetweenPredicate node, Void context) {
            Optional<SortExpressionContext> result = this.visitComparisonExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, node.getValue(), node.getMin()), context);
            if (result.isPresent()) {
                return result;
            }
            return this.visitComparisonExpression(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, node.getValue(), node.getMax()), context);
        }
    }

    private static class BuildSymbolReferenceFinder
    extends AstVisitor<Boolean, Void> {
        private final Set<String> buildSymbols;

        public BuildSymbolReferenceFinder(Set<Symbol> buildSymbols) {
            this.buildSymbols = (Set)buildSymbols.stream().map(Symbol::getName).collect(ImmutableSet.toImmutableSet());
        }

        protected Boolean visitNode(Node node, Void context) {
            for (Node child : node.getChildren()) {
                if (!((Boolean)this.process(child, context)).booleanValue()) continue;
                return true;
            }
            return false;
        }

        protected Boolean visitSymbolReference(SymbolReference symbolReference, Void context) {
            return this.buildSymbols.contains(symbolReference.getName());
        }
    }
}

