/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.isthmus;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableAggregateFunctionInvocation;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.relation.Aggregate;
import io.substrait.relation.ImmutableProject;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

public class PreCalciteAggregateValidator {
    public static boolean isValidCalciteAggregate(Aggregate aggregate) {
        return aggregate.getMeasures().stream().allMatch(PreCalciteAggregateValidator::isValidCalciteMeasure) && aggregate.getGroupings().stream().allMatch(PreCalciteAggregateValidator::isValidCalciteGrouping);
    }

    private static boolean isValidCalciteMeasure(Aggregate.Measure measure) {
        return measure.getFunction().arguments().stream().allMatch(farg -> PreCalciteAggregateValidator.isSimpleFieldReference(farg)) && measure.getFunction().sort().stream().allMatch(sf -> PreCalciteAggregateValidator.isSimpleFieldReference((FunctionArg)sf.expr())) && measure.getPreMeasureFilter().map(f -> PreCalciteAggregateValidator.isSimpleFieldReference((FunctionArg)f)).orElse(true) != false;
    }

    private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) {
        if (!grouping.getExpressions().stream().allMatch(e -> PreCalciteAggregateValidator.isSimpleFieldReference((FunctionArg)e))) {
            return false;
        }
        List<Integer> groupingFields = grouping.getExpressions().stream().map(expr -> PreCalciteAggregateValidator.getFieldRefOffset((FieldReference)expr)).collect(Collectors.toList());
        return PreCalciteAggregateValidator.isOrdered(groupingFields);
    }

    private static boolean isSimpleFieldReference(FunctionArg e) {
        FieldReference fr;
        return e instanceof FieldReference && (fr = (FieldReference)e).segments().size() == 1 && fr.segments().get(0) instanceof FieldReference.StructField;
    }

    private static int getFieldRefOffset(FieldReference fr) {
        return ((FieldReference.StructField)fr.segments().get(0)).offset();
    }

    private static boolean isOrdered(List<Integer> list) {
        for (int i = 1; i < list.size(); ++i) {
            if (list.get(i - 1) <= list.get(i)) continue;
            return false;
        }
        return true;
    }

    public static class PreCalciteAggregateTransformer {
        private final List<Expression> newExpressions = new ArrayList<Expression>();
        private int expressionOffset;

        private PreCalciteAggregateTransformer(Aggregate aggregate) {
            this.expressionOffset = aggregate.getInput().getRecordType().fields().size();
        }

        public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) {
            PreCalciteAggregateTransformer at = new PreCalciteAggregateTransformer(aggregate);
            List newMeasures = aggregate.getMeasures().stream().map(at::updateMeasure).collect(Collectors.toList());
            List newGroupings = aggregate.getGroupings().stream().map(at::updateGrouping).collect(Collectors.toList());
            ImmutableProject preAggregateProject = Project.builder().input(aggregate.getInput()).expressions(at.newExpressions).build();
            return Aggregate.builder().from(aggregate).input((Rel)preAggregateProject).measures(newMeasures).groupings(newGroupings).build();
        }

        private Aggregate.Measure updateMeasure(Aggregate.Measure measure) {
            AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction();
            List newFunctionArgs = oldAggregateFunctionInvocation.arguments().stream().map(this::projectOutNonFieldReference).collect(Collectors.toList());
            List newSortFields = oldAggregateFunctionInvocation.sort().stream().map(sf -> Expression.SortField.builder().from(sf).expr(this.projectOutNonFieldReference(sf.expr())).build()).collect(Collectors.toList());
            Optional<Expression> newPreMeasureFilter = measure.getPreMeasureFilter().map(this::projectOutNonFieldReference);
            ImmutableAggregateFunctionInvocation newAggregateFunctionInvocation = AggregateFunctionInvocation.builder().from(oldAggregateFunctionInvocation).arguments(newFunctionArgs).sort(newSortFields).build();
            return Aggregate.Measure.builder().function((AggregateFunctionInvocation)newAggregateFunctionInvocation).preMeasureFilter(newPreMeasureFilter).build();
        }

        private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) {
            List newGroupingExpressions = grouping.getExpressions().stream().map(this::projectOut).collect(Collectors.toList());
            return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build();
        }

        private Expression projectOutNonFieldReference(FunctionArg farg) {
            if (farg instanceof Expression) {
                Expression e = (Expression)farg;
                return this.projectOutNonFieldReference(e);
            }
            throw new IllegalArgumentException("cannot handle non-expression argument for aggregate");
        }

        private Expression projectOutNonFieldReference(Expression expr) {
            if (PreCalciteAggregateValidator.isSimpleFieldReference((FunctionArg)expr)) {
                return expr;
            }
            return this.projectOut(expr);
        }

        private Expression projectOut(Expression expr) {
            this.newExpressions.add(expr);
            return ImmutableFieldReference.builder().addSegments((FieldReference.ReferenceSegment)FieldReference.StructField.of((int)this.expressionOffset++)).type(expr.getType()).build();
        }
    }
}

