/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.functions;

import io.trino.hive.$internal.com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.CanAggregateDistinct;
import org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction;

public class HiveSqlSumAggFunction
extends SqlAggFunction
implements CanAggregateDistinct {
    final boolean isDistinct;
    final SqlReturnTypeInference returnTypeInference;
    final SqlOperandTypeInference operandTypeInference;
    final SqlOperandTypeChecker operandTypeChecker;

    public HiveSqlSumAggFunction(boolean isDistinct, SqlReturnTypeInference returnTypeInference, SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker) {
        super("sum", SqlKind.SUM, returnTypeInference, operandTypeInference, operandTypeChecker, SqlFunctionCategory.NUMERIC);
        this.returnTypeInference = returnTypeInference;
        this.operandTypeChecker = operandTypeChecker;
        this.operandTypeInference = operandTypeInference;
        this.isDistinct = isDistinct;
    }

    @Override
    public boolean isDistinct() {
        return this.isDistinct;
    }

    public <T> T unwrap(Class<T> clazz) {
        if (clazz == SqlSplittableAggFunction.class) {
            return clazz.cast((Object)new HiveSumSplitter());
        }
        return (T)super.unwrap(clazz);
    }

    class HiveSumSplitter
    extends SqlSplittableAggFunction.SumSplitter {
        HiveSumSplitter() {
        }

        public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
            RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
            return AggregateCall.create((SqlAggFunction)new HiveSqlCountAggFunction(HiveSqlSumAggFunction.this.isDistinct, (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)countRetType), HiveSqlSumAggFunction.this.operandTypeInference, HiveSqlSumAggFunction.this.operandTypeChecker), (boolean)false, (List)ImmutableIntList.of(), (int)-1, (RelDataType)countRetType, (String)"count");
        }

        public AggregateCall topSplit(RexBuilder rexBuilder, SqlSplittableAggFunction.Registry<RexNode> extra, int offset, RelDataType inputRowType, AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) {
            RexNode node;
            RelDataType type;
            ArrayList<RexInputRef> merges = new ArrayList<RexInputRef>();
            List fieldList = inputRowType.getFieldList();
            if (leftSubTotal >= 0) {
                type = ((RelDataTypeField)fieldList.get(leftSubTotal)).getType();
                merges.add(rexBuilder.makeInputRef(type, leftSubTotal));
            }
            if (rightSubTotal >= 0) {
                type = ((RelDataTypeField)fieldList.get(rightSubTotal)).getType();
                merges.add(rexBuilder.makeInputRef(type, rightSubTotal));
            }
            switch (merges.size()) {
                case 1: {
                    node = (RexNode)merges.get(0);
                    break;
                }
                case 2: {
                    node = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, merges);
                    node = rexBuilder.makeAbstractCast(aggregateCall.type, node);
                    break;
                }
                default: {
                    throw new AssertionError((Object)("unexpected count " + merges));
                }
            }
            int ordinal = extra.register((Object)node);
            return AggregateCall.create((SqlAggFunction)new HiveSqlSumAggFunction(HiveSqlSumAggFunction.this.isDistinct, HiveSqlSumAggFunction.this.returnTypeInference, HiveSqlSumAggFunction.this.operandTypeInference, HiveSqlSumAggFunction.this.operandTypeChecker), (boolean)false, ImmutableList.of(Integer.valueOf(ordinal)), (int)-1, (RelDataType)aggregateCall.type, (String)aggregateCall.name);
        }
    }
}

