/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeDecorrelator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.ExistsPredicate;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.QualifiedName;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class TransformExistsApplyToCorrelatedJoin
implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode();
    private static final QualifiedName COUNT = QualifiedName.of((String)"count");
    private final Metadata metadata;

    public TransformExistsApplyToCorrelatedJoin(Metadata metadata) {
        Objects.requireNonNull(metadata, "metadata is null");
        this.metadata = metadata;
    }

    @Override
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(ApplyNode parent, Captures captures, Rule.Context context) {
        if (parent.getSubqueryAssignments().size() != 1) {
            return Rule.Result.empty();
        }
        Expression expression = (Expression)Iterables.getOnlyElement(parent.getSubqueryAssignments().getExpressions());
        if (!(expression instanceof ExistsPredicate)) {
            return Rule.Result.empty();
        }
        if (parent.getCorrelation().isEmpty()) {
            return Rule.Result.ofPlanNode(this.rewriteToDefaultAggregation(parent, context));
        }
        Optional<PlanNode> nonDefaultAggregation = this.rewriteToNonDefaultAggregation(parent, context);
        return nonDefaultAggregation.map(Rule.Result::ofPlanNode).orElseGet(() -> Rule.Result.ofPlanNode(this.rewriteToDefaultAggregation(parent, context)));
    }

    private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        Preconditions.checkState((boolean)applyNode.getSubquery().getOutputSymbols().isEmpty(), (Object)"Expected subquery output symbols to be pruned");
        Symbol subqueryTrue = context.getSymbolAllocator().newSymbol("subqueryTrue", (Type)BooleanType.BOOLEAN);
        ProjectNode subquery = new ProjectNode(context.getIdAllocator().getNextId(), new LimitNode(context.getIdAllocator().getNextId(), applyNode.getSubquery(), 1L, false), Assignments.of(subqueryTrue, (Expression)BooleanLiteral.TRUE_LITERAL));
        PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(this.metadata, context.getSymbolAllocator(), context.getLookup());
        if (decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isEmpty()) {
            return Optional.empty();
        }
        Symbol exists = (Symbol)Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getSymbols());
        Assignments.Builder assignments = Assignments.builder().putIdentities(applyNode.getInput().getOutputSymbols()).put(exists, (Expression)new CoalesceExpression((List)ImmutableList.of((Object)subqueryTrue.toSymbolReference(), (Object)BooleanLiteral.FALSE_LITERAL)));
        return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new CorrelatedJoinNode(applyNode.getId(), applyNode.getInput(), subquery, applyNode.getCorrelation(), CorrelatedJoinNode.Type.LEFT, (Expression)BooleanLiteral.TRUE_LITERAL, applyNode.getOriginSubquery()), assignments.build()));
    }

    private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Rule.Context context) {
        ResolvedFunction countFunction = this.metadata.resolveFunction(context.getSession(), COUNT, (List<TypeSignatureProvider>)ImmutableList.of());
        Symbol count = context.getSymbolAllocator().newSymbol(COUNT.toString(), (Type)BigintType.BIGINT);
        Symbol exists = (Symbol)Iterables.getOnlyElement(applyNode.getSubqueryAssignments().getSymbols());
        return new CorrelatedJoinNode(applyNode.getId(), applyNode.getInput(), new ProjectNode(context.getIdAllocator().getNextId(), new AggregationNode(context.getIdAllocator().getNextId(), applyNode.getSubquery(), (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of((Object)count, (Object)new AggregationNode.Aggregation(countFunction, (List<Expression>)ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.globalAggregation(), (List<Symbol>)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), Assignments.of(exists, (Expression)new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, (Expression)count.toSymbolReference(), (Expression)new Cast((Expression)new LongLiteral("0"), TypeSignatureTranslator.toSqlType((Type)BigintType.BIGINT))))), applyNode.getCorrelation(), CorrelatedJoinNode.Type.INNER, (Expression)BooleanLiteral.TRUE_LITERAL, applyNode.getOriginSubquery());
    }
}

