/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.expression.proto;

import com.google.protobuf.GeneratedMessageV3;
import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.WindowBound;
import io.substrait.expression.WindowFunctionInvocation;
import io.substrait.expression.proto.FunctionCollector;
import io.substrait.proto.Expression;
import io.substrait.proto.FunctionArgument;
import io.substrait.proto.Rel;
import io.substrait.proto.SortField;
import io.substrait.relation.RelVisitor;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ExpressionProtoConverter
implements ExpressionVisitor<io.substrait.proto.Expression, RuntimeException> {
    static final Logger logger = LoggerFactory.getLogger(ExpressionProtoConverter.class);
    private final FunctionCollector functionCollector;
    private final RelVisitor<Rel, RuntimeException> relVisitor;

    public ExpressionProtoConverter(FunctionCollector functionCollector, RelVisitor<Rel, RuntimeException> relVisitor) {
        this.functionCollector = functionCollector;
        this.relVisitor = relVisitor;
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.NullLiteral expr) {
        return this.lit(bldr -> bldr.setNull(expr.type().accept(TypeProtoConverter.INSTANCE)));
    }

    private io.substrait.proto.Expression lit(Consumer<Expression.Literal.Builder> consumer) {
        Expression.Literal.Builder builder = Expression.Literal.newBuilder();
        consumer.accept(builder);
        return io.substrait.proto.Expression.newBuilder().setLiteral(builder).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.BoolLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setBoolean(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.I8Literal expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setI8(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.I16Literal expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setI16(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.I32Literal expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setI32(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.I64Literal expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setI64(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.FP32Literal expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setFp32(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.FP64Literal expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setFp64(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.StrLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setString(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.BinaryLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setBinary(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.TimeLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setTime(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.DateLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setDate(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.TimestampLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setTimestamp(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.TimestampTZLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setTimestampTz(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.IntervalYearLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setIntervalYearToMonth(Expression.Literal.IntervalYearToMonth.newBuilder().setYears(expr.years()).setMonths(expr.months())));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.IntervalDayLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setIntervalDayToSecond(Expression.Literal.IntervalDayToSecond.newBuilder().setDays(expr.days()).setSeconds(expr.seconds())));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.UUIDLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setUuid(expr.toBytes()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.FixedCharLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setFixedChar(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.VarCharLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setVarChar(Expression.Literal.VarChar.newBuilder().setValue(expr.value()).setLength(expr.length())));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.FixedBinaryLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setFixedBinary(expr.value()));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.DecimalLiteral expr) {
        return this.lit(bldr -> bldr.setNullable(expr.nullable()).setDecimal(Expression.Literal.Decimal.newBuilder().setValue(expr.value()).setPrecision(expr.precision()).setScale(expr.scale())));
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.MapLiteral expr) {
        return this.lit(bldr -> {
            List keyValues = expr.values().entrySet().stream().map(e -> {
                Expression.Literal key = this.toLiteral((Expression)e.getKey());
                Expression.Literal value = this.toLiteral((Expression)e.getValue());
                return Expression.Literal.Map.KeyValue.newBuilder().setKey(key).setValue(value).build();
            }).collect(Collectors.toList());
            bldr.setNullable(expr.nullable()).setMap(Expression.Literal.Map.newBuilder().addAllKeyValues(keyValues));
        });
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.ListLiteral expr) {
        return this.lit(bldr -> {
            List values = expr.values().stream().map(this::toLiteral).collect(Collectors.toList());
            bldr.setNullable(expr.nullable()).setList(Expression.Literal.List.newBuilder().addAllValues(values));
        });
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.StructLiteral expr) {
        return this.lit(bldr -> {
            List values = expr.fields().stream().map(this::toLiteral).collect(Collectors.toList());
            bldr.setNullable(expr.nullable()).setStruct(Expression.Literal.Struct.newBuilder().addAllFields(values));
        });
    }

    private Expression.Literal toLiteral(Expression expression) {
        io.substrait.proto.Expression e = expression.accept(this);
        assert (e.getRexTypeCase() == Expression.RexTypeCase.LITERAL);
        return e.getLiteral();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.Switch expr) {
        List clauses = expr.switchClauses().stream().map(s -> Expression.SwitchExpression.IfValue.newBuilder().setIf(this.toLiteral(s.condition())).setThen(s.then().accept(this)).build()).collect(Collectors.toList());
        return io.substrait.proto.Expression.newBuilder().setSwitchExpression(Expression.SwitchExpression.newBuilder().addAllIfs(clauses).setElse(expr.defaultClause().accept(this))).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.IfThen expr) {
        List clauses = expr.ifClauses().stream().map(s -> Expression.IfThen.IfClause.newBuilder().setIf(s.condition().accept(this)).setThen(s.then().accept(this)).build()).collect(Collectors.toList());
        return io.substrait.proto.Expression.newBuilder().setIfThen(Expression.IfThen.newBuilder().addAllIfs(clauses).setElse(expr.elseClause().accept(this))).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.ScalarFunctionInvocation expr) {
        FunctionArg.FuncArgVisitor<FunctionArgument, RuntimeException> argVisitor = FunctionArg.toProto(TypeProtoConverter.INSTANCE, this);
        return io.substrait.proto.Expression.newBuilder().setScalarFunction(Expression.ScalarFunction.newBuilder().setOutputType(expr.getType().accept(TypeProtoConverter.INSTANCE)).setFunctionReference(this.functionCollector.getFunctionReference(expr.declaration())).addAllArguments(expr.arguments().stream().map(a -> (FunctionArgument)a.accept(expr.declaration(), 0, argVisitor)).collect(Collectors.toList()))).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.Cast expr) {
        return io.substrait.proto.Expression.newBuilder().setCast(Expression.Cast.newBuilder().setInput(expr.input().accept(this)).setType(expr.getType().accept(TypeProtoConverter.INSTANCE))).build();
    }

    private io.substrait.proto.Expression from(Expression expr) {
        return expr.accept(this);
    }

    private List<io.substrait.proto.Expression> from(List<Expression> expr) {
        return expr.stream().map(this::from).collect(Collectors.toList());
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.SingleOrList expr) throws RuntimeException {
        return io.substrait.proto.Expression.newBuilder().setSingularOrList(Expression.SingularOrList.newBuilder().setValue(expr.condition().accept(this)).addAllOptions(this.from(expr.options()))).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.MultiOrList expr) throws RuntimeException {
        return io.substrait.proto.Expression.newBuilder().setMultiOrList(Expression.MultiOrList.newBuilder().addAllValue(this.from(expr.conditions())).addAllOptions(expr.optionCombinations().stream().map(r -> Expression.MultiOrList.Record.newBuilder().addAllFields(this.from(r.values())).build()).collect(Collectors.toList()))).build();
    }

    @Override
    public io.substrait.proto.Expression visit(FieldReference expr) {
        Expression.ReferenceSegment top = null;
        Expression.ReferenceSegment seg = null;
        for (FieldReference.ReferenceSegment segment : expr.segments()) {
            Expression.ReferenceSegment.Builder protoSegment;
            GeneratedMessageV3.Builder bldr;
            if (segment instanceof FieldReference.StructField) {
                FieldReference.StructField f = (FieldReference.StructField)segment;
                bldr = Expression.ReferenceSegment.StructField.newBuilder().setField(f.offset());
                if (seg != null) {
                    bldr.setChild(seg);
                }
                protoSegment = Expression.ReferenceSegment.newBuilder().setStructField((Expression.ReferenceSegment.StructField.Builder)bldr);
            } else if (segment instanceof FieldReference.ListElement) {
                FieldReference.ListElement f = (FieldReference.ListElement)segment;
                bldr = Expression.ReferenceSegment.ListElement.newBuilder().setOffset(f.offset());
                if (seg != null) {
                    bldr.setChild(seg);
                }
                protoSegment = Expression.ReferenceSegment.newBuilder().setListElement((Expression.ReferenceSegment.ListElement.Builder)bldr);
            } else if (segment instanceof FieldReference.MapKey) {
                FieldReference.MapKey f = (FieldReference.MapKey)segment;
                bldr = Expression.ReferenceSegment.MapKey.newBuilder().setMapKey(this.toLiteral(f.key()));
                if (seg != null) {
                    bldr.setChild(seg);
                }
                protoSegment = Expression.ReferenceSegment.newBuilder().setMapKey((Expression.ReferenceSegment.MapKey.Builder)bldr);
            } else {
                throw new IllegalArgumentException("Unhandled type: " + segment);
            }
            Expression.ReferenceSegment builtSegment = protoSegment.build();
            if (top == null) {
                top = builtSegment;
            }
            seg = builtSegment;
        }
        Expression.FieldReference.Builder out = Expression.FieldReference.newBuilder().setDirectReference(top);
        if (expr.inputExpression().isPresent()) {
            out.setExpression(this.from(expr.inputExpression().get()));
        } else if (expr.outerReferenceStepsOut().isPresent()) {
            out.setOuterReference(Expression.FieldReference.OuterReference.newBuilder().setStepsOut(expr.outerReferenceStepsOut().get()));
        } else {
            out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance());
        }
        return io.substrait.proto.Expression.newBuilder().setSelection(out).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.SetPredicate expr) throws RuntimeException {
        return io.substrait.proto.Expression.newBuilder().setSubquery(Expression.Subquery.newBuilder().setSetPredicate(Expression.Subquery.SetPredicate.newBuilder().setPredicateOp(expr.predicateOp().toProto()).setTuples(expr.tuples().accept(this.relVisitor)).build()).build()).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.ScalarSubquery expr) throws RuntimeException {
        return io.substrait.proto.Expression.newBuilder().setSubquery(Expression.Subquery.newBuilder().setScalar(Expression.Subquery.Scalar.newBuilder().setInput(expr.input().accept(this.relVisitor)).build()).build()).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.InPredicate expr) throws RuntimeException {
        return io.substrait.proto.Expression.newBuilder().setSubquery(Expression.Subquery.newBuilder().setInPredicate(Expression.Subquery.InPredicate.newBuilder().setHaystack(expr.haystack().accept(this.relVisitor)).addAllNeedles(this.from(expr.needles())).build()).build()).build();
    }

    @Override
    public io.substrait.proto.Expression visit(Expression.Window expr) throws RuntimeException {
        int funcReference;
        List partExps = expr.partitionBy().stream().map(e -> e.accept(this)).collect(Collectors.toList());
        Expression.WindowFunction.Builder builder = Expression.WindowFunction.newBuilder();
        if (expr.hasNormalAggregateFunction()) {
            AggregateFunctionInvocation aggMeasureFunc = expr.aggregateFunction().getFunction();
            funcReference = this.functionCollector.getFunctionReference(aggMeasureFunc.declaration());
            FunctionArg.FuncArgVisitor<FunctionArgument, RuntimeException> argVisitor = FunctionArg.toProto(TypeProtoConverter.INSTANCE, this);
            List args = aggMeasureFunc.arguments().stream().map(a -> (FunctionArgument)a.accept(aggMeasureFunc.declaration(), 0, argVisitor)).collect(Collectors.toList());
            int ordinal = aggMeasureFunc.aggregationPhase().ordinal();
            builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args);
        } else {
            WindowFunctionInvocation windowFunc = expr.windowFunction().getFunction();
            funcReference = this.functionCollector.getFunctionReference(windowFunc.declaration());
            int ordinal = windowFunc.aggregationPhase().ordinal();
            FunctionArg.FuncArgVisitor<FunctionArgument, RuntimeException> argVisitor = FunctionArg.toProto(TypeProtoConverter.INSTANCE, this);
            List args = windowFunc.arguments().stream().map(a -> (FunctionArgument)a.accept(windowFunc.declaration(), 0, argVisitor)).collect(Collectors.toList());
            builder.setFunctionReference(funcReference).setPhaseValue(ordinal).addAllArguments(args);
        }
        List sortFields = expr.orderBy().stream().map(s -> SortField.newBuilder().setDirection(s.direction().toProto()).setExpr(s.expr().accept(this)).build()).collect(Collectors.toList());
        Expression.WindowFunction.Bound upperBound = this.toBound(expr.upperBound());
        Expression.WindowFunction.Bound lowerBound = this.toBound(expr.lowerBound());
        return io.substrait.proto.Expression.newBuilder().setWindowFunction(builder.addAllPartitions(partExps).addAllSorts(sortFields).setLowerBound(lowerBound).setUpperBound(upperBound).build()).build();
    }

    private Expression.WindowFunction.Bound toBound(WindowBound windowBound) {
        WindowBound.BoundedKind boundedKind = windowBound.boundedKind();
        Expression.WindowFunction.Bound bound = null;
        switch (boundedKind) {
            case CURRENT_ROW: {
                bound = Expression.WindowFunction.Bound.newBuilder().setCurrentRow(Expression.WindowFunction.Bound.CurrentRow.getDefaultInstance()).build();
                break;
            }
            case BOUNDED: {
                WindowBound.BoundedWindowBound boundedWindowBound = (WindowBound.BoundedWindowBound)windowBound;
                Expression offset = boundedWindowBound.offset();
                boolean isPreceding = boundedWindowBound.direction() == WindowBound.Direction.PRECEDING;
                Expression.I32Literal offsetLiteral = (Expression.I32Literal)offset;
                int offsetVal = offsetLiteral.value();
                Expression.WindowFunction.Bound.Unbounded boundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance();
                if (isPreceding) {
                    Expression.WindowFunction.Bound.Preceding offsetProto = Expression.WindowFunction.Bound.Preceding.newBuilder().setOffset(offsetVal).build();
                    bound = Expression.WindowFunction.Bound.newBuilder().setPreceding(offsetProto).build();
                    break;
                }
                Expression.WindowFunction.Bound.Following offsetProto = Expression.WindowFunction.Bound.Following.newBuilder().setOffset(offsetVal).build();
                bound = Expression.WindowFunction.Bound.newBuilder().setFollowing(offsetProto).build();
                break;
            }
            case UNBOUNDED: {
                WindowBound.UnboundedWindowBound unboundedWindowBound = (WindowBound.UnboundedWindowBound)windowBound;
                boolean isPreceding = unboundedWindowBound.direction() == WindowBound.Direction.PRECEDING;
                Expression.WindowFunction.Bound.Unbounded unboundedProto = Expression.WindowFunction.Bound.Unbounded.getDefaultInstance();
                if (isPreceding) {
                    Expression.WindowFunction.Bound.Preceding preceding = Expression.WindowFunction.Bound.Preceding.newBuilder().build();
                    bound = Expression.WindowFunction.Bound.newBuilder().setUnbounded(unboundedProto).setPreceding(preceding).build();
                    break;
                }
                Expression.WindowFunction.Bound.Following following = Expression.WindowFunction.Bound.Following.newBuilder().build();
                bound = Expression.WindowFunction.Bound.newBuilder().setUnbounded(unboundedProto).setFollowing(following).build();
                break;
            }
            default: {
                throw new RuntimeException(String.format("Unexpected Expression.WindowFunction.Bound enum:%s", new Object[]{boundedKind}));
            }
        }
        return bound;
    }
}

