/*
 * Decompiled with CFR 0.152.
 */
package water.jdbc;

import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.Date;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Time;
import java.sql.Timestamp;
import java.util.concurrent.ArrayBlockingQueue;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.fvec.AppendableVec;
import water.fvec.Chunk;
import water.fvec.FileVec;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.jdbc.SqlFetchMode;
import water.parser.BufferedString;
import water.parser.ParseDataset;
import water.util.Log;

public class SQLManager {
    private static final String TEMP_TABLE_NAME = "table_for_h2o_import";
    private static final String MAX_USR_CONNECTIONS_KEY = "sys.ai.h2o.sql.connections.max";
    private static final String JDBC_DRIVER_CLASS_KEY_PREFIX = "sys.ai.h2o.sql.jdbc.driver.";
    private static final int MAX_CONNECTIONS = 100;
    private static final int MIN_CONNECTIONS_PER_NODE = 1;
    private static final String NETEZZA_DB_TYPE = "netezza";
    private static final String HIVE_DB_TYPE = "hive2";
    private static final String ORACLE_DB_TYPE = "oracle";
    private static final String SQL_SERVER_DB_TYPE = "sqlserver";
    private static final String TERADATA_DB_TYPE = "teradata";
    private static final String NETEZZA_JDBC_DRIVER_CLASS = "org.netezza.Driver";
    private static final String HIVE_JDBC_DRIVER_CLASS = "org.apache.hive.jdbc.HiveDriver";
    private static final String TMP_TABLE_ENABLED = "sys.ai.h2o.sql.tmp_table.enabled";

