/*
 * Copyright DataStax, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * Copyright (C) 2020 ScyllaDB
 *
 * Modified by ScyllaDB
 */
package com.datastax.driver.core;

import static com.datastax.driver.core.ProtocolVersion.V4;

import com.datastax.driver.core.policies.RetryPolicy;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableMap;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultPreparedStatement implements PreparedStatement {
  private static final Logger LOGGER = LoggerFactory.getLogger(DefaultPreparedStatement.class);
  private static final String SCYLLA_CDC_LOG_SUFFIX = "_scylla_cdc_log";
  private static final Splitter SPACE_SPLITTER = Splitter.onPattern("\\s+");
  private static final Splitter COMMA_SPLITTER = Splitter.onPattern(",");

  final PreparedId preparedId;

  final String query;
  final String queryKeyspace;
  final Map<String, ByteBuffer> incomingPayload;
  final Cluster cluster;
  final boolean isLWT;
  final Token.Factory partitioner;

  volatile ByteBuffer routingKey;

  volatile ConsistencyLevel consistency;
  volatile ConsistencyLevel serialConsistency;
  volatile boolean traceQuery;
  volatile RetryPolicy retryPolicy;
  volatile ImmutableMap<String, ByteBuffer> outgoingPayload;
  volatile Boolean idempotent;
  volatile boolean skipMetadata;

  private DefaultPreparedStatement(
      PreparedId id,
      String query,
      String queryKeyspace,
      Map<String, ByteBuffer> incomingPayload,
      Cluster cluster,
      boolean isLWT,
      Token.Factory partitioner) {
    this.preparedId = id;
    this.query = query;
    this.queryKeyspace = queryKeyspace;
    this.incomingPayload = incomingPayload;
    this.cluster = cluster;
    this.isLWT = isLWT;
    this.partitioner = partitioner;
    this.skipMetadata = this.calculateSkipMetadata();
  }

  static DefaultPreparedStatement fromMessage(
      Responses.Result.Prepared msg,
      Cluster cluster,
      String query,
      String queryKeyspace,
      LwtInfo lwtInfo) {
    assert msg.metadata.columns != null;

    ColumnDefinitions defs = msg.metadata.columns;

    ProtocolVersion protocolVersion =
        cluster.getConfiguration().getProtocolOptions().getProtocolVersion();
    PreparedId.PreparedMetadata boundValuesMetadata =
        new PreparedId.PreparedMetadata(msg.statementId, defs);
    PreparedId.PreparedMetadata resultSetMetadata =
        new PreparedId.PreparedMetadata(msg.resultMetadataId, msg.resultMetadata.columns);

    int[] pkIndices = null;
    if (defs.size() > 0) {
      pkIndices =
          (protocolVersion.compareTo(V4) >= 0)
              ? msg.metadata.pkIndices
              : computePkIndices(cluster.getMetadata(), defs);
    }

    PreparedId preparedId =
        new PreparedId(boundValuesMetadata, resultSetMetadata, pkIndices, protocolVersion);

    Token.Factory partitoner = partitioner(defs, cluster);

    return new DefaultPreparedStatement(
        preparedId,
        query,
        queryKeyspace,
        msg.getCustomPayload(),
        cluster,
        lwtInfo != null && lwtInfo.isLwt(msg.metadata.flags),
        partitoner);
  }

  private static int[] computePkIndices(Metadata clusterMetadata, ColumnDefinitions boundColumns) {
    List<ColumnMetadata> partitionKeyColumns = null;
    int[] pkIndexes = null;
    KeyspaceMetadata km = clusterMetadata.getKeyspace(Metadata.quote(boundColumns.getKeyspace(0)));
    if (km != null) {
      TableMetadata tm = km.getTable(Metadata.quote(boundColumns.getTable(0)));
      if (tm != null) {
        partitionKeyColumns = tm.getPartitionKey();
        pkIndexes = new int[partitionKeyColumns.size()];
        for (int i = 0; i < pkIndexes.length; ++i) pkIndexes[i] = -1;
      }
    }

    // Note: we rely on the fact CQL queries cannot span multiple tables. If that change, we'll have
    // to get smarter.
    for (int i = 0; i < boundColumns.size(); i++)
      maybeGetIndex(boundColumns.getName(i), i, partitionKeyColumns, pkIndexes);

    return allSet(pkIndexes) ? pkIndexes : null;
  }

  private static void maybeGetIndex(
      String name, int j, List<ColumnMetadata> pkColumns, int[] pkIndexes) {
    if (pkColumns == null) return;

    for (int i = 0; i < pkColumns.size(); ++i) {
      if (name.equals(pkColumns.get(i).getName())) {
        // We may have the same column prepared multiple times, but only pick the first value
        pkIndexes[i] = j;
        return;
      }
    }
  }

  private static boolean allSet(int[] pkColumns) {
    if (pkColumns == null) return false;

    for (int i = 0; i < pkColumns.length; ++i) if (pkColumns[i] < 0) return false;

    return true;
  }

  private static Token.Factory partitioner(ColumnDefinitions defs, Cluster cluster) {
    if (defs == null || defs.size() == 0) {
      return null;
    }

    String keyspace = defs.getKeyspace(0);
    String table = defs.getTable(0);
    if (table.endsWith(SCYLLA_CDC_LOG_SUFFIX)) {
      String baseTableName = table.substring(0, table.length() - SCYLLA_CDC_LOG_SUFFIX.length());
      KeyspaceMetadata keyspaceMetadata = cluster.getMetadata().getKeyspace(keyspace);
      if (keyspaceMetadata == null) {
        return null;
      }
      TableMetadata tableMetadata = keyspaceMetadata.getTable(baseTableName);
      if (tableMetadata != null && tableMetadata.options.isScyllaCDC()) {
        return Token.CDCToken.FACTORY;
      }
    }

    return null;
  }

  private boolean calculateSkipMetadata() {
    if (cluster.manager.protocolVersion() == ProtocolVersion.V1
        || preparedId.resultSetMetadata.variables == null) {
      // CQL1 does not support it.
      // If no rows returned there is no reason to send this flag, consequently, no metadata.
      return false;
    }

    if (preparedId.resultSetMetadata.id != null
        && preparedId.resultSetMetadata.id.bytes.length > 0) {
      // It is CQL 5 or higher.
      // Prepared statement invalidation works perfectly no need to disable skip metadata
      return true;
    }

    switch (cluster.getConfiguration().getQueryOptions().getSkipCQL4MetadataResolveMethod()) {
      case ENABLED:
        return true;
      case DISABLED:
        return false;
    }

    if (isWildcardSelect(query)) {
      LOGGER.warn(
          "Prepared statement {} is a wildcard select, which can cause prepared statement invalidation issues when executed on CQL4. "
              + "These issues may lead to broken deserialization or data corruption. "
              + "To mitigate this, the driver ensures that the server returns metadata with each query for such statements, "
              + "though this negatively impacts performance. "
              + "To avoid this, consider using a targeted select instead. "
              + "Alternatively, you can enable the skip-cql4-metadata-resolve-method option in the execution profile by setting it to `always-on`, "
              + "allowing the driver to ignore this issue and proceed regardless, risking broken deserialization or data corruption.",
          query);
      return false;
    }
    // Disable skipping metadata if results contains udt and
    for (ColumnDefinitions.Definition columnDefinition : preparedId.resultSetMetadata.variables) {
      if (containsUDT(columnDefinition.getType())) {
        LOGGER.warn(
            "Prepared statement {} contains UDT in result, which can cause prepared statement invalidation issues when executed on CQL4. "
                + "These issues may lead to broken deserialization or data corruption. "
                + "To mitigate this, the driver ensures that the server returns metadata with each query for such statements, "
                + "though this negatively impacts performance. "
                + "To avoid this, consider using a targeted select instead. "
                + "Alternatively, you can enable the skip-cql4-metadata-resolve-method option in the execution profile by setting it to `always-on`, "
                + "allowing the driver to ignore this issue and proceed regardless, risking broken deserialization or data corruption.",
            query);
        return false;
      }
    }
    return true;
  }

  public boolean isSkipMetadata() {
    return skipMetadata;
  }

  @Override
  public ColumnDefinitions getVariables() {
    return preparedId.boundValuesMetadata.variables;
  }

  @Override
  public BoundStatement bind(Object... values) {
    BoundStatement bs = new BoundStatement(this);
    return bs.bind(values);
  }

  @Override
  public BoundStatement bind() {
    return new BoundStatement(this);
  }

  @Override
  public PreparedStatement setRoutingKey(ByteBuffer routingKey) {
    this.routingKey = routingKey;
    return this;
  }

  @Override
  public PreparedStatement setRoutingKey(ByteBuffer... routingKeyComponents) {
    this.routingKey = SimpleStatement.compose(routingKeyComponents);
    return this;
  }

  @Override
  public ByteBuffer getRoutingKey() {
    return routingKey;
  }

  @Override
  public PreparedStatement setConsistencyLevel(ConsistencyLevel consistency) {
    this.consistency = consistency;
    return this;
  }

  @Override
  public ConsistencyLevel getConsistencyLevel() {
    return consistency;
  }

  @Override
  public PreparedStatement setSerialConsistencyLevel(ConsistencyLevel serialConsistency) {
    if (!serialConsistency.isSerial()) throw new IllegalArgumentException();
    this.serialConsistency = serialConsistency;
    return this;
  }

  @Override
  public ConsistencyLevel getSerialConsistencyLevel() {
    return serialConsistency;
  }

  @Override
  public String getQueryString() {
    return query;
  }

  @Override
  public String getQueryKeyspace() {
    return queryKeyspace;
  }

  @Override
  public PreparedStatement enableTracing() {
    this.traceQuery = true;
    return this;
  }

  @Override
  public PreparedStatement disableTracing() {
    this.traceQuery = false;
    return this;
  }

  @Override
  public boolean isTracing() {
    return traceQuery;
  }

  @Override
  public PreparedStatement setRetryPolicy(RetryPolicy policy) {
    this.retryPolicy = policy;
    return this;
  }

  @Override
  public RetryPolicy getRetryPolicy() {
    return retryPolicy;
  }

  @Override
  public Token.Factory getPartitioner() {
    return partitioner;
  }

  @Override
  public PreparedId getPreparedId() {
    return preparedId;
  }

  @Override
  public Map<String, ByteBuffer> getIncomingPayload() {
    return incomingPayload;
  }

  @Override
  public Map<String, ByteBuffer> getOutgoingPayload() {
    return outgoingPayload;
  }

  @Override
  public PreparedStatement setOutgoingPayload(Map<String, ByteBuffer> payload) {
    this.outgoingPayload = payload == null ? null : ImmutableMap.copyOf(payload);
    return this;
  }

  @Override
  public CodecRegistry getCodecRegistry() {
    return cluster.getConfiguration().getCodecRegistry();
  }

  /** {@inheritDoc} */
  @Override
  public PreparedStatement setIdempotent(Boolean idempotent) {
    this.idempotent = idempotent;
    return this;
  }

  /** {@inheritDoc} */
  @Override
  public Boolean isIdempotent() {
    return this.idempotent;
  }

  /** {@inheritDoc} */
  @Override
  public boolean isLWT() {
    return isLWT;
  }

  private static boolean containsUDT(DataType dataType) {
    if (dataType.isCollection()) {
      for (DataType elementType : dataType.getTypeArguments()) {
        if (containsUDT(elementType)) {
          return true;
        }
      }
      return false;
    }
    return dataType instanceof UserType;
  }

  private static boolean isWildcardSelect(String query) {
    List<String> chunks = SPACE_SPLITTER.splitToList(query.trim().toLowerCase());
    if (chunks.size() < 2) {
      // Weird query, assuming no result expected
      return false;
    }

    if (!chunks.get(0).equals("select")) {
      // In case if non-select sneaks in, disable skip metadata for it no result expected.
      return false;
    }

    for (String chunk : chunks) {
      if (chunk.equals("from")) {
        return false;
      }
      if (chunk.equals("*")) {
        return true;
      }
      for (String part : COMMA_SPLITTER.split(chunk)) {
        if (part.equals("*")) {
          return true;
        }
      }
    }
    return false;
  }
}
