package io.embrace.android.embracesdk.network.http;

import android.annotation.TargetApi;
import android.os.Build;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.ProtocolException;
import java.net.URL;
import java.security.Permission;
import java.security.Principal;
import java.security.cert.Certificate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocketFactory;

import io.embrace.android.embracesdk.Embrace;
import io.embrace.android.embracesdk.InternalApi;
import io.embrace.android.embracesdk.logging.InternalStaticEmbraceLogger;
import io.embrace.android.embracesdk.utils.exceptions.Unchecked;
import kotlin.jvm.functions.Function0;


/**
 * Wraps @{link HttpUrlConnection} to log network calls to Embrace. The wrapper also wraps the
 * InputStream to get an accurate count of bytes received if a Content-Length is not provided by
 * the server.
 * <p>
 * The wrapper handles gzip decompression itself and strips the {@code Content-Length} and
 * {@code Content-Encoding} headers. Typically this is handled transparently by
 * {@link HttpURLConnection} but would prevent us from accessing the {@code Content-Length}.
 * <p>
 * Network logging currently does not follow redirects. The duration is logged from initiation of
 * the network call (upon invocation of any method which would initiate the network call), to the
 * retrieval of the response.
 * <p>
 * As network calls are initiated lazily, we log the network call prior to the calling of any
 * wrapped method which would result in the network call actually being executed, and store a
 * flag to prevent duplication of calls.
 */
