001/*
002 *  Copyright (c) 2022-2023, Mybatis-Flex (fuhai999@gmail.com).
003 *  <p>
004 *  Licensed under the Apache License, Version 2.0 (the "License");
005 *  you may not use this file except in compliance with the License.
006 *  You may obtain a copy of the License at
007 *  <p>
008 *  http://www.apache.org/licenses/LICENSE-2.0
009 *  <p>
010 *  Unless required by applicable law or agreed to in writing, software
011 *  distributed under the License is distributed on an "AS IS" BASIS,
012 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 *  See the License for the specific language governing permissions and
014 *  limitations under the License.
015 */
016package com.mybatisflex.core.dialect;
017
018
019import com.mybatisflex.core.exception.FlexExceptions;
020import com.mybatisflex.core.util.StringUtil;
021import org.apache.ibatis.datasource.unpooled.UnpooledDataSource;
022
023import javax.sql.DataSource;
024import java.lang.reflect.Method;
025import java.sql.Connection;
026import java.sql.SQLException;
027import java.util.regex.Pattern;
028
029/**
030 * DbType 解析 工具类
031 */
032public class DbTypeUtil {
033
034    private DbTypeUtil() {
035    }
036
037    /**
038     * 获取当前配置的 DbType
039     */
040    public static DbType getDbType(DataSource dataSource) {
041        String jdbcUrl = getJdbcUrl(dataSource);
042
043        if (StringUtil.isNotBlank(jdbcUrl)) {
044            return parseDbType(jdbcUrl);
045        }
046
047        throw new IllegalStateException("Can not get dataSource jdbcUrl: " + dataSource.getClass().getName());
048    }
049
050    /**
051     * 通过数据源中获取 jdbc 的 url 配置
052     * 符合 HikariCP, druid, c3p0, DBCP, beecp 数据源框架 以及 MyBatis UnpooledDataSource 的获取规则
053     * UnpooledDataSource 参考 @{@link UnpooledDataSource#getUrl()}
054     *
055     * @return jdbc url 配置
056     */
057    public static String getJdbcUrl(DataSource dataSource) {
058        String[] methodNames = new String[]{"getUrl", "getJdbcUrl"};
059        for (String methodName : methodNames) {
060            try {
061                Method method = dataSource.getClass().getMethod(methodName);
062                return (String) method.invoke(dataSource);
063            } catch (Exception e) {
064                //ignore
065            }
066        }
067
068        Connection connection = null;
069        try {
070            connection = dataSource.getConnection();
071            return connection.getMetaData().getURL();
072        } catch (Exception e) {
073            throw FlexExceptions.wrap("Can not get the dataSource jdbcUrl", e);
074        } finally {
075            if (connection != null) {
076                try {
077                    connection.close();
078                } catch (SQLException e) { //ignore
079                }
080            }
081        }
082    }
083
084
085    /**
086     * 参考 druid  和 MyBatis-plus 的 JdbcUtils
087     * {@link com.alibaba.druid.util.JdbcUtils#getDbType(String, String)}
088     * {@link com.baomidou.mybatisplus.extension.toolkit.JdbcUtils#getDbType(String)}
089     *
090     * @param jdbcUrl jdbcURL
091     * @return 返回数据库类型
092     */
093    public static DbType parseDbType(String jdbcUrl) {
094        jdbcUrl = jdbcUrl.toLowerCase();
095        if (jdbcUrl.contains(":mysql:") || jdbcUrl.contains(":cobar:")) {
096            return DbType.MYSQL;
097        } else if (jdbcUrl.contains(":mariadb:")) {
098            return DbType.MARIADB;
099        } else if (jdbcUrl.contains(":oracle:")) {
100            return DbType.ORACLE;
101        } else if (jdbcUrl.contains(":sqlserver2012:")) {
102            return DbType.SQLSERVER;
103        } else if (jdbcUrl.contains(":sqlserver:") || jdbcUrl.contains(":microsoft:")) {
104            return DbType.SQLSERVER_2005;
105        } else if (jdbcUrl.contains(":postgresql:")) {
106            return DbType.POSTGRE_SQL;
107        } else if (jdbcUrl.contains(":hsqldb:")) {
108            return DbType.HSQL;
109        } else if (jdbcUrl.contains(":db2:")) {
110            return DbType.DB2;
111        } else if (jdbcUrl.contains(":sqlite:")) {
112            return DbType.SQLITE;
113        } else if (jdbcUrl.contains(":h2:")) {
114            return DbType.H2;
115        } else if (isMatchedRegex(":dm\\d*:", jdbcUrl)) {
116            return DbType.DM;
117        } else if (jdbcUrl.contains(":xugu:")) {
118            return DbType.XUGU;
119        } else if (isMatchedRegex(":kingbase\\d*:", jdbcUrl)) {
120            return DbType.KINGBASE_ES;
121        } else if (jdbcUrl.contains(":phoenix:")) {
122            return DbType.PHOENIX;
123        } else if (jdbcUrl.contains(":zenith:")) {
124            return DbType.GAUSS;
125        } else if (jdbcUrl.contains(":gbase:")) {
126            return DbType.GBASE;
127        } else if (jdbcUrl.contains(":gbasedbt-sqli:") || jdbcUrl.contains(":informix-sqli:")) {
128            return DbType.GBASE_8S;
129        } else if (jdbcUrl.contains(":ch:") || jdbcUrl.contains(":clickhouse:")) {
130            return DbType.CLICK_HOUSE;
131        } else if (jdbcUrl.contains(":oscar:")) {
132            return DbType.OSCAR;
133        } else if (jdbcUrl.contains(":sybase:")) {
134            return DbType.SYBASE;
135        } else if (jdbcUrl.contains(":oceanbase:")) {
136            return DbType.OCEAN_BASE;
137        } else if (jdbcUrl.contains(":highgo:")) {
138            return DbType.HIGH_GO;
139        } else if (jdbcUrl.contains(":cubrid:")) {
140            return DbType.CUBRID;
141        } else if (jdbcUrl.contains(":goldilocks:")) {
142            return DbType.GOLDILOCKS;
143        } else if (jdbcUrl.contains(":csiidb:")) {
144            return DbType.CSIIDB;
145        } else if (jdbcUrl.contains(":sap:")) {
146            return DbType.SAP_HANA;
147        } else if (jdbcUrl.contains(":impala:")) {
148            return DbType.IMPALA;
149        } else if (jdbcUrl.contains(":vertica:")) {
150            return DbType.VERTICA;
151        } else if (jdbcUrl.contains(":xcloud:")) {
152            return DbType.XCloud;
153        } else if (jdbcUrl.contains(":firebirdsql:")) {
154            return DbType.FIREBIRD;
155        } else if (jdbcUrl.contains(":redshift:")) {
156            return DbType.REDSHIFT;
157        } else if (jdbcUrl.contains(":opengauss:")) {
158            return DbType.OPENGAUSS;
159        } else if (jdbcUrl.contains(":taos:") || jdbcUrl.contains(":taos-rs:")) {
160            return DbType.TDENGINE;
161        } else if (jdbcUrl.contains(":informix")) {
162            return DbType.INFORMIX;
163        } else if (jdbcUrl.contains(":sinodb")) {
164            return DbType.SINODB;
165        } else if (jdbcUrl.contains(":uxdb:")) {
166            return DbType.UXDB;
167        } else if (jdbcUrl.contains(":greenplum:")) {
168            return DbType.GREENPLUM;
169        } else {
170            return DbType.OTHER;
171        }
172    }
173
174    /**
175     * 正则匹配,验证成功返回 true,验证失败返回 false
176     */
177    public static boolean isMatchedRegex(String regex, String jdbcUrl) {
178        if (null == jdbcUrl) {
179            return false;
180        }
181        return Pattern.compile(regex).matcher(jdbcUrl).find();
182    }
183
184}