package com.alterioncorp.requestlogger;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.management.ManagementFactory;
import java.net.InetAddress;
import java.sql.Timestamp;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

import javax.management.JMException;
import javax.management.ObjectName;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RequestLoggerFilter implements Filter, Constants, RequestLoggerFilterMXBean {

	static final String SANITIZED_PARAM_VALUE = "******";
	static final int MAX_REQUEST_SIZE_TO_LOG = 1024 * 1024;
	
	private Persister persister;
	private boolean enabled;
	private long persistPeriodInMillis;
	private int queueMaxSize;    
	private Set<String> paramNamesToSanitize;
	private PropertyRegistry propertyRegistry;
    
	private final Logger log = LoggerFactory.getLogger(this.getClass().getName());
	private volatile List<Request> queue;
	private ScheduledExecutorService executorService;
	private final ThreadFactory threadFactory;
	
	public RequestLoggerFilter() {
		super();
		this.queue = new LinkedList<Request>();
		this.threadFactory = new ThreadFactoryImpl("request-logger", true);
	}
	
	protected RequestLoggerFilter(
			Persister requestLogger, PropertyRegistry propertyRegistry,
			boolean enabled, long persistPeriodInMillis, int queueMaxSize,
			Set<String> paramNamesToSanitize) {
		this();
		this.persister = requestLogger;
		this.propertyRegistry = propertyRegistry;
		this.enabled = enabled;
		this.persistPeriodInMillis = persistPeriodInMillis;
		this.queueMaxSize = queueMaxSize;
		this.paramNamesToSanitize = paramNamesToSanitize;
	}

	@Override
	public final synchronized void init(FilterConfig config) throws ServletException {
		
		try {
		
			enabled = Boolean.parseBoolean(System.getProperty(
					PROPERTY_NAME_ENABLED, PROPERTY_DEFAULT_ENABLED));
			
			persistPeriodInMillis = Long.parseLong(System.getProperty(
					PROPERTY_NAME_PERSIST_PERIOD_IN_MILLIS, PROPERTY_DEFAULT_PERSIST_PERIOD_IN_MILLIS));
			
			queueMaxSize = Integer.parseInt(System.getProperty(
					PROPERTY_NAME_QUEUE_MAX_SIZE, PROPERTY_DEFAULT_QUEUE_MAX_SIZE));
			
			String paramNamesToSanitizeValue = System.getProperty(
					PROPERTY_NAME_PARAMS_TO_SANITIZE, PROPERTY_DEFAULT_PARAMS_TO_SANITIZE);
			
			paramNamesToSanitize = new HashSet<String>(
					Arrays.asList(paramNamesToSanitizeValue.split(",")));
			
			propertyRegistry = PropertyRegistryImpl.getInstance();
			this.modifyPropertyRegistry(propertyRegistry);

			String loggerImplClassName = System.getProperty(
					PROPERTY_NAME_PERSISTER_IMPL, PROPERTY_DEFAULT_PERSISTER_IMPL);

			persister = (Persister)Class.forName(loggerImplClassName).newInstance();
			
			this.onInit();

			ManagementFactory.getPlatformMBeanServer().registerMBean(this, new ObjectName(MBEAN_NAME_FILTER));

			this.startScheduler();
		}
		catch (ClassNotFoundException e) {
			throw new ServletException(e);
		}
		catch (InstantiationException e) {
			throw new ServletException(e);
		}
		catch (IllegalAccessException e) {
			throw new ServletException(e);
		}
		catch (JMException e) {
			throw new ServletException(e);
		}

	}
	
	protected void onInit() {
		
	}
	
	@Override
	public final synchronized void destroy() {
		
		this.stopScheduler();
		
		try {
			ManagementFactory.getPlatformMBeanServer().unregisterMBean(new ObjectName(MBEAN_NAME_FILTER));
		}
		catch (JMException e) {
		}
		
		this.onDestroy();
	}
	
	protected void onDestroy() {
		
	}

	@Override
	public final void doFilter(ServletRequest request, ServletResponse response,
			FilterChain chain) throws IOException, ServletException {
        
    	HttpServletRequest httpServletRequest = (HttpServletRequest)request;
    	HttpServletResponse httpServletResponse = (HttpServletResponse)response;
    	CachedPayloadRequest cachedRequest = new CachedPayloadRequest(httpServletRequest);

    	long startTime = System.currentTimeMillis();
    	
    	// grab it now since it can be changed by a forward
    	final String path = httpServletRequest.getRequestURI();
    	
    	Exception error = null;
    	
        try {
        	
        	this.beforeFilter(httpServletRequest, httpServletResponse);
        	
        	this.log(false, httpServletRequest, httpServletResponse, path);
       		chain.doFilter(cachedRequest, response);
        	this.log(true, httpServletRequest, httpServletResponse, path);
        }
        catch (IOException e) {
        	error = e;
        	throw e;
        }
        catch (ServletException e) {
        	error = e;
        	throw e;
        }
        catch (RuntimeException e) {
        	error = e;
        	throw e;
        }
        finally {

        	if (enabled) {
        		
        		int responseStatus = httpServletResponse.getStatus();
	        	long duration = System.currentTimeMillis() - startTime;
	
	            Request requestEntity = new Request();            
	        	
	            requestEntity.getData().put(REQUEST_PROPERTY_REQUEST_IP, getRequestIP(httpServletRequest));
	            requestEntity.getData().put(REQUEST_PROPERTY_SERVER_IP, InetAddress.getLocalHost().getHostAddress());
	            requestEntity.getData().put(REQUEST_PROPERTY_SESSION_ID, httpServletRequest.getSession(false) == null ? null : httpServletRequest.getSession(false).getId());
	            requestEntity.getData().put(REQUEST_PROPERTY_START_TIME, new Timestamp(startTime));
	            requestEntity.getData().put(REQUEST_PROPERTY_METHOD, httpServletRequest.getMethod());
	            requestEntity.getData().put(REQUEST_PROPERTY_PATH, path);
	            requestEntity.getData().put(REQUEST_PROPERTY_REQUEST_CONTENT_TYPE, httpServletRequest.getContentType());
	            requestEntity.getData().put(REQUEST_PROPERTY_REQUEST_BODY, this.requestBodyToString(cachedRequest));
	            requestEntity.getData().put(REQUEST_PROPERTY_DURATION, duration);
	            requestEntity.getData().put(REQUEST_PROPERTY_RESPONSE_STATUS, responseStatus);
	            
	            if (error != null) {
	            	StringWriter stringWriter = new StringWriter();
	            	error.printStackTrace(new PrintWriter(stringWriter, true));
	            	requestEntity.getData().put(REQUEST_PROPERTY_ERROR, stringWriter.toString());
	            }
	            
	            this.modifyRequestLogData(requestEntity, httpServletRequest, httpServletResponse);

	            synchronized(this) {
	        	
	            	if (queue.size() >= queueMaxSize) {
	        			throw new IllegalStateException("queue is too large");
	        		}
	
		            queue.add(requestEntity);
	            }
        	}
        	
        	this.afterFilter(httpServletRequest, httpServletResponse);
        }
    }
	
	protected void beforeFilter(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
		
	}

	protected void afterFilter(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
		
	}

	void saveRequests() {
		List<Request> requestsToSave = null;
		if (! queue.isEmpty()) {
			synchronized(this) {
				if (! queue.isEmpty()) {
					requestsToSave = queue;
					queue = new LinkedList<Request>();
				}
			}
		}
		if (requestsToSave != null) {
			persister.saveRequests(requestsToSave, propertyRegistry);
		}
	}
	
	private void startScheduler() {
		
		executorService = Executors.newSingleThreadScheduledExecutor(threadFactory);
		
		executorService.scheduleWithFixedDelay(
				new Runnable() {
					@Override
					public void run() {
						try {
							saveRequests();
						}
						catch (Exception e) {
							log.error("Error persisting requests", e);
						}
						catch (Error e) {
							log.error("Error persisting requests", e);
							throw e;
						}
					}
				},
				persistPeriodInMillis, persistPeriodInMillis, TimeUnit.MILLISECONDS);

	}
	
	private void stopScheduler() {
		if (executorService != null) {
			executorService.shutdown();
			try {
				executorService.awaitTermination(10, TimeUnit.SECONDS);
			}
			catch (InterruptedException e) {
			}
		}
	}
	
	private String requestBodyToString(CachedPayloadRequest httpServletRequest) throws IOException {
		
		httpServletRequest.resetInputStream();

		String requestBodyAsText = null;
        String queryString = this.paramsToString(httpServletRequest);

        if (queryString != null && queryString.trim().length() > 0) {
        	requestBodyAsText = queryString;
        }
        else {
    		if (httpServletRequest.getContentLength() <= MAX_REQUEST_SIZE_TO_LOG) {
    			if (
    					httpServletRequest.getContentType() != null &&
    					! httpServletRequest.getContentType().equals("application/octet-stream") &&
    					! httpServletRequest.getContentType().contains("multipart")
    			) {
    				requestBodyAsText = IOUtils.toString(httpServletRequest.getReader());
    			}
    		}
    	}

    	return requestBodyAsText;
	}
		
	private String paramsToString(HttpServletRequest httpServletRequest) {

		StringBuilder result = new StringBuilder();
		
		for (String paramName : httpServletRequest.getParameterMap().keySet()) {
			for (String paramValue : httpServletRequest.getParameterMap().get(paramName)) {
				if (result.length() > 0) {
					result.append("&");
				}
				result.append(paramName);
				result.append("=");
				if (paramNamesToSanitize.contains(paramName)) {
					result.append(SANITIZED_PARAM_VALUE);
				}
				else {
					result.append(paramValue);
				}
			}
		}
		
		return result.toString();
	}
	
	static final String getRequestIP(HttpServletRequest request) {
		
		String requestIP = null;
		
		// X-Forwarded-For header set by a proxy
		Enumeration<String> forwardIPs = request.getHeaders("X-Forwarded-For");
		if (forwardIPs != null && forwardIPs.hasMoreElements()) {
			
			// get the first one
			requestIP = forwardIPs.nextElement();

			// if it's multiple addresses separated by a comma
			int separatorPosition = requestIP.indexOf(',');
			if (separatorPosition > 0) {
				// get the first one
				requestIP = requestIP.substring(0, separatorPosition);
			}
		}
		
		// if no X-Forwarded-For header
		if (requestIP == null) {
			requestIP = request.getRemoteAddr();
		}
		
		return requestIP;
	}
	
	private void log(boolean end, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, final String path) throws IOException {
		if (log.isInfoEnabled()) {
			StringBuilder message = new StringBuilder();
			message.append(end ? "END" : "START");
			message.append("\t");
			message.append(getRequestIP(httpServletRequest));
			message.append("\t");
			message.append(end ? httpServletResponse.getStatus() : "");
			message.append("\t");
			message.append(httpServletRequest.getMethod());
			message.append("\t");
			message.append(path);
			if (httpServletRequest.getQueryString() != null && httpServletRequest.getQueryString().trim().length() > 0) {
				message.append("?");
				message.append(httpServletRequest.getQueryString());				
			}
			log.info(message.toString());
		}
	}
	
	protected void modifyPropertyRegistry(PropertyRegistry dataTypeRegistry) {
	}
	
	protected void modifyRequestLogData(Request requestEntity,
			HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {		
	}

	@Override
	public final boolean isEnabled() {
		return enabled;
	}

	@Override
	public final void setEnabled(boolean enabled) {
		this.enabled = enabled;
	}

	@Override
	public final long getPeriodInMillis() {
		return persistPeriodInMillis;
	}

	@Override
	public final synchronized void setPeriodInMillis(long periodInMillis) {
		if (this.persistPeriodInMillis != periodInMillis) {
			this.persistPeriodInMillis = periodInMillis;
			if (executorService != null && ! executorService.isShutdown()) {
				this.stopScheduler();
				this.startScheduler();
			}
		}
	}

	@Override
	public final int getQueueMaxSize() {
		return queueMaxSize;
	}

	@Override
	public final void setQueueMaxSize(int queueMaxSize) {
		this.queueMaxSize = queueMaxSize;
	}

	@Override
	public final int getQueueSize() {
		return queue.size();
	}

	final Set<String> getParamNamesToSanitize() {
		return paramNamesToSanitize;
	}

	final Persister getPersister() {
		return persister;
	}

	final List<Request> getQueue() {
		return queue;
	}

	final ScheduledExecutorService getExecutorService() {
		return executorService;
	}
}
