/*
 * Copyright (c) MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */

package org.mule.tooling.client.bootstrap.internal.reflection;

import static com.google.common.base.Throwables.propagate;
import static java.lang.String.format;
import static java.lang.Thread.currentThread;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyMap;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.apache.commons.lang3.ObjectUtils.defaultIfNull;
import org.mule.tooling.client.api.exception.TimeoutException;
import org.mule.tooling.client.api.exception.ToolingException;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Function;

import org.apache.commons.io.input.ClassLoaderObjectInputStream;
import org.apache.commons.lang3.SerializationException;
import org.apache.commons.lang3.SerializationUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;

/**
 * Dispatcher for calling methods using reflection.
 *
 * @since 1.0
 */
public final class Dispatcher {

  private static final String DISPATCH_METHOD_NAME = "invokeMethod";
  private static final String IS_FEATURE_ENABLED_NAME = "isFeatureEnabled";
  private static final List<String> DESERIALIZED_PACKAGES =
      asList(new String[] {"org.mule.tooling", "com.thoughtworks.xstream", "org.mule.maven"});
  private static Logger LOGGER = LoggerFactory.getLogger(Dispatcher.class);

  private Object target;
  private Method dispatchMethod;
  private ClassLoader classLoader;
  private ExecutorService executorService;

  public Dispatcher(Object target, ClassLoader classLoader, ExecutorService executorService) {
    requireNonNull(target, "target cannot be null");
    requireNonNull(classLoader, "classLoader cannot be null");
    requireNonNull(executorService, "executorService cannot be null");

    this.target = target;
    this.dispatchMethod = findMethod(DISPATCH_METHOD_NAME);
    this.dispatchMethod.setAccessible(true);
    this.classLoader = classLoader;
    this.executorService = executorService;
  }

  public Dispatcher newReflectionInvoker(Object target) {
    return new Dispatcher(target, classLoader, executorService);
  }

  public Object invoke(Method method, long timeout, Object[] args) {
    if (LOGGER.isTraceEnabled()) {
      LOGGER.trace(format("Dispatching method '%s' on '%s' (timeout=%ss)", method, target.getClass(), timeout));
    }
    final Map<String, String> copyOfContextMap = defaultIfNull(MDC.getCopyOfContextMap(), emptyMap());
    Future<?> future = executorService.submit(new ToolingCallable(method, args, copyOfContextMap));
    try {
      if (timeout != -1) {
        return future.get(timeout, MILLISECONDS);
      }
      return future.get();
    } catch (java.util.concurrent.TimeoutException e) {
      throw new TimeoutException(format("Couldn't resolve the operation in the the expected time frame (%sms)", timeout), e);
    } catch (ExecutionException e) {
      if (LOGGER.isTraceEnabled()) {
        LOGGER.trace(format("Error response from method %s dispatched on Tooling Client was: %s", method,
                            e.getCause().getMessage()),
                     e);
      }
      // If it not a ToolingException or subclass it will not try to serialize/deserialize so it can be acceded from this class loader.
      if (!isDeserializedException(e.getCause().getClass().getPackage().getName())) {
        throw propagate(e.getCause());
      }
      // Exceptions are supposed to be serializable, so use Object serialization instead...
      Throwable deserializedException;
      try {
        deserializedException = (Throwable) deserialize(new ByteArrayInputStream(SerializationUtils.serialize(e.getCause())),
                                                        this.getClass().getClassLoader());
      } catch (Exception serializationException) {
        LOGGER.error("Error while trying to deserialize exception", serializationException);
        throw new ToolingException("Internal error", e.getCause());
      }
      throw propagate(deserializedException);
    } catch (InterruptedException e) {
      throw new RuntimeException(e);
    }
  }

  private static boolean isDeserializedException(String packageName) {
    return DESERIALIZED_PACKAGES.stream().anyMatch(deserializedPackage -> packageName.startsWith(deserializedPackage));
  }

  public Object dispatchRemoteMethod(String targetMethodName) {
    return dispatchRemoteMethod(targetMethodName, -1, null, null);
  }

