/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableMap;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.LongLiteral;
import io.trino.sql.ir.SubscriptExpression;
import io.trino.sql.planner.IrTypeAnalyzer;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.ApplyNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;

public class UnwrapSingleColumnRowInApply
implements Rule<ApplyNode> {
    private static final Pattern<ApplyNode> PATTERN = Patterns.applyNode();
    private final IrTypeAnalyzer typeAnalyzer;

    public UnwrapSingleColumnRowInApply(IrTypeAnalyzer typeAnalyzer) {
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override
    public Pattern<ApplyNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(ApplyNode node, Captures captures, Rule.Context context) {
        Assignments.Builder inputAssignments = Assignments.builder().putIdentities(node.getInput().getOutputSymbols());
        Assignments.Builder nestedPlanAssignments = Assignments.builder().putIdentities(node.getSubquery().getOutputSymbols());
        boolean applied = false;
        ImmutableMap.Builder applyAssignments = ImmutableMap.builder();
        for (Map.Entry<Symbol, ApplyNode.SetExpression> assignment : node.getSubqueryAssignments().entrySet()) {
            Symbol output = assignment.getKey();
            ApplyNode.SetExpression expression = assignment.getValue();
            Optional<Object> unwrapped = Optional.empty();
            if (expression instanceof ApplyNode.In) {
                ApplyNode.In predicate = (ApplyNode.In)expression;
                unwrapped = this.unwrapSingleColumnRow(context, predicate.value().toSymbolReference(), predicate.reference().toSymbolReference(), ApplyNode.In::new);
            } else if (expression instanceof ApplyNode.QuantifiedComparison) {
                ApplyNode.QuantifiedComparison comparison = (ApplyNode.QuantifiedComparison)expression;
                unwrapped = this.unwrapSingleColumnRow(context, comparison.value().toSymbolReference(), comparison.reference().toSymbolReference(), (value, list) -> new ApplyNode.QuantifiedComparison(comparison.operator(), comparison.quantifier(), (Symbol)value, (Symbol)list));
            }
            if (unwrapped.isPresent()) {
                applied = true;
                Unwrapping unwrapping = (Unwrapping)unwrapped.get();
                inputAssignments.add(unwrapping.getInputAssignment());
                nestedPlanAssignments.add(unwrapping.getNestedPlanAssignment());
                applyAssignments.put((Object)output, (Object)unwrapping.getExpression());
                continue;
            }
            applyAssignments.put(assignment);
        }
        if (!applied) {
            return Rule.Result.empty();
        }
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new ApplyNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getInput(), inputAssignments.build()), new ProjectNode(context.getIdAllocator().getNextId(), node.getSubquery(), nestedPlanAssignments.build()), (Map<Symbol, ApplyNode.SetExpression>)applyAssignments.buildOrThrow(), node.getCorrelation(), node.getOriginSubquery()), Assignments.identity(node.getOutputSymbols())));
    }

    private Optional<Unwrapping> unwrapSingleColumnRow(Rule.Context context, Expression value, Expression list, BiFunction<Symbol, Symbol, ApplyNode.SetExpression> function) {
        RowType rowType;
        Type type = this.typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), value);
        if (type instanceof RowType && (rowType = (RowType)type).getFields().size() == 1) {
            Type elementType = (Type)rowType.getTypeParameters().get(0);
            Symbol valueSymbol = context.getSymbolAllocator().newSymbol("input", elementType);
            Symbol listSymbol = context.getSymbolAllocator().newSymbol("subquery", elementType);
            Assignments.Assignment inputAssignment = new Assignments.Assignment(valueSymbol, new SubscriptExpression(value, new LongLiteral(1L)));
            Assignments.Assignment nestedPlanAssignment = new Assignments.Assignment(listSymbol, new SubscriptExpression(list, new LongLiteral(1L)));
            ApplyNode.SetExpression comparison = function.apply(valueSymbol, listSymbol);
            return Optional.of(new Unwrapping(comparison, inputAssignment, nestedPlanAssignment));
        }
        return Optional.empty();
    }

    private static class Unwrapping {
        private final ApplyNode.SetExpression expression;
        private final Assignments.Assignment inputAssignment;
        private final Assignments.Assignment nestedPlanAssignment;

        public Unwrapping(ApplyNode.SetExpression expression, Assignments.Assignment inputAssignment, Assignments.Assignment nestedPlanAssignment) {
            this.expression = Objects.requireNonNull(expression, "expression is null");
            this.inputAssignment = Objects.requireNonNull(inputAssignment, "inputAssignment is null");
            this.nestedPlanAssignment = Objects.requireNonNull(nestedPlanAssignment, "nestedPlanAssignment is null");
        }

        public ApplyNode.SetExpression getExpression() {
            return this.expression;
        }

        public Assignments.Assignment getInputAssignment() {
            return this.inputAssignment;
        }

        public Assignments.Assignment getNestedPlanAssignment() {
            return this.nestedPlanAssignment;
        }
    }
}

