/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.prestosql.metadata.FunctionRegistry;
import io.prestosql.spi.type.BigintType;
import io.prestosql.spi.type.BooleanType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.optimizations.PlanNodeDecorrelator;
import io.prestosql.sql.planner.optimizations.PlanNodeSearcher;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.EnforceSingleRowNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LateralJoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.QualifiedName;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class ScalarAggregationToJoinRewriter {
    private static final QualifiedName COUNT = QualifiedName.of((String)"count");
    private final FunctionRegistry functionRegistry;
    private final SymbolAllocator symbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final Lookup lookup;
    private final PlanNodeDecorrelator planNodeDecorrelator;

    public ScalarAggregationToJoinRewriter(FunctionRegistry functionRegistry, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
        this.functionRegistry = Objects.requireNonNull(functionRegistry, "metadata is null");
        this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
        this.lookup = Objects.requireNonNull(lookup, "lookup is null");
        this.planNodeDecorrelator = new PlanNodeDecorrelator(idAllocator, lookup);
    }

    public PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode aggregation) {
        List<Symbol> correlation = lateralJoinNode.getCorrelation();
        Optional<PlanNodeDecorrelator.DecorrelatedNode> source = this.planNodeDecorrelator.decorrelateFilters(this.lookup.resolve(aggregation.getSource()), correlation);
        if (!source.isPresent()) {
            return lateralJoinNode;
        }
        Symbol nonNull = this.symbolAllocator.newSymbol("non_null", (Type)BooleanType.BOOLEAN);
        Assignments scalarAggregationSourceAssignments = Assignments.builder().putIdentities(source.get().getNode().getOutputSymbols()).put(nonNull, (Expression)BooleanLiteral.TRUE_LITERAL).build();
        ProjectNode scalarAggregationSourceWithNonNullableSymbol = new ProjectNode(this.idAllocator.getNextId(), source.get().getNode(), scalarAggregationSourceAssignments);
        return this.rewriteScalarAggregation(lateralJoinNode, aggregation, scalarAggregationSourceWithNonNullableSymbol, source.get().getCorrelatedPredicates(), nonNull);
    }

    private PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional<Expression> joinExpression, Symbol nonNull) {
        AssignUniqueId inputWithUniqueColumns = new AssignUniqueId(this.idAllocator.getNextId(), lateralJoinNode.getInput(), this.symbolAllocator.newSymbol("unique", (Type)BigintType.BIGINT));
        JoinNode leftOuterJoin = new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.LEFT, inputWithUniqueColumns, scalarAggregationSource, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), (List<Symbol>)ImmutableList.builder().addAll(inputWithUniqueColumns.getOutputSymbols()).addAll(scalarAggregationSource.getOutputSymbols()).build(), joinExpression, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
        Optional<AggregationNode> aggregationNode = this.createAggregationNode(scalarAggregation, leftOuterJoin, nonNull);
        if (!aggregationNode.isPresent()) {
            return lateralJoinNode;
        }
        Optional subqueryProjection = PlanNodeSearcher.searchFrom(lateralJoinNode.getSubquery(), this.lookup).where(ProjectNode.class::isInstance).recurseOnlyWhen(EnforceSingleRowNode.class::isInstance).findFirst();
        List<Symbol> aggregationOutputSymbols = ScalarAggregationToJoinRewriter.getTruncatedAggregationSymbols(lateralJoinNode, aggregationNode.get());
        if (subqueryProjection.isPresent()) {
            Assignments assignments = Assignments.builder().putIdentities(aggregationOutputSymbols).putAll(((ProjectNode)subqueryProjection.get()).getAssignments()).build();
            return new ProjectNode(this.idAllocator.getNextId(), aggregationNode.get(), assignments);
        }
        return new ProjectNode(this.idAllocator.getNextId(), aggregationNode.get(), Assignments.identity(aggregationOutputSymbols));
    }

    private static List<Symbol> getTruncatedAggregationSymbols(LateralJoinNode lateralJoinNode, AggregationNode aggregationNode) {
        HashSet<Symbol> applySymbols = new HashSet<Symbol>(lateralJoinNode.getOutputSymbols());
        return (List)aggregationNode.getOutputSymbols().stream().filter(applySymbols::contains).collect(ImmutableList.toImmutableList());
    }

    private Optional<AggregationNode> createAggregationNode(AggregationNode scalarAggregation, JoinNode leftOuterJoin, Symbol nonNullableAggregationSourceSymbol) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : scalarAggregation.getAggregations().entrySet()) {
            FunctionCall call = entry.getValue().getCall();
            Symbol symbol = entry.getKey();
            if (call.getName().equals((Object)COUNT)) {
                ImmutableList scalarAggregationSourceTypeSignatures = ImmutableList.of((Object)this.symbolAllocator.getTypes().get(nonNullableAggregationSourceSymbol).getTypeSignature());
                aggregations.put((Object)symbol, (Object)new AggregationNode.Aggregation(new FunctionCall(COUNT, (List)ImmutableList.of((Object)nonNullableAggregationSourceSymbol.toSymbolReference())), this.functionRegistry.resolveFunction(COUNT, TypeSignatureProvider.fromTypeSignatures((List<? extends TypeSignature>)scalarAggregationSourceTypeSignatures)), entry.getValue().getMask()));
                continue;
            }
            aggregations.put((Object)symbol, (Object)entry.getValue());
        }
        return Optional.of(new AggregationNode(this.idAllocator.getNextId(), leftOuterJoin, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), AggregationNode.singleGroupingSet(leftOuterJoin.getLeft().getOutputSymbols()), (List<Symbol>)ImmutableList.of(), scalarAggregation.getStep(), scalarAggregation.getHashSymbol(), Optional.empty()));
    }
}

