001/* 002 * Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com). 003 * <p> 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * <p> 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * <p> 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package com.mybatisflex.core.util; 017 018import com.mybatisflex.core.BaseMapper; 019import com.mybatisflex.core.constant.SqlConsts; 020import com.mybatisflex.core.exception.FlexExceptions; 021import com.mybatisflex.core.field.FieldQuery; 022import com.mybatisflex.core.field.FieldQueryBuilder; 023import com.mybatisflex.core.field.FieldQueryManager; 024import com.mybatisflex.core.paginate.Page; 025import com.mybatisflex.core.query.*; 026import com.mybatisflex.core.relation.RelationManager; 027import org.apache.ibatis.exceptions.TooManyResultsException; 028import org.apache.ibatis.session.defaults.DefaultSqlSession; 029 030import java.util.*; 031import java.util.function.Consumer; 032 033import static com.mybatisflex.core.query.QueryMethods.count; 034 035public class MapperUtil { 036 037 private MapperUtil() { 038 } 039 040 041 /** 042 * <p>原生的、未经过优化的 COUNT 查询。抛开效率问题不谈,只关注结果的准确性, 043 * 这个 COUNT 查询查出来的分页总数据是 100% 正确的,不接受任何反驳。 044 * 045 * <p>为什么这么说,因为是用子查询实现的,生成的 SQL 如下: 046 * 047 * <p><pre> 048 * {@code 049 * SELECT COUNT(*) AS `total` FROM ( ...用户构建的 SQL 语句... ) AS `t`; 050 * } 051 * </pre> 052 * 053 * <p>不进行 SQL 优化的时候,返回的就是这样的 COUNT 查询语句。 054 */ 055 public static QueryWrapper rawCountQueryWrapper(QueryWrapper queryWrapper) { 056 return QueryWrapper.create() 057 .select(count().as("total")) 058 .from(queryWrapper).as("t"); 059 } 060 061 /** 062 * 优化 COUNT 查询语句。 063 */ 064 public static QueryWrapper optimizeCountQueryWrapper(QueryWrapper queryWrapper) { 065 // 对克隆对象进行操作,不影响原来的 QueryWrapper 对象 066 QueryWrapper clone = queryWrapper.clone(); 067 // 将最后面的 order by 移除掉 068 CPI.setOrderBys(clone, null); 069 // 获取查询列和分组列,用于判断是否进行优化 070 List<QueryColumn> selectColumns = CPI.getSelectColumns(clone); 071 List<QueryColumn> groupByColumns = CPI.getGroupByColumns(clone); 072 // 如果有 distinct 语句或者 group by 语句则不优化 073 // 这种一旦优化了就会造成 count 语句查询出来的值不对 074 if (hasDistinct(selectColumns) || hasGroupBy(groupByColumns)) { 075 return rawCountQueryWrapper(clone); 076 } 077 // 判断能不能清除 join 语句 078 if (canClearJoins(clone)) { 079 CPI.setJoins(clone, null); 080 } 081 // 将 select 里面的列换成 COUNT(*) AS `total` 082 CPI.setSelectColumns(clone, Collections.singletonList(count().as("total"))); 083 return clone; 084 } 085 086 public static boolean hasDistinct(List<QueryColumn> selectColumns) { 087 if (CollectionUtil.isEmpty(selectColumns)) { 088 return false; 089 } 090 for (QueryColumn selectColumn : selectColumns) { 091 if (selectColumn instanceof DistinctQueryColumn) { 092 return true; 093 } 094 } 095 return false; 096 } 097 098 private static boolean hasGroupBy(List<QueryColumn> groupByColumns) { 099 return CollectionUtil.isNotEmpty(groupByColumns); 100 } 101 102 private static boolean canClearJoins(QueryWrapper queryWrapper) { 103 List<Join> joins = CPI.getJoins(queryWrapper); 104 if (CollectionUtil.isEmpty(joins)) { 105 return false; 106 } 107 108 // 只有全是 left join 语句才会清除 join 109 // 因为如果是 inner join 或 right join 往往都会放大记录数 110 for (Join join : joins) { 111 if (!SqlConsts.LEFT_JOIN.equals(CPI.getJoinType(join))) { 112 return false; 113 } 114 } 115 116 // 获取 join 语句中使用到的表名 117 List<String> joinTables = new ArrayList<>(); 118 joins.forEach(join -> { 119 QueryTable joinQueryTable = CPI.getJoinQueryTable(join); 120 if (joinQueryTable != null && StringUtil.isNotBlank(joinQueryTable.getName())) { 121 joinTables.add(joinQueryTable.getName()); 122 } 123 }); 124 125 // 获取 where 语句中的条件 126 QueryCondition where = CPI.getWhereQueryCondition(queryWrapper); 127 128 // 最后判断一下 where 中是否用到了 join 的表 129 return !CPI.containsTable(where, CollectionUtil.toArrayString(joinTables)); 130 } 131 132 @SafeVarargs 133 public static <T, R> Page<R> doPaginate(BaseMapper<T> mapper, Page<R> page, QueryWrapper queryWrapper, Class<R> asType, boolean withRelations, Consumer<FieldQueryBuilder<R>>... consumers) { 134 try { 135 // 只有 totalRow 小于 0 的时候才会去查询总量 136 // 这样方便用户做总数缓存,而非每次都要去查询总量 137 // 一般的分页场景中,只有第一页的时候有必要去查询总量,第二页以后是不需要的 138 if (page.getTotalRow() < 0) { 139 QueryWrapper countQueryWrapper; 140 if (page.needOptimizeCountQuery()) { 141 countQueryWrapper = MapperUtil.optimizeCountQueryWrapper(queryWrapper); 142 } else { 143 countQueryWrapper = MapperUtil.rawCountQueryWrapper(queryWrapper); 144 } 145 page.setTotalRow(mapper.selectCountByQuery(countQueryWrapper)); 146 } 147 148 if (page.isEmpty()) { 149 return page; 150 } 151 152 queryWrapper.limit(page.offset(), page.getPageSize()); 153 154 List<R> records; 155 if (asType != null) { 156 records = mapper.selectListByQueryAs(queryWrapper, asType); 157 } else { 158 // noinspection unchecked 159 records = (List<R>) mapper.selectListByQuery(queryWrapper); 160 } 161 162 if (withRelations) { 163 queryRelations(mapper, records); 164 } 165 166 queryFields(mapper, records, consumers); 167 page.setRecords(records); 168 169 return page; 170 171 } finally { 172 // 将之前设置的 limit 清除掉 173 // 保险起见把重置代码放到 finally 代码块中 174 CPI.setLimitRows(queryWrapper, null); 175 CPI.setLimitOffset(queryWrapper, null); 176 } 177 } 178 179 180 public static <R> void queryFields(BaseMapper<?> mapper, List<R> list, Consumer<FieldQueryBuilder<R>>[] consumers) { 181 if (CollectionUtil.isEmpty(list) || ArrayUtil.isEmpty(consumers) || consumers[0] == null) { 182 return; 183 } 184 185 Map<String, FieldQuery> fieldQueryMap = new HashMap<>(); 186 for (Consumer<FieldQueryBuilder<R>> consumer : consumers) { 187 FieldQueryBuilder<R> fieldQueryBuilder = new FieldQueryBuilder<>(); 188 consumer.accept(fieldQueryBuilder); 189 190 FieldQuery fieldQuery = fieldQueryBuilder.build(); 191 192 String className = fieldQuery.getEntityClass().getName(); 193 String fieldName = fieldQuery.getFieldName(); 194 String mapKey = className + '#' + fieldName; 195 196 fieldQueryMap.put(mapKey, fieldQuery); 197 } 198 199 FieldQueryManager.queryFields(mapper, list, fieldQueryMap); 200 } 201 202 203 public static <E> E queryRelations(BaseMapper<?> mapper, E entity) { 204 queryRelations(mapper, Collections.singletonList(entity)); 205 return entity; 206 } 207 208 public static <E> List<E> queryRelations(BaseMapper<?> mapper, List<E> entities) { 209 RelationManager.queryRelations(mapper, entities); 210 return entities; 211 } 212 213 214 public static Class<? extends Collection> getCollectionWrapType(Class<?> type) { 215 if (ClassUtil.canInstance(type.getModifiers())) { 216 return (Class<? extends Collection>) type; 217 } 218 219 if (List.class.isAssignableFrom(type)) { 220 return ArrayList.class; 221 } 222 223 if (Set.class.isAssignableFrom(type)) { 224 return HashSet.class; 225 } 226 227 throw new IllegalStateException("Field query can not support type: " + type.getName()); 228 } 229 230 231 /** 232 * 搬运加改造 {@link DefaultSqlSession#selectOne(String, Object)} 233 */ 234 public static <T> T getSelectOneResult(List<T> list) { 235 if (list == null || list.isEmpty()) { 236 return null; 237 } 238 int size = list.size(); 239 if (size == 1) { 240 return list.get(0); 241 } 242 throw new TooManyResultsException( 243 "Expected one result (or null) to be returned by selectOne(), but found: " + size); 244 } 245 246 public static long getLongNumber(List<Object> objects) { 247 Object object = objects == null || objects.isEmpty() ? null : objects.get(0); 248 if (object == null) { 249 return 0; 250 } else if (object instanceof Number) { 251 return ((Number) object).longValue(); 252 } else { 253 throw FlexExceptions.wrap("selectCountByQuery error, can not get number value of result: \"" + object + "\""); 254 } 255 } 256 257}