/*
 * Copyright © MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.jms.commons.internal.source.polling;

import static java.lang.System.getProperty;
import static org.mule.jms.commons.internal.source.polling.JmsXaPollingMessageConsumerDelegate.CONSUMER_CONTEXT_VAR;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.jms.commons.api.connection.JmsSpecification;
import org.mule.jms.commons.internal.common.JmsCommons;
import org.mule.jms.commons.internal.config.InternalAckMode;
import org.mule.jms.commons.internal.config.JmsConfig;
import org.mule.jms.commons.internal.connection.XaJmsTransactionalConnection;
import org.mule.jms.commons.internal.connection.session.JmsSession;
import org.mule.jms.commons.internal.connection.session.JmsSessionManager;
import org.mule.jms.commons.internal.source.JmsConnectionExceptionResolver;
import org.mule.jms.commons.internal.source.JmsMessageDispatcher;
import org.mule.jms.commons.internal.source.NullJmsListenerLock;
import org.mule.runtime.api.connection.ConnectionException;
import org.mule.runtime.api.connection.ConnectionProvider;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.tx.TransactionException;
import org.mule.runtime.api.util.Pair;
import org.mule.runtime.core.api.transaction.TransactionCoordination;
import org.mule.runtime.extension.api.connectivity.XATransactionalConnection;
import org.mule.runtime.extension.api.runtime.source.SourceCallback;
import org.mule.runtime.extension.api.runtime.source.SourceCallbackContext;

import java.lang.reflect.Field;
import java.util.concurrent.CountDownLatch;

import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.MessageConsumer;

import org.slf4j.Logger;

/**
 * Jms Message consumer which polls a destination inside a XA Transaction.
 *
 * @since 1.0
 */
public class JmsXaMessageConsumer implements Runnable {

  protected static final int MAX_TRANSACTION_BINDING_RETRIES = 10;
  private static final boolean TX_COMMIT_ON_EMPTY_MESSAGE_ENABLED =
      Boolean.valueOf(getProperty("enableTxCommitOnEmptyMessage", "false"));
  private static final Logger LOGGER = getLogger(JmsXaMessageConsumer.class);
  private static final long POLLING_TIMEOUT = 2000L;
  private final int id;
  private final MessageConsumerFactory consumer;
  private final SourceCallback sourceCallback;
  private final JmsSessionManager sessionManager;
  private final ConnectionProvider connectionProvider;
  private final JmsMessageDispatcher dispatcher;
  private CountDownLatch initializationCountDownLatch;
  private JmsSession session;
  private boolean stopRequested = false;

  JmsXaMessageConsumer(MessageConsumerFactory consumerFactory, SourceCallback sourceCallback,
                       JmsSessionManager sessionManager, ConnectionProvider connectionProvider, JmsConfig config,
                       String inboundContentType, String inboundEncoding, JmsSpecification specification, int id,
                       CountDownLatch initializationCountDownLatch, JmsConnectionExceptionResolver exceptionResolver) {
    this.consumer = consumerFactory;
    this.sourceCallback = sourceCallback;
    this.sessionManager = sessionManager;
    this.connectionProvider = connectionProvider;
    this.id = id;
    this.initializationCountDownLatch = initializationCountDownLatch;
    // In this case, the exceptionResolver is just a pass-through argument used to build the JmsMessageDispatcher
    this.dispatcher =
        new JmsMessageDispatcher(config, inboundContentType, inboundEncoding, specification, () -> session,
                                 InternalAckMode.TRANSACTED,
                                 sessionManager, sourceCallback, new NullJmsListenerLock(), exceptionResolver);
  }

  @Override
  public synchronized void run() {
    MessageConsumer messageConsumer = null;
    try {
      boolean shouldKeepIterating = true;

      LOGGER.debug("[{}] : Starting to poll", id);

      while (shouldKeepIterating && !stopRequested) {
        try {
          // Created session will be closed at XAResource.end(..)
          Pair<SourceCallbackContext, JmsSession> jmsContext = initializePoll();
          session = jmsContext.getSecond();
          messageConsumer = consumer.createConsumer(session).get();
          Message message = messageConsumer.receive(POLLING_TIMEOUT);
          shouldKeepIterating = dispatchMessage(message, jmsContext.getFirst());
        } catch (JMSException e) {
          LOGGER.error("[" + id + "] : Unknown error when trying to poll message", e);
        } catch (Exception e) {
          throw e;
        } finally {
          if (shouldKeepIterating) {
            TransactionCoordination.getInstance().rollbackCurrentTransaction();
          }
        }
      }
      if (stopRequested) {
        LOGGER.debug("[{}] : Stopping poll", id);
      }
      LOGGER.debug("[{}] : Finishing poll", id);
    } catch (ConnectionException e) {
      LOGGER.debug("[{}] : Finishing poll due to {}:{}", id, e.getClass(), e.getMessage());
      sourceCallback.onConnectionException(e);
    } catch (Exception e) {
      LOGGER.debug("[{}] : Finishing poll due to {}:{}", id, e.getClass(), e.getMessage());
      sourceCallback.onConnectionException(new ConnectionException(e, "Unexpected error occurred trying to poll a message"));
    } finally {
      sessionManager.unbindSession();
      JmsCommons.closeQuietly(messageConsumer);
    }
  }

