package ai.apiverse.apisuite.mirror.agent;

import ai.apiverse.apisuite.mirror.agent.buffer.ApiBufferKey;
import ai.apiverse.apisuite.mirror.agent.buffer.BufferManagerWorker;
import ai.apiverse.apisuite.mirror.agent.buffer.DiscoveredApiBufferManager;
import ai.apiverse.apisuite.mirror.agent.buffer.RegisteredApiBufferManager;
import ai.apiverse.apisuite.mirror.models.data.APISample;
import lombok.RequiredArgsConstructor;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;

@RequiredArgsConstructor
public class ApimonitorApiProcessor {
    private final RegisteredApiBufferManager registeredApiBufferManager;
    private final DiscoveredApiBufferManager discoveredApiBufferManager;
    private final List<String> maskHeaders;
    private final SDKLogger logger;

    private static ContentCachingRequestWrapper wrapRequest(HttpServletRequest request) {
        if (request instanceof ContentCachingRequestWrapper) {
            return (ContentCachingRequestWrapper) request;
        } else {
            return new ContentCachingRequestWrapper(request);
        }
    }

    private static ContentCachingResponseWrapper wrapResponse(HttpServletResponse response) {
        if (response instanceof ContentCachingResponseWrapper) {
            return (ContentCachingResponseWrapper) response;
        } else {
            return new ContentCachingResponseWrapper(response);
        }
    }

    public void processDiscoveredApi(ApimonitorSpringFilter.RequestResponseContext context, final FilterChain filterChain) throws ServletException, IOException {
        BufferManagerWorker<ApiBufferKey> worker = this.discoveredApiBufferManager.getWorker();
        if (null == worker) {
            logger.error("BufferManagerWorker is NULL inside RegisteredApiProcessor");
            doFilter(filterChain, context);
            return;
        }

        boolean canOffer = worker.canOffer(context.getApiBufferKey());
        context.setPayloadCaptureAttempted(false);
        try {
            doFilter(filterChain, context);
        } catch (Exception exception) {
            if (canOffer) {
                tryOffering(context, exception, worker);
            }
            throw exception;
        }
        if (canOffer) {
            tryOffering(context, null, worker);
        }
    }

    public void processRegisteredApi(ApimonitorSpringFilter.RequestResponseContext context, final FilterChain filterChain) throws ServletException, IOException {
        BufferManagerWorker<ApiBufferKey> worker = this.registeredApiBufferManager.getWorker();
        if (null == worker) {
            logger.error("BufferManagerWorker is NULL inside RegisteredApiProcessor");
            doFilter(filterChain, context);
            return;
        }
        boolean canOffer = worker.canOffer(context.getApiBufferKey());
        context.setPayloadCaptureAttempted(true);
        boolean requestPayloadCaptureAttempted = false;
        boolean responsePayloadCaptureAttempted = false;
        if (canOffer) {
            requestPayloadCaptureAttempted = shouldCaptureRequest(context);
            if (requestPayloadCaptureAttempted) {
                context.setCachedRequest(wrapRequest(context.getServletRequest()));
            }
            responsePayloadCaptureAttempted = shouldCaptureResponse(context);
            if (responsePayloadCaptureAttempted) {
                context.setCachedResponse(wrapResponse(context.getServletResponse()));
            }
            context.setRequestPayloadCaptureAttempted(requestPayloadCaptureAttempted);
            context.setResponsePayloadCaptureAttempted(responsePayloadCaptureAttempted);
        }
        try {
            long startTime = 0;
            boolean shouldComputeLatency = canOffer;
            if (shouldComputeLatency) {
                startTime = System.currentTimeMillis();
            }
            doFilter(filterChain, context);
            if (shouldComputeLatency) {
                context.setLatency(System.currentTimeMillis() - startTime);
            }
        } catch (Exception exception) {
            if (canOffer) {
                tryOffering(context, exception, worker);
            }
            throw exception;
        }
        if (canOffer) {
            tryOffering(context, null, worker);
        }
    }

    private void tryOffering(ApimonitorSpringFilter.RequestResponseContext context, Exception exception, BufferManagerWorker<ApiBufferKey> worker) {
        try {
            APISample apiSample = getBufferEntryForApiSample(context, exception);
            if (null != apiSample) {
                worker.offer(context.getApiBufferKey(), apiSample);
            }
        } catch (Exception e) {
        }
    }

    private void doFilter(final FilterChain filterChain, ApimonitorSpringFilter.RequestResponseContext context) throws ServletException, IOException {
        final HttpServletRequest servletRequest = null != context.getCachedRequest() ? context.getCachedRequest() : context.getServletRequest();
        final HttpServletResponse servletResponse = null != context.getCachedResponse() ? context.getCachedResponse() : context.getServletResponse();
        filterChain.doFilter(servletRequest, servletResponse);
    }

