001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.util;
017
018import com.mybatisflex.core.BaseMapper;
019import com.mybatisflex.core.FlexGlobalConfig;
020import com.mybatisflex.core.constant.SqlConsts;
021import com.mybatisflex.core.dialect.DbType;
022import com.mybatisflex.core.dialect.DialectFactory;
023import com.mybatisflex.core.exception.FlexExceptions;
024import com.mybatisflex.core.field.FieldQuery;
025import com.mybatisflex.core.field.FieldQueryBuilder;
026import com.mybatisflex.core.field.FieldQueryManager;
027import com.mybatisflex.core.paginate.Page;
028import com.mybatisflex.core.query.*;
029import com.mybatisflex.core.relation.RelationManager;
030import org.apache.ibatis.exceptions.TooManyResultsException;
031import org.apache.ibatis.session.defaults.DefaultSqlSession;
032
033import java.util.*;
034import java.util.function.Consumer;
035
036import static com.mybatisflex.core.query.QueryMethods.count;
037
038public class MapperUtil {
039
040    private MapperUtil() {
041    }
042
043
044    /**
045     * <p>原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性,
046     * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。
047     *
048     * <p>为什么这么说,因为是用子查询实现的,生成的 SQL 如下:
049     *
050     * <p><pre>
051     * {@code
052     * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... ) AS `t`;
053     * }
054     * </pre>
055     *
056     * <p>不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。
057     */
058    public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) {
059        return QueryWrapper.create()
060            .select(count().as("total"))
061            .from(queryWrapper).as("t");
062    }
063
064    /**
065     * 优化 COUNT 查询语句。
066     */
067    public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) {
068        // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象
069        QueryWrapper clone = queryWrapper.clone();
070        // 将最后面的 order by 移除掉
071        CPI.setOrderBys(clone, null);
072        // 获取查询列和分组列,用于判断是否进行优化
073        List<QueryColumn> selectColumns = CPI.getSelectColumns(clone);
074        List<QueryColumn> groupByColumns = CPI.getGroupByColumns(clone);
075        // 如果有 distinct 语句或者 group by 语句则不优化
076        // 这种一旦优化了就会造成 count 语句查询出来的值不对
077        if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns)) {
078            return rawCountQueryWrapper(clone);
079        }
080        // 判断能不能清除 join 语句
081        if (canClearJoins(clone)) {
082            CPI.setJoins(clone, null);
083        }
084        // 将 select 里面的列换成 COUNT(*) AS `total`
085        CPI.setSelectColumns(clone, Collections.singletonList(count().as("total")));
086        return clone;
087    }
088
089    public static boolean hasDistinct(List<QueryColumn> selectColumns) {
090        if (CollectionUtil.isEmpty(selectColumns)) {
091            return false;
092        }
093        for (QueryColumn selectColumn : selectColumns) {
094            if (selectColumn instanceof DistinctQueryColumn) {
095                return true;
096            }
097        }
098        return false;
099    }
100
101    private static boolean hasGroupBy(List<QueryColumn> groupByColumns) {
102        return CollectionUtil.isNotEmpty(groupByColumns);
103    }
104
105    private static boolean canClearJoins(QueryWrapper queryWrapper) {
106        List<Join> joins = CPI.getJoins(queryWrapper);
107        if (CollectionUtil.isEmpty(joins)) {
108            return false;
109        }
110
111        // 只有全是 left join 语句才会清除 join
112        // 因为如果是 inner join 或 right join 往往都会放大记录数
113        for (Join join : joins) {
114            if (!SqlConsts.LEFT_JOIN.equals(CPI.getJoinType(join))) {
115                return false;
116            }
117        }
118
119        // 获取 join 语句中使用到的表名
120        List<String> joinTables = new ArrayList<>();
121        joins.forEach(join -> {
122            QueryTable joinQueryTable = CPI.getJoinQueryTable(join);
123            if (joinQueryTable != null) {
124                String tableName = joinQueryTable.getName();
125                if (StringUtil.isNotBlank(joinQueryTable.getAlias())) {
126                    joinTables.add(tableName + "." + joinQueryTable.getAlias());
127                } else {
128                    joinTables.add(tableName);
129                }
130            }
131        });
132
133        // 获取 where 语句中的条件
134        QueryCondition where = CPI.getWhereQueryCondition(queryWrapper);
135
136        // 最后判断一下 where 中是否用到了 join 的表
137        return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables));
138    }
139
140    @SafeVarargs
141    public static <T, R> Page<R> doPaginate(BaseMapper<T> mapper, Page<R> page, QueryWrapper queryWrapper, Class<R> asType, boolean withRelations, Consumer<FieldQueryBuilder<R>>... consumers) {
142        try {
143            // 只有 totalRow 小于 0 的时候才会去查询总量
144            // 这样方便用户做总数缓存,而非每次都要去查询总量
145            // 一般的分页场景中,只有第一页的时候有必要去查询总量,第二页以后是不需要的
146            if (page.getTotalRow() < 0) {
147                QueryWrapper countQueryWrapper;
148                if (page.needOptimizeCountQuery()) {
149                    countQueryWrapper = MapperUtil.optimizeCountQueryWrapper(queryWrapper);
150                } else {
151                    countQueryWrapper = MapperUtil.rawCountQueryWrapper(queryWrapper);
152                }
153                page.setTotalRow(mapper.selectCountByQuery(countQueryWrapper));
154            }
155
156            if (!page.hasRecords()) {
157                return page;
158            }
159
160            queryWrapper.limit(page.offset(), page.getPageSize());
161
162            List<R> records;
163            if (asType != null) {
164                records = mapper.selectListByQueryAs(queryWrapper, asType);
165            } else {
166                // noinspection unchecked
167                records = (List<R>) mapper.selectListByQuery(queryWrapper);
168            }
169
170            if (withRelations) {
171                queryRelations(mapper, records);
172            }
173
174            queryFields(mapper, records, consumers);
175            page.setRecords(records);
176
177            return page;
178
179        } finally {
180            // 将之前设置的 limit 清除掉
181            // 保险起见把重置代码放到 finally 代码块中
182            CPI.setLimitRows(queryWrapper, null);
183            CPI.setLimitOffset(queryWrapper, null);
184        }
185    }
186
187
188    public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) {
189        if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) {
190            return;
191        }
192
193        Map<String, FieldQuery> fieldQueryMap = new HashMap<>();
194        for (Consumer<FieldQueryBuilder<R>> consumer : consumers) {
195            FieldQueryBuilder<R> fieldQueryBuilder = new FieldQueryBuilder<>();
196            consumer.accept(fieldQueryBuilder);
197
198            FieldQuery fieldQuery = fieldQueryBuilder.build();
199
200            String className = fieldQuery.getEntityClass().getName();
201            String fieldName = fieldQuery.getFieldName();
202            String mapKey = className + '#' + fieldName;
203
204            fieldQueryMap.put(mapKey, fieldQuery);
205        }
206
207        FieldQueryManager.queryFields(mapper, list, fieldQueryMap);
208    }
209
210
211    public static <E> E queryRelations(BaseMapper<?> mapper, E entity) {
212        if (entity != null) {
213            queryRelations(mapper, Collections.singletonList(entity));
214        }
215        return entity;
216    }
217
218    public static <E> List<E> queryRelations(BaseMapper<?> mapper, List<E> entities) {
219        RelationManager.queryRelations(mapper, entities);
220        return entities;
221    }
222
223
224    public static Class<? extends Collection> getCollectionWrapType(Class<?> type) {
225        if (ClassUtil.canInstance(type.getModifiers())) {
226            return (Class<? extends Collection>) type;
227        }
228
229        if (List.class.isAssignableFrom(type)) {
230            return ArrayList.class;
231        }
232
233        if (Set.class.isAssignableFrom(type)) {
234            return HashSet.class;
235        }
236
237        throw new IllegalStateException("Field query can not support type: " + type.getName());
238    }
239
240
241    /**
242     * 搬运加改造 {@link DefaultSqlSession#selectOne(String, Object)}
243     */
244    public static <T> T getSelectOneResult(List<T> list) {
245        if (list == null || list.isEmpty()) {
246            return null;
247        }
248        int size = list.size();
249        if (size == 1) {
250            return list.get(0);
251        }
252        throw new TooManyResultsException(
253            "Expected one result (or null) to be returned by selectOne(), but found: " + size);
254    }
255
256    public static long getLongNumber(List<Object> objects) {
257        Object object = objects == null || objects.isEmpty() ? null : objects.get(0);
258        if (object == null) {
259            return 0;
260        } else if (object instanceof Number) {
261            return ((Number) object).longValue();
262        } else {
263            throw FlexExceptions.wrap("selectCountByQuery error, can not get number value of result: \"" + object + "\"");
264        }
265    }
266
267
268    public static Map<String, Object> preparedParams(Page<?> page, QueryWrapper queryWrapper, Map<String, Object> params) {
269        Map<String, Object> newParams = new HashMap<>();
270
271        if (params != null) {
272            newParams.putAll(params);
273        }
274
275        newParams.put("pageOffset", page.offset());
276        newParams.put("pageNumber", page.getPageNumber());
277        newParams.put("pageSize", page.getPageSize());
278
279        DbType dbType = DialectFactory.getHintDbType();
280        newParams.put("dbType", dbType != null ? dbType : FlexGlobalConfig.getDefaultConfig().getDbType());
281
282        if (queryWrapper != null) {
283            preparedQueryWrapper(newParams, queryWrapper);
284        }
285
286        return newParams;
287    }
288
289
290    private static void preparedQueryWrapper(Map<String, Object> params, QueryWrapper queryWrapper) {
291        String sql = DialectFactory.getDialect().buildNoSelectSql(queryWrapper);
292        StringBuilder sqlBuilder = new StringBuilder();
293        char quote = 0;
294        int index = 0;
295        for (int i = 0; i < sql.length(); ++i) {
296            char ch = sql.charAt(i);
297            if (ch == '\'') {
298                if (quote == 0) {
299                    quote = ch;
300                } else if (quote == '\'') {
301                    quote = 0;
302                }
303            } else if (ch == '"') {
304                if (quote == 0) {
305                    quote = ch;
306                } else if (quote == '"') {
307                    quote = 0;
308                }
309            }
310            if (quote == 0 && ch == '?') {
311                sqlBuilder.append("#{qwParams_").append(index++).append("}");
312            } else {
313                sqlBuilder.append(ch);
314            }
315        }
316        params.put("qwSql", sqlBuilder.toString());
317        Object[] valueArray = CPI.getValueArray(queryWrapper);
318        for (int i = 0; i < valueArray.length; i++) {
319            params.put("qwParams_" + i, valueArray[i]);
320        }
321    }
322
323}