/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.project;

import com.google.common.collect.ImmutableList;
import io.trino.operator.project.InputChannels;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.Expressions;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.RowExpressionVisitor;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;

public final class PageFieldsToInputParametersRewriter {
    private PageFieldsToInputParametersRewriter() {
    }

    public static Result rewritePageFieldsToInputParameters(RowExpression expression) {
        Visitor visitor = new Visitor();
        RowExpression rewrittenProjection = expression.accept(visitor, true);
        InputChannels inputChannels = new InputChannels(visitor.getInputChannels(), visitor.getEagerlyLoadedChannels());
        return new Result(rewrittenProjection, inputChannels);
    }

    private static class Visitor
    implements RowExpressionVisitor<RowExpression, Boolean> {
        private final Map<Integer, Integer> fieldToParameter = new HashMap<Integer, Integer>();
        private final List<Integer> inputChannels = new ArrayList<Integer>();
        private final Set<Integer> eagerlyLoadedChannels = new HashSet<Integer>();
        private int nextParameter;

        private Visitor() {
        }

        public List<Integer> getInputChannels() {
            return ImmutableList.copyOf(this.inputChannels);
        }

        public List<Integer> getEagerlyLoadedChannels() {
            return ImmutableList.copyOf(this.eagerlyLoadedChannels);
        }

        @Override
        public RowExpression visitInputReference(InputReferenceExpression reference, Boolean unconditionallyEvaluated) {
            if (unconditionallyEvaluated.booleanValue()) {
                this.eagerlyLoadedChannels.add(reference.getField());
            }
            int parameter = this.getParameterForField(reference);
            return Expressions.field(parameter, reference.getType());
        }

        private Integer getParameterForField(InputReferenceExpression reference) {
            return this.fieldToParameter.computeIfAbsent(reference.getField(), field -> {
                this.inputChannels.add((Integer)field);
                return this.nextParameter++;
            });
        }

        @Override
        public RowExpression visitCall(CallExpression call, Boolean unconditionallyEvaluated) {
            boolean containsLambdaExpression = call.getArguments().stream().anyMatch(LambdaDefinitionExpression.class::isInstance);
            return new CallExpression(call.getResolvedFunction(), (List)call.getArguments().stream().map(expression -> expression.accept(this, unconditionallyEvaluated != false && !containsLambdaExpression)).collect(ImmutableList.toImmutableList()));
        }

        @Override
        public RowExpression visitSpecialForm(SpecialForm specialForm, Boolean unconditionallyEvaluated) {
            switch (specialForm.getForm()) {
                case IF: 
                case SWITCH: 
                case BETWEEN: 
                case AND: 
                case OR: 
                case COALESCE: {
                    List<RowExpression> arguments = specialForm.getArguments();
                    return new SpecialForm(specialForm.getForm(), specialForm.getType(), (List)IntStream.range(0, arguments.size()).boxed().map(index -> ((RowExpression)arguments.get((int)index)).accept(this, index == 0 && unconditionallyEvaluated != false)).collect(ImmutableList.toImmutableList()), specialForm.getFunctionDependencies());
                }
                case BIND: 
                case IN: 
                case WHEN: 
                case IS_NULL: 
                case NULL_IF: 
                case DEREFERENCE: 
                case ROW_CONSTRUCTOR: {
                    return new SpecialForm(specialForm.getForm(), specialForm.getType(), (List)specialForm.getArguments().stream().map(expression -> expression.accept(this, unconditionallyEvaluated)).collect(ImmutableList.toImmutableList()), specialForm.getFunctionDependencies());
                }
            }
            throw new IllegalArgumentException("Unsupported special form " + String.valueOf((Object)specialForm.getForm()));
        }

        @Override
        public RowExpression visitConstant(ConstantExpression literal, Boolean unconditionallyEvaluated) {
            return literal;
        }

        @Override
        public RowExpression visitLambda(LambdaDefinitionExpression lambda, Boolean unconditionallyEvaluated) {
            return new LambdaDefinitionExpression(lambda.getArguments(), lambda.getBody().accept(this, unconditionallyEvaluated));
        }

        @Override
        public RowExpression visitVariableReference(VariableReferenceExpression reference, Boolean unconditionallyEvaluated) {
            return reference;
        }
    }

    public static class Result {
        private final RowExpression rewrittenExpression;
        private final InputChannels inputChannels;

        public Result(RowExpression rewrittenExpression, InputChannels inputChannels) {
            this.rewrittenExpression = rewrittenExpression;
            this.inputChannels = inputChannels;
        }

        public RowExpression getRewrittenExpression() {
            return this.rewrittenExpression;
        }

        public InputChannels getInputChannels() {
            return this.inputChannels;
        }
    }
}

