package com.openfin.desktop.channel.webrtc;

import dev.onvoid.webrtc.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;

import static java.util.Objects.nonNull;

public class DataChannel implements RTCDataChannelObserver {
    private final static Logger logger = LoggerFactory.getLogger(DataChannel.class);
    private RTCDataChannel dataChannel;
    private final CopyOnWriteArrayList<DataChannelListener> channelListeners;
    private final List<String> outboundQueue;   // store messages when not ready to send.  Both endpoints of a Data Channel may not get
                                                // OPEN event at 'same' time so it is possible for one endpoint not ready to send messages
                                                // when the other believe the channel is OPEN.  Messages are queued if not OPEN

    public DataChannel(RTCPeerConnection peerConnection, String name) {
        logger.debug("creating data channel {}", name);
        this.dataChannel = peerConnection.createDataChannel(name , new RTCDataChannelInit());
        this.dataChannel.registerObserver(this);
        this.channelListeners = new CopyOnWriteArrayList<>();
        this.outboundQueue = new ArrayList<>();
    }

    public DataChannel(RTCDataChannel dataChannel) {
        logger.debug("wrapping data channel {}", dataChannel.getLabel());
        this.dataChannel = dataChannel;
        this.dataChannel.registerObserver(this);
        this.channelListeners = new CopyOnWriteArrayList<>();
        this.outboundQueue = new ArrayList<>();
    }

    public String getName() {
        return this.dataChannel.getLabel();
    }

    public DataChannelListener.State getState() {
        return this.mapState(this.dataChannel.getState());
    }

    public boolean addChannelListener(DataChannelListener listener) {
        return this.channelListeners.add(listener);
    }

    public boolean removeChannelListener(DataChannelListener listener) {
        return this.channelListeners.remove(listener);
    }

    private void fireChannelStatusEvent() {
        var state = this.mapState(this.dataChannel.getState());
        if (nonNull(state)) {
            for (DataChannelListener listener : this.channelListeners) {
                listener.onStateChange(this, state);
            }
        }
    }

    private void fireChannelMessageEvent(String message) {
        for (DataChannelListener listener : this.channelListeners) {
            listener.onMessage(this, message);
        }
    }

    private DataChannelListener.State mapState(RTCDataChannelState state) {
        if (state == RTCDataChannelState.OPEN) {
            return DataChannelListener.State.OPEN;
        }
        if (state == RTCDataChannelState.CLOSED) {
            return DataChannelListener.State.CLOSED;
        }
        return null;
    }

    public void send(String s) throws Exception {
        logger.debug("datachannel {} sending {}", this.getName(), s);
        DataChannelListener.State state =  this.getState();
        if (state == DataChannelListener.State.OPEN) {
            this.sendString(s);
        } else if (Objects.isNull(state)) {
            this.queueMessage(s);
        }
        else {
            logger.error("datachannel {} not open for sending {}", this.getName(), state);
        }
    }

    private synchronized void queueMessage(String s) {
        this.outboundQueue.add(s);
        logger.debug("Queuing outbound message {} {}", this.getName(), this.outboundQueue.size());
    }

    private void sendString(String s) throws Exception {
        logger.debug("datachannel {} sending string {}", this.getName(), s);
        ByteBuffer data = ByteBuffer.wrap(s.getBytes(StandardCharsets.UTF_8));
        RTCDataChannelBuffer buffer = new RTCDataChannelBuffer(data, false);
        this.dataChannel.send(buffer);
    }
    private synchronized void checkOutboundQueue() throws Exception {
        if (this.outboundQueue.size() > 0) {
            logger.debug("Sending queued outbound message {} {}", this.getName(), this.outboundQueue.size());
            for (String s : this.outboundQueue) {
                this.sendString(s);
            }
            this.outboundQueue.clear();
        }
    }

    public void close() {
        if (nonNull(this.dataChannel)) {
            logger.debug("Closing channel {}", getName());
            this.dataChannel.close();
            this.dataChannel.unregisterObserver();
            this.dataChannel.close();
            this.dataChannel.dispose();
            this.dataChannel = null;
        }
    }

    @Override
    public void onBufferedAmountChange(long previousAmount) {
        logger.debug("onBufferedAmountChange {}", previousAmount);
    }

    @Override
    public void onStateChange() {
        // IMPORTANT NOTE:  if any callback from WetRTC native code throws exception, native code may break with weird errors
        //                  any callback must do try/catch to prevent exceptions bubbled up.
        try {
            logger.debug("onStateChange {} {} {}", this.getName(), this.dataChannel.getId(), this.dataChannel.getState().toString());
            if (this.getState() == DataChannelListener.State.OPEN) {
                this.checkOutboundQueue();
            }
            this.fireChannelStatusEvent();
        } catch (Exception ex) {
            logger.error("onStateChange", ex);
        }
    }

    @Override
    public void onMessage(RTCDataChannelBuffer rtcDataChannelBuffer) {
        try {
            String m = decodeMessage(rtcDataChannelBuffer);
            this.fireChannelMessageEvent(m);
        } catch (Exception ex) {
            logger.error("onMessage", ex);
        }
    }

    private String decodeMessage(RTCDataChannelBuffer buffer) {
        ByteBuffer byteBuffer = buffer.data;
        byte[] payload;
        if (byteBuffer.hasArray()) {
            payload = byteBuffer.array();
        }
        else {
            payload = new byte[byteBuffer.limit()];
            byteBuffer.get(payload);
        }
        return new String(payload, StandardCharsets.UTF_8);
    }

}
