/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.jms.commons.internal.connection.session;

import static java.util.Optional.ofNullable;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.jms.commons.internal.config.InternalAckMode;
import org.mule.jms.commons.internal.connection.JmsConnection;
import org.mule.jms.commons.internal.connection.JmsTransactionalConnection;
import org.mule.jms.commons.internal.connection.JmsXaContext;
import org.mule.jms.commons.internal.source.JmsListener;
import org.mule.jms.commons.internal.source.JmsListenerLock;
import org.mule.runtime.extension.api.connectivity.XATransactionalConnection;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

import javax.jms.JMSException;
import javax.jms.Message;
import javax.jms.Session;
import javax.transaction.xa.XAResource;

import org.slf4j.Logger;

/**
 * Manager that takes the responsibility of register the session information to be able to execute a manual acknowledgement or a
 * recover over a {@link Session}. This is used when the {@link InternalAckMode} is configured in {@link InternalAckMode#MANUAL}
 *
 * @since 1.0
 */
// TODO MULE-16989 split this
public class JmsSessionManager {

  private static final Logger LOGGER = getLogger(JmsSessionManager.class);
  private final Map<String, SessionInformation> pendingSessions = new HashMap<>();
  private final ThreadLocal<TransactionInformation> transactionInformation = new ThreadLocal<>();
  private final ThreadLocal<Map<XATransactionalConnection, TransactionInformation>> xaTransactions = new ThreadLocal<>();

  /**
   * Registers the {@link Message} to the {@link Session} using the {@code ackId} in order to being able later to perform a
   * {@link InternalAckMode#MANUAL} ACK
   *
   * @param ackId   the id associated to the {@link Session} used to create the {@link Message}
   * @param message the {@link Message} to use for executing the {@link Message#acknowledge}
   * @param jmsLock the optional {@link JmsListenerLock} to be able to unlock the {@link JmsListener}
   * @throws IllegalArgumentException if no Session was registered with the given AckId
   */
  public void registerMessageForAck(String ackId, Message message, Session session, JmsListenerLock jmsLock) {
    registerMessageForAck(ackId, message, session, jmsLock, null);
  }

  public void registerMessageForAck(String ackId, Message message, Session session, JmsListenerLock jmsLock,
                                    JmsMessageAckedMonitor jmsMessageAckedMonitor) {
    if (LOGGER.isTraceEnabled()) {
      try {
        LOGGER.trace("Registering message for Ack. AckId: [{}], JmsMessageId: [{}]", ackId, message.getJMSMessageID());
      } catch (Exception e) {
        LOGGER.error("Caught exception while attempting to trace the JmsMessageID: {}", e.getMessage(), e);
      }
    }

    if (!pendingSessions.containsKey(ackId)) {
      pendingSessions.put(ackId, new SessionInformation(message, session, jmsLock, jmsMessageAckedMonitor));
    }

    if (LOGGER.isDebugEnabled()) {
      LOGGER.debug("Registered Message for Session AckId [" + ackId + "]");
    }
  }

  public synchronized boolean isPendingAck(String ackId) {
    return pendingSessions.get(ackId) != null;
  }

