/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.coral.trino.rel2trino;

import com.google.common.collect.ImmutableMap;
import com.linkedin.coral.com.google.common.base.CaseFormat;
import com.linkedin.coral.com.google.common.base.Converter;
import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.com.google.common.collect.ImmutableMultimap;
import com.linkedin.coral.com.google.common.collect.Multimap;
import com.linkedin.coral.common.functions.FunctionReturnTypes;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.trino.rel2trino.CalciteTrinoUDFMap;
import com.linkedin.coral.trino.rel2trino.TrinoTryCastFunction;
import com.linkedin.coral.trino.rel2trino.UDFMapUtils;
import com.linkedin.coral.trino.rel2trino.UDFTransformer;
import com.linkedin.coral.trino.rel2trino.functions.GenericProjectToTrinoConverter;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttle;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlMapValueConstructor;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

public class Calcite2TrinoUDFConverter {
    private Calcite2TrinoUDFConverter() {
    }

    public static RelNode convertRel(RelNode calciteNode, final Map<String, Boolean> configs) {
        RelShuttleImpl converter = new RelShuttleImpl(){

            public RelNode visit(LogicalProject project) {
                return super.visit(project).accept((RexShuttle)this.getTrinoRexConverter((RelNode)project));
            }

            public RelNode visit(LogicalFilter inputFilter) {
                return super.visit(inputFilter).accept((RexShuttle)this.getTrinoRexConverter((RelNode)inputFilter));
            }

            public RelNode visit(LogicalAggregate aggregate) {
                return super.visit(aggregate).accept((RexShuttle)this.getTrinoRexConverter((RelNode)aggregate));
            }

            public RelNode visit(LogicalMatch match) {
                return super.visit(match).accept((RexShuttle)this.getTrinoRexConverter((RelNode)match));
            }

            public RelNode visit(TableScan scan) {
                return super.visit(scan).accept((RexShuttle)this.getTrinoRexConverter((RelNode)scan));
            }

            public RelNode visit(TableFunctionScan scan) {
                return super.visit(scan).accept((RexShuttle)this.getTrinoRexConverter((RelNode)scan));
            }

            public RelNode visit(LogicalValues values) {
                return super.visit(values).accept((RexShuttle)this.getTrinoRexConverter((RelNode)values));
            }

            public RelNode visit(LogicalJoin join) {
                return super.visit(join).accept((RexShuttle)this.getTrinoRexConverter((RelNode)join));
            }

            public RelNode visit(LogicalCorrelate correlate) {
                return super.visit(correlate).accept((RexShuttle)this.getTrinoRexConverter((RelNode)correlate));
            }

            public RelNode visit(LogicalUnion union) {
                return super.visit(union).accept((RexShuttle)this.getTrinoRexConverter((RelNode)union));
            }

            public RelNode visit(LogicalIntersect intersect) {
                return super.visit(intersect).accept((RexShuttle)this.getTrinoRexConverter((RelNode)intersect));
            }

            public RelNode visit(LogicalMinus minus) {
                return super.visit(minus).accept((RexShuttle)this.getTrinoRexConverter((RelNode)minus));
            }

            public RelNode visit(LogicalSort sort) {
                return super.visit(sort).accept((RexShuttle)this.getTrinoRexConverter((RelNode)sort));
            }

            public RelNode visit(LogicalExchange exchange) {
                return super.visit(exchange).accept((RexShuttle)this.getTrinoRexConverter((RelNode)exchange));
            }

            public RelNode visit(RelNode other) {
                return super.visit(other).accept((RexShuttle)this.getTrinoRexConverter(other));
            }

            private TrinoRexConverter getTrinoRexConverter(RelNode node) {
                return new TrinoRexConverter(node, configs);
            }
        };
        return calciteNode.accept((RelShuttle)converter);
    }