@InternalApi
class EmbraceUrlConnectionOverride<T extends HttpURLConnection>
    implements EmbraceUrlConnectionService, EmbraceSslUrlConnectionService {

    /**
     * The content encoding HTTP header.
     */
    private static final String CONTENT_ENCODING = "Content-Encoding";

    /**
     * The content length HTTP header.
     */
    private static final String CONTENT_LENGTH = "Content-Length";

    /**
     * Reference to the wrapped connection.
     */
    private final T connection;

    /**
     * The time at which the connection was created.
     */
    private final long createdTime;

    /**
     * Whether transparent gzip compression is enabled.
     */
    private final boolean enableWrapIoStreams;

    /**
     * A reference to the input stream wrapped in a counter, so we can determine the bytes received.
     */
    private volatile InputStream inputStream;

    /**
     * A reference to the output stream wrapped in a counter, so we can determine the bytes sent.
     */
    private volatile CountingOutputStream outputStream;

    /**
     * Whether the network call has already been logged, to prevent duplication.
     */
    private volatile boolean didLogNetworkCall = false;

    /**
     * The time at which the network call ended.
     */
    private volatile Long endTime;

    /**
     * The time at which the network call was initiated.
     */
    private volatile Long startTime;

    /**
     * The trace id specified for the request
     */
    private volatile String traceId;

    /**
     * The request headers captured from the http connection.
     */
    private volatile HashMap<String, String> requestHeaders;
    /**
     * Indicates if the request throws a IOException
     */
    private volatile boolean isIoException;

    /**
     * Wraps an existing {@link HttpURLConnection} with the Embrace network logic.
     *
     * @param connection          the connection to wrap
     * @param enableWrapIoStreams true if we should transparently ungzip the response, else false
     */
    public EmbraceUrlConnectionOverride(@NonNull T connection, boolean enableWrapIoStreams) {
        this.connection = connection;
        this.createdTime = System.currentTimeMillis();
        this.enableWrapIoStreams = enableWrapIoStreams;
    }

    @Override
    public void addRequestProperty(@NonNull String key, @Nullable String value) {
        this.connection.addRequestProperty(key, value);
    }

    @Override
    public void connect() throws IOException {
        identifyTraceId();
        this.connection.connect();
    }

    @Override
    public void disconnect() {
        identifyTraceId();
        // The network call must be logged before we close the transport
        logNetworkCall(this.createdTime);
        this.connection.disconnect();
    }

    @Override
    public boolean getAllowUserInteraction() {
        return this.connection.getAllowUserInteraction();
    }

    @Override
    public void setAllowUserInteraction(boolean allowUserInteraction) {
        this.connection.setAllowUserInteraction(allowUserInteraction);
    }

    @Override
    public int getConnectTimeout() {
        return this.connection.getConnectTimeout();
    }

    @Override
    public void setConnectTimeout(int timeout) {
        this.connection.setConnectTimeout(timeout);
    }

    @Override
    @Nullable
    public Object getContent() throws IOException {
        identifyTraceId();
        return this.connection.getContent();
    }

    @Override
    @Nullable
    public Object getContent(@NonNull Class<?>[] classes) throws IOException {
        identifyTraceId();
        return this.connection.getContent(classes);
    }

    @Override
    @Nullable
    public String getContentEncoding() {
        return shouldUncompressGzip() ? null : this.connection.getContentEncoding();
    }

    @Override
    public int getContentLength() {
        return shouldUncompressGzip() ? -1 : this.connection.getContentLength();
    }

    @Override
    @TargetApi(24)
    public long getContentLengthLong() {
        return (shouldUncompressGzip() || Build.VERSION.SDK_INT < Build.VERSION_CODES.N) ?
            -1 : this.connection.getContentLengthLong();
    }

    @Override
    @Nullable
    public String getContentType() {
        return this.connection.getContentType();
    }

    @Override
    public long getDate() {
        return this.connection.getDate();
    }

    @Override
    public boolean getDefaultUseCaches() {
        return this.connection.getDefaultUseCaches();
    }

    @Override
    public void setDefaultUseCaches(boolean defaultUseCaches) {
        this.connection.setDefaultUseCaches(defaultUseCaches);
    }

    @Override
    public boolean getDoInput() {
        return this.connection.getDoInput();
    }

    @Override
    public void setDoInput(boolean doInput) {
        this.connection.setDoInput(doInput);
    }

    @Override
    public boolean getDoOutput() {
        return this.connection.getDoOutput();
    }

    @Override
    public void setDoOutput(boolean doOutput) {
        this.connection.setDoOutput(doOutput);
    }

    @Override

    @Nullable
    public InputStream getErrorStream() {
        return getWrappedInputStream(this.connection.getErrorStream());
    }

    @Override
    public boolean shouldInterceptHeaderRetrieval(@Nullable String key) {
        return shouldUncompressGzip() && key != null && (key.equalsIgnoreCase(CONTENT_ENCODING) || key.equalsIgnoreCase(CONTENT_LENGTH));
    }

    @Override
    @Nullable
    public String getHeaderField(int n) {
        String key = this.connection.getHeaderFieldKey(n);
        return retrieveHeaderField(key,
            null,
            () -> connection.getHeaderField(n)
        );
    }

    @Override
    @Nullable
    public String getHeaderField(@Nullable String name) {
        return retrieveHeaderField(name,
            null,
            () -> connection.getHeaderField(name)
        );
    }

    @Override
    @Nullable
    public String getHeaderFieldKey(int n) {
        String key = this.connection.getHeaderFieldKey(n);
        return retrieveHeaderField(key,
            null,
            () -> key
        );
    }

    @Override
    public long getHeaderFieldDate(@NonNull String name, long defaultValue) {
        Long result = retrieveHeaderField(name,
            defaultValue,
            () -> connection.getHeaderFieldDate(name, defaultValue)
        );

        return result != null ? result : defaultValue;
    }

    @Override
    public int getHeaderFieldInt(@NonNull String name, int defaultValue) {
        Integer result = retrieveHeaderField(name,
            defaultValue,
            () -> connection.getHeaderFieldInt(name, defaultValue)
        );

        return result != null ? result : defaultValue;
    }


    @Override
    @TargetApi(24)
    public long getHeaderFieldLong(@NonNull String name, long defaultValue) {
        Long result = retrieveHeaderField(name,
            defaultValue,
            () -> Build.VERSION.SDK_INT < Build.VERSION_CODES.N ? -1 :
                this.connection.getHeaderFieldLong(name, defaultValue)

        );
        return result != null ? result : defaultValue;
    }

    @Override
    @Nullable
    public Map<String, List<String>> getHeaderFields() {
        final long startTime = System.currentTimeMillis();

        if (!enableWrapIoStreams) {
            return this.connection.getHeaderFields();
        }

        Map<String, List<String>> headerFields = new HashMap<>(this.connection.getHeaderFields());
        headerFields.remove(CONTENT_ENCODING);
        headerFields.remove(CONTENT_LENGTH);

        logNetworkCall(startTime);
        return headerFields;
    }


    private <R> R retrieveHeaderField(@Nullable String name,
                                      R defaultValue,
                                      Function0<R> action) {
        if (name == null) {
            return null;
        }
        long startTime = System.currentTimeMillis();
        if (shouldInterceptHeaderRetrieval(name)) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return defaultValue;
        }

        R result = action.invoke();
        logNetworkCall(startTime);
        return result;
    }

    @Override
    public long getIfModifiedSince() {
        return this.connection.getIfModifiedSince();
    }

    @Override
    public void setIfModifiedSince(long ifModifiedSince) {
        this.connection.setIfModifiedSince(ifModifiedSince);
    }

    @Override
    @Nullable
    public InputStream getInputStream() throws IOException {
        try {
            return getWrappedInputStream(this.connection.getInputStream());
        } catch (IOException e) {
            isIoException = true;
            throw e;
        }
    }

    @Override
    public boolean getInstanceFollowRedirects() {
        return this.connection.getInstanceFollowRedirects();
    }

    @Override
    public void setInstanceFollowRedirects(boolean followRedirects) {
        this.connection.setInstanceFollowRedirects(followRedirects);
    }

    @Override
    public long getLastModified() {
        return this.connection.getLastModified();
    }

    @Override
    @Nullable
    public OutputStream getOutputStream() throws IOException {
        identifyTraceId();
        OutputStream out = connection.getOutputStream();
        if (enableWrapIoStreams && this.outputStream == null && out != null) {
            this.outputStream = new CountingOutputStream(out);
            return this.outputStream;
        }
        return out;
    }

    @Override
    @Nullable
    public Permission getPermission() throws IOException {
        return this.connection.getPermission();
    }

    @Override
    public int getReadTimeout() {
        return this.connection.getReadTimeout();
    }

    @Override
    public void setReadTimeout(int timeout) {
        this.connection.setReadTimeout(timeout);
    }

    @Override
    @NonNull
    public String getRequestMethod() {
        return this.connection.getRequestMethod();
    }

    @Override
    public void setRequestMethod(@NonNull String method) throws ProtocolException {
        this.connection.setRequestMethod(method);
    }

    @Override
    @Nullable
    public Map<String, List<String>> getRequestProperties() {
        return this.connection.getRequestProperties();
    }

    @Override
    @Nullable
    public String getRequestProperty(@NonNull String key) {
        return this.connection.getRequestProperty(key);
    }

    @Override
    public int getResponseCode() throws IOException {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        int responseCode = this.connection.getResponseCode();
        logNetworkCall(startTime);
        return responseCode;
    }

    @Override
    @Nullable
    public String getResponseMessage() throws IOException {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        String responseMsg = this.connection.getResponseMessage();
        logNetworkCall(startTime);
        return responseMsg;
    }

    @Override
    @Nullable
    public URL getUrl() {
        return this.connection.getURL();
    }

    @Override
    public boolean getUseCaches() {
        return this.connection.getUseCaches();
    }

    @Override
    public void setUseCaches(boolean useCaches) {
        this.connection.setUseCaches(useCaches);
    }

    @Override
    public void setChunkedStreamingMode(int chunkLen) {
        this.connection.setChunkedStreamingMode(chunkLen);
    }

    @Override
    public void setFixedLengthStreamingMode(int contentLength) {
        this.connection.setFixedLengthStreamingMode(contentLength);
    }

    @Override
    @TargetApi(19)
    public void setFixedLengthStreamingMode(long contentLength) {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) {
            this.connection.setFixedLengthStreamingMode(contentLength);
        }
    }

    @Override
    public void setRequestProperty(@NonNull String key, @Nullable String value) {
        this.connection.setRequestProperty(key, value);

        if (hasNetworkCaptureRules()) {
            this.requestHeaders = getProcessedHeaders(getRequestProperties());
        }
    }

    @Override
    @NonNull
    public String toString() {
        return this.connection.toString();
    }

    @Override
    public boolean usingProxy() {
        return this.connection.usingProxy();
    }

    /**
     * Given a start time (in milliseconds), logs the network call to Embrace using the current time as the end time.
     * <p>
     * If the network call has already been logged for this HttpURLConnection, this method is a no-op and is effectively
     * ignored.
     */
    synchronized void logNetworkCall(long startTime) {
        logNetworkCall(startTime, System.currentTimeMillis(), false, null, false, null);
    }

    /**
     * Given a start time and end time (in milliseconds), logs the network call to Embrace.
     * <p>
     * If the network call has already been logged for this HttpURLConnection, this method is a no-op and is effectively
     * ignored.
     */
    synchronized void logNetworkCall(long startTime, long endTime, boolean overwrite, Long
        bytesIn, boolean shouldCaptureBody, byte[] responseBody) {
        if (!this.didLogNetworkCall || overwrite) {
            // We are proactive with setting this flag so that we don't get nested calls to log the network call by virtue of
            // extracting the data we need to log the network call.
            this.didLogNetworkCall = true;
            this.startTime = startTime;
            this.endTime = endTime;

            String url = EmbraceHttpPathOverride.getURLString(new EmbraceHttpUrlConnectionOverride(this.connection));

            NetworkCaptureData networkCaptureData = null;

            // If we don't have network capture rules, it's unnecessary to save these values
            if (hasNetworkCaptureRules() && (shouldCaptureBody || isIoException)) {
                Map<String, String> requestHeaders = this.requestHeaders;
                String requestQueryParams = connection.getURL().getQuery();
                byte[] requestBody = this.outputStream != null ? this.outputStream.getRequestBody() : null;
                Map<String, String> responseHeaders = getProcessedHeaders(getHeaderFields());

                networkCaptureData = new NetworkCaptureData(
                    requestHeaders,
                    requestQueryParams,
                    requestBody,
                    responseHeaders,
                    responseBody,
                    null
                );
            }

            try {
                long bytesOut = this.outputStream == null ? 0 : Math.max(this.outputStream.getCount(),
                    0);
                long contentLength = bytesIn == null ? Math.max(this.connection.getContentLength(), 0) : bytesIn;

                Embrace.getInstance().logNetworkCall(
                    url,
                    HttpMethod.fromString(getRequestMethod()),
                    getResponseCode(),
                    startTime,
                    endTime,
                    bytesOut,
                    contentLength,
                    traceId,
                    networkCaptureData
                );
            } catch (Exception e) {
                String className = e.getClass().getCanonicalName();
                String message = e.getMessage();

                Embrace.getInstance().logNetworkClientError(
                    url,
                    HttpMethod.fromString(getRequestMethod()),
                    startTime,
                    endTime,
                    className != null ? className : "",
                    message != null ? message : "",
                    traceId,
                    networkCaptureData
                );
            }
        }
    }

    @Nullable
    private HashMap<String, String> getProcessedHeaders(@Nullable Map<String, List<String>> properties) {
        if (properties == null) {
            return null;
        }

        HashMap<String, String> headers = new HashMap<>();

        for (Map.Entry<String, List<String>> h : properties.entrySet()) {
            StringBuilder builder = new StringBuilder();
            for (String value : h.getValue()) {
                if (value != null) {
                    builder.append(value);
                }
            }
            headers.put(h.getKey(), builder.toString());
        }

        return headers;
    }

    /**
     * Wraps an input stream with an input stream which counts the number of bytes read, and then
     * updates the network call service with the correct number of bytes read once the stream has
     * reached the end.
     *
     * @param inputStream the input stream to count
     * @return the wrapped input stream
     */
    private CountingInputStreamWithCallback countingInputStream(InputStream inputStream) {
        return new CountingInputStreamWithCallback(
            inputStream,
            hasNetworkCaptureRules(),
            (bytesCount, responseBody) -> {
                if (this.startTime != null && this.endTime != null) {
                    logNetworkCall(
                        this.startTime,
                        this.endTime,
                        true,
                        bytesCount,
                        true,
                        responseBody);
                }
            });
    }


    /**
     * We disable the automatic gzip decompression behavior of {@link HttpURLConnection} in the
     * {@link EmbraceHttpUrlStreamHandler} to ensure that we can count the bytes in the response
     * from the server. We decompress the response transparently to the user only if both:
     * <ul>
     * <li>The user did not specify an encoding</li>
     * <li>The server returned a gzipped response</li>
     * </ul>
     * <p>
     * If the user specified an encoding, even if it is gzip, we do not transparently decompress
     * the response. This is to mirror the behavior of {@link HttpURLConnection} whilst providing
     * us access to the content length.
     *
     * @return true if we should decompress the response, false otherwise
     * @see <a href="https://developer.android.com/reference/java/net/HttpURLConnection#performance">Android Docs</a>
     * @see <a href="https://android.googlesource.com/platform/external/okhttp/+/master/okhttp/src/main/java/com/squareup/okhttp/internal/http/HttpEngine.java">Android Source Code</a>
     */
    private boolean shouldUncompressGzip() {
        String contentEncoding = this.connection.getContentEncoding();
        return enableWrapIoStreams &&
            contentEncoding != null &&
            contentEncoding.equalsIgnoreCase("gzip");
    }

    private void identifyTraceId() {
        if (traceId == null) {
            try {
                traceId = getRequestProperty(Embrace.getInstance().getTraceIdHeader());
            } catch (Exception e) {
                InternalStaticEmbraceLogger.logDebug("Failed to retrieve actual trace id header. Current: " + traceId);
            }
        }
    }

    @Override
    @Nullable
    public String getCipherSuite() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getCipherSuite();
        }

        return null;
    }

    @Override
    @Nullable
    public Certificate[] getLocalCertificates() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getLocalCertificates();
        }

        return new Certificate[0];
    }

    @Override
    @Nullable
    public Certificate[] getServerCertificates() throws SSLPeerUnverifiedException {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getServerCertificates();
        }

        return new Certificate[0];
    }

    @Override
    @Nullable
    public SSLSocketFactory getSslSocketFactory() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getSSLSocketFactory();
        }

        return null;
    }

    @Override
    public void setSslSocketFactory(@NonNull SSLSocketFactory factory) {
        if (this.connection instanceof HttpsURLConnection) {
            ((HttpsURLConnection) this.connection).setSSLSocketFactory(factory);
        }
    }

    @Override
    @Nullable
    public HostnameVerifier getHostnameVerifier() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getHostnameVerifier();
        }

        return null;
    }

    @Override
    public void setHostnameVerifier(@NonNull HostnameVerifier verifier) {
        if (this.connection instanceof HttpsURLConnection) {
            ((HttpsURLConnection) this.connection).setHostnameVerifier(verifier);
        }
    }

    @Override
    @Nullable
    public Principal getLocalPrincipal() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getLocalPrincipal();
        }

        return null;
    }

    @Override
    @Nullable
    public Principal getPeerPrincipal() throws SSLPeerUnverifiedException {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getPeerPrincipal();
        }

        return null;
    }

    @Nullable
    private InputStream getWrappedInputStream(InputStream connectionInputStream) {
        identifyTraceId();
        long startTime = System.currentTimeMillis();

        InputStream in = null;
        if (shouldUncompressGzip()) {
            try {
                in = countingInputStream(new BufferedInputStream(Unchecked.wrap(() -> new GZIPInputStream(connectionInputStream))));
            } catch (Exception e) {
                // This handles the case where it's availability is 0, causing the GZIPInputStream instantiation to fail.
            }
        }

        if (in == null) {
            in = enableWrapIoStreams ?
                countingInputStream(new BufferedInputStream(connectionInputStream)) : connectionInputStream;
        }

        logNetworkCall(startTime);
        return in;
    }

    private boolean hasNetworkCaptureRules() {
        if (this.connection.getURL() == null) {
            return false;
        }
        String url = this.connection.getURL().toString();
        String method = this.connection.getRequestMethod();

        return Embrace.getInstance().shouldCaptureNetworkBody(url, method);
    }
}
