001/*
002 *  Copyright (c) 2022-2025, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.util;
017
018import com.mybatisflex.core.BaseMapper;
019import com.mybatisflex.core.FlexGlobalConfig;
020import com.mybatisflex.core.constant.SqlConsts;
021import com.mybatisflex.core.dialect.DbType;
022import com.mybatisflex.core.dialect.DialectFactory;
023import com.mybatisflex.core.exception.FlexExceptions;
024import com.mybatisflex.core.field.FieldQuery;
025import com.mybatisflex.core.field.FieldQueryBuilder;
026import com.mybatisflex.core.field.FieldQueryManager;
027import com.mybatisflex.core.paginate.Page;
028import com.mybatisflex.core.query.*;
029import com.mybatisflex.core.relation.RelationManager;
030import com.mybatisflex.core.table.TableInfo;
031import com.mybatisflex.core.table.TableInfoFactory;
032import org.apache.ibatis.exceptions.TooManyResultsException;
033import org.apache.ibatis.session.defaults.DefaultSqlSession;
034
035import java.util.ArrayList;
036import java.util.Collection;
037import java.util.Collections;
038import java.util.HashMap;
039import java.util.HashSet;
040import java.util.List;
041import java.util.Map;
042import java.util.Set;
043import java.util.function.Consumer;
044
045import static com.mybatisflex.core.query.QueryMethods.count;
046
047public class MapperUtil {
048
049    private MapperUtil() {
050    }
051
052
053    /**
054     * <p>原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性,
055     * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。
056     *
057     * <p>为什么这么说,因为是用子查询实现的,生成的 SQL 如下:
058     *
059     * <p><pre>
060     * {@code
061     * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... ) AS `t`;
062     * }
063     * </pre>
064     *
065     * <p>不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。
066     */
067    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
068        return QueryWrapper.create()
069            .select(count().as("total"))
070            .from(queryWrapper).as("t");
071    }
072
073    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper, List<QueryColumn> customCountColumns) {
074        return customCountColumns != null ? QueryWrapper.create()
075            .select(customCountColumns)
076            .from(queryWrapper).as("t") : rawCountQueryWrapper(queryWrapper);
077    }
078
079    /**
080     * 优化 COUNT 查询语句。
081     */
082    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
083        return optimizeCountQueryWrapper(queryWrapper, Collections.singletonList(count().as("total")));
084    }
085
086    /**
087     * 优化 COUNT 查询语句。
088     */
089    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper, List<QueryColumn> customCountColumns) {
090        // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象
091        QueryWrapper clone = queryWrapper.clone();
092
093        List<UnionWrapper> unions = CPI.getUnions(clone);
094        if (!CollectionUtil.isEmpty(unions)) {
095            List<UnionWrapper> newUnions = new ArrayList<>(unions.size());
096            for (UnionWrapper union : unions) {
097                QueryWrapper unionQuery = optimizeCountQueryWrapper(union.getQueryWrapper().clone(), null);
098                UnionWrapper clone1 = union.clone();
099                clone1.setQueryWrapper(unionQuery);
100                newUnions.add(clone1);
101            }
102            CPI.setUnions(clone, newUnions);
103        }
104
105        // 将最后面的 order by 移除掉
106        CPI.setOrderBys(clone, null);
107        // 获取查询列和分组列,用于判断是否进行优化
108        List<QueryColumn> selectColumns = CPI.getSelectColumns(clone);
109        List<QueryColumn> groupByColumns = CPI.getGroupByColumns(clone);
110        QueryCondition havingCondition = CPI.getHavingQueryCondition(clone);
111        // 如果有 distinct、group by、having 等语句则不优化
112        // 这种一旦优化了就会造成 count 语句查询出来的值不对
113        if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns) || havingCondition != null) {
114            return rawCountQueryWrapper(clone);
115        }
116        // 判断能不能清除 join 语句
117        if (canClearJoins(clone)) {
118            CPI.setJoins(clone, null);
119        }
120        // 将 select 里面的列换成 COUNT(*) AS `total`
121        if (customCountColumns != null) {
122            if (hasUnion(clone)) {
123                return rawCountQueryWrapper(clone, customCountColumns);
124            } else {
125                CPI.setSelectColumns(clone, customCountColumns);
126            }
127        }
128        return clone;
129    }
130
131    public static boolean hasDistinct(List<QueryColumn> selectColumns) {
132        if (CollectionUtil.isEmpty(selectColumns)) {
133            return false;
134        }
135        for (QueryColumn selectColumn : selectColumns) {
136            if (selectColumn instanceof DistinctQueryColumn) {
137                return true;
138            }
139        }
140        return false;
141    }
142
143    private static boolean hasGroupBy(List<QueryColumn> groupByColumns) {
144        return CollectionUtil.isNotEmpty(groupByColumns);
145    }
146
147    private static boolean hasUnion(QueryWrapper countQueryWrapper) {
148        return CollectionUtil.isNotEmpty(CPI.getUnions(countQueryWrapper));
149    }
150
151    private static boolean canClearJoins(QueryWrapper queryWrapper) {
152        List<Join> joins = CPI.getJoins(queryWrapper);
153        if (CollectionUtil.isEmpty(joins)) {
154            return false;
155        }
156
157        // 只有全是 left join 语句才会清除 join
158        // 因为如果是 inner join 或 right join 往往都会放大记录数
159        for (Join join : joins) {
160            if (!SqlConsts.LEFT_JOIN.equals(CPI.getJoinType(join))) {
161                return false;
162            }
163        }
164
165        // 获取 join 语句中使用到的表名
166        List<String> joinTables = new ArrayList<>();
167        joins.forEach(join -> {
168            QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
169            if (joinQueryTable != null) {
170                String tableName = joinQueryTable.getName();
171                if (StringUtil.hasText(joinQueryTable.getAlias())) {
172                    joinTables.add(tableName + "." + joinQueryTable.getAlias());
173                } else {
174                    joinTables.add(tableName);
175                }
176            }
177        });
178
179        // 获取 where 语句中的条件
180        QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
181
182        // 最后判断一下 where 中是否用到了 join 的表
183        return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
184    }
185
186    @SafeVarargs
187    public static <T, R> Page<R> doPaginate(
188        BaseMapper<T> mapper,
189        Page<R> page,
190        QueryWrapper queryWrapper,
191        Class<R> asType,
192        boolean withRelations,
193        Consumer<FieldQueryBuilder<R>>... consumers
194    ) {
195        Long limitRows = CPI.getLimitRows(queryWrapper);
196        Long limitOffset = CPI.getLimitOffset(queryWrapper);
197        try {
198            // 只有 totalRow 小于 0 的时候才会去查询总量
199            // 这样方便用户做总数缓存,而非每次都要去查询总量
200            // 一般的分页场景中,只有第一页的时候有必要去查询总量,第二页以后是不需要的
201
202            if (page.getTotalRow() < 0) {
203
204                QueryWrapper countQueryWrapper;
205
206                if (page.needOptimizeCountQuery()) {
207                    countQueryWrapper = MapperUtil.optimizeCountQueryWrapper(queryWrapper);
208                } else {
209                    countQueryWrapper = MapperUtil.rawCountQueryWrapper(queryWrapper);
210                }
211
212                // optimize: 在 count 之前先去掉 limit 参数,避免 count 查询错误
213                CPI.setLimitRows(countQueryWrapper, null);
214                CPI.setLimitOffset(countQueryWrapper, null);
215
216                page.setTotalRow(mapper.selectCountByQuery(countQueryWrapper));
217            }
218
219            if (!page.hasRecords()) {
220                if (withRelations) {
221                    RelationManager.clearConfigIfNecessary();
222                }
223                return page;
224            }
225
226            queryWrapper.limit(page.offset(), page.getPageSize());
227
228            List<R> records;
229            if (asType != null) {
230                records = mapper.selectListByQueryAs(queryWrapper, asType);
231            } else {
232                // noinspection unchecked
233                records = (List<R>) mapper.selectListByQuery(queryWrapper);
234            }
235
236            if (withRelations) {
237                queryRelations(mapper, records);
238            }
239
240            queryFields(mapper, records, consumers);
241            page.setRecords(records);
242
243            return page;
244
245        } finally {
246            // 将之前设置的 limit 清除掉
247            // 保险起见把重置代码放到 finally 代码块中
248            CPI.setLimitRows(queryWrapper, limitRows);
249            CPI.setLimitOffset(queryWrapper, limitOffset);
250        }
251    }
252
253
254    public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) {
255        if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {
256            return;
257        }
258
259        Map<String, FieldQuery> fieldQueryMap = new HashMap<>();
260        for (Consumer<FieldQueryBuilder<R>> consumer : consumers) {
261            FieldQueryBuilder<R> fieldQueryBuilder = new FieldQueryBuilder<>();
262            consumer.accept(fieldQueryBuilder);
263
264            FieldQuery fieldQuery = fieldQueryBuilder.build();
265
266            String className = fieldQuery.getEntityClass().getName();
267            String fieldName = fieldQuery.getFieldName();
268            String mapKey = className + '#' + fieldName;
269
270            fieldQueryMap.put(mapKey, fieldQuery);
271        }
272
273        FieldQueryManager.queryFields(mapper, list, fieldQueryMap);
274    }
275
276
277    public static <E> E queryRelations(BaseMapper<?> mapper, E entity) {
278        if (entity != null) {
279            queryRelations(mapper, Collections.singletonList(entity));
280        } else {
281            RelationManager.clearConfigIfNecessary();
282        }
283        return entity;
284    }
285
286    public static <E> List<E> queryRelations(BaseMapper<?> mapper, List<E> entities) {
287        RelationManager.queryRelations(mapper, entities);
288        return entities;
289    }
290
291
292    public static Class<? extends Collection> getCollectionWrapType(Class<?> type) {
293        if (ClassUtil.canInstance(type.getModifiers())) {
294            return (Class<? extends Collection>) type;
295        }
296
297        if (List.class.isAssignableFrom(type)) {
298            return ArrayList.class;
299        }
300
301        if (Set.class.isAssignableFrom(type)) {
302            return HashSet.class;
303        }
304
305        throw new IllegalStateException("Field query can not support type: " + type.getName());
306    }
307
308
309    /**
310     * 搬运加改造 {@link DefaultSqlSession#selectOne(String, Object)}
311     */
312    public static <T> T getSelectOneResult(List<T> list) {
313        if (list == null || list.isEmpty()) {
314            return null;
315        }
316        int size = list.size();
317        if (size == 1) {
318            return list.get(0);
319        }
320        throw new TooManyResultsException(
321            "Expected one result (or null) to be returned by selectOne(), but found: " + size);
322    }
323
324    public static long getLongNumber(List<Object> objects) {
325        Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
326        if (object == null) {
327            return 0;
328        } else if (object instanceof Number) {
329            return ((Number) object).longValue();
330        } else {
331            throw FlexExceptions.wrap("selectCountByQuery error, can not get number value of result: \"" + object + "\"");
332        }
333    }
334
335
336    public static Map<String, Object> preparedParams(BaseMapper<?> baseMapper, Page<?> page, QueryWrapper queryWrapper, Map<String, Object> params) {
337        Map<String, Object> newParams = new HashMap<>();
338
339        if (params != null) {
340            newParams.putAll(params);
341        }
342
343        newParams.put("pageOffset", page.offset());
344        newParams.put("pageNumber", page.getPageNumber());
345        newParams.put("pageSize", page.getPageSize());
346
347        DbType dbType = DialectFactory.getHintDbType();
348        newParams.put("dbType", dbType != null ? dbType : FlexGlobalConfig.getDefaultConfig().getDbType());
349
350        if (queryWrapper != null) {
351            TableInfo tableInfo = TableInfoFactory.ofMapperClass(baseMapper.getClass());
352            tableInfo.appendConditions(null, queryWrapper);
353            preparedQueryWrapper(newParams, queryWrapper);
354        }
355
356        return newParams;
357    }
358
359
360    private static void preparedQueryWrapper(Map<String, Object> params, QueryWrapper queryWrapper) {
361        String sql = DialectFactory.getDialect().buildNoSelectSql(queryWrapper);
362        StringBuilder sqlBuilder = new StringBuilder();
363        char quote = 0;
364        int index = 0;
365        for (int i = 0; i < sql.length(); ++i) {
366            char ch = sql.charAt(i);
367            if (ch == '\'') {
368                if (quote == 0) {
369                    quote = ch;
370                } else if (quote == '\'') {
371                    quote = 0;
372                }
373            } else if (ch == '"') {
374                if (quote == 0) {
375                    quote = ch;
376                } else if (quote == '"') {
377                    quote = 0;
378                }
379            }
380            if (quote == 0 && ch == '?') {
381                sqlBuilder.append("#{qwParams_").append(index++).append("}");
382            } else {
383                sqlBuilder.append(ch);
384            }
385        }
386        params.put("qwSql", sqlBuilder.toString());
387        Object[] valueArray = CPI.getValueArray(queryWrapper);
388        for (int i = 0; i < valueArray.length; i++) {
389            params.put("qwParams_" + i, valueArray[i]);
390        }
391    }
392
393}