001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.util;
017
018import com.mybatisflex.core.BaseMapper;
019import com.mybatisflex.core.FlexGlobalConfig;
020import com.mybatisflex.core.constant.SqlConsts;
021import com.mybatisflex.core.dialect.DbType;
022import com.mybatisflex.core.dialect.DialectFactory;
023import com.mybatisflex.core.exception.FlexExceptions;
024import com.mybatisflex.core.field.FieldQuery;
025import com.mybatisflex.core.field.FieldQueryBuilder;
026import com.mybatisflex.core.field.FieldQueryManager;
027import com.mybatisflex.core.paginate.Page;
028import com.mybatisflex.core.query.CPI;
029import com.mybatisflex.core.query.DistinctQueryColumn;
030import com.mybatisflex.core.query.Join;
031import com.mybatisflex.core.query.QueryColumn;
032import com.mybatisflex.core.query.QueryCondition;
033import com.mybatisflex.core.query.QueryTable;
034import com.mybatisflex.core.query.QueryWrapper;
035import com.mybatisflex.core.relation.RelationManager;
036import org.apache.ibatis.exceptions.TooManyResultsException;
037import org.apache.ibatis.session.defaults.DefaultSqlSession;
038
039import java.util.ArrayList;
040import java.util.Collection;
041import java.util.Collections;
042import java.util.HashMap;
043import java.util.HashSet;
044import java.util.List;
045import java.util.Map;
046import java.util.Set;
047import java.util.function.Consumer;
048
049import static com.mybatisflex.core.query.QueryMethods.count;
050
051public class MapperUtil {
052
053    private MapperUtil() {
054    }
055
056
057    /**
058     * <p>原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性,
059     * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。
060     *
061     * <p>为什么这么说,因为是用子查询实现的,生成的 SQL 如下:
062     *
063     * <p><pre>
064     * {@code
065     * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... ) AS `t`;
066     * }
067     * </pre>
068     *
069     * <p>不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。
070     */
071    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
072        return QueryWrapper.create()
073            .select(count().as("total"))
074            .from(queryWrapper).as("t");
075    }
076
077    /**
078     * 优化 COUNT 查询语句。
079     */
080    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
081        // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象
082        QueryWrapper clone = queryWrapper.clone();
083        // 将最后面的 order by 移除掉
084        CPI.setOrderBys(clone, null);
085        // 获取查询列和分组列,用于判断是否进行优化
086        List<QueryColumn> selectColumns = CPI.getSelectColumns(clone);
087        List<QueryColumn> groupByColumns = CPI.getGroupByColumns(clone);
088        // 如果有 distinct 语句或者 group by 语句则不优化
089        // 这种一旦优化了就会造成 count 语句查询出来的值不对
090        if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns)) {
091            return rawCountQueryWrapper(clone);
092        }
093        // 判断能不能清除 join 语句
094        if (canClearJoins(clone)) {
095            CPI.setJoins(clone, null);
096        }
097        // 将 select 里面的列换成 COUNT(*) AS `total`
098        CPI.setSelectColumns(clone, Collections.singletonList(count().as("total")));
099        return clone;
100    }
101
102    public static boolean hasDistinct(List<QueryColumn> selectColumns) {
103        if (CollectionUtil.isEmpty(selectColumns)) {
104            return false;
105        }
106        for (QueryColumn selectColumn : selectColumns) {
107            if (selectColumn instanceof DistinctQueryColumn) {
108                return true;
109            }
110        }
111        return false;
112    }
113
114    private static boolean hasGroupBy(List<QueryColumn> groupByColumns) {
115        return CollectionUtil.isNotEmpty(groupByColumns);
116    }
117
118    private static boolean canClearJoins(QueryWrapper queryWrapper) {
119        List<Join> joins = CPI.getJoins(queryWrapper);
120        if (CollectionUtil.isEmpty(joins)) {
121            return false;
122        }
123
124        // 只有全是 left join 语句才会清除 join
125        // 因为如果是 inner join 或 right join 往往都会放大记录数
126        for (Join join : joins) {
127            if (!SqlConsts.LEFT_JOIN.equals(CPI.getJoinType(join))) {
128                return false;
129            }
130        }
131
132        // 获取 join 语句中使用到的表名
133        List<String> joinTables = new ArrayList<>();
134        joins.forEach(join -> {
135            QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
136            if (joinQueryTable != null) {
137                String tableName = joinQueryTable.getName();
138                if (StringUtil.isNotBlank(joinQueryTable.getAlias())) {
139                    joinTables.add(tableName + "." + joinQueryTable.getAlias());
140                } else {
141                    joinTables.add(tableName);
142                }
143            }
144        });
145
146        // 获取 where 语句中的条件
147        QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
148
149        // 最后判断一下 where 中是否用到了 join 的表
150        return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
151    }
152
153    @SafeVarargs
154    public static <T, R> Page<R> doPaginate(
155        BaseMapper<T> mapper,
156        Page<R> page,
157        QueryWrapper queryWrapper,
158        Class<R> asType,
159        boolean withRelations,
160        Consumer<FieldQueryBuilder<R>>... consumers
161    ) {
162        Long limitRows = CPI.getLimitRows(queryWrapper);
163        Long limitOffset = CPI.getLimitOffset(queryWrapper);
164        try {
165            // 只有 totalRow 小于 0 的时候才会去查询总量
166            // 这样方便用户做总数缓存,而非每次都要去查询总量
167            // 一般的分页场景中,只有第一页的时候有必要去查询总量,第二页以后是不需要的
168
169            if (page.getTotalRow() < 0) {
170
171                QueryWrapper countQueryWrapper;
172
173                if (page.needOptimizeCountQuery()) {
174                    countQueryWrapper = MapperUtil.optimizeCountQueryWrapper(queryWrapper);
175                } else {
176                    countQueryWrapper = MapperUtil.rawCountQueryWrapper(queryWrapper);
177                }
178
179                // optimize: 在 count 之前先去掉 limit 参数,避免 count 查询错误
180                CPI.setLimitRows(countQueryWrapper, null);
181                CPI.setLimitOffset(countQueryWrapper, null);
182
183                page.setTotalRow(mapper.selectCountByQuery(countQueryWrapper));
184            }
185
186            if (!page.hasRecords()) {
187                if (withRelations) {
188                    RelationManager.clearConfigIfNecessary();
189                }
190                return page;
191            }
192
193            queryWrapper.limit(page.offset(), page.getPageSize());
194
195            List<R> records;
196            if (asType != null) {
197                records = mapper.selectListByQueryAs(queryWrapper, asType);
198            } else {
199                // noinspection unchecked
200                records = (List<R>) mapper.selectListByQuery(queryWrapper);
201            }
202
203            if (withRelations) {
204                queryRelations(mapper, records);
205            }
206
207            queryFields(mapper, records, consumers);
208            page.setRecords(records);
209
210            return page;
211
212        } finally {
213            // 将之前设置的 limit 清除掉
214            // 保险起见把重置代码放到 finally 代码块中
215            CPI.setLimitRows(queryWrapper, limitRows);
216            CPI.setLimitOffset(queryWrapper, limitOffset);
217        }
218    }
219
220
221    public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) {
222        if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {
223            return;
224        }
225
226        Map<String, FieldQuery> fieldQueryMap = new HashMap<>();
227        for (Consumer<FieldQueryBuilder<R>> consumer : consumers) {
228            FieldQueryBuilder<R> fieldQueryBuilder = new FieldQueryBuilder<>();
229            consumer.accept(fieldQueryBuilder);
230
231            FieldQuery fieldQuery = fieldQueryBuilder.build();
232
233            String className = fieldQuery.getEntityClass().getName();
234            String fieldName = fieldQuery.getFieldName();
235            String mapKey = className + '#' + fieldName;
236
237            fieldQueryMap.put(mapKey, fieldQuery);
238        }
239
240        FieldQueryManager.queryFields(mapper, list, fieldQueryMap);
241    }
242
243
244    public static <E> E queryRelations(BaseMapper<?> mapper, E entity) {
245        if (entity != null) {
246            queryRelations(mapper, Collections.singletonList(entity));
247        } else {
248            RelationManager.clearConfigIfNecessary();
249        }
250        return entity;
251    }
252
253    public static <E> List<E> queryRelations(BaseMapper<?> mapper, List<E> entities) {
254        RelationManager.queryRelations(mapper, entities);
255        return entities;
256    }
257
258
259    public static Class<? extends Collection> getCollectionWrapType(Class<?> type) {
260        if (ClassUtil.canInstance(type.getModifiers())) {
261            return (Class<? extends Collection>) type;
262        }
263
264        if (List.class.isAssignableFrom(type)) {
265            return ArrayList.class;
266        }
267
268        if (Set.class.isAssignableFrom(type)) {
269            return HashSet.class;
270        }
271
272        throw new IllegalStateException("Field query can not support type: " + type.getName());
273    }
274
275
276    /**
277     * 搬运加改造 {@link DefaultSqlSession#selectOne(String, Object)}
278     */
279    public static <T> T getSelectOneResult(List<T> list) {
280        if (list == null || list.isEmpty()) {
281            return null;
282        }
283        int size = list.size();
284        if (size == 1) {
285            return list.get(0);
286        }
287        throw new TooManyResultsException(
288            "Expected one result (or null) to be returned by selectOne(), but found: " + size);
289    }
290
291    public static long getLongNumber(List<Object> objects) {
292        Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
293        if (object == null) {
294            return 0;
295        } else if (object instanceof Number) {
296            return ((Number) object).longValue();
297        } else {
298            throw FlexExceptions.wrap("selectCountByQuery error, can not get number value of result: \"" + object + "\"");
299        }
300    }
301
302
303    public static Map<String, Object> preparedParams(Page<?> page, QueryWrapper queryWrapper, Map<String, Object> params) {
304        Map<String, Object> newParams = new HashMap<>();
305
306        if (params != null) {
307            newParams.putAll(params);
308        }
309
310        newParams.put("pageOffset", page.offset());
311        newParams.put("pageNumber", page.getPageNumber());
312        newParams.put("pageSize", page.getPageSize());
313
314        DbType dbType = DialectFactory.getHintDbType();
315        newParams.put("dbType", dbType != null ? dbType : FlexGlobalConfig.getDefaultConfig().getDbType());
316
317        if (queryWrapper != null) {
318            preparedQueryWrapper(newParams, queryWrapper);
319        }
320
321        return newParams;
322    }
323
324
325    private static void preparedQueryWrapper(Map<String, Object> params, QueryWrapper queryWrapper) {
326        String sql = DialectFactory.getDialect().buildNoSelectSql(queryWrapper);
327        StringBuilder sqlBuilder = new StringBuilder();
328        char quote = 0;
329        int index = 0;
330        for (int i = 0; i < sql.length(); ++i) {
331            char ch = sql.charAt(i);
332            if (ch == '\'') {
333                if (quote == 0) {
334                    quote = ch;
335                } else if (quote == '\'') {
336                    quote = 0;
337                }
338            } else if (ch == '"') {
339                if (quote == 0) {
340                    quote = ch;
341                } else if (quote == '"') {
342                    quote = 0;
343                }
344            }
345            if (quote == 0 && ch == '?') {
346                sqlBuilder.append("#{qwParams_").append(index++).append("}");
347            } else {
348                sqlBuilder.append(ch);
349            }
350        }
351        params.put("qwSql", sqlBuilder.toString());
352        Object[] valueArray = CPI.getValueArray(queryWrapper);
353        for (int i = 0; i < valueArray.length; i++) {
354            params.put("qwParams_" + i, valueArray[i]);
355        }
356    }
357
358}