/*
 * (c) 2003-2021 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.connectivity.rest.commons.api.source;

import static com.mulesoft.connectivity.rest.commons.internal.RestConstants.CONFIG;
import static com.mulesoft.connectivity.rest.commons.internal.RestConstants.CONNECTION;
import static com.mulesoft.connectivity.rest.commons.internal.RestConstants.CONTEXT_KEY_CUSTOM_PARAMETERS;
import static com.mulesoft.connectivity.rest.commons.internal.RestConstants.CONTEXT_KEY_PARAMETERS;
import static com.mulesoft.connectivity.rest.commons.internal.util.RequestStreamingUtils.doRequestAndConsumeString;
import static com.mulesoft.connectivity.rest.commons.internal.util.RestSdkUtils.closeStream;
import static com.mulesoft.connectivity.rest.commons.internal.util.RestSdkUtils.consumeStringAndClose;
import static com.mulesoft.connectivity.rest.commons.internal.util.RestSdkUtils.getTypedValueOrNull;
import static com.mulesoft.connectivity.rest.commons.internal.util.RestSdkUtils.isBlank;
import static com.mulesoft.connectivity.rest.commons.internal.util.RestSdkUtils.isNotBlank;
import static com.mulesoft.connectivity.rest.commons.internal.util.RestSdkUtils.resolveCharset;
import static com.mulesoft.connectivity.rest.commons.internal.util.StreamUtils.resolveCursorProvider;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.Optional.empty;
import static java.util.stream.Collectors.toList;
import static java.util.stream.StreamSupport.stream;
import static org.mule.runtime.api.el.BindingContext.builder;
import static org.mule.runtime.api.metadata.DataType.OBJECT;
import static org.mule.runtime.api.metadata.DataType.STRING;

import org.mule.runtime.api.connection.ConnectionException;
import org.mule.runtime.api.connection.ConnectionProvider;
import org.mule.runtime.api.el.BindingContext;
import org.mule.runtime.api.el.ExpressionLanguage;
import org.mule.runtime.api.el.ValidationResult;
import org.mule.runtime.api.exception.MuleException;
import org.mule.runtime.api.exception.MuleRuntimeException;
import org.mule.runtime.api.metadata.DataType;
import org.mule.runtime.api.metadata.MediaType;
import org.mule.runtime.api.metadata.TypedValue;
import org.mule.runtime.api.streaming.bytes.CursorStreamProvider;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.extension.api.annotation.param.Config;
import org.mule.runtime.extension.api.annotation.param.Connection;
import org.mule.runtime.extension.api.connectivity.oauth.AccessTokenExpiredException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.source.PollContext;
import org.mule.runtime.extension.api.runtime.source.PollingSource;
import org.mule.runtime.extension.api.runtime.source.SourceCallbackContext;

import com.mulesoft.connectivity.rest.commons.api.binding.HttpRequestBinding;
import com.mulesoft.connectivity.rest.commons.api.binding.ParameterBinding;
import com.mulesoft.connectivity.rest.commons.api.configuration.RestConfiguration;
import com.mulesoft.connectivity.rest.commons.api.configuration.StreamingType;
import com.mulesoft.connectivity.rest.commons.api.connection.RestConnection;
import com.mulesoft.connectivity.rest.commons.api.error.SourceStartingException;
import com.mulesoft.connectivity.rest.commons.api.operation.HttpResponseAttributes;
import com.mulesoft.connectivity.rest.commons.internal.util.DwUtils;
import com.mulesoft.connectivity.rest.commons.internal.util.FromCursorProviderInputStream;
import com.mulesoft.connectivity.rest.commons.internal.util.RequestStreamingUtils;
import com.mulesoft.connectivity.rest.commons.internal.util.RestRequestBuilder;
import com.mulesoft.connectivity.rest.commons.internal.util.SplitPayloadUtils;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import javax.inject.Inject;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Base class for defining a polling source that consumes a remote REST endpoint in a connector.
 *
 * @since 1.0
 */
