/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.coral.trino.rel2trino.transformers;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.rel.rel2sql.SqlImplementor;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.SqlSelect;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;

public class UnionSqlCallTransformer
extends SqlCallTransformer {
    public UnionSqlCallTransformer(TypeDerivationUtil typeDerivationUtil) {
        super(typeDerivationUtil);
    }

    @Override
    protected boolean condition(SqlCall sqlCall) {
        return sqlCall.getOperator().kind == SqlKind.UNION || sqlCall.getOperator().kind == SqlKind.INTERSECT || sqlCall.getOperator().kind == SqlKind.MINUS;
    }

    @Override
    protected SqlCall transform(SqlCall sqlCall) {
        SqlNode selectItem;
        List<SqlNode> operandsList = sqlCall.getOperandList();
        if (sqlCall.getOperandList().isEmpty()) {
            return sqlCall;
        }
        Integer selectListSize = null;
        for (SqlNode operand : operandsList) {
            if (operand.getKind() == SqlKind.SELECT && ((SqlSelect)operand).getSelectList() != null) {
                SqlSelect select = (SqlSelect)operand;
                List<SqlNode> selectList = select.getSelectList().getList();
                if (selectListSize == null) {
                    selectListSize = selectList.size();
                    continue;
                }
                if (selectListSize.intValue() == selectList.size()) continue;
                return sqlCall;
            }
            return sqlCall;
        }
        ArrayList leastRestrictiveSelectItemTypes = new ArrayList(selectListSize);
        for (int i = 0; i < selectListSize; ++i) {
            ArrayList<RelDataType> selectItemTypes = new ArrayList<RelDataType>();
            boolean selectItemTypesDerived = true;
            for (SqlNode operand : operandsList) {
                SqlSelect select = (SqlSelect)operand;
                List<SqlNode> selectList = select.getSelectList().getList();
                selectItem = selectList.get(i);
                if (selectItem.getKind() == SqlKind.IDENTIFIER && ((SqlIdentifier)selectItem).isStar()) {
                    return sqlCall;
                }
                try {
                    selectItemTypes.add(this.deriveRelDatatype(selectItem));
                }
                catch (RuntimeException e) {
                    selectItemTypesDerived = false;
                    break;
                }
            }
            Optional leastRestrictiveSelectItemType = selectItemTypesDerived ? Optional.ofNullable(this.leastRestrictive(selectItemTypes)) : Optional.empty();
            leastRestrictiveSelectItemTypes.add(leastRestrictiveSelectItemType);
        }
        boolean operandsUpdated = false;
        for (SqlNode operand : operandsList) {
            SqlSelect select = (SqlSelect)operand;
            List<SqlNode> selectList = select.getSelectList().getList();
            ArrayList<SqlNode> rewrittenSelectList = null;
            for (int i = 0; i < selectList.size(); ++i) {
                selectItem = selectList.get(i);
                if (!((Optional)leastRestrictiveSelectItemTypes.get(i)).isPresent()) {
                    if (rewrittenSelectList == null) continue;
                    rewrittenSelectList.add(selectItem);
                    continue;
                }
                RelDataType leastRestrictiveSelectItemType = (RelDataType)((Optional)leastRestrictiveSelectItemTypes.get(i)).get();
                RelDataType selectItemType = this.deriveRelDatatype(selectItem);
                if (selectItemType.getSqlTypeName() == SqlTypeName.CHAR && leastRestrictiveSelectItemType.getSqlTypeName() == SqlTypeName.VARCHAR) {
                    SqlNode rewrittenSelectItem = this.castNode(selectItem, leastRestrictiveSelectItemType);
                    if (rewrittenSelectList == null) {
                        rewrittenSelectList = new ArrayList<SqlNode>(selectListSize);
                        rewrittenSelectList.addAll(selectList.subList(0, i));
                        operandsUpdated = true;
                    }
                    rewrittenSelectList.add(rewrittenSelectItem);
                    continue;
                }
                if (rewrittenSelectList == null) continue;
                rewrittenSelectList.add(selectItem);
            }
            if (rewrittenSelectList == null) continue;
            select.setSelectList(new SqlNodeList(rewrittenSelectList, SqlParserPos.ZERO));
        }
        if (operandsUpdated) {
            return sqlCall.getOperator().createCall(SqlImplementor.POS, operandsList);
        }
        return sqlCall;
    }

    private SqlNode castNode(SqlNode node, RelDataType type) {
        if (node.getKind() == SqlKind.AS) {
            SqlNode expression = ((SqlCall)node).getOperandList().get(0);
            SqlIdentifier identifier = (SqlIdentifier)((SqlCall)node).getOperandList().get(1);
            return SqlStdOperatorTable.AS.createCall(SqlImplementor.POS, new SqlCastFunction(){

                @Override
                public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
                    SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call);
                    return this.inferReturnType(opBinding);
                }
            }.createCall(SqlParserPos.ZERO, expression, UnionSqlCallTransformer.getSqlDataTypeSpec(type)), identifier);
        }
        return new SqlCastFunction(){

            @Override
            public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
                SqlCallBinding opBinding = new SqlCallBinding(validator, scope, call);
                return this.inferReturnType(opBinding);
            }
        }.createCall(SqlParserPos.ZERO, node, UnionSqlCallTransformer.getSqlDataTypeSpec(type));
    }

    private static SqlDataTypeSpec getSqlDataTypeSpec(RelDataType relDataType) {
        SqlBasicTypeNameSpec typeNameSpec = new SqlBasicTypeNameSpec(relDataType.getSqlTypeName(), relDataType.getPrecision(), relDataType.getScale(), null, SqlParserPos.ZERO);
        return new SqlDataTypeSpec(typeNameSpec, SqlParserPos.ZERO);
    }
}

