/*
 * Copyright 2015, The Querydsl Team (http://www.querydsl.com/team)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * http://www.apache.org/licenses/LICENSE-2.0
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.querydsl.sql.dml;

import com.querydsl.core.DefaultQueryMetadata;
import com.querydsl.core.FilteredClause;
import com.querydsl.core.JoinType;
import com.querydsl.core.QueryFlag;
import com.querydsl.core.QueryFlag.Position;
import com.querydsl.core.QueryMetadata;
import com.querydsl.core.dml.StoreClause;
import com.querydsl.core.types.ConstantImpl;
import com.querydsl.core.types.Expression;
import com.querydsl.core.types.ExpressionUtils;
import com.querydsl.core.types.NullExpression;
import com.querydsl.core.types.Path;
import com.querydsl.core.types.SubQueryExpression;
import com.querydsl.core.types.dsl.Expressions;
import com.querydsl.core.types.dsl.SimpleExpression;
import com.querydsl.core.util.CollectionUtils;
import com.querydsl.core.util.ResultSetAdapter;
import com.querydsl.sql.ColumnMetadata;
import com.querydsl.sql.Configuration;
import com.querydsl.sql.RelationalPath;
import com.querydsl.sql.SQLBindings;
import com.querydsl.sql.SQLListener;
import com.querydsl.sql.SQLNoCloseListener;
import com.querydsl.sql.SQLQuery;
import com.querydsl.sql.SQLSerializer;
import com.querydsl.sql.SQLTemplates;
import com.querydsl.sql.types.Null;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.logging.Logger;
import org.jetbrains.annotations.Nullable;

/**
 * {@code SQLMergeClause} defines an MERGE INTO clause
 *
 * @author tiwe
 */