    public static class TrinoRexConverter
    extends RexShuttle {
        private final RexBuilder rexBuilder;
        private final RelDataTypeFactory typeFactory;
        private final RelNode node;
        private final Map<String, Boolean> configs;
        private static final Multimap<SqlTypeFamily, SqlTypeFamily> SUPPORTED_TYPE_CAST_MAP = ImmutableMultimap.builder().putAll((Object)SqlTypeFamily.CHARACTER, (Object[])new SqlTypeFamily[]{SqlTypeFamily.NUMERIC, SqlTypeFamily.BOOLEAN}).build();

        public TrinoRexConverter(RelNode node, Map<String, Boolean> configs) {
            this.rexBuilder = node.getCluster().getRexBuilder();
            this.typeFactory = node.getCluster().getTypeFactory();
            this.configs = configs;
            this.node = node;
        }

        public RexNode visitCall(RexCall call) {
            Optional<RexNode> modifiedCall;
            if (call.getOperator() instanceof GenericProjectFunction) {
                return GenericProjectToTrinoConverter.convertGenericProject(this.rexBuilder, call, this.node);
            }
            if (call.getOperator() instanceof SqlMapValueConstructor) {
                return this.convertMapValueConstructor(this.rexBuilder, call);
            }
            String operatorName = call.getOperator().getName();
            if (operatorName.equalsIgnoreCase("from_utc_timestamp") && (modifiedCall = this.visitFromUtcTimestampCall(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("from_unixtime") && (modifiedCall = this.visitFromUnixtime(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("cast") && (modifiedCall = this.visitCast(call)).isPresent()) {
                return modifiedCall.get();
            }
            if ((operatorName.equalsIgnoreCase("collect_list") || operatorName.equalsIgnoreCase("collect_set")) && (modifiedCall = this.visitCollectListOrSetFunction(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("substr") && (modifiedCall = this.visitSubstring(call)).isPresent()) {
                return modifiedCall.get();
            }
            if (operatorName.equalsIgnoreCase("current_timestamp") && (modifiedCall = this.visitCurrentTimestamp(call)).isPresent()) {
                return modifiedCall.get();
            }
            UDFTransformer transformer = CalciteTrinoUDFMap.getUDFTransformer(operatorName, call.operands.size());
            if (transformer != null && this.shouldTransformOperator(operatorName)) {
                return this.adjustReturnTypeWithCast(this.rexBuilder, super.visitCall((RexCall)transformer.transformCall(this.rexBuilder, call.getOperands())));
            }
            if (operatorName.startsWith("com.linkedin") && transformer == null) {
                return this.visitUnregisteredUDF(call);
            }
            RexCall modifiedCall2 = this.adjustInconsistentTypesToEqualityOperator(call);
            return this.adjustReturnTypeWithCast(this.rexBuilder, super.visitCall(modifiedCall2));
        }

        private RexNode visitUnregisteredUDF(RexCall call) {
            List convertedOperands = this.visitList(call.getOperands(), null);
            Converter caseConverter = CaseFormat.UPPER_CAMEL.converterTo(CaseFormat.LOWER_UNDERSCORE);
            SqlOperator operator = call.getOperator();
            String operatorName = operator.getName();
            String[] nameSplit = operatorName.split("\\.");
            String className = nameSplit[nameSplit.length - 1];
            String convertedFunctionName = (String)caseConverter.convert((Object)className);
            SqlUserDefinedFunction convertedFunctionOperator = new SqlUserDefinedFunction(new SqlIdentifier(convertedFunctionName, SqlParserPos.ZERO), operator.getReturnTypeInference(), null, operator.getOperandTypeChecker(), null, null);
            return this.rexBuilder.makeCall(call.getType(), (SqlOperator)convertedFunctionOperator, convertedOperands);
        }

        private Optional<RexNode> visitCollectListOrSetFunction(RexCall call) {
            List convertedOperands = this.visitList(call.getOperands(), null);
            SqlOperator arrayAgg = UDFMapUtils.createUDF("array_agg", FunctionReturnTypes.ARRAY_OF_ARG0_TYPE);
            SqlOperator arrayDistinct = UDFMapUtils.createUDF("array_distinct", ReturnTypes.ARG0_NULLABLE);
            String operatorName = call.getOperator().getName();
            if (operatorName.equalsIgnoreCase("collect_list")) {
                return Optional.of(this.rexBuilder.makeCall(arrayAgg, convertedOperands));
            }
            return Optional.of(this.rexBuilder.makeCall(arrayDistinct, new RexNode[]{this.rexBuilder.makeCall(arrayAgg, convertedOperands)}));
        }

        private Optional<RexNode> visitFromUnixtime(RexCall call) {
            List convertedOperands = this.visitList(call.getOperands(), null);
            SqlOperator formatDatetime = UDFMapUtils.createUDF("format_datetime", FunctionReturnTypes.STRING);
            SqlOperator fromUnixtime = UDFMapUtils.createUDF("from_unixtime", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            if (convertedOperands.size() == 1) {
                return Optional.of(this.rexBuilder.makeCall(formatDatetime, new RexNode[]{this.rexBuilder.makeCall(fromUnixtime, new RexNode[]{(RexNode)call.getOperands().get(0)}), this.rexBuilder.makeLiteral("yyyy-MM-dd HH:mm:ss")}));
            }
            if (convertedOperands.size() == 2) {
                return Optional.of(this.rexBuilder.makeCall(formatDatetime, new RexNode[]{this.rexBuilder.makeCall(fromUnixtime, new RexNode[]{(RexNode)call.getOperands().get(0)}), (RexNode)call.getOperands().get(1)}));
            }
            return Optional.empty();
        }

        private Optional<RexNode> visitFromUtcTimestampCall(RexCall call) {
            RelDataType inputType = ((RexNode)call.getOperands().get(0)).getType();
            RelDataType targetType = this.typeFactory.createSqlType(SqlTypeName.TIMESTAMP, 3);
            List convertedOperands = this.visitList(call.getOperands(), null);
            RexNode sourceValue = (RexNode)convertedOperands.get(0);
            RexNode timezone = (RexNode)convertedOperands.get(1);
            SqlOperator trinoAtTimeZone = UDFMapUtils.createUDF("at_timezone", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoWithTimeZone = UDFMapUtils.createUDF("with_timezone", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoToUnixTime = UDFMapUtils.createUDF("to_unixtime", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.DOUBLE));
            SqlOperator trinoFromUnixtimeNanos = UDFMapUtils.createUDF("from_unixtime_nanos", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoFromUnixTime = UDFMapUtils.createUDF("from_unixtime", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
            SqlOperator trinoCanonicalizeHiveTimezoneId = UDFMapUtils.createUDF("$canonicalize_hive_timezone_id", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.VARCHAR));
            RelDataType bigintType = this.typeFactory.createSqlType(SqlTypeName.BIGINT);
            RelDataType doubleType = this.typeFactory.createSqlType(SqlTypeName.DOUBLE);
            if (inputType.getSqlTypeName() == SqlTypeName.BIGINT || inputType.getSqlTypeName() == SqlTypeName.INTEGER || inputType.getSqlTypeName() == SqlTypeName.SMALLINT || inputType.getSqlTypeName() == SqlTypeName.TINYINT) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, new RexNode[]{this.rexBuilder.makeCall(trinoFromUnixtimeNanos, new RexNode[]{this.rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, new RexNode[]{this.rexBuilder.makeCast(bigintType, sourceValue), this.rexBuilder.makeBigintLiteral(BigDecimal.valueOf(1000000L))})}), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, new RexNode[]{timezone})})));
            }
            if (inputType.getSqlTypeName() == SqlTypeName.DOUBLE || inputType.getSqlTypeName() == SqlTypeName.FLOAT || inputType.getSqlTypeName() == SqlTypeName.DECIMAL) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, new RexNode[]{this.rexBuilder.makeCall(trinoFromUnixTime, new RexNode[]{this.rexBuilder.makeCast(doubleType, sourceValue)}), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, new RexNode[]{timezone})})));
            }
            if (inputType.getSqlTypeName() == SqlTypeName.TIMESTAMP || inputType.getSqlTypeName() == SqlTypeName.DATE) {
                return Optional.of(this.rexBuilder.makeCast(targetType, this.rexBuilder.makeCall(trinoAtTimeZone, new RexNode[]{this.rexBuilder.makeCall(trinoFromUnixTime, new RexNode[]{this.rexBuilder.makeCall(trinoToUnixTime, new RexNode[]{this.rexBuilder.makeCall(trinoWithTimeZone, new RexNode[]{sourceValue, this.rexBuilder.makeLiteral("UTC")})})}), this.rexBuilder.makeCall(trinoCanonicalizeHiveTimezoneId, new RexNode[]{timezone})})));
            }
            return Optional.empty();
        }

        private Optional<RexNode> visitSubstring(RexCall call) {
            SqlOperator op = call.getOperator();
            List convertedOperands = this.visitList(call.getOperands(), null);
            RexNode inputOperand = (RexNode)convertedOperands.get(0);
            if (inputOperand.getType().getSqlTypeName() != SqlTypeName.VARCHAR && inputOperand.getType().getSqlTypeName() != SqlTypeName.CHAR) {
                ImmutableList operands = new ImmutableList.Builder().add((Object)this.rexBuilder.makeCast(this.typeFactory.createSqlType(SqlTypeName.VARCHAR), inputOperand)).addAll(convertedOperands.subList(1, convertedOperands.size())).build();
                return Optional.of(this.rexBuilder.makeCall(op, (List)operands));
            }
            return Optional.empty();
        }

        private Optional<RexNode> visitCurrentTimestamp(RexCall call) {
            SqlOperator op = call.getOperator();
            return Optional.of(this.rexBuilder.makeCast(this.typeFactory.createSqlType(SqlTypeName.TIMESTAMP, 3), this.rexBuilder.makeCall(op, new RexNode[0])));
        }

        private Optional<RexNode> visitCast(RexCall call) {
            SqlOperator op = call.getOperator();
            if (op.getKind() != SqlKind.CAST) {
                return Optional.empty();
            }
            List convertedOperands = this.visitList(call.getOperands(), null);
            RexNode leftOperand = (RexNode)convertedOperands.get(0);
            if (call.getType().getSqlTypeName() == SqlTypeName.DECIMAL && leftOperand.getType().getSqlTypeName() == SqlTypeName.TIMESTAMP) {
                SqlOperator trinoToUnixTime = UDFMapUtils.createUDF("to_unixtime", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.DOUBLE));
                SqlOperator trinoWithTimeZone = UDFMapUtils.createUDF("with_timezone", (SqlReturnTypeInference)ReturnTypes.explicit((SqlTypeName)SqlTypeName.TIMESTAMP));
                return Optional.of(this.rexBuilder.makeCast(call.getType(), this.rexBuilder.makeCall(trinoToUnixTime, new RexNode[]{this.rexBuilder.makeCall(trinoWithTimeZone, new RexNode[]{leftOperand, this.rexBuilder.makeLiteral("UTC")})})));
            }
            return Optional.empty();
        }

        private RexCall adjustInconsistentTypesToEqualityOperator(RexCall call) {
            SqlOperator op = call.getOperator();
            if (op.getKind() != SqlKind.EQUALS) {
                return call;
            }
            RexNode leftOperand = (RexNode)call.getOperands().get(0);
            RexNode rightOperand = (RexNode)call.getOperands().get(1);
            if (leftOperand.getKind() == SqlKind.CAST) {
                leftOperand = (RexNode)((RexCall)leftOperand).getOperands().get(0);
            }
            if (SUPPORTED_TYPE_CAST_MAP.containsEntry((Object)leftOperand.getType().getSqlTypeName().getFamily(), (Object)rightOperand.getType().getSqlTypeName().getFamily())) {
                RexNode tryCastNode = this.rexBuilder.makeCall(rightOperand.getType(), (SqlOperator)TrinoTryCastFunction.INSTANCE, (List)ImmutableList.of((Object)leftOperand));
                return (RexCall)this.rexBuilder.makeCall(op, new RexNode[]{tryCastNode, rightOperand});
            }
            return call;
        }

        private RexNode convertMapValueConstructor(RexBuilder rexBuilder, RexCall call) {
            List sourceOperands = this.visitList(call.getOperands(), null);
            ArrayList<RexNode> results = new ArrayList<RexNode>();
            ArrayList keys = new ArrayList();
            for (int i = 0; i < sourceOperands.size(); i += 2) {
                keys.add(sourceOperands.get(i));
            }
            results.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, keys));
            ArrayList values = new ArrayList();
            for (int i = 1; i < sourceOperands.size(); i += 2) {
                values.add(sourceOperands.get(i));
            }
            results.add(rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, values));
            return rexBuilder.makeCall(call.getOperator(), results);
        }

        private boolean shouldTransformOperator(String operatorName) {
            return !"to_date".equalsIgnoreCase(operatorName) || this.configs.getOrDefault("AVOID_TRANSFORM_TO_DATE_UDF", false) == false;
        }

        private RexNode adjustReturnTypeWithCast(RexBuilder rexBuilder, RexNode call) {
            if (!(call instanceof RexCall)) {
                return call;
            }
            String lowercaseOperatorName = ((RexCall)call).getOperator().getName().toLowerCase(Locale.ROOT);
            ImmutableMap operatorsToAdjust = ImmutableMap.of((Object)"date_diff", (Object)this.typeFactory.createSqlType(SqlTypeName.INTEGER), (Object)"cardinality", (Object)this.typeFactory.createSqlType(SqlTypeName.INTEGER), (Object)"ceil", (Object)this.typeFactory.createSqlType(SqlTypeName.BIGINT), (Object)"ceiling", (Object)this.typeFactory.createSqlType(SqlTypeName.BIGINT), (Object)"floor", (Object)this.typeFactory.createSqlType(SqlTypeName.BIGINT));
            if (operatorsToAdjust.containsKey((Object)lowercaseOperatorName)) {
                return rexBuilder.makeCast((RelDataType)operatorsToAdjust.get((Object)lowercaseOperatorName), call);
            }
            if (this.configs.getOrDefault("CAST_DATE_ADD_TO_STRING", false).booleanValue() && lowercaseOperatorName.equals("date_add")) {
                return rexBuilder.makeCast(this.typeFactory.createSqlType(SqlTypeName.VARCHAR), call);
            }
            return call;
        }
    }
}

