/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.common.client.sse;

import static org.mule.runtime.http.api.HttpConstants.HttpStatus.OK;
import static org.mule.runtime.http.api.HttpConstants.Method.GET;
import static org.mule.runtime.http.api.HttpHeaders.Names.ACCEPT;
import static org.mule.runtime.http.api.HttpHeaders.Names.CACHE_CONTROL;
import static org.mule.runtime.http.api.HttpHeaders.Names.CONTENT_TYPE;
import static org.mule.runtime.http.api.HttpHeaders.Names.LAST_EVENT_ID;
import static org.mule.runtime.http.api.HttpHeaders.Values.NO_CACHE;
import static org.mule.runtime.http.api.HttpHeaders.Values.TEXT_EVENT_STREAM;

import static org.slf4j.LoggerFactory.getLogger;

import org.mule.runtime.http.api.client.HttpRequestOptions;
import org.mule.runtime.http.api.domain.entity.EmptyHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.request.HttpRequestBuilder;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.mule.runtime.http.api.sse.ServerSentEvent;
import org.mule.runtime.http.api.sse.client.SseFailureContext;
import org.mule.runtime.http.api.sse.client.SseListener;
import org.mule.runtime.http.api.sse.client.SseSource;
import org.mule.runtime.http.api.sse.client.SseSourceConfig;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import org.slf4j.Logger;

/**
 * Implementation of {@link SseSource} based on
 * <a href="https://html.spec.whatwg.org/multipage/server-sent-events.html">server-sent-events spec</a>.
 */
public class DefaultSseSource implements SseSource, SseListener, InternalConnectable {

  private static final Logger LOGGER = getLogger(DefaultSseSource.class);

  private final HttpResponse httpResponse;

  private final String uri;
  private final InternalClient httpClient;
  private final Consumer<HttpRequestBuilder> requestCustomizer;
  private final HttpRequestOptions requestOptions;
  private final RetryHelper retryHelper;
  private final ScheduledExecutorService ioScheduler;

  private final AtomicInteger readyState;

  private final Map<String, SseListener> eventListenersByTopic = new ConcurrentHashMap<>();
  private SseListener fallbackListener = event -> {
    // Do nothing by default.
  };

  private String lastEventId = null;
  private final List<Consumer<SseFailureContext>> onConnectionFailureCallbacks = new ArrayList<>();
  private CompletableFuture<HttpResponse> responseFuture;

  private final boolean preserveHeaderCase;

  public DefaultSseSource(SseSourceConfig config,
                          InternalClient httpClient,
                          ScheduledExecutorService ioScheduler) {
    this.uri = config.getUrl();
    this.httpClient = httpClient;
    this.requestCustomizer = config.getRequestCustomizer();
    this.requestOptions = config.getRequestOptions();
    this.readyState = new AtomicInteger(READY_STATUS_CLOSED);
    this.retryHelper = new RetryHelper(ioScheduler, config.getRetryConfig(), this);
    this.preserveHeaderCase = config.isPreserveHeaderCase();
    this.ioScheduler = ioScheduler;
    this.httpResponse = config.getResponse();
  }

  @Override
  public int getReadyState() {
    return readyState.get();
  }

  @Override
  public synchronized void open() {
    if (READY_STATUS_CLOSED != readyState.get()) {
      LOGGER.debug("Connection is not needed when trying to open SSE source, skipping.");
      return;
    }

    if (httpResponse != null) {
      openFromResponse(httpResponse);
    } else {
      internalConnect();
    }
  }

  private void openFromResponse(HttpResponse response) {
    if (response == null) {
      throw new IllegalArgumentException("Response cannot be null when trying to open an SseSource");
    }

    if (!isSuccessfullyConnected(response)) {
      throw new IllegalArgumentException("Response is not a successfully established SSE connection. Status code: '%d', Content-Type: '%s'"
          .formatted(response.getStatusCode(), response.getHeaderValue(CONTENT_TYPE)));
    }

    readyState.set(READY_STATUS_OPEN);
    ioScheduler.submit(new SseStreamConsumer(response.getEntity(), this, httpClient.getName()));
  }

