/*
 * (c) 2003-2021 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.connectivity.rest.sdk.internal.connectormodel.dw;

import static com.mulesoft.connectivity.rest.sdk.internal.connectormodel.parameter.ParameterType.BODY;
import static java.lang.String.format;
import static java.util.Optional.empty;
import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;

import org.mule.runtime.api.el.ExpressionExecutionException;
import org.mule.weave.v2.grammar.DynamicSelectorOpId$;
import org.mule.weave.v2.grammar.ObjectKeyValueSelectorOpId$;
import org.mule.weave.v2.grammar.ValueSelectorOpId$;
import org.mule.weave.v2.parser.MappingParser;
import org.mule.weave.v2.parser.ast.AstNode;
import org.mule.weave.v2.parser.ast.AstNodeHelper;
import org.mule.weave.v2.parser.ast.header.directives.ContentType;
import org.mule.weave.v2.parser.ast.header.directives.OutputDirective;
import org.mule.weave.v2.parser.ast.operators.BinaryOpNode;
import org.mule.weave.v2.parser.ast.structure.DocumentNode;
import org.mule.weave.v2.parser.ast.structure.NameNode;
import org.mule.weave.v2.parser.ast.structure.StringNode;
import org.mule.weave.v2.parser.ast.variables.VariableReferenceNode;
import org.mule.weave.v2.parser.phase.ParsingResult;
import org.mule.weave.v2.parser.phase.PhaseResult;
import org.mule.weave.v2.sdk.ParsingContextFactory;
import org.mule.weave.v2.sdk.WeaveResource;

import com.mulesoft.connectivity.rest.sdk.internal.connectormodel.generic.Argument;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import javax.ws.rs.core.MediaType;

import scala.Option;
import scala.collection.JavaConverters;
import scala.collection.Seq;

public class DataWeaveExpressionParser {

  private static final String DEFAULT_EXPRESSION_PREFIX = "#[";
  private static final String DEFAULT_EXPRESSION_POSTFIX = "]";
  private static final String DW_PREFIX = "dw";
  private static final String PREFIX_EXPR_SEPARATOR = ":";
  private static final int DW_PREFIX_LENGTH = (DW_PREFIX + PREFIX_EXPR_SEPARATOR).length();

  public static boolean isBodyBindingUsed(Argument argument) {
    return isBindingUsed(argument.getValue().getValue(), BODY.getBinding());
  }

  public static boolean isBindingUsed(String script, String variableName) {
    PhaseResult<ParsingResult<DocumentNode>> parse = parseDWScript(script);
    if (parse.hasResult()) {
      DocumentNode documentNode = parse.getResult().astNode();
      Seq<VariableReferenceNode> binaryOpNodeSeq = AstNodeHelper.collectChildrenWith(documentNode, VariableReferenceNode.class);
      Collection<VariableReferenceNode> binaryOpNodes = JavaConverters.asJavaCollection(binaryOpNodeSeq);
      return binaryOpNodes.stream().anyMatch((n) -> n.variable().name().equals(variableName));
    } else {
      return false;
    }
  }

  /**
   * The way this method works is by getting the AST from DW to navigate through the tree, looking for the usages of the binding
   * and it's fields. So, for a script='parameters.someParameter', and a variableName='parameters', the AST will look like:
   * 
   * <pre>
   *   DocumentNode
   *     NullSafeNode
   *       BinaryOpNode(ValueSelector)
   *         VariableReferenceNode(parameters)
   *           NameIdentifier(parameters)
   *         NameNode
   *           StringNode(someParameter)
   * </pre>
   * 
   * @param script to generate the AST from
   * @param variableName root binding to look for its fields accessing
   * @return the collection of fields being accessed by the root binding
   */
  public static String[] selectionsFromBinding(String script, String variableName) {
    PhaseResult<ParsingResult<DocumentNode>> parse = parseDWScript(script);
    if (parse.hasResult()) {
      DocumentNode documentNode = parse.getResult().astNode();
      Seq<BinaryOpNode> binaryOpNodeSeq = AstNodeHelper.collectChildrenWith(documentNode, BinaryOpNode.class);

      Collection<BinaryOpNode> binaryOpNodes = JavaConverters.asJavaCollection(binaryOpNodeSeq);
      return binaryOpNodes.stream()
          .filter((binaryOpNode) -> {
            boolean isSelector = isSelectorOperation(binaryOpNode);
            if (isSelector) {
              if (binaryOpNode.rhs() instanceof NameNode || binaryOpNode.rhs() instanceof StringNode) {
                AstNode lhs = binaryOpNode.lhs();
                if (lhs instanceof VariableReferenceNode) {
                  return ((VariableReferenceNode) lhs).variable().name().equals(variableName);
                }
              }
            }
            return false;
          })
          .map((binaryOpNode) -> {
            if (binaryOpNode.rhs() instanceof NameNode) {
              AstNode astNode = ((NameNode) binaryOpNode.rhs()).keyName();
              if (astNode instanceof StringNode) {
                return ((StringNode) astNode).value();
              }
            } else if (binaryOpNode.rhs() instanceof StringNode) {
              return ((StringNode) binaryOpNode.rhs()).value();
            }
            throw new IllegalArgumentException("Script: `" + script
                + "` produced an invalid script. This is a bug.");
          }).toArray(String[]::new);

    } else {
      return new String[0];
    }

  }

  private static PhaseResult<ParsingResult<DocumentNode>> parseDWScript(String script) {
    String sanitizedScript = sanitize(script);
    return MappingParser
        .parse(MappingParser.parsingPhase(), WeaveResource.anonymous(sanitizedScript),
               ParsingContextFactory.createParsingContext());
  }

  private static boolean isSelectorOperation(BinaryOpNode binaryOpNode) {
    return binaryOpNode.opId().equals(ValueSelectorOpId$.MODULE$) ||
        binaryOpNode.opId().equals(DynamicSelectorOpId$.MODULE$) ||
        binaryOpNode.opId().equals(ObjectKeyValueSelectorOpId$.MODULE$);
  }

  /**
   * Stolen from org.mule.runtime.core.internal.el.ExpressionLanguageUtils#sanitize(java.lang.String), for further references.
   * 
   * @param expression expression to remove the #[], or the dw: elements that have no value to the actual script
   * @return a sanitized script
   */
  private static String sanitize(String expression) {
    // TODO: (RSDK-728) hack to fix org.mule.runtime.api.el.ExpressionExecutionException: Unbalanced brackets in expression
    expression = expression.trim();

    String sanitizedExpression;
    if (expression.startsWith(DEFAULT_EXPRESSION_PREFIX)) {
      if (!expression.endsWith(DEFAULT_EXPRESSION_POSTFIX)) {
        throw new ExpressionExecutionException(createStaticMessage(format("Unbalanced brackets in expression '%s'", expression)));
      }
      sanitizedExpression =
          expression.substring(DEFAULT_EXPRESSION_PREFIX.length(), expression.length() - DEFAULT_EXPRESSION_POSTFIX.length());
    } else {
      sanitizedExpression = expression;
    }

    if (sanitizedExpression.startsWith(DW_PREFIX + PREFIX_EXPR_SEPARATOR)
        // Handle DW functions that start with dw:: without removing dw:
        && !sanitizedExpression.substring(DW_PREFIX_LENGTH, DW_PREFIX_LENGTH + 1).equals(PREFIX_EXPR_SEPARATOR)) {
      sanitizedExpression = sanitizedExpression.substring(DW_PREFIX_LENGTH);
    }
    return sanitizedExpression;
  }

  public static Optional<MediaType> getOutputMediaType(String script) {
    PhaseResult<ParsingResult<DocumentNode>> parse = parseDWScript(script);
    if (parse.hasResult()) {
      DocumentNode documentNode = parse.getResult().astNode();
      Seq<OutputDirective> outputDirectiveSeq = AstNodeHelper.collectChildrenWith(documentNode, OutputDirective.class);

      Collection<OutputDirective> outputDirectiveNodes = JavaConverters.asJavaCollection(outputDirectiveSeq);
      return outputDirectiveNodes.stream()
          .map(outputDirective -> {
            Option<ContentType> mime = outputDirective.mime();
            if (mime.isDefined()) {
              MediaType mediaType = MediaType.valueOf(mime.get().mime());
              if (outputDirective.options().isDefined()) {
                Map<String, String> params = new HashMap<>();
                JavaConverters.asJavaCollection(outputDirective.options().get())
                    .stream()
                    .filter(directiveOption -> directiveOption.value() instanceof StringNode)
                    .forEach(directiveOption -> params.put(directiveOption.name().name(),
                                                           ((StringNode) directiveOption.value()).value()));
                mediaType = new MediaType(mediaType.getType(), mediaType.getSubtype(), params);
              }
              return mediaType;
            }
            return null;
          })
          .filter(value -> value != null)
          .findFirst();
    }
    return empty();
  }
}
