001/*
002 *  Copyright (c) 2022-2025, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.datasource;
017
018import com.mybatisflex.core.dialect.DbType;
019import com.mybatisflex.core.dialect.DbTypeUtil;
020import com.mybatisflex.core.transaction.TransactionContext;
021import com.mybatisflex.core.transaction.TransactionalManager;
022import com.mybatisflex.core.util.ArrayUtil;
023import com.mybatisflex.core.util.StringUtil;
024import org.apache.ibatis.logging.Log;
025import org.apache.ibatis.logging.LogFactory;
026
027import javax.sql.DataSource;
028import java.lang.reflect.InvocationHandler;
029import java.lang.reflect.Method;
030import java.lang.reflect.Proxy;
031import java.sql.Connection;
032import java.sql.SQLException;
033import java.util.*;
034import java.util.concurrent.ThreadLocalRandom;
035
036/**
037 * @author michael
038 */
039public class FlexDataSource extends AbstractDataSource {
040
041    private static final char LOAD_BALANCE_KEY_SUFFIX = '*';
042    private static final Log log = LogFactory.getLog(FlexDataSource.class);
043
044    private final Map<String, DataSource> dataSourceMap = new HashMap<>();
045    private final Map<String, DbType> dbTypeHashMap = new HashMap<>();
046
047    private final DbType defaultDbType;
048    private final String defaultDataSourceKey;
049    private final DataSource defaultDataSource;
050
051    public FlexDataSource(String dataSourceKey, DataSource dataSource) {
052        this(dataSourceKey, dataSource, true);
053    }
054
055    public FlexDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) {
056        if (needDecryptDataSource) {
057            DataSourceManager.decryptDataSource(dataSource);
058        }
059
060        this.defaultDataSourceKey = dataSourceKey;
061        this.defaultDataSource = dataSource;
062        this.defaultDbType = DbTypeUtil.getDbType(dataSource);
063
064        dataSourceMap.put(dataSourceKey, dataSource);
065        dbTypeHashMap.put(dataSourceKey, defaultDbType);
066    }
067
068    public void addDataSource(String dataSourceKey, DataSource dataSource) {
069        addDataSource(dataSourceKey, dataSource, true);
070    }
071
072
073    public void addDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) {
074        if (needDecryptDataSource) {
075            DataSourceManager.decryptDataSource(dataSource);
076        }
077        dataSourceMap.put(dataSourceKey, dataSource);
078        dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource));
079    }
080
081
082    public void removeDatasource(String dataSourceKey) {
083        dataSourceMap.remove(dataSourceKey);
084        dbTypeHashMap.remove(dataSourceKey);
085    }
086
087    public Map<String, DataSource> getDataSourceMap() {
088        return dataSourceMap;
089    }
090
091    public Map<String, DbType> getDbTypeHashMap() {
092        return dbTypeHashMap;
093    }
094
095    public String getDefaultDataSourceKey() {
096        return defaultDataSourceKey;
097    }
098
099    public DataSource getDefaultDataSource() {
100        return defaultDataSource;
101    }
102
103    public DbType getDefaultDbType() {
104        return defaultDbType;
105    }
106
107    public DbType getDbType(String dataSourceKey) {
108        return dbTypeHashMap.get(dataSourceKey);
109    }
110
111
112    @Override
113    public Connection getConnection() throws SQLException {
114        String xid = TransactionContext.getXID();
115        if (StringUtil.isNotBlank(xid)) {
116            String dataSourceKey = DataSourceKey.get();
117            if (StringUtil.isBlank(dataSourceKey)) {
118                dataSourceKey = defaultDataSourceKey;
119            }
120
121            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
122            if (connection == null) {
123                connection = proxy(getDataSource().getConnection(), xid);
124                TransactionalManager.hold(xid, dataSourceKey, connection);
125            }
126            return connection;
127        } else {
128            return getDataSource().getConnection();
129        }
130    }
131
132
133    @Override
134    public Connection getConnection(String username, String password) throws SQLException {
135        String xid = TransactionContext.getXID();
136        if (StringUtil.isNotBlank(xid)) {
137            String dataSourceKey = DataSourceKey.get();
138            if (StringUtil.isBlank(dataSourceKey)) {
139                dataSourceKey = defaultDataSourceKey;
140            }
141            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
142            if (connection == null) {
143                connection = proxy(getDataSource().getConnection(username, password), xid);
144                TransactionalManager.hold(xid, dataSourceKey, connection);
145            }
146            return connection;
147        } else {
148            return getDataSource().getConnection(username, password);
149        }
150    }
151
152    static void closeAutoCommit(Connection connection) {
153        try {
154            connection.setAutoCommit(false);
155        } catch (SQLException e) {
156            if (log.isDebugEnabled()) {
157                log.debug("Error set autoCommit to false. Cause: " + e);
158            }
159        }
160    }
161
162    static void resetAutoCommit(Connection connection) {
163        try {
164            if (!connection.getAutoCommit()) {
165                connection.setAutoCommit(true);
166            }
167        } catch (SQLException e) {
168            if (log.isDebugEnabled()) {
169                log.debug("Error resetting autoCommit to true before closing the connection. " +
170                    "Cause: " + e);
171            }
172        }
173    }
174
175
176    public Connection proxy(Connection connection, String xid) {
177        return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader()
178            , new Class[]{Connection.class}
179            , new ConnectionHandler(connection, xid));
180    }
181
182    /**
183     * 方便用于 {@link DbTypeUtil#getDbType(DataSource)}
184     */
185    public String getUrl() {
186        return DbTypeUtil.getJdbcUrl(defaultDataSource);
187    }
188
189
190    @Override
191    @SuppressWarnings("unchecked")
192    public <T> T unwrap(Class<T> iface) throws SQLException {
193        if (iface.isInstance(this)) {
194            return (T) this;
195        }
196        return getDataSource().unwrap(iface);
197    }
198
199    @Override
200    public boolean isWrapperFor(Class<?> iface) throws SQLException {
201        return (iface.isInstance(this) || getDataSource().isWrapperFor(iface));
202    }
203
204
205    private DataSource getDataSource() {
206        DataSource dataSource = defaultDataSource;
207        if (dataSourceMap.size() > 1) {
208            String dataSourceKey = DataSourceKey.get();
209            if (StringUtil.isNotBlank(dataSourceKey)) {
210                //负载均衡 key
211                if (dataSourceKey.charAt(dataSourceKey.length() - 1) == LOAD_BALANCE_KEY_SUFFIX) {
212                    String prefix = dataSourceKey.substring(0, dataSourceKey.length() - 1);
213                    List<String> matchedKeys = new ArrayList<>();
214                    for (String key : dataSourceMap.keySet()) {
215                        if (key.startsWith(prefix)) {
216                            matchedKeys.add(key);
217                        }
218                    }
219
220                    if (matchedKeys.isEmpty()) {
221                        throw new IllegalStateException("Can not matched dataSource by key: \"" + dataSourceKey + "\"");
222                    }
223
224                    String randomKey = matchedKeys.get(ThreadLocalRandom.current().nextInt(matchedKeys.size()));
225                    return dataSourceMap.get(randomKey);
226                }
227                //非负载均衡 key
228                else {
229                    dataSource = dataSourceMap.get(dataSourceKey);
230                    if (dataSource == null) {
231                        throw new IllegalStateException("Cannot get target dataSource by key: \"" + dataSourceKey + "\"");
232                    }
233                }
234            }
235        }
236        return dataSource;
237    }
238
239    private static class ConnectionHandler implements InvocationHandler {
240        private static final String[] proxyMethods = new String[]{"commit", "rollback", "close", "setAutoCommit"};
241        private final Connection original;
242        private final String xid;
243
244        public ConnectionHandler(Connection original, String xid) {
245            closeAutoCommit(original);
246            this.original = original;
247            this.xid = xid;
248        }
249
250        @Override
251        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
252            if (ArrayUtil.contains(proxyMethods, method.getName())
253                && isTransactional()) {
254                //do nothing
255                return null;
256            }
257
258            //setAutoCommit: true
259            if ("close".equalsIgnoreCase(method.getName())) {
260                resetAutoCommit(original);
261            }
262
263            return method.invoke(original, args);
264        }
265
266        private boolean isTransactional() {
267            return Objects.equals(xid, TransactionContext.getXID());
268        }
269
270    }
271
272
273}