/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningHandle;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public final class PushRemoteExchangeThroughGroupId
implements Rule<ExchangeNode> {
    private final Metadata metadata;
    private static final Capture<GroupIdNode> GROUP_ID = Capture.newCapture();
    private static final Pattern<ExchangeNode> PATTERN = Patterns.exchange().matching(exchange -> exchange.getScope().isRemote()).matching(exchange -> exchange.getType() == ExchangeNode.Type.REPARTITION).with(Patterns.source().matching(Patterns.groupId().capturedAs(GROUP_ID).matching(groupId -> !groupId.getCommonGroupingColumns().isEmpty())));

    public PushRemoteExchangeThroughGroupId(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public Pattern<ExchangeNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.shouldPushRemoteExchangeThroughGroupId(session);
    }

    @Override
    public Rule.Result apply(ExchangeNode node, Captures captures, Rule.Context context) {
        GroupIdNode groupIdNode = (GroupIdNode)((Object)captures.get(GROUP_ID));
        List<VariableReferenceExpression> inputs = (List<VariableReferenceExpression>)Iterables.getOnlyElement(node.getInputs());
        inputs = PushRemoteExchangeThroughGroupId.removeVariable(inputs, groupIdNode.getGroupIdVariable());
        inputs = PushRemoteExchangeThroughGroupId.replaceAlias(inputs, groupIdNode.getGroupingColumns());
        PartitioningScheme partitioningScheme = node.getPartitioningScheme();
        List<VariableReferenceExpression> outputLayout = partitioningScheme.getOutputLayout();
        outputLayout = PushRemoteExchangeThroughGroupId.removeVariable(outputLayout, groupIdNode.getGroupIdVariable());
        outputLayout = PushRemoteExchangeThroughGroupId.replaceAlias(outputLayout, groupIdNode.getGroupingColumns());
        Set<VariableReferenceExpression> commonGroupingColumns = groupIdNode.getCommonGroupingColumns();
        List<VariableReferenceExpression> partitionColumns = PushRemoteExchangeThroughGroupId.replaceAlias(commonGroupingColumns, groupIdNode.getGroupingColumns());
        Map<VariableReferenceExpression, VariableReferenceExpression> groupingColumns = groupIdNode.getGroupingColumns();
        List originalPartitionColumns = (List)partitioningScheme.getPartitioning().getVariableReferences().stream().map(expr -> groupingColumns.getOrDefault(expr, (VariableReferenceExpression)expr)).collect(ImmutableList.toImmutableList());
        if (!originalPartitionColumns.containsAll(partitionColumns)) {
            return Rule.Result.empty();
        }
        PartitioningHandle partitioningHandle = "system".equals(SystemSessionProperties.getPartitioningProviderCatalog(context.getSession())) ? partitioningScheme.getPartitioning().getHandle() : this.createPartitioningHandle(context.getSession(), partitionColumns);
        return Rule.Result.ofPlanNode(new GroupIdNode(node.getSourceLocation(), groupIdNode.getId(), new ExchangeNode(node.getSourceLocation(), node.getId(), node.getType(), node.getScope(), new PartitioningScheme(Partitioning.create((PartitioningHandle)partitioningHandle, partitionColumns), outputLayout, partitioningScheme.getHashColumn(), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition()), (List<PlanNode>)ImmutableList.of((Object)groupIdNode.getSource()), (List<List<VariableReferenceExpression>>)ImmutableList.of(inputs), node.isEnsureSourceOrdering(), node.getOrderingScheme()), groupIdNode.getGroupingSets(), groupIdNode.getGroupingColumns(), groupIdNode.getAggregationArguments(), groupIdNode.getGroupIdVariable()));
    }

    private static List<VariableReferenceExpression> removeVariable(List<VariableReferenceExpression> variables, VariableReferenceExpression variableToRemove) {
        return (List)variables.stream().filter(variable -> !variableToRemove.equals(variable)).collect(ImmutableList.toImmutableList());
    }

    private static List<VariableReferenceExpression> replaceAlias(Collection<VariableReferenceExpression> variables, Map<VariableReferenceExpression, VariableReferenceExpression> mapping) {
        return (List)variables.stream().map(variable -> mapping.containsKey(variable) ? (VariableReferenceExpression)mapping.get(variable) : variable).collect(ImmutableList.toImmutableList());
    }

    private PartitioningHandle createPartitioningHandle(Session session, Collection<VariableReferenceExpression> partitioningColumns) {
        List partitioningTypes = (List)partitioningColumns.stream().map(VariableReferenceExpression::getType).collect(ImmutableList.toImmutableList());
        return this.metadata.getPartitioningHandleForExchange(session, SystemSessionProperties.getPartitioningProviderCatalog(session), SystemSessionProperties.getHashPartitionCount(session), partitioningTypes);
    }
}

