001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.provider;
017
018import com.mybatisflex.core.FlexConsts;
019import com.mybatisflex.core.dialect.DialectFactory;
020import com.mybatisflex.core.exception.FlexAssert;
021import com.mybatisflex.core.query.CPI;
022import com.mybatisflex.core.query.QueryWrapper;
023import com.mybatisflex.core.row.Row;
024import com.mybatisflex.core.row.RowCPI;
025import com.mybatisflex.core.row.RowMapper;
026import com.mybatisflex.core.table.TableInfo;
027import com.mybatisflex.core.table.TableInfoFactory;
028import com.mybatisflex.core.util.ArrayUtil;
029
030import java.util.*;
031
032@SuppressWarnings({"rawtypes", "DuplicatedCode"})
033public class RowSqlProvider {
034
035    public static final String METHOD_RAW_SQL = "providerRawSql";
036
037    /**
038     * 不让实例化,使用静态方法的模式,效率更高,非静态方法每次都会实例化当前类
039     * 参考源码: {{@link org.apache.ibatis.builder.annotation.ProviderSqlSource#getBoundSql(Object)}
040     */
041    private RowSqlProvider() {
042    }
043
044    /**
045     * 执行原生 sql 的方法
046     *
047     * @param params 方法参数
048     * @return SQL 语句
049     * @see RowMapper#insertBySql(String, Object...)
050     * @see RowMapper#deleteBySql(String, Object...)
051     * @see RowMapper#updateBySql(String, Object...)
052     */
053    public static String providerRawSql(Map params) {
054        return ProviderUtil.getSqlString(params);
055    }
056
057    /**
058     * insert 的 SQL 构建。
059     *
060     * @param params 方法参数
061     * @return SQL 语句
062     * @see RowMapper#insert(String, String, Row)
063     */
064    public static String insert(Map params) {
065        String tableName = ProviderUtil.getTableName(params);
066        String schema = ProviderUtil.getSchemaName(params);
067        Row row = ProviderUtil.getRow(params);
068
069        // 先生成 SQL,再设置参数
070        String sql = DialectFactory.getDialect().forInsertRow(schema, tableName, row);
071        ProviderUtil.setSqlArgs(params, RowCPI.obtainInsertValues(row));
072        return sql;
073    }
074
075    /**
076     * insertBatch 的 SQL 构建。
077     *
078     * @param params 方法参数
079     * @return SQL 语句
080     * @see RowMapper#insertBatchWithFirstRowColumns(String, String, List)
081     */
082    public static String insertBatchWithFirstRowColumns(Map params) {
083        List<Row> rows = ProviderUtil.getRows(params);
084
085        FlexAssert.notEmpty(rows, "rows");
086
087        String tableName = ProviderUtil.getTableName(params);
088        String schema = ProviderUtil.getSchemaName(params);
089
090        // 让所有 row 的列顺序和值的数量与第条数据保持一致
091        // 这个必须 new 一个 LinkedHashSet,因为 keepModifyAttrs 会清除 row 所有的 modifyAttrs
092        Set<String> modifyAttrs = new LinkedHashSet<>(RowCPI.getModifyAttrs(rows.get(0)));
093        rows.forEach(row -> row.keep(modifyAttrs));
094
095        //sql: INSERT INTO `tb_table`(`name`, `sex`) VALUES (?, ?),(?, ?),(?, ?)
096        String sql = DialectFactory.getDialect().forInsertBatchWithFirstRowColumns(schema, tableName, rows);
097
098        Object[] values = new Object[]{};
099        for (Row row : rows) {
100            values = ArrayUtil.concat(values, RowCPI.obtainInsertValues(row));
101        }
102        ProviderUtil.setSqlArgs(params, values);
103
104        return sql;
105    }
106
107    /**
108     * deleteById 的 SQL 构建。
109     *
110     * @param params 方法参数
111     * @return SQL 语句
112     * @see RowMapper#deleteById(String, String, String, Object)
113     */
114    public static String deleteById(Map params) {
115        Object[] primaryValues = ProviderUtil.getPrimaryValues(params);
116
117        FlexAssert.notEmpty(primaryValues, "primaryValues");
118
119        String schema = ProviderUtil.getSchemaName(params);
120        String tableName = ProviderUtil.getTableName(params);
121        String[] primaryKeys = ProviderUtil.getPrimaryKeys(params);
122
123        String sql = DialectFactory.getDialect().forDeleteById(schema, tableName, primaryKeys);
124        ProviderUtil.setSqlArgs(params, primaryValues);
125        return sql;
126    }
127
128    /**
129     * deleteBatchByIds 的 SQL 构建。
130     *
131     * @param params 方法参数
132     * @return SQL 语句
133     * @see RowMapper#deleteBatchByIds(String, String, String, Collection)
134     */
135    public static String deleteBatchByIds(Map params) {
136        String schema = ProviderUtil.getSchemaName(params);
137        String tableName = ProviderUtil.getTableName(params);
138        String[] primaryKeys = ProviderUtil.getPrimaryKeys(params);
139        Object[] primaryValues = ProviderUtil.getPrimaryValues(params);
140
141        String sql = DialectFactory.getDialect().forDeleteBatchByIds(schema, tableName, primaryKeys, primaryValues);
142        ProviderUtil.setSqlArgs(params, primaryValues);
143        return sql;
144    }
145
146    /**
147     * deleteByQuery 的 SQL 构建。
148     *
149     * @param params 方法参数
150     * @return SQL 语句
151     * @see RowMapper#deleteByQuery(String, String, QueryWrapper)
152     */
153    public static String deleteByQuery(Map params) {
154        String schema = ProviderUtil.getSchemaName(params);
155        String tableName = ProviderUtil.getTableName(params);
156        QueryWrapper queryWrapper = ProviderUtil.getQueryWrapper(params);
157        CPI.setFromIfNecessary(queryWrapper, schema, tableName);
158
159        //优先构建 sql,再构建参数
160        String sql = DialectFactory.getDialect().forDeleteByQuery(queryWrapper);
161        Object[] valueArray = CPI.getValueArray(queryWrapper);
162        ProviderUtil.setSqlArgs(params, valueArray);
163
164        return sql;
165    }
166
167    /**
168     * updateById 的 SQL 构建。
169     *
170     * @param params 方法参数
171     * @return SQL 语句
172     * @see RowMapper#updateById(String, String, Row)
173     */
174    public static String updateById(Map params) {
175        String schema = ProviderUtil.getSchemaName(params);
176        String tableName = ProviderUtil.getTableName(params);
177        Row row = ProviderUtil.getRow(params);
178        String sql = DialectFactory.getDialect().forUpdateById(schema, tableName, row);
179        ProviderUtil.setSqlArgs(params, RowCPI.obtainAllModifyValues(row));
180        return sql;
181    }
182
183    /**
184     * updateByQuery 的 SQL 构建。
185     *
186     * @param params 方法参数
187     * @return SQL 语句
188     * @see RowMapper#updateByQuery(String, String, Row, QueryWrapper)
189     */
190    public static String updateByQuery(Map params) {
191        String schema = ProviderUtil.getSchemaName(params);
192        String tableName = ProviderUtil.getTableName(params);
193        Row data = ProviderUtil.getRow(params);
194
195        QueryWrapper queryWrapper = ProviderUtil.getQueryWrapper(params);
196        CPI.setFromIfNecessary(queryWrapper, schema, tableName);
197
198        //优先构建 sql,再构建参数
199        String sql = DialectFactory.getDialect().forUpdateByQuery(queryWrapper, data);
200
201        Object[] modifyValues = RowCPI.obtainModifyValues(data);
202        Object[] valueArray = CPI.getValueArray(queryWrapper);
203
204        ProviderUtil.setSqlArgs(params, ArrayUtil.concat(modifyValues, valueArray));
205
206        return sql;
207    }
208
209    /**
210     * updateBatchById 的 SQL 构建。
211     * mysql 等链接配置需要开启 allowMultiQueries=true
212     *
213     * @param params 方法参数
214     * @return SQL 语句
215     * @see RowMapper#updateBatchById(String, String, List)
216     */
217    public static String updateBatchById(Map params) {
218        List<Row> rows = ProviderUtil.getRows(params);
219
220        FlexAssert.notEmpty(rows, "rows");
221
222        String schema = ProviderUtil.getSchemaName(params);
223        String tableName = ProviderUtil.getTableName(params);
224
225        String sql = DialectFactory.getDialect().forUpdateBatchById(schema, tableName, rows);
226
227        Object[] values = FlexConsts.EMPTY_ARRAY;
228        for (Row row : rows) {
229            values = ArrayUtil.concat(values, RowCPI.obtainAllModifyValues(row));
230        }
231        ProviderUtil.setSqlArgs(params, values);
232        return sql;
233    }
234
235    /**
236     * updateEntity 的 SQL 构建。
237     *
238     * @param params 方法参数
239     * @return SQL 语句
240     * @see RowMapper#updateEntity(Object entities)
241     */
242    public static String updateEntity(Map params) {
243        Object entity = ProviderUtil.getEntity(params);
244
245        FlexAssert.notNull(entity, "entity can not be null");
246
247        // 该 Mapper 是通用 Mapper  无法通过 ProviderContext 获取,直接使用 TableInfoFactory
248        TableInfo tableInfo = TableInfoFactory.ofEntityClass(entity.getClass());
249
250        // 执行 onUpdate 监听器
251        tableInfo.invokeOnUpdateListener(entity);
252
253        String sql = DialectFactory.getDialect().forUpdateEntity(tableInfo, entity, false);
254
255        Object[] updateValues = tableInfo.buildUpdateSqlArgs(entity, false, false);
256        Object[] primaryValues = tableInfo.buildPkSqlArgs(entity);
257        Object[] tenantIdArgs = tableInfo.buildTenantIdArgs();
258
259        FlexAssert.assertAreNotNull(primaryValues, "The value of primary key must not be null, entity[%s]", entity);
260
261        ProviderUtil.setSqlArgs(params, ArrayUtil.concat(updateValues, primaryValues, tenantIdArgs));
262        return sql;
263    }
264
265
266    /**
267     * selectOneById 的 SQL 构建。
268     *
269     * @param params 方法参数
270     * @return SQL 语句
271     * @see RowMapper#selectOneById(String, String, String, Object)
272     */
273    public static String selectOneById(Map params) {
274        String schema = ProviderUtil.getSchemaName(params);
275        String tableName = ProviderUtil.getTableName(params);
276        String[] primaryKeys = ProviderUtil.getPrimaryKeys(params);
277        Object[] primaryValues = ProviderUtil.getPrimaryValues(params);
278
279        String sql = DialectFactory.getDialect().forSelectOneById(schema, tableName, primaryKeys, primaryValues);
280        ProviderUtil.setSqlArgs(params, primaryValues);
281
282        return sql;
283    }
284
285    /**
286     * selectListByQuery 的 SQL 构建。
287     *
288     * @param params 方法参数
289     * @return SQL 语句
290     * @see RowMapper#selectListByQuery(String, String, QueryWrapper)
291     */
292    public static String selectListByQuery(Map params) {
293        String schema = ProviderUtil.getSchemaName(params);
294        String tableName = ProviderUtil.getTableName(params);
295
296        QueryWrapper queryWrapper = ProviderUtil.getQueryWrapper(params);
297        CPI.setFromIfNecessary(queryWrapper, schema, tableName);
298
299        //优先构建 sql,再构建参数
300        String sql = DialectFactory.getDialect().forSelectByQuery(queryWrapper);
301
302        Object[] valueArray = CPI.getValueArray(queryWrapper);
303        ProviderUtil.setSqlArgs(params, valueArray);
304
305        return sql;
306    }
307
308}