public class SQLMergeClause extends AbstractSQLClause<SQLMergeClause>
    implements StoreClause<SQLMergeClause> {

  protected static final Logger logger = Logger.getLogger(SQLMergeClause.class.getName());

  protected final List<Path<?>> columns = new ArrayList<>();

  protected final RelationalPath<?> entity;

  protected final QueryMetadata metadata = new DefaultQueryMetadata();

  protected final List<Path<?>> keys = new ArrayList<>();

  @Nullable protected SubQueryExpression<?> subQuery;

  protected final List<SQLMergeBatch> batches = new ArrayList<>();

  protected final List<Expression<?>> values = new ArrayList<>();

  protected transient String queryString;

  protected transient List<Object> constants;

  public SQLMergeClause(Connection connection, SQLTemplates templates, RelationalPath<?> entity) {
    this(connection, new Configuration(templates), entity);
  }

  public SQLMergeClause(
      Connection connection, Configuration configuration, RelationalPath<?> entity) {
    super(configuration, connection);
    this.entity = entity;
    metadata.addJoin(JoinType.DEFAULT, entity);
  }

  public SQLMergeClause(
      Supplier<Connection> connection, Configuration configuration, RelationalPath<?> entity) {
    super(configuration, connection);
    this.entity = entity;
    metadata.addJoin(JoinType.DEFAULT, entity);
  }

  /**
   * Add the given String literal at the given position as a query flag
   *
   * @param position position
   * @param flag query flag
   * @return the current object
   */
  public SQLMergeClause addFlag(Position position, String flag) {
    metadata.addFlag(new QueryFlag(position, flag));
    return this;
  }

  /**
   * Add the given Expression at the given position as a query flag
   *
   * @param position position
   * @param flag query flag
   * @return the current object
   */
  public SQLMergeClause addFlag(Position position, Expression<?> flag) {
    metadata.addFlag(new QueryFlag(position, flag));
    return this;
  }

  public SQLMergeUsingClause using(SimpleExpression<?> dataQuery) {
    clear();
    return new SQLMergeUsingClause(connection(), configuration, entity, dataQuery);
  }

  protected List<? extends Path<?>> getKeys() {
    if (!keys.isEmpty()) {
      return keys;
    } else if (entity.getPrimaryKey() != null) {
      return entity.getPrimaryKey().getLocalColumns();
    } else {
      throw new IllegalStateException("No keys were defined, invoke keys(..) to add keys");
    }
  }

  /**
   * Add the current state of bindings as a batch item
   *
   * @return the current object
   */
  public SQLMergeClause addBatch() {
    if (!configuration.getTemplates().isNativeMerge()) {
      throw new IllegalStateException(
          "batch only supported for databases that support native merge");
    }

    batches.add(new SQLMergeBatch(keys, columns, values, subQuery));
    columns.clear();
    values.clear();
    keys.clear();
    subQuery = null;
    return this;
  }

  @Override
  public void clear() {
    batches.clear();
    columns.clear();
    values.clear();
    keys.clear();
    subQuery = null;
  }

  public SQLMergeClause columns(Path<?>... columns) {
    this.columns.addAll(Arrays.asList(columns));
    return this;
  }

  /**
   * Execute the clause and return the generated key with the type of the given path. If no rows
   * were created, null is returned, otherwise the key of the first row is returned.
   *
   * @param <T>
   * @param path path for key
   * @return generated key
   */
  @SuppressWarnings("unchecked")
  @Nullable
  public <T> T executeWithKey(Path<T> path) {
    return executeWithKey((Class<T>) path.getType(), path);
  }

  /**
   * Execute the clause and return the generated key cast to the given type. If no rows were
   * created, null is returned, otherwise the key of the first row is returned.
   *
   * @param <T>
   * @param type type of key
   * @return generated key
   */
  public <T> T executeWithKey(Class<T> type) {
    return executeWithKey(type, null);
  }

  protected <T> T executeWithKey(Class<T> type, @Nullable Path<T> path) {
    var rs = executeWithKeys();
    try {
      if (rs.next()) {
        return configuration.get(rs, path, 1, type);
      } else {
        return null;
      }
    } catch (SQLException e) {
      throw configuration.translate(e);
    } finally {
      close(rs);
    }
  }

  /**
   * Execute the clause and return the generated key with the type of the given path. If no rows
   * were created, or the referenced column is not a generated key, null is returned. Otherwise, the
   * key of the first row is returned.
   *
   * @param <T>
   * @param path path for key
   * @return generated keys
   */
  @SuppressWarnings("unchecked")
  public <T> List<T> executeWithKeys(Path<T> path) {
    return executeWithKeys((Class<T>) path.getType(), path);
  }

  public <T> List<T> executeWithKeys(Class<T> type) {
    return executeWithKeys(type, null);
  }

  protected <T> List<T> executeWithKeys(Class<T> type, @Nullable Path<T> path) {
    ResultSet rs = null;
    try {
      rs = executeWithKeys();
      List<T> rv = new ArrayList<>();
      while (rs.next()) {
        rv.add(configuration.get(rs, path, 1, type));
      }
      return rv;
    } catch (SQLException e) {
      throw configuration.translate(e);
    } finally {
      if (rs != null) {
        close(rs);
      }
      reset();
    }
  }

  /**
   * Execute the clause and return the generated keys as a ResultSet
   *
   * @return result set with generated keys
   */
  public ResultSet executeWithKeys() {
    context = startContext(connection(), metadata, entity);
    try {
      if (configuration.getTemplates().isNativeMerge()) {
        PreparedStatement stmt = null;
        if (batches.isEmpty()) {
          stmt = createStatement(true);
          listeners.notifyMerge(entity, metadata, keys, columns, values, subQuery);

          listeners.preExecute(context);
          stmt.executeUpdate();
          listeners.executed(context);
        } else {
          var stmts = createStatements(true);
          if (stmts != null && stmts.size() > 1) {
            throw new IllegalStateException(
                "executeWithKeys called with batch statement and multiple SQL strings");
          }
          stmt = stmts.iterator().next();
          listeners.notifyMerges(entity, metadata, batches);

          listeners.preExecute(context);
          stmt.executeBatch();
          listeners.executed(context);
        }

        final Statement stmt2 = stmt;
        var rs = stmt.getGeneratedKeys();
        return new ResultSetAdapter(rs) {
          @Override
          public void close() throws SQLException {
            try {
              super.close();
            } finally {
              stmt2.close();
              reset();
              endContext(context);
            }
          }
        };
      } else {
        if (hasRow()) {
          // update
          var update = new SQLUpdateClause(connection(), configuration, entity);
          update.addListener(listeners);
          populate(update);
          addKeyConditions(update);
          reset();
          endContext(context);
          return EmptyResultSet.DEFAULT;
        } else {
          // insert
          var insert = new SQLInsertClause(connection(), configuration, entity);
          insert.addListener(listeners);
          populate(insert);
          return insert.executeWithKeys();
        }
      }
    } catch (SQLException e) {
      onException(context, e);
      reset();
      endContext(context);
      throw configuration.translate(queryString, constants, e);
    }
  }

  @Override
  public long execute() {
    if (configuration.getTemplates().isNativeMerge()) {
      return executeNativeMerge();
    } else {
      return executeCompositeMerge();
    }
  }

  @Override
  public List<SQLBindings> getSQL() {
    if (batches.isEmpty()) {
      var serializer = createSerializer();
      serializer.serializeMerge(metadata, entity, keys, columns, values, subQuery);
      return Collections.singletonList(createBindings(metadata, serializer));
    } else {
      List<SQLBindings> builder = new ArrayList<>();
      for (SQLMergeBatch batch : batches) {
        var serializer = createSerializer();
        serializer.serializeMerge(
            metadata,
            entity,
            batch.getKeys(),
            batch.getColumns(),
            batch.getValues(),
            batch.getSubQuery());
        builder.add(createBindings(metadata, serializer));
      }
      return CollectionUtils.unmodifiableList(builder);
    }
  }

  protected boolean hasRow() {
    SQLQuery<?> query = new SQLQuery<Void>(connection(), configuration).from(entity);
    for (SQLListener listener : listeners.getListeners()) {
      query.addListener(listener);
    }
    query.addListener(SQLNoCloseListener.DEFAULT);
    addKeyConditions(query);
    return query.select(Expressions.ONE).fetchFirst() != null;
  }

  @SuppressWarnings("unchecked")
  protected void addKeyConditions(FilteredClause<?> query) {
    List<? extends Path<?>> keys = getKeys();
    for (var i = 0; i < columns.size(); i++) {
      if (keys.contains(columns.get(i))) {
        if (values.get(i) instanceof NullExpression) {
          query.where(ExpressionUtils.isNull(columns.get(i)));
        } else {
          query.where(ExpressionUtils.eq(columns.get(i), (Expression) values.get(i)));
        }
      }
    }
  }

  @SuppressWarnings("unchecked")
  protected long executeCompositeMerge() {
    if (hasRow()) {
      // update
      var update = new SQLUpdateClause(connection(), configuration, entity);
      populate(update);
      addListeners(update);
      addKeyConditions(update);
      return update.execute();
    } else {
      // insert
      var insert = new SQLInsertClause(connection(), configuration, entity);
      addListeners(insert);
      populate(insert);
      return insert.execute();
    }
  }

  protected void addListeners(AbstractSQLClause<?> clause) {
    for (SQLListener listener : listeners.getListeners()) {
      clause.addListener(listener);
    }
  }

  @SuppressWarnings("unchecked")
  protected void populate(StoreClause<?> clause) {
    for (var i = 0; i < columns.size(); i++) {
      clause.set((Path) columns.get(i), (Object) values.get(i));
    }
  }

  protected PreparedStatement createStatement(boolean withKeys) throws SQLException {
    var addBatches = !configuration.getUseLiterals();
    listeners.preRender(context);
    var serializer = createSerializer();
    PreparedStatement stmt = null;
    if (batches.isEmpty()) {
      serializer.serializeMerge(metadata, entity, keys, columns, values, subQuery);
      context.addSQL(createBindings(metadata, serializer));
      listeners.rendered(context);

      listeners.prePrepare(context);
      stmt = prepareStatementAndSetParameters(serializer, withKeys);
      context.addPreparedStatement(stmt);
      listeners.prepared(context);
    } else {
      serializer.serializeMerge(
          metadata,
          entity,
          batches.get(0).getKeys(),
          batches.get(0).getColumns(),
          batches.get(0).getValues(),
          batches.get(0).getSubQuery());
      context.addSQL(createBindings(metadata, serializer));
      listeners.rendered(context);

      stmt = prepareStatementAndSetParameters(serializer, withKeys);

      // add first batch
      if (addBatches) {
        stmt.addBatch();
      }

      // add other batches
      for (var i = 1; i < batches.size(); i++) {
        var batch = batches.get(i);
        listeners.preRender(context);
        serializer = createSerializer();
        serializer.serializeMerge(
            metadata,
            entity,
            batch.getKeys(),
            batch.getColumns(),
            batch.getValues(),
            batch.getSubQuery());
        context.addSQL(createBindings(metadata, serializer));
        listeners.rendered(context);

        setParameters(
            stmt, serializer.getConstants(), serializer.getConstantPaths(), metadata.getParams());
        if (addBatches) {
          stmt.addBatch();
        }
      }
    }
    return stmt;
  }

  protected Collection<PreparedStatement> createStatements(boolean withKeys) throws SQLException {
    var addBatches = !configuration.getUseLiterals();
    Map<String, PreparedStatement> stmts = new HashMap<>();

    // add first batch
    listeners.preRender(context);
    var serializer = createSerializer();
    serializer.serializeMerge(
        metadata,
        entity,
        batches.get(0).getKeys(),
        batches.get(0).getColumns(),
        batches.get(0).getValues(),
        batches.get(0).getSubQuery());
    context.addSQL(createBindings(metadata, serializer));
    listeners.rendered(context);

    var stmt = prepareStatementAndSetParameters(serializer, withKeys);
    stmts.put(serializer.toString(), stmt);
    if (addBatches) {
      stmt.addBatch();
    }

    // add other batches
    for (var i = 1; i < batches.size(); i++) {
      var batch = batches.get(i);
      serializer = createSerializer();
      serializer.serializeMerge(
          metadata,
          entity,
          batch.getKeys(),
          batch.getColumns(),
          batch.getValues(),
          batch.getSubQuery());
      stmt = stmts.get(serializer.toString());
      if (stmt == null) {
        stmt = prepareStatementAndSetParameters(serializer, withKeys);
        stmts.put(serializer.toString(), stmt);
      } else {
        setParameters(
            stmt, serializer.getConstants(), serializer.getConstantPaths(), metadata.getParams());
      }
      if (addBatches) {
        stmt.addBatch();
      }
    }

    return stmts.values();
  }

  protected PreparedStatement prepareStatementAndSetParameters(
      SQLSerializer serializer, boolean withKeys) throws SQLException {
    listeners.prePrepare(context);

    queryString = serializer.toString();
    constants = serializer.getConstants();
    logQuery(logger, queryString, constants);
    PreparedStatement stmt;
    if (withKeys) {
      var target = new String[keys.size()];
      for (var i = 0; i < target.length; i++) {
        target[i] = ColumnMetadata.getName(getKeys().get(i));
      }
      stmt = connection().prepareStatement(queryString, target);
    } else {
      stmt = connection().prepareStatement(queryString);
    }
    setParameters(
        stmt, serializer.getConstants(), serializer.getConstantPaths(), metadata.getParams());
    context.addPreparedStatement(stmt);
    listeners.prepared(context);

    return stmt;
  }

  protected long executeNativeMerge() {
    context = startContext(connection(), metadata, entity);
    PreparedStatement stmt = null;
    Collection<PreparedStatement> stmts = null;
    try {
      if (batches.isEmpty()) {
        stmt = createStatement(false);
        listeners.notifyMerge(entity, metadata, keys, columns, values, subQuery);

        listeners.preExecute(context);
        var rc = stmt.executeUpdate();
        listeners.executed(context);
        return rc;
      } else {
        stmts = createStatements(false);
        listeners.notifyMerges(entity, metadata, batches);

        listeners.preExecute(context);
        var rc = executeBatch(stmts);
        listeners.executed(context);
        return rc;
      }
    } catch (SQLException e) {
      onException(context, e);
      throw configuration.translate(queryString, constants, e);
    } finally {
      if (stmt != null) {
        close(stmt);
      }
      if (stmts != null) {
        close(stmts);
      }
      reset();
      endContext(context);
    }
  }

  /**
   * Set the keys to be used in the MERGE clause
   *
   * @param paths keys
   * @return the current object
   */
  public SQLMergeClause keys(Path<?>... paths) {
    keys.addAll(Arrays.asList(paths));
    return this;
  }

  public SQLMergeClause select(SubQueryExpression<?> subQuery) {
    this.subQuery = subQuery;
    return this;
  }

  @Override
  public <T> SQLMergeClause set(Path<T> path, @Nullable T value) {
    columns.add(path);
    if (value != null) {
      values.add(ConstantImpl.create(value));
    } else {
      values.add(Null.CONSTANT);
    }
    return this;
  }

  @Override
  public <T> SQLMergeClause set(Path<T> path, Expression<? extends T> expression) {
    columns.add(path);
    values.add(expression);
    return this;
  }

  @Override
  public <T> SQLMergeClause setNull(Path<T> path) {
    columns.add(path);
    values.add(Null.CONSTANT);
    return this;
  }

  @Override
  public String toString() {
    var serializer = createSerializer();
    serializer.serializeMerge(metadata, entity, keys, columns, values, subQuery);
    return serializer.toString();
  }

  public SQLMergeClause values(Object... v) {
    for (Object value : v) {
      if (value instanceof Expression<?>) {
        values.add((Expression<?>) value);
      } else if (value != null) {
        values.add(ConstantImpl.create(value));
      } else {
        values.add(Null.CONSTANT);
      }
    }
    return this;
  }

  @Override
  public boolean isEmpty() {
    return values.isEmpty() && batches.isEmpty();
  }

  @Override
  public int getBatchCount() {
    return batches.size();
  }
}