  /**
   * Initializes the poll context, this means: create a connection, initialize a transaction, bind the session to the session
   * manager and bind the connection to the source callback context
   *
   * @return Callback context where the message should be dispatched
   * @throws ConnectionException  if an error occurs trying to obtain the JMS Connection
   * @throws TransactionException if an error occurs when trying to bing the connection to the transaction
   */
  private Pair<SourceCallbackContext, JmsSession> initializePoll() throws ConnectionException, TransactionException {
    try {
      LOGGER.trace("[{}] : initializing poll ", id);
      SourceCallbackContext context = sourceCallback.createContext();
      XaJmsTransactionalConnection connect = (XaJmsTransactionalConnection) connectionProvider.connect();
      bindTransaction(context, connect);
      try {
        return new Pair<>(context, connect.getSession(InternalAckMode.TRANSACTED, false));
      } catch (JMSException e) {
        throw new MuleRuntimeException(e);
      }
    } finally {
      if (initializationCountDownLatch != null) {
        initializationCountDownLatch.countDown();
        this.initializationCountDownLatch = null;
      }
    }
  }

  /**
   * Tries to dispatch the given message to the current context.
   *
   * @param message Message to dispatch
   * @param context Context where the message should be dispatched
   * @return a boolean indicating if the poll should keep looking for messages.
   */
  private boolean dispatchMessage(Message message, SourceCallbackContext context) {
    if (stopRequested) {
      LOGGER.trace("[{}] : Stop has been requested, rolling back current transaction.", id);
      TransactionCoordination.getInstance().rollbackCurrentTransaction();
    } else if (message == null) {
      if (TX_COMMIT_ON_EMPTY_MESSAGE_ENABLED) {
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("[{}] : No message found, committing transaction.", id);
        }
        TransactionCoordination.getInstance().commitCurrentTransaction();
      } else {
        if (LOGGER.isDebugEnabled()) {
          LOGGER.debug("[{}] : No message found, rolling back transaction.", id);
        }
        TransactionCoordination.getInstance().rollbackCurrentTransaction();
      }
    } else {
      LOGGER.trace("[{}] : received message, handling to the flow.", id);
      context.addVariable(CONSUMER_CONTEXT_VAR, this);
      try {
        dispatcher.dispatchMessage(message, context);
        return false;
      } catch (Exception e) {
        LOGGER.trace("[{}] : Message dispatch failed, rolling back transaction.", id);
        TransactionCoordination.getInstance().rollbackCurrentTransaction();
      }
    }
    return true;
  }

  private void bindTransaction(SourceCallbackContext context, XATransactionalConnection connect)
      throws ConnectionException, TransactionException {
    // TODO MULE-15194: This is needed because of a race condition inside the SDK, so we need to retry again.
    // One time is always enough, but who knows?
    int timesToTry = MAX_TRANSACTION_BINDING_RETRIES;
    for (int i = 1; i <= timesToTry; i++) {
      try {
        LOGGER.trace("[{}] : about to bind connection [{}] into context: [{}]", id, connect, context);
        context.bindConnection(connectionProvider.connect());
        return;
      } catch (TransactionException e) {
        LOGGER.debug("Internal error, unable to bind connection to transaction. Attempt {}/{}", i, timesToTry, e);
        if (i == timesToTry) {
          throw e;
        } else {
          try {
            TransactionCoordination.getInstance().rollbackCurrentTransaction();
          } catch (Exception ex) {
            LOGGER.debug("Failure on transaction rollback", e);
          }
          // TODO MULE-15194: Also removed since this is only necessary because sometimes the bindConnection(..) fails but the
          // field value keeps being assigned preventing us to invoke it again.
          Field connectionField;
          try {
            connectionField = context.getClass().getDeclaredField("connection");
            connectionField.setAccessible(true);
            connectionField.set(context, null);
          } catch (Exception e1) {
            throw new MuleRuntimeException(e1);
          }
        }
      }
    }
  }

  public void stop() {
    stopRequested = true;
  }
}
