package team.bangbang.common.sql;

import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;

/**
 * 数据库工具类
 *
 * @author Bangbang
 * @version 1.0  2021年1月8日
 */
public final class DbUtil {
	private static Pattern sqlUncomment = Pattern.compile("(?ms)('(?:''|[^'])*')|--.*?$|/\\*.*?\\*/|#.*?$|");
	/** 日期风格(yyyy-MM) */
	public static final int YYYY_MM = 1;
	/** 日期风格(MM-dd) */
	public static final int MM_DD = 2;
	/** 日期风格(yyyy-MM-dd) */
	public static final int YYYY_MM_DD = 3;
	/** 日期风格(hh24:mi) */
	public static final int HH24_MI = 4;
	/** 日期风格(hh24:mi:ss) */
	public static final int HH24_MI_SS = 5;
	/** 日期风格(yyyy-MM-dd hh24:mi) */
	public static final int YYYY_MM_DD_HH24_MI = 6;

	/**
	 * 使用日期函数将DATE、DATETIME类型的字段进行格式化，格式化为yyyy-MM-dd的字符串
	 *
	 * @param dbName
	 *            数据库产品名称
	 * @param fieldName
	 *            字段名称
	 * @return 字段格式化函数
	 */
	public static String formatDateField(String dbName, String fieldName) {
		return formatDateField(dbName, fieldName, YYYY_MM_DD);
	}

	/**
	 * 使用日期函数将DATE、DATETIME类型的字段按照指定的风格进行格式化
	 *
	 * @param dbName
	 *            数据库产品名称
	 * @param fieldName
	 *            字段名称
	 * @param style
	 *            日期/时间风格，参见SQLHelper下的日期风格定义
	 * @return 字段格式化函数
	 */
	public static String formatDateField(String dbName, String fieldName,
			int style) {
		switch (style) {
		case YYYY_MM:
			if ("Oracle".equalsIgnoreCase(dbName)) {
				// Oracle
				return "TO_CHAR(" + fieldName + ", 'yyyy-MM')";
			}
			if ("MySQL".equalsIgnoreCase(dbName)) {
				// MySQL
				return "DATE_FORMAT(" + fieldName + ", '%Y-%m')";
			}
			if ("Microsoft SQL Server".equalsIgnoreCase(dbName)) {
				// Microsoft SQL Server
				return "CONVERT(varchar(7), " + fieldName + ", 23)";
			}
			break;
		case MM_DD:
			if ("Oracle".equalsIgnoreCase(dbName)) {
				// Oracle
				return "TO_CHAR(" + fieldName + ", 'MM-dd')";
			}
			if ("MySQL".equalsIgnoreCase(dbName)) {
				// MySQL
				return "DATE_FORMAT(" + fieldName + ", '%m-%d')";
			}
			if ("Microsoft SQL Server".equalsIgnoreCase(dbName)) {
				// Microsoft SQL Server
				return "CONVERT(varchar(5), " + fieldName + ", 110)";
			}
			break;
		case YYYY_MM_DD:
			if ("Oracle".equalsIgnoreCase(dbName)) {
				// Oracle
				return "TO_CHAR(" + fieldName + ", 'yyyy-MM-dd')";
			}
			if ("MySQL".equalsIgnoreCase(dbName)) {
				// MySQL
				return "DATE_FORMAT(" + fieldName + ", '%Y-%m-%d')";
			}
			if ("Microsoft SQL Server".equalsIgnoreCase(dbName)) {
				// Microsoft SQL Server
				return "CONVERT(varchar(10), " + fieldName + ", 23)";
			}
			break;
		case HH24_MI:
			if ("Oracle".equalsIgnoreCase(dbName)) {
				// Oracle
				return "TO_CHAR(" + fieldName + ", 'hh24:mi')";
			}
			if ("MySQL".equalsIgnoreCase(dbName)) {
				// MySQL
				return "DATE_FORMAT(" + fieldName + ", '%H:%i')";
			}
			if ("Microsoft SQL Server".equalsIgnoreCase(dbName)) {
				// Microsoft SQL Server
				return "CONVERT(varchar(5), " + fieldName + ", 24)";
			}
			break;
		case HH24_MI_SS:
			if ("Oracle".equalsIgnoreCase(dbName)) {
				// Oracle
				return "TO_CHAR(" + fieldName + ", 'hh24:mi:ss')";
			}
			if ("MySQL".equalsIgnoreCase(dbName)) {
				// MySQL
				return "DATE_FORMAT(" + fieldName + ", '%H:%i:%s')";
			}
			if ("Microsoft SQL Server".equalsIgnoreCase(dbName)) {
				// Microsoft SQL Server
				return "CONVERT(varchar(8), " + fieldName + ", 24)";
			}
			break;
		case YYYY_MM_DD_HH24_MI:
			if ("Oracle".equalsIgnoreCase(dbName)) {
				// Oracle
				return "TO_CHAR(" + fieldName + ", 'yyyy-MM-dd hh24:mi')";
			}
			if ("MySQL".equalsIgnoreCase(dbName)) {
				// MySQL
				return "DATE_FORMAT(" + fieldName + ", '%Y-%m-%d %H:%i')";
			}
			if ("Microsoft SQL Server".equalsIgnoreCase(dbName)) {
				// Microsoft SQL Server
				return "CONVERT(varchar(16), " + fieldName + ", 20)";
			}
			break;
		}

		return fieldName;
	}

