/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.spi.block.SortOrder;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.PartitioningScheme;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.planner.plan.StatisticAggregations;
import io.prestosql.sql.planner.plan.StatisticAggregationsDescriptor;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableFinishNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.ExpressionRewriter;
import io.prestosql.sql.tree.ExpressionTreeRewriter;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.SymbolReference;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class SymbolMapper {
    private final Map<Symbol, Symbol> mapping;

    public SymbolMapper(Map<Symbol, Symbol> mapping) {
        this.mapping = ImmutableMap.copyOf(Objects.requireNonNull(mapping, "mapping is null"));
    }

    public Symbol map(Symbol symbol) {
        Symbol canonical = symbol;
        while (this.mapping.containsKey(canonical) && !this.mapping.get(canonical).equals(canonical)) {
            canonical = this.mapping.get(canonical);
        }
        return canonical;
    }

    public Expression map(Expression value) {
        return ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new ExpressionRewriter<Void>(){

            public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                Symbol canonical = SymbolMapper.this.map(Symbol.from((Expression)node));
                return canonical.toSymbolReference();
            }
        }, (Expression)value);
    }

    public AggregationNode map(AggregationNode node, PlanNode source) {
        return this.map(node, source, node.getId());
    }

    public AggregationNode map(AggregationNode node, PlanNode source, PlanNodeIdAllocator idAllocator) {
        return this.map(node, source, idAllocator.getNextId());
    }

    private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            aggregations.put((Object)this.map(entry.getKey()), (Object)this.map(entry.getValue()));
        }
        return new AggregationNode(newNodeId, source, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), AggregationNode.groupingSets(this.mapAndDistinct(node.getGroupingKeys()), node.getGroupingSetCount(), node.getGlobalGroupingSets()), (List<Symbol>)ImmutableList.of(), node.getStep(), node.getHashSymbol().map(this::map), node.getGroupIdSymbol().map(this::map));
    }

    private AggregationNode.Aggregation map(AggregationNode.Aggregation aggregation) {
        return new AggregationNode.Aggregation((FunctionCall)this.map((Expression)aggregation.getCall()), aggregation.getSignature(), aggregation.getMask().map(this::map));
    }

    public TopNNode map(TopNNode node, PlanNode source, PlanNodeId newNodeId) {
        ImmutableList.Builder symbols = ImmutableList.builder();
        ImmutableMap.Builder orderings = ImmutableMap.builder();
        HashSet<Symbol> seenCanonicals = new HashSet<Symbol>(node.getOrderingScheme().getOrderBy().size());
        for (Symbol symbol : node.getOrderingScheme().getOrderBy()) {
            Symbol canonical = this.map(symbol);
            if (!seenCanonicals.add(canonical)) continue;
            seenCanonicals.add(canonical);
            symbols.add((Object)canonical);
            orderings.put((Object)canonical, (Object)node.getOrderingScheme().getOrdering(symbol));
        }
        return new TopNNode(newNodeId, source, node.getCount(), new OrderingScheme((List<Symbol>)symbols.build(), (Map<Symbol, SortOrder>)orderings.build()), node.getStep());
    }

    public TableWriterNode map(TableWriterNode node, PlanNode source) {
        return this.map(node, source, node.getId());
    }

    public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId newNodeId) {
        ImmutableList columns = (ImmutableList)node.getColumns().stream().map(this::map).collect(ImmutableList.toImmutableList());
        return new TableWriterNode(newNodeId, source, node.getTarget(), this.map(node.getRowCountSymbol()), this.map(node.getFragmentSymbol()), (List<Symbol>)columns, node.getColumnNames(), node.getPartitioningScheme().map((? super T partitioningScheme) -> this.canonicalize((PartitioningScheme)partitioningScheme, source)), node.getStatisticsAggregation().map(this::map), node.getStatisticsAggregationDescriptor().map(this::map));
    }

    public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source) {
        return new StatisticsWriterNode(node.getId(), source, node.getTarget(), node.getRowCountSymbol(), node.isRowCountEnabled(), node.getDescriptor().map(this::map));
    }

    public TableFinishNode map(TableFinishNode node, PlanNode source) {
        return new TableFinishNode(node.getId(), source, node.getTarget(), this.map(node.getRowCountSymbol()), node.getStatisticsAggregation().map(this::map), node.getStatisticsAggregationDescriptor().map((? super T descriptor) -> descriptor.map(this::map)));
    }

    private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode source) {
        return new PartitioningScheme(scheme.getPartitioning().translate(this::map), this.mapAndDistinct(source.getOutputSymbols()), scheme.getHashColumn().map(this::map), scheme.isReplicateNullsAndAny(), scheme.getBucketToPartition());
    }

    private StatisticAggregations map(StatisticAggregations statisticAggregations) {
        Map aggregations = (Map)statisticAggregations.getAggregations().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> this.map((Symbol)entry.getKey()), entry -> this.map((AggregationNode.Aggregation)entry.getValue())));
        return new StatisticAggregations(aggregations, this.mapAndDistinct(statisticAggregations.getGroupingSymbols()));
    }

    private StatisticAggregationsDescriptor<Symbol> map(StatisticAggregationsDescriptor<Symbol> descriptor) {
        return descriptor.map(this::map);
    }

    private List<Symbol> map(List<Symbol> outputs) {
        return (List)outputs.stream().map(this::map).collect(ImmutableList.toImmutableList());
    }

    private List<Symbol> mapAndDistinct(List<Symbol> outputs) {
        HashSet<Symbol> added = new HashSet<Symbol>();
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Symbol symbol : outputs) {
            Symbol canonical = this.map(symbol);
            if (!added.add(canonical)) continue;
            builder.add((Object)canonical);
        }
        return builder.build();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private final ImmutableMap.Builder<Symbol, Symbol> mappings = ImmutableMap.builder();

        public SymbolMapper build() {
            return new SymbolMapper((Map<Symbol, Symbol>)this.mappings.build());
        }

        public void put(Symbol from, Symbol to) {
            this.mappings.put((Object)from, (Object)to);
        }
    }
}

