/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.coral.trino.rel2trino;

import com.google.common.collect.ImmutableList;
import com.linkedin.coral.calcite.$internal.com.google.common.collect.ImmutableCollection;
import com.linkedin.coral.calcite.$internal.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.FunctionReturnTypes;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.trino.rel2trino.functions.GenericProjectToTrinoConverter;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

public class Calcite2TrinoUDFConverter {
    private Calcite2TrinoUDFConverter() {
    }

    public static RelNode convertRel(RelNode calciteNode, final Map<String, Boolean> configs) {
        RelShuttleImpl converter = new RelShuttleImpl(){

            @Override
            public RelNode visit(LogicalProject project) {
                return super.visit(project).accept(this.getTrinoRexConverter(project));
            }

            @Override
            public RelNode visit(LogicalFilter inputFilter) {
                return super.visit(inputFilter).accept(this.getTrinoRexConverter(inputFilter));
            }

            @Override
            public RelNode visit(LogicalAggregate aggregate) {
                return super.visit(aggregate).accept(this.getTrinoRexConverter(aggregate));
            }

            @Override
            public RelNode visit(LogicalMatch match) {
                return super.visit(match).accept(this.getTrinoRexConverter(match));
            }

            @Override
            public RelNode visit(TableScan scan) {
                return super.visit(scan).accept(this.getTrinoRexConverter(scan));
            }

            @Override
            public RelNode visit(TableFunctionScan scan) {
                return super.visit(scan).accept(this.getTrinoRexConverter(scan));
            }

            @Override
            public RelNode visit(LogicalValues values) {
                return super.visit(values).accept(this.getTrinoRexConverter(values));
            }

            @Override
            public RelNode visit(LogicalJoin join) {
                return super.visit(join).accept(this.getTrinoRexConverter(join));
            }

            @Override
            public RelNode visit(LogicalCorrelate correlate) {
                return super.visit(correlate).accept(this.getTrinoRexConverter(correlate));
            }

            @Override
            public RelNode visit(LogicalUnion union) {
                return super.visit(union).accept(this.getTrinoRexConverter(union));
            }

            @Override
            public RelNode visit(LogicalIntersect intersect) {
                return super.visit(intersect).accept(this.getTrinoRexConverter(intersect));
            }

            @Override
            public RelNode visit(LogicalMinus minus) {
                return super.visit(minus).accept(this.getTrinoRexConverter(minus));
            }

            @Override
            public RelNode visit(LogicalSort sort) {
                return super.visit(sort).accept(this.getTrinoRexConverter(sort));
            }

            @Override
            public RelNode visit(LogicalExchange exchange) {
                return super.visit(exchange).accept(this.getTrinoRexConverter(exchange));
            }

            @Override
            public RelNode visit(RelNode other) {
                return super.visit(other).accept(this.getTrinoRexConverter(other));
            }

            private TrinoRexConverter getTrinoRexConverter(RelNode node) {
                return new TrinoRexConverter(node, configs);
            }
        };
        return calciteNode.accept(converter);
    }

    private static SqlOperator createSqlOperatorOfFunction(String functionName, SqlReturnTypeInference typeInference) {
        SqlIdentifier sqlIdentifier = new SqlIdentifier((List<String>)ImmutableList.of((Object)functionName), SqlParserPos.ZERO);
        return new SqlUserDefinedFunction(sqlIdentifier, typeInference, null, null, null, null);
    }

