/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.jdbc.aggregation;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.aggregation.AggregateFunctionPatterns;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionPatterns;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.DecimalType;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;

public class ImplementSum
implements AggregateFunctionRule<JdbcExpression, ParameterizedExpression> {
    private static final Capture<Variable> ARGUMENT = Capture.newCapture();
    private final Function<DecimalType, Optional<JdbcTypeHandle>> decimalTypeHandle;

    public ImplementSum(Function<DecimalType, Optional<JdbcTypeHandle>> decimalTypeHandle) {
        this.decimalTypeHandle = Objects.requireNonNull(decimalTypeHandle, "decimalTypeHandle is null");
    }

    public Pattern<AggregateFunction> getPattern() {
        return Pattern.typeOf(AggregateFunction.class).with(AggregateFunctionPatterns.hasSortOrder().equalTo((Object)false)).with(AggregateFunctionPatterns.hasFilter().equalTo((Object)false)).with(AggregateFunctionPatterns.functionName().equalTo((Object)"sum")).with(AggregateFunctionPatterns.singleArgument().matching(ConnectorExpressionPatterns.variable().capturedAs(ARGUMENT)));
    }

    public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, AggregateFunctionRule.RewriteContext<ParameterizedExpression> context) {
        JdbcTypeHandle resultTypeHandle;
        Variable argument = (Variable)captures.get(ARGUMENT);
        JdbcColumnHandle columnHandle = (JdbcColumnHandle)context.getAssignment(argument.getName());
        if (columnHandle.getColumnType().equals((Object)aggregateFunction.getOutputType())) {
            resultTypeHandle = columnHandle.getJdbcTypeHandle();
        } else if (aggregateFunction.getOutputType() instanceof DecimalType) {
            Optional<JdbcTypeHandle> decimalTypeHandle = this.decimalTypeHandle.apply((DecimalType)aggregateFunction.getOutputType());
            if (decimalTypeHandle.isEmpty()) {
                return Optional.empty();
            }
            resultTypeHandle = decimalTypeHandle.get();
        } else {
            return Optional.empty();
        }
        ParameterizedExpression rewrittenArgument = (ParameterizedExpression)context.rewriteExpression((ConnectorExpression)argument).orElseThrow();
        String function = aggregateFunction.isDistinct() ? "sum(DISTINCT %s)" : "sum(%s)";
        return Optional.of(new JdbcExpression(String.format(function, rewrittenArgument.expression()), rewrittenArgument.parameters(), resultTypeHandle));
    }
}

