package org.mariadb.jdbc;

/*
MariaDB Client for Java

Copyright (c) 2012 Monty Program Ab.

This library is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License as published by the Free
Software Foundation; either version 2.1 of the License, or (at your option)
any later version.

This library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
for more details.

You should have received a copy of the GNU Lesser General Public License along
with this library; if not, write to Monty Program Ab info@montyprogram.com.

This particular MariaDB Client for Java file is work
derived from a Drizzle-JDBC. Drizzle-JDBC file which is covered by subject to
the following copyright and notice provisions:

Copyright (c) 2009-2011, Marcus Eriksson, Trond Norbye, Stephane Giron

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list
of conditions and the following disclaimer.

Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or
other materials provided with the distribution.

Neither the name of the driver nor the names of its contributors may not be
used to endorse or promote products derived from this software without specific
prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS  AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
OF SUCH DAMAGE.
*/

import org.mariadb.jdbc.internal.queryresults.ResultSetType;
import org.mariadb.jdbc.internal.packet.dao.parameters.ParameterHolder;
import org.mariadb.jdbc.internal.util.ExceptionMapper;
import org.mariadb.jdbc.internal.util.dao.QueryException;
import java.sql.*;
import java.util.*;


public class MariaDbClientPreparedStatement extends AbstractMariaDbPrepareStatement {
    private final String sqlQuery;
    private List<String> queryParts;
    private ParameterHolder[] parameters;
    private List<ParameterHolder[]> parameterList = new ArrayList<>();
    private int paramCount;
    private ResultSetMetaData resultSetMetaData = null;
    private ParameterMetaData parameterMetaData = null;

    /**
     * Constructor.
     * @param connection connection
     * @param sql sql query
     * @param autoGeneratedKeys auto generated keys (after insert).
     * @throws SQLException exception
     */
    public MariaDbClientPreparedStatement(MariaDbConnection connection,
                                          String sql, int autoGeneratedKeys) throws SQLException {
        super(connection, autoGeneratedKeys);
        this.sqlQuery = sql;

        useFractionalSeconds = connection.getProtocol().getOptions().useFractionalSeconds;
        queryParts = createRewritableParts(sql, connection.noBackslashEscapes);
        paramCount = queryParts.size() - 3;
        parameters = new ParameterHolder[paramCount];
    }

    @Override
    protected boolean isNoBackslashEscapes() {
        return connection.noBackslashEscapes;
    }

    @Override
    protected boolean useFractionalSeconds() {
        return useFractionalSeconds;
    }

    @Override
    protected Calendar cal() {
        return protocol.getCalendar();
    }


    /**
     * Executes the SQL statement in this <code>PreparedStatement</code> object,
     * which may be any kind of SQL statement.
     * Some prepared statements return multiple results; the <code>execute</code>
     * method handles these complex statements as well as the simpler
     * form of statements handled by the methods <code>executeQuery</code>
     * and <code>executeUpdate</code>.
     * <br>
     * The <code>execute</code> method returns a <code>boolean</code> to
     * indicate the form of the first result.  You must call either the method
     * <code>getResultSet</code> or <code>getUpdateCount</code>
     * to retrieve the result; you must call <code>getInternalMoreResults</code> to
     * move to any subsequent result(s).
     *
     * @return <code>true</code> if the first result is a <code>ResultSet</code>
     * object; <code>false</code> if the first result is an update
     * count or there is no result
     * @throws java.sql.SQLException if a database access error occurs;
     *                               this method is called on a closed <code>PreparedStatement</code>
     *                               or an argument is supplied to this method
     * @see java.sql.Statement#execute
     * @see java.sql.Statement#getResultSet
     * @see java.sql.Statement#getUpdateCount
     * @see java.sql.Statement#getMoreResults
     */
    public boolean execute() throws SQLException {
        return executeInternal();
    }

