001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.provider;
017
018import com.mybatisflex.core.FlexConsts;
019import com.mybatisflex.core.dialect.DialectFactory;
020import com.mybatisflex.core.exception.FlexAssert;
021import com.mybatisflex.core.exception.FlexExceptions;
022import com.mybatisflex.core.query.CPI;
023import com.mybatisflex.core.query.QueryWrapper;
024import com.mybatisflex.core.row.Row;
025import com.mybatisflex.core.row.RowCPI;
026import com.mybatisflex.core.row.RowMapper;
027import com.mybatisflex.core.table.TableInfo;
028import com.mybatisflex.core.table.TableInfoFactory;
029import com.mybatisflex.core.util.ArrayUtil;
030
031import java.util.*;
032
033@SuppressWarnings({"rawtypes", "DuplicatedCode"})
034public class RowSqlProvider {
035
036    public static final String METHOD_RAW_SQL = "providerRawSql";
037
038    /**
039     * 不让实例化,使用静态方法的模式,效率更高,非静态方法每次都会实例化当前类
040     * 参考源码: {{@link org.apache.ibatis.builder.annotation.ProviderSqlSource#getBoundSql(Object)}
041     */
042    private RowSqlProvider() {
043    }
044
045    /**
046     * 执行原生 sql 的方法
047     *
048     * @param params 方法参数
049     * @return SQL 语句
050     * @see RowMapper#insertBySql(String, Object...)
051     * @see RowMapper#deleteBySql(String, Object...)
052     * @see RowMapper#updateBySql(String, Object...)
053     */
054    public static String providerRawSql(Map params) {
055        return ProviderUtil.getSqlString(params);
056    }
057
058    /**
059     * insert 的 SQL 构建。
060     *
061     * @param params 方法参数
062     * @return SQL 语句
063     * @see RowMapper#insert(String, String, Row)
064     */
065    public static String insert(Map params) {
066        String tableName = ProviderUtil.getTableName(params);
067        String schema = ProviderUtil.getSchemaName(params);
068        Row row = ProviderUtil.getRow(params);
069        ProviderUtil.setSqlArgs(params, RowCPI.obtainModifyValues(row));
070        return DialectFactory.getDialect().forInsertRow(schema, tableName, row);
071    }
072
073    /**
074     * insertBatch 的 SQL 构建。
075     *
076     * @param params 方法参数
077     * @return SQL 语句
078     * @see RowMapper#insertBatchWithFirstRowColumns(String, String, List)
079     */
080    public static String insertBatchWithFirstRowColumns(Map params) {
081        List<Row> rows = ProviderUtil.getRows(params);
082
083        FlexAssert.notEmpty(rows, "rows can not be null or empty.");
084
085        String tableName = ProviderUtil.getTableName(params);
086        String schema = ProviderUtil.getSchemaName(params);
087
088        // 让所有 row 的列顺序和值的数量与第条数据保持一致
089        // 这个必须 new 一个 LinkedHashSet,因为 keepModifyAttrs 会清除 row 所有的 modifyAttrs
090        Set<String> modifyAttrs = new LinkedHashSet<>(RowCPI.getModifyAttrs(rows.get(0)));
091        rows.forEach(row -> row.keep(modifyAttrs));
092
093        Object[] values = new Object[]{};
094        for (Row row : rows) {
095            values = ArrayUtil.concat(values, RowCPI.obtainModifyValues(row));
096        }
097        ProviderUtil.setSqlArgs(params, values);
098
099        //sql: INSERT INTO `tb_table`(`name`, `sex`) VALUES (?, ?),(?, ?),(?, ?)
100        return DialectFactory.getDialect().forInsertBatchWithFirstRowColumns(schema, tableName, rows);
101    }
102
103    /**
104     * deleteById 的 SQL 构建。
105     *
106     * @param params 方法参数
107     * @return SQL 语句
108     * @see RowMapper#deleteById(String, String, String, Object)
109     */
110    public static String deleteById(Map params) {
111        Object[] primaryValues = ProviderUtil.getPrimaryValues(params);
112
113        FlexAssert.notEmpty(primaryValues, "primaryValue can not be null or empty.");
114
115        String schema = ProviderUtil.getSchemaName(params);
116        String tableName = ProviderUtil.getTableName(params);
117        String[] primaryKeys = ProviderUtil.getPrimaryKeys(params);
118
119        ProviderUtil.setSqlArgs(params, primaryValues);
120
121        return DialectFactory.getDialect().forDeleteById(schema, tableName, primaryKeys);
122    }
123
124    /**
125     * deleteBatchByIds 的 SQL 构建。
126     *
127     * @param params 方法参数
128     * @return SQL 语句
129     * @see RowMapper#deleteBatchByIds(String, String, String, Collection)
130     */
131    public static String deleteBatchByIds(Map params) {
132        String schema = ProviderUtil.getSchemaName(params);
133        String tableName = ProviderUtil.getTableName(params);
134        String[] primaryKeys = ProviderUtil.getPrimaryKeys(params);
135        Object[] primaryValues = ProviderUtil.getPrimaryValues(params);
136
137        ProviderUtil.setSqlArgs(params, primaryValues);
138        return DialectFactory.getDialect().forDeleteBatchByIds(schema, tableName, primaryKeys, primaryValues);
139    }
140
141    /**
142     * deleteByQuery 的 SQL 构建。
143     *
144     * @param params 方法参数
145     * @return SQL 语句
146     * @see RowMapper#deleteByQuery(String, String, QueryWrapper)
147     */
148    public static String deleteByQuery(Map params) {
149        String schema = ProviderUtil.getSchemaName(params);
150        String tableName = ProviderUtil.getTableName(params);
151        QueryWrapper queryWrapper = ProviderUtil.getQueryWrapper(params);
152        CPI.setFromIfNecessary(queryWrapper, schema, tableName);
153
154        //优先构建 sql,再构建参数
155        String sql = DialectFactory.getDialect().forDeleteByQuery(queryWrapper);
156
157        Object[] valueArray = CPI.getValueArray(queryWrapper);
158        ProviderUtil.setSqlArgs(params, valueArray);
159
160        return sql;
161    }
162
163    /**
164     * updateById 的 SQL 构建。
165     *
166     * @param params 方法参数
167     * @return SQL 语句
168     * @see RowMapper#updateById(String, String, Row)
169     */
170    public static String updateById(Map params) {
171        String schema = ProviderUtil.getSchemaName(params);
172        String tableName = ProviderUtil.getTableName(params);
173        Row row = ProviderUtil.getRow(params);
174        ProviderUtil.setSqlArgs(params, RowCPI.obtainAllModifyValues(row));
175        return DialectFactory.getDialect().forUpdateById(schema, tableName, row);
176    }
177
178    /**
179     * updateByQuery 的 SQL 构建。
180     *
181     * @param params 方法参数
182     * @return SQL 语句
183     * @see RowMapper#updateByQuery(String, String, Row, QueryWrapper)
184     */
185    public static String updateByQuery(Map params) {
186        String schema = ProviderUtil.getSchemaName(params);
187        String tableName = ProviderUtil.getTableName(params);
188        Row data = ProviderUtil.getRow(params);
189
190        QueryWrapper queryWrapper = ProviderUtil.getQueryWrapper(params);
191        CPI.setFromIfNecessary(queryWrapper, schema, tableName);
192
193        //优先构建 sql,再构建参数
194        String sql = DialectFactory.getDialect().forUpdateByQuery(queryWrapper, data);
195
196        Object[] modifyValues = RowCPI.obtainModifyValues(data);
197        Object[] valueArray = CPI.getValueArray(queryWrapper);
198
199        ProviderUtil.setSqlArgs(params, ArrayUtil.concat(modifyValues, valueArray));
200
201        return sql;
202    }
203
204    /**
205     * updateBatchById 的 SQL 构建。
206     * mysql 等链接配置需要开启 allowMultiQueries=true
207     *
208     * @param params 方法参数
209     * @return SQL 语句
210     * @see RowMapper#updateBatchById(String, String, List)
211     */
212    public static String updateBatchById(Map params) {
213        List<Row> rows = ProviderUtil.getRows(params);
214
215        FlexAssert.notEmpty(rows, "rows can not be null or empty.");
216
217        String schema = ProviderUtil.getSchemaName(params);
218        String tableName = ProviderUtil.getTableName(params);
219
220        Object[] values = FlexConsts.EMPTY_ARRAY;
221        for (Row row : rows) {
222            values = ArrayUtil.concat(values, RowCPI.obtainAllModifyValues(row));
223        }
224        ProviderUtil.setSqlArgs(params, values);
225        return DialectFactory.getDialect().forUpdateBatchById(schema, tableName, rows);
226    }
227
228    /**
229     * updateEntity 的 SQL 构建。
230     *
231     * @param params 方法参数
232     * @return SQL 语句
233     * @see RowMapper#updateEntity(Object entities)
234     */
235    public static String updateEntity(Map params) {
236        Object entity = ProviderUtil.getEntity(params);
237
238        FlexAssert.notNull(entity, "entity can not be null");
239
240        // 该 Mapper 是通用 Mapper  无法通过 ProviderContext 获取,直接使用 TableInfoFactory
241        TableInfo tableInfo = TableInfoFactory.ofEntityClass(entity.getClass());
242
243        // 执行 onUpdate 监听器
244        tableInfo.invokeOnUpdateListener(entity);
245
246        Object[] updateValues = tableInfo.buildUpdateSqlArgs(entity, false, false);
247        Object[] primaryValues = tableInfo.buildPkSqlArgs(entity);
248        Object[] tenantIdArgs = tableInfo.buildTenantIdArgs();
249
250        FlexExceptions.assertAreNotNull(primaryValues, "The value of primary key must not be null, entity[%s]", entity);
251
252        ProviderUtil.setSqlArgs(params, ArrayUtil.concat(updateValues, primaryValues, tenantIdArgs));
253
254        return DialectFactory.getDialect().forUpdateEntity(tableInfo, entity, false);
255    }
256
257    /**
258     * 执行类似 update table set field=field+1 where ... 的场景
259     *
260     * @param params 方法参数
261     * @return SQL 语句
262     * @see RowMapper#updateNumberAddByQuery(String, String, String, Number, QueryWrapper)
263     */
264    public static String updateNumberAddByQuery(Map params) {
265        QueryWrapper queryWrapper = ProviderUtil.getQueryWrapper(params);
266        String schema = ProviderUtil.getSchemaName(params);
267        String tableName = ProviderUtil.getTableName(params);
268        String fieldName = ProviderUtil.getFieldName(params);
269        Number value = (Number) ProviderUtil.getValue(params);
270
271        //优先构建 sql,再构建参数
272        String sql = DialectFactory.getDialect().forUpdateNumberAddByQuery(schema
273            , tableName, fieldName, value, queryWrapper);
274
275        Object[] queryParams = CPI.getValueArray(queryWrapper);
276        ProviderUtil.setSqlArgs(params, queryParams);
277        return sql;
278    }
279
280    /**
281     * selectOneById 的 SQL 构建。
282     *
283     * @param params 方法参数
284     * @return SQL 语句
285     * @see RowMapper#selectOneById(String, String, String, Object)
286     */
287    public static String selectOneById(Map params) {
288        String schema = ProviderUtil.getSchemaName(params);
289        String tableName = ProviderUtil.getTableName(params);
290        String[] primaryKeys = ProviderUtil.getPrimaryKeys(params);
291        Object[] primaryValues = ProviderUtil.getPrimaryValues(params);
292
293        ProviderUtil.setSqlArgs(params, primaryValues);
294
295        return DialectFactory.getDialect().forSelectOneById(schema, tableName, primaryKeys, primaryValues);
296    }
297
298    /**
299     * selectListByQuery 的 SQL 构建。
300     *
301     * @param params 方法参数
302     * @return SQL 语句
303     * @see RowMapper#selectListByQuery(String, String, QueryWrapper)
304     */
305    public static String selectListByQuery(Map params) {
306        String schema = ProviderUtil.getSchemaName(params);
307        String tableName = ProviderUtil.getTableName(params);
308
309        QueryWrapper queryWrapper = ProviderUtil.getQueryWrapper(params);
310        CPI.setFromIfNecessary(queryWrapper, schema, tableName);
311
312        //优先构建 sql,再构建参数
313        String sql = DialectFactory.getDialect().forSelectByQuery(queryWrapper);
314
315        Object[] valueArray = CPI.getValueArray(queryWrapper);
316        ProviderUtil.setSqlArgs(params, valueArray);
317
318        return sql;
319    }
320
321}