    public static Job<Frame> importSqlTable(final String connection_url, String table, String select_query, final String username, final String password, final String columns, final SqlFetchMode sqlFetchMode) {
        int num_retrieval_chunks;
        byte[] columnH2OTypes;
        String[] columnNames;
        int numCol;
        Connection conn = null;
        Statement stmt = null;
        ResultSet rs = null;
        final String databaseType = connection_url.split(":", 3)[1];
        SQLManager.initializeDatabaseDriver(databaseType);
        int catcols = 0;
        int intcols = 0;
        int bincols = 0;
        int realcols = 0;
        int timecols = 0;
        int stringcols = 0;
        long numRow = 0L;
        try {
            conn = DriverManager.getConnection(connection_url, username, password);
            stmt = conn.createStatement();
            stmt.setFetchSize(1);
            if (table.equals("")) {
                if (!select_query.toLowerCase().startsWith("select")) {
                    throw new IllegalArgumentException("The select_query must start with `SELECT`, but instead is: " + select_query);
                }
                boolean createTmpTable = Boolean.parseBoolean(System.getProperty(TMP_TABLE_ENABLED, "true"));
                if (createTmpTable) {
                    table = TEMP_TABLE_NAME;
                    numRow = stmt.executeUpdate("CREATE TABLE " + table + " AS " + select_query);
                } else {
                    table = "(" + select_query + ") sub_h2o_import";
                }
            } else if (table.equals(TEMP_TABLE_NAME)) {
                throw new IllegalArgumentException("The specified table cannot be named: table_for_h2o_import");
            }
            if (numRow <= 0L) {
                rs = stmt.executeQuery("SELECT COUNT(1) FROM " + table);
                rs.next();
                numRow = rs.getLong(1);
                rs.close();
            }
            if (SqlFetchMode.DISTRIBUTED.equals((Object)sqlFetchMode)) {
                rs = stmt.executeQuery(SQLManager.buildSelectSingleRowSql(databaseType, table, columns));
            } else {
                stmt.setMaxRows(1);
                rs = stmt.executeQuery("SELECT " + columns + " FROM " + table);
            }
            ResultSetMetaData rsmd = rs.getMetaData();
            numCol = rsmd.getColumnCount();
            columnNames = new String[numCol];
            columnH2OTypes = new byte[numCol];
            rs.next();
            block24: for (int i = 0; i < numCol; ++i) {
                columnNames[i] = rsmd.getColumnName(i + 1);
                switch (rsmd.getColumnType(i + 1)) {
                    case 2: 
                    case 3: 
                    case 6: 
                    case 7: 
                    case 8: {
                        columnH2OTypes[i] = 3;
                        ++realcols;
                        continue block24;
                    }
                    case -6: 
                    case -5: 
                    case 4: 
                    case 5: {
                        columnH2OTypes[i] = 3;
                        ++intcols;
                        continue block24;
                    }
                    case -7: 
                    case 16: {
                        columnH2OTypes[i] = 3;
                        ++bincols;
                        continue block24;
                    }
                    case -16: 
                    case -15: 
                    case -9: 
                    case -1: 
                    case 1: 
                    case 12: {
                        columnH2OTypes[i] = 2;
                        ++stringcols;
                        continue block24;
                    }
                    case 91: 
                    case 92: 
                    case 93: {
                        columnH2OTypes[i] = 5;
                        ++timecols;
                        continue block24;
                    }
                    default: {
                        Log.warn("Unsupported column type: " + rsmd.getColumnTypeName(i + 1));
                        columnH2OTypes[i] = 0;
                    }
                }
            }
        }
        catch (SQLException ex) {
            throw new RuntimeException("SQLException: " + ex.getMessage() + "\nFailed to connect and read from SQL database with connection_url: " + connection_url, ex);
        }
        finally {
            if (rs != null) {
                try {
                    rs.close();
                }
                catch (SQLException sQLException) {}
                rs = null;
            }
            if (stmt != null) {
                try {
                    stmt.close();
                }
                catch (SQLException sQLException) {}
                stmt = null;
            }
            if (conn != null) {
                try {
                    conn.close();
                }
                catch (SQLException sQLException) {}
                conn = null;
            }
        }
        double binary_ones_fraction = 0.5;
        long totSize = (long)((double)((float)(catcols + intcols) * (float)numRow * 4.0f) + (double)((float)bincols * (float)numRow * 1.0f) * binary_ones_fraction + (double)((float)(realcols + timecols + stringcols) * (float)numRow * 8.0f));
        int chunk_size = FileVec.calcOptimalChunkSize(totSize, numCol, numCol * 4, H2O.ARGS.nthreads, H2O.getCloudSize(), false, false);
        double rows_per_chunk = chunk_size;
        int num_chunks = Vec.nChunksFor(numRow, (int)Math.ceil(Math.log1p(rows_per_chunk)), false);
        final Vec vec = SqlFetchMode.DISTRIBUTED.equals((Object)sqlFetchMode) ? ((num_retrieval_chunks = ConnectionPoolProvider.estimateConcurrentConnections(H2O.getCloudSize(), H2O.ARGS.nthreads)) >= num_chunks ? Vec.makeConN(numRow, num_chunks) : Vec.makeConN(numRow, num_retrieval_chunks)) : Vec.makeConN(numRow, num_chunks);
        Log.info("Number of chunks for data retrieval: " + vec.nChunks() + ", number of rows:" + numRow);
        final Key destination_key = Key.make((table + "_sql_to_hex").replaceAll("\\W", "_"));
        final Job<Frame> j = new Job<Frame>(destination_key, Frame.class.getName(), "Import SQL Table");
        final String finalTable = table;
        H2O.H2OCountedCompleter work = new H2O.H2OCountedCompleter(){

            @Override
            public void compute2() {
                ConnectionPoolProvider provider = new ConnectionPoolProvider(connection_url, username, password, vec.nChunks());
                Frame fr = SqlFetchMode.DISTRIBUTED.equals((Object)sqlFetchMode) ? ((SqlTableToH2OFrame)new SqlTableToH2OFrame(finalTable, databaseType, columns, columnNames, numCol, j, provider).doAll(columnH2OTypes, vec)).outputFrame(destination_key, columnNames, null) : new SqlTableToH2OFrameStreaming(finalTable, databaseType, columns, columnNames, numCol, j, provider).readTable(vec, columnH2OTypes, destination_key);
                vec.remove();
                DKV.put(fr);
                ParseDataset.logParseResults(fr);
                if (finalTable.equals(SQLManager.TEMP_TABLE_NAME)) {
                    SQLManager.dropTempTable(connection_url, username, password);
                }
                this.tryComplete();
            }
        };
        j.start(work, vec.nChunks());
        return j;
    }

