/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.Iterables;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.sql.planner.DeterminismEvaluator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.Cardinality;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.Row;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;

public class ReplaceJoinOverConstantWithProject
implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(ReplaceJoinOverConstantWithProject::isUnconditional);
    private final Metadata metadata;

    public ReplaceJoinOverConstantWithProject(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode node, Captures captures, Rule.Context context) {
        Cardinality leftCardinality = QueryCardinalityUtil.extractCardinality(node.getLeft(), context.getLookup());
        if (leftCardinality.isEmpty()) {
            return Rule.Result.empty();
        }
        Cardinality rightCardinality = QueryCardinalityUtil.extractCardinality(node.getRight(), context.getLookup());
        if (rightCardinality.isEmpty()) {
            return Rule.Result.empty();
        }
        PlanNode left = context.getLookup().resolve(node.getLeft());
        PlanNode right = context.getLookup().resolve(node.getRight());
        boolean canInlineLeftSource = this.canInlineJoinSource(left);
        boolean canInlineRightSource = this.canInlineJoinSource(right);
        return switch (node.getType()) {
            default -> throw new MatchException(null, null);
            case JoinType.INNER -> {
                if (canInlineLeftSource) {
                    yield Rule.Result.ofPlanNode(this.appendProjection(right, node.getRightOutputSymbols(), left, node.getLeftOutputSymbols(), context.getIdAllocator()));
                }
                if (canInlineRightSource) {
                    yield Rule.Result.ofPlanNode(this.appendProjection(left, node.getLeftOutputSymbols(), right, node.getRightOutputSymbols(), context.getIdAllocator()));
                }
                yield Rule.Result.empty();
            }
            case JoinType.LEFT -> {
                if (canInlineLeftSource && rightCardinality.isAtLeastScalar()) {
                    yield Rule.Result.ofPlanNode(this.appendProjection(right, node.getRightOutputSymbols(), left, node.getLeftOutputSymbols(), context.getIdAllocator()));
                }
                if (canInlineRightSource) {
                    yield Rule.Result.ofPlanNode(this.appendProjection(left, node.getLeftOutputSymbols(), right, node.getRightOutputSymbols(), context.getIdAllocator()));
                }
                yield Rule.Result.empty();
            }
            case JoinType.RIGHT -> {
                if (canInlineLeftSource) {
                    yield Rule.Result.ofPlanNode(this.appendProjection(right, node.getRightOutputSymbols(), left, node.getLeftOutputSymbols(), context.getIdAllocator()));
                }
                if (canInlineRightSource && leftCardinality.isAtLeastScalar()) {
                    yield Rule.Result.ofPlanNode(this.appendProjection(left, node.getLeftOutputSymbols(), right, node.getRightOutputSymbols(), context.getIdAllocator()));
                }
                yield Rule.Result.empty();
            }
            case JoinType.FULL -> canInlineLeftSource && rightCardinality.isAtLeastScalar() ? Rule.Result.ofPlanNode(this.appendProjection(right, node.getRightOutputSymbols(), left, node.getLeftOutputSymbols(), context.getIdAllocator())) : (canInlineRightSource && leftCardinality.isAtLeastScalar() ? Rule.Result.ofPlanNode(this.appendProjection(left, node.getLeftOutputSymbols(), right, node.getRightOutputSymbols(), context.getIdAllocator())) : Rule.Result.empty());
        };
    }

    private static boolean isUnconditional(JoinNode joinNode) {
        return joinNode.getCriteria().isEmpty() && (joinNode.getFilter().isEmpty() || joinNode.getFilter().get().equals((Object)BooleanLiteral.TRUE_LITERAL));
    }

    private boolean canInlineJoinSource(PlanNode source) {
        return this.isSingleConstantRow(source) && !source.getOutputSymbols().isEmpty();
    }

    private boolean isSingleConstantRow(PlanNode node) {
        if (!(node instanceof ValuesNode)) {
            return false;
        }
        ValuesNode values = (ValuesNode)node;
        if (values.getRowCount() != 1) {
            return false;
        }
        if (values.getRows().isEmpty()) {
            return true;
        }
        Expression row = (Expression)Iterables.getOnlyElement((Iterable)values.getRows().get());
        if (!DeterminismEvaluator.isDeterministic(row, this.metadata)) {
            return false;
        }
        return row instanceof Row;
    }

    private ProjectNode appendProjection(PlanNode source, List<Symbol> sourceOutputs, PlanNode constantSource, List<Symbol> constantOutputs, PlanNodeIdAllocator idAllocator) {
        ValuesNode values = (ValuesNode)constantSource;
        Row row = (Row)Iterables.getOnlyElement((Iterable)values.getRows().get());
        HashMap<Symbol, Expression> mapping = new HashMap<Symbol, Expression>();
        for (int i = 0; i < values.getOutputSymbols().size(); ++i) {
            mapping.put(values.getOutputSymbols().get(i), (Expression)row.getItems().get(i));
        }
        Assignments.Builder assignments = Assignments.builder().putIdentities(sourceOutputs);
        constantOutputs.forEach(symbol -> assignments.put((Symbol)symbol, (Expression)mapping.get(symbol)));
        return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
    }
}