    public static class TrinoRexConverter
    extends RexShuttle {
        private final RexBuilder rexBuilder;
        private final RelDataTypeFactory typeFactory;
        private final RelNode node;
        private final Map<String, Boolean> configs;

        public TrinoRexConverter(RelNode node, Map<String, Boolean> configs) {
            this.rexBuilder = node.getCluster().getRexBuilder();
            this.typeFactory = node.getCluster().getTypeFactory();
            this.configs = configs;
            this.node = node;
        }

        @Override
        public RexNode visitCall(RexCall call) {
            Optional<RexNode> modifiedCall;
            if (call.getOperator() instanceof GenericProjectFunction) {
                return GenericProjectToTrinoConverter.convertGenericProject(this.rexBuilder, call, this.node);
            }
            String operatorName = call.getOperator().getName();
            if (operatorName.equalsIgnoreCase("from_utc_timestamp") && (modifiedCall = this.visitFromUtcTimestampCall(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("from_unixtime") && (modifiedCall = this.visitFromUnixtime(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("cast") && (modifiedCall = this.visitCast(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("substr") && (modifiedCall = this.visitSubstring(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("concat") && (modifiedCall = this.visitConcat(call)).isPresent()) {
                return modifiedCall.get();
            }
            return super.visitCall(call);
        }

        private Optional<RexNode> visitConcat(RexCall call) {
            SqlOperator op = call.getOperator();
            List<RexNode> convertedOperands = this.visitList(call.getOperands(), (boolean[])null);
            ArrayList<RexNode> castOperands = new ArrayList<RexNode>();
            for (RexNode inputOperand : convertedOperands) {
                if (inputOperand.getType().getSqlTypeName() != SqlTypeName.VARCHAR && inputOperand.getType().getSqlTypeName() != SqlTypeName.CHAR) {
                    RexNode castOperand = this.rexBuilder.makeCast(this.typeFactory.createSqlType(SqlTypeName.VARCHAR), inputOperand);
                    castOperands.add(castOperand);
                    continue;
                }
                castOperands.add(inputOperand);
            }
            return Optional.of(this.rexBuilder.makeCall(op, castOperands));
        }

        private Optional<RexNode> visitFromUnixtime(RexCall call) {
            List<RexNode> convertedOperands = this.visitList(call.getOperands(), (boolean[])null);
            SqlOperator formatDatetime = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("format_datetime", FunctionReturnTypes.STRING);
            SqlOperator fromUnixtime = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("from_unixtime", ReturnTypes.explicit(SqlTypeName.TIMESTAMP));
            if (convertedOperands.size() == 1) {
                return Optional.of(this.rexBuilder.makeCall(formatDatetime, this.rexBuilder.makeCall(fromUnixtime, call.getOperands().get(0)), this.rexBuilder.makeLiteral("yyyy-MM-dd HH:mm:ss")));
            }
            if (convertedOperands.size() == 2) {
                return Optional.of(this.rexBuilder.makeCall(formatDatetime, this.rexBuilder.makeCall(fromUnixtime, call.getOperands().get(0)), call.getOperands().get(1)));
            }
            return Optional.empty();
        }

        private Optional<RexNode> visitFromUtcTimestampCall(RexCall call) {
            RelDataType inputType = call.getOperands().get(0).getType();
            RelDataType targetType = this.typeFactory.createSqlType(SqlTypeName.TIMESTAMP, 3);
            List<RexNode> convertedOperands = this.visitList(call.getOperands(), (boolean[])null);
            RexNode sourceValue = convertedOperands.get(0);
            RexNode timezone = convertedOperands.get(1);
            SqlOperator trinoAtTimeZone = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("at_timezone", ReturnTypes.explicit(SqlTypeName.TIMESTAMP));
            SqlOperator trinoWithTimeZone = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("with_timezone", ReturnTypes.explicit(SqlTypeName.TIMESTAMP));
            SqlOperator trinoToUnixTime = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("to_unixtime", ReturnTypes.explicit(SqlTypeName.DOUBLE));
            SqlOperator trinoFromUnixtimeNanos = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("from_unixtime_nanos", ReturnTypes.explicit(SqlTypeName.TIMESTAMP));
            SqlOperator trinoFromUnixTime = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("from_unixtime", ReturnTypes.explicit(SqlTypeName.TIMESTAMP));
            SqlOperator trinoCanonicalizeHiveTimezoneId = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("$canonicalize_hive_timezone_id", ReturnTypes.explicit(SqlTypeName.VARCHAR));
            RelDataType bigintType = this.typeFactory.createSqlType(SqlTypeName.BIGINT);
            RelDataType doubleType = this.typeFactory.createSqlType(SqlTypeName.DOUBLE);
            if (inputType.getSqlTypeName() == SqlTypeName.BIGINT || inputType.getSqlTypeName() == SqlTypeName.INTEGER || inputType.getSqlTypeName() == SqlTypeName.SMALLINT || inputType.getSqlTypeName() == SqlTypeName.TINYINT) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, this.rexBuilder.makeCall(trinoFromUnixtimeNanos, this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, this.rexBuilder.makeCast(bigintType, sourceValue), this.rexBuilder.makeBigintLiteral(BigDecimal.valueOf(1000000L)))), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, timezone))));
            }
            if (inputType.getSqlTypeName() == SqlTypeName.DOUBLE || inputType.getSqlTypeName() == SqlTypeName.FLOAT || inputType.getSqlTypeName() == SqlTypeName.DECIMAL) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, this.rexBuilder.makeCall(trinoFromUnixTime, this.rexBuilder.makeCast(doubleType, sourceValue)), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, timezone))));
            }
            if (inputType.getSqlTypeName() == SqlTypeName.TIMESTAMP || inputType.getSqlTypeName() == SqlTypeName.DATE) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, this.rexBuilder.makeCall(trinoFromUnixTime, this.rexBuilder.makeCall(trinoToUnixTime, this.rexBuilder.makeCall(trinoWithTimeZone, sourceValue, this.rexBuilder.makeLiteral("UTC")))), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, timezone))));
            }
            return Optional.empty();
        }

        private Optional<RexNode> visitSubstring(RexCall call) {
            SqlOperator op = call.getOperator();
            List<RexNode> convertedOperands = this.visitList(call.getOperands(), (boolean[])null);
            RexNode inputOperand = convertedOperands.get(0);
            if (inputOperand.getType().getSqlTypeName() != SqlTypeName.VARCHAR && inputOperand.getType().getSqlTypeName() != SqlTypeName.CHAR) {
                ImmutableCollection operands = ((ImmutableList.Builder)((ImmutableList.Builder)new ImmutableList.Builder().add(this.rexBuilder.makeCast(this.typeFactory.createSqlType(SqlTypeName.VARCHAR), inputOperand))).addAll(convertedOperands.subList(1, convertedOperands.size()))).build();
                return Optional.of(this.rexBuilder.makeCall(op, (List<? extends RexNode>)((Object)operands)));
            }
            return Optional.empty();
        }

        private Optional<RexNode> visitCast(RexCall call) {
            SqlOperator op = call.getOperator();
            if (op.getKind() != SqlKind.CAST) {
                return Optional.empty();
            }
            List<RexNode> convertedOperands = this.visitList(call.getOperands(), (boolean[])null);
            RexNode leftOperand = convertedOperands.get(0);
            if (call.getType().getSqlTypeName() == SqlTypeName.DECIMAL && leftOperand.getType().getSqlTypeName() == SqlTypeName.TIMESTAMP) {
                SqlOperator trinoToUnixTime = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("to_unixtime", ReturnTypes.explicit(SqlTypeName.DOUBLE));
                SqlOperator trinoWithTimeZone = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("with_timezone", ReturnTypes.explicit(SqlTypeName.TIMESTAMP));
                return Optional.of(this.rexBuilder.makeCast(call.getType(), this.rexBuilder.makeCall(trinoToUnixTime, this.rexBuilder.makeCall(trinoWithTimeZone, leftOperand, this.rexBuilder.makeLiteral("UTC")))));
            }
            if (!(call.getType().getSqlTypeName() != SqlTypeName.VARCHAR && call.getType().getSqlTypeName() != SqlTypeName.CHAR || leftOperand.getType().getSqlTypeName() != SqlTypeName.VARBINARY && leftOperand.getType().getSqlTypeName() != SqlTypeName.BINARY)) {
                SqlOperator fromUTF8 = Calcite2TrinoUDFConverter.createSqlOperatorOfFunction("from_utf8", ReturnTypes.explicit(SqlTypeName.VARCHAR));
                return Optional.of(this.rexBuilder.makeCall(fromUTF8, leftOperand));
            }
            return Optional.empty();
        }
    }
}

