/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.sql.ir.Bind;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.Lambda;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

public final class LambdaCaptureDesugaringRewriter {
    public static Expression rewrite(Expression expression, SymbolAllocator symbolAllocator) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(symbolAllocator), expression, new Context());
    }

    private LambdaCaptureDesugaringRewriter() {
    }

    private static class Visitor
    extends ExpressionRewriter<Context> {
        private final SymbolAllocator symbolAllocator;

        public Visitor(SymbolAllocator symbolAllocator) {
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override
        public Expression rewriteLambda(Lambda node, Context context, ExpressionTreeRewriter<Context> treeRewriter) {
            Lambda lambda;
            LinkedHashSet<Symbol> referencedSymbols = new LinkedHashSet<Symbol>();
            Expression rewrittenBody = treeRewriter.rewrite(node.body(), context.withReferencedSymbols(referencedSymbols));
            List<Symbol> lambdaArguments = node.arguments();
            Sets.SetView captureSymbols = Sets.difference(referencedSymbols, (Set)ImmutableSet.copyOf(lambdaArguments));
            ImmutableMap.Builder captureSymbolToExtraSymbol = ImmutableMap.builder();
            ImmutableList.Builder newLambdaArguments = ImmutableList.builder();
            for (Symbol captureSymbol : captureSymbols) {
                Symbol extraSymbol = this.symbolAllocator.newSymbol(captureSymbol.name(), captureSymbol.type());
                captureSymbolToExtraSymbol.put((Object)captureSymbol, (Object)extraSymbol);
                newLambdaArguments.add((Object)extraSymbol);
            }
            newLambdaArguments.addAll(node.arguments());
            ImmutableMap symbolsMap = captureSymbolToExtraSymbol.buildOrThrow();
            Function<Symbol, Expression> symbolMapping = arg_0 -> Visitor.lambda$rewriteLambda$0((Map)symbolsMap, arg_0);
            Record rewrittenExpression = lambda = new Lambda((List<Symbol>)newLambdaArguments.build(), ExpressionSymbolInliner.inlineSymbols(symbolMapping, rewrittenBody));
            if (captureSymbols.size() != 0) {
                List capturedValues = (List)captureSymbols.stream().map(symbol -> new Reference(symbol.type(), symbol.name())).collect(ImmutableList.toImmutableList());
                rewrittenExpression = new Bind(capturedValues, lambda);
            }
            context.getReferencedSymbols().addAll((Collection<Symbol>)captureSymbols);
            return rewrittenExpression;
        }

        @Override
        public Expression rewriteReference(Reference node, Context context, ExpressionTreeRewriter<Context> treeRewriter) {
            context.getReferencedSymbols().add(new Symbol(node.type(), node.name()));
            return null;
        }

        private static /* synthetic */ Expression lambda$rewriteLambda$0(Map symbolsMap, Symbol symbol) {
            return symbolsMap.getOrDefault(symbol, symbol).toSymbolReference();
        }
    }

    private static class Context {
        final LinkedHashSet<Symbol> referencedSymbols;

        public Context() {
            this(new LinkedHashSet<Symbol>());
        }

        private Context(LinkedHashSet<Symbol> referencedSymbols) {
            this.referencedSymbols = referencedSymbols;
        }

        public LinkedHashSet<Symbol> getReferencedSymbols() {
            return this.referencedSymbols;
        }

        public Context withReferencedSymbols(LinkedHashSet<Symbol> symbols) {
            return new Context(symbols);
        }
    }
}

