/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.relation;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.relation.HasExtension;
import io.substrait.relation.ImmutableAggregate;
import io.substrait.relation.ImmutableGrouping;
import io.substrait.relation.ImmutableMeasure;
import io.substrait.relation.RelVisitor;
import io.substrait.relation.SingleInputRel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import io.substrait.util.VisitationContext;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
public abstract class Aggregate
extends SingleInputRel
implements HasExtension {
    public abstract List<Grouping> getGroupings();

    public abstract List<Measure> getMeasures();

    @Override
    protected Type.Struct deriveRecordType() {
        if (this.getGroupings().size() <= 1) {
            Stream<Type> groupingTypes = this.getGroupings().stream().flatMap(g -> g.getExpressions().stream()).map(Expression::getType);
            Stream<Type> measureTypes = this.getMeasures().stream().map(t -> t.getFunction().getType());
            return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
        }
        LinkedHashSet uniqueGroupingExpressions = this.getGroupings().stream().flatMap(g -> g.getExpressions().stream()).collect(Collectors.toCollection(LinkedHashSet::new));
        Stream<Type> groupingTypes = uniqueGroupingExpressions.stream().map(expr -> {
            boolean appearsInAllSets = this.getGroupings().stream().allMatch(g -> g.getExpressions().contains(expr));
            if (appearsInAllSets) {
                return expr.getType();
            }
            return TypeCreator.asNullable(expr.getType());
        });
        Stream<Type> measureTypes = this.getMeasures().stream().map(t -> t.getFunction().getType());
        return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
    }

    @Override
    public <O, C extends VisitationContext, E extends Exception> O accept(RelVisitor<O, C, E> visitor, C context) throws E {
        return visitor.visit(this, context);
    }

    public static ImmutableAggregate.Builder builder() {
        return ImmutableAggregate.builder();
    }

    @Value.Immutable
    public static abstract class Measure {
        public abstract AggregateFunctionInvocation getFunction();

        public abstract Optional<Expression> getPreMeasureFilter();

        public static ImmutableMeasure.Builder builder() {
            return ImmutableMeasure.builder();
        }
    }

    @Value.Immutable
    public static abstract class Grouping {
        public abstract List<Expression> getExpressions();

        public static ImmutableGrouping.Builder builder() {
            return ImmutableGrouping.builder();
        }
    }
}

