001/*
002 *  Copyright (c) 2022-2025, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.datasource;
017
018import com.mybatisflex.core.dialect.DbType;
019import com.mybatisflex.core.dialect.DbTypeUtil;
020import com.mybatisflex.core.transaction.TransactionContext;
021import com.mybatisflex.core.transaction.TransactionalManager;
022import com.mybatisflex.core.util.ArrayUtil;
023import com.mybatisflex.core.util.StringUtil;
024import org.apache.ibatis.logging.Log;
025import org.apache.ibatis.logging.LogFactory;
026
027import javax.sql.DataSource;
028import java.lang.reflect.InvocationHandler;
029import java.lang.reflect.Method;
030import java.lang.reflect.Proxy;
031import java.sql.Connection;
032import java.sql.SQLException;
033import java.util.*;
034import java.util.concurrent.ThreadLocalRandom;
035
036/**
037 * @author michael
038 */
039public class FlexDataSource extends AbstractDataSource {
040
041    private static final char LOAD_BALANCE_KEY_SUFFIX = '*';
042    private static final Log log = LogFactory.getLog(FlexDataSource.class);
043
044    private final Map<String, DataSource> dataSourceMap = new HashMap<>();
045    private final Map<String, DbType> dbTypeHashMap = new HashMap<>();
046
047    private DbType defaultDbType;
048    private String defaultDataSourceKey;
049    private DataSource defaultDataSource;
050
051    public FlexDataSource(String dataSourceKey, DataSource dataSource) {
052        this(dataSourceKey, dataSource, true);
053    }
054
055    public FlexDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) {
056        if (needDecryptDataSource) {
057            DataSourceManager.decryptDataSource(dataSource);
058        }
059
060        this.defaultDataSourceKey = dataSourceKey;
061        this.defaultDataSource = dataSource;
062        this.defaultDbType = DbTypeUtil.getDbType(dataSource);
063
064        dataSourceMap.put(dataSourceKey, dataSource);
065        dbTypeHashMap.put(dataSourceKey, defaultDbType);
066    }
067
068    /**
069     * 设置默认数据源(提供动态可控性)
070     */
071    public void setDefaultDataSource(String dataSourceKey) {
072        DataSource ds = dataSourceMap.get(dataSourceKey);
073
074        if (ds != null) {
075            this.defaultDataSourceKey = dataSourceKey;
076            this.defaultDataSource = ds;
077            this.defaultDbType = DbTypeUtil.getDbType(ds);
078        } else {
079            throw new IllegalStateException("DataSource not found by key: \"" + dataSourceKey + "\"");
080        }
081    }
082
083    public void addDataSource(String dataSourceKey, DataSource dataSource) {
084        addDataSource(dataSourceKey, dataSource, true);
085    }
086
087
088    public void addDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) {
089        if (needDecryptDataSource) {
090            DataSourceManager.decryptDataSource(dataSource);
091        }
092        dataSourceMap.put(dataSourceKey, dataSource);
093        dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource));
094    }
095
096
097    public void removeDatasource(String dataSourceKey) {
098        dataSourceMap.remove(dataSourceKey);
099        dbTypeHashMap.remove(dataSourceKey);
100    }
101
102    public Map<String, DataSource> getDataSourceMap() {
103        return dataSourceMap;
104    }
105
106    public Map<String, DbType> getDbTypeHashMap() {
107        return dbTypeHashMap;
108    }
109
110    public String getDefaultDataSourceKey() {
111        return defaultDataSourceKey;
112    }
113
114    public DataSource getDefaultDataSource() {
115        return defaultDataSource;
116    }
117
118    public DbType getDefaultDbType() {
119        return defaultDbType;
120    }
121
122    public DbType getDbType(String dataSourceKey) {
123        return dbTypeHashMap.get(dataSourceKey);
124    }
125
126
127    @Override
128    public Connection getConnection() throws SQLException {
129        String xid = TransactionContext.getXID();
130        if (StringUtil.hasText(xid)) {
131            String dataSourceKey = DataSourceKey.get();
132            if (StringUtil.noText(dataSourceKey)) {
133                dataSourceKey = defaultDataSourceKey;
134            }
135
136            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
137            if (connection == null) {
138                connection = proxy(getDataSource().getConnection(), xid);
139                TransactionalManager.hold(xid, dataSourceKey, connection);
140            }
141            return connection;
142        } else {
143            return getDataSource().getConnection();
144        }
145    }
146
147
148    @Override
149    public Connection getConnection(String username, String password) throws SQLException {
150        String xid = TransactionContext.getXID();
151        if (StringUtil.hasText(xid)) {
152            String dataSourceKey = DataSourceKey.get();
153            if (StringUtil.noText(dataSourceKey)) {
154                dataSourceKey = defaultDataSourceKey;
155            }
156            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
157            if (connection == null) {
158                connection = proxy(getDataSource().getConnection(username, password), xid);
159                TransactionalManager.hold(xid, dataSourceKey, connection);
160            }
161            return connection;
162        } else {
163            return getDataSource().getConnection(username, password);
164        }
165    }
166
167    static void closeAutoCommit(Connection connection) {
168        try {
169            connection.setAutoCommit(false);
170        } catch (SQLException e) {
171            if (log.isDebugEnabled()) {
172                log.debug("Error set autoCommit to false. Cause: " + e);
173            }
174        }
175    }
176
177    static void resetAutoCommit(Connection connection) {
178        try {
179            if (!connection.getAutoCommit()) {
180                connection.setAutoCommit(true);
181            }
182        } catch (SQLException e) {
183            if (log.isDebugEnabled()) {
184                log.debug("Error resetting autoCommit to true before closing the connection. " +
185                    "Cause: " + e);
186            }
187        }
188    }
189
190
191    public Connection proxy(Connection connection, String xid) {
192        return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader()
193            , new Class[]{Connection.class}
194            , new ConnectionHandler(connection, xid)
195        );
196    }
197
198    /**
199     * 方便用于 {@link DbTypeUtil#getDbType(DataSource)}
200     */
201    public String getUrl() {
202        return DbTypeUtil.getJdbcUrl(defaultDataSource);
203    }
204
205
206    @Override
207    @SuppressWarnings("unchecked")
208    public <T> T unwrap(Class<T> iface) throws SQLException {
209        if (iface.isInstance(this)) {
210            return (T) this;
211        }
212        return getDataSource().unwrap(iface);
213    }
214
215    @Override
216    public boolean isWrapperFor(Class<?> iface) throws SQLException {
217        return (iface.isInstance(this) || getDataSource().isWrapperFor(iface));
218    }
219
220
221    protected DataSource getDataSource() {
222        DataSource dataSource = defaultDataSource;
223        if (dataSourceMap.size() > 1) {
224            String dataSourceKey = DataSourceKey.get();
225            if (StringUtil.hasText(dataSourceKey)) {
226                //负载均衡 key
227                if (dataSourceKey.charAt(dataSourceKey.length() - 1) == LOAD_BALANCE_KEY_SUFFIX) {
228                    String prefix = dataSourceKey.substring(0, dataSourceKey.length() - 1);
229                    List<String> matchedKeys = new ArrayList<>();
230                    for (String key : dataSourceMap.keySet()) {
231                        if (key.startsWith(prefix)) {
232                            matchedKeys.add(key);
233                        }
234                    }
235
236                    if (matchedKeys.isEmpty()) {
237                        throw new IllegalStateException("Can not matched dataSource by key: \"" + dataSourceKey + "\"");
238                    }
239
240                    String randomKey = matchedKeys.get(ThreadLocalRandom.current().nextInt(matchedKeys.size()));
241                    return dataSourceMap.get(randomKey);
242                }
243                //非负载均衡 key
244                else {
245                    dataSource = dataSourceMap.get(dataSourceKey);
246                    if (dataSource == null) {
247                        throw new IllegalStateException("Cannot get target dataSource by key: \"" + dataSourceKey + "\"");
248                    }
249                }
250            }
251        }
252        return dataSource;
253    }
254
255    private static class ConnectionHandler implements InvocationHandler {
256        private static final String[] proxyMethods = new String[]{"commit", "rollback", "close", "setAutoCommit"};
257        private final Connection original;
258        private final String xid;
259
260        public ConnectionHandler(Connection original, String xid) {
261            closeAutoCommit(original);
262            this.original = original;
263            this.xid = xid;
264        }
265
266        @Override
267        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
268            if (ArrayUtil.contains(proxyMethods, method.getName())
269                && isTransactional()) {
270                //do nothing
271                return null;
272            }
273
274            //setAutoCommit: true
275            if ("close".equalsIgnoreCase(method.getName())) {
276                resetAutoCommit(original);
277            }
278
279            return method.invoke(original, args);
280        }
281
282        private boolean isTransactional() {
283            return Objects.equals(xid, TransactionContext.getXID());
284        }
285
286    }
287}