package com.openfin.desktop;

import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Fixed size thread pool that logs warning messages if a task is unable to acquire a thread to execute it after the given time.
 * The default timeout for the warning is 10000ms and it can be modified via system property "com.openfin.desktop.threadpool.pendingtask.timeout"
 * @author Anthony
 *
 */
public class OpenFinThreadPool extends ThreadPoolExecutor {

	private final static Logger logger = LoggerFactory.getLogger(OpenFinThreadPool.class);

	private ConcurrentHashMap<Runnable, ThreadPoolTimerTask> pendingRunnableTimerTaskMap;
	private ConcurrentHashMap<Runnable, String> executingRunnableStackTraceMap;
	private Timer executionTimer;
	private int pendingTaskTimeout;
	private boolean logStackTrace;
	private String threadPoolName;

	/**
	 * Creates a fixed size thread pool.
	 * @param name name of the thread pool
	 * @param nThreads number of threads in the pool.
	 */
	public OpenFinThreadPool(String name, int nThreads) {
		super(nThreads, nThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(), new ThreadFactory() {
			AtomicInteger threadId = new AtomicInteger(0);

			@Override
			public Thread newThread(Runnable r) {
				return new Thread(r, name + "-" + threadId.getAndIncrement());
			}
		});
		this.threadPoolName = name;
		this.pendingRunnableTimerTaskMap = new ConcurrentHashMap<>();
		this.executingRunnableStackTraceMap = new ConcurrentHashMap<>();
		this.executionTimer = new Timer(name + "ThreadPoolReportTimer");
		this.pendingTaskTimeout = Integer
				.parseInt(java.lang.System.getProperty("com.openfin.desktop.threadpool.pendingtask.timeout", "10000"));
		this.logStackTrace = Boolean
				.parseBoolean(java.lang.System.getProperty("com.openfin.desktop.threadpool.logstacktrace", "false"));
	}

	@Override
	public void execute(Runnable task) {
		try {
			// timer task to detect if it takes longer to get executed.
			ThreadPoolTimerTask tt = new ThreadPoolTimerTask(task, new Exception().getStackTrace()) {
				@Override
				public void run() {
					logger.warn("{}: unable to get available thread after {}ms, thread pool size: {}, running thread count: {}, the threads in the thread pool might have been exhausted.", 
							threadPoolName, pendingTaskTimeout, getCorePoolSize(), executingRunnableStackTraceMap.size());
					if (logStackTrace) {
						logger.warn("waiting task was submitted at: {}", stackTrace);
						for (String runningStackTrace : executingRunnableStackTraceMap.values()) {
							logger.warn("running task was submitted at: {}", runningStackTrace);
						}
						
					}
				}
			};
			this.executionTimer.schedule(tt, pendingTaskTimeout);
			this.pendingRunnableTimerTaskMap.put(task, tt);
		}
		catch (Exception ex) {
			logger.error("error execute", ex);
		}
		finally {
			super.execute(task);
		}
	}

	@Override
	protected void beforeExecute(Thread thread, Runnable task) {
		try {
			ThreadPoolTimerTask waitForExecutionTimerTask = this.pendingRunnableTimerTaskMap.get(task);
			waitForExecutionTimerTask.cancel();
			this.executingRunnableStackTraceMap.put(task, waitForExecutionTimerTask.stackTrace);
		}
		catch (Exception ex) {
			logger.error("error beforeExecute", ex);
		}
		finally {
			super.beforeExecute(thread, task);
		}
	}

	@Override
	protected void afterExecute(Runnable task, Throwable t) {
		try {
			this.executingRunnableStackTraceMap.remove(task);
		}
		catch (Exception ex) {
			logger.error("error afterExecute", ex);
		}
		finally {
			super.afterExecute(task, t);
		}
	}

	abstract class ThreadPoolTimerTask extends TimerTask {
		Runnable task;
		String stackTrace;

		ThreadPoolTimerTask(Runnable task, StackTraceElement[] stackTrace) {
			this.task = task;
			StringBuilder sb = new StringBuilder("Stack trace").append(java.lang.System.lineSeparator());
			for (StackTraceElement traceElement : stackTrace) {
				sb.append("\tat ").append(traceElement).append(java.lang.System.lineSeparator());
			}
			this.stackTrace = sb.toString();
		}
	}
}
