/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.query.aggregation.function;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.floats.FloatArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import java.lang.invoke.LambdaMetafactory;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.CustomObject;
import org.apache.pinot.common.datatable.DataTable;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FilterContext;
import org.apache.pinot.common.request.context.predicate.Predicate;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.common.utils.config.QueryOptionsUtils;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.ObjectSerDeUtils;
import org.apache.pinot.core.common.Operator;
import org.apache.pinot.core.operator.BaseProjectOperator;
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.filter.BaseFilterOperator;
import org.apache.pinot.core.operator.filter.CombinedFilterOperator;
import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
import org.apache.pinot.core.plan.FilterPlanNode;
import org.apache.pinot.core.plan.ProjectPlanNode;
import org.apache.pinot.core.query.aggregation.function.AggregationFunction;
import org.apache.pinot.core.query.request.context.QueryContext;
import org.apache.pinot.core.startree.StarTreeUtils;
import org.apache.pinot.segment.spi.AggregationFunctionType;
import org.apache.pinot.segment.spi.SegmentContext;
import org.apache.pinot.segment.spi.index.startree.AggregationFunctionColumnPair;

public class AggregationFunctionUtils {
    private AggregationFunctionUtils() {
    }

    @Nullable
    public static AggregationFunctionColumnPair getStoredFunctionColumnPair(AggregationFunction aggregationFunction) {
        ExpressionContext inputExpression;
        AggregationFunctionType functionType = aggregationFunction.getType();
        if (functionType == AggregationFunctionType.COUNT) {
            return AggregationFunctionColumnPair.COUNT_STAR;
        }
        List<ExpressionContext> inputExpressions = aggregationFunction.getInputExpressions();
        if (inputExpressions.size() == 1 && (inputExpression = inputExpressions.get(0)).getType() == ExpressionContext.Type.IDENTIFIER) {
            return new AggregationFunctionColumnPair(AggregationFunctionColumnPair.getStoredType((AggregationFunctionType)functionType), inputExpression.getIdentifier());
        }
        return null;
    }

    public static Set<ExpressionContext> collectExpressionsToTransform(AggregationFunction[] aggregationFunctions, @Nullable List<ExpressionContext> groupByExpressions) {
        HashSet<ExpressionContext> expressions = new HashSet<ExpressionContext>();
        for (AggregationFunction aggregationFunction : aggregationFunctions) {
            expressions.addAll(aggregationFunction.getInputExpressions());
        }
        if (groupByExpressions != null) {
            expressions.addAll(groupByExpressions);
        }
        return expressions;
    }

