/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.DistinctLimitNode;
import io.trino.sql.planner.plan.GroupIdNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.StatisticAggregations;
import io.trino.sql.planner.plan.StatisticsWriterNode;
import io.trino.sql.planner.plan.TableFinishNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.planner.plan.TopNNode;
import io.trino.sql.planner.plan.TopNRankingNode;
import io.trino.sql.planner.plan.WindowNode;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.SymbolReference;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

public class SymbolMapper {
    private final Function<Symbol, Symbol> mappingFunction;

    public SymbolMapper(Function<Symbol, Symbol> mappingFunction) {
        this.mappingFunction = Objects.requireNonNull(mappingFunction, "mappingFunction is null");
    }

    public static SymbolMapper symbolMapper(Map<Symbol, Symbol> mapping) {
        return new SymbolMapper(symbol -> {
            while (mapping.containsKey(symbol) && !((Symbol)mapping.get(symbol)).equals(symbol)) {
                symbol = (Symbol)mapping.get(symbol);
            }
            return symbol;
        });
    }

    public static SymbolMapper symbolReallocator(Map<Symbol, Symbol> mapping, SymbolAllocator symbolAllocator) {
        return new SymbolMapper(symbol -> {
            if (mapping.containsKey(symbol)) {
                while (mapping.containsKey(symbol) && !((Symbol)mapping.get(symbol)).equals(symbol)) {
                    symbol = (Symbol)mapping.get(symbol);
                }
                mapping.put((Symbol)symbol, (Symbol)symbol);
                return symbol;
            }
            Symbol newSymbol = symbolAllocator.newSymbol((Symbol)symbol);
            mapping.put((Symbol)symbol, newSymbol);
            mapping.put(newSymbol, newSymbol);
            return newSymbol;
        });
    }

    public Symbol map(Symbol symbol) {
        return this.mappingFunction.apply(symbol);
    }

    public List<Symbol> map(List<Symbol> symbols) {
        return (List)symbols.stream().map(this::map).collect(ImmutableList.toImmutableList());
    }

    public List<Symbol> mapAndDistinct(List<Symbol> symbols) {
        return (List)symbols.stream().map(this::map).distinct().collect(ImmutableList.toImmutableList());
    }

