001/* 002 * Copyright (c) 2022-2025, Mybatis-Flex (fuhai999@gmail.com). 003 * <p> 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * <p> 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * <p> 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package com.mybatisflex.core.datasource; 017 018import com.mybatisflex.core.dialect.DbType; 019import com.mybatisflex.core.dialect.DbTypeUtil; 020import com.mybatisflex.core.transaction.TransactionContext; 021import com.mybatisflex.core.transaction.TransactionalManager; 022import com.mybatisflex.core.util.ArrayUtil; 023import com.mybatisflex.core.util.StringUtil; 024import org.apache.ibatis.logging.Log; 025import org.apache.ibatis.logging.LogFactory; 026 027import javax.sql.DataSource; 028import java.lang.reflect.InvocationHandler; 029import java.lang.reflect.Method; 030import java.lang.reflect.Proxy; 031import java.sql.Connection; 032import java.sql.SQLException; 033import java.util.*; 034import java.util.concurrent.ThreadLocalRandom; 035 036/** 037 * @author michael 038 */ 039public class FlexDataSource extends AbstractDataSource { 040 041 private static final char LOAD_BALANCE_KEY_SUFFIX = '*'; 042 private static final Log log = LogFactory.getLog(FlexDataSource.class); 043 044 private final Map<String, DataSource> dataSourceMap = new HashMap<>(); 045 private final Map<String, DbType> dbTypeHashMap = new HashMap<>(); 046 047 private DbType defaultDbType; 048 private String defaultDataSourceKey; 049 private DataSource defaultDataSource; 050 051 public FlexDataSource(String dataSourceKey, DataSource dataSource) { 052 this(dataSourceKey, dataSource, true); 053 } 054 055 public FlexDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) { 056 if (needDecryptDataSource) { 057 DataSourceManager.decryptDataSource(dataSource); 058 } 059 060 this.defaultDataSourceKey = dataSourceKey; 061 this.defaultDataSource = dataSource; 062 this.defaultDbType = DbTypeUtil.getDbType(dataSource); 063 064 dataSourceMap.put(dataSourceKey, dataSource); 065 dbTypeHashMap.put(dataSourceKey, defaultDbType); 066 } 067 068 /** 069 * 设置默认数据源(提供动态可控性) 070 */ 071 public void setDefaultDataSource(String dataSourceKey) { 072 DataSource ds = dataSourceMap.get(dataSourceKey); 073 074 if (ds != null) { 075 this.defaultDataSourceKey = dataSourceKey; 076 this.defaultDataSource = ds; 077 this.defaultDbType = DbTypeUtil.getDbType(ds); 078 } else { 079 throw new IllegalStateException("DataSource not found by key: \"" + dataSourceKey + "\""); 080 } 081 } 082 083 public void addDataSource(String dataSourceKey, DataSource dataSource) { 084 addDataSource(dataSourceKey, dataSource, true); 085 } 086 087 088 public void addDataSource(String dataSourceKey, DataSource dataSource, boolean needDecryptDataSource) { 089 if (needDecryptDataSource) { 090 DataSourceManager.decryptDataSource(dataSource); 091 } 092 dataSourceMap.put(dataSourceKey, dataSource); 093 dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource)); 094 } 095 096 097 public void removeDatasource(String dataSourceKey) { 098 dataSourceMap.remove(dataSourceKey); 099 dbTypeHashMap.remove(dataSourceKey); 100 } 101 102 public Map<String, DataSource> getDataSourceMap() { 103 return dataSourceMap; 104 } 105 106 public Map<String, DbType> getDbTypeHashMap() { 107 return dbTypeHashMap; 108 } 109 110 public String getDefaultDataSourceKey() { 111 return defaultDataSourceKey; 112 } 113 114 public DataSource getDefaultDataSource() { 115 return defaultDataSource; 116 } 117 118 public DbType getDefaultDbType() { 119 return defaultDbType; 120 } 121 122 public DbType getDbType(String dataSourceKey) { 123 return dbTypeHashMap.get(dataSourceKey); 124 } 125 126 127 @Override 128 public Connection getConnection() throws SQLException { 129 String xid = TransactionContext.getXID(); 130 if (StringUtil.hasText(xid)) { 131 String dataSourceKey = DataSourceKey.get(); 132 if (StringUtil.noText(dataSourceKey)) { 133 dataSourceKey = defaultDataSourceKey; 134 } 135 136 Connection connection = TransactionalManager.getConnection(xid, dataSourceKey); 137 if (connection == null) { 138 connection = proxy(getDataSource().getConnection(), xid); 139 TransactionalManager.hold(xid, dataSourceKey, connection); 140 } 141 return connection; 142 } else { 143 return getDataSource().getConnection(); 144 } 145 } 146 147 148 @Override 149 public Connection getConnection(String username, String password) throws SQLException { 150 String xid = TransactionContext.getXID(); 151 if (StringUtil.hasText(xid)) { 152 String dataSourceKey = DataSourceKey.get(); 153 if (StringUtil.noText(dataSourceKey)) { 154 dataSourceKey = defaultDataSourceKey; 155 } 156 Connection connection = TransactionalManager.getConnection(xid, dataSourceKey); 157 if (connection == null) { 158 connection = proxy(getDataSource().getConnection(username, password), xid); 159 TransactionalManager.hold(xid, dataSourceKey, connection); 160 } 161 return connection; 162 } else { 163 return getDataSource().getConnection(username, password); 164 } 165 } 166 167 static void closeAutoCommit(Connection connection) { 168 try { 169 connection.setAutoCommit(false); 170 } catch (SQLException e) { 171 if (log.isDebugEnabled()) { 172 log.debug("Error set autoCommit to false. Cause: " + e); 173 } 174 } 175 } 176 177 static void resetAutoCommit(Connection connection) { 178 try { 179 if (!connection.getAutoCommit()) { 180 connection.setAutoCommit(true); 181 } 182 } catch (SQLException e) { 183 if (log.isDebugEnabled()) { 184 log.debug("Error resetting autoCommit to true before closing the connection. " + 185 "Cause: " + e); 186 } 187 } 188 } 189 190 191 public Connection proxy(Connection connection, String xid) { 192 return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader() 193 , new Class[]{Connection.class} 194 , new ConnectionHandler(connection, xid) 195 ); 196 } 197 198 /** 199 * 方便用于 {@link DbTypeUtil#getDbType(DataSource)} 200 */ 201 public String getUrl() { 202 return DbTypeUtil.getJdbcUrl(defaultDataSource); 203 } 204 205 206 @Override 207 @SuppressWarnings("unchecked") 208 public <T> T unwrap(Class<T> iface) throws SQLException { 209 if (iface.isInstance(this)) { 210 return (T) this; 211 } 212 return getDataSource().unwrap(iface); 213 } 214 215 @Override 216 public boolean isWrapperFor(Class<?> iface) throws SQLException { 217 return (iface.isInstance(this) || getDataSource().isWrapperFor(iface)); 218 } 219 220 221 protected DataSource getDataSource() { 222 DataSource dataSource = defaultDataSource; 223 if (dataSourceMap.size() > 1) { 224 String dataSourceKey = DataSourceKey.get(); 225 if (StringUtil.hasText(dataSourceKey)) { 226 //负载均衡 key 227 if (dataSourceKey.charAt(dataSourceKey.length() - 1) == LOAD_BALANCE_KEY_SUFFIX) { 228 String prefix = dataSourceKey.substring(0, dataSourceKey.length() - 1); 229 List<String> matchedKeys = new ArrayList<>(); 230 for (String key : dataSourceMap.keySet()) { 231 if (key.startsWith(prefix)) { 232 matchedKeys.add(key); 233 } 234 } 235 236 if (matchedKeys.isEmpty()) { 237 throw new IllegalStateException("Can not matched dataSource by key: \"" + dataSourceKey + "\""); 238 } 239 240 String randomKey = matchedKeys.get(ThreadLocalRandom.current().nextInt(matchedKeys.size())); 241 return dataSourceMap.get(randomKey); 242 } 243 //非负载均衡 key 244 else { 245 dataSource = dataSourceMap.get(dataSourceKey); 246 if (dataSource == null) { 247 throw new IllegalStateException("Cannot get target dataSource by key: \"" + dataSourceKey + "\""); 248 } 249 } 250 } 251 } 252 return dataSource; 253 } 254 255 private static class ConnectionHandler implements InvocationHandler { 256 private static final String[] proxyMethods = new String[]{"commit", "rollback", "close", "setAutoCommit"}; 257 private final Connection original; 258 private final String xid; 259 260 public ConnectionHandler(Connection original, String xid) { 261 closeAutoCommit(original); 262 this.original = original; 263 this.xid = xid; 264 } 265 266 @Override 267 public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { 268 if (ArrayUtil.contains(proxyMethods, method.getName()) 269 && isTransactional()) { 270 //do nothing 271 return null; 272 } 273 274 //setAutoCommit: true 275 if ("close".equalsIgnoreCase(method.getName())) { 276 resetAutoCommit(original); 277 } 278 279 return method.invoke(original, args); 280 } 281 282 private boolean isTransactional() { 283 return Objects.equals(xid, TransactionContext.getXID()); 284 } 285 286 } 287}