/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 */
package org.mule.service.http.test.netty.utils.server;

import static java.util.stream.Collectors.toList;

import java.io.IOException;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;

import org.junit.Rule;
import org.junit.rules.ExternalResource;

/**
 * A TCP Server as {@link Rule}, that writes always the same response after reading something from the input.
 */
public class HardcodedResponseTcpServer extends ExternalResource {

  private final int port;
  private AcceptorThread acceptorThread;

  // Default response is a HTTP 200 OK, because it was created as an HTTP test utility.
  private String response = "HTTP/1.1 200 OK\ncontent-length: 0\n\n";
  private final AtomicInteger acceptedCount = new AtomicInteger(0);
  private final List<RequestHandlerThread> requestHandlerThreads = new ArrayList<>();
  private final ConcurrentLinkedQueue<StringBuilder> receivedRawRequests = new ConcurrentLinkedQueue<>();
  private boolean isCloseOutputAfterResponse = false;

  public HardcodedResponseTcpServer(int port) {
    this.port = port;
  }

  public void setResponse(String response) {
    this.response = response;
  }

  public void setCloseOutputAfterResponse(boolean closeOutputAfterResponse) {
    this.isCloseOutputAfterResponse = closeOutputAfterResponse;
  }

  @Override
  protected void before() throws Throwable {
    ServerSocket serverSocket = new ServerSocket(port);
    acceptorThread = new AcceptorThread(serverSocket, this::onAccepted);
    acceptorThread.start();
  }

  private void onAccepted(Socket socket) {
    acceptedCount.incrementAndGet();

    RequestHandlerThread handlerThread =
        new RequestHandlerThread(socket, response, acceptedCount::decrementAndGet, receivedRawRequests,
                                 isCloseOutputAfterResponse);
    handlerThread.start();
    requestHandlerThreads.add(handlerThread);
  }

  public int acceptedCount() {
    return acceptedCount.get();
  }

  public List<String> getReceivedRawRequests() {
    return receivedRawRequests.stream().map(StringBuilder::toString).collect(toList());
  }

  @Override
  protected void after() {
    try {
      acceptorThread.close();
      acceptorThread.join();
      for (RequestHandlerThread requestHandlerThread : requestHandlerThreads) {
        requestHandlerThread.close();
        requestHandlerThread.join();
      }
    } catch (IOException | InterruptedException e) {
      throw new RuntimeException(e);
    }
  }

  private static class AcceptorThread extends Thread {

    private final ServerSocket serverSocket;
    private final Consumer<Socket> onAccept;
    private boolean isClosed;

    public AcceptorThread(ServerSocket serverSocket, Consumer<Socket> onAccept) {
      this.serverSocket = serverSocket;
      this.onAccept = onAccept;
    }

    @Override
    public void run() {
      try {
        while (!isClosed) {
          Socket accepted = serverSocket.accept();
          onAccept.accept(accepted);
        }
      } catch (IOException e) {
        if (!isClosed) {
          throw new RuntimeException(e);
        }
      }
    }

    public void close() throws IOException {
      isClosed = true;
      serverSocket.close();
    }
  }

  private static class RequestHandlerThread extends Thread {

    private final Socket socket;
    private final String response;
    private final Runnable onClosed;
    private final StringBuilder receivedRawRequest;
    private final boolean isCloseOutputAfterResponse;

    private RequestHandlerThread(Socket socket, String response, Runnable onClosed,
                                 ConcurrentLinkedQueue<StringBuilder> receivedRawRequests,
                                 boolean isCloseOutputAfterResponse) {
      this.socket = socket;
      this.response = response;
      this.onClosed = onClosed;
      this.isCloseOutputAfterResponse = isCloseOutputAfterResponse;
      this.receivedRawRequest = new StringBuilder();
      receivedRawRequests.add(receivedRawRequest);
    }

    @Override
    public void run() {
      boolean responseIsSent = false;
      try {
        InputStream stream = socket.getInputStream();

        int ret = 0;
        while (ret != -1) {
          // Read something (if it's a -1 it's also ok, we just don't want to start writing before reading something)
          byte[] buf = new byte[2048];
          ret = stream.read(buf);
          if (ret != -1) {
            receivedRawRequest.append(new String(buf));
            socket.getOutputStream().write(response.getBytes());
            socket.getOutputStream().flush();
            if (isCloseOutputAfterResponse) {
              socket.getOutputStream().close();
            }
            responseIsSent = true;
          }
        }
      } catch (SocketException e) {
        if (!responseIsSent) {
          // If the response is sent, it's ok for us
          throw new RuntimeException(e);
        }
      } catch (IOException e) {
        throw new RuntimeException(e);
      } finally {
        onClosed.run();
      }
    }

    public void close() throws IOException {
      socket.close();
    }
  }
}
