001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.util;
017
018import com.mybatisflex.core.BaseMapper;
019import com.mybatisflex.core.FlexGlobalConfig;
020import com.mybatisflex.core.constant.SqlConsts;
021import com.mybatisflex.core.dialect.DbType;
022import com.mybatisflex.core.dialect.DialectFactory;
023import com.mybatisflex.core.exception.FlexExceptions;
024import com.mybatisflex.core.field.FieldQuery;
025import com.mybatisflex.core.field.FieldQueryBuilder;
026import com.mybatisflex.core.field.FieldQueryManager;
027import com.mybatisflex.core.paginate.Page;
028import com.mybatisflex.core.query.*;
029import com.mybatisflex.core.relation.RelationManager;
030import org.apache.ibatis.exceptions.TooManyResultsException;
031import org.apache.ibatis.session.defaults.DefaultSqlSession;
032
033import java.util.*;
034import java.util.function.Consumer;
035
036import static com.mybatisflex.core.query.QueryMethods.count;
037
038public class MapperUtil {
039
040    private MapperUtil() {
041    }
042
043
044    /**
045     * <p>原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性,
046     * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。
047     *
048     * <p>为什么这么说,因为是用子查询实现的,生成的 SQL 如下:
049     *
050     * <p><pre>
051     * {@code
052     * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... ) AS `t`;
053     * }
054     * </pre>
055     *
056     * <p>不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。
057     */
058    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
059        return QueryWrapper.create()
060            .select(count().as("total"))
061            .from(queryWrapper).as("t");
062    }
063
064    /**
065     * 优化 COUNT 查询语句。
066     */
067    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
068        // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象
069        QueryWrapper clone = queryWrapper.clone();
070        // 将最后面的 order by 移除掉
071        CPI.setOrderBys(clone, null);
072        // 获取查询列和分组列,用于判断是否进行优化
073        List<QueryColumn> selectColumns = CPI.getSelectColumns(clone);
074        List<QueryColumn> groupByColumns = CPI.getGroupByColumns(clone);
075        // 如果有 distinct 语句或者 group by 语句则不优化
076        // 这种一旦优化了就会造成 count 语句查询出来的值不对
077        if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns)) {
078            return rawCountQueryWrapper(clone);
079        }
080        // 判断能不能清除 join 语句
081        if (canClearJoins(clone)) {
082            CPI.setJoins(clone, null);
083        }
084        // 将 select 里面的列换成 COUNT(*) AS `total`
085        CPI.setSelectColumns(clone, Collections.singletonList(count().as("total")));
086        return clone;
087    }
088
089    public static boolean hasDistinct(List<QueryColumn> selectColumns) {
090        if (CollectionUtil.isEmpty(selectColumns)) {
091            return false;
092        }
093        for (QueryColumn selectColumn : selectColumns) {
094            if (selectColumn instanceof DistinctQueryColumn) {
095                return true;
096            }
097        }
098        return false;
099    }
100
101    private static boolean hasGroupBy(List<QueryColumn> groupByColumns) {
102        return CollectionUtil.isNotEmpty(groupByColumns);
103    }
104
105    private static boolean canClearJoins(QueryWrapper queryWrapper) {
106        List<Join> joins = CPI.getJoins(queryWrapper);
107        if (CollectionUtil.isEmpty(joins)) {
108            return false;
109        }
110
111        // 只有全是 left join 语句才会清除 join
112        // 因为如果是 inner join 或 right join 往往都会放大记录数
113        for (Join join : joins) {
114            if (!SqlConsts.LEFT_JOIN.equals(CPI.getJoinType(join))) {
115                return false;
116            }
117        }
118
119        // 获取 join 语句中使用到的表名
120        List<String> joinTables = new ArrayList<>();
121        joins.forEach(join -> {
122            QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
123            if (joinQueryTable != null) {
124                String tableName = joinQueryTable.getName();
125                if (StringUtil.isNotBlank(joinQueryTable.getAlias())) {
126                    joinTables.add(tableName + "." + joinQueryTable.getAlias());
127                } else {
128                    joinTables.add(tableName);
129                }
130            }
131        });
132
133        // 获取 where 语句中的条件
134        QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
135
136        // 最后判断一下 where 中是否用到了 join 的表
137        return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
138    }
139
140    @SafeVarargs
141    public static <T, R> Page<R> doPaginate(BaseMapper<T> mapper, Page<R> page, QueryWrapper queryWrapper, Class<R> asType, boolean withRelations, Consumer<FieldQueryBuilder<R>>... consumers) {
142        try {
143            // 只有 totalRow 小于 0 的时候才会去查询总量
144            // 这样方便用户做总数缓存,而非每次都要去查询总量
145            // 一般的分页场景中,只有第一页的时候有必要去查询总量,第二页以后是不需要的
146            if (page.getTotalRow() < 0) {
147                QueryWrapper countQueryWrapper;
148                if (page.needOptimizeCountQuery()) {
149                    countQueryWrapper = MapperUtil.optimizeCountQueryWrapper(queryWrapper);
150                } else {
151                    countQueryWrapper = MapperUtil.rawCountQueryWrapper(queryWrapper);
152                }
153                page.setTotalRow(mapper.selectCountByQuery(countQueryWrapper));
154            }
155
156            if (!page.hasRecords()) {
157                if (withRelations) {
158                    RelationManager.clearConfigIfNecessary();
159                }
160                return page;
161            }
162
163            queryWrapper.limit(page.offset(), page.getPageSize());
164
165            List<R> records;
166            if (asType != null) {
167                records = mapper.selectListByQueryAs(queryWrapper, asType);
168            } else {
169                // noinspection unchecked
170                records = (List<R>) mapper.selectListByQuery(queryWrapper);
171            }
172
173            if (withRelations) {
174                queryRelations(mapper, records);
175            }
176
177            queryFields(mapper, records, consumers);
178            page.setRecords(records);
179
180            return page;
181
182        } finally {
183            // 将之前设置的 limit 清除掉
184            // 保险起见把重置代码放到 finally 代码块中
185            CPI.setLimitRows(queryWrapper, null);
186            CPI.setLimitOffset(queryWrapper, null);
187        }
188    }
189
190
191    public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) {
192        if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {
193            return;
194        }
195
196        Map<String, FieldQuery> fieldQueryMap = new HashMap<>();
197        for (Consumer<FieldQueryBuilder<R>> consumer : consumers) {
198            FieldQueryBuilder<R> fieldQueryBuilder = new FieldQueryBuilder<>();
199            consumer.accept(fieldQueryBuilder);
200
201            FieldQuery fieldQuery = fieldQueryBuilder.build();
202
203            String className = fieldQuery.getEntityClass().getName();
204            String fieldName = fieldQuery.getFieldName();
205            String mapKey = className + '#' + fieldName;
206
207            fieldQueryMap.put(mapKey, fieldQuery);
208        }
209
210        FieldQueryManager.queryFields(mapper, list, fieldQueryMap);
211    }
212
213
214    public static <E> E queryRelations(BaseMapper<?> mapper, E entity) {
215        if (entity != null) {
216            queryRelations(mapper, Collections.singletonList(entity));
217        } else {
218            RelationManager.clearConfigIfNecessary();
219        }
220        return entity;
221    }
222
223    public static <E> List<E> queryRelations(BaseMapper<?> mapper, List<E> entities) {
224        RelationManager.queryRelations(mapper, entities);
225        return entities;
226    }
227
228
229    public static Class<? extends Collection> getCollectionWrapType(Class<?> type) {
230        if (ClassUtil.canInstance(type.getModifiers())) {
231            return (Class<? extends Collection>) type;
232        }
233
234        if (List.class.isAssignableFrom(type)) {
235            return ArrayList.class;
236        }
237
238        if (Set.class.isAssignableFrom(type)) {
239            return HashSet.class;
240        }
241
242        throw new IllegalStateException("Field query can not support type: " + type.getName());
243    }
244
245
246    /**
247     * 搬运加改造 {@link DefaultSqlSession#selectOne(String, Object)}
248     */
249    public static <T> T getSelectOneResult(List<T> list) {
250        if (list == null || list.isEmpty()) {
251            return null;
252        }
253        int size = list.size();
254        if (size == 1) {
255            return list.get(0);
256        }
257        throw new TooManyResultsException(
258            "Expected one result (or null) to be returned by selectOne(), but found: " + size);
259    }
260
261    public static long getLongNumber(List<Object> objects) {
262        Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
263        if (object == null) {
264            return 0;
265        } else if (object instanceof Number) {
266            return ((Number) object).longValue();
267        } else {
268            throw FlexExceptions.wrap("selectCountByQuery error, can not get number value of result: \"" + object + "\"");
269        }
270    }
271
272
273    public static Map<String, Object> preparedParams(Page<?> page, QueryWrapper queryWrapper, Map<String, Object> params) {
274        Map<String, Object> newParams = new HashMap<>();
275
276        if (params != null) {
277            newParams.putAll(params);
278        }
279
280        newParams.put("pageOffset", page.offset());
281        newParams.put("pageNumber", page.getPageNumber());
282        newParams.put("pageSize", page.getPageSize());
283
284        DbType dbType = DialectFactory.getHintDbType();
285        newParams.put("dbType", dbType != null ? dbType : FlexGlobalConfig.getDefaultConfig().getDbType());
286
287        if (queryWrapper != null) {
288            preparedQueryWrapper(newParams, queryWrapper);
289        }
290
291        return newParams;
292    }
293
294
295    private static void preparedQueryWrapper(Map<String, Object> params, QueryWrapper queryWrapper) {
296        String sql = DialectFactory.getDialect().buildNoSelectSql(queryWrapper);
297        StringBuilder sqlBuilder = new StringBuilder();
298        char quote = 0;
299        int index = 0;
300        for (int i = 0; i < sql.length(); ++i) {
301            char ch = sql.charAt(i);
302            if (ch == '\'') {
303                if (quote == 0) {
304                    quote = ch;
305                } else if (quote == '\'') {
306                    quote = 0;
307                }
308            } else if (ch == '"') {
309                if (quote == 0) {
310                    quote = ch;
311                } else if (quote == '"') {
312                    quote = 0;
313                }
314            }
315            if (quote == 0 && ch == '?') {
316                sqlBuilder.append("#{qwParams_").append(index++).append("}");
317            } else {
318                sqlBuilder.append(ch);
319            }
320        }
321        params.put("qwSql", sqlBuilder.toString());
322        Object[] valueArray = CPI.getValueArray(queryWrapper);
323        for (int i = 0; i < valueArray.length; i++) {
324            params.put("qwParams_" + i, valueArray[i]);
325        }
326    }
327
328}