    public Expression map(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new ExpressionRewriter<Void>(){

            public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                Symbol canonical = SymbolMapper.this.map(Symbol.from((Expression)node));
                return canonical.toSymbolReference();
            }
        }, (Expression)expression);
    }

    public AggregationNode map(AggregationNode node, PlanNode source) {
        return this.map(node, source, node.getId());
    }

    public AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            aggregations.put((Object)this.map(entry.getKey()), (Object)this.map(entry.getValue()));
        }
        return new AggregationNode(newNodeId, source, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), AggregationNode.groupingSets(this.mapAndDistinct(node.getGroupingKeys()), node.getGroupingSetCount(), node.getGlobalGroupingSets()), (List<Symbol>)ImmutableList.of(), node.getStep(), node.getHashSymbol().map(this::map), node.getGroupIdSymbol().map(this::map));
    }

    public AggregationNode.Aggregation map(AggregationNode.Aggregation aggregation) {
        return new AggregationNode.Aggregation(aggregation.getResolvedFunction(), (List)aggregation.getArguments().stream().map(this::map).collect(ImmutableList.toImmutableList()), aggregation.isDistinct(), aggregation.getFilter().map(this::map), aggregation.getOrderingScheme().map(this::map), aggregation.getMask().map(this::map));
    }

    public GroupIdNode map(GroupIdNode node, PlanNode source) {
        HashMap<Symbol, Symbol> newGroupingMappings = new HashMap<Symbol, Symbol>();
        ImmutableList.Builder newGroupingSets = ImmutableList.builder();
        for (List<Symbol> groupingSet : node.getGroupingSets()) {
            ImmutableList.Builder newGroupingSet = ImmutableList.builder();
            for (Symbol output : groupingSet) {
                Symbol newOutput = this.map(output);
                newGroupingMappings.putIfAbsent(newOutput, this.map(node.getGroupingColumns().get(output)));
                newGroupingSet.add((Object)newOutput);
            }
            newGroupingSets.add((Object)newGroupingSet.build());
        }
        return new GroupIdNode(node.getId(), source, (List<List<Symbol>>)newGroupingSets.build(), newGroupingMappings, this.mapAndDistinct(node.getAggregationArguments()), this.map(node.getGroupIdSymbol()));
    }

    public WindowNode map(WindowNode node, PlanNode source) {
        ImmutableMap.Builder newFunctions = ImmutableMap.builder();
        node.getWindowFunctions().forEach((symbol, function) -> {
            List newArguments = (List)function.getArguments().stream().map(this::map).collect(ImmutableList.toImmutableList());
            WindowNode.Frame newFrame = this.map(function.getFrame());
            newFunctions.put((Object)this.map((Symbol)symbol), (Object)new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls()));
        });
        return new WindowNode(node.getId(), source, this.mapAndDistinct(node.getSpecification()), (Map<Symbol, WindowNode.Function>)newFunctions.build(), node.getHashSymbol().map(this::map), (Set)node.getPrePartitionedInputs().stream().map(this::map).collect(ImmutableSet.toImmutableSet()), node.getPreSortedOrderPrefix());
    }

    private WindowNode.Frame map(WindowNode.Frame frame) {
        return new WindowNode.Frame(frame.getType(), frame.getStartType(), frame.getStartValue().map(this::map), frame.getSortKeyCoercedForFrameStartComparison().map(this::map), frame.getEndType(), frame.getEndValue().map(this::map), frame.getSortKeyCoercedForFrameEndComparison().map(this::map), frame.getOriginalStartValue(), frame.getOriginalEndValue());
    }

    private WindowNode.Specification mapAndDistinct(WindowNode.Specification specification) {
        return new WindowNode.Specification(this.mapAndDistinct(specification.getPartitionBy()), specification.getOrderingScheme().map(this::map));
    }

    public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source) {
        ImmutableMap.Builder newFunctions = ImmutableMap.builder();
        node.getWindowFunctions().forEach((symbol, function) -> {
            List newArguments = (List)function.getArguments().stream().map(this::map).collect(ImmutableList.toImmutableList());
            WindowNode.Frame newFrame = this.map(function.getFrame());
            newFunctions.put((Object)this.map((Symbol)symbol), (Object)new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls()));
        });
        ImmutableMap.Builder newMeasures = ImmutableMap.builder();
        node.getMeasures().forEach((symbol, measure) -> {
            LogicalIndexExtractor.ExpressionAndValuePointers newExpression = this.map(measure.getExpressionAndValuePointers());
            newMeasures.put((Object)this.map((Symbol)symbol), (Object)new PatternRecognitionNode.Measure(newExpression, measure.getType()));
        });
        ImmutableMap.Builder newVariableDefinitions = ImmutableMap.builder();
        node.getVariableDefinitions().forEach((label, expression) -> newVariableDefinitions.put(label, (Object)this.map((LogicalIndexExtractor.ExpressionAndValuePointers)expression)));
        return new PatternRecognitionNode(node.getId(), source, this.mapAndDistinct(node.getSpecification()), node.getHashSymbol().map(this::map), (Set)node.getPrePartitionedInputs().stream().map(this::map).collect(ImmutableSet.toImmutableSet()), node.getPreSortedOrderPrefix(), (Map<Symbol, WindowNode.Function>)newFunctions.build(), (Map<Symbol, PatternRecognitionNode.Measure>)newMeasures.build(), node.getCommonBaseFrame().map(this::map), node.getRowsPerMatch(), node.getSkipToLabel(), node.getSkipToPosition(), node.isInitial(), node.getPattern(), node.getSubsets(), (Map<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers>)newVariableDefinitions.build());
    }

    private LogicalIndexExtractor.ExpressionAndValuePointers map(LogicalIndexExtractor.ExpressionAndValuePointers expressionAndValuePointers) {
        Set<Symbol> syntheticClassifierSymbols = expressionAndValuePointers.getClassifierSymbols();
        Set<Symbol> syntheticMatchNumberSymbols = expressionAndValuePointers.getMatchNumberSymbols();
        List newValuePointers = (List)expressionAndValuePointers.getValuePointers().stream().map((? super T pointer) -> {
            Symbol inputSymbol = pointer.getInputSymbol();
            if (syntheticClassifierSymbols.contains(inputSymbol) || syntheticMatchNumberSymbols.contains(inputSymbol)) {
                return pointer;
            }
            return new LogicalIndexExtractor.ValuePointer(pointer.getLogicalIndexPointer(), this.map(inputSymbol));
        }).collect(ImmutableList.toImmutableList());
        return new LogicalIndexExtractor.ExpressionAndValuePointers(expressionAndValuePointers.getExpression(), expressionAndValuePointers.getLayout(), newValuePointers, syntheticClassifierSymbols, syntheticMatchNumberSymbols);
    }

    public LimitNode map(LimitNode node, PlanNode source) {
        return new LimitNode(node.getId(), source, node.getCount(), node.getTiesResolvingScheme().map(this::map), node.isPartial(), (List)node.getPreSortedInputs().stream().map(this::map).collect(ImmutableList.toImmutableList()));
    }

    public OrderingScheme map(OrderingScheme orderingScheme) {
        ImmutableList.Builder newSymbols = ImmutableList.builder();
        ImmutableMap.Builder newOrderings = ImmutableMap.builder();
        HashSet<Symbol> added = new HashSet<Symbol>(orderingScheme.getOrderBy().size());
        for (Symbol symbol : orderingScheme.getOrderBy()) {
            Symbol canonical = this.map(symbol);
            if (!added.add(canonical)) continue;
            newSymbols.add((Object)canonical);
            newOrderings.put((Object)canonical, (Object)orderingScheme.getOrdering(symbol));
        }
        return new OrderingScheme((List<Symbol>)newSymbols.build(), (Map<Symbol, SortOrder>)newOrderings.build());
    }

    public DistinctLimitNode map(DistinctLimitNode node, PlanNode source) {
        return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), this.mapAndDistinct(node.getDistinctSymbols()), node.getHashSymbol().map(this::map));
    }

    public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source) {
        return new StatisticsWriterNode(node.getId(), source, node.getTarget(), this.map(node.getRowCountSymbol()), node.isRowCountEnabled(), node.getDescriptor().map(this::map));
    }

    public TableWriterNode map(TableWriterNode node, PlanNode source) {
        return this.map(node, source, node.getId());
    }

    public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId newId) {
        return new TableWriterNode(newId, source, node.getTarget(), this.map(node.getRowCountSymbol()), this.map(node.getFragmentSymbol()), this.map(node.getColumns()), node.getColumnNames(), node.getNotNullColumnSymbols(), node.getPartitioningScheme().map((? super T partitioningScheme) -> this.map((PartitioningScheme)partitioningScheme, source.getOutputSymbols())), node.getPreferredPartitioningScheme().map((? super T partitioningScheme) -> this.map((PartitioningScheme)partitioningScheme, source.getOutputSymbols())), node.getStatisticsAggregation().map(this::map), node.getStatisticsAggregationDescriptor().map((? super T descriptor) -> descriptor.map(this::map)));
    }

    public PartitioningScheme map(PartitioningScheme scheme, List<Symbol> sourceLayout) {
        return new PartitioningScheme(scheme.getPartitioning().translate(this::map), this.mapAndDistinct(sourceLayout), scheme.getHashColumn().map(this::map), scheme.isReplicateNullsAndAny(), scheme.getBucketToPartition());
    }

    public TableFinishNode map(TableFinishNode node, PlanNode source) {
        return new TableFinishNode(node.getId(), source, node.getTarget(), this.map(node.getRowCountSymbol()), node.getStatisticsAggregation().map(this::map), node.getStatisticsAggregationDescriptor().map((? super T descriptor) -> descriptor.map(this::map)));
    }

    private StatisticAggregations map(StatisticAggregations statisticAggregations) {
        Map aggregations = (Map)statisticAggregations.getAggregations().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> this.map((Symbol)entry.getKey()), entry -> this.map((AggregationNode.Aggregation)entry.getValue())));
        return new StatisticAggregations(aggregations, this.mapAndDistinct(statisticAggregations.getGroupingSymbols()));
    }

    public RowNumberNode map(RowNumberNode node, PlanNode source) {
        return new RowNumberNode(node.getId(), source, this.mapAndDistinct(node.getPartitionBy()), node.isOrderSensitive(), this.map(node.getRowNumberSymbol()), node.getMaxRowCountPerPartition(), node.getHashSymbol().map(this::map));
    }

    public TopNRankingNode map(TopNRankingNode node, PlanNode source) {
        return new TopNRankingNode(node.getId(), source, this.mapAndDistinct(node.getSpecification()), node.getRankingType(), this.map(node.getRankingSymbol()), node.getMaxRankingPerPartition(), node.isPartial(), node.getHashSymbol().map(this::map));
    }

    public TopNNode map(TopNNode node, PlanNode source) {
        return this.map(node, source, node.getId());
    }

    public TopNNode map(TopNNode node, PlanNode source, PlanNodeId nodeId) {
        return new TopNNode(nodeId, source, node.getCount(), this.map(node.getOrderingScheme()), node.getStep());
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private final ImmutableMap.Builder<Symbol, Symbol> mappings = ImmutableMap.builder();

        public void put(Symbol from, Symbol to) {
            this.mappings.put((Object)from, (Object)to);
        }

        public SymbolMapper build() {
            return SymbolMapper.symbolMapper((Map<Symbol, Symbol>)this.mappings.build());
        }
    }
}

