/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.isthmus;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.CallConverter;
import io.substrait.isthmus.FeatureBoard;
import io.substrait.isthmus.ImmutableFeatureBoard;
import io.substrait.isthmus.OuterReferenceResolver;
import io.substrait.isthmus.RelNodeVisitor;
import io.substrait.isthmus.TypeConverter;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.CallConverters;
import io.substrait.isthmus.expression.LiteralConverter;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.ImmutableFetch;
import io.substrait.relation.ImmutableMeasure;
import io.substrait.relation.ImmutableSort;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.core.Match;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Value.Enclosing
public class SubstraitRelVisitor
extends RelNodeVisitor<Rel, RuntimeException> {
    static final Logger logger = LoggerFactory.getLogger(SubstraitRelVisitor.class);
    private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
    private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool((boolean)false, (boolean)true);
    protected final RexExpressionConverter rexExpressionConverter;
    protected final AggregateFunctionConverter aggregateFunctionConverter;
    protected final TypeConverter typeConverter;
    protected final FeatureBoard featureBoard;
    private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

    public SubstraitRelVisitor(RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) {
        this(typeFactory, extensions, FEATURES_DEFAULT);
    }

    public SubstraitRelVisitor(RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
        this.typeConverter = TypeConverter.DEFAULT;
        ArrayList<CallConverter> converters = new ArrayList<CallConverter>();
        converters.addAll(CallConverters.defaults(this.typeConverter));
        converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory));
        converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
        this.aggregateFunctionConverter = new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory);
        WindowFunctionConverter windowFunctionConverter = new WindowFunctionConverter(extensions.windowFunctions(), typeFactory);
        this.rexExpressionConverter = new RexExpressionConverter(this, converters, windowFunctionConverter, this.typeConverter);
        this.featureBoard = features;
    }

    public SubstraitRelVisitor(RelDataTypeFactory typeFactory, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter, FeatureBoard features) {
        ArrayList<CallConverter> converters = new ArrayList<CallConverter>();
        converters.addAll(CallConverters.defaults(typeConverter));
        converters.add(scalarFunctionConverter);
        converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
        this.aggregateFunctionConverter = aggregateFunctionConverter;
        this.rexExpressionConverter = new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter);
        this.typeConverter = typeConverter;
        this.featureBoard = features;
    }

    protected Expression toExpression(RexNode node) {
        return (Expression)node.accept((RexVisitor)this.rexExpressionConverter);
    }

    @Override
    public Rel visit(TableScan scan) {
        NamedStruct type = this.typeConverter.toNamedStruct(scan.getRowType());
        return NamedScan.builder().initialSchema(type).addAllNames((Iterable)scan.getTable().getQualifiedName()).build();
    }

    @Override
    public Rel visit(TableFunctionScan scan) {
        return (Rel)super.visit(scan);
    }

    @Override
    public Rel visit(Values values) {
        NamedStruct type = this.typeConverter.toNamedStruct(values.getRowType());
        if (values.getTuples().isEmpty()) {
            return EmptyScan.builder().initialSchema(type).build();
        }
        LiteralConverter literalConverter = new LiteralConverter(this.typeConverter);
        List structs = values.getTuples().stream().map(list -> {
            List fields = list.stream().map(l -> literalConverter.convert((RexLiteral)l)).collect(Collectors.toUnmodifiableList());
            return ExpressionCreator.struct((boolean)false, fields);
        }).collect(Collectors.toUnmodifiableList());
        return VirtualTableScan.builder().addAllDfsNames((Iterable)type.names()).addAllRows(structs).build();
    }

    @Override
    public Rel visit(org.apache.calcite.rel.core.Filter filter) {
        Expression condition = this.toExpression(filter.getCondition());
        return Filter.builder().condition(condition).input(this.apply(filter.getInput())).build();
    }

    @Override
    public Rel visit(Calc calc) {
        return (Rel)super.visit(calc);
    }

    @Override
    public Rel visit(org.apache.calcite.rel.core.Project project) {
        List expressions = project.getProjects().stream().map(this::toExpression).collect(Collectors.toList());
        return Project.builder().remap(Rel.Remap.offset((int)project.getInput().getRowType().getFieldCount(), (int)expressions.size())).expressions(expressions).input(this.apply(project.getInput())).build();
    }

    @Override
    public Rel visit(org.apache.calcite.rel.core.Join join) {
        Join.JoinType joinType;
        Rel left = this.apply(join.getLeft());
        Rel right = this.apply(join.getRight());
        Expression condition = this.toExpression(join.getCondition());
        switch (join.getJoinType()) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case INNER: {
                Join.JoinType joinType2 = Join.JoinType.INNER;
                break;
            }
            case LEFT: {
                Join.JoinType joinType2 = Join.JoinType.LEFT;
                break;
            }
            case RIGHT: {
                Join.JoinType joinType2 = Join.JoinType.RIGHT;
                break;
            }
            case FULL: {
                Join.JoinType joinType2 = Join.JoinType.OUTER;
                break;
            }
            case SEMI: {
                Join.JoinType joinType2 = Join.JoinType.SEMI;
                break;
            }
            case ANTI: {
                Join.JoinType joinType2 = joinType = Join.JoinType.ANTI;
            }
        }
        if (joinType == Join.JoinType.INNER && TRUE.equals(condition) && this.featureBoard.crossJoinPolicy().equals((Object)CrossJoinPolicy.KEEP_AS_CROSS_JOIN)) {
            return Cross.builder().left(left).right(right).build();
        }
        return Join.builder().condition(condition).joinType(joinType).left(left).right(right).build();
    }

    @Override
    public Rel visit(Correlate correlate) {
        Join.JoinType joinType;
        this.apply(correlate.getLeft());
        this.apply(correlate.getRight());
        switch (correlate.getJoinType()) {
            case INNER: {
                joinType = Join.JoinType.INNER;
                break;
            }
            case LEFT: {
                joinType = Join.JoinType.LEFT;
                break;
            }
            default: {
                throw new IllegalArgumentException("Invalid correlated join type: " + correlate.getJoinType());
            }
        }
        Join.JoinType joinType2 = joinType;
        return (Rel)super.visit(correlate);
    }

    @Override
    public Rel visit(Union union) {
        List<Rel> inputs = this.apply(union.getInputs());
        Set.SetOp setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT;
        return Set.builder().inputs(inputs).setOp(setOp).build();
    }

    @Override
    public Rel visit(Intersect intersect) {
        List<Rel> inputs = this.apply(intersect.getInputs());
        Set.SetOp setOp = intersect.all ? Set.SetOp.INTERSECTION_MULTISET : Set.SetOp.INTERSECTION_PRIMARY;
        return Set.builder().inputs(inputs).setOp(setOp).build();
    }

    @Override
    public Rel visit(Minus minus) {
        List<Rel> inputs = this.apply(minus.getInputs());
        Set.SetOp setOp = minus.all ? Set.SetOp.MINUS_MULTISET : Set.SetOp.MINUS_PRIMARY;
        return Set.builder().inputs(inputs).setOp(setOp).build();
    }

    @Override
    public Rel visit(Aggregate aggregate) {
        Rel input = this.apply(aggregate.getInput());
        Stream<ImmutableBitSet> sets = aggregate.groupSets != null ? aggregate.groupSets.stream() : Stream.of(aggregate.getGroupSet());
        List groupings = sets.filter(s -> s != null).map(s -> this.fromGroupSet((ImmutableBitSet)s, input)).collect(Collectors.toList());
        List aggCalls = aggregate.getAggCallList().stream().map(c -> this.fromAggCall(aggregate.getInput(), input.getRecordType(), (AggregateCall)c)).collect(Collectors.toList());
        return io.substrait.relation.Aggregate.builder().input(input).addAllGroupings(groupings).addAllMeasures(aggCalls).build();
    }

    Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {
        List references = bitSet.asList().stream().map(i -> FieldReference.newInputRelReference((int)i, (Rel)input)).collect(Collectors.toList());
        return Aggregate.Grouping.builder().addAllExpressions(references).build();
    }

    Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) {
        Optional<AggregateFunctionInvocation> invocation = this.aggregateFunctionConverter.convert(input, inputType, call, t -> (Expression)t.accept((RexVisitor)this.rexExpressionConverter));
        if (invocation.isEmpty()) {
            throw new UnsupportedOperationException("Unable to find binding for call " + call);
        }
        ImmutableMeasure.Builder builder = Aggregate.Measure.builder().function(invocation.get());
        if (call.filterArg != -1) {
            builder.preMeasureFilter((Expression)FieldReference.newRootStructReference((int)call.filterArg, (Type)inputType));
        }
        return builder.build();
    }

    @Override
    public Rel visit(Match match) {
        return (Rel)super.visit(match);
    }

    @Override
    public Rel visit(org.apache.calcite.rel.core.Sort sort) {
        Rel input = this.apply(sort.getInput());
        List fields = sort.getCollation().getFieldCollations().stream().map(t -> SubstraitRelVisitor.toSortField(t, input.getRecordType())).collect(Collectors.toList());
        ImmutableSort convertedSort = Sort.builder().addAllSortFields(fields).input(input).build();
        if (sort.fetch == null && sort.offset == null) {
            return convertedSort;
        }
        Long offset = Optional.ofNullable(sort.offset).map(r -> this.asLong((RexNode)r)).orElse(0L);
        ImmutableFetch.Builder builder = Fetch.builder().input((Rel)convertedSort).offset(offset.longValue());
        if (sort.fetch == null) {
            return builder.build();
        }
        return builder.count(this.asLong(sort.fetch)).build();
    }

    private long asLong(RexNode rex) {
        Expression expr = this.toExpression(rex);
        if (expr instanceof Expression.I64Literal) {
            Expression.I64Literal i = (Expression.I64Literal)expr;
            return i.value();
        }
        if (expr instanceof Expression.I32Literal) {
            Expression.I32Literal i = (Expression.I32Literal)expr;
            return i.value();
        }
        throw new UnsupportedOperationException("Unknown type: " + rex);
    }

    public static Expression.SortField toSortField(RelFieldCollation collation, Type.Struct inputType) {
        Expression.SortDirection sortDirection;
        switch (collation.direction) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case STRICTLY_ASCENDING: 
            case ASCENDING: {
                if (collation.nullDirection == RelFieldCollation.NullDirection.LAST) {
                    sortDirection = Expression.SortDirection.ASC_NULLS_LAST;
                    break;
                }
                sortDirection = Expression.SortDirection.ASC_NULLS_FIRST;
                break;
            }
            case STRICTLY_DESCENDING: 
            case DESCENDING: {
                if (collation.nullDirection == RelFieldCollation.NullDirection.LAST) {
                    sortDirection = Expression.SortDirection.DESC_NULLS_LAST;
                    break;
                }
                sortDirection = Expression.SortDirection.DESC_NULLS_FIRST;
                break;
            }
            case CLUSTERED: {
                sortDirection = Expression.SortDirection.CLUSTERED;
            }
        }
        Expression.SortDirection direction = sortDirection;
        return Expression.SortField.builder().expr((Expression)FieldReference.newRootStructReference((int)collation.getFieldIndex(), (Type)inputType)).direction(direction).build();
    }

    @Override
    public Rel visit(Exchange exchange) {
        return (Rel)super.visit(exchange);
    }

    @Override
    public Rel visit(TableModify modify) {
        return (Rel)super.visit(modify);
    }

    @Override
    public Rel visitOther(RelNode other) {
        throw new UnsupportedOperationException("Unable to handle node: " + other);
    }

    private void popFieldAccessDepthMap(RelNode root) {
        OuterReferenceResolver resolver = new OuterReferenceResolver();
        resolver.apply(root);
        this.fieldAccessDepthMap = resolver.getFieldAccessDepthMap();
    }

    public Integer getFieldAccessDepth(RexFieldAccess fieldAccess) {
        return this.fieldAccessDepthMap.get(fieldAccess);
    }

    public Rel apply(RelNode r) {
        return (Rel)this.reverseAccept(r);
    }

    public List<Rel> apply(List<RelNode> inputs) {
        return inputs.stream().map(inputRel -> this.apply((RelNode)inputRel)).collect(Collectors.toList());
    }

    public static Rel convert(RelRoot root, SimpleExtension.ExtensionCollection extensions) {
        return SubstraitRelVisitor.convert(root.rel, extensions, FEATURES_DEFAULT);
    }

    public static Rel convert(RelRoot root, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
        return SubstraitRelVisitor.convert(root.rel, extensions, features);
    }

    private static Rel convert(RelNode rel, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
        SubstraitRelVisitor visitor = new SubstraitRelVisitor(rel.getCluster().getTypeFactory(), extensions, features);
        visitor.popFieldAccessDepthMap(rel);
        return visitor.apply(rel);
    }

    public static enum CrossJoinPolicy {
        KEEP_AS_CROSS_JOIN,
        CONVERT_TO_INNER_JOIN;

    }

    public static class Options {
        private final CrossJoinPolicy crossJoinPolicy;

        public Options() {
            this(CrossJoinPolicy.CONVERT_TO_INNER_JOIN);
        }

        public Options(CrossJoinPolicy crossJoinPolicy) {
            this.crossJoinPolicy = crossJoinPolicy;
        }

        public CrossJoinPolicy getCrossJoinPolicy() {
            return this.crossJoinPolicy;
        }
    }
}

