/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.optimizations.PlanNodeDecorrelator;
import com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class ScalarAggregationToJoinRewriter {
    private final FunctionResolution functionResolution;
    private final VariableAllocator variableAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;
    private final PlanNodeDecorrelator planNodeDecorrelator;

    public ScalarAggregationToJoinRewriter(FunctionAndTypeManager functionAndTypeManager, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        Objects.requireNonNull(functionAndTypeManager, "metadata is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
        this.variableAllocator = Objects.requireNonNull(variableAllocator, "variableAllocator is null");
        this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
        LogicalRowExpressions logicalRowExpressions = new LogicalRowExpressions((DeterminismEvaluator)new RowExpressionDeterminismEvaluator(functionAndTypeManager), (StandardFunctionResolution)new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()), (FunctionMetadataManager)functionAndTypeManager);
        this.planNodeDecorrelator = new PlanNodeDecorrelator(idAllocator, variableAllocator, lookup, logicalRowExpressions);
    }

    public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregation) {
        List<VariableReferenceExpression> correlation = lateralJoinNode.getCorrelation();
        Optional<PlanNodeDecorrelator.DecorrelatedNode> source = this.planNodeDecorrelator.decorrelateFilters(this.lookup.resolve(aggregation.getSource()), correlation);
        if (!source.isPresent()) {
            return lateralJoinNode;
        }
        VariableReferenceExpression nonNull = this.variableAllocator.newVariable("non_null", (Type)BooleanType.BOOLEAN);
        Assignments scalarAggregationSourceAssignments = Assignments.builder().putAll(AssignmentUtils.identityAssignments(source.get().getNode().getOutputVariables())).put(nonNull, (RowExpression)LogicalRowExpressions.TRUE_CONSTANT).build();
        ProjectNode scalarAggregationSourceWithNonNullableVariable = new ProjectNode(this.idAllocator.getNextId(), source.get().getNode(), scalarAggregationSourceAssignments);
        return this.rewriteScalarAggregation(lateralJoinNode, aggregation, (PlanNode)scalarAggregationSourceWithNonNullableVariable, source.get().getCorrelatedPredicates(), nonNull);
    }

    private PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional<RowExpression> joinExpression, VariableReferenceExpression nonNull) {
        AssignUniqueId inputWithUniqueColumns = new AssignUniqueId(lateralJoinNode.getSourceLocation(), this.idAllocator.getNextId(), lateralJoinNode.getInput(), this.variableAllocator.newVariable(nonNull.getSourceLocation(), "unique", (Type)BigintType.BIGINT));
        JoinNode leftOuterJoin = new JoinNode(scalarAggregation.getSourceLocation(), this.idAllocator.getNextId(), JoinNode.Type.LEFT, inputWithUniqueColumns, scalarAggregationSource, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), (List<VariableReferenceExpression>)ImmutableList.builder().addAll(inputWithUniqueColumns.getOutputVariables()).addAll((Iterable)scalarAggregationSource.getOutputVariables()).build(), joinExpression, Optional.empty(), Optional.empty(), Optional.empty(), (Map<String, VariableReferenceExpression>)ImmutableMap.of());
        Optional<AggregationNode> aggregationNode = this.createAggregationNode(scalarAggregation, leftOuterJoin, nonNull);
        if (!aggregationNode.isPresent()) {
            return lateralJoinNode;
        }
        Optional subqueryProjection = PlanNodeSearcher.searchFrom(lateralJoinNode.getSubquery(), this.lookup).where(ProjectNode.class::isInstance).recurseOnlyWhen(EnforceSingleRowNode.class::isInstance).findFirst();
        List<VariableReferenceExpression> aggregationOutputVariables = this.getTruncatedAggregationVariables(lateralJoinNode, aggregationNode.get());
        if (subqueryProjection.isPresent()) {
            Assignments assignments = Assignments.builder().putAll(AssignmentUtils.identityAssignments(aggregationOutputVariables)).putAll(((ProjectNode)subqueryProjection.get()).getAssignments()).build();
            return new ProjectNode(this.idAllocator.getNextId(), (PlanNode)aggregationNode.get(), assignments);
        }
        return new ProjectNode(this.idAllocator.getNextId(), (PlanNode)aggregationNode.get(), AssignmentUtils.identityAssignments(aggregationOutputVariables));
    }

    private List<VariableReferenceExpression> getTruncatedAggregationVariables(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) {
        HashSet<VariableReferenceExpression> applyVariables = new HashSet<VariableReferenceExpression>(lateralJoinNode.getOutputVariables());
        return (List)aggregationNode.getOutputVariables().stream().filter(applyVariables::contains).collect(ImmutableList.toImmutableList());
    }

    private Optional<AggregationNode> createAggregationNode(AggregationNode scalarAggregation, JoinNode leftOuterJoin, VariableReferenceExpression nonNull) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (Map.Entry entry : scalarAggregation.getAggregations().entrySet()) {
            VariableReferenceExpression variable = (VariableReferenceExpression)entry.getKey();
            if (this.functionResolution.isCountFunction(((AggregationNode.Aggregation)entry.getValue()).getFunctionHandle())) {
                Type scalarAggregationSourceType = nonNull.getType();
                aggregations.put((Object)variable, (Object)new AggregationNode.Aggregation(new CallExpression(variable.getSourceLocation(), "count", this.functionResolution.countFunction(scalarAggregationSourceType), (Type)BigintType.BIGINT, (List)ImmutableList.of((Object)nonNull)), Optional.empty(), Optional.empty(), false, ((AggregationNode.Aggregation)entry.getValue()).getMask()));
                continue;
            }
            aggregations.put((Object)variable, entry.getValue());
        }
        return Optional.of(new AggregationNode(scalarAggregation.getSourceLocation(), this.idAllocator.getNextId(), (PlanNode)leftOuterJoin, (Map)aggregations.build(), AggregationNode.singleGroupingSet((List)leftOuterJoin.getLeft().getOutputVariables()), (List)ImmutableList.of(), scalarAggregation.getStep(), scalarAggregation.getHashVariable(), Optional.empty(), Optional.empty()));
    }
}

