package com.mulesoft.extension.mq;

import static com.mulesoft.extension.mq.internal.config.ConsumerAckMode.IMMEDIATE;
import static com.mulesoft.extension.mq.internal.config.ConsumerAckMode.MANUAL;
import static java.lang.String.format;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.fail;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.collection.IsMapContaining.hasEntry;
import org.mule.functional.api.flow.FlowRunner;
import org.mule.functional.junit4.MuleArtifactFunctionalTestCase;
import org.mule.runtime.api.event.Event;
import org.mule.runtime.api.message.Message;
import org.mule.runtime.api.metadata.MediaType;
import org.mule.runtime.core.api.construct.Flow;
import org.mule.runtime.core.api.event.CoreEvent;
import org.mule.tck.junit4.rule.SystemProperty;
import org.mule.tck.probe.JUnitLambdaProbe;
import org.mule.tck.probe.PollingProber;
import org.mule.tck.probe.Probe;
import org.mule.tck.util.TestConnectivityUtils;
import org.mule.test.runner.ArtifactClassLoaderRunnerConfig;

import com.mulesoft.extension.mq.api.attributes.AnypointMqMessagePublishAttributes;
import com.mulesoft.extension.mq.api.message.AnypointMQMessageContext;
import com.mulesoft.extension.mq.internal.config.ConsumerAckMode;
import com.mulesoft.mq.restclient.api.AnypointMqMessage;

import java.nio.charset.Charset;
import java.time.Instant;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Stack;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ArtifactClassLoaderRunnerConfig(testInclusions = {"org.codehaus.groovy:groovy-all", "org.apache.commons:commons-lang3", "com.google.code.gson:gson", "org.mule.runtime:mule-service-http-api",
        "com.fasterxml.jackson.core:jackson-core", "com.fasterxml.jackson.core::jackson-databind", "commons-io:commons-io"})
public class AnypointMQExtensionTestCase extends MuleArtifactFunctionalTestCase {

    public static final long PROCESS_DELAY = 500;

    private static final Logger LOGGER = LoggerFactory.getLogger(AnypointMQExtensionTestCase.class);
    private static final String SUBSCRIBER_FLOW = "subscribe";
    private static final String SUBSCRIBER_MANUAL_ACK_FLOW = "subscribe-manual-ack";
    private static final String SUBSCRIBER_IMMEDIATE_ACK_FLOW = "subscribe-immediate-ack";
    private static final String SUBSCRIBER_MAX_PREFETCH = "subscribe-max-prefetch";
    private static final String SUBSCRIBER_SLOW_PREFETCH = "subscribe-slow-prefetch";
    private static final String SUBSCRIBER_POLLING = "subscribe-polling";
    private static final String SUBSCRIBER_SLOW_POLLING = "subscribe-slow-polling";
    private static final String PROPERTIES = "properties";
    private static final String KEY_1 = "key1";
    private static final String VAL_1 = "val1";
    private static final Map<String, String> PROPERTIES_MAP = Collections.singletonMap(KEY_1, VAL_1);
    private static final String SUBSCRIBE_MANUAL_ACK_WITH_REDELIVERY = "subscribe-manual-ack-with-redelivery";
    private static final MediaType APPLICATION_JSON = MediaType.create("application", "json", Charset.forName("UTF-8"));
    private static final String EMPTY_STRING = "";
    private static final String PUBLISH_MESSAGE = "publish-message";
    private static final String PUBLISH_MANY_MESSAGE = "publish-many-concurrent-messages";
    public static final String DESTINATION_VAR = "destination";
    public static final String MESSAGE_COUNT_VAR = "messageCount";
    public static final String ACK_MODE_VAR = "ackMode";
    public static final String CONSUME_MESSAGE = "consume-message";
    public static final String ACK_MESSAGE = "ack-message";
    public static final String NACK_MESSAGE = "nack-message";

    private static Stack<Message> receivedMessages = new Stack<>();
    private static AtomicInteger receivedMessagesCount = new AtomicInteger(0);
    private static String SUBSCRIBER_QUEUE = "subscriberQueue";
    private static String MESSAGE_BODY = "this is a message";

    @Rule
    public AnypointMQExtensionRule rule = new AnypointMQExtensionRule();

    @Rule
    public SystemProperty rule2 = TestConnectivityUtils.disableAutomaticTestConnectivity();