  public Object dispatchRemoteMethod(String targetMethodName, List<Class> classes, List<String> args) {
    return dispatchRemoteMethod(targetMethodName, -1, classes, args);
  }

  public Object dispatchRemoteMethod(String targetMethodName, long timeout, List<Class> classes, List<String> args) {
    List<Object> arguments = new ArrayList<>();
    arguments.add(targetMethodName);
    addArrayArgument(classes, clazz -> clazz.getName(), arguments);
    addArrayArgument(args, arg -> arg, arguments);

    if (LOGGER.isTraceEnabled()) {
      LOGGER.trace(format("Dispatching method '%s' on '%s' (timeout=%ss)", targetMethodName, target.getClass(), timeout));
    }
    return invoke(dispatchMethod, timeout, arguments.toArray());
  }

  private <T extends Object> void addArrayArgument(List<T> args, Function<T, String> function, List<Object> target) {
    if (args != null && args.size() > 0) {
      target.add(args.stream().map(argument -> function.apply(argument)).toArray(String[]::new));
    } else {
      target.add(new String[0]);
    }

  }

  public Method findMethod(String methodName) {
    List<Method> methodsFound = new ArrayList<>();

    Method[] methods = target.getClass().getMethods();
    for (Method method : methods) {
      if (method.getName().equals(methodName)) {
        methodsFound.add(method);
      }
    }

    if (methodsFound.isEmpty()) {
      throw new IllegalStateException(new NoSuchMethodException(format("Method '%s' not found on %s", methodName,
                                                                       target.getClass())));
    }

    if (methodsFound.size() > 1) {
      throw new IllegalStateException(format("%s has more than one %s method", target.getClass(), methodName));
    }

    return methodsFound.get(0);
  }

  public static Object deserialize(InputStream inputStream, ClassLoader cl) {
    if (inputStream == null) {
      throw new IllegalArgumentException("The InputStream must not be null");
    }
    if (cl == null) {
      throw new IllegalArgumentException("The ClassLoader must not be null");
    }
    ObjectInputStream in = null;
    try {
      // stream closed in the finally
      in = new ClassLoaderObjectInputStream(cl, inputStream);
      return in.readObject();
    } catch (ClassNotFoundException ex) {
      throw new SerializationException(ex);
    } catch (IOException ex) {
      throw new SerializationException(ex);
    } catch (Exception ex) {
      throw new SerializationException(ex);
    } finally {
      try {
        if (in != null) {
          in.close();
        }
      } catch (IOException ex) {
        // ignore close exception
      }
    }
  }

  public boolean isFeatureEnabled(String methodName, String[] classes) {
    Method isFeatureEnabledMethod = findMethod(IS_FEATURE_ENABLED_NAME);
    isFeatureEnabledMethod.setAccessible(true);
    return (boolean) invoke(isFeatureEnabledMethod, -1, new Object[] {methodName, classes});
  }

  private class ToolingCallable implements Callable<Object> {

    private final Method method;
    private final Object[] args;
    private final Map<String, String> copyOfContextMap;

    public ToolingCallable(Method method, Object[] args, Map<String, String> copyOfContextMap) {
      this.method = method;
      this.args = args;
      this.copyOfContextMap = copyOfContextMap;
    }

    @Override
    public Object call() throws Exception {
      ClassLoader contextClassLoader = currentThread().getContextClassLoader();
      try {
        classLoader.loadClass(MDC.class.getName()).getMethod("setContextMap", Map.class).invoke(null, copyOfContextMap);
        currentThread().setContextClassLoader(classLoader);
        return method.invoke(target, args);
      } catch (InvocationTargetException e) {
        if (LOGGER.isTraceEnabled()) {
          LOGGER.trace(format("Error while calling method '%s' on '%s', error: %s", method,
                              target.getClass(), e.getCause().getMessage()));
        }
        Throwable cause = e.getCause();
        if (cause instanceof RuntimeException) {
          throw (RuntimeException) cause;
        } else {
          throw new IllegalStateException(cause);
        }
      } catch (Throwable t) {
        throw new IllegalStateException(t);
      } finally {
        currentThread().setContextClassLoader(contextClassLoader);
        MDC.clear();
      }
    }
  }
}