    /**
     * Executes the SQL query in this <code>PreparedStatement</code> object
     * and returns the <code>ResultSet</code> object generated by the query.
     *
     * @return a <code>ResultSet</code> object that contains the data produced by the
     * query; never <code>null</code>
     * @throws java.sql.SQLException if a database access error occurs;
     *                               this method is called on a closed  <code>PreparedStatement</code> or the SQL
     *                               statement does not return a <code>ResultSet</code> object
     */
    public ResultSet executeQuery() throws SQLException {
        if (executeInternal()) {
            return getResultSet();
        }
        return MariaDbResultSet.EMPTY;
    }



    /**
     * Executes the SQL statement in this <code>PreparedStatement</code> object, which must be an SQL Data Manipulation
     * Language (DML) statement, such as <code>INSERT</code>, <code>UPDATE</code> or <code>DELETE</code>; or an SQL
     * statement that returns nothing, such as a DDL statement.
     *
     * @return either (1) the row count for SQL Data Manipulation Language (DML) statements or (2) 0 for SQL statements
     * that return nothing
     * @throws java.sql.SQLException if a database access error occurs; this method is called on a closed
     *                               <code>PreparedStatement</code> or the SQL statement returns a
     *                               <code>ResultSet</code> object
     */
    public int executeUpdate() throws SQLException {
        if (executeInternal()) {
            return 0;
        }
        return getUpdateCount();
    }


    protected boolean executeInternal() throws SQLException {
        executing = true;
        QueryException exception = null;
        lock.lock();
        try {
            executeQueryProlog();
            batchResultSet = null;
            queryResult = protocol.executeQueries(queryParts, Collections.singletonList(parameters), isStreaming(), false);
            cacheMoreResults();
            return (queryResult.getResultSetType() == ResultSetType.SELECT);
        } catch (QueryException e) {
            exception = e;
            return false;
        } finally {
            lock.unlock();
            executeQueryEpilog(exception);
            executing = false;
        }
    }

    /**
     * Adds a set of parameters to this <code>PreparedStatement</code> object's batch of send.
     * <br>
     * <br>
     *
     * @throws java.sql.SQLException if a database access error occurs or this method is called on a closed
     *                               <code>PreparedStatement</code>
     * @see java.sql.Statement#addBatch
     * @since 1.2
     */
    public void addBatch() throws SQLException {
        parameterList.add(parameters);
        clearParameters();
    }

    /**
     * Add batch.
     * @param sql typically this is a SQL <code>INSERT</code> or <code>UPDATE</code> statement
     * @throws java.sql.SQLException every time since that method is forbidden on prepareStatement
     */
    @Override
    public void addBatch(final String sql) throws SQLException {
        throw new SQLException("Cannot do addBatch(String) on preparedStatement");
    }

    /**
     * Clear batch.
     */
    @Override
    public void clearBatch() {
        parameterList.clear();
        this.parameters = new ParameterHolder[paramCount];
    }

    /**
     * {inheritdoc}.
     */
    public int[] executeBatch() throws SQLException {
        checkClose();
        int size = parameterList.size();
        if (size == 0) {
            return new int[0];
        }

        int[] ret = new int[size];
        int batchQueriesCount = 0;
        MariaDbResultSet rs = null;
        cachedResultSets.clear();
        lock.lock();
        try {
            if (isRewriteable && (protocol.getOptions().allowMultiQueries || protocol.getOptions().rewriteBatchedStatements)) {
                batchResultSet = null;
                boolean rewrittenBatch = isRewriteable && protocol.getOptions().rewriteBatchedStatements;
                executeRewriteQuery(rewrittenBatch);
                return rewrittenBatch ? getUpdateCountsForReWrittenBatch(size) : getUpdateRewrittenCounts();
            } else {
                for (; batchQueriesCount <  size; batchQueriesCount++) {
                    this.parameters = parameterList.get(batchQueriesCount);
                    executeInternal();
                    int updateCount = getUpdateCount();
                    if (updateCount == -1) {
                        ret[batchQueriesCount] = SUCCESS_NO_INFO;
                    } else {
                        ret[batchQueriesCount] = updateCount;
                    }
                    if (batchQueriesCount == 0) {
                        rs = (MariaDbResultSet) getInternalGeneratedKeys();
                    } else {
                        rs = rs.joinResultSets((MariaDbResultSet) getInternalGeneratedKeys());
                    }
                }
            }
        } catch (SQLException sqle) {
            throw new BatchUpdateException(sqle.getMessage(), sqle.getSQLState(), sqle.getErrorCode(), Arrays.copyOf(ret, batchQueriesCount), sqle);
        } finally {
            lock.unlock();
            clearBatch();
        }
        batchResultSet = rs;
        return ret;
    }

