001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.table;
017
018import com.mybatisflex.annotation.*;
019import com.mybatisflex.core.BaseMapper;
020import com.mybatisflex.core.FlexConsts;
021import com.mybatisflex.core.FlexGlobalConfig;
022import com.mybatisflex.core.exception.FlexExceptions;
023import com.mybatisflex.core.util.ClassUtil;
024import com.mybatisflex.core.util.CollectionUtil;
025import com.mybatisflex.core.util.StringUtil;
026import org.apache.ibatis.io.ResolverUtil;
027import org.apache.ibatis.reflection.Reflector;
028import org.apache.ibatis.session.Configuration;
029import org.apache.ibatis.type.JdbcType;
030import org.apache.ibatis.type.TypeHandler;
031import org.apache.ibatis.type.TypeHandlerRegistry;
032import org.apache.ibatis.type.UnknownTypeHandler;
033import org.apache.ibatis.util.MapUtil;
034
035import java.lang.reflect.Field;
036import java.lang.reflect.Modifier;
037import java.lang.reflect.ParameterizedType;
038import java.lang.reflect.Type;
039import java.math.BigDecimal;
040import java.math.BigInteger;
041import java.sql.Time;
042import java.sql.Timestamp;
043import java.time.*;
044import java.time.chrono.JapaneseDate;
045import java.util.*;
046import java.util.concurrent.ConcurrentHashMap;
047import java.util.stream.Collectors;
048
049public class TableInfoFactory {
050
051    private TableInfoFactory() {}
052
053    private static final Set<Class<?>> defaultSupportColumnTypes = CollectionUtil.newHashSet(
054            int.class, Integer.class,
055            short.class, Short.class,
056            long.class, Long.class,
057            float.class, Float.class,
058            double.class, Double.class,
059            boolean.class, Boolean.class,
060            Date.class, java.sql.Date.class, Time.class, Timestamp.class,
061            Instant.class, LocalDate.class, LocalDateTime.class, LocalTime.class, OffsetDateTime.class, OffsetTime.class, ZonedDateTime.class,
062            Year.class, Month.class, YearMonth.class, JapaneseDate.class,
063            byte[].class, Byte[].class, Byte.class,
064            BigInteger.class, BigDecimal.class,
065            char.class, String.class, Character.class
066    );
067
068
069    private static final Map<Class<?>, TableInfo> mapperTableInfoMap = new ConcurrentHashMap<>();
070    private static final Map<Class<?>, TableInfo> entityTableMap = new ConcurrentHashMap<>();
071    private static final Map<String, TableInfo> tableInfoMap = new ConcurrentHashMap<>();
072    private static final Set<String> initedPackageNames = new HashSet<>();
073
074
075    public synchronized static void init(String mapperPackageName) {
076        if (!initedPackageNames.contains(mapperPackageName)) {
077            ResolverUtil<Class<?>> resolverUtil = new ResolverUtil<>();
078            resolverUtil.find(new ResolverUtil.IsA(BaseMapper.class), mapperPackageName);
079            Set<Class<? extends Class<?>>> mapperSet = resolverUtil.getClasses();
080            for (Class<? extends Class<?>> mapperClass : mapperSet) {
081                ofMapperClass(mapperClass);
082            }
083            initedPackageNames.add(mapperPackageName);
084        }
085    }
086
087
088    public static TableInfo ofMapperClass(Class<?> mapperClass) {
089        return MapUtil.computeIfAbsent(mapperTableInfoMap, mapperClass, key -> {
090            Class<?> entityClass = getEntityClass(mapperClass);
091            if (entityClass == null) {
092                return null;
093            }
094            return ofEntityClass(entityClass);
095        });
096    }
097
098
099    public static TableInfo ofEntityClass(Class<?> entityClass) {
100        return MapUtil.computeIfAbsent(entityTableMap, entityClass, aClass -> {
101            TableInfo tableInfo = createTableInfo(entityClass);
102            tableInfoMap.put(tableInfo.getTableName(), tableInfo);
103            return tableInfo;
104        });
105    }
106
107
108    public static TableInfo ofTableName(String tableName) {
109        return StringUtil.isNotBlank(tableName) ? tableInfoMap.get(tableName) : null;
110    }
111
112
113    private static Class<?> getEntityClass(Class<?> mapperClass) {
114        if (mapperClass == null || mapperClass == Object.class) {
115            return null;
116        }
117        Type[] genericInterfaces = mapperClass.getGenericInterfaces();
118        if (genericInterfaces.length == 1) {
119            Type type = genericInterfaces[0];
120            if (type instanceof ParameterizedType) {
121                Type actualTypeArgument = ((ParameterizedType) type).getActualTypeArguments()[0];
122                return actualTypeArgument instanceof Class ? (Class<?>) actualTypeArgument : null;
123            } else if (type instanceof Class) {
124                return getEntityClass((Class<?>) type);
125            }
126        }
127        return getEntityClass(mapperClass.getSuperclass());
128    }
129
130
131    private static TableInfo createTableInfo(Class<?> entityClass) {
132
133        TableInfo tableInfo = new TableInfo();
134        tableInfo.setEntityClass(entityClass);
135        tableInfo.setReflector(new Reflector(entityClass));
136
137        //初始化表名
138        Table table = entityClass.getAnnotation(Table.class);
139        if (table != null) {
140            tableInfo.setTableName(table.value());
141            tableInfo.setSchema(table.schema());
142            tableInfo.setCamelToUnderline(table.camelToUnderline());
143
144            if (table.onInsert().length > 0) {
145                List<InsertListener> insertListeners = Arrays.stream(table.onInsert())
146                        .filter(listener -> listener != NoneListener.class)
147                        .map(ClassUtil::newInstance)
148                        .collect(Collectors.toList());
149                tableInfo.setOnInsertListeners(insertListeners);
150            }
151
152            if (table.onUpdate().length > 0) {
153                List<UpdateListener> updateListeners = Arrays.stream(table.onUpdate())
154                        .filter(listener -> listener != NoneListener.class)
155                        .map(ClassUtil::newInstance)
156                        .collect(Collectors.toList());
157                tableInfo.setOnUpdateListeners(updateListeners);
158            }
159
160            if (table.onSet().length > 0) {
161                List<SetListener> setListeners = Arrays.stream(table.onSet())
162                        .filter(listener -> listener != NoneListener.class)
163                        .map(ClassUtil::newInstance)
164                        .collect(Collectors.toList());
165                tableInfo.setOnSetListeners(setListeners);
166            }
167
168            if (StringUtil.isNotBlank(table.dataSource())) {
169                tableInfo.setDataSource(table.dataSource());
170            }
171        } else {
172            //默认为类名转驼峰下划线
173            String tableName = StringUtil.camelToUnderline(entityClass.getSimpleName());
174            tableInfo.setTableName(tableName);
175        }
176
177        //初始化字段相关
178        List<ColumnInfo> columnInfoList = new ArrayList<>();
179        List<IdInfo> idInfos = new ArrayList<>();
180
181        Field idField = null;
182
183        String logicDeleteColumn = null;
184        String versionColumn = null;
185        String tenantIdColumn = null;
186
187        //数据插入时,默认插入数据字段
188        Map<String, String> onInsertColumns = new HashMap<>();
189
190        //数据更新时,默认更新内容的字段
191        Map<String, String> onUpdateColumns = new HashMap<>();
192
193        //大字段列
194        Set<String> largeColumns = new LinkedHashSet<>();
195        // 默认查询列
196        Set<String> defaultColumns = new LinkedHashSet<>();
197
198
199        List<Field> entityFields = getColumnFields(entityClass);
200
201        for (Field field : entityFields) {
202
203            Column column = field.getAnnotation(Column.class);
204            if (column != null && column.ignore()) {
205                continue; // ignore
206            }
207
208            Class<?> fieldType = field.getType();
209
210            //满足一下 3 中情况,不支持该类型
211            if ((column == null || column.typeHandler() == UnknownTypeHandler.class) // 未配置 typeHandler
212                    && !fieldType.isEnum()   // 类型不是枚举
213                    && !defaultSupportColumnTypes.contains(fieldType) //默认的自动类型不包含该类型
214            ) {
215                // 集合嵌套
216                if (Collection.class.isAssignableFrom(fieldType)) {
217                    ParameterizedType genericType = (ParameterizedType) field.getGenericType();
218                    Type actualTypeArgument = genericType.getActualTypeArguments()[0];
219                    tableInfo.addCollectionType(field, (Class<?>) actualTypeArgument);
220                }
221                // 实体类嵌套
222                else if (!Map.class.isAssignableFrom(fieldType)
223                        && !fieldType.isArray()) {
224                    // tableInfo.addJoinType(field.getName(), fieldType);
225                    tableInfo.addAssociationType(field.getName(), fieldType);
226                }
227                // 不支持的类型直接跳过
228                continue;
229            }
230
231            //列名
232            String columnName = column != null && StringUtil.isNotBlank(column.value())
233                    ? column.value()
234                    : (tableInfo.isCamelToUnderline() ? StringUtil.camelToUnderline(field.getName()) : field.getName());
235
236            //逻辑删除字段
237            if (column != null && column.isLogicDelete()) {
238                if (logicDeleteColumn == null) {
239                    logicDeleteColumn = columnName;
240                } else {
241                    throw FlexExceptions.wrap("The logic delete column of entity[%s] must be less then 2.", entityClass.getName());
242                }
243            }
244
245            //乐观锁版本字段
246            if (column != null && column.version()) {
247                if (versionColumn == null) {
248                    versionColumn = columnName;
249                } else {
250                    throw FlexExceptions.wrap("The version column of entity[%s] must be less then 2.", entityClass.getName());
251                }
252            }
253
254            //租户ID 字段
255            if (column != null && column.tenantId()) {
256                if (tenantIdColumn == null) {
257                    tenantIdColumn = columnName;
258                } else {
259                    throw FlexExceptions.wrap("The tenantId column of entity[%s] must be less then 2.", entityClass.getName());
260                }
261            }
262
263            if (column != null && StringUtil.isNotBlank(column.onInsertValue())) {
264                onInsertColumns.put(columnName, column.onInsertValue().trim());
265            }
266
267
268            if (column != null && StringUtil.isNotBlank(column.onUpdateValue())) {
269                onUpdateColumns.put(columnName, column.onUpdateValue().trim());
270            }
271
272
273            if (column != null && column.isLarge()) {
274                largeColumns.add(columnName);
275            }
276
277            if (column == null || !column.isLarge()) {
278                defaultColumns.add(columnName);
279            }
280
281            Id id = field.getAnnotation(Id.class);
282            ColumnInfo columnInfo;
283            if (id != null) {
284                columnInfo = new IdInfo(columnName, field.getName(), field.getType(), id);
285                idInfos.add((IdInfo) columnInfo);
286            } else {
287                columnInfo = new ColumnInfo();
288                columnInfoList.add(columnInfo);
289            }
290
291            columnInfo.setColumn(columnName);
292            columnInfo.setProperty(field.getName());
293            columnInfo.setPropertyType(field.getType());
294
295            if (column != null && column.typeHandler() != UnknownTypeHandler.class) {
296                Class<?> typeHandlerClass = column.typeHandler();
297                Configuration configuration = FlexGlobalConfig.getDefaultConfig().getConfiguration();
298                TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
299                TypeHandler<?> typeHandler = typeHandlerRegistry.getInstance(columnInfo.getPropertyType(), typeHandlerClass);
300                columnInfo.setTypeHandler(typeHandler);
301            }
302
303            ColumnMask columnMask = field.getAnnotation(ColumnMask.class);
304            if (columnMask != null && StringUtil.isNotBlank(columnMask.value())) {
305                if (String.class != field.getType()) {
306                    throw new IllegalStateException("@ColumnMask() only support for string type field. error: " + entityClass.getName() + "." + field.getName());
307                }
308                columnInfo.setMaskType(columnMask.value().trim());
309            }
310
311            if (column != null && column.jdbcType() != JdbcType.UNDEFINED) {
312                columnInfo.setJdbcType(column.jdbcType());
313            }
314
315            if (FlexConsts.DEFAULT_PRIMARY_FIELD.equals(field.getName())) {
316                idField = field;
317            }
318        }
319
320
321        if (idInfos.isEmpty() && idField != null) {
322            int index = -1;
323            for (int i = 0; i < columnInfoList.size(); i++) {
324                ColumnInfo columnInfo = columnInfoList.get(i);
325                if (FlexConsts.DEFAULT_PRIMARY_FIELD.equals(columnInfo.getProperty())) {
326                    index = i;
327                    break;
328                }
329            }
330            if (index >= 0) {
331                ColumnInfo removedColumnInfo = columnInfoList.remove(index);
332                idInfos.add(new IdInfo(removedColumnInfo));
333            }
334        }
335
336        tableInfo.setLogicDeleteColumn(logicDeleteColumn);
337        tableInfo.setVersionColumn(versionColumn);
338        tableInfo.setTenantIdColumn(tenantIdColumn);
339
340        if (!onInsertColumns.isEmpty()) {
341            tableInfo.setOnInsertColumns(onInsertColumns);
342        }
343
344        if (!onUpdateColumns.isEmpty()) {
345            tableInfo.setOnUpdateColumns(onUpdateColumns);
346        }
347
348        if (!largeColumns.isEmpty()) {
349            tableInfo.setLargeColumns(largeColumns.toArray(new String[0]));
350        }
351        if (!defaultColumns.isEmpty()) {
352            tableInfo.setDefaultColumns(defaultColumns.toArray(new String[0]));
353        }
354
355        tableInfo.setColumnInfoList(columnInfoList);
356        tableInfo.setPrimaryKeyList(idInfos);
357
358
359        return tableInfo;
360    }
361
362
363    public static List<Field> getColumnFields(Class<?> entityClass) {
364        List<Field> fields = new ArrayList<>();
365        doGetFields(entityClass, fields);
366        return fields;
367    }
368
369
370    private static void doGetFields(Class<?> entityClass, List<Field> fields) {
371        if (entityClass == null || entityClass == Object.class) {
372            return;
373        }
374
375        Field[] declaredFields = entityClass.getDeclaredFields();
376        for (Field declaredField : declaredFields) {
377            if (Modifier.isStatic(declaredField.getModifiers())
378                    || existName(fields, declaredField)) {
379                continue;
380            }
381            fields.add(declaredField);
382        }
383
384        doGetFields(entityClass.getSuperclass(), fields);
385    }
386
387
388    private static boolean existName(List<Field> fields, Field field) {
389        for (Field f : fields) {
390            if (f.getName().equalsIgnoreCase(field.getName())) {
391                return true;
392            }
393        }
394        return false;
395    }
396}