/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.SourceLocation;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class RewriteSpatialPartitioningAggregation
implements Rule<AggregationNode> {
    private static final TypeSignature GEOMETRY_TYPE_SIGNATURE = TypeSignature.parseTypeSignature((String)"Geometry");
    private static final QualifiedObjectName NAME = QualifiedObjectName.valueOf((CatalogSchemaName)BuiltInTypeAndFunctionNamespaceManager.DEFAULT_NAMESPACE, (String)"spatial_partitioning");
    private final Pattern<AggregationNode> pattern = Patterns.aggregation().matching(this::hasSpatialPartitioningAggregation);
    private final Metadata metadata;

    public RewriteSpatialPartitioningAggregation(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    private boolean hasSpatialPartitioningAggregation(AggregationNode aggregationNode) {
        return aggregationNode.getAggregations().values().stream().anyMatch(aggregation -> this.metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName().equals((Object)NAME) && aggregation.getArguments().size() == 1);
    }

    @Override
    public Pattern<AggregationNode> getPattern() {
        return this.pattern;
    }

    @Override
    public Rule.Result apply(AggregationNode node, Captures captures, Rule.Context context) {
        ImmutableMap.Builder aggregations = ImmutableMap.builder();
        VariableReferenceExpression partitionCountVariable = context.getVariableAllocator().newVariable("partition_count", (Type)IntegerType.INTEGER);
        ImmutableMap.Builder envelopeAssignments = ImmutableMap.builder();
        for (Map.Entry entry : node.getAggregations().entrySet()) {
            AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation)entry.getValue();
            QualifiedObjectName name = this.metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName();
            Type geometryType = this.metadata.getType(GEOMETRY_TYPE_SIGNATURE);
            if (name.equals((Object)NAME) && aggregation.getArguments().size() == 1) {
                RowExpression geometry = (RowExpression)Iterables.getOnlyElement((Iterable)aggregation.getArguments());
                VariableReferenceExpression envelopeVariable = context.getVariableAllocator().newVariable((Optional<SourceLocation>)aggregation.getCall().getSourceLocation(), "envelope", geometryType);
                if (RewriteSpatialPartitioningAggregation.isFunctionNameMatch(geometry, "ST_Envelope")) {
                    envelopeAssignments.put((Object)envelopeVariable, (Object)geometry);
                } else {
                    envelopeAssignments.put((Object)envelopeVariable, (Object)OriginalExpressionUtils.castToRowExpression((Expression)new FunctionCall(QualifiedName.of((String)"ST_Envelope"), (List)ImmutableList.of((Object)OriginalExpressionUtils.castToExpression(geometry)))));
                }
                aggregations.put(entry.getKey(), (Object)new AggregationNode.Aggregation(new CallExpression(envelopeVariable.getSourceLocation(), name.getObjectName(), this.metadata.getFunctionAndTypeManager().lookupFunction(NAME.getObjectName(), TypeSignatureProvider.fromTypes(new Type[]{geometryType, IntegerType.INTEGER})), ((VariableReferenceExpression)entry.getKey()).getType(), (List)ImmutableList.of((Object)OriginalExpressionUtils.castToRowExpression((Expression)OriginalExpressionUtils.asSymbolReference(envelopeVariable)), (Object)OriginalExpressionUtils.castToRowExpression((Expression)OriginalExpressionUtils.asSymbolReference(partitionCountVariable)))), Optional.empty(), Optional.empty(), false, aggregation.getMask()));
                continue;
            }
            aggregations.put(entry);
        }
        return Rule.Result.ofPlanNode((PlanNode)new AggregationNode(node.getSourceLocation(), node.getId(), (PlanNode)new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putAll(AssignmentUtils.identitiesAsSymbolReferences(node.getSource().getOutputVariables())).put(partitionCountVariable, OriginalExpressionUtils.castToRowExpression((Expression)new LongLiteral(Integer.toString(SystemSessionProperties.getHashPartitionCount(context.getSession()))))).putAll((Map)envelopeAssignments.build()).build()), (Map)aggregations.build(), node.getGroupingSets(), node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable()));
    }

    private static boolean isFunctionNameMatch(RowExpression rowExpression, String expectedName) {
        if (OriginalExpressionUtils.castToExpression(rowExpression) instanceof FunctionCall) {
            return ((FunctionCall)OriginalExpressionUtils.castToExpression(rowExpression)).getName().toString().equalsIgnoreCase(expectedName);
        }
        return false;
    }
}

