/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.trino.sql.ir.BindExpression;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.LambdaExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeProvider;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

public final class LambdaCaptureDesugaringRewriter {
    public static Expression rewrite(Expression expression, TypeProvider symbolTypes, SymbolAllocator symbolAllocator) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(symbolTypes, symbolAllocator), expression, new Context());
    }

    private LambdaCaptureDesugaringRewriter() {
    }

    private static class Visitor
    extends ExpressionRewriter<Context> {
        private final TypeProvider symbolTypes;
        private final SymbolAllocator symbolAllocator;

        public Visitor(TypeProvider symbolTypes, SymbolAllocator symbolAllocator) {
            this.symbolTypes = Objects.requireNonNull(symbolTypes, "symbolTypes is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        }

        @Override
        public Expression rewriteLambdaExpression(LambdaExpression node, Context context, ExpressionTreeRewriter<Context> treeRewriter) {
            LinkedHashSet<Symbol> referencedSymbols = new LinkedHashSet<Symbol>();
            Expression rewrittenBody = treeRewriter.rewrite(node.getBody(), context.withReferencedSymbols(referencedSymbols));
            List lambdaArguments = (List)node.getArguments().stream().map(Symbol::new).collect(ImmutableList.toImmutableList());
            Sets.SetView captureSymbols = Sets.difference(referencedSymbols, (Set)ImmutableSet.copyOf((Collection)lambdaArguments));
            ImmutableMap.Builder captureSymbolToExtraSymbol = ImmutableMap.builder();
            ImmutableList.Builder newLambdaArguments = ImmutableList.builder();
            for (Symbol captureSymbol : captureSymbols) {
                Symbol extraSymbol = this.symbolAllocator.newSymbol(captureSymbol.getName(), this.symbolTypes.get(captureSymbol));
                captureSymbolToExtraSymbol.put((Object)captureSymbol, (Object)extraSymbol);
                newLambdaArguments.add((Object)extraSymbol.getName());
            }
            newLambdaArguments.addAll(node.getArguments());
            ImmutableMap symbolsMap = captureSymbolToExtraSymbol.buildOrThrow();
            Function<Symbol, Expression> symbolMapping = symbol -> ((Symbol)symbolsMap.getOrDefault(symbol, symbol)).toSymbolReference();
            Expression rewrittenExpression = new LambdaExpression((List<String>)newLambdaArguments.build(), ExpressionSymbolInliner.inlineSymbols(symbolMapping, rewrittenBody));
            if (captureSymbols.size() != 0) {
                List capturedValues = (List)captureSymbols.stream().map(symbol -> new SymbolReference(symbol.getName())).collect(ImmutableList.toImmutableList());
                rewrittenExpression = new BindExpression(capturedValues, rewrittenExpression);
            }
            context.getReferencedSymbols().addAll((Collection<Symbol>)captureSymbols);
            return rewrittenExpression;
        }

        @Override
        public Expression rewriteSymbolReference(SymbolReference node, Context context, ExpressionTreeRewriter<Context> treeRewriter) {
            context.getReferencedSymbols().add(new Symbol(node.getName()));
            return null;
        }
    }

    private static class Context {
        final LinkedHashSet<Symbol> referencedSymbols;

        public Context() {
            this(new LinkedHashSet<Symbol>());
        }

        private Context(LinkedHashSet<Symbol> referencedSymbols) {
            this.referencedSymbols = referencedSymbols;
        }

        public LinkedHashSet<Symbol> getReferencedSymbols() {
            return this.referencedSymbols;
        }

        public Context withReferencedSymbols(LinkedHashSet<Symbol> symbols) {
            return new Context(symbols);
        }
    }
}

