/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.expression.proto.FunctionCollector;
import io.substrait.function.SimpleExtension;
import io.substrait.proto.AggregateFunction;
import io.substrait.proto.AggregateRel;
import io.substrait.proto.CrossRel;
import io.substrait.proto.Expression;
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.JoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.RelCommon;
import io.substrait.proto.SetRel;
import io.substrait.proto.SortField;
import io.substrait.proto.SortRel;
import io.substrait.proto.Type;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.LocalFiles;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.RelVisitor;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class RelProtoConverter
implements RelVisitor<io.substrait.proto.Rel, RuntimeException> {
    private final ExpressionProtoConverter protoConverter;
    private final FunctionCollector functionCollector;

    public RelProtoConverter(FunctionCollector functionCollector) {
        this.functionCollector = functionCollector;
        this.protoConverter = new ExpressionProtoConverter(functionCollector, this);
    }

    private List<Expression> toProto(Collection<io.substrait.expression.Expression> expressions) {
        return expressions.stream().map(this::toProto).collect(Collectors.toList());
    }

    private Expression toProto(io.substrait.expression.Expression expression) {
        return expression.accept(this.protoConverter);
    }

    public io.substrait.proto.Rel toProto(Rel rel) {
        return rel.accept(this);
    }

    private Type toProto(io.substrait.type.Type type) {
        return type.accept(TypeProtoConverter.INSTANCE);
    }

    private List<SortField> toProtoS(Collection<Expression.SortField> sorts) {
        return sorts.stream().map(s -> SortField.newBuilder().setDirection(s.direction().toProto()).setExpr(this.toProto(s.expr())).build()).collect(Collectors.toList());
    }

    @Override
    public io.substrait.proto.Rel visit(Aggregate aggregate) throws RuntimeException {
        AggregateRel.Builder builder = AggregateRel.newBuilder().setInput(this.toProto(aggregate.getInput())).setCommon(this.common(aggregate)).addAllGroupings(aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList())).addAllMeasures(aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList()));
        return io.substrait.proto.Rel.newBuilder().setAggregate(builder).build();
    }

    private AggregateRel.Measure toProto(Aggregate.Measure measure) {
        FunctionArg.FuncArgVisitor<FunctionArgument, RuntimeException> argVisitor = FunctionArg.toProto(TypeProtoConverter.INSTANCE, this.protoConverter);
        List<FunctionArg> args = measure.getFunction().arguments();
        SimpleExtension.AggregateFunctionVariant aggFuncDef = measure.getFunction().declaration();
        AggregateFunction.Builder func = AggregateFunction.newBuilder().setPhase(measure.getFunction().aggregationPhase().toProto()).setInvocation(measure.getFunction().invocation()).setOutputType(this.toProto(measure.getFunction().getType())).addAllArguments(IntStream.range(0, args.size()).mapToObj(i -> (FunctionArgument)((FunctionArg)args.get(i)).accept(aggFuncDef, i, argVisitor)).collect(Collectors.toList())).addAllSorts(this.toProtoS(measure.getFunction().sort())).setFunctionReference(this.functionCollector.getFunctionReference(measure.getFunction().declaration()));
        AggregateRel.Measure.Builder builder = AggregateRel.Measure.newBuilder().setMeasure(func);
        measure.getPreMeasureFilter().ifPresent(f -> builder.setFilter(this.toProto((io.substrait.expression.Expression)f)));
        return builder.build();
    }

    private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) {
        return AggregateRel.Grouping.newBuilder().addAllGroupingExpressions(this.toProto(grouping.getExpressions())).build();
    }

    @Override
    public io.substrait.proto.Rel visit(EmptyScan emptyScan) throws RuntimeException {
        return io.substrait.proto.Rel.newBuilder().setRead(ReadRel.newBuilder().setCommon(this.common(emptyScan)).setVirtualTable(ReadRel.VirtualTable.newBuilder().build()).setBaseSchema(emptyScan.getInitialSchema().toProto()).build()).build();
    }

    @Override
    public io.substrait.proto.Rel visit(Fetch fetch) throws RuntimeException {
        FetchRel.Builder builder = FetchRel.newBuilder().setCommon(this.common(fetch)).setInput(this.toProto(fetch.getInput())).setOffset(fetch.getOffset());
        fetch.getCount().ifPresent(f -> builder.setCount(f));
        return io.substrait.proto.Rel.newBuilder().setFetch(builder).build();
    }

    @Override
    public io.substrait.proto.Rel visit(Filter filter) throws RuntimeException {
        FilterRel.Builder builder = FilterRel.newBuilder().setCommon(this.common(filter)).setInput(this.toProto(filter.getInput())).setCondition(filter.getCondition().accept(this.protoConverter));
        return io.substrait.proto.Rel.newBuilder().setFilter(builder).build();
    }

    @Override
    public io.substrait.proto.Rel visit(Join join) throws RuntimeException {
        JoinRel.Builder builder = JoinRel.newBuilder().setCommon(this.common(join)).setLeft(this.toProto(join.getLeft())).setRight(this.toProto(join.getRight())).setType(join.getJoinType().toProto());
        join.getCondition().ifPresent(t -> builder.setExpression(this.toProto((io.substrait.expression.Expression)t)));
        return io.substrait.proto.Rel.newBuilder().setJoin(builder).build();
    }

    @Override
    public io.substrait.proto.Rel visit(Set set) throws RuntimeException {
        SetRel.Builder builder = SetRel.newBuilder().setCommon(this.common(set)).setOp(set.getSetOp().toProto());
        set.getInputs().forEach(inputRel -> builder.addInputs(this.toProto((Rel)inputRel)));
        return io.substrait.proto.Rel.newBuilder().setSet(builder).build();
    }

    @Override
    public io.substrait.proto.Rel visit(NamedScan namedScan) throws RuntimeException {
        return io.substrait.proto.Rel.newBuilder().setRead(ReadRel.newBuilder().setCommon(this.common(namedScan)).setNamedTable(ReadRel.NamedTable.newBuilder().addAllNames(namedScan.getNames())).setBaseSchema(namedScan.getInitialSchema().toProto()).build()).build();
    }

    @Override
    public io.substrait.proto.Rel visit(LocalFiles localFiles) throws RuntimeException {
        ReadRel.Builder builder = ReadRel.newBuilder().setCommon(this.common(localFiles)).setLocalFiles(ReadRel.LocalFiles.newBuilder().addAllItems(localFiles.getItems().stream().map(FileOrFiles::toProto).collect(Collectors.toList())).build()).setBaseSchema(localFiles.getInitialSchema().toProto());
        localFiles.getFilter().ifPresent(t -> builder.setFilter(this.toProto((io.substrait.expression.Expression)t)));
        return io.substrait.proto.Rel.newBuilder().setRead(builder.build()).build();
    }

    @Override
    public io.substrait.proto.Rel visit(Project project) throws RuntimeException {
        ProjectRel.Builder builder = ProjectRel.newBuilder().setCommon(this.common(project)).setInput(this.toProto(project.getInput())).addAllExpressions(project.getExpressions().stream().map(this::toProto).collect(Collectors.toList()));
        return io.substrait.proto.Rel.newBuilder().setProject(builder).build();
    }

    @Override
    public io.substrait.proto.Rel visit(Sort sort) throws RuntimeException {
        SortRel.Builder builder = SortRel.newBuilder().setCommon(this.common(sort)).setInput(this.toProto(sort.getInput())).addAllSorts(this.toProtoS(sort.getSortFields()));
        return io.substrait.proto.Rel.newBuilder().setSort(builder).build();
    }

    @Override
    public io.substrait.proto.Rel visit(Cross cross) throws RuntimeException {
        CrossRel.Builder builder = CrossRel.newBuilder().setCommon(this.common(cross)).setLeft(this.toProto(cross.getLeft())).setRight(this.toProto(cross.getRight()));
        return io.substrait.proto.Rel.newBuilder().setCross(builder).build();
    }

    @Override
    public io.substrait.proto.Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException {
        return io.substrait.proto.Rel.newBuilder().setRead(ReadRel.newBuilder().setCommon(this.common(virtualTableScan)).setVirtualTable(ReadRel.VirtualTable.newBuilder().addAllValues(virtualTableScan.getRows().stream().map(this::toProto).map(t -> t.getLiteral().getStruct()).collect(Collectors.toList())).build()).setBaseSchema(virtualTableScan.getInitialSchema().toProto()).build()).build();
    }

    private RelCommon common(Rel rel) {
        RelCommon.Builder builder = RelCommon.newBuilder();
        Rel.Remap remap = rel.getRemap().orElse(null);
        if (remap != null) {
            builder.setEmit(RelCommon.Emit.newBuilder().addAllOutputMapping(remap.indices()));
        } else {
            builder.setDirect(RelCommon.Direct.getDefaultInstance());
        }
        return builder.build();
    }
}

