/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.DistinctOutputQueryUtil;
import io.trino.sql.planner.optimizations.SymbolMapper;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.Row;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

public class PushAggregationThroughOuterJoin
implements Rule<AggregationNode> {
    private static final Capture<JoinNode> JOIN = Capture.newCapture();
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().with(Patterns.source().matching(Patterns.join().capturedAs(JOIN)));

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushAggregationThroughOuterJoin(session);
    }

    @Override
    public Rule.Result apply(AggregationNode aggregation, Captures captures, Rule.Context context) {
        JoinNode rewrittenJoin;
        JoinNode join;
        block5: {
            block4: {
                Preconditions.checkArgument((boolean)aggregation.getHashSymbol().isEmpty(), (Object)"unexpected hash symbol");
                join = (JoinNode)captures.get(JOIN);
                if (join.getFilter().isPresent() || join.getType() != JoinNode.Type.LEFT && join.getType() != JoinNode.Type.RIGHT || !PushAggregationThroughOuterJoin.groupsOnAllColumns(aggregation, PushAggregationThroughOuterJoin.getOuterTable(join).getOutputSymbols())) break block4;
                if (DistinctOutputQueryUtil.isDistinct(context.getLookup().resolve(PushAggregationThroughOuterJoin.getOuterTable(join)), context.getLookup()::resolve) && PushAggregationThroughOuterJoin.isAggregationOnSymbols(aggregation, PushAggregationThroughOuterJoin.getInnerTable(join))) break block5;
            }
            return Rule.Result.empty();
        }
        List groupingKeys = (List)join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(ImmutableList.toImmutableList());
        AggregationNode rewrittenAggregation = AggregationNode.builderFrom(aggregation).setSource(PushAggregationThroughOuterJoin.getInnerTable(join)).setGroupingSets(AggregationNode.singleGroupingSet(groupingKeys)).setPreGroupedSymbols((List<Symbol>)ImmutableList.of()).build();
        Optional<PlanNode> resultNode = this.coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin = join.getType() == JoinNode.Type.LEFT ? new JoinNode(join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), join.getLeft().getOutputSymbols(), (List<Symbol>)ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost()) : new JoinNode(join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), (List<Symbol>)ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), join.getRight().getOutputSymbols(), false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost()), context.getSymbolAllocator(), context.getIdAllocator());
        if (resultNode.isEmpty()) {
            return Rule.Result.empty();
        }
        return Rule.Result.ofPlanNode(resultNode.get());
    }

    private static PlanNode getInnerTable(JoinNode join) {
        Preconditions.checkState((join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT ? 1 : 0) != 0, (Object)"expected LEFT or RIGHT JOIN");
        PlanNode innerNode = join.getType() == JoinNode.Type.LEFT ? join.getRight() : join.getLeft();
        return innerNode;
    }

    private static PlanNode getOuterTable(JoinNode join) {
        Preconditions.checkState((join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT ? 1 : 0) != 0, (Object)"expected LEFT or RIGHT JOIN");
        PlanNode outerNode = join.getType() == JoinNode.Type.LEFT ? join.getLeft() : join.getRight();
        return outerNode;
    }

    private static boolean groupsOnAllColumns(AggregationNode node, List<Symbol> columns) {
        return node.getGroupingSetCount() == 1 && new HashSet<Symbol>(node.getGroupingKeys()).equals(new HashSet<Symbol>(columns));
    }

    private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        MappedAggregationInfo aggregationOverNullInfo = this.createAggregationOverNull(aggregationNode, symbolAllocator, idAllocator);
        AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation();
        Map<Symbol, Symbol> sourceAggregationToOverNullMapping = aggregationOverNullInfo.getSymbolMapping();
        JoinNode crossJoin = new JoinNode(idAllocator.getNextId(), JoinNode.Type.INNER, outerJoin, aggregationOverNull, (List<JoinNode.EquiJoinClause>)ImmutableList.of(), outerJoin.getOutputSymbols(), aggregationOverNull.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), (Map<DynamicFilterId, Symbol>)ImmutableMap.of(), Optional.empty());
        Assignments.Builder assignmentsBuilder = Assignments.builder();
        for (Symbol symbol : outerJoin.getOutputSymbols()) {
            if (aggregationNode.getAggregations().containsKey(symbol)) {
                assignmentsBuilder.put(symbol, (Expression)new CoalesceExpression((Expression)symbol.toSymbolReference(), (Expression)sourceAggregationToOverNullMapping.get(symbol).toSymbolReference(), new Expression[0]));
                continue;
            }
            assignmentsBuilder.put(symbol, (Expression)symbol.toSymbolReference());
        }
        return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build()));
    }

    private MappedAggregationInfo createAggregationOverNull(AggregationNode referenceAggregation, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
        ImmutableList.Builder nullSymbols = ImmutableList.builder();
        ImmutableList.Builder nullLiterals = ImmutableList.builder();
        ImmutableMap.Builder sourcesSymbolMappingBuilder = ImmutableMap.builder();
        for (Symbol sourceSymbol : referenceAggregation.getSource().getOutputSymbols()) {
            Type type = symbolAllocator.getTypes().get(sourceSymbol);
            nullLiterals.add((Object)new Cast((Expression)new NullLiteral(), TypeSignatureTranslator.toSqlType(type)));
            Symbol nullSymbol = symbolAllocator.newSymbol("null", type);
            nullSymbols.add((Object)nullSymbol);
            sourcesSymbolMappingBuilder.put((Object)sourceSymbol, (Object)nullSymbol);
        }
        ValuesNode nullRow = new ValuesNode(idAllocator.getNextId(), (List<Symbol>)nullSymbols.build(), (List<Expression>)ImmutableList.of((Object)new Row((List)nullLiterals.build())));
        ImmutableMap.Builder aggregationsSymbolMappingBuilder = ImmutableMap.builder();
        ImmutableMap.Builder aggregationsOverNullBuilder = ImmutableMap.builder();
        SymbolMapper mapper = SymbolMapper.symbolMapper((Map<Symbol, Symbol>)sourcesSymbolMappingBuilder.buildOrThrow());
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
            Symbol aggregationSymbol = entry.getKey();
            AggregationNode.Aggregation overNullAggregation = mapper.map(entry.getValue());
            Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getResolvedFunction().getSignature().getName(), symbolAllocator.getTypes().get(aggregationSymbol));
            aggregationsOverNullBuilder.put((Object)overNullSymbol, (Object)overNullAggregation);
            aggregationsSymbolMappingBuilder.put((Object)aggregationSymbol, (Object)overNullSymbol);
        }
        ImmutableMap aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.buildOrThrow();
        AggregationNode aggregationOverNullRow = AggregationNode.singleAggregation(idAllocator.getNextId(), nullRow, (Map<Symbol, AggregationNode.Aggregation>)aggregationsOverNullBuilder.buildOrThrow(), AggregationNode.globalAggregation());
        return new MappedAggregationInfo(aggregationOverNullRow, (Map<Symbol, Symbol>)aggregationsSymbolMapping);
    }

    private static boolean isAggregationOnSymbols(AggregationNode aggregationNode, PlanNode source) {
        ImmutableSet sourceSymbols = ImmutableSet.copyOf(source.getOutputSymbols());
        return aggregationNode.getAggregations().values().stream().allMatch(arg_0 -> PushAggregationThroughOuterJoin.lambda$isAggregationOnSymbols$0((Set)sourceSymbols, arg_0));
    }

    private static /* synthetic */ boolean lambda$isAggregationOnSymbols$0(Set sourceSymbols, AggregationNode.Aggregation aggregation) {
        return sourceSymbols.containsAll(SymbolsExtractor.extractUnique(aggregation));
    }

    private static class MappedAggregationInfo {
        private final AggregationNode aggregationNode;
        private final Map<Symbol, Symbol> symbolMapping;

        public MappedAggregationInfo(AggregationNode aggregationNode, Map<Symbol, Symbol> symbolMapping) {
            this.aggregationNode = aggregationNode;
            this.symbolMapping = symbolMapping;
        }

        public Map<Symbol, Symbol> getSymbolMapping() {
            return this.symbolMapping;
        }

        public AggregationNode getAggregation() {
            return this.aggregationNode;
        }
    }
}