    public static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunction aggregationFunction, ValueBlock valueBlock) {
        List<ExpressionContext> expressions = aggregationFunction.getInputExpressions();
        int numExpressions = expressions.size();
        if (numExpressions == 0) {
            return Collections.emptyMap();
        }
        if (numExpressions == 1) {
            ExpressionContext expression = expressions.get(0);
            return Collections.singletonMap(expression, valueBlock.getBlockValueSet(expression));
        }
        HashMap<ExpressionContext, BlockValSet> blockValSetMap = new HashMap<ExpressionContext, BlockValSet>();
        for (ExpressionContext expression : expressions) {
            blockValSetMap.put(expression, valueBlock.getBlockValueSet(expression));
        }
        return blockValSetMap;
    }

    public static Map<ExpressionContext, BlockValSet> getBlockValSetMap(AggregationFunctionColumnPair aggregationFunctionColumnPair, ValueBlock valueBlock) {
        ExpressionContext expression = ExpressionContext.forIdentifier((String)aggregationFunctionColumnPair.getColumn());
        BlockValSet blockValSet = valueBlock.getBlockValueSet(aggregationFunctionColumnPair.toColumnName());
        return Collections.singletonMap(expression, blockValSet);
    }

    public static Object getIntermediateResult(DataTable dataTable, DataSchema.ColumnDataType columnDataType, int rowId, int colId) {
        switch (columnDataType.getStoredType()) {
            case INT: {
                return dataTable.getInt(rowId, colId);
            }
            case LONG: {
                return dataTable.getLong(rowId, colId);
            }
            case DOUBLE: {
                return dataTable.getDouble(rowId, colId);
            }
            case OBJECT: {
                CustomObject customObject = dataTable.getCustomObject(rowId, colId);
                return customObject != null ? ObjectSerDeUtils.deserialize(customObject) : null;
            }
        }
        throw new IllegalStateException("Illegal column data type in intermediate result: " + columnDataType);
    }

    public static Comparable getFinalResult(DataTable dataTable, DataSchema.ColumnDataType columnDataType, int rowId, int colId) {
        switch (columnDataType.getStoredType()) {
            case INT: {
                return Integer.valueOf(dataTable.getInt(rowId, colId));
            }
            case LONG: {
                return Long.valueOf(dataTable.getLong(rowId, colId));
            }
            case FLOAT: {
                return Float.valueOf(dataTable.getFloat(rowId, colId));
            }
            case DOUBLE: {
                return Double.valueOf(dataTable.getDouble(rowId, colId));
            }
            case BIG_DECIMAL: {
                return dataTable.getBigDecimal(rowId, colId);
            }
            case STRING: {
                return dataTable.getString(rowId, colId);
            }
            case BYTES: {
                return dataTable.getBytes(rowId, colId);
            }
            case INT_ARRAY: {
                return IntArrayList.wrap((int[])dataTable.getIntArray(rowId, colId));
            }
            case LONG_ARRAY: {
                return LongArrayList.wrap((long[])dataTable.getLongArray(rowId, colId));
            }
            case FLOAT_ARRAY: {
                return FloatArrayList.wrap((float[])dataTable.getFloatArray(rowId, colId));
            }
            case DOUBLE_ARRAY: {
                return DoubleArrayList.wrap((double[])dataTable.getDoubleArray(rowId, colId));
            }
            case STRING_ARRAY: {
                return ObjectArrayList.wrap((Object[])dataTable.getStringArray(rowId, colId));
            }
        }
        throw new IllegalStateException("Illegal column data type in final result: " + columnDataType);
    }

    public static Object getConvertedFinalResult(DataTable dataTable, DataSchema.ColumnDataType columnDataType, int rowId, int colId) {
        switch (columnDataType) {
            case INT: {
                return dataTable.getInt(rowId, colId);
            }
            case LONG: {
                return dataTable.getLong(rowId, colId);
            }
            case FLOAT: {
                return Float.valueOf(dataTable.getFloat(rowId, colId));
            }
            case DOUBLE: {
                return dataTable.getDouble(rowId, colId);
            }
            case BIG_DECIMAL: {
                return dataTable.getBigDecimal(rowId, colId);
            }
            case BOOLEAN: {
                return dataTable.getInt(rowId, colId) == 1;
            }
            case TIMESTAMP: {
                return new Timestamp(dataTable.getLong(rowId, colId));
            }
            case STRING: 
            case JSON: {
                return dataTable.getString(rowId, colId);
            }
            case BYTES: {
                return dataTable.getBytes(rowId, colId).getBytes();
            }
            case INT_ARRAY: {
                return dataTable.getIntArray(rowId, colId);
            }
            case LONG_ARRAY: {
                return dataTable.getLongArray(rowId, colId);
            }
            case FLOAT_ARRAY: {
                return dataTable.getFloatArray(rowId, colId);
            }
            case DOUBLE_ARRAY: {
                return dataTable.getDoubleArray(rowId, colId);
            }
            case BOOLEAN_ARRAY: {
                int[] intValues = dataTable.getIntArray(rowId, colId);
                int numValues = intValues.length;
                boolean[] booleanValues = new boolean[numValues];
                for (int i = 0; i < numValues; ++i) {
                    booleanValues[i] = intValues[i] == 1;
                }
                return booleanValues;
            }
            case TIMESTAMP_ARRAY: {
                long[] longValues = dataTable.getLongArray(rowId, colId);
                int numValues = longValues.length;
                Timestamp[] timestampValues = new Timestamp[numValues];
                for (int i = 0; i < numValues; ++i) {
                    timestampValues[i] = new Timestamp(longValues[i]);
                }
                return timestampValues;
            }
            case STRING_ARRAY: {
                return dataTable.getStringArray(rowId, colId);
            }
        }
        throw new IllegalStateException("Illegal column data type in final result: " + columnDataType);
    }

    public static AggregationInfo buildAggregationInfo(SegmentContext segmentContext, QueryContext queryContext, AggregationFunction[] aggregationFunctions, @Nullable FilterContext filter, BaseFilterOperator filterOperator, List<Pair<Predicate, PredicateEvaluator>> predicateEvaluators) {
        Operator<Object> projectOperator = null;
        if (!filterOperator.isResultEmpty()) {
            projectOperator = StarTreeUtils.createStarTreeBasedProjectOperator(segmentContext.getIndexSegment(), queryContext, aggregationFunctions, filter, predicateEvaluators);
        }
        if (projectOperator != null) {
            return new AggregationInfo(aggregationFunctions, (BaseProjectOperator<?>)projectOperator, true);
        }
        Set<ExpressionContext> expressionsToTransform = AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, queryContext.getGroupByExpressions());
        projectOperator = new ProjectPlanNode(segmentContext, queryContext, expressionsToTransform, 10000, filterOperator).run();
        return new AggregationInfo(aggregationFunctions, (BaseProjectOperator<?>)projectOperator, false);
    }

    public static List<AggregationInfo> buildFilteredAggregationInfos(SegmentContext segmentContext, QueryContext queryContext) {
        assert (queryContext.getAggregationFunctions() != null && queryContext.getFilteredAggregationFunctions() != null);
        FilterPlanNode mainFilterPlan = new FilterPlanNode(segmentContext, queryContext);
        BaseFilterOperator mainFilterOperator = mainFilterPlan.run();
        List<Pair<Predicate, PredicateEvaluator>> mainPredicateEvaluators = mainFilterPlan.getPredicateEvaluators();
        if (mainFilterOperator.isResultEmpty()) {
            AggregationFunction[] aggregationFunctions = queryContext.getAggregationFunctions();
            Set<ExpressionContext> expressions = AggregationFunctionUtils.collectExpressionsToTransform(aggregationFunctions, queryContext.getGroupByExpressions());
            Operator projectOperator = new ProjectPlanNode(segmentContext, queryContext, expressions, 10000, mainFilterOperator).run();
            return Collections.singletonList(new AggregationInfo(aggregationFunctions, (BaseProjectOperator<?>)projectOperator, false));
        }
        HashMap<FilterContext, FilteredAggregationContext> filteredAggregationContexts = new HashMap<FilterContext, FilteredAggregationContext>();
        ArrayList<AggregationFunction> nonFilteredFunctions = new ArrayList<AggregationFunction>();
        FilterContext mainFilter = queryContext.getFilter();
        for (Pair<AggregationFunction, FilterContext> pair : queryContext.getFilteredAggregationFunctions()) {
            AggregationFunction aggregationFunction = (AggregationFunction)pair.getLeft();
            FilterContext filter = (FilterContext)pair.getRight();
            if (filter != null) {
                filteredAggregationContexts.computeIfAbsent(filter, (Function<FilterContext, FilteredAggregationContext>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Ljava/lang/Object;, lambda$buildFilteredAggregationInfos$0(org.apache.pinot.common.request.context.FilterContext org.apache.pinot.common.request.context.FilterContext org.apache.pinot.segment.spi.SegmentContext org.apache.pinot.core.query.request.context.QueryContext org.apache.pinot.core.operator.filter.BaseFilterOperator java.util.List org.apache.pinot.common.request.context.FilterContext ), (Lorg/apache/pinot/common/request/context/FilterContext;)Lorg/apache/pinot/core/query/aggregation/function/AggregationFunctionUtils$FilteredAggregationContext;)((FilterContext)mainFilter, (FilterContext)filter, (SegmentContext)segmentContext, (QueryContext)queryContext, (BaseFilterOperator)mainFilterOperator, mainPredicateEvaluators))._aggregationFunctions.add(aggregationFunction);
                continue;
            }
            nonFilteredFunctions.add(aggregationFunction);
        }
        ArrayList<AggregationInfo> aggregationInfos = new ArrayList<AggregationInfo>();
        for (FilteredAggregationContext filteredAggregationContext : filteredAggregationContexts.values()) {
            BaseFilterOperator filterOperator = filteredAggregationContext._filterOperator;
            if (filterOperator == mainFilterOperator) {
                nonFilteredFunctions.addAll(filteredAggregationContext._aggregationFunctions);
                continue;
            }
            AggregationFunction[] aggregationFunctions = filteredAggregationContext._aggregationFunctions.toArray(new AggregationFunction[0]);
            aggregationInfos.add(AggregationFunctionUtils.buildAggregationInfo(segmentContext, queryContext, aggregationFunctions, filteredAggregationContext._filter, filteredAggregationContext._filterOperator, filteredAggregationContext._predicateEvaluators));
        }
        if (!nonFilteredFunctions.isEmpty() || queryContext.getGroupByExpressions() != null && !QueryOptionsUtils.isFilteredAggregationsSkipEmptyGroups(queryContext.getQueryOptions())) {
            AggregationFunction[] aggregationFunctionArray = nonFilteredFunctions.toArray(new AggregationFunction[0]);
            aggregationInfos.add(AggregationFunctionUtils.buildAggregationInfo(segmentContext, queryContext, aggregationFunctionArray, mainFilter, mainFilterOperator, mainPredicateEvaluators));
        }
        return aggregationInfos;
    }

    public static String getResultColumnName(AggregationFunction aggregationFunction, @Nullable FilterContext filter) {
        Object columnName = aggregationFunction.getResultColumnName();
        if (filter != null) {
            columnName = (String)columnName + " FILTER(WHERE " + filter + ")";
        }
        return columnName;
    }

    private static /* synthetic */ FilteredAggregationContext lambda$buildFilteredAggregationInfos$0(FilterContext mainFilter, FilterContext filter, SegmentContext segmentContext, QueryContext queryContext, BaseFilterOperator mainFilterOperator, List mainPredicateEvaluators, FilterContext k) {
        FilterContext combinedFilter = mainFilter == null ? filter : FilterContext.forAnd(List.of(mainFilter, filter));
        FilterPlanNode subFilterPlan = new FilterPlanNode(segmentContext, queryContext, filter);
        BaseFilterOperator subFilterOperator = subFilterPlan.run();
        BaseFilterOperator combinedFilterOperator = mainFilterOperator.isResultMatchingAll() || subFilterOperator.isResultEmpty() ? subFilterOperator : (subFilterOperator.isResultMatchingAll() ? mainFilterOperator : new CombinedFilterOperator(mainFilterOperator, subFilterOperator, queryContext.getQueryOptions()));
        List<Pair<Predicate, PredicateEvaluator>> subPredicateEvaluators = subFilterPlan.getPredicateEvaluators();
        ArrayList<Pair<Predicate, PredicateEvaluator>> combinedPredicateEvaluators = new ArrayList<Pair<Predicate, PredicateEvaluator>>(mainPredicateEvaluators.size() + subPredicateEvaluators.size());
        combinedPredicateEvaluators.addAll(mainPredicateEvaluators);
        combinedPredicateEvaluators.addAll(subPredicateEvaluators);
        return new FilteredAggregationContext(combinedFilter, combinedFilterOperator, combinedPredicateEvaluators);
    }

    private static class FilteredAggregationContext {
        final FilterContext _filter;
        final BaseFilterOperator _filterOperator;
        final List<Pair<Predicate, PredicateEvaluator>> _predicateEvaluators;
        final List<AggregationFunction> _aggregationFunctions = new ArrayList<AggregationFunction>();

        public FilteredAggregationContext(FilterContext filter, BaseFilterOperator filterOperator, List<Pair<Predicate, PredicateEvaluator>> predicateEvaluators) {
            this._filter = filter;
            this._filterOperator = filterOperator;
            this._predicateEvaluators = predicateEvaluators;
        }
    }

    public static class AggregationInfo {
        private final AggregationFunction[] _functions;
        private final BaseProjectOperator<?> _projectOperator;
        private final boolean _useStarTree;

        public AggregationInfo(AggregationFunction[] functions, BaseProjectOperator<?> projectOperator, boolean useStarTree) {
            this._functions = functions;
            this._projectOperator = projectOperator;
            this._useStarTree = useStarTree;
        }

        public AggregationFunction[] getFunctions() {
            return this._functions;
        }

        public BaseProjectOperator<?> getProjectOperator() {
            return this._projectOperator;
        }

        public boolean isUseStarTree() {
            return this._useStarTree;
        }
    }
}

