001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.audit;
017
018import com.mybatisflex.core.FlexConsts;
019import org.apache.ibatis.mapping.BoundSql;
020import org.apache.ibatis.mapping.ParameterMapping;
021import org.apache.ibatis.mapping.ParameterMode;
022import org.apache.ibatis.reflection.MetaObject;
023import org.apache.ibatis.reflection.ParamNameResolver;
024import org.apache.ibatis.session.Configuration;
025import org.apache.ibatis.type.TypeHandlerRegistry;
026
027import java.sql.SQLException;
028import java.sql.Statement;
029import java.util.Collection;
030import java.util.List;
031import java.util.Map;
032
033/**
034 * 审计管理器,统一执行如何和配置入口
035 */
036public class AuditManager {
037
038
039    private AuditManager() {
040    }
041
042
043    private static MessageFactory messageFactory = new DefaultMessageFactory();
044
045    private static boolean auditEnable = false;
046    private static Clock clock = System::currentTimeMillis;
047    private static MessageCollector messageCollector = new ScheduledMessageCollector();
048
049    public static boolean isAuditEnable() {
050        return auditEnable;
051    }
052
053    public static void setAuditEnable(boolean auditEnable) {
054        AuditManager.auditEnable = auditEnable;
055    }
056
057    public static Clock getClock() {
058        return clock;
059    }
060
061    public static void setClock(Clock clock) {
062        AuditManager.clock = clock;
063    }
064
065    public static MessageFactory getMessageFactory() {
066        return messageFactory;
067    }
068
069    public static void setMessageFactory(MessageFactory messageFactory) {
070        AuditManager.messageFactory = messageFactory;
071    }
072
073    public static MessageCollector getMessageCollector() {
074        return messageCollector;
075    }
076
077
078    public static void setMessageReporter(MessageReporter messageReporter) {
079        MessageCollector newMessageCollector = new ScheduledMessageCollector(10, messageReporter);
080        setMessageCollector(newMessageCollector);
081    }
082
083    public static void setMessageCollector(MessageCollector messageCollector) {
084        MessageCollector temp = AuditManager.messageCollector;
085        AuditManager.messageCollector = messageCollector;
086        releaseScheduledMessageCollector(temp);
087
088    }
089
090    private static void releaseScheduledMessageCollector(MessageCollector messageCollector) {
091        if (messageCollector instanceof ScheduledMessageCollector) {
092            ((ScheduledMessageCollector) messageCollector).release();
093        }
094    }
095
096    @SuppressWarnings("rawtypes")
097    public static <T> T startAudit(AuditRunnable<T> supplier, Statement statement, BoundSql boundSql, Configuration configuration) throws SQLException {
098        AuditMessage auditMessage = messageFactory.create();
099        if (auditMessage == null) {
100            return supplier.execute();
101        }
102        auditMessage.setQueryTime(clock.getTick());
103        try {
104            T result = supplier.execute();
105            if (result instanceof Collection) {
106                auditMessage.setQueryCount(((Collection) result).size());
107            } else if (result != null) {
108                auditMessage.setQueryCount(1);
109            }
110            return result;
111        } finally {
112            auditMessage.setElapsedTime(clock.getTick() - auditMessage.getQueryTime());
113            auditMessage.setQuery(boundSql.getSql());
114            Object parameter = boundSql.getParameterObject();
115
116            /** parameter 的组装请查看 getNamedParams 方法
117             * @see ParamNameResolver#getNamedParams(Object[])
118             */
119            if (parameter instanceof Map) {
120                TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
121                if (((Map<?, ?>) parameter).containsKey(FlexConsts.SQL_ARGS)) {
122                    auditMessage.addParams(statement, ((Map<?, ?>) parameter).get(FlexConsts.SQL_ARGS));
123                } else if (((Map<?, ?>) parameter).containsKey("collection")) {
124                    Collection collection = (Collection) ((Map<?, ?>) parameter).get("collection");
125                    auditMessage.addParams(statement, collection.toArray());
126                } else if (((Map<?, ?>) parameter).containsKey("array")) {
127                    auditMessage.addParams(statement, ((Map<?, ?>) parameter).get("array"));
128                } else {
129                    List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
130                    for (ParameterMapping parameterMapping : parameterMappings) {
131                        if (parameterMapping.getMode() != ParameterMode.OUT) {
132                            Object value;
133                            String propertyName = parameterMapping.getProperty();
134                            if (boundSql.hasAdditionalParameter(propertyName)) {
135                                value = boundSql.getAdditionalParameter(propertyName);
136                            } else if (typeHandlerRegistry.hasTypeHandler(parameter.getClass())) {
137                                value = parameter;
138                            } else {
139                                MetaObject metaObject = configuration.newMetaObject(parameter);
140                                value = metaObject.getValue(propertyName);
141                            }
142                            auditMessage.addParams(statement, value);
143                        }
144                    }
145                }
146            }
147            messageCollector.collect(auditMessage);
148        }
149    }
150
151
152    @FunctionalInterface
153    public interface AuditRunnable<T> {
154
155        T execute() throws SQLException;
156
157    }
158
159}