/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SetOperationNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FrameBound;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.WindowFrame;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class SetOperationNodeTranslator {
    private static final String MARKER = "marker";
    private final SymbolAllocator symbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final ResolvedFunction countFunction;
    private final ResolvedFunction rowNumberFunction;

    public SetOperationNodeTranslator(Session session, Metadata metadata, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "SymbolAllocator is null");
        this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        Objects.requireNonNull(metadata, "metadata is null");
        this.countFunction = metadata.resolveFunction(session, QualifiedName.of((String)"count"), TypeSignatureProvider.fromTypes(new Type[]{BooleanType.BOOLEAN}));
        this.rowNumberFunction = metadata.resolveFunction(session, QualifiedName.of((String)"row_number"), (List<TypeSignatureProvider>)ImmutableList.of());
    }

    public TranslationResult makeSetContainmentPlanForDistinct(SetOperationNode node) {
        Preconditions.checkArgument((!(node instanceof UnionNode) ? 1 : 0) != 0, (Object)"Cannot simplify a UnionNode");
        List<Symbol> markers = this.allocateSymbols(node.getSources().size(), MARKER, (Type)BooleanType.BOOLEAN);
        List<PlanNode> withMarkers = this.appendMarkers(markers, node.getSources(), node);
        List<Symbol> outputs = node.getOutputSymbols();
        UnionNode union = this.union(withMarkers, (List<Symbol>)ImmutableList.copyOf((Iterable)Iterables.concat(outputs, markers)));
        List<Symbol> aggregationOutputs = this.allocateSymbols(markers.size(), "count", (Type)BigintType.BIGINT);
        AggregationNode aggregation = this.computeCounts(union, outputs, markers, aggregationOutputs);
        return new TranslationResult(aggregation, aggregationOutputs);
    }

    public TranslationResult makeSetContainmentPlanForAll(SetOperationNode node) {
        Preconditions.checkArgument((!(node instanceof UnionNode) ? 1 : 0) != 0, (Object)"Cannot simplify a UnionNode");
        List<Symbol> markers = this.allocateSymbols(node.getSources().size(), MARKER, (Type)BooleanType.BOOLEAN);
        List<PlanNode> withMarkers = this.appendMarkers(markers, node.getSources(), node);
        List<Symbol> outputs = node.getOutputSymbols();
        UnionNode union = this.union(withMarkers, (List<Symbol>)ImmutableList.copyOf((Iterable)Iterables.concat(outputs, markers)));
        List<Symbol> countOutputs = this.allocateSymbols(markers.size(), "count", (Type)BigintType.BIGINT);
        Symbol rowNumberSymbol = this.symbolAllocator.newSymbol("row_number", (Type)BigintType.BIGINT);
        WindowNode window = this.appendCounts(union, outputs, markers, countOutputs, rowNumberSymbol);
        ProjectNode project = new ProjectNode(this.idAllocator.getNextId(), window, Assignments.identity((Iterable<Symbol>)ImmutableList.copyOf((Iterable)Iterables.concat(outputs, countOutputs, (Iterable)ImmutableList.of((Object)rowNumberSymbol)))));
        return new TranslationResult(project, countOutputs, Optional.of(rowNumberSymbol));
    }

    private List<Symbol> allocateSymbols(int count, String nameHint, Type type) {
        ImmutableList.Builder symbolsBuilder = ImmutableList.builder();
        for (int i = 0; i < count; ++i) {
            symbolsBuilder.add((Object)this.symbolAllocator.newSymbol(nameHint, type));
        }
        return symbolsBuilder.build();
    }

    private List<PlanNode> appendMarkers(List<Symbol> markers, List<PlanNode> nodes, SetOperationNode node) {
        ImmutableList.Builder result = ImmutableList.builder();
        for (int i = 0; i < nodes.size(); ++i) {
            result.add((Object)SetOperationNodeTranslator.appendMarkers(this.idAllocator, this.symbolAllocator, nodes.get(i), i, markers, node.sourceSymbolMap(i)));
        }
        return result.build();
    }

    private static PlanNode appendMarkers(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, PlanNode source, int markerIndex, List<Symbol> markers, Map<Symbol, SymbolReference> projections) {
        Assignments.Builder assignments = Assignments.builder();
        for (Map.Entry<Symbol, SymbolReference> entry : projections.entrySet()) {
            Symbol symbol = symbolAllocator.newSymbol(entry.getKey().getName(), symbolAllocator.getTypes().get(entry.getKey()));
            assignments.put(symbol, (Expression)entry.getValue());
        }
        for (int i = 0; i < markers.size(); ++i) {
            BooleanLiteral expression = i == markerIndex ? BooleanLiteral.TRUE_LITERAL : new Cast((Expression)new NullLiteral(), TypeSignatureTranslator.toSqlType((Type)BooleanType.BOOLEAN));
            assignments.put(symbolAllocator.newSymbol(markers.get(i).getName(), (Type)BooleanType.BOOLEAN), (Expression)expression);
        }
        return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
    }

    private UnionNode union(List<PlanNode> nodes, List<Symbol> outputs) {
        ImmutableListMultimap.Builder outputsToInputs = ImmutableListMultimap.builder();
        for (PlanNode source : nodes) {
            for (int i = 0; i < source.getOutputSymbols().size(); ++i) {
                outputsToInputs.put((Object)outputs.get(i), (Object)source.getOutputSymbols().get(i));
            }
        }
        return new UnionNode(this.idAllocator.getNextId(), nodes, (ListMultimap<Symbol, Symbol>)outputsToInputs.build(), outputs);
    }

    private AggregationNode computeCounts(UnionNode sourceNode, List<Symbol> originalColumns, List<Symbol> markers, List<Symbol> aggregationOutputs) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (int i = 0; i < markers.size(); ++i) {
            Symbol output = aggregationOutputs.get(i);
            aggregations.put((Object)output, (Object)new AggregationNode.Aggregation(this.countFunction, (List<Expression>)ImmutableList.of((Object)markers.get(i).toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        return AggregationNode.singleAggregation(this.idAllocator.getNextId(), sourceNode, (Map<Symbol, AggregationNode.Aggregation>)aggregations.buildOrThrow(), AggregationNode.singleGroupingSet(originalColumns));
    }

    private WindowNode appendCounts(UnionNode sourceNode, List<Symbol> originalColumns, List<Symbol> markers, List<Symbol> countOutputs, Symbol rowNumberSymbol) {
        ImmutableMap.Builder functions = ImmutableMap.builder();
        WindowNode.Frame defaultFrame = new WindowNode.Frame(WindowFrame.Type.ROWS, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
        for (int i = 0; i < markers.size(); ++i) {
            Symbol output = countOutputs.get(i);
            functions.put((Object)output, (Object)new WindowNode.Function(this.countFunction, (List<Expression>)ImmutableList.of((Object)markers.get(i).toSymbolReference()), defaultFrame, false));
        }
        functions.put((Object)rowNumberSymbol, (Object)new WindowNode.Function(this.rowNumberFunction, (List<Expression>)ImmutableList.of(), defaultFrame, false));
        return new WindowNode(this.idAllocator.getNextId(), sourceNode, new WindowNode.Specification(originalColumns, Optional.empty()), (Map<Symbol, WindowNode.Function>)functions.buildOrThrow(), Optional.empty(), (Set<Symbol>)ImmutableSet.of(), 0);
    }

    public static class TranslationResult {
        private final PlanNode planNode;
        private final List<Symbol> countSymbols;
        private final Optional<Symbol> rowNumberSymbol;

        public TranslationResult(PlanNode planNode, List<Symbol> countSymbols) {
            this(planNode, countSymbols, Optional.empty());
        }

        public TranslationResult(PlanNode planNode, List<Symbol> countSymbols, Optional<Symbol> rowNumberSymbol) {
            this.planNode = Objects.requireNonNull(planNode, "planNode is null");
            this.countSymbols = ImmutableList.copyOf((Collection)Objects.requireNonNull(countSymbols, "countSymbols is null"));
            this.rowNumberSymbol = Objects.requireNonNull(rowNumberSymbol, "rowNumberSymbol is null");
        }

        public PlanNode getPlanNode() {
            return this.planNode;
        }

        public List<Symbol> getCountSymbols() {
            return this.countSymbols;
        }

        public Symbol getRowNumberSymbol() {
            Preconditions.checkState((boolean)this.rowNumberSymbol.isPresent(), (Object)"rowNumberSymbol is empty");
            return this.rowNumberSymbol.get();
        }
    }
}

