/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.SearchedCaseExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.WhenClause;
import io.trino.sql.util.AstUtils;
import jakarta.annotation.Nullable;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class TransformCorrelatedInPredicateToJoin
implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode().with(Pattern.nonEmpty(Patterns.Apply.correlation()));
    private final Metadata metadata;

    public TransformCorrelatedInPredicateToJoin(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(ApplyNode apply, Captures captures, Rule.Context context) {
        Assignments subqueryAssignments = apply.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Rule.Result.empty();
        }
        Expression assignmentExpression = (Expression)Iterables.getOnlyElement(subqueryAssignments.getExpressions());
        if (!(assignmentExpression instanceof InPredicate)) {
            return Rule.Result.empty();
        }
        InPredicate inPredicate = (InPredicate)assignmentExpression;
        Symbol inPredicateOutputSymbol = (Symbol)Iterables.getOnlyElement(subqueryAssignments.getSymbols());
        return this.apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
    }

    private Rule.Result apply(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
        Optional<Decorrelated> decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation()).decorrelate(apply.getSubquery());
        if (decorrelated.isEmpty()) {
            return Rule.Result.empty();
        }
        PlanNode projection = this.buildInPredicateEquivalent(apply, inPredicate, inPredicateOutputSymbol, decorrelated.get(), idAllocator, symbolAllocator);
        return Rule.Result.ofPlanNode(projection);
    }

    private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
        Expression correlationCondition = ExpressionUtils.and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
        AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), symbolAllocator.newSymbol("unique", (Type)BigintType.BIGINT));
        Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", (Type)BigintType.BIGINT);
        ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putIdentities(decorrelatedBuildSource.getOutputSymbols()).put(buildSideKnownNonNull, TransformCorrelatedInPredicateToJoin.bigint(0L)).build());
        Symbol probeSideSymbol = Symbol.from(inPredicate.getValue());
        Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList());
        Expression joinExpression = ExpressionUtils.and(ExpressionUtils.or(new Expression[]{new IsNullPredicate((Expression)probeSideSymbol.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)probeSideSymbol.toSymbolReference(), (Expression)buildSideSymbol.toSymbolReference()), new IsNullPredicate((Expression)buildSideSymbol.toSymbolReference())}), correlationCondition);
        JoinNode leftOuterJoin = TransformCorrelatedInPredicateToJoin.leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
        Symbol matchConditionSymbol = symbolAllocator.newSymbol("matchConditionSymbol", (Type)BooleanType.BOOLEAN);
        Expression matchCondition = ExpressionUtils.and(TransformCorrelatedInPredicateToJoin.isNotNull(probeSideSymbol), TransformCorrelatedInPredicateToJoin.isNotNull(buildSideSymbol));
        Symbol nullMatchConditionSymbol = symbolAllocator.newSymbol("nullMatchConditionSymbol", (Type)BooleanType.BOOLEAN);
        Expression nullMatchCondition = ExpressionUtils.and(TransformCorrelatedInPredicateToJoin.isNotNull(buildSideKnownNonNull), TransformCorrelatedInPredicateToJoin.not(matchCondition));
        ProjectNode preProjection = new ProjectNode(idAllocator.getNextId(), leftOuterJoin, Assignments.builder().putIdentities(leftOuterJoin.getOutputSymbols()).put(matchConditionSymbol, matchCondition).put(nullMatchConditionSymbol, nullMatchCondition).build());
        Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", (Type)BigintType.BIGINT);
        Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", (Type)BigintType.BIGINT);
        AggregationNode aggregation = AggregationNode.singleAggregation(idAllocator.getNextId(), preProjection, (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.builder().put((Object)countMatchesSymbol, (Object)this.countWithFilter(matchConditionSymbol)).put((Object)countNullMatchesSymbol, (Object)this.countWithFilter(nullMatchConditionSymbol)).buildOrThrow(), AggregationNode.singleGroupingSet(probeSide.getOutputSymbols()));
        SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression((List)ImmutableList.of((Object)new WhenClause(TransformCorrelatedInPredicateToJoin.isGreaterThan(countMatchesSymbol, 0L), TransformCorrelatedInPredicateToJoin.booleanConstant(true)), (Object)new WhenClause(TransformCorrelatedInPredicateToJoin.isGreaterThan(countNullMatchesSymbol, 0L), TransformCorrelatedInPredicateToJoin.booleanConstant(null))), Optional.of(TransformCorrelatedInPredicateToJoin.booleanConstant(false)));
        return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putIdentities(apply.getInput().getOutputSymbols()).put(inPredicateOutputSymbol, (Expression)inPredicateEquivalent).build());
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUniqueId probeSide, ProjectNode buildSide, Expression joinExpression) {
        return new JoinNode(idAllocator.getNextId(), JoinNode.Type.LEFT, probeSide, buildSide, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), probeSide.getOutputSymbols(), buildSide.getOutputSymbols(), false, Optional.of(joinExpression), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of(), Optional.empty());
    }

    private AggregationNode.Aggregation countWithFilter(Symbol filter) {
        return new AggregationNode.Aggregation(this.metadata.resolveBuiltinFunction("count", (List<TypeSignatureProvider>)ImmutableList.of()), (List<Expression>)ImmutableList.of(), false, Optional.of(filter), Optional.empty(), Optional.empty());
    }

    private static Expression isGreaterThan(Symbol symbol, long value) {
        return new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)symbol.toSymbolReference(), TransformCorrelatedInPredicateToJoin.bigint(value));
    }

    private static Expression not(Expression booleanExpression) {
        return new NotExpression(booleanExpression);
    }

    private static Expression isNotNull(Symbol symbol) {
        return new IsNotNullPredicate((Expression)symbol.toSymbolReference());
    }

    private static Expression bigint(long value) {
        return new Cast((Expression)new LongLiteral(String.valueOf(value)), TypeSignatureTranslator.toSqlType((Type)BigintType.BIGINT));
    }

    private static Expression booleanConstant(@Nullable Boolean value) {
        if (value == null) {
            return new Cast((Expression)new NullLiteral(), TypeSignatureTranslator.toSqlType((Type)BooleanType.BOOLEAN));
        }
        return new BooleanLiteral(value.toString());
    }

    private static class DecorrelatingVisitor
    extends PlanVisitor<Optional<Decorrelated>, PlanNode> {
        private final Lookup lookup;
        private final Set<Symbol> correlation;

        public DecorrelatingVisitor(Lookup lookup, Iterable<Symbol> correlation) {
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf(Objects.requireNonNull(correlation, "correlation is null"));
        }

        public Optional<Decorrelated> decorrelate(PlanNode reference) {
            return this.lookup.resolve(reference).accept(this, reference);
        }

        @Override
        public Optional<Decorrelated> visitProject(ProjectNode node, PlanNode reference) {
            if (this.isCorrelatedShallowly(node)) {
                return Optional.empty();
            }
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> {
                Assignments.Builder assignments = Assignments.builder().putAll(node.getAssignments());
                decorrelated.getCorrelatedPredicates().stream().flatMap(AstUtils::preOrder).filter(SymbolReference.class::isInstance).map(SymbolReference.class::cast).filter(symbolReference -> !this.correlation.contains(Symbol.from((Expression)symbolReference))).forEach(symbolReference -> assignments.putIdentity(Symbol.from((Expression)symbolReference)));
                return new Decorrelated(decorrelated.getCorrelatedPredicates(), new ProjectNode(node.getId(), decorrelated.getDecorrelatedNode(), assignments.build()));
            });
        }

        @Override
        public Optional<Decorrelated> visitFilter(FilterNode node, PlanNode reference) {
            Optional<Decorrelated> result = this.decorrelate(node.getSource());
            return result.map(decorrelated -> new Decorrelated((List<Expression>)ImmutableList.builder().addAll(decorrelated.getCorrelatedPredicates()).add((Object)node.getPredicate()).build(), decorrelated.getDecorrelatedNode()));
        }

        @Override
        protected Optional<Decorrelated> visitPlan(PlanNode node, PlanNode reference) {
            if (this.isCorrelatedRecursively(node)) {
                return Optional.empty();
            }
            return Optional.of(new Decorrelated((List<Expression>)ImmutableList.of(), reference));
        }

        private boolean isCorrelatedRecursively(PlanNode node) {
            if (this.isCorrelatedShallowly(node)) {
                return true;
            }
            return node.getSources().stream().map(this.lookup::resolve).anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode node) {
            return SymbolsExtractor.extractUniqueNonRecursive(node).stream().anyMatch(this.correlation::contains);
        }
    }

    private static class Decorrelated {
        private final List<Expression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<Expression> correlatedPredicates, PlanNode decorrelatedNode) {
            this.correlatedPredicates = ImmutableList.copyOf((Collection)Objects.requireNonNull(correlatedPredicates, "correlatedPredicates is null"));
            this.decorrelatedNode = Objects.requireNonNull(decorrelatedNode, "decorrelatedNode is null");
        }

        public List<Expression> getCorrelatedPredicates() {
            return this.correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode() {
            return this.decorrelatedNode;
        }
    }
}

