/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.sql;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.mapping.Mappings;

public interface SqlSplittableAggFunction {
    public AggregateCall split(AggregateCall var1, Mappings.TargetMapping var2);

    public AggregateCall other(RelDataTypeFactory var1, AggregateCall var2);

    public AggregateCall topSplit(RexBuilder var1, Registry<RexNode> var2, int var3, RelDataType var4, AggregateCall var5, int var6, int var7);

    public RexNode singleton(RexBuilder var1, RelDataType var2, AggregateCall var3);

    public AggregateCall merge(AggregateCall var1, AggregateCall var2);

    public static class Sum0Splitter
    extends AbstractSumSplitter {
        public static final Sum0Splitter INSTANCE = new Sum0Splitter();

        @Override
        public SqlAggFunction getMergeAggFunctionOfTopSplit() {
            return SqlStdOperatorTable.SUM0;
        }

        @Override
        public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall) {
            int arg = aggregateCall.getArgList().get(0);
            RelDataType type = inputRowType.getFieldList().get(arg).getType();
            RexInputRef inputRef = rexBuilder.makeInputRef(type, arg);
            if (type.isNullable()) {
                return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.COALESCE, inputRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO, type));
            }
            return inputRef;
        }
    }

    public static class SumSplitter
    extends AbstractSumSplitter {
        public static final SumSplitter INSTANCE = new SumSplitter();

        @Override
        public SqlAggFunction getMergeAggFunctionOfTopSplit() {
            return SqlStdOperatorTable.SUM;
        }
    }

    public static abstract class AbstractSumSplitter
    implements SqlSplittableAggFunction {
        @Override
        public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall) {
            int arg = aggregateCall.getArgList().get(0);
            RelDataTypeField field = inputRowType.getFieldList().get(arg);
            return rexBuilder.makeInputRef(field.getType(), arg);
        }

        @Override
        public AggregateCall split(AggregateCall aggregateCall, Mappings.TargetMapping mapping) {
            return aggregateCall.transform(mapping);
        }

        @Override
        public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
            return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, false, ImmutableIntList.of(), -1, RelCollations.EMPTY, typeFactory.createSqlType(SqlTypeName.BIGINT), null);
        }

        @Override
        public AggregateCall topSplit(RexBuilder rexBuilder, Registry<RexNode> extra, int offset, RelDataType inputRowType, AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) {
            RexNode node;
            RelDataType type;
            ArrayList<RexInputRef> merges = new ArrayList<RexInputRef>();
            List<RelDataTypeField> fieldList = inputRowType.getFieldList();
            if (leftSubTotal >= 0) {
                type = fieldList.get(leftSubTotal).getType();
                merges.add(rexBuilder.makeInputRef(type, leftSubTotal));
            }
            if (rightSubTotal >= 0) {
                type = fieldList.get(rightSubTotal).getType();
                merges.add(rexBuilder.makeInputRef(type, rightSubTotal));
            }
            switch (merges.size()) {
                case 1: {
                    node = (RexNode)merges.get(0);
                    break;
                }
                case 2: {
                    node = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, merges);
                    node = rexBuilder.makeAbstractCast(aggregateCall.type, node);
                    break;
                }
                default: {
                    throw new AssertionError((Object)("unexpected count " + merges));
                }
            }
            int ordinal = extra.register(node);
            return AggregateCall.create(this.getMergeAggFunctionOfTopSplit(), false, false, false, ImmutableList.of(Integer.valueOf(ordinal)), -1, aggregateCall.collation, aggregateCall.type, aggregateCall.name);
        }

        @Override
        public AggregateCall merge(AggregateCall top, AggregateCall bottom) {
            SqlKind topKind = top.getAggregation().getKind();
            if (topKind == bottom.getAggregation().getKind() && (topKind == SqlKind.SUM || topKind == SqlKind.SUM0)) {
                return AggregateCall.create(bottom.getAggregation(), bottom.isDistinct(), bottom.isApproximate(), false, bottom.getArgList(), bottom.filterArg, bottom.getCollation(), bottom.getType(), top.getName());
            }
            return null;
        }

        protected abstract SqlAggFunction getMergeAggFunctionOfTopSplit();
    }

    public static class SelfSplitter
    implements SqlSplittableAggFunction {
        public static final SelfSplitter INSTANCE = new SelfSplitter();

        @Override
        public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall) {
            int arg = aggregateCall.getArgList().get(0);
            RelDataTypeField field = inputRowType.getFieldList().get(arg);
            return rexBuilder.makeInputRef(field.getType(), arg);
        }

        @Override
        public AggregateCall split(AggregateCall aggregateCall, Mappings.TargetMapping mapping) {
            return aggregateCall.transform(mapping);
        }

        @Override
        public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
            return null;
        }

        @Override
        public AggregateCall topSplit(RexBuilder rexBuilder, Registry<RexNode> extra, int offset, RelDataType inputRowType, AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) {
            assert (leftSubTotal >= 0 != rightSubTotal >= 0);
            assert (aggregateCall.collation.getFieldCollations().isEmpty());
            int arg = leftSubTotal >= 0 ? leftSubTotal : rightSubTotal;
            return aggregateCall.copy(ImmutableIntList.of(arg), -1, RelCollations.EMPTY);
        }

        @Override
        public AggregateCall merge(AggregateCall top, AggregateCall bottom) {
            if (top.getAggregation().getKind() == bottom.getAggregation().getKind()) {
                return AggregateCall.create(bottom.getAggregation(), bottom.isDistinct(), bottom.isApproximate(), false, bottom.getArgList(), bottom.filterArg, bottom.getCollation(), bottom.getType(), top.getName());
            }
            return null;
        }
    }

    public static class CountSplitter
    implements SqlSplittableAggFunction {
        public static final CountSplitter INSTANCE = new CountSplitter();

        @Override
        public AggregateCall split(AggregateCall aggregateCall, Mappings.TargetMapping mapping) {
            return aggregateCall.transform(mapping);
        }

        @Override
        public AggregateCall other(RelDataTypeFactory typeFactory, AggregateCall e) {
            return AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, false, ImmutableIntList.of(), -1, RelCollations.EMPTY, typeFactory.createSqlType(SqlTypeName.BIGINT), null);
        }

        @Override
        public AggregateCall topSplit(RexBuilder rexBuilder, Registry<RexNode> extra, int offset, RelDataType inputRowType, AggregateCall aggregateCall, int leftSubTotal, int rightSubTotal) {
            RexNode node;
            ArrayList<RexInputRef> merges = new ArrayList<RexInputRef>();
            if (leftSubTotal >= 0) {
                merges.add(rexBuilder.makeInputRef(aggregateCall.type, leftSubTotal));
            }
            if (rightSubTotal >= 0) {
                merges.add(rexBuilder.makeInputRef(aggregateCall.type, rightSubTotal));
            }
            switch (merges.size()) {
                case 1: {
                    node = (RexNode)merges.get(0);
                    break;
                }
                case 2: {
                    node = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, merges);
                    break;
                }
                default: {
                    throw new AssertionError((Object)("unexpected count " + merges));
                }
            }
            int ordinal = extra.register(node);
            return AggregateCall.create(SqlStdOperatorTable.SUM0, false, false, false, ImmutableList.of(Integer.valueOf(ordinal)), -1, aggregateCall.collation, aggregateCall.type, aggregateCall.name);
        }

        @Override
        public RexNode singleton(RexBuilder rexBuilder, RelDataType inputRowType, AggregateCall aggregateCall) {
            ArrayList<RexNode> predicates = new ArrayList<RexNode>();
            for (Integer arg : aggregateCall.getArgList()) {
                RelDataType type = inputRowType.getFieldList().get(arg).getType();
                if (!type.isNullable()) continue;
                predicates.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, rexBuilder.makeInputRef(type, (int)arg)));
            }
            RexNode predicate = RexUtil.composeConjunction(rexBuilder, predicates, true);
            RexLiteral rexOne = rexBuilder.makeExactLiteral(BigDecimal.ONE, aggregateCall.getType());
            if (predicate == null) {
                return rexOne;
            }
            return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, predicate, rexOne, rexBuilder.makeExactLiteral(BigDecimal.ZERO, aggregateCall.getType()));
        }

        @Override
        public AggregateCall merge(AggregateCall top, AggregateCall bottom) {
            if (bottom.getAggregation().getKind() == SqlKind.COUNT && top.getAggregation().getKind() == SqlKind.SUM) {
                return AggregateCall.create(bottom.getAggregation(), bottom.isDistinct(), bottom.isApproximate(), false, bottom.getArgList(), bottom.filterArg, bottom.getCollation(), bottom.getType(), top.getName());
            }
            return null;
        }
    }

    public static interface Registry<E> {
        public int register(E var1);
    }
}