    /**
     * Execute statements. if many queries, those queries will be rewritten if isRewritable = false, the query will be agreggated : INSERT INTO jdbc
     * (`name`) VALUES ('Line 1: Lorem ipsum ...') INSERT INTO jdbc (`name`) VALUES ('Line 2: Lorem ipsum ...') will be agreggate as INSERT INTO jdbc
     * (`name`) VALUES ('Line 1: Lorem ipsum ...');INSERT INTO jdbc (`name`) VALUES ('Line 2: Lorem ipsum ...') and if isRewritable, agreggated as
     * INSERT INTO jdbc (`name`) VALUES ('Line 1: Lorem ipsum ...'),('Line 2: Lorem ipsum ...')
     *
     * @param isRewritable are the queries of the same type to be agreggated
     * @return true if there was a result set, false otherwise.
     * @throws SQLException the error description
     */
    private boolean executeRewriteQuery(boolean isRewritable) throws SQLException {
        executing = true;

        QueryException exception = null;
        lock.lock();
        try {
            executeQueryProlog();
            batchResultSet = null;

            queryResult = protocol.executeQueries(queryParts, parameterList, isStreaming(), isRewritable);
            cacheMoreResults();
            return (queryResult.getResultSetType() == ResultSetType.SELECT);
        } catch (QueryException e) {
            exception = e;
            return false;
        } finally {
            lock.unlock();
            executeQueryEpilog(exception);
            executing = false;
        }
    }

    /**
     * Retrieves a <code>ResultSetMetaData</code> object that contains information about the columns of the
     * <code>ResultSet</code> object that will be returned when this <code>PreparedStatement</code> object is executed.
     * <br>
     * Because a <code>PreparedStatement</code> object is precompiled, it is possible to know about the
     * <code>ResultSet</code> object that it will return without having to execute it.  Consequently, it is possible to
     * invoke the method <code>getMetaData</code> on a <code>PreparedStatement</code> object rather than waiting to
     * execute it and then invoking the <code>ResultSet.getMetaData</code> method on the <code>ResultSet</code> object
     * that is returned.
     * <br>
     * <B>NOTE:</B> Using this method may be expensive for some drivers due to the lack of underlying DBMS support.
     *
     * @return the description of a <code>ResultSet</code> object's columns or <code>null</code> if the driver cannot
     * return a <code>ResultSetMetaData</code> object
     * @throws java.sql.SQLException                    if a database access error occurs or this method is called on a closed
     *                                                  <code>PreparedStatement</code>
     * @throws java.sql.SQLFeatureNotSupportedException if the JDBC driver does not support this method
     * @since 1.2
     */
    public ResultSetMetaData getMetaData() throws SQLException {
        checkClose();
        ResultSet rs = getResultSet();
        if (rs != null) {
            return rs.getMetaData();
        }
        if (resultSetMetaData == null) {
            loadMetadata();
        }
        return resultSetMetaData;
    }


    protected void setParameter(final int parameterIndex, final ParameterHolder holder) throws SQLException {
        if (parameterIndex >= 1 && parameterIndex  < paramCount + 1) {
            parameters[parameterIndex - 1] = holder;
        } else {
            throw ExceptionMapper.getSqlException("Could not set parameter at position " + parameterIndex
                    + " (values vas " + holder.toString() + ")");
        }
    }


    /**
     * Retrieves the number, types and properties of this <code>PreparedStatement</code> object's parameters.
     *
     * @return a <code>ParameterMetaData</code> object that contains information about the number, types and properties
     * for each parameter marker of this <code>PreparedStatement</code> object
     * @throws java.sql.SQLException if a database access error occurs or this method is called on a closed
     *                               <code>PreparedStatement</code>
     * @see java.sql.ParameterMetaData
     * @since 1.4
     */
    public ParameterMetaData getParameterMetaData() throws SQLException {
        checkClose();
        if (parameterMetaData == null) {
            loadMetadata();
        }
        return parameterMetaData;
    }