    @Before
    public void setUp(){
        SUBSCRIBER_QUEUE = System.getProperty("subscriberQueue");
        receivedMessages.clear();
        receivedMessagesCount.set(0);
    }

    @Override
    protected String getConfigFile() {
        return "anypoint-mq-extension.xml";
    }

    @Test
    public void subscriberConsumesWithPrefetch() throws Exception {
      final int expectedMessages = 50;
      publishMany(expectedMessages, SUBSCRIBER_QUEUE);

      initializeFlow(SUBSCRIBER_MAX_PREFETCH);
      
      expectUniqueMessages(2000, 250, expectedMessages, new ConcurrentHashMap<>());
      expectNoMoreMessages(500);
    }

    @Test
    public void publishManyConcurrent() throws Exception {
        int expectedMessages = 400;
        flowRunner(PUBLISH_MANY_MESSAGE)
          .withPayload(MESSAGE_BODY)
          .withVariable(DESTINATION_VAR, SUBSCRIBER_QUEUE)
          .withVariable(MESSAGE_COUNT_VAR, expectedMessages)
          .run();

        initializeFlow(SUBSCRIBER_MAX_PREFETCH);
        expectUniqueMessages(30000, 500, expectedMessages, new ConcurrentHashMap<>());
        expectNoMoreMessages(500);
        muleContext.dispose();
    }

    @Test
    public void slowSubscriberConsumesWithPrefetch() throws Exception {
        final int expectedMessages = 50;
        publishMany(expectedMessages, SUBSCRIBER_QUEUE);

        initializeFlow(SUBSCRIBER_SLOW_PREFETCH);

        assertExpectedOnSlowConsumer(5000, 20000, expectedMessages);
        expectNoMoreMessages(500);
    }

    @Test
    public void subscriberConsumesWithPolling() throws Exception {
      final int expectedMessages = 30;
      publishMany(expectedMessages, SUBSCRIBER_QUEUE);

      initializeFlow(SUBSCRIBER_POLLING);

      assertExpectedOnSlowConsumer(2000, 5000, expectedMessages);
      expectNoMoreMessages(500);
    }

    @Test
    public void slowSubscriberConsumesWithPolling() throws Exception {
        final int expectedMessages = 30;
        publishMany(expectedMessages, SUBSCRIBER_QUEUE);

        initializeFlow(SUBSCRIBER_SLOW_POLLING);

        assertExpectedOnSlowConsumer(10000, 10000, expectedMessages);
        expectNoMoreMessages(500);
    }

    private void assertExpectedOnSlowConsumer(long initialBatchTimeout, long finalBatchTimeout, int expectedMessages){
        Map<String, Message> processed = new ConcurrentHashMap<>();
        try {
            expectUniqueMessages(initialBatchTimeout, 500, expectedMessages, processed);
            fail(format("Polling for %sms should not be enough to clear all the %s messages",
                        initialBatchTimeout, expectedMessages));
        } catch (AssertionError e){
            expectUniqueMessages(finalBatchTimeout, 500, expectedMessages, processed);
        }
    }

    @Test
    public void publishedMessagesAreConsumedBySubscriber() throws Exception {
        initializeFlow(SUBSCRIBER_FLOW);
        final String messageId = publishMessage(publishMessage(SUBSCRIBER_QUEUE)
                .withPayload(MESSAGE_BODY)
                .withVariable(PROPERTIES, PROPERTIES_MAP));
        final Message muleMessage = checkMessageArrived(messageId);
        assertThat((getAttributes(muleMessage)).getMessage().getProperties(), hasEntry(KEY_1, VAL_1));
    }

    @Test
    public void publishMessageCommunicatesContentMimeTypeUsingSubscriber() throws Exception {
        initializeFlow(SUBSCRIBER_FLOW);
        final String messageId = publishMessage(publishMessage(SUBSCRIBER_QUEUE)
                .withPayload("{}")
                .withMediaType(APPLICATION_JSON)
                .withVariable(PROPERTIES, PROPERTIES_MAP));
        final Message muleMessage = checkMessageArrived(messageId);
        assertThat(muleMessage.getPayload().getDataType().getMediaType(), is(APPLICATION_JSON));
    }

