/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.SystemSessionProperties;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.FunctionCallBuilder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.QualifiedName;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class RewriteSpatialPartitioningAggregation
implements Rule<AggregationNode> {
    private static final TypeSignature GEOMETRY_TYPE_SIGNATURE = new TypeSignature("Geometry", new TypeSignatureParameter[0]);
    private static final String NAME = "spatial_partitioning";
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(RewriteSpatialPartitioningAggregation::hasSpatialPartitioningAggregation);
    private final PlannerContext plannerContext;

    public RewriteSpatialPartitioningAggregation(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    private static boolean hasSpatialPartitioningAggregation(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().anyMatch(aggregation -> aggregation.getResolvedFunction().getSignature().getName().equals(NAME) && aggregation.getArguments().size() == 1);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode node, Captures captures, Rule.Context context) {
        ResolvedFunction spatialPartitioningFunction = this.plannerContext.getMetadata().resolveFunction(context.getSession(), QualifiedName.of((String)NAME), TypeSignatureProvider.fromTypeSignatures(GEOMETRY_TYPE_SIGNATURE, IntegerType.INTEGER.getTypeSignature()));
        ResolvedFunction stEnvelopeFunction = this.plannerContext.getMetadata().resolveFunction(context.getSession(), QualifiedName.of((String)"ST_Envelope"), TypeSignatureProvider.fromTypeSignatures(GEOMETRY_TYPE_SIGNATURE));
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        Symbol partitionCountSymbol = context.getSymbolAllocator().newSymbol("partition_count", (Type)IntegerType.INTEGER);
        ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            String name = aggregation.getResolvedFunction().getSignature().getName();
            if (name.equals(NAME) && aggregation.getArguments().size() == 1) {
                Expression geometry = (Expression)Iterables.getOnlyElement(aggregation.getArguments());
                Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", this.plannerContext.getTypeManager().getType(GEOMETRY_TYPE_SIGNATURE));
                if (this.isStEnvelopeFunctionCall(geometry, stEnvelopeFunction)) {
                    envelopeAssignments.put((Object)envelopeSymbol, (Object)geometry);
                } else {
                    envelopeAssignments.put((Object)envelopeSymbol, (Object)FunctionCallBuilder.resolve(context.getSession(), this.plannerContext.getMetadata()).setName(QualifiedName.of((String)"ST_Envelope")).addArgument(GEOMETRY_TYPE_SIGNATURE, geometry).build());
                }
                aggregations.put((Object)entry.getKey(), (Object)new AggregationNode.Aggregation(spatialPartitioningFunction, (List<Expression>)ImmutableList.of((Object)envelopeSymbol.toSymbolReference(), (Object)partitionCountSymbol.toSymbolReference()), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
                continue;
            }
            aggregations.put(entry);
        }
        return Rule.Result.ofPlanNode(new AggregationNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putIdentities(node.getSource().getOutputSymbols()).put(partitionCountSymbol, (Expression)new LongLiteral(Integer.toString(SystemSessionProperties.getHashPartitionCount(context.getSession())))).putAll((Map<Symbol, ? extends Expression>)envelopeAssignments.buildOrThrow()).build()), (Map<Symbol, AggregationNode.Aggregation>)aggregations.buildOrThrow(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()));
    }

    private boolean isStEnvelopeFunctionCall(Expression expression, ResolvedFunction stEnvelopeFunction) {
        if (!(expression instanceof FunctionCall)) {
            return false;
        }
        FunctionCall functionCall = (FunctionCall)expression;
        return this.plannerContext.getMetadata().decodeFunction(functionCall.getName()).getFunctionId().equals(stEnvelopeFunction.getFunctionId());
    }
}

