/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.service.http.impl.functional.client;

import static java.net.URI.create;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.nio.file.Files.newInputStream;
import static java.nio.file.Paths.get;

import static org.apache.commons.io.FileUtils.readFileToString;

import org.mule.service.http.impl.service.HttpServiceImplementation;
import org.mule.service.http.impl.service.client.GrizzlyHttpClient;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.net.URL;
import java.security.CodeSource;

import org.junit.After;
import org.junit.Before;

public class RequestStreamingLargePayloadTestCase
    extends org.mule.service.http.test.common.client.RequestStreamingLargePayloadTestCase {

  public RequestStreamingLargePayloadTestCase() {
    super(HttpServiceImplementation.class.getName());
  }

  @Before
  public void before() throws Exception {
    setRequestStreaming(true);
  }

  @After
  public void after() throws Exception {
    setRequestStreaming(false);
  }

  @Override
  protected InputStream getInputStream() {
    try {
      return newInputStream(get(create(getClassPathRoot(RequestStreamingLargePayloadTestCase.class).toURI()
          + "largePayload")));
    } catch (Exception e) {
      throw new AssertionError("Error on loading the large payload file");
    }
  }

  @Override
  protected String expectedPayload() {
    try {
      return readFileToString(new File(getClassPathRoot(RequestStreamingLargePayloadTestCase.class).getPath()
          + "largePayload"), UTF_8);
    } catch (IOException e) {
      throw new AssertionError("Error on loading the large payload file");
    }
  }

  public static void setRequestStreaming(boolean requestStreaming) throws Exception {
    Field requestStreamingEnabledField = GrizzlyHttpClient.class.getDeclaredField("requestStreamingEnabled");
    requestStreamingEnabledField.setAccessible(true);
    requestStreamingEnabledField.setBoolean(null, requestStreaming);
  }

  // this is a shorter version of the snippet from:
  // http://www.davidflanagan.com/blog/2005_06.html#000060
  // (see comments; DF's "manual" version works fine too)
  public static URL getClassPathRoot(Class clazz) {
    CodeSource cs = clazz.getProtectionDomain().getCodeSource();
    return (cs != null ? cs.getLocation() : null);
  }

}
