001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.util;
017
018import com.mybatisflex.core.BaseMapper;
019import com.mybatisflex.core.FlexGlobalConfig;
020import com.mybatisflex.core.constant.SqlConsts;
021import com.mybatisflex.core.dialect.DbType;
022import com.mybatisflex.core.dialect.DialectFactory;
023import com.mybatisflex.core.exception.FlexExceptions;
024import com.mybatisflex.core.field.FieldQuery;
025import com.mybatisflex.core.field.FieldQueryBuilder;
026import com.mybatisflex.core.field.FieldQueryManager;
027import com.mybatisflex.core.paginate.Page;
028import com.mybatisflex.core.query.CPI;
029import com.mybatisflex.core.query.DistinctQueryColumn;
030import com.mybatisflex.core.query.Join;
031import com.mybatisflex.core.query.QueryColumn;
032import com.mybatisflex.core.query.QueryCondition;
033import com.mybatisflex.core.query.QueryTable;
034import com.mybatisflex.core.query.QueryWrapper;
035import com.mybatisflex.core.relation.RelationManager;
036import com.mybatisflex.core.table.TableInfo;
037import com.mybatisflex.core.table.TableInfoFactory;
038import org.apache.ibatis.exceptions.TooManyResultsException;
039import org.apache.ibatis.session.defaults.DefaultSqlSession;
040
041import java.util.ArrayList;
042import java.util.Collection;
043import java.util.Collections;
044import java.util.HashMap;
045import java.util.HashSet;
046import java.util.List;
047import java.util.Map;
048import java.util.Set;
049import java.util.function.Consumer;
050
051import static com.mybatisflex.core.query.QueryMethods.count;
052
053public class MapperUtil {
054
055    private MapperUtil() {
056    }
057
058
059    /**
060     * <p>原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性,
061     * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。
062     *
063     * <p>为什么这么说,因为是用子查询实现的,生成的 SQL 如下:
064     *
065     * <p><pre>
066     * {@code
067     * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... ) AS `t`;
068     * }
069     * </pre>
070     *
071     * <p>不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。
072     */
073    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
074        return QueryWrapper.create()
075            .select(count().as("total"))
076            .from(queryWrapper).as("t");
077    }
078
079    /**
080     * 优化 COUNT 查询语句。
081     */
082    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
083        // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象
084        QueryWrapper clone = queryWrapper.clone();
085        // 将最后面的 order by 移除掉
086        CPI.setOrderBys(clone, null);
087        // 获取查询列和分组列,用于判断是否进行优化
088        List<QueryColumn> selectColumns = CPI.getSelectColumns(clone);
089        List<QueryColumn> groupByColumns = CPI.getGroupByColumns(clone);
090        // 如果有 distinct 语句或者 group by 语句则不优化
091        // 这种一旦优化了就会造成 count 语句查询出来的值不对
092        if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns)) {
093            return rawCountQueryWrapper(clone);
094        }
095        // 判断能不能清除 join 语句
096        if (canClearJoins(clone)) {
097            CPI.setJoins(clone, null);
098        }
099        // 将 select 里面的列换成 COUNT(*) AS `total`
100        CPI.setSelectColumns(clone, Collections.singletonList(count().as("total")));
101        return clone;
102    }
103
104    public static boolean hasDistinct(List<QueryColumn> selectColumns) {
105        if (CollectionUtil.isEmpty(selectColumns)) {
106            return false;
107        }
108        for (QueryColumn selectColumn : selectColumns) {
109            if (selectColumn instanceof DistinctQueryColumn) {
110                return true;
111            }
112        }
113        return false;
114    }
115
116    private static boolean hasGroupBy(List<QueryColumn> groupByColumns) {
117        return CollectionUtil.isNotEmpty(groupByColumns);
118    }
119
120    private static boolean canClearJoins(QueryWrapper queryWrapper) {
121        List<Join> joins = CPI.getJoins(queryWrapper);
122        if (CollectionUtil.isEmpty(joins)) {
123            return false;
124        }
125
126        // 只有全是 left join 语句才会清除 join
127        // 因为如果是 inner join 或 right join 往往都会放大记录数
128        for (Join join : joins) {
129            if (!SqlConsts.LEFT_JOIN.equals(CPI.getJoinType(join))) {
130                return false;
131            }
132        }
133
134        // 获取 join 语句中使用到的表名
135        List<String> joinTables = new ArrayList<>();
136        joins.forEach(join -> {
137            QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
138            if (joinQueryTable != null) {
139                String tableName = joinQueryTable.getName();
140                if (StringUtil.isNotBlank(joinQueryTable.getAlias())) {
141                    joinTables.add(tableName + "." + joinQueryTable.getAlias());
142                } else {
143                    joinTables.add(tableName);
144                }
145            }
146        });
147
148        // 获取 where 语句中的条件
149        QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
150
151        // 最后判断一下 where 中是否用到了 join 的表
152        return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
153    }
154
155    @SafeVarargs
156    public static <T, R> Page<R> doPaginate(
157        BaseMapper<T> mapper,
158        Page<R> page,
159        QueryWrapper queryWrapper,
160        Class<R> asType,
161        boolean withRelations,
162        Consumer<FieldQueryBuilder<R>>... consumers
163    ) {
164        Long limitRows = CPI.getLimitRows(queryWrapper);
165        Long limitOffset = CPI.getLimitOffset(queryWrapper);
166        try {
167            // 只有 totalRow 小于 0 的时候才会去查询总量
168            // 这样方便用户做总数缓存,而非每次都要去查询总量
169            // 一般的分页场景中,只有第一页的时候有必要去查询总量,第二页以后是不需要的
170
171            if (page.getTotalRow() < 0) {
172
173                QueryWrapper countQueryWrapper;
174
175                if (page.needOptimizeCountQuery()) {
176                    countQueryWrapper = MapperUtil.optimizeCountQueryWrapper(queryWrapper);
177                } else {
178                    countQueryWrapper = MapperUtil.rawCountQueryWrapper(queryWrapper);
179                }
180
181                // optimize: 在 count 之前先去掉 limit 参数,避免 count 查询错误
182                CPI.setLimitRows(countQueryWrapper, null);
183                CPI.setLimitOffset(countQueryWrapper, null);
184
185                page.setTotalRow(mapper.selectCountByQuery(countQueryWrapper));
186            }
187
188            if (!page.hasRecords()) {
189                if (withRelations) {
190                    RelationManager.clearConfigIfNecessary();
191                }
192                return page;
193            }
194
195            queryWrapper.limit(page.offset(), page.getPageSize());
196
197            List<R> records;
198            if (asType != null) {
199                records = mapper.selectListByQueryAs(queryWrapper, asType);
200            } else {
201                // noinspection unchecked
202                records = (List<R>) mapper.selectListByQuery(queryWrapper);
203            }
204
205            if (withRelations) {
206                queryRelations(mapper, records);
207            }
208
209            queryFields(mapper, records, consumers);
210            page.setRecords(records);
211
212            return page;
213
214        } finally {
215            // 将之前设置的 limit 清除掉
216            // 保险起见把重置代码放到 finally 代码块中
217            CPI.setLimitRows(queryWrapper, limitRows);
218            CPI.setLimitOffset(queryWrapper, limitOffset);
219        }
220    }
221
222
223    public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) {
224        if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {
225            return;
226        }
227
228        Map<String, FieldQuery> fieldQueryMap = new HashMap<>();
229        for (Consumer<FieldQueryBuilder<R>> consumer : consumers) {
230            FieldQueryBuilder<R> fieldQueryBuilder = new FieldQueryBuilder<>();
231            consumer.accept(fieldQueryBuilder);
232
233            FieldQuery fieldQuery = fieldQueryBuilder.build();
234
235            String className = fieldQuery.getEntityClass().getName();
236            String fieldName = fieldQuery.getFieldName();
237            String mapKey = className + '#' + fieldName;
238
239            fieldQueryMap.put(mapKey, fieldQuery);
240        }
241
242        FieldQueryManager.queryFields(mapper, list, fieldQueryMap);
243    }
244
245
246    public static <E> E queryRelations(BaseMapper<?> mapper, E entity) {
247        if (entity != null) {
248            queryRelations(mapper, Collections.singletonList(entity));
249        } else {
250            RelationManager.clearConfigIfNecessary();
251        }
252        return entity;
253    }
254
255    public static <E> List<E> queryRelations(BaseMapper<?> mapper, List<E> entities) {
256        RelationManager.queryRelations(mapper, entities);
257        return entities;
258    }
259
260
261    public static Class<? extends Collection> getCollectionWrapType(Class<?> type) {
262        if (ClassUtil.canInstance(type.getModifiers())) {
263            return (Class<? extends Collection>) type;
264        }
265
266        if (List.class.isAssignableFrom(type)) {
267            return ArrayList.class;
268        }
269
270        if (Set.class.isAssignableFrom(type)) {
271            return HashSet.class;
272        }
273
274        throw new IllegalStateException("Field query can not support type: " + type.getName());
275    }
276
277
278    /**
279     * 搬运加改造 {@link DefaultSqlSession#selectOne(String, Object)}
280     */
281    public static <T> T getSelectOneResult(List<T> list) {
282        if (list == null || list.isEmpty()) {
283            return null;
284        }
285        int size = list.size();
286        if (size == 1) {
287            return list.get(0);
288        }
289        throw new TooManyResultsException(
290            "Expected one result (or null) to be returned by selectOne(), but found: " + size);
291    }
292
293    public static long getLongNumber(List<Object> objects) {
294        Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
295        if (object == null) {
296            return 0;
297        } else if (object instanceof Number) {
298            return ((Number) object).longValue();
299        } else {
300            throw FlexExceptions.wrap("selectCountByQuery error, can not get number value of result: \"" + object + "\"");
301        }
302    }
303
304
305    public static Map<String, Object> preparedParams(BaseMapper<?> baseMapper, Page<?> page, QueryWrapper queryWrapper, Map<String, Object> params) {
306        Map<String, Object> newParams = new HashMap<>();
307
308        if (params != null) {
309            newParams.putAll(params);
310        }
311
312        newParams.put("pageOffset", page.offset());
313        newParams.put("pageNumber", page.getPageNumber());
314        newParams.put("pageSize", page.getPageSize());
315
316        DbType dbType = DialectFactory.getHintDbType();
317        newParams.put("dbType", dbType != null ? dbType : FlexGlobalConfig.getDefaultConfig().getDbType());
318
319        if (queryWrapper != null) {
320            TableInfo tableInfo = TableInfoFactory.ofMapperClass(baseMapper.getClass());
321            tableInfo.appendConditions(null, queryWrapper);
322            preparedQueryWrapper(newParams, queryWrapper);
323        }
324
325        return newParams;
326    }
327
328
329    private static void preparedQueryWrapper(Map<String, Object> params, QueryWrapper queryWrapper) {
330        String sql = DialectFactory.getDialect().buildNoSelectSql(queryWrapper);
331        StringBuilder sqlBuilder = new StringBuilder();
332        char quote = 0;
333        int index = 0;
334        for (int i = 0; i < sql.length(); ++i) {
335            char ch = sql.charAt(i);
336            if (ch == '\'') {
337                if (quote == 0) {
338                    quote = ch;
339                } else if (quote == '\'') {
340                    quote = 0;
341                }
342            } else if (ch == '"') {
343                if (quote == 0) {
344                    quote = ch;
345                } else if (quote == '"') {
346                    quote = 0;
347                }
348            }
349            if (quote == 0 && ch == '?') {
350                sqlBuilder.append("#{qwParams_").append(index++).append("}");
351            } else {
352                sqlBuilder.append(ch);
353            }
354        }
355        params.put("qwSql", sqlBuilder.toString());
356        Object[] valueArray = CPI.getValueArray(queryWrapper);
357        for (int i = 0; i < valueArray.length; i++) {
358            params.put("qwParams_" + i, valueArray[i]);
359        }
360    }
361
362}