/*
 * Copyright © MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.jms.commons.internal.source.polling;

import static java.lang.String.format;
import static org.mule.jms.commons.internal.common.JmsCommons.getDestinationType;
import static org.slf4j.LoggerFactory.getLogger;

import org.mule.jms.commons.api.destination.ConsumerType;
import org.mule.jms.commons.api.exception.JmsExtensionException;
import org.mule.jms.commons.internal.config.JmsConfig;
import org.mule.jms.commons.internal.connection.JmsTransactionalConnection;
import org.mule.jms.commons.internal.connection.JmsXaContext;
import org.mule.jms.commons.internal.connection.session.JmsSessionManager;
import org.mule.jms.commons.internal.source.JmsConnectionExceptionResolver;
import org.mule.jms.commons.internal.source.MessageConsumerDelegate;
import org.mule.jms.commons.internal.support.JmsSupport;
import org.mule.runtime.api.connection.ConnectionException;
import org.mule.runtime.api.connection.ConnectionProvider;
import org.mule.runtime.api.message.Error;
import org.mule.runtime.api.meta.MuleVersion;
import org.mule.runtime.api.scheduler.Scheduler;
import org.mule.runtime.api.util.concurrent.Latch;
import org.mule.runtime.core.api.config.MuleManifest;
import org.mule.runtime.extension.api.runtime.source.SourceCallback;
import org.mule.runtime.extension.api.runtime.source.SourceCallbackContext;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Future;

import org.slf4j.Logger;

/**
 * {@link MessageConsumerDelegate} implementation which consumes messages inside XA Transactions doing polling.
 *
 * @since 1.0.0
 */
public class JmsXaPollingMessageConsumerDelegate implements MessageConsumerDelegate {

  private static final Logger LOGGER = getLogger(JmsXaPollingMessageConsumerDelegate.class);

  private final String destination;
  private final ConsumerType consumerType;
  private final JmsConfig config;
  private final String selector;
  private final JmsSessionManager sessionManager;
  private final ConnectionProvider connectionProvider;
  private final Scheduler scheduler;
  private final String inboundContentType;
  private final String inboundEncoding;
  private final JmsConnectionExceptionResolver exceptionResolver;
  private final MessageConsumerFactory messageConsumerFactory;
  private final JmsTransactionalConnection connection;
  private final JmsSupport jmsSupport;
  private final SourceCallback sourceCallback;

  static final String CONSUMER_CONTEXT_VAR = "CONSUMER";

  private final List<JmsXaMessageConsumer> xaMessageConsumers = new ArrayList<>();

  public JmsXaPollingMessageConsumerDelegate(JmsTransactionalConnection connection, JmsSupport jmsSupport, String destination,
                                             ConsumerType consumerType, JmsConfig config, String selector,
                                             JmsSessionManager sessionManager, ConnectionProvider connectionProvider,
                                             Scheduler scheduler, String inboundContentType, String inboundEncoding,
                                             SourceCallback sourceCallback, JmsConnectionExceptionResolver exceptionResolver) {
    this.connection = connection;
    this.jmsSupport = jmsSupport;
    this.destination = destination;
    this.consumerType = consumerType;
    this.config = config;
    this.selector = selector;
    this.sessionManager = sessionManager;
    this.connectionProvider = connectionProvider;
    this.scheduler = scheduler;
    this.inboundContentType = inboundContentType;
    this.inboundEncoding = inboundEncoding;
    this.sourceCallback = sourceCallback;
    this.exceptionResolver = exceptionResolver;
    this.messageConsumerFactory = new MessageConsumerFactory(connection, destination, selector, consumerType, config);
  }

  /**
   * {@inheritDoc}
   */
  @Override
  public void createConsumers(int numberOfConsumers) throws ConnectionException {
    try {
      MuleVersion muleVersion = new MuleVersion(MuleManifest.getProductVersion());
      for (int i = 0; i < numberOfConsumers; i++) {
        CountDownLatch xaTransactionInitialization = new Latch();
        JmsXaMessageConsumer messageConsumer = new JmsXaMessageConsumer(messageConsumerFactory, sourceCallback, sessionManager,
                                                                        connectionProvider, config,
                                                                        inboundContentType, inboundEncoding,
                                                                        jmsSupport.getSpecification(), i,
                                                                        xaTransactionInitialization, exceptionResolver,
                                                                        muleVersion);
        xaMessageConsumers.add(messageConsumer);
        scheduler.submit(messageConsumer);
        xaTransactionInitialization.await();
      }
    } catch (JmsExtensionException e) {
      String msg = format("An error occurred while creating the consumers for destination [%s:%s]: %s",
                          getDestinationType(consumerType), destination, e.getMessage());
      LOGGER.error(msg, e);
      stop();

      throw new ConnectionException(msg, e, null, connection);
    } catch (InterruptedException e) {
      throw new JmsExtensionException("The JMS Consumer creation has been interrupted, probably the Listener is being stopped. ",
                                      e);
    }
  }

  /**
   * {@inheritDoc}
   */
  @Override
  public void onSuccess(SourceCallbackContext callbackContext) {
    final Optional<Runnable> consumer = callbackContext.<Runnable>getVariable(CONSUMER_CONTEXT_VAR);
    final Optional<JmsXaContext> jmsXaContext = sessionManager.getJmsXaContext(connection);

    if (LOGGER.isTraceEnabled()) {
      LOGGER.trace("onSuccess called for '{}'", this);
      LOGGER.trace("onSuccess called, CONSUMER_CONTEXT_VAR={}", consumer);
      LOGGER.trace("onSuccess called, sessionManager.jmsXaContext={}", jmsXaContext);
    }

    consumer.ifPresent(runnable -> jmsXaContext
        .ifPresent(xaResource -> xaResource.afterEnds(() -> reschedulePolling(runnable))));

  }

  /**
   * {@inheritDoc}
   */
  @Override
  public void onError(SourceCallbackContext callbackContext, Error error) {
    final Optional<Runnable> consumer = callbackContext.<Runnable>getVariable(CONSUMER_CONTEXT_VAR);

    if (LOGGER.isTraceEnabled()) {
      LOGGER.trace("onError called for '{}' ({})", this, error);
      LOGGER.trace("onError called, CONSUMER_CONTEXT_VAR={}", consumer);
    }

    consumer.ifPresent(this::reschedulePolling);
  }

  protected void reschedulePolling(Runnable runnable) {
    LOGGER.trace("Rescheduling poller {}", runnable);
    final Future<?> reschedulePolling = scheduler.submit(runnable);
    LOGGER.trace("Poller {} rescheduled: {}", runnable, reschedulePolling);
  }

  /**
   * {@inheritDoc}
   */
  @Override
  public void stop() {
    xaMessageConsumers.forEach(JmsXaMessageConsumer::stop);
  }

  @Override
  public void disableConsumers() {
    xaMessageConsumers.forEach(JmsXaMessageConsumer::stop);
  }
}