    private boolean shouldCaptureRequest(ApimonitorSpringFilter.RequestResponseContext context) {
        if (null != context.getApiConfig() && null != context.getApiConfig().getCaptureSampleRequest() && !context.getApiConfig().getCaptureSampleRequest()) {
            return false;
        }
        return true;
    }

    private boolean shouldCaptureResponse(ApimonitorSpringFilter.RequestResponseContext context) {
        if (null != context.getApiConfig() && null != context.getApiConfig().getCaptureSampleResponse() && !context.getApiConfig().getCaptureSampleResponse()) {
            return false;
        }
        return true;
    }

    private APISample getBufferEntryForApiSample(ApimonitorSpringFilter.RequestResponseContext context, Exception exception) {
        APISample apiSample = new APISample();
        apiSample.setApplicationName(context.getApplicationName());
        try {
            if (null != context.getApiConfig()) {
                apiSample.setMethod(context.getApiConfig().getMethod());
            } else {
                apiSample.setMethod(context.getObservedApi().getMethod());
            }
            apiSample.setRawUri(context.getObservedApi().getUri().getUriPath());
            apiSample.setUri(context.getObservedApi().getUri());
            apiSample.setParameters(getParameters(context.getServletRequest()));
            apiSample.setRequestHeaders(MaskingUtils.getReplacedHeaders(getRequestHeaders(context), maskHeaders));
            apiSample.setResponseHeaders(MaskingUtils.getReplacedHeaders(getResponseHeaders(context), maskHeaders));
            apiSample.setLatency(context.getLatency());
            apiSample.setHostName(context.getServletRequest().getServerName());
            apiSample.setPort(context.getServletRequest().getServerPort());
            apiSample.setScheme(context.getServletRequest().getScheme());
            if (null != context.getCachedRequest()) {
                apiSample.setRequestPayload(new String(context.getCachedRequest().getContentAsByteArray()));
            }
            apiSample.setStatusCode(context.getServletResponse().getStatus());
            if (null == exception) {
                if (null != context.getCachedResponse()) {
                    apiSample.setResponsePayload(new String(context.getCachedResponse().getContentAsByteArray()));
                    context.getCachedResponse().copyBodyToResponse();
                    apiSample.setStatusCode(context.getCachedResponse().getStatus());
                }
            } else {
                // this is uncaught exception, even after all exception mappers
                apiSample.setUncaughtExceptionMessage(exception.getCause().getMessage());
                apiSample.setStatusCode(500);
            }
            apiSample.setRequestPayloadCaptureAttempted(context.getRequestPayloadCaptureAttempted());
            apiSample.setResponsePayloadCaptureAttempted(context.getResponsePayloadCaptureAttempted());
            apiSample.setPayloadCaptureAttempted(context.getPayloadCaptureAttempted());
        } catch (Exception e) {
            logger.error("Error create bufferEntry for API", e);
        }
        return apiSample;
    }

    private Map<String, String> getRequestHeaders(ApimonitorSpringFilter.RequestResponseContext context) {
        HttpServletRequest httpServletRequest = null;
        if (null != context.getCachedRequest()) {
            httpServletRequest = context.getCachedRequest();
        } else {
            httpServletRequest = context.getServletRequest();
        }
        if (null == httpServletRequest) {
            return new HashMap<>();
        }
        Enumeration<String> headerNames = httpServletRequest.getHeaderNames();
        Map<String, String> headerMap = new HashMap<>();
        if (headerNames != null) {
            while (headerNames.hasMoreElements()) {
                String headerName = headerNames.nextElement();
                headerMap.put(headerName, httpServletRequest.getHeader(headerName));
            }
        }
        return headerMap;
    }

    private Map<String, String> getResponseHeaders(ApimonitorSpringFilter.RequestResponseContext context) {
        HttpServletResponse httpServletResponse = null;
        if (null != context.getCachedResponse()) {
            httpServletResponse = context.getCachedResponse();
        } else {
            httpServletResponse = context.getServletResponse();
        }
        if (null == httpServletResponse || null == httpServletResponse.getHeaderNames()) {
            return new HashMap<>();
        }
        Collection<String> headerNames = httpServletResponse.getHeaderNames();
        Map<String, String> headerMap = new HashMap<>();
        for (String header : headerNames) {
            headerMap.put(header, httpServletResponse.getHeader(header));
        }
        headerMap.put("Content-Type", httpServletResponse.getContentType());
        return headerMap;
    }

    private Map<String, String[]> getParameters(HttpServletRequest servletRequest) {
        Map<String, String[]> servletParameterMap = servletRequest.getParameterMap();
        Map<String, String[]> parameterMap = new HashMap<>();
        for (String key : servletParameterMap.keySet()) {
            parameterMap.put(key, servletParameterMap.get(key));
        }
        return parameterMap;
    }
}
