/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.relnode;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.adapter.enumerable.EnumerableAggregate;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.InvalidRelException;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.schema.AggregateFunction;
import org.apache.calcite.schema.FunctionParameter;
import org.apache.calcite.schema.impl.AggregateFunctionImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;
import org.apache.kylin.measure.MeasureTypeFactory;
import org.apache.kylin.measure.basic.BasicMeasureType;
import org.apache.kylin.measure.topn.TopNMeasureType;
import org.apache.kylin.metadata.model.FunctionDesc;
import org.apache.kylin.metadata.model.MeasureDesc;
import org.apache.kylin.metadata.model.ParameterDesc;
import org.apache.kylin.metadata.model.TblColRef;
import org.apache.kylin.query.relnode.ColumnRowType;
import org.apache.kylin.query.relnode.KylinAggregateCall;
import org.apache.kylin.query.relnode.OLAPContext;
import org.apache.kylin.query.relnode.OLAPProjectRel;
import org.apache.kylin.query.relnode.OLAPRel;
import org.apache.kylin.query.schema.OLAPTable;

public class OLAPAggregateRel
extends Aggregate
implements OLAPRel {
    static final Map<String, String> AGGR_FUNC_MAP = new HashMap<String, String>();
    protected OLAPContext context;
    protected ColumnRowType columnRowType;
    protected boolean afterAggregate;
    protected List<AggregateCall> rewriteAggCalls;
    protected List<TblColRef> groups;
    protected List<FunctionDesc> aggregations;

    public OLAPAggregateRel(RelOptCluster cluster, RelTraitSet traits, RelNode child, boolean indicator, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) throws InvalidRelException {
        super(cluster, traits, child, indicator, groupSet, groupSets, aggCalls);
        Preconditions.checkArgument((this.getConvention() == OLAPRel.CONVENTION ? 1 : 0) != 0);
        this.afterAggregate = false;
        this.rewriteAggCalls = aggCalls;
        this.rowType = this.getRowType();
    }

    static String getSqlFuncName(AggregateCall aggCall) {
        String sqlName = aggCall.getAggregation().getName();
        if (aggCall.isDistinct()) {
            sqlName = sqlName + "_DISTINCT";
        }
        return sqlName;
    }

    public static String getAggrFuncName(AggregateCall aggCall) {
        if (SqlKind.SINGLE_VALUE == aggCall.getAggregation().kind) {
            return SqlKind.SINGLE_VALUE.sql;
        }
        String sqlName = OLAPAggregateRel.getSqlFuncName(aggCall);
        String funcName = AGGR_FUNC_MAP.get(sqlName);
        if (funcName == null) {
            throw new IllegalStateException("Non-support aggregation " + sqlName);
        }
        return funcName;
    }

    public Aggregate copy(RelTraitSet traitSet, RelNode input, boolean indicator, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
        try {
            return new OLAPAggregateRel(this.getCluster(), traitSet, input, indicator, groupSet, groupSets, aggCalls);
        }
        catch (InvalidRelException e) {
            throw new IllegalStateException("Can't create OLAPAggregateRel!", e);
        }
    }

    public RelOptCost computeSelfCost(RelOptPlanner planner, RelMetadataQuery mq) {
        RelOptCost cost = this.getGroupType() == Aggregate.Group.SIMPLE ? super.computeSelfCost(planner, mq).multiplyBy(0.05) : super.computeSelfCost(planner, mq).multiplyBy(0.05).plus(planner.getCost(this.getInput(), mq)).multiplyBy((double)this.groupSets.size() * 1.5);
        return cost;
    }

    @Override
    public void implementOLAP(OLAPRel.OLAPImplementor implementor) {
        implementor.fixSharedOlapTableScan((SingleRel)this);
        implementor.visitChild(this.getInput(), this);
        this.context = implementor.getContext();
        this.columnRowType = this.buildColumnRowType();
        this.afterAggregate = this.context.afterAggregate;
        if (!this.afterAggregate) {
            this.addToContextGroupBy(this.groups);
            this.context.aggregations.addAll(this.aggregations);
            this.context.afterAggregate = true;
            if (this.context.afterLimit) {
                this.context.limitPrecedesAggr = true;
            }
        } else {
            for (AggregateCall aggCall : this.aggCalls) {
                if (!aggCall.isDistinct()) continue;
                throw new IllegalStateException("Distinct count is only allowed in innermost sub-query.");
            }
        }
    }

    public ColumnRowType buildColumnRowType() {
        this.buildGroups();
        this.buildAggregations();
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        ArrayList columns = Lists.newArrayListWithCapacity((int)this.rowType.getFieldCount());
        columns.addAll(this.getGroupColsOfColumnRowType());
        if (this.indicator) {
            HashSet containedNames = Sets.newHashSet();
            for (TblColRef groupCol : this.groups) {
                String base;
                String name = base = "i$" + groupCol.getName();
                int i = 0;
                while (containedNames.contains(name)) {
                    name = base + "_" + i++;
                }
                containedNames.add(name);
                TblColRef indicatorCol = TblColRef.newInnerColumn((String)name, (TblColRef.InnerDataTypeEnum)TblColRef.InnerDataTypeEnum.LITERAL);
                columns.add(indicatorCol);
            }
        }
        for (int i = 0; i < this.aggregations.size(); ++i) {
            String aggOutName;
            FunctionDesc aggFunc = this.aggregations.get(i);
            ArrayList operands = Lists.newArrayList();
            if (aggFunc != null) {
                operands.addAll(aggFunc.getColRefs());
                aggOutName = aggFunc.getRewriteFieldName();
            } else {
                AggregateCall aggCall = this.rewriteAggCalls.get(i);
                int index = (Integer)aggCall.getArgList().get(0);
                aggOutName = OLAPAggregateRel.getSqlFuncName(aggCall) + "_" + inputColumnRowType.getColumnByIndex(index).getIdentity().replace('.', '_') + "_";
                aggCall.getArgList().forEach(argIndex -> operands.add(inputColumnRowType.getColumnByIndex((int)argIndex)));
            }
            TblColRef aggOutCol = TblColRef.newInnerColumn((String)aggOutName, (TblColRef.InnerDataTypeEnum)TblColRef.InnerDataTypeEnum.AGGREGATION_TYPE);
            aggOutCol.setOperator((SqlOperator)this.rewriteAggCalls.get(i).getAggregation());
            aggOutCol.setOperands((List)operands);
            aggOutCol.getColumnDesc().setId("" + (i + 1));
            columns.add(aggOutCol);
        }
        Preconditions.checkState((columns.size() == this.rowType.getFieldCount() ? 1 : 0) != 0);
        return new ColumnRowType(columns);
    }

    TblColRef buildRewriteColumn(FunctionDesc aggFunc) {
        if (!aggFunc.needRewriteField()) {
            throw new IllegalStateException("buildRewriteColumn on a aggrFunc that does not need rewrite " + aggFunc);
        }
        String colName = aggFunc.getRewriteFieldName();
        TblColRef colRef = this.context.firstTableScan.makeRewriteColumn(colName);
        return colRef;
    }

    protected void buildGroups() {
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        this.groups = new ArrayList<TblColRef>();
        int i = this.getGroupSet().nextSetBit(0);
        while (i >= 0) {
            Set<TblColRef> columns = inputColumnRowType.getSourceColumnsByIndex(i);
            this.groups.addAll(columns);
            i = this.getGroupSet().nextSetBit(i + 1);
        }
    }

    public List<TblColRef> getGroupColsOfColumnRowType() {
        ArrayList allColumns = Lists.newArrayList();
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        int i = this.getGroupSet().nextSetBit(0);
        while (i >= 0) {
            TblColRef tblColRef = inputColumnRowType.getColumnByIndex(i);
            allColumns.add(tblColRef);
            i = this.getGroupSet().nextSetBit(i + 1);
        }
        return allColumns;
    }

    void buildAggregations() {
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        this.aggregations = new ArrayList<FunctionDesc>();
        for (AggregateCall aggCall : this.rewriteAggCalls) {
            ArrayList parameters = Lists.newArrayList();
            if (!aggCall.getArgList().isEmpty()) {
                ArrayList columns = Lists.newArrayList();
                List args = Lists.newArrayList();
                if ("PERCENTILE".equals(OLAPAggregateRel.getSqlFuncName(aggCall)) || "PERCENTILE_APPROX".equals(OLAPAggregateRel.getSqlFuncName(aggCall))) {
                    args.add(aggCall.getArgList().get(0));
                } else {
                    args = aggCall.getArgList();
                }
                for (Integer index : args) {
                    TblColRef column2 = inputColumnRowType.getColumnByIndex(index);
                    if ("SUM".equals(OLAPAggregateRel.getSqlFuncName(aggCall))) {
                        column2 = this.rewriteCastInSumIfNecessary(aggCall, inputColumnRowType, index);
                    }
                    columns.add(column2);
                }
                if (!columns.isEmpty()) {
                    columns.forEach(column -> parameters.add(ParameterDesc.newInstance((Object)column)));
                }
            }
            String expression = OLAPAggregateRel.getAggrFuncName(aggCall);
            FunctionDesc aggFunc = FunctionDesc.newInstance((String)expression, (List)parameters, null);
            this.aggregations.add(aggFunc);
        }
    }

    private TblColRef rewriteCastInSumIfNecessary(AggregateCall aggCall, ColumnRowType inputColumnRowType, Integer index) {
        TblColRef innerColumn;
        TblColRef column = inputColumnRowType.getColumnByIndex(index);
        if (this.getInput() instanceof OLAPProjectRel && SqlTypeUtil.isBigint((RelDataType)aggCall.type) && column.isCastInnerColumn() && !(innerColumn = (TblColRef)column.getOperands().get(0)).isInnerColumn() && innerColumn.getType().isIntegerFamily()) {
            inputColumnRowType.getAllColumns().set(index, innerColumn);
            column = inputColumnRowType.getColumnByIndex(index);
        }
        return column;
    }

    public boolean needRewrite() {
        return this.context.realization != null && !this.afterAggregate && !this.context.isAnsweredByTableIndex();
    }

    @Override
    public void implementRewrite(OLAPRel.RewriteImplementor implementor) {
        if (this.needRewrite()) {
            this.translateAggregation();
            this.buildRewriteFieldsAndMetricsColumns();
        }
        implementor.visitChild(this, this.getInput());
        if (this.needRewrite()) {
            this.rewriteAggCalls = new ArrayList<AggregateCall>(this.aggCalls.size());
            for (int i = 0; i < this.aggCalls.size(); ++i) {
                AggregateCall aggCall = (AggregateCall)this.aggCalls.get(i);
                if (SqlStdOperatorTable.GROUPING == aggCall.getAggregation()) {
                    this.rewriteAggCalls.add(aggCall);
                    continue;
                }
                FunctionDesc cubeFunc = this.context.aggregations.get(i);
                aggCall = this.rewriteAggCall(aggCall, cubeFunc);
                this.rewriteAggCalls.add(aggCall);
            }
        }
        this.rowType = this.deriveRowType();
        this.columnRowType = this.buildColumnRowType();
    }

    protected AggregateCall rewriteAggCall(AggregateCall aggCall, FunctionDesc cubeFunc) {
        if (!this.noPrecaculatedFieldsAvailable() || !cubeFunc.needRewriteField()) {
            if (cubeFunc.needRewrite()) {
                aggCall = this.rewriteAggregateCall(aggCall, cubeFunc);
            }
            if (cubeFunc.getMeasureType() != null && cubeFunc.getMeasureType().needRewriteField()) {
                aggCall = new KylinAggregateCall(aggCall, cubeFunc);
            }
        } else {
            logger.info(aggCall + "skip rewriteAggregateCall because no pre-aggregated field available");
        }
        return aggCall;
    }

    protected void translateAggregation() {
        if (!this.noPrecaculatedFieldsAvailable()) {
            List measures = this.context.realization.getMeasures();
            ArrayList newAggrs = Lists.newArrayList();
            for (FunctionDesc aggFunc : this.aggregations) {
                if (aggFunc.isDimensionAsMetric()) {
                    newAggrs.add(aggFunc);
                    continue;
                }
                FunctionDesc newAgg = this.findInMeasures(aggFunc, measures);
                if (newAgg == null && aggFunc.isCountOnColumn() && this.context.realization.getConfig().isReplaceColCountWithCountStar()) {
                    newAgg = FunctionDesc.newCountOne();
                }
                if (newAgg == null) {
                    newAgg = aggFunc;
                }
                newAggrs.add(newAgg);
            }
            this.aggregations.clear();
            this.aggregations.addAll(newAggrs);
            this.context.aggregations.clear();
            for (FunctionDesc agg : this.aggregations) {
                if (agg.isAggregateOnConstant()) continue;
                this.context.aggregations.add(agg);
            }
        }
    }

    private FunctionDesc findInMeasures(FunctionDesc aggFunc, List<MeasureDesc> measures) {
        for (MeasureDesc m : measures) {
            if (!aggFunc.equals((Object)m.getFunction())) continue;
            return m.getFunction();
        }
        for (MeasureDesc m : measures) {
            FunctionDesc internalTopn;
            if (m.getFunction().getMeasureType() instanceof BasicMeasureType) continue;
            if (m.getFunction().getMeasureType() instanceof TopNMeasureType && aggFunc.equals((Object)(internalTopn = TopNMeasureType.getTopnInternalMeasure((FunctionDesc)m.getFunction())))) {
                return internalTopn;
            }
            if (!"INTERSECT_COUNT".equalsIgnoreCase(aggFunc.getExpression()) && !"BITMAP_UUID".equalsIgnoreCase(aggFunc.getExpression()) && !"BITMAP_BUILD".equalsIgnoreCase(aggFunc.getExpression()) || !m.getFunction().getReturnType().equals("bitmap") || !((ParameterDesc)aggFunc.getParameters().get(0)).equals(m.getFunction().getParameters().get(0))) continue;
            return m.getFunction();
        }
        return null;
    }

    protected void buildRewriteFieldsAndMetricsColumns() {
        ColumnRowType inputColumnRowType = ((OLAPRel)this.getInput()).getColumnRowType();
        RelDataTypeFactory typeFactory = this.getCluster().getTypeFactory();
        for (int i = 0; i < this.aggregations.size(); ++i) {
            AggregateCall aggCall;
            FunctionDesc aggFunc = this.aggregations.get(i);
            if (aggFunc.isDimensionAsMetric()) continue;
            if (aggFunc.needRewriteField()) {
                String rewriteFieldName = aggFunc.getRewriteFieldName();
                RelDataType rewriteFieldType = OLAPTable.createSqlType(typeFactory, aggFunc.getRewriteFieldType(), true);
                this.context.rewriteFields.put(rewriteFieldName, rewriteFieldType);
                TblColRef column = this.buildRewriteColumn(aggFunc);
                this.context.metricsColumns.add(column);
            }
            if ((aggCall = this.rewriteAggCalls.get(i)).getArgList().isEmpty()) continue;
            for (Integer index : aggCall.getArgList()) {
                TblColRef column = inputColumnRowType.getColumnByIndex(index);
                if (column.isInnerColumn() || !this.context.belongToContextTables(column)) continue;
                this.context.metricsColumns.add(column);
            }
        }
    }

    protected void addToContextGroupBy(List<TblColRef> colRefs) {
        for (TblColRef col : colRefs) {
            if (col.isInnerColumn() || !this.context.belongToContextTables(col)) continue;
            this.context.getGroupByColumns().add(col);
        }
    }

    public boolean noPrecaculatedFieldsAvailable() {
        return !this.context.hasPrecalculatedFields() || !OLAPRel.RewriteImplementor.needRewrite(this.context);
    }

    protected AggregateCall rewriteAggregateCall(AggregateCall aggCall, FunctionDesc func) {
        String callName = OLAPAggregateRel.getSqlFuncName(aggCall);
        RelDataType fieldType = aggCall.getType();
        SqlAggFunction newAgg = aggCall.getAggregation();
        Map udafMap = func.getMeasureType().getRewriteCalciteAggrFunctions();
        if (func.isCount()) {
            newAgg = SqlStdOperatorTable.SUM0;
        } else if (udafMap != null && udafMap.containsKey(callName)) {
            newAgg = this.createCustomAggFunction(callName, fieldType, (Class)udafMap.get(callName));
        }
        List<Object> newArgList = Lists.newArrayList((Iterable)aggCall.getArgList());
        if (udafMap != null && udafMap.containsKey(callName)) {
            newArgList = this.truncArgList((List<Integer>)newArgList, (Class)udafMap.get(callName));
        }
        if (func.needRewriteField()) {
            RelDataTypeField field = this.getInput().getRowType().getField(func.getRewriteFieldName(), true, false);
            if (newArgList.isEmpty()) {
                newArgList.add(field.getIndex());
            } else {
                newArgList.set(0, field.getIndex());
            }
        }
        return new AggregateCall(newAgg, false, (List)newArgList, fieldType, callName);
    }

    List<Integer> truncArgList(List<Integer> argList, Class<?> udafClazz) {
        int argListLength = argList.size();
        for (Method method : udafClazz.getMethods()) {
            if (!method.getName().equals("add")) continue;
            argListLength = Math.min(method.getParameterTypes().length - 1, argListLength);
        }
        return argList.subList(0, argListLength);
    }

    SqlAggFunction createCustomAggFunction(String funcName, RelDataType returnType, Class<?> customAggFuncClz) {
        RelDataTypeFactory typeFactory = this.getCluster().getTypeFactory();
        SqlIdentifier sqlIdentifier = new SqlIdentifier(funcName, new SqlParserPos(1, 1));
        AggregateFunctionImpl aggFunction = AggregateFunctionImpl.create(customAggFuncClz);
        ArrayList<RelDataType> argTypes = new ArrayList<RelDataType>();
        ArrayList<Object> typeFamilies = new ArrayList<Object>();
        for (FunctionParameter o : aggFunction.getParameters()) {
            RelDataType type = o.getType(typeFactory);
            argTypes.add(type);
            typeFamilies.add(Util.first((Object)type.getSqlTypeName().getFamily(), (Object)SqlTypeFamily.ANY));
        }
        return new SqlUserDefinedAggFunction(sqlIdentifier, (SqlReturnTypeInference)ReturnTypes.explicit((RelDataType)returnType), InferTypes.explicit(argTypes), (SqlOperandTypeChecker)OperandTypes.family(typeFamilies), (AggregateFunction)aggFunction, false, false, typeFactory);
    }

    @Override
    public EnumerableRel implementEnumerable(List<EnumerableRel> inputs) {
        try {
            return new EnumerableAggregate(this.getCluster(), this.getCluster().traitSetOf((RelTrait)EnumerableConvention.INSTANCE), (RelNode)OLAPAggregateRel.sole(inputs), this.indicator, this.groupSet, (List)this.groupSets, this.rewriteAggCalls);
        }
        catch (InvalidRelException e) {
            throw new IllegalStateException("Can't create EnumerableAggregate!", e);
        }
    }

    @Override
    public OLAPContext getContext() {
        return this.context;
    }

    @Override
    public ColumnRowType getColumnRowType() {
        return this.columnRowType;
    }

    @Override
    public boolean hasSubQuery() {
        OLAPRel olapChild = (OLAPRel)this.getInput();
        return olapChild.hasSubQuery();
    }

    @Override
    public RelTraitSet replaceTraitSet(RelTrait trait) {
        RelTraitSet oldTraitSet = this.traitSet;
        this.traitSet = this.traitSet.replace(trait);
        return oldTraitSet;
    }

    public List<AggregateCall> getRewriteAggCalls() {
        return this.rewriteAggCalls;
    }

    public RelWriter explainTerms(RelWriter pw) {
        return super.explainTerms(pw).item("ctx", (Object)(this.context == null ? "" : this.context.id + "@" + this.context.realization));
    }

    public List<TblColRef> getGroups() {
        return this.groups;
    }

    public void setGroups(List<TblColRef> groups) {
        this.groups = groups;
    }

    static {
        AGGR_FUNC_MAP.put("SUM", "SUM");
        AGGR_FUNC_MAP.put("$SUM0", "SUM");
        AGGR_FUNC_MAP.put("COUNT", "COUNT");
        AGGR_FUNC_MAP.put("COUNT_DISTINCT", "COUNT_DISTINCT");
        AGGR_FUNC_MAP.put("MAX", "MAX");
        AGGR_FUNC_MAP.put("MIN", "MIN");
        AGGR_FUNC_MAP.put("GROUPING", "GROUPING");
        Map udafFactories = MeasureTypeFactory.getUDAFFactories();
        for (Map.Entry entry : udafFactories.entrySet()) {
            AGGR_FUNC_MAP.put((String)entry.getKey(), ((MeasureTypeFactory)entry.getValue()).getAggrFunctionName());
        }
        AGGR_FUNC_MAP.put("BITMAP_UUID", "BITMAP_UUID");
        AGGR_FUNC_MAP.put("BITMAP_BUILD", "BITMAP_BUILD");
    }
}

