/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.Metadata;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.FieldReference;
import io.trino.sql.ir.Row;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.trino.type.UnknownType;
import java.util.ArrayDeque;
import java.util.List;

public class UnwrapRowSubscript
extends ExpressionRewriteRuleSet {
    public UnwrapRowSubscript(PlannerContext context) {
        super((Expression expression, Rule.Context context2) -> ExpressionTreeRewriter.rewriteWith(new Rewriter(context.getMetadata()), expression));
    }

    private static class Rewriter
    extends ExpressionRewriter<Void> {
        private final Metadata metadata;

        public Rewriter(Metadata metadata) {
            this.metadata = metadata;
        }

        @Override
        public Expression rewriteSubscript(FieldReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            Type type;
            Expression base = treeRewriter.rewrite(node.base(), context);
            ArrayDeque<Coercion> coercions = new ArrayDeque<Coercion>();
            while ((type = base.type()) instanceof RowType) {
                Expression expression;
                boolean safe;
                RowType rowType = (RowType)type;
                if (base instanceof Cast) {
                    Cast cast = (Cast)base;
                    safe = false;
                    expression = cast.expression();
                } else {
                    Call call;
                    if (!(base instanceof Call) || !(call = (Call)base).function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName("$try_cast"))) break;
                    safe = true;
                    expression = call.arguments().getFirst();
                }
                Type type2 = ((RowType.Field)rowType.getFields().get(node.field())).getType();
                if (!(type2 instanceof UnknownType)) {
                    coercions.push(new Coercion(type2, safe));
                }
                base = expression;
            }
            if (base instanceof Row) {
                Row row = (Row)base;
                Expression result = row.items().get(node.field());
                while (!coercions.isEmpty()) {
                    Coercion coercion = (Coercion)coercions.pop();
                    result = coercion.isSafe() ? new Call(this.metadata.getCoercion(GlobalFunctionCatalog.builtinFunctionName("$try_cast"), result.type(), coercion.getType()), (List<Expression>)ImmutableList.of((Object)result)) : new Cast(result, coercion.getType());
                }
                return result;
            }
            return treeRewriter.defaultRewrite(node, context);
        }
    }

    private static class Coercion {
        private final Type type;
        private final boolean safe;

        public Coercion(Type type, boolean safe) {
            this.type = type;
            this.safe = safe;
        }

        public Type getType() {
            return this.type;
        }

        public boolean isSafe() {
            return this.safe;
        }
    }
}

