/*
 * (c) 2003-2021 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.connectors.mqtt3.internal.connection;

import com.mulesoft.connectors.mqtt3.api.MQTT3MessageAttributes;
import com.mulesoft.connectors.mqtt3.api.QoS;
import com.mulesoft.connectors.mqtt3.api.Topic;
import com.mulesoft.connectors.mqtt3.internal.exceptions.MQTT3InvalidTopicException;
import com.mulesoft.connectors.mqtt3.internal.exceptions.MQTT3PersistenceException;
import com.mulesoft.connectors.mqtt3.internal.exceptions.MQTT3PublishException;
import com.mulesoft.connectors.mqtt3.internal.routing.DefaultMQTT3Message;
import com.mulesoft.connectors.mqtt3.internal.routing.LWTMessage;
import com.mulesoft.connectors.mqtt3.internal.routing.MQTT3MessageHandler;
import com.mulesoft.connectors.mqtt3.internal.routing.MQTT3TopicRouter;
import com.mulesoft.connectors.mqtt3.internal.source.MQTT3ConnectionLostHandler;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.eclipse.paho.client.mqttv3.IMqttActionListener;
import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken;
import org.eclipse.paho.client.mqttv3.IMqttToken;
import org.eclipse.paho.client.mqttv3.MqttAsyncClient;
import org.eclipse.paho.client.mqttv3.MqttCallbackExtended;
import org.eclipse.paho.client.mqttv3.MqttClient;
import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
import org.eclipse.paho.client.mqttv3.MqttException;
import org.eclipse.paho.client.mqttv3.MqttMessage;
import org.eclipse.paho.client.mqttv3.MqttPersistenceException;
import org.eclipse.paho.client.mqttv3.MqttSecurityException;
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence;
import org.mule.runtime.api.connection.ConnectionException;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.tls.TlsContextFactory;
import org.mule.runtime.api.util.Reference;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.source.SourceCallback;
import org.slf4j.Logger;

import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.ConcurrentHashMap;
import java.util.Set;

import static java.util.stream.Collectors.toList;

import static com.mulesoft.connectors.mqtt3.internal.exceptions.MQTT3ConnectionExceptionResolver.resolveMQTT3ConnectionException;
import static com.mulesoft.connectors.mqtt3.internal.exceptions.MQTT3Error.UNAUTHORIZED;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.eclipse.paho.client.mqttv3.MqttTopic.isMatched;
import static org.slf4j.LoggerFactory.getLogger;

/**
 * The default implementation of an {@link MQTT3Connection}
 */
public class DefaultMQTT3Connection implements MQTT3Connection {

  private static final String ROOT_TOPIC = "#";
  private static final Logger LOGGER = getLogger(DefaultMQTT3Connection.class);
  private static final int COMPLETION_WAIT_TIMEOUT_MILLIS = 10000;
  public static final int ADDITIONAL_PROCESSING_TIME = 3;

  private final MqttConnectOptions mqttConnectOptions = new MqttConnectOptions();

  private MQTT3TopicRouter topicRouter;
  private MqttAsyncClient mqttClient;
  private IMqttToken mqttToken;
  private AtomicBoolean notifiedReconnect = new AtomicBoolean(false);
  private int connectionTimeoutMultiplier = 1;
  private Set<String> brokerSubscriptions = ConcurrentHashMap.newKeySet();

  private MQTT3ConnectionLostHandler connectionLostHandler;

