package junitparams;

import java.lang.reflect.*;
import java.util.*;

import javax.lang.model.type.*;

import org.junit.runner.*;
import org.junit.runner.notification.*;
import org.junit.runners.model.*;

/**
 * Testmethod-level functionalities for parameterised tests
 * 
 * @author Pawel Lipinski
 * 
 */
public class ParameterisedTestMethodRunner {

    private int count;
    private final TestMethod method;
    private Parameters parametersAnnotation;

    public ParameterisedTestMethodRunner(TestMethod testMethod) {
        this.method = testMethod;
        parametersAnnotation = testMethod.frameworkMethod.getAnnotation(Parameters.class);
    }

    public int nextCount() {
        return count++;
    }

    public int count() {
        return count;
    }

    Object[] paramsFromAnnotation() {
        Object[] params = paramsFromValue();

        if (params.length == 0)
            params = paramsFromSource();

        if (params.length == 0)
            params = paramsFromMethod();

        return params;
    }

    private Object[] paramsFromValue() {
        Object[] params = parametersAnnotation.value();
        return params;
    }

    private Object[] paramsFromSource() {
        if (sourceClassUndefined())
            return new Object[] {};

        Class<?> sourceClass = parametersAnnotation.source();

        return fillResultWithAllParamProviderMethods(sourceClass);
    }

    private Object[] fillResultWithAllParamProviderMethods(Class<?> sourceClass) {
        List<Object> result = getParamsFromSourceHierarchy(sourceClass);
        if (result.isEmpty())
            throw new RuntimeException(
                    "No methods starting with provide or they return no result in the parameters source class: "
                            + sourceClass.getName());

        return result.toArray(new Object[] {});
    }

    public List<Object> getParamsFromSourceHierarchy(Class<?> sourceClass) {
        List<Object> result = new ArrayList<Object>();
        while (sourceClass.getSuperclass() != null) {
            result.addAll(gatherParamsFromAllMethodsFrom(sourceClass));
            sourceClass = sourceClass.getSuperclass();
        }

        return result;
    }

    private List<Object> gatherParamsFromAllMethodsFrom(Class<?> sourceClass) {
        List<Object> result = new ArrayList<Object>();
        Method[] methods = sourceClass.getDeclaredMethods();
        for (Method method : methods) {
            if (method.getName().startsWith("provide")) {
                if (!Modifier.isStatic(method.getModifiers()))
                    throw new RuntimeException("Parameters source method " +
                            method.getName() +
                            " is not declared as static. Modify it to a static method.");
                try {
                    result.addAll(Arrays.asList(processParamsIfSingle((Object[]) method.invoke(null))));
                } catch (Exception e) {
                    throw new RuntimeException("Cannot invoke parameters source method: " + method.getName(), e);
                }
            }
        }
        return result;
    }

    private boolean sourceClassUndefined() {
        return parametersAnnotation.source().isAssignableFrom(NullType.class);
    }

    private Object[] paramsFromMethod() {
        String methodAnnotation = parametersAnnotation.method();

        if ("".equals(methodAnnotation))
            return invokeMethodWithParams(defaultMethodName());

        List<Object> result = new ArrayList<Object>();
        for (String methodName : methodAnnotation.split(",")) {
            for (Object param : invokeMethodWithParams(methodName.trim()))
                result.add(param);
        }

        return result.toArray();

    }

    private Object[] invokeMethodWithParams(String methodName) {
        Class<?> testClass = method.testClass();

        Method provideMethod = findParamsProvidingMethodInTestclassHierarchy(methodName, testClass);

        return invokeParamsProvidingMethod(testClass, provideMethod);
    }

    private Object[] invokeParamsProvidingMethod(Class<?> testClass, Method provideMethod) {
        try {
            Object testObject = testClass.newInstance();
            provideMethod.setAccessible(true);
            Object[] params = (Object[]) provideMethod.invoke(testObject);
            return processParamsIfSingle(params);
        } catch (ClassCastException e) {
            throw new RuntimeException("The return type of: " + provideMethod.getName() + " defined in class " + testClass
                    + " is not Object[]. Fix it!", e);
        } catch (Exception e) {
            throw new RuntimeException("Could not invoke method: " + provideMethod.getName() + " defined in class " + testClass
                    + " so no params were used.", e);
        }
    }

