/*
 * Copyright (c) 2017 MuleSoft, Inc. This software is protected under international
 * copyright law. All use of this software is subject to MuleSoft's Master Subscription
 * Agreement (or other master license agreement) separately entered into in writing between
 * you and MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package org.mule.munit;

import org.mule.munit.exception.DatabaseServerException;

import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.StringWriter;
import java.io.Writer;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import au.com.bytecode.opencsv.CSVWriter;
import org.apache.commons.lang3.StringUtils;
import org.h2.tools.RunScript;
import org.junit.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DatabaseServer {

  private static Logger logger = LoggerFactory.getLogger(DatabaseServer.class);


  /**
   * <p>
   * H2 Database name
   * </p>
   */
  private String database;

  /**
   * <p>
   * Name of (or path to) the SQL file whose statements will be executed when the database is started
   * </p>
   */
  private String sqlFile;

  /**
   * <p>
   * CSV files (separated by semicolon) that creates tables in the database using the file name (without the termination, ".csv")
   * as the table name and its columns as the table columns
   * </p>
   */
  private String csv;

  /**
   * <p>
   * The database connection
   * </p>
   */
  private Connection connection;

  /**
   * <p>
   * Partial database connection string
   * </p>
   */
  private String connectionStringParameters;

  public DatabaseServer(String database, String sqlFile, String csv, String connectionStringParameters) {
    this.database = database;
    this.sqlFile = sqlFile;
    this.csv = csv;
    this.connectionStringParameters = connectionStringParameters;
  }


  /**
   * <p>
   * Starts the server
   * </p>
   * <p>
   * Executes the correspondent queries if an SQL file has been included in the dbserver configuration
   * </p>
   * <p>
   * Creates the correspondent tables in the database if a CSV file has been included in the dbserver configuration
   * </p>
   */
  public void start() {
    try {
      logger.info("Starting database server...");
      addJdbcToClassLoader();
      sanitizeDatabaseName();
      String connectionString = getConnectionString();
      connection = DriverManager.getConnection(connectionString);
      executeQueriesFromSQLFile(connection);
      Statement stmt = connection.createStatement();
      createTablesFromCsv(stmt);
    } catch (Exception e) {
      throw new DatabaseServerException("Could not start the database server", e);
    }
  }

  /**
   * <p>
   * Prevent the construction of a connection string
   * </p>
   */

  private void sanitizeDatabaseName() {
    database = StringUtils.substringBefore(StringUtils.trim(database), ";");
  }

  /**
   * <p>
   * Executes the SQL query received as parameter
   * </p>
   *
   * @param sql query to be executed
   * @return result of the SQL query received
   */
  public Boolean execute(String sql) {
    Statement statement = null;
    try {
      statement = connection.createStatement();
      return statement.execute(sql);
    } catch (SQLException e) {
      logger.error("There has been a problem while executing the SQL statement", e);
      throw new DatabaseServerException("There has been a problem while executing the SQL statement", e);
    }
  }

  /**
   * <p>
   * Executes a SQL query
   * </p>
   *
   * @param sql query to be executed
   * @return result of the SQL query in a JSON format.
   */
  public List<Map<String, String>> executeQuery(String sql) {
    try {
      return getMap(sql);
    } catch (SQLException e) {
      logger.error("There has been a problem while executing the SQL statement", e);
      throw new DatabaseServerException("There has been a problem while executing the SQL statement", e);
    }
  }

  /**
   * <p>
   * Executes a SQL query
   * </p>
   *
   * @param query query to be executed
   * @param returns Expected value
   */
  public void validateThat(String query, String returns) {
    try {
      Writer writerQueryResult = getResults(query);
      String expected = returns.replace("\\n", "\n");
      String actual = writerQueryResult.toString().trim();
      Assert.assertEquals(expected, actual);
    } catch (org.junit.ComparisonFailure e) {
      throw new AssertionError(e.getMessage());
    } catch (ClassCastException ccException) {
      throw new DatabaseServerException("The JSON String must always be an array");
    } catch (SQLException e) {
      throw new DatabaseServerException("Invalid Query");
    } catch (IOException e) {
      throw new DatabaseServerException("Could no access to query results");
    }

  }

  /**
   * <p>
   * Stops the server.
   * </p>
   */
  public void stop() {
    logger.info("Stopping database server ...");
    try {
      if (connection != null) {
        connection.close();
      }
    } catch (SQLException e) {
      throw new RuntimeException("Could not stop the database server", e);
    }
  }

  private void addJdbcToClassLoader() throws InstantiationException,
      IllegalAccessException, ClassNotFoundException {
    Class.forName("org.h2.Driver").newInstance();
  }

  private void executeQueriesFromSQLFile(Connection conn) throws SQLException, FileNotFoundException {
    if (sqlFile != null) {
      logger.info("Loading " + sqlFile + " ...");
      InputStream streamImput = getClass().getClassLoader().getResourceAsStream(sqlFile);
      if (streamImput != null) {
        RunScript.execute(conn, new InputStreamReader(streamImput));
      } else {
        throw new RuntimeException("The SQL file " + sqlFile + " could not be found");
      }
    }
  }

  private void createTablesFromCsv(Statement stmt) {
    if (csv != null) {
      logger.info("Loading " + csv + " ...");
      String[] tables = csv.split(";");
      boolean isCaseSensitive = isDatabaseToUpperParameterSet(connectionStringParameters);
      for (String table : tables) {
        String tableName = table.replaceAll(".csv", "");
        try {
          StringBuilder command = new StringBuilder();
          command.append("CREATE TABLE `" + tableName + "` AS SELECT * FROM CSVREAD(\'classpath:" + table + "\'");
          if (isCaseSensitive) {
            command.append(", null, 'caseSensitiveColumnNames=true'");
          }
          command.append(");");
          stmt.execute(command.toString());
        } catch (SQLException e) {
          throw new RuntimeException("Could not create table " + tableName + " from " + table + " : " + e.getCause(), e);
        }
      }
    }
  }

  private boolean isDatabaseToUpperParameterSet(String connectionStringParameters) {
    String[] parameters = StringUtils.split(connectionStringParameters, ';');
    if (parameters != null) {
      for (String parameter : parameters) {
        if (StringUtils.containsIgnoreCase(parameter, "DATABASE_TO_UPPER")
            && StringUtils.containsIgnoreCase(parameter, "FALSE")) {
          return true;
        }
      }
    }
    return false;
  }

  private List<Map<String, String>> getMap(String sql) throws SQLException {
    Statement statement;
    statement = connection.createStatement();
    ResultSet resultSet = statement.executeQuery(sql);
    List<Map<String, String>> jsonArray = new ArrayList<>();
    ResultSetMetaData metaData = resultSet.getMetaData();
    while (resultSet.next()) {
      HashMap<String, String> jsonObject = new HashMap<>();
      for (int i = 1; i <= metaData.getColumnCount(); i++) {
        String columnName = metaData.getColumnName(i);
        jsonObject.put(columnName, String.valueOf(resultSet.getObject(columnName)));
      }
      jsonArray.add(jsonObject);
    }
    return jsonArray;
  }

  private Writer getResults(String sql) throws SQLException, IOException {
    Statement statement;
    statement = connection.createStatement();
    ResultSet resultSet = statement.executeQuery(sql);

    Writer writer = new StringWriter();
    CSVWriter csvwriter = new CSVWriter(writer);
    csvwriter.writeAll(resultSet, true);

    return writer;
  }


  private String getConnectionString() {
    String connectionString = "jdbc:h2:mem:" + database;
    if (StringUtils.isNotBlank(connectionStringParameters)) {
      connectionString += ";" + connectionStringParameters;
    }
    logger.debug("Connection string: " + connectionString);
    return connectionString;
  }
}
