/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.isthmus;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.expression.FieldReference;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.CallConverter;
import io.substrait.isthmus.FeatureBoard;
import io.substrait.isthmus.ImmutableFeatureBoard;
import io.substrait.isthmus.OuterReferenceResolver;
import io.substrait.isthmus.RelNodeVisitor;
import io.substrait.isthmus.TypeConverter;
import io.substrait.isthmus.calcite.rel.CreateTable;
import io.substrait.isthmus.calcite.rel.CreateView;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.CallConverters;
import io.substrait.isthmus.expression.LiteralConverter;
import io.substrait.isthmus.expression.RexExpressionConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractDdlRel;
import io.substrait.relation.AbstractUpdate;
import io.substrait.relation.AbstractWriteRel;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Fetch;
import io.substrait.relation.ImmutableAggregate;
import io.substrait.relation.ImmutableFetch;
import io.substrait.relation.ImmutableMeasure;
import io.substrait.relation.ImmutableTransformExpression;
import io.substrait.relation.Join;
import io.substrait.relation.NamedDdl;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NamedUpdate;
import io.substrait.relation.NamedWrite;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Match;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.immutables.value.Value;

@Value.Enclosing
public class SubstraitRelVisitor
extends RelNodeVisitor<Rel, RuntimeException> {
    private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build();
    private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool((boolean)false, (boolean)true);
    protected final RexExpressionConverter rexExpressionConverter;
    protected final AggregateFunctionConverter aggregateFunctionConverter;
    protected final TypeConverter typeConverter;
    protected final FeatureBoard featureBoard;
    private Map<RexFieldAccess, Integer> fieldAccessDepthMap;

    public SubstraitRelVisitor(RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) {
        this(typeFactory, extensions, FEATURES_DEFAULT);
    }

    public SubstraitRelVisitor(RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
        this.typeConverter = TypeConverter.DEFAULT;
        ArrayList<CallConverter> converters = new ArrayList<CallConverter>();
        converters.addAll(CallConverters.defaults(this.typeConverter));
        converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory));
        converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
        this.aggregateFunctionConverter = new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory);
        WindowFunctionConverter windowFunctionConverter = new WindowFunctionConverter(extensions.windowFunctions(), typeFactory);
        this.rexExpressionConverter = new RexExpressionConverter(this, converters, windowFunctionConverter, this.typeConverter);
        this.featureBoard = features;
    }

    public SubstraitRelVisitor(RelDataTypeFactory typeFactory, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter, FeatureBoard features) {
        ArrayList<CallConverter> converters = new ArrayList<CallConverter>();
        converters.addAll(CallConverters.defaults(typeConverter));
        converters.add(scalarFunctionConverter);
        converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory)));
        this.aggregateFunctionConverter = aggregateFunctionConverter;
        this.rexExpressionConverter = new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter);
        this.typeConverter = typeConverter;
        this.featureBoard = features;
    }

    protected Expression toExpression(RexNode node) {
        return (Expression)node.accept((RexVisitor)this.rexExpressionConverter);
    }

    @Override
    public Rel visit(TableScan scan) {
        NamedStruct type = this.typeConverter.toNamedStruct(scan.getRowType());
        return NamedScan.builder().initialSchema(type).addAllNames((Iterable)scan.getTable().getQualifiedName()).build();
    }

    @Override
    public Rel visit(TableFunctionScan scan) {
        return (Rel)super.visit(scan);
    }

    @Override
    public Rel visit(Values values) {
        NamedStruct type = this.typeConverter.toNamedStruct(values.getRowType());
        if (values.getTuples().isEmpty()) {
            return EmptyScan.builder().initialSchema(type).build();
        }
        LiteralConverter literalConverter = new LiteralConverter(this.typeConverter);
        List structs = values.getTuples().stream().map(list -> {
            List fields = list.stream().map(l -> literalConverter.convert((RexLiteral)l)).collect(Collectors.toUnmodifiableList());
            return ExpressionCreator.struct((boolean)false, fields);
        }).collect(Collectors.toUnmodifiableList());
        return VirtualTableScan.builder().initialSchema(type).addAllRows(structs).build();
    }

    @Override
    public Rel visit(Filter filter) {
        Expression condition = this.toExpression(filter.getCondition());
        return io.substrait.relation.Filter.builder().condition(condition).input(this.apply(filter.getInput())).build();
    }

    @Override
    public Rel visit(Calc calc) {
        return (Rel)super.visit(calc);
    }

    @Override
    public Rel visit(org.apache.calcite.rel.core.Project project) {
        List expressions = project.getProjects().stream().map(this::toExpression).collect(Collectors.toList());
        return Project.builder().remap(Rel.Remap.offset((int)project.getInput().getRowType().getFieldCount(), (int)expressions.size())).expressions(expressions).input(this.apply(project.getInput())).build();
    }

    @Override
    public Rel visit(Join join) {
        Rel left = this.apply(join.getLeft());
        Rel right = this.apply(join.getRight());
        Expression condition = this.toExpression(join.getCondition());
        Join.JoinType joinType = this.asJoinType(join);
        if (joinType == Join.JoinType.INNER && TRUE.equals(condition)) {
            return Cross.builder().left(left).right(right).build();
        }
        return io.substrait.relation.Join.builder().condition(condition).joinType(joinType).left(left).right(right).build();
    }

    private Join.JoinType asJoinType(Join join) {
        JoinRelType type = join.getJoinType();
        if (type == JoinRelType.INNER) {
            return Join.JoinType.INNER;
        }
        if (type == JoinRelType.LEFT) {
            return Join.JoinType.LEFT;
        }
        if (type == JoinRelType.RIGHT) {
            return Join.JoinType.RIGHT;
        }
        if (type == JoinRelType.FULL) {
            return Join.JoinType.OUTER;
        }
        if (type == JoinRelType.SEMI) {
            return Join.JoinType.LEFT_SEMI;
        }
        if (type == JoinRelType.ANTI) {
            return Join.JoinType.LEFT_ANTI;
        }
        throw new UnsupportedOperationException("Unsupported join type: " + String.valueOf(join.getJoinType()));
    }

    @Override
    public Rel visit(Correlate correlate) {
        this.apply(correlate.getLeft());
        this.apply(correlate.getRight());
        return (Rel)super.visit(correlate);
    }

    @Override
    public Rel visit(Union union) {
        List<Rel> inputs = this.apply(union.getInputs());
        Set.SetOp setOp = union.all ? Set.SetOp.UNION_ALL : Set.SetOp.UNION_DISTINCT;
        return Set.builder().inputs(inputs).setOp(setOp).build();
    }

    @Override
    public Rel visit(Intersect intersect) {
        List<Rel> inputs = this.apply(intersect.getInputs());
        Set.SetOp setOp = intersect.all ? Set.SetOp.INTERSECTION_MULTISET_ALL : Set.SetOp.INTERSECTION_MULTISET;
        return Set.builder().inputs(inputs).setOp(setOp).build();
    }

    @Override
    public Rel visit(Minus minus) {
        List<Rel> inputs = this.apply(minus.getInputs());
        Set.SetOp setOp = minus.all ? Set.SetOp.MINUS_PRIMARY_ALL : Set.SetOp.MINUS_PRIMARY;
        return Set.builder().inputs(inputs).setOp(setOp).build();
    }

    @Override
    public Rel visit(Aggregate aggregate) {
        Rel input = this.apply(aggregate.getInput());
        Stream<ImmutableBitSet> sets = aggregate.groupSets != null ? aggregate.groupSets.stream() : Stream.of(aggregate.getGroupSet());
        List groupings = sets.filter(s -> s != null).map(s -> this.fromGroupSet((ImmutableBitSet)s, input)).collect(Collectors.toList());
        List groupIdCalls = aggregate.getAggCallList().stream().filter(c -> c.getAggregation().equals((Object)SqlStdOperatorTable.GROUP_ID)).collect(Collectors.toList());
        List filteredAggCalls = aggregate.getAggCallList().stream().filter(c -> !groupIdCalls.contains(c)).collect(Collectors.toList());
        List aggCalls = filteredAggCalls.stream().map(c -> this.fromAggCall(aggregate.getInput(), input.getRecordType(), (AggregateCall)c)).collect(Collectors.toList());
        ImmutableAggregate.Builder builder = io.substrait.relation.Aggregate.builder().input(input).addAllGroupings(groupings).addAllMeasures(aggCalls);
        if (groupings.size() > 1) {
            if (groupIdCalls.isEmpty()) {
                int groupingExprSize = Math.toIntExact(groupings.stream().flatMap(g -> g.getExpressions().stream()).distinct().count());
                builder.remap(Rel.Remap.offset((int)0, (int)(groupingExprSize + aggCalls.size())));
            } else {
                int groupingFieldCount = Math.toIntExact(groupings.stream().flatMap(g -> g.getExpressions().stream()).count());
                int filterAggCallCount = aggCalls.size();
                Integer groupingSetIndex = groupingFieldCount + filterAggCallCount;
                List remap = IntStream.range(0, groupingFieldCount).mapToObj(i -> i).collect(Collectors.toCollection(ArrayList::new));
                for (int i2 = 0; i2 < aggregate.getAggCallList().size(); ++i2) {
                    AggregateCall aggCall = (AggregateCall)aggregate.getAggCallList().get(i2);
                    if (filteredAggCalls.contains(aggCall)) {
                        remap.add(i2 + groupingFieldCount, filteredAggCalls.indexOf(aggCall) + groupingFieldCount);
                        continue;
                    }
                    if (groupIdCalls.contains(aggCall)) {
                        remap.add(i2 + groupingFieldCount, groupingSetIndex);
                        continue;
                    }
                    throw new IllegalStateException("encountered AggregateCall that is neither in filteredAggCalls nor in groupIdCalls" + String.valueOf(aggCall));
                }
                builder.remap(Rel.Remap.of((Iterable)remap));
            }
        }
        return builder.build();
    }

    Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) {
        List references = bitSet.asList().stream().map(i -> FieldReference.newInputRelReference((int)i, (Rel)input)).collect(Collectors.toList());
        return Aggregate.Grouping.builder().addAllExpressions(references).build();
    }

    Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) {
        Optional<AggregateFunctionInvocation> invocation = this.aggregateFunctionConverter.convert(input, inputType, call, t -> (Expression)t.accept((RexVisitor)this.rexExpressionConverter));
        if (invocation.isEmpty()) {
            throw new UnsupportedOperationException("Unable to find binding for call " + String.valueOf(call));
        }
        ImmutableMeasure.Builder builder = Aggregate.Measure.builder().function(invocation.get());
        if (call.filterArg != -1) {
            builder.preMeasureFilter((Expression)FieldReference.newRootStructReference((int)call.filterArg, (Type)inputType));
        }
        return builder.build();
    }

    @Override
    public Rel visit(Match match) {
        return (Rel)super.visit(match);
    }

    @Override
    public Rel visit(org.apache.calcite.rel.core.Sort sort) {
        Rel input;
        Rel output = input = this.apply(sort.getInput());
        if (!sort.getCollation().getFieldCollations().isEmpty()) {
            List fields = sort.getCollation().getFieldCollations().stream().map(t -> SubstraitRelVisitor.toSortField(t, input.getRecordType())).collect(Collectors.toList());
            output = Sort.builder().addAllSortFields(fields).input(output).build();
        }
        if (sort.fetch != null || sort.offset != null) {
            Long offset = Optional.ofNullable(sort.offset).map(this::asLong).orElse(0L);
            OptionalLong count = Optional.ofNullable(sort.fetch).map(r -> OptionalLong.of(this.asLong((RexNode)r))).orElse(OptionalLong.empty());
            ImmutableFetch.Builder builder = Fetch.builder().input(output).offset(offset.longValue()).count(count);
            output = builder.build();
        }
        return output;
    }

    private long asLong(RexNode rex) {
        Expression expr = this.toExpression(rex);
        if (expr instanceof Expression.I64Literal) {
            return ((Expression.I64Literal)expr).value();
        }
        if (expr instanceof Expression.I32Literal) {
            return ((Expression.I32Literal)expr).value();
        }
        throw new UnsupportedOperationException("Unknown type: " + String.valueOf(rex));
    }

    public static Expression.SortField toSortField(RelFieldCollation collation, Type.Struct inputType) {
        Expression.SortDirection direction = SubstraitRelVisitor.asSortDirection(collation);
        return Expression.SortField.builder().expr((Expression)FieldReference.newRootStructReference((int)collation.getFieldIndex(), (Type)inputType)).direction(direction).build();
    }

    private static Expression.SortDirection asSortDirection(RelFieldCollation collation) {
        RelFieldCollation.Direction direction = collation.direction;
        if (direction == RelFieldCollation.Direction.STRICTLY_ASCENDING || direction == RelFieldCollation.Direction.ASCENDING) {
            return collation.nullDirection == RelFieldCollation.NullDirection.LAST ? Expression.SortDirection.ASC_NULLS_LAST : Expression.SortDirection.ASC_NULLS_FIRST;
        }
        if (direction == RelFieldCollation.Direction.STRICTLY_DESCENDING || direction == RelFieldCollation.Direction.DESCENDING) {
            return collation.nullDirection == RelFieldCollation.NullDirection.LAST ? Expression.SortDirection.DESC_NULLS_LAST : Expression.SortDirection.DESC_NULLS_FIRST;
        }
        if (direction == RelFieldCollation.Direction.CLUSTERED) {
            return Expression.SortDirection.CLUSTERED;
        }
        throw new IllegalArgumentException("Unsupported collation direction: " + String.valueOf(direction));
    }

    @Override
    public Rel visit(Exchange exchange) {
        return (Rel)super.visit(exchange);
    }

    @Override
    public Rel visit(TableModify modify) {
        switch (modify.getOperation()) {
            case INSERT: 
            case DELETE: {
                AbstractWriteRel.WriteOp op;
                Rel input = this.apply(modify.getInput());
                AbstractWriteRel.WriteOp writeOp = op = modify.getOperation() == TableModify.Operation.INSERT ? AbstractWriteRel.WriteOp.INSERT : AbstractWriteRel.WriteOp.DELETE;
                assert (modify.getTable() != null);
                return NamedWrite.builder().input(input).tableSchema(this.typeConverter.toNamedStruct(modify.getTable().getRowType())).operation(op).createMode(AbstractWriteRel.CreateMode.UNSPECIFIED).outputMode(AbstractWriteRel.OutputMode.MODIFIED_RECORDS).names((Iterable)modify.getTable().getQualifiedName()).build();
            }
            case UPDATE: {
                Expression condition;
                assert (modify.getTable() != null);
                RelNode input = modify.getInput();
                if (input instanceof Filter) {
                    Filter filter = (Filter)input;
                    condition = this.toExpression(filter.getCondition());
                } else {
                    condition = Expression.BoolLiteral.builder().nullable(false).value(Boolean.valueOf(true)).build();
                }
                List updateColumnNames = modify.getUpdateColumnList();
                List<RexNode> sourceExpressions = this.getSourceExpressions(modify);
                List allTableColumnNames = modify.getTable().getRowType().getFieldNames();
                ArrayList<ImmutableTransformExpression> transformations = new ArrayList<ImmutableTransformExpression>();
                for (int i = 0; i < updateColumnNames.size(); ++i) {
                    String colName = (String)updateColumnNames.get(i);
                    RexNode rexExpr = sourceExpressions.get(i);
                    int columnIndex = allTableColumnNames.indexOf(colName);
                    if (columnIndex == -1) {
                        throw new IllegalStateException("Updated column '" + colName + "' not found in table schema.");
                    }
                    Expression substraitExpr = this.toExpression(rexExpr);
                    transformations.add(AbstractUpdate.TransformExpression.builder().columnTarget(columnIndex).transformation(substraitExpr).build());
                }
                return NamedUpdate.builder().tableSchema(this.typeConverter.toNamedStruct(modify.getTable().getRowType())).names((Iterable)modify.getTable().getQualifiedName()).condition(condition).transformations(transformations).build();
            }
        }
        return (Rel)super.visit(modify);
    }

    private List<RexNode> getSourceExpressions(TableModify modify) {
        List results = modify.getSourceExpressionList();
        if (results == null) {
            return Collections.emptyList();
        }
        RelNode input = modify.getInput();
        if (input instanceof org.apache.calcite.rel.core.Project) {
            return this.resolveProjectedRefs(results, (org.apache.calcite.rel.core.Project)input);
        }
        return results;
    }

    private List<RexNode> resolveProjectedRefs(List<RexNode> expressions, org.apache.calcite.rel.core.Project project) {
        List projects = project.getProjects();
        return expressions.stream().map(expression -> {
            if (expression instanceof RexInputRef) {
                int refIndex = ((RexInputRef)expression).getIndex();
                return (RexNode)projects.get(refIndex);
            }
            return expression;
        }).collect(Collectors.toList());
    }

    private NamedStruct getSchema(RelNode queryRelRoot) {
        RelDataType rowType = queryRelRoot.getRowType();
        return this.typeConverter.toNamedStruct(rowType);
    }

    public Rel handleCreateTable(CreateTable createTable) {
        RelNode input = createTable.getInput();
        Rel inputRel = this.apply(input);
        NamedStruct schema = this.getSchema(input);
        return NamedWrite.builder().input(inputRel).tableSchema(schema).operation(AbstractWriteRel.WriteOp.CTAS).createMode(AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS).outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT).names(createTable.getTableName()).build();
    }

    public Rel handleCreateView(CreateView createView) {
        RelNode input = createView.getInput();
        Rel inputRel = this.apply(input);
        Expression.StructLiteral defaults = ExpressionCreator.struct((boolean)false, (Expression.Literal[])new Expression.Literal[0]);
        return NamedDdl.builder().viewDefinition(inputRel).tableSchema(this.getSchema(input)).tableDefaults(defaults).operation(AbstractDdlRel.DdlOp.CREATE).object(AbstractDdlRel.DdlObject.VIEW).names(createView.getViewName()).build();
    }

    @Override
    public Rel visitOther(RelNode other) {
        if (other instanceof CreateTable) {
            return this.handleCreateTable((CreateTable)other);
        }
        if (other instanceof CreateView) {
            return this.handleCreateView((CreateView)other);
        }
        throw new UnsupportedOperationException("Unable to handle node: " + String.valueOf(other));
    }

    protected void popFieldAccessDepthMap(RelNode root) {
        OuterReferenceResolver resolver = new OuterReferenceResolver();
        resolver.apply(root);
        this.fieldAccessDepthMap = resolver.getFieldAccessDepthMap();
    }

    public Integer getFieldAccessDepth(RexFieldAccess fieldAccess) {
        return this.fieldAccessDepthMap.get(fieldAccess);
    }

    public Rel apply(RelNode r) {
        return (Rel)this.reverseAccept(r);
    }

    public List<Rel> apply(List<RelNode> inputs) {
        return inputs.stream().map(inputRel -> this.apply((RelNode)inputRel)).collect(Collectors.toList());
    }

    public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) {
        return SubstraitRelVisitor.convert(relRoot, extensions, FEATURES_DEFAULT);
    }

    public static Plan.Root convert(RelRoot relRoot, SubstraitRelVisitor visitor) {
        visitor.popFieldAccessDepthMap(relRoot.rel);
        Rel rel = visitor.apply(relRoot.project());
        List names = visitor.typeConverter.toNamedStruct(relRoot.validatedRowType).names();
        return Plan.Root.builder().input(rel).names((Iterable)names).build();
    }

    public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
        return SubstraitRelVisitor.convert(relRoot, new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features));
    }

    public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) {
        return SubstraitRelVisitor.convert(relNode, extensions, FEATURES_DEFAULT);
    }

    public static Rel convert(RelNode relNode, SubstraitRelVisitor visitor) {
        visitor.popFieldAccessDepthMap(relNode);
        return visitor.apply(relNode);
    }

    public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) {
        return SubstraitRelVisitor.convert(relNode, new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features));
    }
}

