001/**
002 * Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 * <p>
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 * <p>
008 * http://www.apache.org/licenses/LICENSE-2.0
009 * <p>
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package com.mybatisflex.core.datasource;
017
018import com.mybatisflex.core.dialect.DbType;
019import com.mybatisflex.core.dialect.DbTypeUtil;
020import com.mybatisflex.core.transaction.TransactionContext;
021import com.mybatisflex.core.transaction.TransactionalManager;
022import com.mybatisflex.core.util.ArrayUtil;
023import com.mybatisflex.core.util.StringUtil;
024
025import javax.sql.DataSource;
026import java.lang.reflect.InvocationHandler;
027import java.lang.reflect.Method;
028import java.lang.reflect.Proxy;
029import java.sql.Connection;
030import java.sql.SQLException;
031import java.util.HashMap;
032import java.util.Map;
033import java.util.Objects;
034
035public class FlexDataSource extends AbstractDataSource {
036
037    private final Map<String, DataSource> dataSourceMap = new HashMap<>();
038    private final Map<String, DbType> dbTypeHashMap = new HashMap<>();
039
040    private final String defaultDataSourceKey;
041    private final DataSource defaultDataSource;
042
043    public FlexDataSource(String dataSourceKey, DataSource dataSource) {
044        this.defaultDataSourceKey = dataSourceKey;
045        this.defaultDataSource = dataSource;
046        dataSourceMap.put(dataSourceKey, dataSource);
047        dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource));
048    }
049
050    public void addDataSource(String dataSourceKey, DataSource dataSource) {
051        dataSourceMap.put(dataSourceKey, dataSource);
052        dbTypeHashMap.put(dataSourceKey, DbTypeUtil.getDbType(dataSource));
053    }
054
055    public DbType getDbType(String dataSourceKey) {
056        return dbTypeHashMap.get(dataSourceKey);
057    }
058
059    @Override
060    public Connection getConnection() throws SQLException {
061        String xid = TransactionContext.getXID();
062        if (StringUtil.isNotBlank(xid)) {
063            String dataSourceKey = DataSourceKey.get();
064            if (StringUtil.isBlank(dataSourceKey)) {
065                dataSourceKey = defaultDataSourceKey;
066            }
067
068            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
069            if (connection == null) {
070                connection = proxy(getDataSource().getConnection(), xid);
071                TransactionalManager.hold(xid, dataSourceKey, connection);
072            }
073            return connection;
074        } else {
075            return getDataSource().getConnection();
076        }
077    }
078
079
080    @Override
081    public Connection getConnection(String username, String password) throws SQLException {
082        String xid = TransactionContext.getXID();
083        if (StringUtil.isNotBlank(xid)) {
084            String dataSourceKey = DataSourceKey.get();
085            if (StringUtil.isBlank(dataSourceKey)) {
086                dataSourceKey = defaultDataSourceKey;
087            }
088            Connection connection = TransactionalManager.getConnection(xid, dataSourceKey);
089            if (connection == null) {
090                connection = proxy(getDataSource().getConnection(username, password), xid);
091                TransactionalManager.hold(xid, dataSourceKey, connection);
092            }
093            return connection;
094        } else {
095            return getDataSource().getConnection(username, password);
096        }
097    }
098
099    public Connection proxy(Connection connection, String xid) {
100        return (Connection) Proxy.newProxyInstance(FlexDataSource.class.getClassLoader()
101                , new Class[]{Connection.class}
102                , new ConnectionHandler(connection, xid));
103    }
104
105    /**
106     * 方便用于 {@link DbTypeUtil#getDbType(DataSource)}
107     */
108    public String getUrl(){
109        return DbTypeUtil.getJdbcUrl(defaultDataSource);
110    }
111
112    @Override
113    @SuppressWarnings("unchecked")
114    public <T> T unwrap(Class<T> iface) throws SQLException {
115        if (iface.isInstance(this)) {
116            return (T) this;
117        }
118        return getDataSource().unwrap(iface);
119    }
120
121    @Override
122    public boolean isWrapperFor(Class<?> iface) throws SQLException {
123        return (iface.isInstance(this) || getDataSource().isWrapperFor(iface));
124    }
125
126
127    private DataSource getDataSource() {
128        DataSource dataSource = defaultDataSource;
129        if (dataSourceMap.size() > 1) {
130            String dataSourceKey = DataSourceKey.get();
131            if (StringUtil.isNotBlank(dataSourceKey)) {
132                dataSource = dataSourceMap.get(dataSourceKey);
133                if (dataSource == null) {
134                    throw new IllegalStateException("Cannot get target DataSource for dataSourceKey [" + dataSourceKey + "]");
135                }
136            }
137        }
138        return dataSource;
139    }
140
141    private static class ConnectionHandler implements InvocationHandler {
142        private static final String[] proxyMethods = new String[]{"commit", "rollback", "close",};
143        private final Connection original;
144        private final String xid;
145
146        public ConnectionHandler(Connection original, String xid) {
147            this.original = original;
148            this.xid = xid;
149        }
150
151        @Override
152        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
153            if (ArrayUtil.contains(proxyMethods, method.getName())
154                    && isTransactional()) {
155                return null;  //do nothing
156            }
157            return method.invoke(original, args);
158        }
159
160        private boolean isTransactional() {
161            return Objects.equals(xid, TransactionContext.getXID());
162        }
163    }
164
165
166}