    private Object[] processParamsIfSingle(Object[] params) {
        if (method.frameworkMethod.getMethod().getParameterTypes().length != params.length)
            return params;

        if (params.length == 0)
            return params;

        Object param = params[0];
        if (param == null || !param.getClass().isArray())
            return new Object[] { params };

        return params;
    }

    private Method findParamsProvidingMethodInTestclassHierarchy(String methodName, Class<?> testClass) {
        Method provideMethod = null;
        Class<?> declaringClass = testClass;
        while (declaringClass.getSuperclass() != null) {
            try {
                provideMethod = declaringClass.getDeclaredMethod(methodName);
                break;
            } catch (Exception e) {
            }
            declaringClass = declaringClass.getSuperclass();
        }
        if (provideMethod == null)
            throw new RuntimeException("Could not find method: " + methodName + " so no params were used.");
        return provideMethod;
    }

    private String defaultMethodName() {
        String methodName;
        methodName = "parametersFor" + method.frameworkMethod.getName().substring(0, 1).toUpperCase()
                + method.frameworkMethod.getName().substring(1);
        return methodName;
    }

    Object currentParamsFromAnnotation() {
        return paramsFromAnnotation()[nextCount()];
    }

    void runTestMethod(Statement methodInvoker, RunNotifier notifier) {
        Description methodDescription = describeMethod();
        Description methodWithParams = findChildForParams(methodInvoker, methodDescription);

        notifier.fireTestStarted(methodWithParams);
        runMethodInvoker(notifier, methodDescription, methodInvoker, methodWithParams);
        notifier.fireTestFinished(methodWithParams);
    }

    Description describeMethod() {
        Object[] params = paramsFromAnnotation();
        Description parametrised = Description.createSuiteDescription(method.name());
        for (int i = 0; i < params.length; i++) {
            Object paramSet = params[i];
            parametrised.addChild(Description.createTestDescription(method.testClass(),
                    Utils.stringify(paramSet, i) + " (" + method.name() + ")", method.frameworkMethod.getAnnotations()));
        }
        return parametrised;
    }

    private void runMethodInvoker(RunNotifier notifier, Description description, Statement methodInvoker, Description methodWithParams) {
        try {
            methodInvoker.evaluate();
        } catch (Throwable e) {
            notifier.fireTestFailure(new Failure(methodWithParams, e));
        }
    }

    private Description findChildForParams(Statement methodInvoker, Description methodDescription) {
        for (Description child : methodDescription.getChildren()) {
            InvokeParameterisedMethod parameterisedInvoker = findParameterisedMethodInvokerInChain(methodInvoker);

            if (child.getMethodName().startsWith(parameterisedInvoker.getParamsAsString()))
                return child;
        }
        return null;
    }

    private InvokeParameterisedMethod findParameterisedMethodInvokerInChain(Statement methodInvoker) {
        while (methodInvoker != null && !(methodInvoker instanceof InvokeParameterisedMethod))
            methodInvoker = nextChainedInvoker(methodInvoker);

        if (methodInvoker == null)
            throw new RuntimeException("Cannot find invoker for the parameterised method. Using wrong JUnit version?");

        return (InvokeParameterisedMethod) methodInvoker;
    }

    private Statement nextChainedInvoker(Statement methodInvoker) {
        Field[] declaredFields = methodInvoker.getClass().getDeclaredFields();

        for (Field field : declaredFields) {
            Statement statement = statementOrNull(methodInvoker, field);
            if (statement != null)
                return statement;
        }

        return null;
    }

    private Statement statementOrNull(Statement methodInvoker, Field field) {
        if (field.getType().isAssignableFrom(Statement.class))
            return getOriginalStatement(methodInvoker, field);

        return null;
    }

    private Statement getOriginalStatement(Statement methodInvoker, Field field) {
        field.setAccessible(true);
        try {
            return (Statement) field.get(methodInvoker);
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }
}
