/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.spi.block.Block;
import io.trino.spi.block.SqlRow;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.DomainTranslator;
import io.trino.sql.planner.EqualityInference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SortNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.planner.plan.WindowNode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;

public class EffectivePredicateExtractor {
    private static final Predicate<Map.Entry<Symbol, ? extends Expression>> SYMBOL_MATCHES_EXPRESSION = entry -> ((Expression)entry.getValue()).equals(((Symbol)entry.getKey()).toSymbolReference());
    private static final Function<Map.Entry<Symbol, ? extends Expression>, Expression> ENTRY_TO_EQUALITY = entry -> {
        SymbolReference reference = ((Symbol)entry.getKey()).toSymbolReference();
        Expression expression = (Expression)entry.getValue();
        return new ComparisonExpression(ComparisonExpression.Operator.EQUAL, reference, expression);
    };
    private final PlannerContext plannerContext;
    private final DomainTranslator domainTranslator;
    private final boolean useTableProperties;

    public EffectivePredicateExtractor(DomainTranslator domainTranslator, PlannerContext plannerContext, boolean useTableProperties) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.domainTranslator = Objects.requireNonNull(domainTranslator, "domainTranslator is null");
        this.useTableProperties = useTableProperties;
    }

    public Expression extract(Session session, PlanNode node) {
        return node.accept(new Visitor(this.domainTranslator, this.plannerContext, session, this.useTableProperties), null);
    }

    private static class Visitor
    extends PlanVisitor<Expression, Void> {
        private final DomainTranslator domainTranslator;
        private final PlannerContext plannerContext;
        private final Metadata metadata;
        private final Session session;
        private final boolean useTableProperties;

        public Visitor(DomainTranslator domainTranslator, PlannerContext plannerContext, Session session, boolean useTableProperties) {
            this.domainTranslator = Objects.requireNonNull(domainTranslator, "domainTranslator is null");
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.metadata = plannerContext.getMetadata();
            this.session = Objects.requireNonNull(session, "session is null");
            this.useTableProperties = useTableProperties;
        }

        @Override
        protected Expression visitPlan(PlanNode node, Void context) {
            return BooleanLiteral.TRUE_LITERAL;
        }

        @Override
        public Expression visitAggregation(AggregationNode node, Void context) {
            if (node.getGroupingKeys().isEmpty()) {
                return BooleanLiteral.TRUE_LITERAL;
            }
            Expression underlyingPredicate = node.getSource().accept(this, context);
            return this.pullExpressionThroughSymbols(underlyingPredicate, node.getGroupingKeys());
        }

        @Override
        public Expression visitFilter(FilterNode node, Void context) {
            Expression underlyingPredicate = node.getSource().accept(this, context);
            Expression predicate = node.getPredicate();
            predicate = IrUtils.filterDeterministicConjuncts(this.metadata, predicate);
            return IrUtils.combineConjuncts(predicate, underlyingPredicate);
        }

        @Override
        public Expression visitExchange(ExchangeNode node, Void context) {
            return this.deriveCommonPredicates(node, source -> {
                HashMap<Symbol, SymbolReference> mappings = new HashMap<Symbol, SymbolReference>();
                for (int i = 0; i < node.getInputs().get((int)source).size(); ++i) {
                    mappings.put(node.getOutputSymbols().get(i), node.getInputs().get((int)source).get(i).toSymbolReference());
                }
                return mappings.entrySet();
            });
        }

        @Override
        public Expression visitProject(ProjectNode node, Void context) {
            Expression underlyingPredicate = node.getSource().accept(this, context);
            List nonIdentityAssignments = (List)node.getAssignments().entrySet().stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).collect(ImmutableList.toImmutableList());
            Set newlyAssignedSymbols = (Set)nonIdentityAssignments.stream().map(Map.Entry::getKey).collect(ImmutableSet.toImmutableSet());
            List validUnderlyingEqualities = (List)IrUtils.extractConjuncts(underlyingPredicate).stream().filter(expression -> Sets.intersection(SymbolsExtractor.extractUnique(expression), (Set)newlyAssignedSymbols).isEmpty()).collect(ImmutableList.toImmutableList());
            List projectionEqualities = (List)nonIdentityAssignments.stream().filter(assignment -> Sets.intersection(SymbolsExtractor.extractUnique((Expression)assignment.getValue()), (Set)newlyAssignedSymbols).isEmpty()).map(ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList());
            return this.pullExpressionThroughSymbols(IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll((Iterable)projectionEqualities).addAll((Iterable)validUnderlyingEqualities).build()), node.getOutputSymbols());
        }

        @Override
        public Expression visitTopN(TopNNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitLimit(LimitNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitAssignUniqueId(AssignUniqueId node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitDistinctLimit(DistinctLimitNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitTableScan(TableScanNode node, Void context) {
            ImmutableBiMap assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
            TupleDomain<ColumnHandle> predicate = node.getEnforcedConstraint();
            if (this.useTableProperties) {
                predicate = this.metadata.getTableProperties(this.session, node.getTable()).getPredicate();
            }
            return this.domainTranslator.toPredicate((TupleDomain<Symbol>)predicate.simplify().filter((arg_0, arg_1) -> Visitor.lambda$visitTableScan$3((Map)assignments, arg_0, arg_1)).transformKeys(((Map)assignments)::get));
        }

        @Override
        public Expression visitSort(SortNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitWindow(WindowNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitPatternRecognition(PatternRecognitionNode node, Void context) {
            Expression sourcePredicate = node.getSource().accept(this, context);
            return this.pullExpressionThroughSymbols(sourcePredicate, node.getOutputSymbols());
        }

        @Override
        public Expression visitUnion(UnionNode node, Void context) {
            return this.deriveCommonPredicates(node, source -> node.outputSymbolMap((int)source).entries());
        }

        @Override
        public Expression visitUnnest(UnnestNode node, Void context) {
            return BooleanLiteral.TRUE_LITERAL;
        }

        @Override
        public Expression visitJoin(JoinNode node, Void context) {
            Expression leftPredicate = node.getLeft().accept(this, context);
            Expression rightPredicate = node.getRight().accept(this, context);
            List joinConjuncts = (List)node.getCriteria().stream().map(JoinNode.EquiJoinClause::toExpression).collect(ImmutableList.toImmutableList());
            return switch (node.getType()) {
                default -> throw new MatchException(null, null);
                case JoinType.INNER -> this.pullExpressionThroughSymbols(IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)leftPredicate).add((Object)rightPredicate).add((Object)IrUtils.combineConjuncts(joinConjuncts)).add((Object)node.getFilter().orElse(BooleanLiteral.TRUE_LITERAL)).build()), node.getOutputSymbols());
                case JoinType.LEFT -> {
                    Predicate[] v1 = new Predicate[1];
                    v1[0] = node.getRight().getOutputSymbols()::contains;
                    Predicate[] v2 = new Predicate[1];
                    v2[0] = node.getRight().getOutputSymbols()::contains;
                    yield IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).addAll(this.pullNullableConjunctsThroughOuterJoin(IrUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), v1)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), v2)).build());
                }
                case JoinType.RIGHT -> {
                    Predicate[] v3 = new Predicate[1];
                    v3[0] = node.getLeft().getOutputSymbols()::contains;
                    Predicate[] v4 = new Predicate[1];
                    v4[0] = node.getLeft().getOutputSymbols()::contains;
                    yield IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())).addAll(this.pullNullableConjunctsThroughOuterJoin(IrUtils.extractConjuncts(leftPredicate), node.getOutputSymbols(), v3)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), v4)).build());
                }
                case JoinType.FULL -> {
                    Predicate[] v5 = new Predicate[1];
                    v5[0] = node.getLeft().getOutputSymbols()::contains;
                    Predicate[] v6 = new Predicate[1];
                    v6[0] = node.getRight().getOutputSymbols()::contains;
                    Predicate[] v7 = new Predicate[2];
                    v7[0] = node.getLeft().getOutputSymbols()::contains;
                    v7[1] = node.getRight().getOutputSymbols()::contains;
                    yield IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll(this.pullNullableConjunctsThroughOuterJoin(IrUtils.extractConjuncts(leftPredicate), node.getOutputSymbols(), v5)).addAll(this.pullNullableConjunctsThroughOuterJoin(IrUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), v6)).addAll(this.pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), v7)).build());
                }
            };
        }

        @Override
        public Expression visitValues(ValuesNode node, Void context) {
            if (node.getOutputSymbols().isEmpty()) {
                return BooleanLiteral.TRUE_LITERAL;
            }
            Preconditions.checkState((boolean)node.getRows().isPresent(), (Object)"rows is empty");
            boolean[] hasNull = new boolean[node.getOutputSymbols().size()];
            boolean[] hasNaN = new boolean[node.getOutputSymbols().size()];
            boolean[] nonDeterministic = new boolean[node.getOutputSymbols().size()];
            ImmutableList.Builder builders = ImmutableList.builder();
            for (int i = 0; i < node.getOutputSymbols().size(); ++i) {
                builders.add((Object)ImmutableList.builder());
            }
            ImmutableList valuesBuilders = builders.build();
            for (Expression row : node.getRows().get()) {
                if (row instanceof Row) {
                    for (int i = 0; i < node.getOutputSymbols().size(); ++i) {
                        Expression value = ((Row)row).getItems().get(i);
                        if (!DeterminismEvaluator.isDeterministic(value)) {
                            nonDeterministic[i] = true;
                            continue;
                        }
                        IrExpressionInterpreter interpreter = new IrExpressionInterpreter(value, this.plannerContext, this.session);
                        Object item = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
                        if (item instanceof Expression) {
                            return BooleanLiteral.TRUE_LITERAL;
                        }
                        if (item == null) {
                            hasNull[i] = true;
                            continue;
                        }
                        Type type = node.getOutputSymbols().get(i).getType();
                        if (!type.isComparable() && !type.isOrderable()) {
                            return BooleanLiteral.TRUE_LITERAL;
                        }
                        if (this.hasNestedNulls(type, item)) {
                            return BooleanLiteral.TRUE_LITERAL;
                        }
                        if (TypeUtils.isFloatingPointNaN((Type)type, (Object)item)) {
                            hasNaN[i] = true;
                        }
                        ((ImmutableList.Builder)valuesBuilders.get(i)).add(item);
                    }
                    continue;
                }
                if (!DeterminismEvaluator.isDeterministic(row)) {
                    return BooleanLiteral.TRUE_LITERAL;
                }
                IrExpressionInterpreter interpreter = new IrExpressionInterpreter(row, this.plannerContext, this.session);
                Object evaluated = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
                if (evaluated instanceof Expression) {
                    return BooleanLiteral.TRUE_LITERAL;
                }
                SqlRow sqlRow = (SqlRow)evaluated;
                int rawIndex = sqlRow.getRawIndex();
                for (int i = 0; i < node.getOutputSymbols().size(); ++i) {
                    Block fieldBlock;
                    Type type = node.getOutputSymbols().get(i).getType();
                    Object item = TypeUtils.readNativeValue((Type)type, (Block)(fieldBlock = sqlRow.getRawFieldBlock(i)), (int)rawIndex);
                    if (item == null) {
                        hasNull[i] = true;
                        continue;
                    }
                    if (!type.isComparable() && !type.isOrderable()) {
                        return BooleanLiteral.TRUE_LITERAL;
                    }
                    if (this.hasNestedNulls(type, item)) {
                        return BooleanLiteral.TRUE_LITERAL;
                    }
                    if (TypeUtils.isFloatingPointNaN((Type)type, (Object)item)) {
                        hasNaN[i] = true;
                    }
                    ((ImmutableList.Builder)valuesBuilders.get(i)).add(item);
                }
            }
            ImmutableMap.Builder domains = ImmutableMap.builder();
            for (int i = 0; i < node.getOutputSymbols().size(); ++i) {
                Symbol symbol = node.getOutputSymbols().get(i);
                Type type = symbol.getType();
                if (nonDeterministic[i]) continue;
                ImmutableList values = ((ImmutableList.Builder)valuesBuilders.get(i)).build();
                Domain domain = values.isEmpty() ? Domain.none((Type)type) : (hasNaN[i] ? Domain.notNull((Type)type) : Domain.multipleValues((Type)type, (List)values));
                if (hasNull[i]) {
                    domain = domain.union(Domain.onlyNull((Type)type));
                }
                domains.put((Object)symbol, (Object)domain);
            }
            return this.domainTranslator.toPredicate((TupleDomain<Symbol>)TupleDomain.withColumnDomains((Map)domains.buildOrThrow()).simplify());
        }

        private boolean hasNestedNulls(Type type, Object value) {
            block3: {
                block2: {
                    if (!(type instanceof RowType)) break block2;
                    RowType rowType = (RowType)type;
                    SqlRow sqlRow = (SqlRow)value;
                    int rawIndex = sqlRow.getRawIndex();
                    for (int i = 0; i < rowType.getFields().size(); ++i) {
                        Type elementType = ((RowType.Field)rowType.getFields().get(i)).getType();
                        Block fieldBlock = sqlRow.getRawFieldBlock(i);
                        if (!fieldBlock.isNull(rawIndex) && !this.elementHasNulls(elementType, fieldBlock, rawIndex)) continue;
                        return true;
                    }
                    break block3;
                }
                if (!(type instanceof ArrayType)) break block3;
                ArrayType arrayType = (ArrayType)type;
                Block container = (Block)value;
                Type elementType = arrayType.getElementType();
                for (int i = 0; i < container.getPositionCount(); ++i) {
                    if (!container.isNull(i) && !this.elementHasNulls(elementType, container, i)) continue;
                    return true;
                }
            }
            return false;
        }

        private boolean elementHasNulls(Type elementType, Block container, int position) {
            if (elementType instanceof RowType) {
                RowType rowType = (RowType)elementType;
                SqlRow element = rowType.getObject(container, position);
                return this.hasNestedNulls(elementType, element);
            }
            if (elementType instanceof ArrayType) {
                Block element = (Block)elementType.getObject(container, position);
                return this.hasNestedNulls(elementType, element);
            }
            return false;
        }

        @SafeVarargs
        private Iterable<Expression> pullNullableConjunctsThroughOuterJoin(List<Expression> conjuncts, Collection<Symbol> outputSymbols, Predicate<Symbol> ... nullSymbolScopes) {
            return (Iterable)conjuncts.stream().map(expression -> this.pullExpressionThroughSymbols((Expression)expression, outputSymbols)).map(expression -> SymbolsExtractor.extractAll(expression).isEmpty() ? BooleanLiteral.TRUE_LITERAL : expression).map(IrUtils.expressionOrNullSymbols(nullSymbolScopes)).collect(ImmutableList.toImmutableList());
        }

        @Override
        public Expression visitSemiJoin(SemiJoinNode node, Void context) {
            return node.getSource().accept(this, context);
        }

        @Override
        public Expression visitSpatialJoin(SpatialJoinNode node, Void context) {
            Expression leftPredicate = node.getLeft().accept(this, context);
            Expression rightPredicate = node.getRight().accept(this, context);
            return switch (node.getType()) {
                default -> throw new MatchException(null, null);
                case SpatialJoinNode.Type.INNER -> IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).add((Object)this.pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())).build());
                case SpatialJoinNode.Type.LEFT -> {
                    Predicate[] v1 = new Predicate[1];
                    v1[0] = node.getRight().getOutputSymbols()::contains;
                    yield IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().add((Object)this.pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())).addAll(this.pullNullableConjunctsThroughOuterJoin(IrUtils.extractConjuncts(rightPredicate), node.getOutputSymbols(), v1)).build());
                }
            };
        }

        private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> mapping) {
            ArrayList<ImmutableSet> sourceOutputConjuncts = new ArrayList<ImmutableSet>();
            for (int i = 0; i < node.getSources().size(); ++i) {
                Expression underlyingPredicate = node.getSources().get(i).accept(this, null);
                List equalities = (List)mapping.apply(i).stream().filter(SYMBOL_MATCHES_EXPRESSION.negate()).map(ENTRY_TO_EQUALITY).collect(ImmutableList.toImmutableList());
                sourceOutputConjuncts.add(ImmutableSet.copyOf(IrUtils.extractConjuncts(this.pullExpressionThroughSymbols(IrUtils.combineConjuncts((Collection<Expression>)ImmutableList.builder().addAll((Iterable)equalities).add((Object)underlyingPredicate).build()), node.getOutputSymbols()))));
            }
            Iterator iterator = sourceOutputConjuncts.iterator();
            Set potentialOutputConjuncts = (Set)iterator.next();
            while (iterator.hasNext()) {
                potentialOutputConjuncts = Sets.intersection((Set)potentialOutputConjuncts, (Set)((Set)iterator.next()));
            }
            return IrUtils.combineConjuncts(potentialOutputConjuncts);
        }

        private Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) {
            EqualityInference equalityInference = new EqualityInference(expression);
            ImmutableList.Builder effectiveConjuncts = ImmutableList.builder();
            ImmutableSet scope = ImmutableSet.copyOf(symbols);
            EqualityInference.nonInferrableConjuncts(expression).forEach(arg_0 -> Visitor.lambda$pullExpressionThroughSymbols$7(equalityInference, (Set)scope, effectiveConjuncts, arg_0));
            effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy((Set<Symbol>)scope).getScopeEqualities());
            return IrUtils.combineConjuncts((Collection<Expression>)effectiveConjuncts.build());
        }

        private static /* synthetic */ void lambda$pullExpressionThroughSymbols$7(EqualityInference equalityInference, Set scope, ImmutableList.Builder effectiveConjuncts, Expression conjunct) {
            Expression rewritten;
            if (DeterminismEvaluator.isDeterministic(conjunct) && (rewritten = equalityInference.rewrite(conjunct, scope)) != null) {
                effectiveConjuncts.add((Object)rewritten);
            }
        }

        private static /* synthetic */ boolean lambda$visitTableScan$3(Map assignments, ColumnHandle columnHandle, Domain domain) {
            return assignments.containsKey(columnHandle);
        }
    }
}

