package io.relayr.java.websocket;

import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken;
import org.eclipse.paho.client.mqttv3.IMqttToken;
import org.eclipse.paho.client.mqttv3.MqttAsyncClient;
import org.eclipse.paho.client.mqttv3.MqttCallback;
import org.eclipse.paho.client.mqttv3.MqttException;
import org.eclipse.paho.client.mqttv3.MqttMessage;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import io.relayr.java.model.channel.DataChannel;
import rx.Observable;
import rx.Subscriber;

class MqttWebSocket extends WebSocket<DataChannel> {

    private final String TAG = "MqttWebSocket";

    @Override
    public Observable<DataChannel> createClient(final DataChannel channel) {
        return Observable.create(new Observable.OnSubscribe<DataChannel>() {
            @Override
            public void call(Subscriber<? super DataChannel> subscriber) {
                synchronized (mLock) {
                    if (channel == null) {
                        subscriber.onError(new Throwable("DataChannel data can't be null"));
                        return;
                    }

                    final DataChannel.ChannelCredentials credentials = channel.getCredentials();
                    if (mClient != null && mClient.isConnected()) {
                        subscriber.onNext(channel);
                        return;
                    }

                    if (mClient != null || createMqttClient(credentials.getClientId())) {
                        try {
                            connect(mClient, credentials.getUser(), credentials.getPassword());
                            subscriber.onNext(channel);
                        } catch (MqttException e) {
                            System.err.println("Failed to connect MQTT client");
                            subscriber.onError(e);
                        }
                    } else {
                        subscriber.onError(new Throwable("Client not created!"));
                    }
                }
            }
        });
    }

    @Override
    public Observable<DataChannel> createPublishClient(final DataChannel channel) {
        return Observable.create(new Observable.OnSubscribe<DataChannel>() {
            @Override
            public void call(Subscriber<? super DataChannel> subscriber) {
                synchronized (mLock) {
                    if (channel == null) {
                        subscriber.onError(new Throwable("DataChannel data can't be null"));
                        return;
                    }

                    final DataChannel.ChannelCredentials credentials = channel.getCredentials();
                    if (mPublishClients.get(credentials.getUser()) != null) {
                        subscriber.onNext(channel);
                        return;
                    }

                    if (createPublishClient(credentials.getUser(), credentials.getClientId())) {
                        try {
                            connect(mPublishClients.get(credentials.getUser()), credentials.getUser(), credentials.getPassword());
                            subscriber.onNext(channel);
                        } catch (MqttException e) {
                            subscriber.onError(e);
                        }
                    } else {
                        System.err.println("Failed to create MQTT publish client");
                        subscriber.onError(new Throwable("Client not created!"));
                    }
                }
            }
        });
    }


    @Override
    public boolean unSubscribe(String topic) {
        if (topic == null) {
            System.out.println(TAG + ": Topic can't be null!");
            return false;
        }

        try {
            mTopicCallbacks.remove(topic);
            final IMqttToken unSubscribeToken = mClient.unsubscribe(topic);
            unSubscribeToken.waitForCompletion(UNSUBSCRIBE_TIMEOUT);
            return true;
        } catch (MqttException e) {
            e.printStackTrace();
            return false;
        }
    }

    @Override
    public boolean subscribe(String topic, String channelId, final WebSocketCallback callback) {
        return subscribe(mClient, topic, channelId, callback);
    }

    @Override
    boolean subscribeAction(String topic, String deviceId, String channelId, WebSocketCallback callback) {
        return subscribe(mPublishClients.get(deviceId), topic, channelId, callback);
    }

    private boolean createPublishClient(String deviceId, String clientId) {
        try {
            final MqttAsyncClient client = new MqttAsyncClient(SslUtil.instance().getBroker(), clientId, null);
            addCallback(client);
            mPublishClients.put(deviceId, client);
        } catch (MqttException e) {
            System.err.println("Failed to create MQTT publish  client");
            e.printStackTrace();
            if (mTopicCallbacks == null || mTopicCallbacks.isEmpty()) return false;
            for (List<WebSocketCallback> callbacks : mTopicCallbacks.values())
                for (WebSocketCallback socketCallback : callbacks)
                    socketCallback.disconnectCallback(e);

            return false;
        }
        return true;
    }

    private boolean createMqttClient(String clientId) {
        try {
            mClient = new MqttAsyncClient(SslUtil.instance().getBroker(), clientId, null);
            addCallback(mClient);
        } catch (MqttException e) {
            System.err.println("Failed to create MQTT client");
            e.printStackTrace();
            if (mTopicCallbacks == null || mTopicCallbacks.isEmpty()) return false;
            for (List<WebSocketCallback> callbacks : mTopicCallbacks.values())
                for (WebSocketCallback socketCallback : callbacks)
                    socketCallback.disconnectCallback(e);

            return false;
        }
        return true;
    }

    private boolean addCallback(MqttAsyncClient client) {
        client.setCallback(new MqttCallback() {
            @Override
            public void connectionLost(Throwable cause) {
                cause.printStackTrace();
                if (mTopicCallbacks == null || mTopicCallbacks.isEmpty()) return;

                for (List<WebSocketCallback> callbacks : mTopicCallbacks.values())
                    for (WebSocketCallback socketCallback : callbacks)
                        socketCallback.disconnectCallback(cause);
            }

            @Override
            public void messageArrived(String topic, MqttMessage message) {
                if (mTopicCallbacks == null || mTopicCallbacks.isEmpty()) return;

                for (WebSocketCallback socketCallback : mTopicCallbacks.get(topic))
                    socketCallback.successCallback(message);
            }

            @Override
            public void deliveryComplete(IMqttDeliveryToken token) {
            }
        });

        return true;
    }

    private void connect(MqttAsyncClient client, String username, String password) throws MqttException {
        if (!client.isConnected()) {
            final IMqttToken connectToken = client.connect(SslUtil.instance().getConnectOptions(username, password));
            connectToken.waitForCompletion(CONNECT_TIMEOUT);
        }
    }

    private void addCallback(String topic, WebSocketCallback callback) {
        if (mTopicCallbacks.get(topic) == null)
            mTopicCallbacks.put(topic, new ArrayList<>(Arrays.asList(callback)));
        else
            mTopicCallbacks.get(topic).add(callback);
    }

    private boolean subscribe(MqttAsyncClient client, String topic, String channelId, final WebSocketCallback callback) {
        if (callback == null) {
            System.out.println(TAG + ": Argument WebSocketCallback can not be null!");
            return false;
        }

        if (topic == null) {
            callback.errorCallback(new IllegalArgumentException("Topic can't be null!"));
            return false;
        }

        if (mTopicCallbacks.containsKey(topic)) {
            addCallback(topic, callback);
            return true;
        }

        try {
            subscribe(client, topic);
            addCallback(topic, callback);
            callback.connectCallback("Subscribed to " + channelId);
        } catch (MqttException e) {
            callback.disconnectCallback(e);
            return false;
        }

        return true;
    }

    private void subscribe(MqttAsyncClient client, String topic) throws MqttException {
        if (client == null) return;

        List<String> topics = new ArrayList<>();
        topics.add(topic);
        topics.addAll(mTopicCallbacks.keySet());

        int[] qos = new int[topics.size()];
        Arrays.fill(qos, 1);

        final IMqttToken subscribeToken = client.subscribe(topics.toArray(new String[topics.size()]), qos);
        subscribeToken.waitForCompletion(SUBSCRIBE_TIMEOUT);
    }
}