  /**
   * Returns an instance of a DefaultMQTT3Connection.
   * @param url the URL to connect to.
   * @param clientId the client id that identifies this connection.
   * @param connectionOptions an instance of {@link MQTT3ConnectionOptions}
   * @param lwtMessage an instance of {@link LWTMessage}
   * @throws ConnectionException
   */
  public DefaultMQTT3Connection(String url, String clientId,
                                MQTT3ConnectionOptions connectionOptions, MQTT3FilePersistenceOptions filePersistenceOptions,
                                LWTMessage lwtMessage)
      throws ConnectionException {
    this.topicRouter = new MQTT3TopicRouter((topicFilter, topic) -> isMatched(topicFilter, topic));
    this.connectionLostHandler = new MQTT3ConnectionLostHandler();
    try {
      if (filePersistenceOptions.getEnableFilePersistence()) {
        if (filePersistenceOptions.getDataStorePath() != null && !filePersistenceOptions.getDataStorePath().isEmpty()) {
          this.mqttClient =
              new MqttAsyncClient(url, clientId, new MqttDefaultFilePersistence(filePersistenceOptions.getDataStorePath()));
        } else {
          this.mqttClient = new MqttAsyncClient(url, clientId, new MqttDefaultFilePersistence());
        }
      } else {
        this.mqttClient = new MqttAsyncClient(url, clientId, new MemoryPersistence());
      }
    } catch (MqttException mqttException) {
      LOGGER.error("Failed to initialize mqttConnection, check that your connection parameters are correct."
          + mqttException.getMessage(), mqttException);
      throw new ConnectionException(mqttException, this);
    }

    long keepAliveIntervalSeconds =
        SECONDS.convert(connectionOptions.getKeepAliveInterval(), connectionOptions.getKeepAliveIntervalUnit());

    mqttConnectOptions.setConnectionTimeout((int) (SECONDS.convert(connectionOptions.getConnectionTimeout(),
                                                                   connectionOptions.getConnectionTimeoutUnit())));
    mqttConnectOptions.setCleanSession(connectionOptions.getCleanSession());
    mqttConnectOptions.setKeepAliveInterval((int) keepAliveIntervalSeconds);
    mqttConnectOptions.setMaxInflight(connectionOptions.getMaxInFlight());

    setLastWillAndTestamentMessage(lwtMessage);
    setMqttClientCallback();
  }

  private DefaultMQTT3Connection() {}

  /**
   * @return true if the clean session flag is set to true.
   */
  public boolean isCleanSessionEnabled() {
    return mqttConnectOptions.isCleanSession();
  }

  /**
   * Sets a username and a password for authentication.
   * @param username the authentication username.
   * @param password the authentication password.
   */
  public void setUsernamePassword(String username, String password) {
    mqttConnectOptions.setUserName(username);
    if (password != null) {
      mqttConnectOptions.setPassword(password.toCharArray());
    }
  }

  /**
   * Sets a list of fail over servers to be iterated over until a connection is successfully established.
   * @param failOverServerArray the list of servers to iterate over.
   */
  public void setFailOverServers(String[] failOverServerArray) {
    mqttConnectOptions.setServerURIs(failOverServerArray);
    this.connectionTimeoutMultiplier = failOverServerArray.length;
  }

  /**
   * Attempts to establish a connection to the mqtt broker.
   * @throws ConnectionException if connection attempt was unsuccessful.
   */
  public void connect() throws ConnectionException {
    if (isConnected()) {
      LOGGER.debug("Client is already connected to {}", mqttClient.getCurrentServerURI());
      return;
    }
    Reference<Throwable> throwableReference = new Reference<>();
    try {
      CountDownLatch latch = new CountDownLatch(1);

      this.mqttToken = mqttClient.connect(mqttConnectOptions, null, new IMqttActionListener() {

        @Override
        public void onSuccess(IMqttToken asyncActionToken) {
          throwableReference.set(null);
          LOGGER.debug("Successfully connected to " + mqttClient.getCurrentServerURI());
          latch.countDown();
        }

        @Override
        public void onFailure(IMqttToken asyncActionToken, Throwable exception) {
          LOGGER.error("Error occurred establishing connection to " + mqttClient.getCurrentServerURI() + ":" +
              exception.getMessage(), exception);
          throwableReference.set(exception);
          latch.countDown();
        }
      });

      if (!latch.await(((mqttConnectOptions.getConnectionTimeout() + ADDITIONAL_PROCESSING_TIME) * connectionTimeoutMultiplier),
                       SECONDS)) {
        throwableReference.set(new ConnectionException("Error occurred attempting to establish connection"));
      }
    } catch (MqttSecurityException securityException) {
      throw new ModuleException("Error connecting to mqtt broker: not authorized to connect", UNAUTHORIZED, securityException);
    } catch (InterruptedException | MqttException mqttException) {
      if (notifiedReconnect.compareAndSet(false, true)) {
        LOGGER.error("Error occurred attempting to establish connection to mqtt broker " + mqttException, mqttException);
        throw new ConnectionException(mqttException, this);
      }
    }

    Throwable throwable = throwableReference.get();
    if (throwable != null) {
      Optional<ConnectionException> mqtt3ConnectionException = resolveMQTT3ConnectionException(throwable, this);
      if (!this.isConnected() && mqtt3ConnectionException.isPresent()) {
        if (notifiedReconnect.compareAndSet(false, true)) {
          LOGGER.error("Error occurred attempting to establish connection to mqtt broker " + throwableReference.get(), throwable);
          throw mqtt3ConnectionException.get();
        }
      } else if (throwable instanceof MqttSecurityException) {
        throw new ModuleException("Error connecting to mqtt broker: not authorized to connect", UNAUTHORIZED, throwable);
      } else {
        throw new MuleRuntimeException(throwable);
      }
    }
  }

