001/**
002 * Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 * <p>
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 * <p>
008 * http://www.apache.org/licenses/LICENSE-2.0
009 * <p>
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package com.mybatisflex.core.datasource;
017
018import com.mybatisflex.core.dialect.DbType;
019import com.mybatisflex.core.dialect.DbTypeUtil;
020import com.mybatisflex.core.transaction.TransactionContext;
021import com.mybatisflex.core.transaction.TransactionalManager;
022import com.mybatisflex.core.util.ArrayUtil;
023import com.mybatisflex.core.util.StringUtil;
024import org.apache.ibatis.logging.Log;
025import org.apache.ibatis.logging.LogFactory;
026
027import javax.sql.DataSource;
028import java.lang.reflect.InvocationHandler;
029import java.lang.reflect.Method;
030import java.lang.reflect.Proxy;
031import java.sql.Connection;
032import java.sql.SQLException;
033import java.util.HashMap;
034import java.util.Map;
035import java.util.Objects;
036
037public class FlexDataSource extends AbstractDataSource {
038
039    private static final Log log = LogFactory.getLog(FlexDataSource.class);
040
041    private final Map<String, DataSource> dataSourceMap = new HashMap<>();
042    private final Map<String, DbType> dbTypeHashMap = new HashMap<>();
043
044    private final String defaultDataSourceKey;
045    private final DataSource defaultDataSource;
046
047    public FlexDataSource(String dataSourceKey, DataSource dataSource) {
048        this.defaultDataSourceKey = dataSourceKey;
049        this.defaultDataSource = dataSource;
050        dataSourceMap.put(dataSourceKey, dataSource);
051        dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource));
052    }
053
054    public void addDataSource(String dataSourceKey, DataSource dataSource) {
055        dataSourceMap.put(dataSourceKey, dataSource);
056        dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource));
057    }
058
059    public DbType getDbType(String dataSourceKey) {
060        return dbTypeHashMap.get(dataSourceKey);
061    }
062
063    @Override
064    public Connection getConnection() throws SQLException {
065        String xid = TransactionContext.getXID();
066        if (StringUtil.isNotBlank(xid)) {
067            String dataSourceKey = DataSourceKey.get();
068            if (StringUtil.isBlank(dataSourceKey)) {
069                dataSourceKey = defaultDataSourceKey;
070            }
071
072            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
073            if (connection == null) {
074                connection = proxy(getDataSource().getConnection(), xid);
075                TransactionalManager.hold(xid, dataSourceKey, connection);
076            }
077            return connection;
078        } else {
079            return getDataSource().getConnection();
080        }
081    }
082
083
084    @Override
085    public Connection getConnection(String username, String password) throws SQLException {
086        String xid = TransactionContext.getXID();
087        if (StringUtil.isNotBlank(xid)) {
088            String dataSourceKey = DataSourceKey.get();
089            if (StringUtil.isBlank(dataSourceKey)) {
090                dataSourceKey = defaultDataSourceKey;
091            }
092            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
093            if (connection == null) {
094                connection = proxy(getDataSource().getConnection(username, password), xid);
095                TransactionalManager.hold(xid, dataSourceKey, connection);
096            }
097            return connection;
098        } else {
099            return getDataSource().getConnection(username, password);
100        }
101    }
102
103     static void closeAutoCommit(Connection connection){
104        try {
105            connection.setAutoCommit(false);
106        } catch (SQLException e) {
107            if (log.isDebugEnabled()) {
108                log.debug("Error set AutoCommit to false.  Cause: " + e);
109            }
110        }
111    }
112
113     static void resetAutoCommit(Connection connection){
114        try {
115            if (!connection.getAutoCommit()){
116                connection.setAutoCommit(true);
117            }
118        } catch (SQLException e) {
119            if (log.isDebugEnabled()) {
120                log.debug("Error resetting autocommit to true "
121                        + "before closing the connection.  Cause: " + e);
122            }
123        }
124    }
125
126
127    public Connection proxy(Connection connection, String xid) {
128        return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader()
129                , new Class[]{Connection.class}
130                , new ConnectionHandler(connection, xid));
131    }
132
133    /**
134     * 方便用于 {@link DbTypeUtil#getDbType(DataSource)}
135     */
136    public String getUrl(){
137        return DbTypeUtil.getJdbcUrl(defaultDataSource);
138    }
139
140    @Override
141    @SuppressWarnings("unchecked")
142    public <T> T unwrap(Class<T> iface) throws SQLException {
143        if (iface.isInstance(this)) {
144            return (T) this;
145        }
146        return getDataSource().unwrap(iface);
147    }
148
149    @Override
150    public boolean isWrapperFor(Class<?> iface) throws SQLException {
151        return (iface.isInstance(this) || getDataSource().isWrapperFor(iface));
152    }
153
154
155    private DataSource getDataSource() {
156        DataSource dataSource = defaultDataSource;
157        if (dataSourceMap.size() > 1) {
158            String dataSourceKey = DataSourceKey.get();
159            if (StringUtil.isNotBlank(dataSourceKey)) {
160                dataSource = dataSourceMap.get(dataSourceKey);
161                if (dataSource == null) {
162                    throw new IllegalStateException("Cannot get target DataSource for dataSourceKey [" + dataSourceKey + "]");
163                }
164            }
165        }
166        return dataSource;
167    }
168
169    private static class ConnectionHandler implements InvocationHandler {
170        private static final String[] proxyMethods = new String[]{"commit", "rollback", "close","setAutoCommit"};
171        private final Connection original;
172        private final String xid;
173
174        public ConnectionHandler(Connection original, String xid) {
175
176            closeAutoCommit(original);
177
178            this.original = original;
179            this.xid = xid;
180        }
181
182        @Override
183        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
184            if (ArrayUtil.contains(proxyMethods, method.getName())
185                    && isTransactional()) {
186                //do nothing
187                return null;
188            }
189
190            //setAutoCommit: true
191            if ("close".equalsIgnoreCase(method.getName())){
192                resetAutoCommit(original);
193            }
194
195            return method.invoke(original, args);
196        }
197
198        private boolean isTransactional() {
199            return Objects.equals(xid, TransactionContext.getXID());
200        }
201    }
202
203
204}