    private void loadMetadata() throws SQLException {
        MariaDbServerPreparedStatement serverPreparedStatement = new MariaDbServerPreparedStatement(connection, this.sqlQuery,
                Statement.NO_GENERATED_KEYS);
        serverPreparedStatement.close();
        resultSetMetaData = serverPreparedStatement.getMetaData();
        parameterMetaData = serverPreparedStatement.getParameterMetaData();
    }

    /**
     * Clears the current parameter values immediately. <P>In general, parameter values remain in force for repeated use
     * of a statement. Setting a parameter value automatically clears its previous value.  However, in some cases it is
     * useful to immediately release the resources used by the current parameter values; this can be done by calling the
     * method <code>clearParameters</code>.
     */
    public void clearParameters() {
        parameters = new ParameterHolder[paramCount];
    }


    // Close prepared statement, maybe fire closed-statement events
    @Override
    public void close() throws SQLException {
        super.close();
        if (connection == null || connection.pooledConnection == null
                || connection.pooledConnection.statementEventListeners.isEmpty()) {
            return;
        }
    }

    public String toString() {
        return sqlQuery;
    }

    /**
     * Create query part if query rewritable
     *
     * Rewritable query part :
     *
     *  - pre value part
     *  - first assign part
     *  for each parameters :
     *       - part after parameter
     *  - after value part
     *
     * example : INSERT INTO TABLE(col1,col2,col3,col4, col5) VALUES (9, ?, 5, ?, 8) ON DUPLICATE KEY UPDATE col2=col2+10
     *
     *  - pre value part : INSERT INTO TABLE(col1,col2,col3,col4, col5) VALUES
     *  - first assign part : "(9 "
     *  - part after parameter 1: ", 5,"
     *     - ", 5,"
     *     - ",8)"
     *  - last part : ON DUPLICATE KEY UPDATE col2=col2+10
     *
     * @param queryString query String
     * @param noBackslashEscapes must backslash be escaped.
     * @return List of query part.
     */
    private List<String> createRewritableParts(String queryString, boolean noBackslashEscapes) {
        isRewriteable = true;
        List<String> partList = new ArrayList<>();
        LexState state = LexState.Normal;
        char lastChar = '\0';

        StringBuilder sb = new StringBuilder();

        boolean singleQuotes = false;
        boolean isParam;
        int valueIndex = -1;
        int isInParenthesis = 0;
        boolean isAfterValue = false;
        boolean skipChar = false;
        boolean addPartPreValue = false;
        boolean addPartAfterValue = false;
        boolean isFirstChar = true;
        boolean isInsert = false;
        boolean semicolon = false;

        char[] query = queryString.toCharArray();

        for (int i = 0; i < query.length; i++) {
            isParam = false;

            if (state == LexState.Escape) {
                sb.append(query[i]);
                state = LexState.String;
                continue;
            }

            char car = query[i];
            switch (car) {
                case '*':
                    if (state == LexState.Normal && lastChar == '/') {
                        state = LexState.SlashStarComment;
                    }
                    break;
                case '/':
                    if (state == LexState.SlashStarComment && lastChar == '*') {
                        state = LexState.Normal;
                    } else if (state == LexState.Normal && lastChar == '/') {
                        state = LexState.EOLComment;
                    }
                    break;

                case '#':
                    if (state == LexState.Normal) {
                        state = LexState.EOLComment;
                    }
                    break;

                case '-':
                    if (state == LexState.Normal && lastChar == '-') {
                        state = LexState.EOLComment;
                    }
                    break;

                case '\n':
                    if (state == LexState.EOLComment) {
                        state = LexState.Normal;
                    }
                    break;

                case '"':
                    if (state == LexState.Normal) {
                        state = LexState.String;
                        singleQuotes = false;
                    } else if (state == LexState.String && !singleQuotes) {
                        state = LexState.Normal;
                    }
                    break;
                case ';':
                    if (state == LexState.Normal) {
                        semicolon = true;
                    }
                    break;
                case '\'':
                    if (state == LexState.Normal) {
                        state = LexState.String;
                        singleQuotes = true;
                    } else if (state == LexState.String && singleQuotes) {
                        state = LexState.Normal;
                    }
                    break;

                case '\\':
                    if (noBackslashEscapes) {
                        break;
                    }
                    if (state == LexState.String) {
                        state = LexState.Escape;
                    }
                    break;

                case '?':
                    if (state == LexState.Normal) {
                        isParam = true;
                        if (isAfterValue) {
                            //having parameters after the last ")" of value is not rewritable
                            isRewriteable = false;
                        }
                    }
                    break;
                case '`':
                    if (state == LexState.Backtick) {
                        state = LexState.Normal;
                    } else if (state == LexState.Normal) {
                        state = LexState.Backtick;
                    }
                    break;

                case 's':
                case 'S':
                    if (state == LexState.Normal) {
                        if (valueIndex == -1
                                && query.length > i + 6
                                && (query[i + 1] == 'e' || query[i + 1] == 'E')
                                && (query[i + 2] == 'l' || query[i + 2] == 'L')
                                && (query[i + 3] == 'e' || query[i + 3] == 'E')
                                && (query[i + 4] == 'c' || query[i + 4] == 'C')
                                && (query[i + 5] == 't' || query[i + 5] == 'T')) {
                            //SELECT queries, INSERT FROM SELECT not rewritable
                            isRewriteable = false;
                        }
                    }
                    break;
                case 'v':
                case 'V':
                    if (state == LexState.Normal) {
                        if (valueIndex == -1
                                && (lastChar == ')' || ((byte) lastChar <= 40))
                                && query.length > i + 7
                                && (query[i + 1] == 'a' || query[i + 1] == 'A')
                                && (query[i + 2] == 'l' || query[i + 2] == 'L')
                                && (query[i + 3] == 'u' || query[i + 3] == 'U')
                                && (query[i + 4] == 'e' || query[i + 4] == 'E')
                                && (query[i + 5] == 's' || query[i + 5] == 'S')
                                && (query[i + 6] == '(' || ((byte) query[i + 6] <= 40))) {
                            sb.append(car);
                            sb.append(query[i + 1]);
                            sb.append(query[i + 2]);
                            sb.append(query[i + 3]);
                            sb.append(query[i + 4]);
                            sb.append(query[i + 5]);
                            i = i + 5;
                            partList.add(sb.toString());
                            sb.setLength(0);
                            valueIndex = i + 6;
                            skipChar = true;
                        }
                    }
                    break;
                case '(':
                    if (state == LexState.Normal) {
                        isInParenthesis++;
                    }
                    break;
                case ')':
                    if (state == LexState.Normal) {
                        isInParenthesis--;
                        if (isInParenthesis == 0 && valueIndex != -1 && !addPartAfterValue) {
                            //after the values data
                            isAfterValue = true;
                            sb.append(car);
                            partList.add(sb.toString());
                            sb.setLength(0);
                            skipChar = true;
                            addPartAfterValue = true;
                        }
                    }
                    break;
                default:
                    if (state == LexState.Normal && isFirstChar && ((byte) car >= 40)) {
                        if (car == 'I' || car == 'i') {
                            isInsert = true;
                        }
                        isFirstChar = false;
                    }
                    if (state == LexState.Normal && semicolon && ((byte) lastChar >= 40)) {
                        //multiple queries
                        isRewriteable = false;
                    }
                    break;
            }

            lastChar = car;
            if (isParam) {
                partList.add(sb.toString());
                sb.setLength(0);
                if (valueIndex == -1 && !addPartPreValue) {
                    partList.add("");
                    isAfterValue = true;
                }
                addPartPreValue = true;
            } else {
                if (skipChar) {
                    skipChar = false;
                } else {
                    sb.append(car);
                }
            }
        }

        if (!addPartPreValue) {
            partList.add("");
        }
        if (!addPartAfterValue) {
            partList.add("");
        }

        partList.add(sb.toString());
        if (!isInsert) {
            isRewriteable = false;
        }
        return partList;
    }
}