    static String buildSelectSingleRowSql(String databaseType, String table, String columns) {
        switch (databaseType) {
            case "sqlserver": {
                return "SELECT TOP(1) " + columns + " FROM " + table;
            }
            case "oracle": {
                return "SELECT " + columns + " FROM " + table + " FETCH NEXT 1 ROWS ONLY";
            }
            case "teradata": {
                return "SELECT TOP 1 " + columns + " FROM " + table;
            }
        }
        return "SELECT " + columns + " FROM " + table + " LIMIT 1";
    }

    static String buildSelectChunkSql(String databaseType, String table, long start, int length, String columns, String[] columnNames) {
        String sqlText = "SELECT " + columns + " FROM " + table;
        switch (databaseType) {
            case "sqlserver": {
                sqlText = sqlText + " ORDER BY ROW_NUMBER() OVER (ORDER BY (SELECT 0))";
                sqlText = sqlText + " OFFSET " + start + " ROWS FETCH NEXT " + length + " ROWS ONLY";
                break;
            }
            case "oracle": {
                sqlText = sqlText + " OFFSET " + start + " ROWS FETCH NEXT " + length + " ROWS ONLY";
                break;
            }
            case "teradata": {
                sqlText = sqlText + " QUALIFY ROW_NUMBER() OVER (ORDER BY " + columnNames[0] + ") BETWEEN " + (start + 1L) + " AND " + (start + (long)length);
                break;
            }
            default: {
                sqlText = sqlText + " LIMIT " + length + " OFFSET " + start;
            }
        }
        return sqlText;
    }

    static void initializeDatabaseDriver(String databaseType) {
        String driverClass = System.getProperty(JDBC_DRIVER_CLASS_KEY_PREFIX + databaseType);
        if (driverClass != null) {
            Log.debug("Loading " + driverClass + " to initialize database of type " + databaseType);
            try {
                Class.forName(driverClass);
            }
            catch (ClassNotFoundException e) {
                throw new RuntimeException("Connection to '" + databaseType + "' database is not possible due to missing JDBC driver. User specified driver class: " + driverClass, e);
            }
            return;
        }
        switch (databaseType) {
            case "hive2": {
                try {
                    Class.forName(HIVE_JDBC_DRIVER_CLASS);
                    break;
                }
                catch (ClassNotFoundException e) {
                    throw new RuntimeException("Connection to HIVE database is not possible due to missing JDBC driver.", e);
                }
            }
            case "netezza": {
                try {
                    Class.forName(NETEZZA_JDBC_DRIVER_CLASS);
                    break;
                }
                catch (ClassNotFoundException e) {
                    throw new RuntimeException("Connection to Netezza database is not possible due to missing JDBC driver.", e);
                }
            }
        }
    }

    private static void dropTempTable(String connection_url, String username, String password) {
        Connection conn = null;
        Statement stmt = null;
        String drop_table_query = "DROP TABLE table_for_h2o_import";
        try {
            conn = DriverManager.getConnection(connection_url, username, password);
            stmt = conn.createStatement();
            stmt.executeUpdate(drop_table_query);
        }
        catch (SQLException ex) {
            throw new RuntimeException("SQLException: " + ex.getMessage() + "\nFailed to execute SQL query: " + drop_table_query, ex);
        }
        finally {
            if (stmt != null) {
                try {
                    stmt.close();
                }
                catch (SQLException sQLException) {}
                stmt = null;
            }
            if (conn != null) {
                try {
                    conn.close();
                }
                catch (SQLException sQLException) {}
                conn = null;
            }
        }
    }