  /**
   * Executes the {@link Message#acknowledge} on the latest {@link Message} associated to the {@link Session} identified by the
   * {@code ackId}
   *
   * @param ackId the id associated to the {@link Session} that should be ACKed
   * @throws JMSException if an error occurs during the ack
   */
  public void ack(String ackId, AckCallback ackCallback) {
    Optional<SessionInformation> optionalSession = getSessionInformation(ackId);

    if (optionalSession.isPresent()) {
      Optional<JmsListenerLock> jmsListenerLock = optionalSession.get().getJmsListenerLock();

      // When consuming a message from a Message Listener, the ACK and Recover Actions must be done
      // in the same thread that onMessage() has been called.
      if (jmsListenerLock.isPresent()) {
        jmsListenerLock.get().executeOnListenerThread(() -> {
          try {
            Message message = optionalSession.get().getMessage();
            if (LOGGER.isTraceEnabled()) {
              try {
                LOGGER.trace("Attempting to acknowledge message with ackId: [{}], JmsMessageID:[{}]", ackId,
                             message.getJMSMessageID());
              } catch (Exception e) {
                LOGGER.error("Caught exception while attempting to trace the JmsMessageID: {}", e.getMessage(), e);
              }
            }
            message.acknowledge();
            ackCallback.onSuccess();
          } catch (JMSException e) {
            ackCallback.onError(e);
          }
        });
      } else {
        // If we don't have a Listener Lock, is because we are not actioning over Message received by a Message Listener
        try {
          optionalSession.get().getMessage().acknowledge();
          ackCallback.onSuccess();
        } catch (JMSException e) {
          ackCallback.onError(e);
        } finally {
          notifyMessageAcknowledged(optionalSession);
        }
      }


      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Acknowledged Message for Session with AckId [" + ackId + "]");
      }
    } else {
      ackCallback.onSuccess();
      // TODO - MULE-11963 : Improve error message for JmsAcknowledgement operations when the SessionInformation doesn't exist
      // anymore
      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("The session could not be acknowledged. This may be due to: \n " +
            "- The session has been already acknowledged\n" +
            "- The session has been recovered\n " +
            "- The given 'ackId' :  [" + ackId + "] is invalid.");
      }
    }
  }

  private synchronized void notifyMessageAcknowledged(Optional<SessionInformation> optionalSession) {
    optionalSession.ifPresent(SessionInformation::notifyMessageAcked);
  }

  /**
   * Executes the {@link Session#recover()} over the {@link Session} identified by the {@code ackId}
   *
   * @param ackId the id associated to the {@link Session} used to create the {@link Message}
   * @throws JMSException if an error occurs during recovering the session
   */
  public void recoverSession(String ackId, AckCallback ackCallback) {
    Optional<SessionInformation> optionalSession = getSessionInformation(ackId);
    if (optionalSession.isPresent()) {
      SessionInformation sessionInformation = optionalSession.get();

      Optional<JmsListenerLock> jmsListenerLock = sessionInformation.getJmsListenerLock();

      if (jmsListenerLock.isPresent()) {
        jmsListenerLock.get().executeOnListenerThread(() -> {
          try {
            optionalSession.get().getSession().recover();
            ackCallback.onSuccess();
          } catch (JMSException e) {
            ackCallback.onError(e);
          }
        });
      } else {
        // If we don't have a Listener Lock, is because we are not actioning over Message received by a Message Listener
        try {
          optionalSession.get().getSession().recover();
          ackCallback.onSuccess();
        } catch (JMSException e) {
          ackCallback.onError(e);
        } finally {
          notifyMessageAcknowledged(optionalSession);
        }
      }

      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("Recovered session for AckId [ " + ackId + "]");
      }
    } else {
      ackCallback.onSuccess();
      if (LOGGER.isDebugEnabled()) {
        // TODO - MULE-11963 : Improve error message for JmsAcknowledgement operations when the SessionInformation doesn't exist
        // anymore
        LOGGER.debug("The session could not be recovered, this could be due to: \n" +
            "- The session has been already recovered\n" +
            "- The all session messages has been already acknowledged\n" +
            "- The given 'ackId' : [" + ackId + "] is invalid");
      }
    }
  }

  private Optional<SessionInformation> getSessionInformation(String ackId) {
    return ofNullable(pendingSessions.remove(ackId));
  }

  /**
   * Binds the given {@link JmsSession} to the current {@link Thread}
   * 
   * @param session session to bind
   */
  public void bindToTransaction(JmsSession session) {
    LOGGER.debug("Binding transaction to current thread...");
    LOGGER.debug("Tx to bind: session: '{}'", session);

    getTransactionInformation().setJmsSession(session);
  }

  /**
   * Binds the given {@link JmsSession} and the correspondent {@link XAResource} to the current {@link Thread}
   * 
   * @param session session to bind
   */
  public void bindToTransaction(XATransactionalConnection jmsXaTransactionalConnection, JmsSession session,
                                JmsXaContext xaResource) {
    LOGGER.debug("Binding XA transaction to current thread...");
    LOGGER.trace("Tx to bind: connection: '{}', session: '{}', xaResource: '{}'", jmsXaTransactionalConnection, session,
                 xaResource);

    TransactionInformation value = new TransactionInformation();
    value.setJmsXaContext(xaResource);
    value.setJmsSession(session);
    final TransactionInformation previousTxInfo =
        onXATransactions(transactions -> transactions.put(jmsXaTransactionalConnection, value));

    if (previousTxInfo != null) {
      LOGGER.trace("Tx info for connection '{}' replaced. Old value was '{}'", jmsXaTransactionalConnection, previousTxInfo);
      xaResource.afterEnds(() -> previousTxInfo.getJmsXaContext().end());
    }
  }

  /**
   * Unbinds the current {@link JmsSession}, if there is one, of the current {@link Thread}
   */
  public void unbindSession() {
    LOGGER.debug("Unbinding transaction from current thread...");

    transactionInformation.remove();

    onXATransactions(transactions -> {
      if (!transactions.isEmpty()) {
        LOGGER.warn("XA connections still bound after transaction finished: {}", transactions);
      }
      return null;
    });
    xaTransactions.remove();
  }

  /**
   * Unbinds the current {@link JmsSession}, if there is one, of the current {@link Thread}
   */
  public void unbindSession(XATransactionalConnection jmsXaTransactionalConnection) {
    LOGGER.debug("Unbinding connection {} from XA transaction from current thread...", jmsXaTransactionalConnection);

    if (onXATransactions(transactions -> {
      transactions.remove(jmsXaTransactionalConnection);
      return transactions.isEmpty();
    })) {
      xaTransactions.remove();
    }
  }

  /**
   * @return the {@link Optional} {@link JmsSession} of the current {@link Thread}
   * @param connection
   */
  public Optional<JmsSession> getTransactedSession(JmsConnection connection) {
    if (connection instanceof XATransactionalConnection) {
      return this.getXaTransactedSession((XATransactionalConnection) connection);
    } else {
      return ofNullable(getTransactionInformation().getJmsSession());
    }
  }

  public Optional<JmsSession> getXaTransactedSession(XATransactionalConnection jmsXaTransactionalConnection) {
    return ofNullable(onXATransactions(transactions -> transactions.get(jmsXaTransactionalConnection)))
        .map(TransactionInformation::getJmsSession);
  }

  /**
   * @return The status of the transaction. - {@link TransactionStatus#NONE} means that there is no started transaction for the
   *         current {@link Thread} - {@link TransactionStatus#STARTED} means that there is a transaction being executed in the
   *         current {@link Thread}
   */
  public TransactionStatus getTransactionStatus() {
    TransactionStatus transactionStatus = getTransactionInformation().getTransactionStatus();
    return transactionStatus != null ? transactionStatus : TransactionStatus.NONE;
  }

  /**
   * @param transactionStatus The new {@link TransactionStatus}
   */
  public void changeTransactionStatus(TransactionStatus transactionStatus) {
    getTransactionInformation().setTransactionStatus(transactionStatus);
  }

  private TransactionInformation getTransactionInformation() {
    TransactionInformation transactionInformation = this.transactionInformation.get();
    if (transactionInformation == null) {
      if (LOGGER.isTraceEnabled()) {
        LOGGER.trace("Initializing single resource transaction information.");
      }
      transactionInformation = new TransactionInformation();
      this.transactionInformation.set(transactionInformation);
    }
    return transactionInformation;
  }

  /**
   * @return the {@link Optional} {@link JmsXaContext} of the current {@link Thread}
   */
  public Optional<JmsXaContext> getJmsXaContext(JmsTransactionalConnection xaTransactionalConnection) {
    return ofNullable(onXATransactions(transactions -> transactions.get(xaTransactionalConnection)))
        .map(TransactionInformation::getJmsXaContext);
  }

  private <I> I onXATransactions(Function<Map<XATransactionalConnection, TransactionInformation>, I> function) {
    Map<XATransactionalConnection, TransactionInformation> map = xaTransactions.get();
    if (map == null) {
      map = new HashMap<>();
      xaTransactions.set(map);
    }
    return function.apply(map);
  }

  /**
   * Returns the set of pending acknowledgements
   * 
   * @return the set of pending acknowledgements
   */
  public Set<String> getPendingAcknowledgements() {
    return pendingSessions.keySet();
  }
}
