/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.service.http.impl.service.client.async;

import static java.lang.Thread.currentThread;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
import static org.mockito.junit.MockitoJUnit.rule;

import org.mule.runtime.api.util.Reference;
import org.mule.tck.junit4.AbstractMuleTestCase;

import com.ning.http.client.AsyncHandler;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mock;
import org.mockito.junit.MockitoRule;

public class PreservingClassLoaderAsyncHandlerTestCase extends AbstractMuleTestCase {

  @Rule
  public MockitoRule rule = rule();

  @Mock
  private AsyncHandler<Integer> delegate;

  @Mock
  private ClassLoader mockClassLoader;

  private PreservingClassLoaderAsyncHandler<Integer> asyncHandler;
  private ClassLoader classLoaderOnCreation;

  @Before
  public void setup() {
    classLoaderOnCreation = currentThread().getContextClassLoader();
    asyncHandler = new PreservingClassLoaderAsyncHandler<>(delegate);
  }

  @Test
  public void creationClassLoaderIsPreservedOnCompleted() throws Exception {
    Reference<ClassLoader> classLoaderOnCompleted = new Reference<>();
    when(delegate.onCompleted()).then(invocation -> {
      classLoaderOnCompleted.set(currentThread().getContextClassLoader());
      return "completed";
    });

    // Call the method with other classloader
    final var tccl = currentThread().getContextClassLoader();
    try {
      currentThread().setContextClassLoader(mockClassLoader);
      asyncHandler.onCompleted();
    } finally {
      currentThread().setContextClassLoader(tccl);
    }
    assertThat(classLoaderOnCompleted.get(), is(classLoaderOnCreation));
  }

  @Test
  public void creationClassLoaderIsPreservedOnThrowable() {
    Reference<ClassLoader> classLoaderOnThrowable = new Reference<>();
    doAnswer(invocation -> {
      classLoaderOnThrowable.set(currentThread().getContextClassLoader());
      return "completed";
    }).when(delegate).onThrowable(any(Throwable.class));

    // Call the method with other classloader
    final var tccl = currentThread().getContextClassLoader();
    try {
      currentThread().setContextClassLoader(mockClassLoader);
      asyncHandler.onThrowable(new Throwable());
    } finally {
      currentThread().setContextClassLoader(tccl);
    }
    assertThat(classLoaderOnThrowable.get(), is(classLoaderOnCreation));
  }

  @Test
  public void creationClassLoaderIsPreservedOnBodyPartReceived() throws Exception {
    Reference<ClassLoader> classLoaderOnBodyPartReceived = new Reference<>();
    when(delegate.onBodyPartReceived(any())).then(invocation -> {
      classLoaderOnBodyPartReceived.set(currentThread().getContextClassLoader());
      return null;
    });

    // Call the method with other classloader
    final var tccl = currentThread().getContextClassLoader();
    try {
      currentThread().setContextClassLoader(mockClassLoader);
      asyncHandler.onBodyPartReceived(null);
    } finally {
      currentThread().setContextClassLoader(tccl);
    }
    assertThat(classLoaderOnBodyPartReceived.get(), is(classLoaderOnCreation));
  }

  @Test
  public void creationClassLoaderIsPreservedOnStatusReceived() throws Exception {
    Reference<ClassLoader> classLoaderOnStatusReceived = new Reference<>();
    when(delegate.onStatusReceived(any())).then(invocation -> {
      classLoaderOnStatusReceived.set(currentThread().getContextClassLoader());
      return null;
    });

    // Call the method with other classloader
    final var tccl = currentThread().getContextClassLoader();
    try {
      currentThread().setContextClassLoader(mockClassLoader);
      asyncHandler.onStatusReceived(null);
    } finally {
      currentThread().setContextClassLoader(tccl);
    }
    assertThat(classLoaderOnStatusReceived.get(), is(classLoaderOnCreation));
  }

  @Test
  public void creationClassLoaderIsPreservedOnHeadersReceived() throws Exception {
    Reference<ClassLoader> classLoaderOnHeadersReceived = new Reference<>();
    when(delegate.onHeadersReceived(any())).then(invocation -> {
      classLoaderOnHeadersReceived.set(currentThread().getContextClassLoader());
      return null;
    });

    // Call the method with other classloader
    final var tccl = currentThread().getContextClassLoader();
    try {
      currentThread().setContextClassLoader(mockClassLoader);
      asyncHandler.onHeadersReceived(null);
    } finally {
      currentThread().setContextClassLoader(tccl);
    }
    assertThat(classLoaderOnHeadersReceived.get(), is(classLoaderOnCreation));
  }

}