    static class SqlTableToH2OFrame
    extends MRTask<SqlTableToH2OFrame> {
        final String _table;
        final String _columns;
        final String _databaseType;
        final int _numCol;
        final Job _job;
        final ConnectionPoolProvider _poolProvider;
        final String[] _columnNames;
        transient ArrayBlockingQueue<Connection> sqlConn;

        public SqlTableToH2OFrame(String table, String databaseType, String columns, String[] columnNames, int numCol, Job job, ConnectionPoolProvider poolProvider) {
            this._table = table;
            this._databaseType = databaseType;
            this._columns = columns;
            this._columnNames = columnNames;
            this._numCol = numCol;
            this._job = job;
            this._poolProvider = poolProvider;
        }

        @Override
        protected void setupLocal() {
            this.sqlConn = this._poolProvider.createConnectionPool();
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            if (this.isCancelled() || this._job != null && this._job.stop_requested()) {
                return;
            }
            Connection conn = null;
            Statement stmt = null;
            ResultSet rs = null;
            Chunk c0 = cs[0];
            String sqlText = SQLManager.buildSelectChunkSql(this._databaseType, this._table, c0.start(), c0._len, this._columns, this._columnNames);
            try {
                conn = this.sqlConn.take();
                stmt = conn.createStatement();
                stmt.setFetchSize(c0._len);
                rs = stmt.executeQuery(sqlText);
                while (rs.next()) {
                    SqlTableToH2OFrame.writeRow(rs, ncs);
                }
            }
            catch (SQLException ex) {
                throw new RuntimeException("SQLException: " + ex.getMessage() + "\nFailed to read SQL data", ex);
            }
            catch (InterruptedException e) {
                throw new RuntimeException("Interrupted exception when trying to take connection from pool", e);
            }
            finally {
                if (rs != null) {
                    try {
                        rs.close();
                    }
                    catch (SQLException sQLException) {}
                    rs = null;
                }
                if (stmt != null) {
                    try {
                        stmt.close();
                    }
                    catch (SQLException sQLException) {}
                    stmt = null;
                }
                this.sqlConn.add(conn);
            }
            if (this._job != null) {
                this._job.update(1L);
            }
        }

        static void writeRow(ResultSet rs, NewChunk[] ncs) throws SQLException {
            block28: for (int i = 0; i < ncs.length; ++i) {
                Object res = rs.getObject(i + 1);
                if (res == null) {
                    ncs[i].addNA();
                    continue;
                }
                switch (res.getClass().getSimpleName()) {
                    case "Double": {
                        ncs[i].addNum((Double)res);
                        continue block28;
                    }
                    case "Integer": {
                        ncs[i].addNum(((Integer)res).intValue(), 0);
                        continue block28;
                    }
                    case "Long": {
                        ncs[i].addNum((Long)res, 0);
                        continue block28;
                    }
                    case "Float": {
                        ncs[i].addNum(((Float)res).floatValue());
                        continue block28;
                    }
                    case "Short": {
                        ncs[i].addNum(((Short)res).shortValue(), 0);
                        continue block28;
                    }
                    case "Byte": {
                        ncs[i].addNum(((Byte)res).byteValue(), 0);
                        continue block28;
                    }
                    case "BigDecimal": {
                        ncs[i].addNum(((BigDecimal)res).doubleValue());
                        continue block28;
                    }
                    case "Boolean": {
                        ncs[i].addNum((Boolean)res != false ? 1 : 0, 0);
                        continue block28;
                    }
                    case "String": {
                        ncs[i].addStr(new BufferedString((String)res));
                        continue block28;
                    }
                    case "Date": {
                        ncs[i].addNum(((Date)res).getTime(), 0);
                        continue block28;
                    }
                    case "Time": {
                        ncs[i].addNum(((Time)res).getTime(), 0);
                        continue block28;
                    }
                    case "Timestamp": {
                        ncs[i].addNum(((Timestamp)res).getTime(), 0);
                        continue block28;
                    }
                    default: {
                        ncs[i].addNA();
                    }
                }
            }
        }

        @Override
        protected void closeLocal() {
            try {
                for (Connection conn : this.sqlConn) {
                    conn.close();
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
    }

    private static class FinalizeNewChunkTask
    extends H2O.H2OCountedCompleter<FinalizeNewChunkTask> {
        private final int _cidx;
        private transient NewChunk[] _ncs;

        FinalizeNewChunkTask(int cidx, NewChunk[] ncs) {
            this._cidx = cidx;
            this._ncs = ncs;
        }

        @Override
        public void compute2() {
            if (this._ncs == null) {
                throw new IllegalStateException("There are no chunks to work with!");
            }
            Futures fs = new Futures();
            for (NewChunk nc : this._ncs) {
                nc.close(this._cidx, fs);
            }
            fs.blockForPending();
            this.tryComplete();
        }
    }

    static class SqlTableToH2OFrameStreaming {
        final String _table;
        final String _columns;
        final String _databaseType;
        final int _numCol;
        final Job _job;
        final ConnectionPoolProvider _poolProvider;
        final String[] _columnNames;

        SqlTableToH2OFrameStreaming(String table, String databaseType, String columns, String[] columnNames, int numCol, Job job, ConnectionPoolProvider poolProvider) {
            this._table = table;
            this._databaseType = databaseType;
            this._columns = columns;
            this._columnNames = columnNames;
            this._numCol = numCol;
            this._job = job;
            this._poolProvider = poolProvider;
        }

        /*
         * Enabled aggressive block sorting
         * Enabled unnecessary exception pruning
         * Enabled aggressive exception aggregation
         */
        Frame readTable(Vec blueprint, byte[] columnTypes, Key<Frame> destinationKey) {
            Futures fs;
            AppendableVec[] res;
            block38: {
                Vec.VectorGroup vg = blueprint.group();
                int vecIdStart = vg.reserveKeys(columnTypes.length);
                res = new AppendableVec[columnTypes.length];
                long[] espc = MemoryManager.malloc8(blueprint.nChunks());
                for (int i = 0; i < res.length; ++i) {
                    res[i] = new AppendableVec(vg.vecKey(vecIdStart + i), espc, columnTypes[i], 0);
                }
                String query = "SELECT " + this._columns + " FROM " + this._table;
                ResultSet rs = null;
                fs = new Futures();
                try {
                    block37: {
                        try (Connection conn = this._poolProvider.createConnection();
                             Statement stmt = conn.createStatement();){
                            int fetchSize = (int)Math.min((double)blueprint.chunkLen(0), 100000.0);
                            stmt.setFetchSize(fetchSize);
                            rs = stmt.executeQuery(query);
                            for (int cidx = 0; cidx < blueprint.nChunks() && !this._job.stop_requested(); ++cidx) {
                                NewChunk[] ncs = new NewChunk[columnTypes.length];
                                for (int i = 0; i < columnTypes.length; ++i) {
                                    ncs[i] = res[i].chunkForChunkIdx(cidx);
                                }
                                int len = blueprint.chunkLen(cidx);
                                for (int r = 0; r < len; ++r) {
                                    if (!rs.next()) {
                                        long totalLen = blueprint.espc()[cidx] + (long)r;
                                        Log.warn("Query `" + query + "` returned less rows than expected. Actual: " + totalLen + ", expected: " + blueprint.length());
                                        break block37;
                                    }
                                    SqlTableToH2OFrame.writeRow(rs, ncs);
                                }
                                fs.add(H2O.submitTask(new FinalizeNewChunkTask(cidx, ncs)));
                                this._job.update(1L);
                            }
                        }
                    }
                    if (rs == null) break block38;
                }
                catch (SQLException e) {
                    try {
                        throw new RuntimeException("SQLException: " + e.getMessage() + "\nFailed to read SQL data", e);
                    }
                    catch (Throwable throwable) {
                        if (rs == null) throw throwable;
                        try {
                            rs.close();
                            throw throwable;
                        }
                        catch (SQLException sqlEx) {
                            Log.trace(sqlEx);
                        }
                        throw throwable;
                    }
                }
                try {
                    rs.close();
                }
                catch (SQLException sqlEx) {
                    Log.trace(sqlEx);
                }
            }
            fs.blockForPending();
            Vec[] vecs = AppendableVec.closeAll(res);
            return new Frame(destinationKey, this._columnNames, vecs);
        }
    }

    static class ConnectionPoolProvider
    extends Iced<ConnectionPoolProvider> {
        private String _url;
        private String _user;
        private String _password;
        private int _nChunks;

        ConnectionPoolProvider(String url, String user, String password, int nChunks) {
            this._url = url;
            this._user = user;
            this._password = password;
            this._nChunks = nChunks;
        }

        public ConnectionPoolProvider() {
        }

        ArrayBlockingQueue<Connection> createConnectionPool() {
            return this.createConnectionPool(H2O.getCloudSize(), H2O.ARGS.nthreads);
        }

        Connection createConnection() throws SQLException {
            return DriverManager.getConnection(this._url, this._user, this._password);
        }

        ArrayBlockingQueue<Connection> createConnectionPool(int cloudSize, short nThreads) throws RuntimeException {
            int maxConnectionsPerNode = ConnectionPoolProvider.getMaxConnectionsPerNode(cloudSize, nThreads, this._nChunks);
            Log.info("Database connections per node: " + maxConnectionsPerNode);
            ArrayBlockingQueue<Connection> connectionPool = new ArrayBlockingQueue<Connection>(maxConnectionsPerNode);
            try {
                for (int i = 0; i < maxConnectionsPerNode; ++i) {
                    Connection conn = this.createConnection();
                    connectionPool.add(conn);
                }
            }
            catch (SQLException ex) {
                throw new RuntimeException("SQLException: " + ex.getMessage() + "\nFailed to connect to SQL database with url: " + this._url, ex);
            }
            return connectionPool;
        }

        private static int getMaxConnectionsTotal() {
            int maxConnections = 100;
            String userDefinedMaxConnections = System.getProperty(SQLManager.MAX_USR_CONNECTIONS_KEY);
            try {
                Integer userMaxConnections = Integer.valueOf(userDefinedMaxConnections);
                if (userMaxConnections > 0 && userMaxConnections < 100) {
                    maxConnections = userMaxConnections;
                }
            }
            catch (NumberFormatException e) {
                Log.info("Unable to parse maximal number of connections: " + userDefinedMaxConnections + ". Falling back to default settings (" + 100 + ").", e);
            }
            return maxConnections;
        }

        static int getMaxConnectionsPerNode(int cloudSize, short nThreads, int nChunks) {
            return ConnectionPoolProvider.calculateLocalConnectionCount(ConnectionPoolProvider.getMaxConnectionsTotal(), cloudSize, nThreads, nChunks);
        }

        private static int calculateLocalConnectionCount(int maxTotalConnections, int cloudSize, short nThreads, int nChunks) {
            int conPerNode = (int)Math.min(Math.ceil((double)nChunks / (double)cloudSize), (double)nThreads);
            conPerNode = Math.min(conPerNode, maxTotalConnections / cloudSize);
            return Math.max(conPerNode, 1);
        }

        private static int estimateConcurrentConnections(int cloudSize, short nThreads) {
            return cloudSize * Math.min(nThreads, Math.max(ConnectionPoolProvider.getMaxConnectionsTotal() / cloudSize, 1));
        }
    }
}