  /**
   * Registers a {@code MQTT3MessageHandler} locally for the provided topics without subscribing to the broker.
   * This allows handlers to be ready before connecting to prevent race conditions.
   * @param topics the list of {@code Topic} filters for which the handler will be registered.
   * @param messageHandler a {@code MQTT3MessageHandler} to be called when messages for these topics are received.
   */
  @Override
  public void registerMessageHandler(List<Topic> topics, MQTT3MessageHandler messageHandler) {
    try {
      topicRouter.registerCallbackForTopics(topics, messageHandler);
      LOGGER.debug("Registered message handler for topics: {}", topics.stream().map(Topic::getTopicFilter).toArray());
    } catch (Exception exception) {
      LOGGER.error("Exception occurred during message handler registration for topics " + topics, exception);
      throw exception;
    }
  }

  /**
   * Subscribes to the provided topics on the broker. The message handlers should already be registered.
   * @param topics the list of {@code Topic} filters to subscribe to on the broker.
   * @throws ConnectionException if an error occurs during subscription.
   */
  @Override
  public void subscribeToTopics(List<Topic> topics) throws ConnectionException {
    try {
      List<Topic> newTopicsSubscriptionList = topics.stream()
          .filter(topic -> !isAlreadySubscribedToBroker(topic))
          .collect(toList());

      if (newTopicsSubscriptionList.isEmpty()) {
        LOGGER.debug("All topics are already subscribed to the broker");
        return;
      }

      String[] topicsArray = newTopicsSubscriptionList.stream().map(Topic::getTopicFilter).toArray(String[]::new);
      int[] qosArray = newTopicsSubscriptionList.stream()
          .map(Topic::getQos)
          .map(QoS::getValue)
          .mapToInt(Integer::intValue)
          .toArray();

      LOGGER.debug("Subscribing to topics: {}, with QOS {}", topicsArray, qosArray);
      if (topics.stream().anyMatch(topic -> topic.getTopicFilter().equals(ROOT_TOPIC))) {
        LOGGER.warn(
                    "Issuing subscription request for the root topic " + ROOT_TOPIC
                        + ". This is not advisable, you will receive all messages issued to all topics.");
      }

      this.subscribe(topicsArray, qosArray);
    } catch (Exception exception) {
      LOGGER.error("Exception occurred during subscription to topics {}", topics, exception);
      throw exception;
    }
  }

  /**
   * Checks if a topic is already subscribed to the broker.
   * @param topic the topic to check
   * @return true if already subscribed to the broker
   */
  private boolean isAlreadySubscribedToBroker(Topic topic) {
    return brokerSubscriptions.contains(topic.getTopicFilter());
  }

  /**
   * Unsubscribes the client with the connection's client id from the provided list of {@link Topic}s.
   * Removes the provided {@link MQTT3MessageHandler} from each of the {@link Topic}'s callback list.
   * @param topics the list of {@code Topic} filters from which the connection will be unsubscribed.
   * @param messageHandler a {@code MQTT3MessageHandler} to be removed from each of the {@link Topic}'s callback list.
   */
  public void unsubscribeListenerFromTopics(List<Topic> topics, MQTT3MessageHandler messageHandler) {
    try {
      List<Topic> deletedTopics = topicRouter.deregisterCallbackForTopics(topics, messageHandler);
      if (this.isCleanSessionEnabled() && !deletedTopics.isEmpty()) {
        String[] deletedTopicsArray = deletedTopics.stream().map(Topic::getTopicFilter).toArray(String[]::new);
        try {
          LOGGER.debug("Unsubscribing from topics: {}", deletedTopicsArray);
          mqttClient.unsubscribe(deletedTopicsArray);
          for (String topic : deletedTopicsArray) {
            brokerSubscriptions.remove(topic);
          }
        } catch (MqttException exception) {
          LOGGER.error("Error occurred unsubscribing from topics {}: {}", deletedTopicsArray, exception);
        }
      }
    } catch (Exception exception) {
      LOGGER.error("Error unsubscribing callbacks for topics " + topics);
    }
  }

  /**
   * Adds a {@link SourceCallback} to the {@link MQTT3ConnectionLostHandler}.
   */
  @Override
  public void addSourceCallbackToConnectionLostHandler(SourceCallback<byte[], MQTT3MessageAttributes> callback) {
    connectionLostHandler.addCallback(callback);
  }