    @Test
    public void publishMessageCommunicatesContentMimeTypeUsingConsumer() throws Exception {
        final String messageId = publishMessage(publishMessage(SUBSCRIBER_QUEUE)
                .withPayload("{}")
                .withMediaType(APPLICATION_JSON)
                .withVariable(PROPERTIES, PROPERTIES_MAP));
        Message message = consumeMessage(messageId, SUBSCRIBER_QUEUE, IMMEDIATE);
        assertThat(message.getPayload().getDataType().getMediaType(), is(APPLICATION_JSON));
    }

    @Test
    public void publishedMessageIsNackAndReceivedAgain() throws Exception {
        Thread.sleep(2000);
        initializeFlow(SUBSCRIBER_MANUAL_ACK_FLOW);
        final String messageId = publishMessage(MESSAGE_BODY, SUBSCRIBER_QUEUE);
        final Message muleMessage = checkMessageArrived(messageId);
        nackMessage(muleMessage);
        checkMessageArrived(messageId);
        ackMessage(muleMessage);
    }

    @Test
    public void publishedMessageIsConsumedNackedAndConsumedAgain() throws Exception {
        final String messageId = publishMessage(MESSAGE_BODY, SUBSCRIBER_QUEUE);
        final Message muleMessage = consumeMessage(messageId, SUBSCRIBER_QUEUE, MANUAL);
        AnypointMQMessageContext messageContext = getAttributes(muleMessage);
        nackMessage(messageContext);
        consumeMessage(messageId, SUBSCRIBER_QUEUE, MANUAL);
        ackMessage(messageContext);
    }

    @Test(expected = AssertionError.class)
    public void publishedMessageIsAckedAndNotReceivedAgain() throws Exception {
        initializeFlow(SUBSCRIBER_MANUAL_ACK_FLOW);
        final String messageId = publishMessage(MESSAGE_BODY, SUBSCRIBER_QUEUE);
        final Message muleMessage = checkMessageArrived(messageId);
        ackMessage(muleMessage);
        checkMessageArrived(messageId);
    }

    @Test
    public void publishedMessageTimeoutsAndIsReceivedAgain() throws Exception {
        final String messageId = publishMessage(MESSAGE_BODY, SUBSCRIBER_QUEUE);
        initializeFlow(SUBSCRIBER_MANUAL_ACK_FLOW);
        checkMessageArrived(messageId);
        Thread.sleep(5000);
        final Message muleMessage = checkMessageArrived(messageId);
        ackMessage(muleMessage);
    }

    @Test
    public void publishedMessageIsReceivedWithImmediateAckAndNotReceivedAgain() throws Exception {
        initializeFlow(SUBSCRIBER_IMMEDIATE_ACK_FLOW);
        final String messageId = publishMessage(MESSAGE_BODY + Instant.now(), SUBSCRIBER_QUEUE);
        checkMessageArrived(messageId);
        assertThatExceptionOfType(AssertionError.class).isThrownBy(() -> checkMessageArrived(messageId));
    }

    @Test
    public void publishMessageWithNullBody() throws Exception {
        initializeFlow(SUBSCRIBER_IMMEDIATE_ACK_FLOW);
        final String messageId = publishMessage(null, SUBSCRIBER_QUEUE);
        final Message muleMessage = checkMessageArrived(messageId);
        assertThat(getPayloadAsString(muleMessage), is(EMPTY_STRING));
    }

    @Test
    public void publishedMessageReachesRedeliveryCount() throws Exception {
        initializeFlow(SUBSCRIBE_MANUAL_ACK_WITH_REDELIVERY);
        final String messageId = publishMessage(MESSAGE_BODY, SUBSCRIBER_QUEUE);
        int timesArrived = 0;
        checkMessageArrived(messageId);
        timesArrived++;
        checkMessageArrived(messageId);
        timesArrived++;
        checkMessageArrived(messageId);
        timesArrived++;

        assertThatExceptionOfType(AssertionError.class).isThrownBy(() -> checkMessageArrived(messageId));
        assertThat(timesArrived, is(3));
    }

    private void expectUniqueMessages(long timeout, long delay,
                                      int expectedMessages,
                                      Map<String, Message> existingMessages) throws AssertionError {
        new PollingProber(timeout, delay).check(new Probe() {

            @Override
            public boolean isSatisfied() {
                while (!receivedMessages.isEmpty()) {
                    Message received = receivedMessages.pop();
                    String id = getAttributes(received).getMessage().getId();
                    Message previous = existingMessages.put(id, received);
                    if (previous != null){
                        fail("Duplicated message found with ID: " + id);
                    }
                }
                return receivedMessagesCount.get() == expectedMessages;
            }

            @Override
            public String describeFailure() {
                return format("Never reached the expected amount of messages. Expected: %s, but got: %s",
                              expectedMessages, receivedMessagesCount.get());
            }
        });
    }

