001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.util;
017
018import java.lang.reflect.Array;
019import java.time.LocalDateTime;
020import java.util.Date;
021import java.util.StringJoiner;
022import java.util.regex.Matcher;
023
024import static com.mybatisflex.core.constant.SqlConsts.*;
025
026public class SqlUtil {
027
028    private SqlUtil() {
029    }
030
031    public static void keepColumnSafely(String column) {
032        if (StringUtil.isBlank(column)) {
033            throw new IllegalArgumentException("Column must not be empty");
034        } else {
035            column = column.trim();
036        }
037
038        int strLen = column.length();
039        for (int i = 0; i < strLen; ++i) {
040            char ch = column.charAt(i);
041            if (Character.isWhitespace(ch)) {
042                throw new IllegalArgumentException("Column must not has space char.");
043            }
044            if (isUnSafeChar(ch)) {
045                throw new IllegalArgumentException("Column has unsafe char: [" + ch + "].");
046            }
047        }
048    }
049
050
051    /**
052     * 仅支持字母、数字、下划线、空格、逗号、小数点(支持多个字段排序)
053     */
054    private static final String SQL_ORDER_BY_PATTERN = "[a-zA-Z0-9_\\ \\,\\.]+";
055
056    public static void keepOrderBySqlSafely(String value) {
057        if (!value.matches(SQL_ORDER_BY_PATTERN)) {
058            throw new IllegalArgumentException("Order By sql not safe, order by string: " + value);
059        }
060    }
061
062
063    private static final char[] UN_SAFE_CHARS = "'`\"<>&+=#-;".toCharArray();
064
065    private static boolean isUnSafeChar(char ch) {
066        for (char c : UN_SAFE_CHARS) {
067            if (c == ch) {
068                return true;
069            }
070        }
071        return false;
072    }
073
074
075    /**
076     * 根据数据库响应结果判断数据库操作是否成功。
077     *
078     * @param result 数据库操作返回影响条数
079     * @return {@code true} 操作成功,{@code false} 操作失败。
080     */
081    public static boolean toBool(Number result) {
082        return result != null && result.intValue() > 0;
083    }
084
085
086    /**
087     * 根据数据库响应结果判断数据库操作是否成功。
088     * 有 1 条数据成功便算成功
089     *
090     * @param results 操作数据的响应成功条数
091     * @return {@code true} 操作成功,{@code false} 操作失败。
092     */
093    public static boolean toBool(int[] results) {
094        for (int result : results) {
095            if (result > 0) {
096                return true;
097            }
098        }
099        return false;
100    }
101
102
103    /**
104     * 替换 sql 中的问号 ?
105     *
106     * @param sql    sql 内容
107     * @param params 参数
108     * @return 完整的 sql
109     */
110    public static String replaceSqlParams(String sql, Object[] params) {
111        if (params != null && params.length > 0) {
112            for (Object value : params) {
113                // null
114                if (value == null) {
115                    sql = sql.replaceFirst("\\?", "null");
116                }
117                // number
118                else if (value instanceof Number || value instanceof Boolean) {
119                    sql = sql.replaceFirst("\\?", value.toString());
120                }
121                // array
122                else if (ClassUtil.isArray(value.getClass())) {
123                    StringJoiner joiner = new StringJoiner(",");
124                    for (int i = 0; i < Array.getLength(value); i++) {
125                        joiner.add(String.valueOf(Array.get(value, i)));
126                    }
127                    sql = sql.replaceFirst("\\?", "[" + joiner + "]");
128                }
129                // other
130                else {
131                    StringBuilder sb = new StringBuilder();
132                    sb.append("'");
133                    if (value instanceof Date) {
134                        sb.append(DateUtil.toDateTimeString((Date) value));
135                    } else if (value instanceof LocalDateTime) {
136                        sb.append(DateUtil.toDateTimeString(DateUtil.toDate((LocalDateTime) value)));
137                    } else {
138                        sb.append(value);
139                    }
140                    sb.append("'");
141                    sql = sql.replaceFirst("\\?", Matcher.quoteReplacement(sb.toString()));
142                }
143            }
144        }
145        return sql;
146    }
147
148
149    public static String buildSqlParamPlaceholder(int count) {
150        StringBuilder sb = new StringBuilder(BRACKET_LEFT);
151        for (int i = 0; i < count; i++) {
152            sb.append(PLACEHOLDER);
153            if (i != count - 1) {
154                sb.append(DELIMITER);
155            }
156        }
157        sb.append(BRACKET_RIGHT);
158        return sb.toString();
159    }
160
161}