package com.liveperson.infra.network.socket;

import android.text.TextUtils;

import com.liveperson.infra.log.LPLog;
import com.liveperson.infra.model.SocketConnectionParams;
import com.liveperson.infra.utils.TlsUtil;
import com.liveperson.infra.utils.Utils;

import java.net.ProtocolException;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLPeerUnverifiedException;

import okhttp3.CertificatePinner;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.ByteString;

/**
 * Created by ofira on 11/22/17.
 * OkHttp web socket implementation
 */

public class SocketWrapperOK implements ISocketWrapper {

    private static final String TAG = "SocketWrapperOK";

    private WebSocket mWebSocket;
    private ISocketCallbacks mSocketCallbacks;


    public SocketWrapperOK(ISocketCallbacks socketCallbacks) {
        mSocketCallbacks = socketCallbacks;
        mSocketCallbacks.onStateChanged(SocketState.INIT);
    }

    @Override
    public void connect(SocketConnectionParams connectionParams) throws IllegalArgumentException {

        Request.Builder requestBuilder = new Request.Builder()
                .url(connectionParams.getUrl());
        for (Map.Entry<String, String> header : connectionParams.getHeaders().entrySet()) {
            requestBuilder.addHeader(header.getKey(), header.getValue());
        }
        Request request = requestBuilder.build();


        WebSocketCallbacks listener = new WebSocketCallbacks();

        OkHttpClient.Builder clientBuilder = new OkHttpClient.Builder();
        TlsUtil.enableTls12ForKitKat(clientBuilder);
        clientBuilder.pingInterval(SocketHandler.PERIODIC_PING_TIME, TimeUnit.MILLISECONDS);

        //Certificate Pinning, checking if we have a valid list
        if (connectionParams.getCertificatePinningKeys() != null) {
            CertificatePinner.Builder builder = new CertificatePinner.Builder();
            for (String key : connectionParams.getCertificatePinningKeys()) {
                LPLog.INSTANCE.d(TAG, "Pinning Key: " + LPLog.INSTANCE.mask(key));
                if (Utils.isValidCertificateKey(key)) {
                    builder.add(request.url().host(), key);
                }
            }
            CertificatePinner certificatePinner = builder.build();
            clientBuilder.certificatePinner(certificatePinner);
        }

        OkHttpClient client = clientBuilder.build();

        LPLog.INSTANCE.d(TAG, "Socket connecting.... " + connectionParams.getUrl() +
                (connectionParams.getCertificatePinningKeys() != null ? "with Pinning Keys " +
                        LPLog.INSTANCE.mask(TextUtils.join(",", connectionParams.getCertificatePinningKeys())) : " with no Pinning Keys"));

        mWebSocket = client.newWebSocket(request, listener);
        mSocketCallbacks.onStateChanged(SocketState.CONNECTING);
    }

    @Override
    public void disconnect() {
        LPLog.INSTANCE.d(TAG, "Socket disconnect was called");
        if (mWebSocket != null) {
            mSocketCallbacks.onStateChanged(SocketState.CLOSING);
            mWebSocket.close(1000, "Disconnected by device");
        }
    }

    @Override
    public void send(String message) {
        LPLog.INSTANCE.d(TAG, "Socket send " + LPLog.INSTANCE.mask(message));
        if (mWebSocket != null) {
            mWebSocket.send(message);
        }
    }

    private class WebSocketCallbacks extends WebSocketListener {

        /**
         * Invoked when a web socket has been accepted by the remote peer and may begin transmitting
         * messages.
         */
        public void onOpen(WebSocket webSocket, Response response) {
            LPLog.INSTANCE.i(TAG, "onOpen() called with: response = [" + response + "]");
            mSocketCallbacks.onStateChanged(SocketState.OPEN);
        }

        /**
         * Invoked when a text (type {@code 0x1}) message has been received.
         */
        public void onMessage(WebSocket webSocket, String text) {
            LPLog.INSTANCE.i(TAG, "---- Socket onMessage callback with text: " + LPLog.INSTANCE.mask(text));
            mSocketCallbacks.onMessage(text);
        }

        /**
         * Invoked when a binary (type {@code 0x2}) message has been received.
         * TODO: Currently not in use
         */
        public void onMessage(WebSocket webSocket, ByteString bytes) {
            LPLog.INSTANCE.i(TAG, "Socket onMessage callback with ByteString");
        }

        /**
         * Invoked when the peer has indicated that no more incoming messages will be transmitted.
         */
        public void onClosing(WebSocket webSocket, int code, String reason) {
            LPLog.INSTANCE.i(TAG, "onClosing() called with: code = [" + code + "], reason = [" + reason + "]");
            mSocketCallbacks.onDisconnected(reason, code);
            mSocketCallbacks.onStateChanged(SocketState.CLOSING);
        }

        /**
         * Invoked when both peers have indicated that no more messages will be transmitted and the
         * connection has been successfully released. sNo further calls to this listener will be made.
         */
        public void onClosed(WebSocket webSocket, int code, String reason) {
            LPLog.INSTANCE.i(TAG, "onClosed() called with: code = [" + code + "]," + " reason = [" + reason + "]");
            mSocketCallbacks.onStateChanged(SocketState.CLOSED);
        }

        /**
         * Invoked when a web socket has been closed due to an error reading from or writing to the
         * network. Both outgoing and incoming messages may have been lost. No further calls to this
         * listener will be made.
         */
        public void onFailure(WebSocket webSocket, Throwable t, Response response) {
            String errorMessage = (t != null? t.getMessage() : "");
            LPLog.INSTANCE.i(TAG, "onFailure() called with: webSocket = [" + webSocket + "], throwable = [" + t + "], response = [" + response + "]" + "ErrorMessage = " + errorMessage );
            mSocketCallbacks.onStateChanged(SocketState.CLOSED);
            if (t instanceof SSLPeerUnverifiedException) {
                mSocketCallbacks.onDisconnected(t.getMessage(), SocketHandler.CERTIFICATE_ERROR);
            } else if (t instanceof ProtocolException && !errorMessage.isEmpty()) {
                //The OKHttp itself is throwing "Received HTTP_PROXY_AUTH (407) code while not using proxy" instead of response,
                //take a look at line 279 RetryAndFollowUpInterceptor.java for reference
                if (errorMessage.contains("HTTP_PROXY_AUTH (407)")) {
                    mSocketCallbacks.onDisconnected("identity token is invalid", 4407);
                }
            }
        }
    }
}