  /**
   * Sets a {@link MQTT3ConnectionLostHandler} to be invoked when connection is lost.
   */
  private void setMqttClientCallback() {
    MqttCallbackExtended callbackExtended = new MqttCallbackExtended() {

      @Override
      public void connectComplete(boolean isReconnecting, String serverURI) {
        notifiedReconnect.set(false);
        if (isReconnecting && mqttConnectOptions.isCleanSession()) {
          List<Topic> subscriptions = topicRouter.getDistinctTopicFilters();
          if (subscriptions.isEmpty()) {
            return;
          }

          String[] topicFilters = new String[subscriptions.size()];
          int[] qosForTopics = new int[subscriptions.size()];
          for (int i = 0; i < subscriptions.size(); i++) {
            topicFilters[i] = subscriptions.get(i).getTopicFilter();
            qosForTopics[i] = subscriptions.get(i).getQos().getValue();
          }

          if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Reconnect to {} complete.", serverURI);
            LOGGER.debug("Recovering subscriptions to topics {}", topicFilters);
          }

          try {
            IMqttToken token = mqttClient.subscribe(topicFilters, qosForTopics, null, new MQTT3SubscriptionSuccessListener());
            token.waitForCompletion(COMPLETION_WAIT_TIMEOUT_MILLIS);
          } catch (MqttException ex) {
            LOGGER.error("Re-subscribe after reconnection failed for " + serverURI + " with error " + ex);
          }
        }
      }

      @Override
      public void connectionLost(Throwable throwable) {
        if (notifiedReconnect.compareAndSet(false, true)) {
          connectionLostHandler.onConnectionLost(throwable, this);
          brokerSubscriptions.clear();
        }
      }

      @Override
      public void messageArrived(String topic, MqttMessage mqttMessage) throws Exception {
        LOGGER
            .info("Message received from broker - Topic: {}, MessageId: {}, QoS: {}, Payload size: {} bytes, Retained: {}, Duplicate: {}",
                  topic, mqttMessage.getId(), mqttMessage.getQos(),
                  mqttMessage.getPayload().length, mqttMessage.isRetained(), mqttMessage.isDuplicate());

        topicRouter.handleMessageArrived(new DefaultMQTT3Message(mqttMessage.getId(), topic, mqttMessage.getPayload(),
                                                                 mqttMessage.getQos(), mqttMessage.isDuplicate(),
                                                                 mqttMessage.isRetained()));
      }

      @Override
      public void deliveryComplete(IMqttDeliveryToken iMqttDeliveryToken) {

      }

    };

