/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.isthmus;

import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.TypeExpression;
import io.substrait.isthmus.PreCalciteAggregateValidator;
import io.substrait.isthmus.SqlToSubstrait;
import io.substrait.isthmus.TypeConverter;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.ExpressionRexConverter;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.AbstractRelVisitor;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.LocalFiles;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.RelVisitor;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.tools.FrameworkConfig;
import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.Program;
import org.apache.calcite.tools.RelBuilder;

public class SubstraitRelNodeConverter
extends AbstractRelVisitor<RelNode, RuntimeException> {
    protected final RelDataTypeFactory typeFactory;
    protected final ScalarFunctionConverter scalarFunctionConverter;
    protected final AggregateFunctionConverter aggregateFunctionConverter;
    protected final ExpressionRexConverter expressionRexConverter;
    protected final RelBuilder relBuilder;
    protected final RexBuilder rexBuilder;
    private final TypeConverter typeConverter;

    public SubstraitRelNodeConverter(SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory, RelBuilder relBuilder) {
        this(typeFactory, relBuilder, new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory), new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), TypeConverter.DEFAULT);
    }

    public SubstraitRelNodeConverter(RelDataTypeFactory typeFactory, RelBuilder relBuilder, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter) {
        this(typeFactory, relBuilder, scalarFunctionConverter, aggregateFunctionConverter, windowFunctionConverter, typeConverter, new ExpressionRexConverter(typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter));
    }

    public SubstraitRelNodeConverter(RelDataTypeFactory typeFactory, RelBuilder relBuilder, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, WindowFunctionConverter windowFunctionConverter, TypeConverter typeConverter, ExpressionRexConverter expressionRexConverter) {
        this.typeFactory = typeFactory;
        this.typeConverter = typeConverter;
        this.relBuilder = relBuilder;
        this.rexBuilder = new RexBuilder(typeFactory);
        this.scalarFunctionConverter = scalarFunctionConverter;
        this.aggregateFunctionConverter = aggregateFunctionConverter;
        this.expressionRexConverter = expressionRexConverter;
        this.expressionRexConverter.setRelNodeConverter(this);
    }

    public static RelNode convert(Rel relRoot, RelOptCluster relOptCluster, Prepare.CatalogReader catalogReader, SqlParser.Config parserConfig) {
        RelBuilder relBuilder = RelBuilder.create((FrameworkConfig)Frameworks.newConfigBuilder().parserConfig(parserConfig).defaultSchema(catalogReader.getRootSchema().plus()).traitDefs((List)null).programs(new Program[0]).build());
        return (RelNode)relRoot.accept((RelVisitor)new SubstraitRelNodeConverter(SqlToSubstrait.EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder));
    }

    public RelNode visit(Filter filter) throws RuntimeException {
        RelNode input = (RelNode)filter.getInput().accept((RelVisitor)this);
        RexNode filterCondition = (RexNode)filter.getCondition().accept((ExpressionVisitor)this.expressionRexConverter);
        RelNode node = this.relBuilder.push(input).filter(new RexNode[]{filterCondition}).build();
        return this.applyRemap(node, filter.getRemap());
    }

    public RelNode visit(NamedScan namedScan) throws RuntimeException {
        RelNode node = this.relBuilder.scan((Iterable)namedScan.getNames()).build();
        return this.applyRemap(node, namedScan.getRemap());
    }

    public RelNode visit(LocalFiles localFiles) throws RuntimeException {
        return this.visitFallback((Rel)localFiles);
    }

    public RelNode visit(Project project) throws RuntimeException {
        RelNode child = (RelNode)project.getInput().accept((RelVisitor)this);
        Stream<RexNode> directOutputs = IntStream.range(0, child.getRowType().getFieldCount()).mapToObj(fieldIndex -> this.rexBuilder.makeInputRef(child, fieldIndex));
        Stream<RexNode> exprs = project.getExpressions().stream().map(expr -> (RexNode)expr.accept((ExpressionVisitor)this.expressionRexConverter));
        List rexExprs = Stream.concat(directOutputs, exprs).collect(Collectors.toList());
        RelNode node = this.relBuilder.push(child).project(rexExprs).build();
        return this.applyRemap(node, project.getRemap());
    }

    public RelNode visit(Cross cross) throws RuntimeException {
        RelNode left = (RelNode)cross.getLeft().accept((RelVisitor)this);
        RelNode right = (RelNode)cross.getRight().accept((RelVisitor)this);
        RelNode node = this.relBuilder.push(left).push(right).join(JoinRelType.INNER, (RexNode)this.relBuilder.literal((Object)true)).build();
        return this.applyRemap(node, cross.getRemap());
    }

    public RelNode visit(Join join) throws RuntimeException {
        JoinRelType joinRelType;
        RelNode left = (RelNode)join.getLeft().accept((RelVisitor)this);
        RelNode right = (RelNode)join.getRight().accept((RelVisitor)this);
        RexNode condition = join.getCondition().map(c -> (RexNode)c.accept((ExpressionVisitor)this.expressionRexConverter)).orElse((RexNode)this.relBuilder.literal((Object)true));
        switch (join.getJoinType()) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case INNER: {
                joinRelType = JoinRelType.INNER;
                break;
            }
            case LEFT: {
                joinRelType = JoinRelType.LEFT;
                break;
            }
            case RIGHT: {
                joinRelType = JoinRelType.RIGHT;
                break;
            }
            case OUTER: {
                joinRelType = JoinRelType.FULL;
                break;
            }
            case SEMI: {
                joinRelType = JoinRelType.SEMI;
                break;
            }
            case ANTI: {
                joinRelType = JoinRelType.ANTI;
                break;
            }
            case UNKNOWN: {
                throw new UnsupportedOperationException("Unknown join type is not supported");
            }
        }
        JoinRelType joinType = joinRelType;
        RelNode node = this.relBuilder.push(left).push(right).join(joinType, condition).build();
        return this.applyRemap(node, join.getRemap());
    }

    public RelNode visit(Set set) throws RuntimeException {
        RelBuilder relBuilder;
        int numInputs = set.getInputs().size();
        set.getInputs().forEach(input -> this.relBuilder.push((RelNode)input.accept((RelVisitor)this)));
        switch (set.getSetOp()) {
            default: {
                throw new IncompatibleClassChangeError();
            }
            case MINUS_PRIMARY: {
                relBuilder = this.relBuilder.minus(false, numInputs);
                break;
            }
            case MINUS_MULTISET: {
                relBuilder = this.relBuilder.minus(true, numInputs);
                break;
            }
            case INTERSECTION_PRIMARY: {
                relBuilder = this.relBuilder.intersect(false, numInputs);
                break;
            }
            case INTERSECTION_MULTISET: {
                relBuilder = this.relBuilder.intersect(true, numInputs);
                break;
            }
            case UNION_DISTINCT: {
                relBuilder = this.relBuilder.union(false, numInputs);
                break;
            }
            case UNION_ALL: {
                relBuilder = this.relBuilder.union(true, numInputs);
                break;
            }
            case UNKNOWN: {
                throw new UnsupportedOperationException("Unknown set operation is not supported");
            }
        }
        RelBuilder builder = relBuilder;
        RelNode node = builder.build();
        return this.applyRemap(node, set.getRemap());
    }

    public RelNode visit(Aggregate aggregate) throws RuntimeException {
        if (!PreCalciteAggregateValidator.isValidCalciteAggregate(aggregate)) {
            aggregate = PreCalciteAggregateValidator.PreCalciteAggregateTransformer.transformToValidCalciteAggregate(aggregate);
        }
        RelNode child = (RelNode)aggregate.getInput().accept((RelVisitor)this);
        List groupExprLists = aggregate.getGroupings().stream().map(gr -> gr.getExpressions().stream().map(expr -> (RexNode)expr.accept((ExpressionVisitor)this.expressionRexConverter)).collect(Collectors.toList())).collect(Collectors.toList());
        List groupExprs = groupExprLists.stream().flatMap(Collection::stream).collect(Collectors.toList());
        RelBuilder.GroupKey groupKey = this.relBuilder.groupKey(groupExprs, groupExprLists);
        List aggregateCalls = aggregate.getMeasures().stream().map(this::fromMeasure).collect(Collectors.toList());
        RelNode node = this.relBuilder.push(child).aggregate(groupKey, aggregateCalls).build();
        return this.applyRemap(node, aggregate.getRemap());
    }

    private AggregateCall fromMeasure(Aggregate.Measure measure) {
        List eArgs = measure.getFunction().arguments();
        List arguments = IntStream.range(0, measure.getFunction().arguments().size()).mapToObj(i -> (RexNode)((FunctionArg)eArgs.get(i)).accept((SimpleExtension.Function)measure.getFunction().declaration(), i, (FunctionArg.FuncArgVisitor)this.expressionRexConverter)).collect(Collectors.toList());
        Optional<SqlOperator> operator = this.aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc(measure.getFunction().declaration().key(), measure.getFunction().outputType());
        if (!operator.isPresent()) {
            throw new IllegalArgumentException(String.format("Unable to find binding for call %s", measure.getFunction().declaration().name()));
        }
        ArrayList<Integer> argIndex = new ArrayList<Integer>();
        for (RexNode arg : arguments) {
            argIndex.add(((RexInputRef)arg).getIndex());
        }
        boolean distinct = measure.getFunction().invocation().equals((Object)Expression.AggregationInvocation.DISTINCT);
        RelDataType returnType = this.typeConverter.toCalcite(this.typeFactory, (TypeExpression)measure.getFunction().getType());
        if (!(operator.get() instanceof SqlAggFunction)) {
            String msg = String.format("Unable to convert non-aggregate operator: %s for substrait aggregate function %s", operator.get(), measure.getFunction().declaration().name());
            throw new IllegalArgumentException(msg);
        }
        SqlAggFunction aggFunction = (SqlAggFunction)operator.get();
        int filterArg = -1;
        if (measure.getPreMeasureFilter().isPresent()) {
            RexNode filter = (RexNode)((Expression)measure.getPreMeasureFilter().get()).accept((ExpressionVisitor)this.expressionRexConverter);
            filterArg = ((RexInputRef)filter).getIndex();
        }
        RelCollation relCollation = RelCollations.EMPTY;
        if (!measure.getFunction().sort().isEmpty()) {
            relCollation = RelCollations.of(measure.getFunction().sort().stream().map(sortField -> this.toRelFieldCollation((Expression.SortField)sortField)).collect(Collectors.toList()));
        }
        return AggregateCall.create((SqlAggFunction)aggFunction, (boolean)distinct, (boolean)false, (boolean)false, Collections.emptyList(), argIndex, (int)filterArg, null, (RelCollation)relCollation, (RelDataType)returnType, null);
    }

    public RelNode visit(Sort sort) throws RuntimeException {
        RelNode child = (RelNode)sort.getInput().accept((RelVisitor)this);
        List relFieldCollations = sort.getSortFields().stream().map(sortField -> this.toRelFieldCollation((Expression.SortField)sortField)).collect(Collectors.toList());
        if (relFieldCollations.isEmpty()) {
            return this.relBuilder.push(child).sort((Iterable)Collections.EMPTY_LIST).build();
        }
        RelNode node = this.relBuilder.push(child).sort(RelCollations.of(relFieldCollations)).build();
        return this.applyRemap(node, sort.getRemap());
    }

    public RelNode visit(Fetch fetch) throws RuntimeException {
        RelNode child = (RelNode)fetch.getInput().accept((RelVisitor)this);
        OptionalLong optCount = fetch.getCount();
        long count = optCount.orElse(-1L);
        long offset = fetch.getOffset();
        if (offset > Integer.MAX_VALUE) {
            throw new RuntimeException(String.format("offset is overflowed as an integer: %d", offset));
        }
        if (count > Integer.MAX_VALUE) {
            throw new RuntimeException(String.format("count is overflowed as an integer: %d", count));
        }
        RelNode node = this.relBuilder.push(child).limit((int)offset, (int)count).build();
        return this.applyRemap(node, fetch.getRemap());
    }

    private RelFieldCollation toRelFieldCollation(Expression.SortField sortField) {
        Expression expression = sortField.expr();
        RexNode rex = (RexNode)expression.accept((ExpressionVisitor)this.expressionRexConverter);
        Expression.SortDirection sortDirection = sortField.direction();
        RexSlot rexSlot = (RexSlot)rex;
        int fieldIndex = rexSlot.getIndex();
        RelFieldCollation.Direction fieldDirection = RelFieldCollation.Direction.ASCENDING;
        RelFieldCollation.NullDirection nullDirection = RelFieldCollation.NullDirection.UNSPECIFIED;
        switch (sortDirection) {
            case ASC_NULLS_FIRST: {
                nullDirection = RelFieldCollation.NullDirection.FIRST;
                break;
            }
            case ASC_NULLS_LAST: {
                nullDirection = RelFieldCollation.NullDirection.LAST;
                break;
            }
            case DESC_NULLS_FIRST: {
                nullDirection = RelFieldCollation.NullDirection.FIRST;
                fieldDirection = RelFieldCollation.Direction.DESCENDING;
                break;
            }
            case DESC_NULLS_LAST: {
                nullDirection = RelFieldCollation.NullDirection.LAST;
                fieldDirection = RelFieldCollation.Direction.DESCENDING;
                break;
            }
            case CLUSTERED: {
                fieldDirection = RelFieldCollation.Direction.CLUSTERED;
                break;
            }
            default: {
                throw new RuntimeException(String.format("Unexpected Expression.SortDirection enum: %s !", sortDirection));
            }
        }
        return new RelFieldCollation(fieldIndex, fieldDirection, nullDirection);
    }

    public RelNode visitFallback(Rel rel) throws RuntimeException {
        throw new UnsupportedOperationException(String.format("Rel %s of type %s not handled by visitor type %s.", rel, rel.getClass().getCanonicalName(), ((Object)((Object)this)).getClass().getCanonicalName()));
    }

    protected RelNode applyRemap(RelNode relNode, Optional<Rel.Remap> remap) {
        if (remap.isPresent()) {
            return this.applyRemap(relNode, remap.get());
        }
        return relNode;
    }

    private RelNode applyRemap(RelNode relNode, Rel.Remap remap) {
        RelDataType rowType = relNode.getRowType();
        List fieldNames = rowType.getFieldNames();
        List rexList = remap.indices().stream().map(index -> {
            RelDataTypeField t = rowType.getField((String)fieldNames.get((int)index), true, false);
            return new RexInputRef(index.intValue(), (RelDataType)t.getValue());
        }).collect(Collectors.toList());
        return this.relBuilder.push(relNode).project(rexList).build();
    }

    private void checkRexInputRefOnly(RexNode rexNode, String context, String aggName) {
        if (!(rexNode instanceof RexInputRef)) {
            throw new UnsupportedOperationException(String.format("Compound expression %s in %s of agg function %s is not implemented yet.", rexNode, context, aggName));
        }
    }
}