  @Override
  public void doOnConnectionFailure(Consumer<SseFailureContext> onConnectionFailure) {
    onConnectionFailureCallbacks.add(onConnectionFailure);
  }

  @Override
  public void register(SseListener listener) {
    fallbackListener = listener;
  }

  @Override
  public void register(String eventName, SseListener listener) {
    eventListenersByTopic.put(eventName, listener);
  }

  @Override
  public synchronized void close() {
    if (null != responseFuture) {
      responseFuture.cancel(true);
    }
    retryHelper.abortReties();
    this.onClose();
  }

  @Override
  public synchronized void onEvent(ServerSentEvent event) {
    if (READY_STATUS_CLOSED == readyState.get()) {
      throw new IllegalStateException("SSE source is already closed");
    }
    event.getId().ifPresent(id -> lastEventId = id);
    event.getRetryDelay().ifPresent(retryHelper::setDelayIfAllowed);
    eventListenersByTopic.getOrDefault(event.getName(), fallbackListener).onEvent(event);
  }

  @Override
  public synchronized void onClose() {
    if (retryHelper.shouldRetryOnStreamEnd()) {
      readyState.set(READY_STATUS_CONNECTING);
      retryHelper.scheduleReconnection();
    } else {
      eventListenersByTopic.values().forEach(SseListener::onClose);
      fallbackListener.onClose();
      readyState.set(READY_STATUS_CLOSED);
    }
  }

  @Override
  public synchronized void internalConnect() {
    readyState.set(READY_STATUS_CONNECTING);
    HttpRequest request = createInitiatorRequest();
    ProgressiveBodyDataListener dataListener = new ServerSentEventDataListener(this, httpClient.getName());
    responseFuture = httpClient.doSendAsync(request, requestOptions, dataListener);
    responseFuture.whenComplete(this::handleResponseOrError);
  }

  private synchronized void handleResponseOrError(HttpResponse httpResponse, Throwable error) {
    if (isSuccessfullyConnected(httpResponse)) {
      readyState.set(READY_STATUS_OPEN);
      return;
    }

    SseFailureContext ctx = new SseFailureContextImpl(httpResponse, error, retryHelper);
    for (Consumer<SseFailureContext> callback : onConnectionFailureCallbacks) {
      callback.accept(ctx);
    }

    // Some error callback could stop the retry mechanism
    if (!retryHelper.isRetryEnabled()) {
      onClose();
      return;
    }

    readyState.set(READY_STATUS_CONNECTING);
    retryHelper.scheduleReconnection();
  }

  private boolean isSuccessfullyConnected(HttpResponse httpResponse) {
    if (null == httpResponse) {
      return false;
    }

    if (OK.getStatusCode() != httpResponse.getStatusCode()) {
      return false;
    }

    var contentTypeHeaderValue = httpResponse.getHeaderValue(CONTENT_TYPE);
    if (contentTypeHeaderValue == null) {
      return false;
    }

    contentTypeHeaderValue = contentTypeHeaderValue.split(";")[0];
    return TEXT_EVENT_STREAM.equalsIgnoreCase(contentTypeHeaderValue);
  }

  private HttpRequest createInitiatorRequest() {
    var builder = HttpRequest.builder(preserveHeaderCase).method(GET)
        .uri(uri)
        .addHeader(ACCEPT, TEXT_EVENT_STREAM)
        .addHeader(CACHE_CONTROL, NO_CACHE)
        .entity(new EmptyHttpEntity());
    if (null != lastEventId) {
      builder.addHeader(LAST_EVENT_ID, lastEventId);
    }
    requestCustomizer.accept(builder);
    return builder.build();
  }
}