public abstract class RestPollingSource
    extends PollingSource<InputStream, Object> {

  private static final Logger LOGGER = LoggerFactory.getLogger(RestPollingSource.class);
  private static final String ITEM_BINDING = "item";

  @Config
  protected RestConfiguration config;

  @Connection
  private ConnectionProvider<RestConnection> connectionProvider;

  @Inject
  private ExpressionLanguage expressionLanguage;

  protected RestConnection connection;

  protected final String watermarkExpression;
  protected final String identityExpression;
  protected final String itemsExpression;
  protected final String requestBodyExpression;
  protected final String eventExpression;
  protected final String startValueExpression;

  private TypedValue<?> startValue;
  private MultiMap<String, TypedValue<?>> parameterBindings;
  private MultiMap<String, TypedValue<?>> customParameterBindings;

  public RestPollingSource(String itemsExpression, String watermarkExpression, String identityExpression,
                           String requestBodyExpression, String eventExpression, String startValueExpression) {
    requireNonNull(watermarkExpression);
    requireNonNull(itemsExpression);

    this.watermarkExpression = watermarkExpression;
    this.identityExpression = identityExpression;
    this.itemsExpression = itemsExpression;
    this.requestBodyExpression = requestBodyExpression;
    this.eventExpression = eventExpression;
    this.startValueExpression = startValueExpression;
  }

  @Override
  protected void doStart() throws MuleException {
    validateExpression(watermarkExpression);
    validateExpression(identityExpression);
    validateExpression(itemsExpression);
    validateExpression(requestBodyExpression);
    validateExpression(startValueExpression);

    for (String bindingExpression : getParameterBinding().getAllBindingExpressions()) {
      if (DwUtils.isExpression(bindingExpression)) {
        validateExpression(bindingExpression);
      }
    }

    connection = connectionProvider.connect();
    evaluateStartValue();

    parameterBindings = (MultiMap<String, TypedValue<?>>) resolveCursorProvider((Map) getParameterValues());
    customParameterBindings = (MultiMap<String, TypedValue<?>>) resolveCursorProvider((Map) getCustomParameterValues());
  }

  private void evaluateStartValue() {
    BindingContext.Builder builder = builder();
    addParametersBinding(builder);
    addConfigAndConnectionBinding(builder);
    BindingContext bindingContext = builder.build();
    startValue = getStartValue(bindingContext);
  }

  /**
   * Returns the initial value start value will hold. Defaults to the expression parameterized as the startValueExpression, but it
   * can execute any code (such as a http call to obtain data).
   *
   * @param bindingContext the context to access the parameters, config or connection in the DW engine
   * @return a TypedValue with the information that will be available later on in the `startValue` binding in the trigger later on
   */
  protected TypedValue<?> getStartValue(BindingContext bindingContext) {
    return getExpressionLanguage().evaluate(startValueExpression, bindingContext);
  }

  @Override
  protected void doStop() {
    connectionProvider.disconnect(connection);
  }

  /**
   * Returns the parameter binding configuration of this RestPollingSource.
   */
  protected abstract HttpRequestBinding getParameterBinding();

  /**
   * Returns a MultiMap containing all the parameters this source exposes to the user. Each entry must contain the parameter name
   * as key and its value represented as a TypedValue.
   */
  protected abstract MultiMap<String, TypedValue<?>> getParameterValues();

  /**
   * Returns a MultiMap containing all the custom parameters this source exposes to the user. Each entry must contain the custom
   * parameter name as key and its value represented as a TypedValue.
   */
  protected abstract MultiMap<String, TypedValue<?>> getCustomParameterValues();

  /**
   * Returns the request path for this RestPollingSource with placeholders for its uri parameters. i.e: /user/{username}/events
   */
  protected abstract String getPathTemplate();

  /**
   * Return a RequestBuilder configured to do the request to the endpoint this source must poll. BaseUri, Path and Method must be
   * configured.
   *
   * @param path The request path with the placeholders replaced with its corresponding values.
   */
  protected abstract RestRequestBuilder getRequestBuilder(String path);

  protected String getId() {
    return getClass().getSimpleName();
  }

  /**
   * @return The injected config
   */
  protected RestConfiguration getConfig() {
    return config;
  }

  /**
   * @return The injected expression language
   */
  protected ExpressionLanguage getExpressionLanguage() {
    return expressionLanguage;
  }

  /**
   * @return The injected connection
   */
  protected RestConnection getConnection() {
    return connection;
  }

  @Override
  public void poll(PollContext<InputStream, Object> pollContext) {
    final Serializable watermark = pollContext.getWatermark().orElse(null);

    Result<TypedValue<String>, HttpResponseAttributes> result = null;
    try {
      result = doRequestAndConsumeString(connection,
                                         config,
                                         getRestRequestBuilder(watermark),
                                         getDefaultResponseMediaType(),
                                         expressionLanguage);
    } catch (AccessTokenExpiredException e) {
      LOGGER.info(format("Trigger '%s': about to notify access token expiration to runtime...", getId()), e);
      pollContext.onConnectionException(new ConnectionException(e));
      LOGGER.info(format("Trigger '%s': access token expiration notified to runtime.", getId()), e);
    } catch (MuleRuntimeException e) {
      LOGGER.warn(format("Trigger '%s': Mule runtime exception found while executing poll: '%s'", getId(), e.getMessage()), e);
    }
    if (result != null) {
      TypedValue<String> fullResponse = result.getOutput();
      HttpResponseAttributes attributes = result.getAttributes().orElse(null);
      for (TypedValue<CursorStreamProvider> item : getItems(fullResponse, watermark)) {
        pollContext.accept(getPollItemConsumer(watermark, fullResponse, attributes, item));
      }
    }
  }

  private RestRequestBuilder getRestRequestBuilder(Serializable watermark) {
    HttpRequestBinding parameterBinding = getParameterBinding();
    RestRequestBuilder requestBuilder = getRequestBuilder(getPathTemplate());
    addUriParams(parameterBinding.getUriParams(), watermark, requestBuilder);
    TypedValue<InputStream> requestBody = getRequestBody(watermark);
    if (requestBody != null) {
      requestBuilder.setBody(requestBody, StreamingType.AUTO);
    }

    parameterBinding
        .getHeaders()
        .forEach(i -> requestBuilder.addHeader(i.getKey(), getParameterValue(i.getValue(), watermark)));
    parameterBinding
        .getQueryParams()
        .forEach(i -> requestBuilder.addQueryParam(i.getKey(), getParameterValue(i.getValue(), watermark)));
    return requestBuilder;
  }

  private Consumer<PollContext.PollItem<InputStream, Object>> getPollItemConsumer(Serializable watermark,
                                                                                  TypedValue<String> fullResponse,
                                                                                  HttpResponseAttributes attributes,
                                                                                  TypedValue<CursorStreamProvider> item) {
    return pollItem -> {
      TypedValue<InputStream> inputStreamTypedValue = getEvent(item);

      Result<InputStream, Object> itemResult =
          Result.<InputStream, Object>builder()
              .output(inputStreamTypedValue.getValue())
              .mediaType(inputStreamTypedValue.getDataType().getMediaType())
              .attributes(attributes)
              .build();

      pollItem.setResult(itemResult);

      if (isNotBlank(watermarkExpression)) {
        pollItem.setWatermark(getItemWatermark(fullResponse, watermark, item));
      }

      if (isNotBlank(identityExpression)) {
        pollItem.setId(getIdentity(fullResponse, watermark, item));
      }
    };
  }

  private void addUriParams(List<ParameterBinding> uriParams, Serializable watermark, RestRequestBuilder restRequestBuilder) {
    uriParams.forEach(i -> restRequestBuilder.addUriParam(i.getKey(), getParameterValue(i.getValue(), watermark)));
  }

  private Object getParameterValue(String expression, Serializable watermark) {
    if (!DwUtils.isExpression(expression)) {
      return expression;
    }

    return expressionLanguage
        .evaluate(expression, buildContext(null, watermark, null))
        .getValue();
  }

  protected DataType getWatermarkDataType() {
    return STRING;
  }

  private Serializable getItemWatermark(TypedValue<?> payload, Serializable currentWatermark,
                                        TypedValue<CursorStreamProvider> item) {
    return (Serializable) expressionLanguage
        .evaluate(watermarkExpression, getWatermarkDataType(), buildContext(payload, currentWatermark, item))
        .getValue();
  }

  private String getIdentity(TypedValue<?> payload, Serializable currentWatermark, TypedValue<CursorStreamProvider> item) {
    return (String) expressionLanguage
        .evaluate(identityExpression, STRING, buildContext(payload, currentWatermark, item))
        .getValue();
  }

  private DataType getRequestBodyMediaType() {
    return org.mule.runtime.api.metadata.DataType.builder().type(String.class).mediaType(getRequestBodyDataType()).build();
  }

  protected String getRequestBodyDataType() {
    return MediaType.APPLICATION_JSON.toString();
  }

  private TypedValue<InputStream> getRequestBody(Serializable currentWatermark) {
    if (isNotBlank(requestBodyExpression)) {
      TypedValue<?> body =
          expressionLanguage.evaluate(requestBodyExpression, getRequestBodyMediaType(),
                                      buildContext(null, currentWatermark, null));

      TypedValue<String> stringTypedValue = consumeStringAndClose(body.getValue(), getDefaultResponseMediaType(),
                                                                  resolveCharset(empty(), getDefaultResponseMediaType()));

      return new TypedValue<>(new ByteArrayInputStream(stringTypedValue.getValue().getBytes()), body.getDataType());
    }

    return null;
  }

  private List<TypedValue<CursorStreamProvider>> getItems(TypedValue<String> fullResponse, Serializable currentWatermark) {
    TypedValue<?> result =
        expressionLanguage.evaluate(itemsExpression, buildContext(fullResponse, currentWatermark, null));

    Iterator<TypedValue<?>> splitResult = SplitPayloadUtils.split(expressionLanguage, result, itemsExpression);

    Iterable<TypedValue<?>> iterable = () -> splitResult;
    return stream(iterable.spliterator(), false)
        .map(RequestStreamingUtils::getCursorStreamProviderValueFromSplitResult)
        .collect(toList());
  }

  private TypedValue<InputStream> getEvent(TypedValue<CursorStreamProvider> item) {
    if (eventExpression != null) {
      item = (TypedValue<CursorStreamProvider>) expressionLanguage.evaluate(eventExpression, item.getDataType(),
                                                                            builder().addBinding(ITEM_BINDING, item).build());
    }

    return new TypedValue<>(FromCursorProviderInputStream.of(item.getValue()), item.getDataType());
  }

  private void validateExpression(String expression) throws SourceStartingException {
    if (isBlank(expression)) {
      return;
    }

    ValidationResult validationResult = expressionLanguage.validate(expression);

    if (!validationResult.isSuccess()) {
      throw new SourceStartingException(format("Expression is not valid: %s", expression));
    }
  }

  private BindingContext buildContext(
                                      TypedValue<?> payload, Serializable currentWatermark,
                                      TypedValue<CursorStreamProvider> item) {
    BindingContext.Builder builder =
        builder()
            .addBinding("payload", payload)
            .addBinding("watermark", TypedValue.of(currentWatermark))
            .addBinding("startValue", startValue);
    addParametersBinding(builder);
    addConfigAndConnectionBinding(builder);

    if (item != null) {
      builder.addBinding(ITEM_BINDING, item);
    }

    return builder.build();
  }

  private void addParametersBinding(BindingContext.Builder builder) {
    builder
        .addBinding(CONTEXT_KEY_PARAMETERS, TypedValue.of(parameterBindings))
        .addBinding(CONTEXT_KEY_CUSTOM_PARAMETERS, TypedValue.of(customParameterBindings));
  }

  private void addConfigAndConnectionBinding(BindingContext.Builder builder) {
    builder
        .addBinding(CONFIG, getTypedValueOrNull(config.getBindings()))
        .addBinding(CONNECTION, getTypedValueOrNull(connection.getBindings()));
  }

  @Override
  public void onRejectedItem(Result<InputStream, Object> result,
                             SourceCallbackContext callbackContext) {
    if (result.getOutput() != null) {
      closeStream(result.getOutput());
    }
    LOGGER.debug("Item Rejected");
  }

  protected MediaType getDefaultResponseMediaType() {
    return MediaType.APPLICATION_JSON;
  }
}
