package com.huawei.dli.jdbc.utils;

import com.huawei.dli.jdbc.model.EnumSqlJobType;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;

public class SqlUtils {
    private static String[] noResultQueryKeywords = new String[] {
        "INSERT", "LOAD", "UPDATE", "DELETE", "MSCK",
        "DROP", "CREATE", "ALTER", "TRUNCATE", "RENAME",
        "SET", "RESET", "ANALYZE", "REFRESH", "CACHE", "UNCACHE", "CLEAR", "GRANT", "REVOKE"
    };

    private static Set<String> noResultKeywordSet = new HashSet<>(Arrays.asList(noResultQueryKeywords));

    public static final String SET_CONF_ANNOTATION_KEY = "@set";

    public static final String SET_CONF_ANNOTATION_FORMAT = "-- " + SET_CONF_ANNOTATION_KEY + " %s=%s";

    public static final String SET_CONF_ANNOTATION_REGEX =
        "(?i)^--\\s{0,10}@set\\s{1,10}([A-Za-z0-9_\\.\\-]{1,256}?)\\s{0,10}=\\s{0,10}([\\s\\S]{1,256})$";

    public static final String getSetConfSql(String key, String value) {
        return String.format(SqlUtils.SET_CONF_ANNOTATION_FORMAT, key, value);
    }

    public static boolean isQuery(String sqlType, String sql) {
        if (EnumSqlJobType.isInHasResultType(sqlType)) {
            return true;
        }
        if (EnumSqlJobType.isInNoResultType(sqlType)) {
            return false;
        }
        String firstLine = removeFrontComment(sql);
        return !noResultKeywordSet.stream().anyMatch(k -> firstLine.toUpperCase(Locale.ENGLISH).startsWith(k));
    }

    @SuppressWarnings("unchecked")
    private static String removeFrontComment(String sql) {
        List<String> lines = new ArrayList(Arrays.asList(sql.split("\\r|\\n|\\r\\n")));
        int i = 0;
        for (; i < lines.size(); i++) {
            String line = lines.get(i);
            String content = line.replace("\t", "").trim();
            if (content.isEmpty() || content.startsWith("--")) {
                continue;
            }

            if (content.startsWith("/*")) {
                String newLine = lines.subList(i, lines.size()).stream().collect(Collectors.joining("\n"));
                int index = newLine.indexOf("*/");
                if (index >= 0) {
                    return removeFrontComment(newLine.substring(index + 2));
                }
                continue;
            }
            return content;
        }
        return sql;
    }


    public static void checkForNoResult(String sql) throws SQLException {
        String[] items = sql.split("\\s");
        if (items.length > 0) {
            String startWord = items[0].toUpperCase(Locale.US);
            if (noResultKeywordSet.contains(startWord)) {
                throw new SQLException("Can not issue DML/DDL/DCL statements with executeQuery()");
            }
        }
    }

    public static String removeConfAnnotation(String sql) {
        if (!sql.toLowerCase(Locale.ROOT).contains(SqlUtils.SET_CONF_ANNOTATION_KEY)) {
            return sql;
        }
        List<String> originLines = new ArrayList<>(Arrays.asList(sql.split("\\r|\\n|\\r\\n")));
        if (originLines.size() == 1) {
            return sql;
        }
        List<String> lines = new ArrayList<>(originLines);
        int frontIndex = lastConfAnnotationIndex(lines);
        Collections.reverse(lines);
        int endIndex = lastConfAnnotationIndex(lines);
        if (frontIndex > 0 && endIndex > 0 && (frontIndex + endIndex) >= originLines.size()) {
            return sql;
        }
        List<String> removedLines = originLines.subList(
            frontIndex >= 0 ? frontIndex + 1 : 0,
            endIndex >= 0 ? (originLines.size() - endIndex - 1) : originLines.size());
        if (removedLines.isEmpty()) {
            return sql;
        }
        return removedLines.stream().collect(Collectors.joining("\n"));
    }

    private static int lastConfAnnotationIndex(List<String> lines) {
        int index = -1;
        for (int i = 0; i < lines.size(); i++) {
            String line = lines.get(i).trim().replace("\\t", "");
            if (line.isEmpty()) {
                continue;
            }
            if (!line.startsWith("--")) {
                return index;
            }
            if (line.matches(SqlUtils.SET_CONF_ANNOTATION_REGEX)) {
                index = i;
                continue;
            }
            break;
        }
        return index;
    }
}
