/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Range;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.spi.StandardErrorCode;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.PlanNodeSearcher;
import io.prestosql.sql.planner.optimizations.QueryCardinalityUtil;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.EnforceSingleRowNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.planner.plan.MarkDistinctNode;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Cast;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.QualifiedName;
import io.prestosql.sql.tree.SimpleCaseExpression;
import io.prestosql.sql.tree.StringLiteral;
import io.prestosql.sql.tree.WhenClause;
import java.util.List;
import java.util.Optional;

public class TransformCorrelatedScalarSubquery
implements Rule<LateralJoinNode> {
    private static final Pattern<LateralJoinNode> PATTERN = Patterns.lateralJoin().with(Pattern.nonEmpty(Patterns.LateralJoin.correlation())).with(Patterns.LateralJoin.filter().equalTo((Object)BooleanLiteral.TRUE_LITERAL));

    @Override
    public Pattern getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(LateralJoinNode lateralJoinNode, Captures captures, Rule.Context context) {
        PlanNode subquery = context.getLookup().resolve(lateralJoinNode.getSubquery());
        if (!PlanNodeSearcher.searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).matches()) {
            return Rule.Result.empty();
        }
        PlanNode rewrittenSubquery = PlanNodeSearcher.searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).removeFirst();
        Range<Long> subqueryCardinality = QueryCardinalityUtil.extractCardinality(rewrittenSubquery, context.getLookup());
        boolean producesAtMostOneRow = Range.closed((Comparable)Long.valueOf(0L), (Comparable)Long.valueOf(1L)).encloses(subqueryCardinality);
        if (producesAtMostOneRow) {
            boolean producesSingleRow = Range.singleton((Comparable)Long.valueOf(1L)).encloses(subqueryCardinality);
            return Rule.Result.ofPlanNode(new LateralJoinNode(context.getIdAllocator().getNextId(), lateralJoinNode.getInput(), rewrittenSubquery, lateralJoinNode.getCorrelation(), producesSingleRow ? lateralJoinNode.getType() : LateralJoinNode.Type.LEFT, lateralJoinNode.getFilter(), lateralJoinNode.getOriginSubquery()));
        }
        Symbol unique = context.getSymbolAllocator().newSymbol("unique", (Type)BigintType.BIGINT);
        LateralJoinNode rewrittenLateralJoinNode = new LateralJoinNode(context.getIdAllocator().getNextId(), new AssignUniqueId(context.getIdAllocator().getNextId(), lateralJoinNode.getInput(), unique), rewrittenSubquery, lateralJoinNode.getCorrelation(), LateralJoinNode.Type.LEFT, lateralJoinNode.getFilter(), lateralJoinNode.getOriginSubquery());
        Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", (Type)BooleanType.BOOLEAN);
        MarkDistinctNode markDistinctNode = new MarkDistinctNode(context.getIdAllocator().getNextId(), rewrittenLateralJoinNode, isDistinct, rewrittenLateralJoinNode.getInput().getOutputSymbols(), Optional.empty());
        FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), markDistinctNode, (Expression)new SimpleCaseExpression((Expression)isDistinct.toSymbolReference(), (List)ImmutableList.of((Object)new WhenClause((Expression)BooleanLiteral.TRUE_LITERAL, (Expression)BooleanLiteral.TRUE_LITERAL)), Optional.of(new Cast((Expression)new FunctionCall(QualifiedName.of((String)"fail"), (List)ImmutableList.of((Object)new LongLiteral(Integer.toString(StandardErrorCode.SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode())), (Object)new StringLiteral("Scalar sub-query has returned multiple rows"))), "boolean"))));
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), filterNode, Assignments.identity(lateralJoinNode.getOutputSymbols())));
    }
}

