/*
 * Decompiled with CFR 0.152.
 */
package net.dreamlu.iot.mqtt.core.server.http.websocket;

import java.nio.ByteBuffer;
import net.dreamlu.iot.mqtt.codec.ByteBufferUtil;
import net.dreamlu.iot.mqtt.codec.MqttMessage;
import net.dreamlu.iot.mqtt.codec.WriteBuffer;
import net.dreamlu.iot.mqtt.core.server.MqttServerCreator;
import org.tio.core.ChannelContext;
import org.tio.core.Tio;
import org.tio.core.TioConfig;
import org.tio.core.intf.AioHandler;
import org.tio.core.intf.Packet;
import org.tio.http.common.HttpRequest;
import org.tio.http.common.HttpResponse;
import org.tio.websocket.common.WsRequest;
import org.tio.websocket.common.WsResponse;
import org.tio.websocket.server.handler.IWsMsgHandler;

public class MqttWsMsgHandler
implements IWsMsgHandler {
    private static final String MQTT_WS_MSG_BODY_KEY = "MQTT_WS_MSG_BODY_KEY";
    private final MqttServerCreator serverCreator;
    private final String[] supportedSubProtocols;
    private final AioHandler mqttServerAioHandler;

    public MqttWsMsgHandler(MqttServerCreator serverCreator, AioHandler aioHandler) {
        this(serverCreator, new String[]{"mqtt", "mqttv3.1", "mqttv3.1.1"}, aioHandler);
    }

    public MqttWsMsgHandler(MqttServerCreator serverCreator, String[] supportedSubProtocols, AioHandler aioHandler) {
        this.serverCreator = serverCreator;
        this.supportedSubProtocols = supportedSubProtocols;
        this.mqttServerAioHandler = aioHandler;
    }

    public String[] getSupportedSubProtocols() {
        return this.supportedSubProtocols;
    }

    public HttpResponse handshake(HttpRequest request, HttpResponse httpResponse, ChannelContext channelContext) {
        if (this.serverCreator.isWebsocketEnable()) {
            return httpResponse;
        }
        return null;
    }

    public void onAfterHandshaked(HttpRequest request, HttpResponse response, ChannelContext context) {
        WriteBuffer wsBody = (WriteBuffer)context.get(MQTT_WS_MSG_BODY_KEY);
        if (wsBody == null) {
            wsBody = new WriteBuffer();
            context.set(MQTT_WS_MSG_BODY_KEY, (Object)wsBody);
        }
    }

    public Object onBytes(WsRequest wsRequest, byte[] bytes, ChannelContext context) throws Exception {
        WriteBuffer wsBody = (WriteBuffer)context.get(MQTT_WS_MSG_BODY_KEY);
        ByteBuffer buffer = MqttWsMsgHandler.getMqttBody(wsBody, bytes);
        if (buffer == null) {
            return null;
        }
        while (buffer.hasRemaining()) {
            Packet packet = this.mqttServerAioHandler.decode(buffer, 0, 0, buffer.remaining(), context);
            if (packet == null) {
                int remaining = buffer.remaining();
                if (remaining > 0) {
                    byte[] data = new byte[remaining];
                    buffer.get(data);
                    wsBody.writeBytes(data);
                }
                return null;
            }
            this.mqttServerAioHandler.handler(packet, context);
        }
        return null;
    }

    public WsResponse encodeSubProtocol(Packet packet, TioConfig tioConfig, ChannelContext context) {
        if (packet instanceof MqttMessage) {
            ByteBuffer buffer = this.mqttServerAioHandler.encode(packet, null, context);
            return WsResponse.fromBytes((byte[])buffer.array());
        }
        return null;
    }

    public Object onClose(WsRequest wsRequest, byte[] bytes, ChannelContext context) {
        Tio.remove((ChannelContext)context, (String)"Mqtt websocket close.");
        return null;
    }

    public Object onText(WsRequest wsRequest, String text, ChannelContext context) {
        return null;
    }

    private static synchronized ByteBuffer getMqttBody(WriteBuffer wsBody, byte[] bytes) {
        wsBody.writeBytes(bytes);
        int length = wsBody.size();
        if (length < 2) {
            return null;
        }
        ByteBuffer buffer = wsBody.toBuffer();
        int mqttLength = MqttWsMsgHandler.getMqttLength(buffer) + 2;
        if (length < mqttLength) {
            return null;
        }
        wsBody.reset();
        buffer.rewind();
        return buffer;
    }

    private static int getMqttLength(ByteBuffer buffer) {
        short digit;
        ByteBufferUtil.skipBytes((ByteBuffer)buffer, (int)1);
        int remainingLength = 0;
        int multiplier = 1;
        int loops = 0;
        do {
            digit = ByteBufferUtil.readUnsignedByte((ByteBuffer)buffer);
            remainingLength += (digit & 0x7F) * multiplier;
            multiplier *= 128;
        } while ((digit & 0x80) != 0 && ++loops < 4);
        return remainingLength;
    }
}