    private void expectNoMoreMessages(long timeout) throws Exception {
        Thread.sleep(timeout);
        List<Message> messages = new LinkedList<>();
        while(!receivedMessages.isEmpty()){
            messages.add(receivedMessages.pop());
        }
        assertThat("No more messages were expected but found: " +
                     messages.stream()
                       .map(i -> getAttributes(i))
                       .map(AnypointMQMessageContext::getMessage)
                       .map(AnypointMqMessage::getId)
                       .collect(toList()),
                    
                    messages.isEmpty(), is(true));
    }

    private Message consumeMessage(String messageId, String destination, ConsumerAckMode ackMode) throws Exception {
        Message message;
        do {
            message = flowRunner(CONSUME_MESSAGE)
                    .withVariable(DESTINATION_VAR, destination)
                    .withVariable(ACK_MODE_VAR, ackMode)
                    .run()
                    .getMessage();
        } while (!getAttributes(message).getMessage().getId().equals(messageId));

        return message;
    }

    public void publishMany(int amount, String destination) throws Exception {
        for (int i = 0; i < amount; i++){
            publishMessage(destination)
              .withPayload(MESSAGE_BODY)
              .withVariable(PROPERTIES, PROPERTIES_MAP).run();
        }
    }

    private static AnypointMQMessageContext getAttributes(Message message) {
        return AnypointMQMessageContext.class.cast(message.getAttributes().getValue());
    }

    private Message checkMessageArrived(String messageId) {
        PollingProber prober = new PollingProber(10000, 1000);
        final Message[] muleMessage = new Message[1];
        prober.check(new JUnitLambdaProbe(() -> {
            if (!receivedMessages.isEmpty()) {
                final Message pop = receivedMessages.pop();
                final AnypointMQMessageContext attributes = getAttributes(pop);
                if (attributes.getMessage().getId().equals(messageId)) {
                    muleMessage[0] = pop;
                    return true;
                } else {
                    try {
                        ackMessage(attributes);
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                }
            }
            return false;
        }));
        return muleMessage[0];
    }

    private String publishMessage(String message, String destination) throws Exception {
        final Event muleEvent = flowRunner(PUBLISH_MESSAGE)
                .withPayload(message)
                .withVariable(DESTINATION_VAR, destination)
                .run();

        final AnypointMqMessagePublishAttributes attributes = AnypointMqMessagePublishAttributes.class.cast(muleEvent.getMessage().getAttributes().getValue());
        return attributes.getMessageId();
    }

    public String publishMessage(FlowRunner flowRunner) throws Exception {
        final Event muleEvent = flowRunner.run();
        final AnypointMqMessagePublishAttributes attributes = AnypointMqMessagePublishAttributes.class.cast(muleEvent.getMessage().getAttributes().getValue());
        return attributes.getMessageId();
    }

    private FlowRunner publishMessage(String destination) {
        return flowRunner(PUBLISH_MESSAGE).withVariable(DESTINATION_VAR, destination);
    }

    public static void onMessage(CoreEvent event) {
        LOGGER.debug("ON MESSAGE: " + getAttributes(event.getMessage()).getMessage().getId() + " - # "+ receivedMessagesCount.incrementAndGet());
        receivedMessages.add(event.getMessage());
    }

    private void initializeFlow(String flowName) throws Exception {
        ((Flow) getFlowConstruct(flowName)).start();
    }

    private void ackMessage(Message muleMessage) throws Exception {
        final AnypointMQMessageContext attributes = getAttributes(muleMessage);
        ackMessage(attributes);
    }

    private void ackMessage(AnypointMQMessageContext messageContext) throws Exception {
        flowRunner(ACK_MESSAGE).withPayload(messageContext).run();
    }

    private void nackMessage(Message muleMessage) throws Exception {
        final AnypointMQMessageContext attributes = getAttributes(muleMessage);
        nackMessage(attributes);
    }

    private void nackMessage(AnypointMQMessageContext messageContext) throws Exception {
        flowRunner(NACK_MESSAGE).withPayload(messageContext).run();
    }
}
