/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeUtils;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizer;
import com.facebook.presto.sql.planner.optimizations.PlanOptimizerResult;
import com.facebook.presto.sql.planner.plan.ChildReplacer;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.airlift.slice.Slices;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class PayloadJoinOptimizer
implements PlanOptimizer {
    private final Metadata metadata;
    private boolean isEnabledForTesting;

    public PayloadJoinOptimizer(Metadata metadata) {
        Objects.requireNonNull(metadata, "metadata is null");
        this.metadata = metadata;
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        FunctionAndTypeManager functionAndTypeManager = this.metadata.getFunctionAndTypeManager();
        if (this.isEnabled(session)) {
            Rewriter rewriter = new Rewriter(session, this.metadata, types, functionAndTypeManager, idAllocator, variableAllocator);
            PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, new JoinContext());
            return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
        }
        return PlanOptimizerResult.optimizerResult(plan, false);
    }

    @Override
    public void setEnabledForTesting(boolean isSet) {
        this.isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session) {
        return this.isEnabledForTesting || SystemSessionProperties.isOptimizePayloadJoins(session);
    }

    private static RowExpression zeroForType(Type type) {
        Preconditions.checkArgument((TypeUtils.isNumericType((Type)type) || type instanceof VarcharType ? 1 : 0) != 0, (Object)"join key should be of numeric or varchar type");
        if (TypeUtils.isNumericType((Type)type)) {
            return Expressions.constant(0L, (Type)BigintType.BIGINT);
        }
        return Expressions.constant(Slices.utf8Slice((String)""), (Type)VarcharType.VARCHAR);
    }

    private static class JoinContext {
        private Set<VariableReferenceExpression> joinKeys = new HashSet<VariableReferenceExpression>();
        private Map<VariableReferenceExpression, VariableReferenceExpression> joinKeyMap;
        private Map<VariableReferenceExpression, RowExpression> projectionsToPush = new HashMap<VariableReferenceExpression, RowExpression>();
        int numJoins;
        PlanNode payloadNode;

        public Set<VariableReferenceExpression> getJoinKeys() {
            return this.joinKeys;
        }

        public void addKeys(ImmutableSet<VariableReferenceExpression> keys) {
            this.joinKeys.addAll((Collection<VariableReferenceExpression>)keys);
        }

        public Map<VariableReferenceExpression, RowExpression> getProjectionsToPush() {
            return this.projectionsToPush;
        }

        public void addProjectionsToPush(Map<VariableReferenceExpression, RowExpression> map) {
            this.projectionsToPush.putAll(map);
        }

        public Map<VariableReferenceExpression, VariableReferenceExpression> getJoinKeyMap() {
            return this.joinKeyMap;
        }

        public void setJoinKeyMap(Map<VariableReferenceExpression, VariableReferenceExpression> map) {
            this.joinKeyMap = map;
        }

        public PlanNode getPayloadNode() {
            return this.payloadNode;
        }

        public void setPayloadNode(PlanNode payloadNode) {
            this.payloadNode = payloadNode;
        }

        public void reset() {
            this.joinKeys = new HashSet<VariableReferenceExpression>();
            this.projectionsToPush = new HashMap<VariableReferenceExpression, RowExpression>();
            this.joinKeyMap = null;
            this.numJoins = 0;
            this.payloadNode = null;
        }

        public int getNumJoins() {
            return this.numJoins;
        }

        public void incrementNumJoins() {
            ++this.numJoins;
        }

        public boolean needsPayloadRejoin() {
            return this.payloadNode != null;
        }
    }

    private static class Rewriter
    extends SimplePlanRewriter<JoinContext> {
        private final Session session;
        Metadata metadata;
        private final TypeProvider types;
        private final FunctionAndTypeManager functionAndTypeManager;
        private final PlanNodeIdAllocator planNodeIdAllocator;
        private final VariableAllocator variableAllocator;
        private boolean planChanged;

        private Rewriter(Session session, Metadata metadata, TypeProvider types, FunctionAndTypeManager functionAndTypeManager, PlanNodeIdAllocator planNodeIdAllocator, VariableAllocator variableAllocator) {
            this.session = Objects.requireNonNull(session, "session is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.functionAndTypeManager = Objects.requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.planNodeIdAllocator = Objects.requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null");
            this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
        }

        public boolean isPlanChanged() {
            return this.planChanged;
        }

        @Override
        public PlanNode visitPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<JoinContext> context) {
            List<PlanNode> newChildren = planNode.getSources().stream().map(childNode -> context.rewrite((PlanNode)childNode, new JoinContext())).collect(Collectors.toList());
            return ChildReplacer.replaceChildren(planNode, newChildren);
        }

        @Override
        public PlanNode visitJoin(JoinNode joinNode, SimplePlanRewriter.RewriteContext<JoinContext> context) {
            JoinContext joinContext = context.get();
            Set<VariableReferenceExpression> inputJoinKeys = joinContext.getJoinKeys();
            PlanNode leftNode = joinNode.getLeft();
            PlanNode rightNode = joinNode.getRight();
            boolean isTopJoin = joinContext.getJoinKeys().size() == 0;
            ImmutableSet leftColumns = (ImmutableSet)leftNode.getOutputVariables().stream().collect(ImmutableSet.toImmutableSet());
            ImmutableSet rightJoinKeys = (ImmutableSet)inputJoinKeys.stream().filter(key -> rightNode.getOutputVariables().contains(key)).collect(ImmutableSet.toImmutableSet());
            Set<VariableReferenceExpression> joinKeys = this.extractJoinKeys(joinNode.getFilter(), joinNode.getCriteria());
            ImmutableSet leftJoinKeys = Sets.intersection(joinKeys, (Set)leftColumns).immutableCopy();
            if (!rightJoinKeys.isEmpty() || !this.needsRewrite(joinNode.getType(), (ImmutableSet<VariableReferenceExpression>)leftColumns, (Set<VariableReferenceExpression>)leftJoinKeys)) {
                List newChildren = (List)joinNode.getSources().stream().map(child -> this.defaultRewriteJoinChild((PlanNode)child, context, joinNode.isCrossJoin())).collect(ImmutableList.toImmutableList());
                return ChildReplacer.replaceChildren(joinNode, newChildren);
            }
            joinContext.addKeys((ImmutableSet<VariableReferenceExpression>)leftJoinKeys);
            joinContext.incrementNumJoins();
            PlanNode newLeftNode = context.rewrite(leftNode, joinContext);
            if (leftNode.equals(newLeftNode)) {
                newLeftNode = context.rewrite(leftNode, new JoinContext());
                return ChildReplacer.replaceChildren(joinNode, (List<PlanNode>)ImmutableList.of((Object)newLeftNode, (Object)rightNode));
            }
            List leftCols = newLeftNode.getOutputVariables();
            List rightCols = rightNode.getOutputVariables();
            List allCols = (List)Stream.concat(leftCols.stream(), rightCols.stream()).collect(ImmutableList.toImmutableList());
            JoinNode newJoinNode = new JoinNode(joinNode.getSourceLocation(), this.planNodeIdAllocator.getNextId(), joinNode.getType(), newLeftNode, rightNode, joinNode.getCriteria(), allCols, joinNode.getFilter(), joinNode.getLeftHashVariable(), joinNode.getRightHashVariable(), joinNode.getDistributionType(), joinNode.getDynamicFilters());
            if (isTopJoin && context.get().needsPayloadRejoin()) {
                PlanNode payloadJoin = this.transformJoin(newJoinNode, joinContext);
                context.get().setPayloadNode(null);
                List<VariableReferenceExpression> outputVariables = joinNode.getOutputVariables();
                if (!payloadJoin.getOutputVariables().containsAll(outputVariables)) {
                    return joinNode;
                }
                return PlannerUtils.restrictOutput(payloadJoin, this.planNodeIdAllocator, outputVariables);
            }
            this.planChanged = true;
            return newJoinNode;
        }

        private PlanNode defaultRewriteJoinChild(PlanNode child, SimplePlanRewriter.RewriteContext<JoinContext> context, boolean isCrossJoin) {
            PlanNode newChild = context.rewrite(child, new JoinContext());
            if (isCrossJoin && child.getOutputVariables() != newChild.getOutputVariables()) {
                return PlannerUtils.restrictOutput(newChild, this.planNodeIdAllocator, child.getOutputVariables());
            }
            return newChild;
        }

        private boolean needsRewrite(JoinNode.Type joinType, ImmutableSet<VariableReferenceExpression> leftColumns, Set<VariableReferenceExpression> joinKeys) {
            return joinType == JoinNode.Type.LEFT && this.supportedJoinKeyTypes(joinKeys) && leftColumns.stream().anyMatch(var -> !joinKeys.contains(var));
        }

        public PlanNode visitProject(ProjectNode projectNode, SimplePlanRewriter.RewriteContext<JoinContext> context) {
            PlanNode newChild;
            if (PlannerUtils.isScanFilterProject((PlanNode)projectNode)) {
                return this.rewriteScanFilterProject((PlanNode)projectNode, context);
            }
            PlanNode child = projectNode.getSource();
            Set<VariableReferenceExpression> inputJoinKeys = context.get().getJoinKeys();
            if (!child.getOutputVariables().containsAll(inputJoinKeys)) {
                HashMap<VariableReferenceExpression, RowExpression> pushableExpressions = new HashMap<VariableReferenceExpression, RowExpression>();
                projectNode.getAssignments().forEach((var, expr) -> {
                    if (inputJoinKeys.contains(var) && !var.equals(expr)) {
                        pushableExpressions.put((VariableReferenceExpression)var, (RowExpression)expr);
                    }
                });
                context.get().addProjectionsToPush(pushableExpressions);
            }
            if (child.equals(newChild = context.rewrite(child, context.get()))) {
                return projectNode;
            }
            Set<VariableReferenceExpression> joinKeys = context.get().getJoinKeys();
            Assignments newAssignments = projectNode.getAssignments();
            if (context.get().needsPayloadRejoin() && !child.getOutputVariables().containsAll(joinKeys)) {
                Assignments.Builder assignments = Assignments.builder();
                projectNode.getAssignments().forEach((var, expr) -> {
                    if (joinKeys.contains(var) && !var.equals(expr)) {
                        assignments.put(var, (RowExpression)var);
                    } else {
                        assignments.put(var, expr);
                    }
                });
                newAssignments = assignments.build();
            }
            Set newChildOutputVarSet = (Set)newChild.getOutputVariables().stream().collect(ImmutableSet.toImmutableSet());
            Assignments newProjectAssighments = this.removeHiddenColumns(newAssignments, newChildOutputVarSet, context.get().getJoinKeys());
            ProjectNode newProjectNode = new ProjectNode(projectNode.getId(), newChild, newProjectAssighments);
            return this.validateProjectAssignments(newProjectNode) ? newProjectNode : projectNode;
        }

        public PlanNode visitFilter(FilterNode filterNode, SimplePlanRewriter.RewriteContext<JoinContext> context) {
            if (PlannerUtils.isScanFilterProject((PlanNode)filterNode)) {
                return this.rewriteScanFilterProject((PlanNode)filterNode, context);
            }
            return context.defaultRewrite((PlanNode)filterNode, new JoinContext());
        }

        public PlanNode visitTableScan(TableScanNode scanNode, SimplePlanRewriter.RewriteContext<JoinContext> context) {
            return this.rewriteScanFilterProject((PlanNode)scanNode, context);
        }

        private PlanNode rewriteScanFilterProject(PlanNode planNode, SimplePlanRewriter.RewriteContext<JoinContext> context) {
            Set<VariableReferenceExpression> joinKeys = context.get().getJoinKeys();
            if (joinKeys.size() == 0 || context.get().getNumJoins() < 2) {
                return planNode;
            }
            List outputCols = planNode.getOutputVariables();
            if (!ImmutableSet.copyOf((Collection)planNode.getOutputVariables()).containsAll(joinKeys)) {
                Map<VariableReferenceExpression, RowExpression> projectionsToPush = context.get().getProjectionsToPush();
                if (!outputCols.containsAll(VariablesExtractor.extractUnique(projectionsToPush.values()))) {
                    return planNode;
                }
                PlanNode newProjectNode = PlannerUtils.addProjections(planNode, this.planNodeIdAllocator, context.get().getProjectionsToPush());
                return this.constructDistinctKeysPlan(newProjectNode, context, joinKeys);
            }
            return this.constructDistinctKeysPlan(planNode, context, joinKeys);
        }

        private AggregationNode constructDistinctKeysPlan(PlanNode planNode, SimplePlanRewriter.RewriteContext<JoinContext> context, Set<VariableReferenceExpression> joinKeys) {
            List groupingKeys = (List)joinKeys.stream().collect(ImmutableList.toImmutableList());
            AggregationNode agg = new AggregationNode(planNode.getSourceLocation(), this.planNodeIdAllocator.getNextId(), planNode, (Map)ImmutableMap.of(), AggregationNode.singleGroupingSet((List)groupingKeys), (List)ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), Optional.empty());
            HashMap<VariableReferenceExpression, VariableReferenceExpression> varMap = new HashMap<VariableReferenceExpression, VariableReferenceExpression>();
            for (VariableReferenceExpression var : joinKeys) {
                VariableReferenceExpression newVar = this.variableAllocator.newVariable(var.getName(), var.getType());
                varMap.put(var, newVar);
            }
            context.get().setJoinKeyMap(new HashMap<VariableReferenceExpression, VariableReferenceExpression>(varMap));
            PlanNode planNodeCopy = PlannerUtils.clonePlanNode(planNode, this.session, this.metadata, this.planNodeIdAllocator, planNode.getOutputVariables(), varMap);
            context.get().setPayloadNode(planNodeCopy);
            return agg;
        }

        private PlanNode transformJoin(JoinNode keysNode, JoinContext context) {
            PlanNode payloadPlanNode = context.getPayloadNode();
            Set<VariableReferenceExpression> joinKeys = context.getJoinKeys();
            Map<VariableReferenceExpression, VariableReferenceExpression> joinKeyMap = context.getJoinKeyMap();
            Preconditions.checkState((null != payloadPlanNode ? 1 : 0) != 0, (Object)"Payload plannode not initialized");
            Preconditions.checkState((null != joinKeyMap ? 1 : 0) != 0, (Object)"joinkey map not initialized");
            FunctionResolution functionResolution = new FunctionResolution(this.functionAndTypeManager.getFunctionAndTypeResolver());
            Assignments.Builder assignments = Assignments.builder();
            ImmutableList.Builder coalesceComparisonBuilder = ImmutableList.builder();
            ImmutableList.Builder nullComparisonBuilder = ImmutableList.builder();
            List<VariableReferenceExpression> joinOutputCols = keysNode.getOutputVariables();
            for (VariableReferenceExpression var : joinOutputCols) {
                assignments.put(var, (RowExpression)var);
            }
            for (VariableReferenceExpression var : joinKeys) {
                VariableReferenceExpression newVar = joinKeyMap.get(var);
                VariableReferenceExpression isNullVar = this.variableAllocator.newVariable(var.getName() + "_NULL", (Type)BooleanType.BOOLEAN);
                assignments.put(isNullVar, (RowExpression)Expressions.specialForm(SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)var)));
                CallExpression coalesceComp = PlannerUtils.equalityPredicate(functionResolution, this.coalesceToZero((RowExpression)newVar), this.coalesceToZero((RowExpression)var));
                CallExpression nullComp = PlannerUtils.equalityPredicate(functionResolution, (RowExpression)Expressions.specialForm(SpecialFormExpression.Form.IS_NULL, (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)newVar)), (RowExpression)isNullVar);
                nullComparisonBuilder.add((Object)nullComp);
                coalesceComparisonBuilder.add((Object)coalesceComp);
            }
            ProjectNode projectNode = new ProjectNode(this.planNodeIdAllocator.getNextId(), (PlanNode)keysNode, assignments.build());
            List resultOutputCols = (List)Stream.concat(payloadPlanNode.getOutputVariables().stream(), projectNode.getOutputVariables().stream()).collect(ImmutableList.toImmutableList());
            List joinCriteria = (List)Stream.concat(nullComparisonBuilder.build().stream(), coalesceComparisonBuilder.build().stream()).collect(ImmutableList.toImmutableList());
            return new JoinNode(keysNode.getSourceLocation(), this.planNodeIdAllocator.getNextId(), JoinNode.Type.LEFT, payloadPlanNode, (PlanNode)projectNode, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), resultOutputCols, Optional.of(LogicalRowExpressions.and((Collection)joinCriteria)), keysNode.getLeftHashVariable(), keysNode.getRightHashVariable(), keysNode.getDistributionType(), keysNode.getDynamicFilters());
        }

        private Assignments removeHiddenColumns(Assignments newAssignments, Set<VariableReferenceExpression> newChildOutputVarSet, Set<VariableReferenceExpression> joinKeys) {
            Map<VariableReferenceExpression, RowExpression> newAssignmentsMap = newAssignments.entrySet().stream().filter(assignment -> newChildOutputVarSet.containsAll(VariablesExtractor.extractUnique((RowExpression)assignment.getValue()))).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            Set<VariableReferenceExpression> outputKeys = newAssignmentsMap.keySet();
            Map joinKeyMap = joinKeys.stream().filter(key -> !outputKeys.contains(key) && newChildOutputVarSet.contains(key)).collect(Collectors.toMap(Function.identity(), Function.identity()));
            newAssignmentsMap.putAll(joinKeyMap);
            return new Assignments(newAssignmentsMap);
        }

        private boolean validateProjectAssignments(ProjectNode projectNode) {
            Assignments assignments = projectNode.getAssignments();
            PlanNode input = projectNode.getSource();
            ImmutableSet inputColsSet = (ImmutableSet)input.getOutputVariables().stream().collect(ImmutableSet.toImmutableSet());
            for (Map.Entry assignment : assignments.entrySet()) {
                RowExpression expr = (RowExpression)assignment.getValue();
                if (inputColsSet.containsAll(VariablesExtractor.extractUnique(expr))) continue;
                return false;
            }
            return true;
        }

        private RowExpression coalesceToZero(RowExpression var) {
            RowExpression zero = PayloadJoinOptimizer.zeroForType(var.getType());
            return PlannerUtils.coalesce((List<RowExpression>)ImmutableList.of((Object)var, (Object)zero));
        }

        private Set<VariableReferenceExpression> extractJoinKeys(Optional<RowExpression> filter, List<JoinNode.EquiJoinClause> criteria) {
            ImmutableSet.Builder builder = ImmutableSet.builder();
            criteria.forEach(v -> {
                builder.add((Object)v.getLeft());
                builder.add((Object)v.getRight());
            });
            if (filter.isPresent()) {
                builder.addAll(VariablesExtractor.extractAll(filter.get()));
            }
            return builder.build();
        }

        private boolean supportedJoinKeyTypes(Set<VariableReferenceExpression> joinKeys) {
            return joinKeys.stream().allMatch(key -> key.getType() instanceof VarcharType || TypeUtils.isNumericType((Type)key.getType()));
        }
    }
}