    mqttClient.setCallback(callbackExtended);
  }

  /**
   * @return true if the connection is still alive.
   */
  public boolean isConnected() {
    return mqttClient.isConnected();
  }

  @Override
  public boolean testConnectivity() {
    if (isConnected()) {
      return true;
    }

    String validationClientId = mqttClient.getClientId() + "_validation_" + UUID.randomUUID();

    try (MqttClient tempClient = new MqttClient(
                                                mqttClient.getServerURI(),
                                                validationClientId,
                                                new MemoryPersistence())) {

      LOGGER.debug("Testing connectivity for validation purposes with client: {}", validationClientId);

      tempClient.connect(mqttConnectOptions);
      tempClient.disconnect();
      return true;

    } catch (Exception e) {
      LOGGER.debug("Connectivity test failed: {}", e.getMessage());
      return false;
    }
  }

  /**
   * Performs connection close.
   */
  private void close() {
    try {
      mqttClient.close(true);
    } catch (MqttException mqttException) {
      LOGGER.error("Error occurred while attempting to close connection.", mqttException);
    }
  }

  /**
   * Publishes a message to a topic, with the specified quality of service and retention flag.
   * @param topic the topic to which the message should be published.
   * @param message the message content to be sent.
   * @param qos the quality of service with which the message should be published.
   * @param isRetained whether the message should be retained by the broker for the specified topic.
   * @return a {@link CompletableFuture} that returns the published message id on success.
   * @throws MQTT3PublishException if an error occurred publishing the message to the broker.
   * @throws ConnectionException if a connectivity issue occurred while publishing the message.
   */
  public CompletableFuture<Integer> publish(String topic, byte[] message, int qos, boolean isRetained)
      throws MQTT3PublishException, ConnectionException {
    CompletableFuture<Integer> future = new CompletableFuture<>();
    try {
      this.connect();
      mqttClient.publish(topic, message, qos, isRetained, null, new MQTT3PublishActionListener(future, this));
    } catch (ModuleException moduleException) {
      LOGGER.error("MQTT3 Module exception found performing publish operation: " + moduleException.getMessage());
      throw moduleException;
    } catch (IllegalArgumentException e) {
      LOGGER.error("IllegalArgumentException found performing publish operation: " + e.getMessage());
      throw new MQTT3InvalidTopicException(e);
    } catch (MqttPersistenceException persistenceException) {
      LOGGER.error("MqttPersistenceException found performing publish operation: " + persistenceException.getMessage());
      throw new MQTT3PersistenceException(persistenceException);
    } catch (Throwable mqttException) {
      Optional<ConnectionException> connException = resolveMQTT3ConnectionException(mqttException, this);
      if (connException.isPresent()) {
        LOGGER.error("MqttConnectionException found performing publish operation: " + mqttException.getMessage());
        throw connException.get();
      }
      LOGGER.error("MqttException found performing publish operation: " + mqttException.getMessage());
      throw new MQTT3PublishException(mqttException);
    }
    return future;
  }

  /**
   * Subscribe the connection to the provided list of topics, with the qos specified for each.
   * The quality of service array will be matched positionally with each topic for subscription.
   * @param topics the list of topics that this connection's clientId will be subscribed to.
   * @param subscriptionQoSArray the list of quality of service levels for each of the provided topics.
   * @throws ConnectionException if an error occurs during the subscription.
   */
  private void subscribe(String[] topics, int[] subscriptionQoSArray) throws ConnectionException {
    try {
      this.mqttToken = mqttClient.subscribe(topics, subscriptionQoSArray, null, new MQTT3SubscriptionSuccessListener());
      for (String topic : topics) {
        brokerSubscriptions.add(topic);
      }
    } catch (MqttException exception) {
      LOGGER.error("Subscription failed for topics " + topics + " with error:" + exception.getMessage(), exception);
      throw new ConnectionException(exception, this);
    }
  }

  /**
   * @return true if there is a session present for this connection's client id.
   */
  public boolean isSessionPresent() {
    return this.mqttToken.getSessionPresent();
  }

  /**
   * Sets a {@link LWTMessage} for this connection.
   * @param lwtMessage
   */
  private void setLastWillAndTestamentMessage(LWTMessage lwtMessage) {
    if (lwtMessage.getBody() == null || lwtMessage.getTopic() == null) {
      return;
    }
    mqttConnectOptions.setWill(lwtMessage.getTopic(),
                               lwtMessage.getBody().getBytes(),
                               lwtMessage.getQoS().getValue(),
                               lwtMessage.isRetained());
  }

  /**
   * Sets a {@link TlsContextFactory} for the connection.
   * @param tlsContextFactory the {@link TlsContextFactory} with the required data.
   * @throws ConnectionException
   */
  public void setTLSOptions(TlsContextFactory tlsContextFactory) throws ConnectionException {
    try {
      mqttConnectOptions.setSocketFactory(tlsContextFactory.createSocketFactory());
    } catch (KeyManagementException | NoSuchAlgorithmException e) {
      throw new ConnectionException(e, this);
    }
  }

  /**
   * Closes the connection to the broker for this connection instance.
   */
  public void disconnect() {
    if (!mqttClient.isConnected()) {
      LOGGER.debug("Connection is already closed for client with id {}", this.mqttClient.getClientId());
      return;
    }
    try {
      mqttClient.disconnect().waitForCompletion(COMPLETION_WAIT_TIMEOUT_MILLIS);
    } catch (MqttException mqttException) {
      LOGGER.error(mqttException.getMessage(), mqttException);
      try {
        LOGGER.error("Error occurred while attempting to disconnect client with id {}", this.mqttClient.getClientId());
        LOGGER.error("Attempting to forcibly disconnect...");
        mqttClient.disconnectForcibly(COMPLETION_WAIT_TIMEOUT_MILLIS);
      } catch (MqttException mqttException2) {
        LOGGER.error("Error occurred while attempting to forcibly disconnect client " + this.mqttClient.getClientId() + ": " +
            mqttException2, mqttException2);
      }
    } finally {
      this.close();
    }
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) {
      return true;
    }

    if (!(o instanceof DefaultMQTT3Connection)) {
      return false;
    }

    DefaultMQTT3Connection that = (DefaultMQTT3Connection) o;

    return new EqualsBuilder().append(mqttClient, that.mqttClient)
        .append(connectionTimeoutMultiplier, that.connectionTimeoutMultiplier)
        .append(mqttConnectOptions, that.mqttConnectOptions)
        .isEquals();
  }

  @Override
  public int hashCode() {
    return Objects.hash(mqttClient, connectionTimeoutMultiplier, mqttConnectOptions);
  }

}