	/**
	 * 拼合OR语句
	 *
	 * @param fieldName
	 *            字段名称
	 * @param ids
	 *            用于匹配的字段值
	 * @return 拼合产生的SQL条件OR语句，格式如：fieldName=id_1 or filedName=id_2 or ...
	 *         fieldName=id_3
	 */
	public static String getOrSQL(String fieldName, Object[] ids) {
		if (ids == null || ids.length == 0)
			return "";

		StringBuffer sb = new StringBuffer();
		for (int i = 0; i < ids.length; i++) {
			Object obj = ids[i];
			if (obj == null)
				continue;
			if (sb.length() > 0)
				sb.append(" or ");
			sb.append(fieldName).append(" = ");
			if (obj instanceof String) {
				sb.append("'").append(obj).append("'");
			} else {
				sb.append(obj);
			}
		}

		return sb.toString();
	}

	/**
	 * 针对默认数据库把字符串中的单引号变为双单引号，正斜线变为双正斜线，<br>
	 * 用于SQL语句中的数值转换，所有字符型字段的值在拼接SQL语句前，必须调用本方法加以预处理。<br>
	 * 其中正斜线变为双正斜线可能只适用于MySQL数据库，其他数据库需要进行测试
	 *
	 * @param strValue 字符串
	 *
	 * @return String 单引号变为双单引号处理后的字符串
	 */
	public static String getDataString(String strValue) {
		// 数据库名称
		String dbName = "";// SQLPool.getDatabaseName();
		return getDataString(strValue, dbName);
	}

	/**
	 * 把字符串中的单引号变为双单引号，正斜线变为双正斜线，用于SQL语句中的数值转换，<br>
	 * 所有字符型字段的值在拼接SQL语句前，必须调用本方法加以预处理。<br>
	 * 其中正斜线变为双正斜线可能只适用于MySQL数据库，其他数据库需要进行测试
	 *
	 * @param strValue 字符串
	 * @param dbName 数据库名称
	 *
	 * @return String 单引号变为双单引号处理后的字符串
	 */
	public static String getDataString(String strValue, String dbName) {

		StringBuffer stbResult = new StringBuffer("");

		// not null
		if (strValue != null && strValue.length() > 0) {
			for (int i = 0; i < strValue.length(); i++) {
				if (strValue.charAt(i) == '\'') {
					stbResult.append("\'\'");
				} else if (strValue.charAt(i) == '\\'
						&& dbName.equalsIgnoreCase("MySQL")) {
					// MySQL
					stbResult.append("\\\\");
				} else if (strValue.charAt(i) == '&'
						&& dbName.equalsIgnoreCase("Oracle")) {
					// Oracle
					stbResult.append("'||'&'||'");
				} else {
					stbResult.append(strValue.charAt(i));
				}
			} // end for

			return stbResult.toString();
		} else { // is null
			return "";
		}
	}

	/**
	 * 将SQL脚本语句（包含注释，多条SQL以半角分号“;”间隔）拆分为多条SQL
	 *
	 * 在拆分的过程中会过滤掉注释和空行，去除语句末尾的半角分号“;”
	 *
	 * @param sqls SQL脚本语句
	 * @return 多条SQL
	 */
	public static String[] splitSQLs(String sqls) {
		if (sqls == null || sqls.length() == 0) {
			return new String[0];
		}

		// 去除注释
		sqls = sqlUncomment.matcher(sqls).replaceAll("$1");

		// 脚本列表
		List<String> ls = new ArrayList<String>();
		// 复制一份大写脚本，用于定位
		String u_sqls = sqls.toUpperCase();
		int start = 0;
		int end = 0;

		// 整个sql脚本长度
		int len = sqls.length();
		while(true) {
			start = findSqlStart(u_sqls, start);
			if (start < 0) break;

			end = findSqlEnd(u_sqls, start);
			if (end < start || end == len) break;

			ls.add(sqls.substring(start, end));

			start = end + 1;
		}

		String[] ss = new String[ls.size()];
		return ls.toArray(ss);
	}

	/**
	 * 找出SQL中的sql语句开始位置
	 *
	 * @param sqls 大写SQL
	 * @param fromPosition 从某个位置往后查找
	 * @return sql语句开始位置
	 */
	private static int findSqlStart(String sqls, int fromPosition) {
		// 查找第一个[INSERT|DELETE|UPDATE|SELECT]
		int n1 = sqls.indexOf("INSERT ", fromPosition);
		int n2 = sqls.indexOf("DELETE ", fromPosition);
		int n3 = sqls.indexOf("UPDATE ", fromPosition);
		int n4 = sqls.indexOf("SELECT ", fromPosition);

		int start = Integer.MAX_VALUE;
		if (n1 > 0) start = Math.min(n1, start);
		if (n2 > 0) start = Math.min(n2, start);
		if (n3 > 0) start = Math.min(n3, start);
		if (n4 > 0) start = Math.min(n4, start);

		return start;
	}

	/**
	 * 找出SQL中的sql语句开始位置
	 *
	 * @param sqls 大写SQL
	 * @param start sql语句开始位置
	 * @return sql语句结束位置
	 */
	private static int findSqlEnd(String sqls, int start) {
		// 分号位置
		while(true) {
			int n1 = sqls.indexOf(";", start);
			int n2 = sqls.indexOf("'", start);
			// ;没有 '没有
			if (n1 < 0 && n2 < 0) {
				return sqls.length();
			}

			// ;没有 '有
			if (n1 < 0 && n2 > 0) {
				return sqls.length();
			}

			// ;有 '没有
			if (n1 > 0 && n2 < 0) {
				return n1;
			}

			// ;有 '有
			// ;在前
			if (n1 > 0 && n1 < n2) {
				return n1;
			}

			// ;有 '有
			// '在前
			// 到下一个'位置
			start = sqls.indexOf("'", n2+1);
			if (start > 0) {
				start++;
			} else {
				// '不成对，语句有错误
				return sqls.length();
			}
		}
	}
}
