/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.prestosql.SystemSessionProperties;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.metadata.FunctionKind;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.spi.type.IntegerType;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.VarcharType;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.FunctionCall;
import io.prestosql.sql.tree.LongLiteral;
import io.prestosql.sql.tree.QualifiedName;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class RewriteSpatialPartitioningAggregation
implements Rule<AggregationNode> {
    private static final TypeSignature GEOMETRY_TYPE_SIGNATURE = TypeSignature.parseTypeSignature((String)"Geometry");
    private static final String NAME = "spatial_partitioning";
    private static final Signature INTERNAL_SIGNATURE = new Signature("spatial_partitioning", FunctionKind.AGGREGATE, VarcharType.VARCHAR.getTypeSignature(), GEOMETRY_TYPE_SIGNATURE, IntegerType.INTEGER.getTypeSignature());
    private static final Pattern<AggregationNode> PATTERN = Patterns.aggregation().matching(RewriteSpatialPartitioningAggregation::hasSpatialPartitioningAggregation);
    private final Metadata metadata;

    public RewriteSpatialPartitioningAggregation(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    private static boolean hasSpatialPartitioningAggregation(AggregationNode aggregation) {
        return aggregation.getAggregations().values().stream().map(AggregationNode.Aggregation::getCall).anyMatch(call -> call.getName().toString().equals(NAME) && call.getArguments().size() == 1);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(AggregationNode node, Captures captures, Rule.Context context) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        Symbol partitionCountSymbol = context.getSymbolAllocator().newSymbol("partition_count", (Type)IntegerType.INTEGER);
        ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = entry.getValue();
            FunctionCall call = aggregation.getCall();
            QualifiedName name = call.getName();
            if (name.toString().equals(NAME) && call.getArguments().size() == 1) {
                Expression geometry = (Expression)Iterables.getOnlyElement((Iterable)call.getArguments());
                Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", this.metadata.getType(GEOMETRY_TYPE_SIGNATURE));
                if (geometry instanceof FunctionCall && ((FunctionCall)geometry).getName().toString().equalsIgnoreCase("ST_Envelope")) {
                    envelopeAssignments.put((Object)envelopeSymbol, (Object)geometry);
                } else {
                    envelopeAssignments.put((Object)envelopeSymbol, (Object)new FunctionCall(QualifiedName.of((String)"ST_Envelope"), (List)ImmutableList.of((Object)geometry)));
                }
                aggregations.put((Object)entry.getKey(), (Object)new AggregationNode.Aggregation(new FunctionCall(name, (List)ImmutableList.of((Object)envelopeSymbol.toSymbolReference(), (Object)partitionCountSymbol.toSymbolReference())), INTERNAL_SIGNATURE, aggregation.getMask()));
                continue;
            }
            aggregations.put(entry);
        }
        return Rule.Result.ofPlanNode(new AggregationNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putIdentities(node.getSource().getOutputSymbols()).put(partitionCountSymbol, (Expression)new LongLiteral(Integer.toString(SystemSessionProperties.getHashPartitionCount(context.getSession())))).putAll((Map<Symbol, Expression>)envelopeAssignments.build()).build()), (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()));
    }
}

