/*
 *  Licensed to the Apache Software Foundation (ASF) under one or more
 *  contributor license agreements.  See the NOTICE file distributed with
 *  this work for additional information regarding copyright ownership.
 *  The ASF licenses this file to You under the Apache License, Version 2.0
 *  (the "License"); you may not use this file except in compliance with
 *  the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.apache.tomcat.websocket;

import static org.jboss.web.WebsocketsMessages.MESSAGES;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CoderResult;
import java.nio.charset.CodingErrorAction;

import javax.websocket.CloseReason;
import javax.websocket.CloseReason.CloseCodes;
import javax.websocket.MessageHandler;
import javax.websocket.PongMessage;

import org.apache.tomcat.util.ExceptionUtils;
import org.apache.tomcat.util.buf.Utf8Decoder;

/**
 * Takes the ServletInputStream, processes the WebSocket frames it contains and
 * extracts the messages. WebSocket Pings received will be responded to
 * automatically without any action required by the application.
 */
public abstract class WsFrameBase {

    // Connection level attributes
    protected final WsSession wsSession;
    protected final byte[] inputBuffer;

    // Attributes for control messages
    // Control messages can appear in the middle of other messages so need
    // separate attributes
    private final ByteBuffer controlBufferBinary = ByteBuffer.allocate(125);
    private final CharBuffer controlBufferText = CharBuffer.allocate(125);

    // Attributes of the current message
    private final CharsetDecoder utf8DecoderControl = new Utf8Decoder().
            onMalformedInput(CodingErrorAction.REPORT).
            onUnmappableCharacter(CodingErrorAction.REPORT);
    private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder().
            onMalformedInput(CodingErrorAction.REPORT).
            onUnmappableCharacter(CodingErrorAction.REPORT);
    private boolean continuationExpected = false;
    private boolean textMessage = false;
    private ByteBuffer messageBufferBinary;
    private CharBuffer messageBufferText;
    // Cache the message handler in force when the message starts so it is used
    // consistently for the entire message
    private MessageHandler binaryMsgHandler = null;
    private MessageHandler textMsgHandler = null;

    // Attributes of the current frame
    private boolean fin = false;
    private int rsv = 0;
    private byte opCode = 0;
    private final byte[] mask = new byte[4];
    private int maskIndex = 0;
    private long payloadLength = 0;
    private long payloadWritten = 0;

    // Attributes tracking state
    private State state = State.NEW_FRAME;
    private volatile boolean open = true;
    private int readPos = 0;
    protected int writePos = 0;

    public WsFrameBase(WsSession wsSession) {

        inputBuffer = new byte[Constants.DEFAULT_BUFFER_SIZE];
        messageBufferBinary =
                ByteBuffer.allocate(wsSession.getMaxBinaryMessageBufferSize());
        messageBufferText =
                CharBuffer.allocate(wsSession.getMaxTextMessageBufferSize());
        this.wsSession = wsSession;
    }


    protected void processInputBuffer() throws IOException {
        while (true) {
            wsSession.updateLastActive();

            if (state == State.NEW_FRAME) {
                if (!processInitialHeader()) {
                    break;
                }
                // If a close frame has been received, no further data should
                // have seen
                if (!open) {
                    throw new IOException(MESSAGES.receivedFrameAfterClose());
                }
            }
            if (state == State.PARTIAL_HEADER) {
                if (!processRemainingHeader()) {
                    break;
                }
            }
            if (state == State.DATA) {
                if (!processData()) {
                    break;
                }
            }
        }
    }


    /**
     * @return <code>true</code> if sufficient data was present to process all
     *         of the initial header
     */
    private boolean processInitialHeader() throws IOException {
        // Need at least two bytes of data to do this
        if (writePos - readPos < 2) {
            return false;
        }
        int b = inputBuffer[readPos++];
        fin = (b & 0x80) > 0;
        rsv = (b & 0x70) >>> 4;
        if (rsv != 0) {
            // Note extensions may use rsv bits but currently no extensions are
            // supported
            throw new WsIOException(new CloseReason(
                    CloseCodes.PROTOCOL_ERROR,
                    MESSAGES.unsupportedReservedBitsSet(Integer.valueOf(rsv))));
        }
        opCode = (byte) (b & 0x0F);
        if (Util.isControl(opCode)) {
            if (!fin) {
                throw new WsIOException(new CloseReason(
                        CloseCodes.PROTOCOL_ERROR,
                        MESSAGES.invalidFragmentedControlFrame()));
            }
            if (opCode != Constants.OPCODE_PING &&
                    opCode != Constants.OPCODE_PONG &&
                    opCode != Constants.OPCODE_CLOSE) {
                throw new WsIOException(new CloseReason(
                        CloseCodes.PROTOCOL_ERROR,
                        MESSAGES.invalidFrameOpcode(opCode)));
            }
        } else {
            if (continuationExpected) {
                if (opCode != Constants.OPCODE_CONTINUATION) {
                    throw new WsIOException(new CloseReason(
                            CloseCodes.PROTOCOL_ERROR,
                            MESSAGES.noContinuationFrame()));
                }
            } else {
                try {
                    if (opCode == Constants.OPCODE_BINARY) {
                        // New binary message
                        textMessage = false;
                        int size = wsSession.getMaxBinaryMessageBufferSize();
                        if (size != messageBufferBinary.capacity()) {
                            messageBufferBinary = ByteBuffer.allocate(size);
                        }
                        binaryMsgHandler = wsSession.getBinaryMessageHandler();
                        textMsgHandler = null;
                    } else if (opCode == Constants.OPCODE_TEXT) {
                        // New text message
                        textMessage = true;
                        int size = wsSession.getMaxTextMessageBufferSize();
                        if (size != messageBufferText.capacity()) {
                            messageBufferText = CharBuffer.allocate(size);
                        }
                        binaryMsgHandler = null;
                        textMsgHandler = wsSession.getTextMessageHandler();
                    } else {
                        throw new WsIOException(new CloseReason(
                                CloseCodes.PROTOCOL_ERROR,
                                MESSAGES.invalidFrameOpcode(opCode)));
                    }
                } catch (IllegalStateException ise) {
                    // Thrown if the session is already closed
                    throw new WsIOException(new CloseReason(
                            CloseCodes.PROTOCOL_ERROR,
                            MESSAGES.sessionClosed()));
                }
            }
            continuationExpected = !fin;
        }
        b = inputBuffer[readPos++];
        // Client data must be masked
        if ((b & 0x80) == 0 && isMasked()) {
            throw new WsIOException(new CloseReason(
                    CloseCodes.PROTOCOL_ERROR,
                    MESSAGES.frameWithoutMask()));
        }
        payloadLength = b & 0x7F;
        state = State.PARTIAL_HEADER;
        return true;
    }


    protected abstract boolean isMasked();


    /**
     * @return <code>true</code> if sufficient data was present to complete the
     *         processing of the header
     */
    private boolean processRemainingHeader() throws IOException {
        // Ignore the 2 bytes already read. 4 for the mask
        int headerLength;
        if (isMasked()) {
            headerLength = 4;
        } else {
            headerLength = 0;
        }
        // Add additional bytes depending on length
        if (payloadLength == 126) {
            headerLength += 2;
        } else if (payloadLength == 127) {
            headerLength += 8;
        }
        if (writePos - readPos < headerLength) {
            return false;
        }
        // Calculate new payload length if necessary
        if (payloadLength == 126) {
            payloadLength = byteArrayToLong(inputBuffer, readPos, 2);
            readPos += 2;
        } else if (payloadLength == 127) {
            payloadLength = byteArrayToLong(inputBuffer, readPos, 8);
            readPos += 8;
        }
        if (Util.isControl(opCode)) {
            if (payloadLength > 125) {
                throw new WsIOException(new CloseReason(
                        CloseCodes.PROTOCOL_ERROR,
                        MESSAGES.controlFramePayloadTooLarge(Long.valueOf(payloadLength))));
            }
            if (!fin) {
                throw new WsIOException(new CloseReason(
                        CloseCodes.PROTOCOL_ERROR,
                        MESSAGES.controlFrameWithoutFin()));
            }
        }
        if (isMasked()) {
            System.arraycopy(inputBuffer, readPos, mask, 0, 4);
            readPos += 4;
        }
        state = State.DATA;
        return true;
    }


    private boolean processData() throws IOException {
        boolean result;
        if (Util.isControl(opCode)) {
            result = processDataControl();
        } else if (textMessage) {
            result = processDataText();
        } else {
            result = processDataBinary();
        }
        checkRoomPayload();
        return result;
    }


    private boolean processDataControl() throws IOException {
        if (!appendPayloadToMessage(controlBufferBinary)) {
            return false;
        }
        controlBufferBinary.flip();
        if (opCode == Constants.OPCODE_CLOSE) {
            open = false;
            String reason = null;
            int code = CloseCodes.NORMAL_CLOSURE.getCode();
            if (controlBufferBinary.remaining() == 1) {
                controlBufferBinary.clear();
                // Payload must be zero or greater than 2
                throw new WsIOException(new CloseReason(
                        CloseCodes.PROTOCOL_ERROR,
                        MESSAGES.invalidOneByteClose()));
            }
            if (controlBufferBinary.remaining() > 1) {
                code = controlBufferBinary.getShort();
                if (controlBufferBinary.remaining() > 0) {
                    CoderResult cr = utf8DecoderControl.decode(
                            controlBufferBinary, controlBufferText, true);
                    if (cr.isError()) {
                        controlBufferBinary.clear();
                        controlBufferText.clear();
                        throw new WsIOException(new CloseReason(
                                CloseCodes.PROTOCOL_ERROR,
                                MESSAGES.invalidUtf8Close()));
                    }
                    // There will be no overflow as the output buffer is big
                    // enough. There will be no underflow as all the data is
                    // passed to the decoder in a single call.
                    controlBufferText.flip();
                    reason = controlBufferText.toString();
                }
            }
            wsSession.onClose(new CloseReason(Util.getCloseCode(code), reason));
        } else if (opCode == Constants.OPCODE_PING) {
            if (wsSession.isOpen()) {
                wsSession.getBasicRemote().sendPong(controlBufferBinary);
            }
        } else if (opCode == Constants.OPCODE_PONG) {
            MessageHandler.Whole<PongMessage> mhPong =
                    wsSession.getPongMessageHandler();
            if (mhPong != null) {
                try {
                    mhPong.onMessage(new WsPongMessage(controlBufferBinary));
                } catch (Throwable t) {
                    ExceptionUtils.handleThrowable(t);
                    wsSession.getLocal().onError(wsSession, t);
                } finally {
                    controlBufferBinary.clear();
                }
            }
        } else {
            // Should have caught this earlier but just in case...
            controlBufferBinary.clear();
            throw new WsIOException(new CloseReason(
                    CloseCodes.PROTOCOL_ERROR,
                    MESSAGES.invalidFrameOpcode(Integer.valueOf(opCode))));
        }
        controlBufferBinary.clear();
        newFrame();
        return true;
    }


    @SuppressWarnings("unchecked")
    private void sendMessageText(boolean last) throws WsIOException {
        if (textMsgHandler != null) {
            if (textMsgHandler instanceof WrappedMessageHandler) {
                long maxMessageSize =
                        ((WrappedMessageHandler) textMsgHandler).getMaxMessageSize();
                if (maxMessageSize > -1 &&
                        messageBufferText.remaining() > maxMessageSize) {
                    throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG,
                            MESSAGES.messageTooLarge(Long.valueOf(messageBufferText.remaining()),
                                    Long.valueOf(maxMessageSize))));
                }
            }

            try {
                if (textMsgHandler instanceof MessageHandler.Partial<?>) {
                    ((MessageHandler.Partial<String>) textMsgHandler).onMessage(
                            messageBufferText.toString(), last);
                } else {
                    // Caller ensures last == true if this branch is used
                    ((MessageHandler.Whole<String>) textMsgHandler).onMessage(
                            messageBufferText.toString());
                }
            } catch (Throwable t) {
                ExceptionUtils.handleThrowable(t);
                wsSession.getLocal().onError(wsSession, t);
            } finally {
                messageBufferText.clear();
            }
        }
    }


    private boolean processDataText() throws IOException {
        // Copy the available data to the buffer
        while (!appendPayloadToMessage(messageBufferBinary)) {
            // Frame not complete - we ran out of something
            // Convert bytes to UTF-8
            messageBufferBinary.flip();
            while (true) {
                CoderResult cr = utf8DecoderMessage.decode(
                        messageBufferBinary, messageBufferText, false);
                if (cr.isError()) {
                    throw new WsIOException(new CloseReason(
                            CloseCodes.NOT_CONSISTENT,
                            MESSAGES.invalidUtf8()));
                } else if (cr.isOverflow()) {
                    // Ran out of space in text buffer - flush it
                    if (usePartial()) {
                        messageBufferText.flip();
                        sendMessageText(false);
                        messageBufferText.clear();
                    } else {
                        throw new WsIOException(new CloseReason(
                                CloseCodes.TOO_BIG,
                                MESSAGES.textMessageTooLarge()));
                    }
                } else if (cr.isUnderflow()) {
                    // Need more input
                    // Compact what we have to create as much space as possible
                    messageBufferBinary.compact();

                    // What did we run out of?
                    if (readPos == writePos) {
                        // Ran out of input data - get some more
                        return false;
                    } else {
                        // Ran out of message buffer - exit inner loop and
                        // refill
                        break;
                    }
                }
            }
        }

        messageBufferBinary.flip();
        boolean last = false;
        // Frame is fully received
        // Convert bytes to UTF-8
        while (true) {
            CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary,
                    messageBufferText, last);
            if (cr.isError()) {
                throw new WsIOException(new CloseReason(
                        CloseCodes.NOT_CONSISTENT,
                        MESSAGES.invalidUtf8()));
            } else if (cr.isOverflow()) {
                // Ran out of space in text buffer - flush it
                if (usePartial()) {
                    messageBufferText.flip();
                    sendMessageText(false);
                    messageBufferText.clear();
                } else {
                    throw new WsIOException(new CloseReason(
                            CloseCodes.TOO_BIG,
                            MESSAGES.textMessageTooLarge()));
                }
            } else if (cr.isUnderflow() & !last) {
                // End of frame and possible message as well.

                if (continuationExpected) {
                    // If partial messages are supported, send what we have
                    // managed to decode
                    if (usePartial()) {
                        messageBufferText.flip();
                        sendMessageText(false);
                        messageBufferText.clear();
                    }
                    messageBufferBinary.compact();
                    newFrame();
                    // Process next frame
                    return true;
                } else {
                    // Make sure coder has flushed all output
                    last = true;
                }
            } else {
                // End of message
                messageBufferText.flip();
                sendMessageText(true);

                newMessage();
                return true;
            }
        }
    }


    private boolean processDataBinary() throws IOException {
        // Copy the available data to the buffer
        while (!appendPayloadToMessage(messageBufferBinary)) {
            // Frame not complete - what did we run out of?
            if (readPos == writePos) {
                // Ran out of input data - get some more
                return false;
            } else {
                // Ran out of message buffer - flush it
                if (!usePartial()) {
                    CloseReason cr = new CloseReason(CloseCodes.TOO_BIG,
                            MESSAGES.bufferTooSmall(Integer.valueOf(messageBufferBinary.capacity()),
                                    Long.valueOf(payloadLength)));
                    throw new WsIOException(cr);
                }
                messageBufferBinary.flip();
                ByteBuffer copy =
                        ByteBuffer.allocate(messageBufferBinary.limit());
                copy.put(messageBufferBinary);
                copy.flip();
                sendMessageBinary(copy, false);
                messageBufferBinary.clear();
            }
        }

        // Frame is fully received
        // Send the message if either:
        // - partial messages are supported
        // - the message is complete
        if (usePartial() || !continuationExpected) {
            messageBufferBinary.flip();
            ByteBuffer copy =
                    ByteBuffer.allocate(messageBufferBinary.limit());
            copy.put(messageBufferBinary);
            copy.flip();
            sendMessageBinary(copy, !continuationExpected);
            messageBufferBinary.clear();
        }

        if (continuationExpected) {
            // More data for this message expected, start a new frame
            newFrame();
        } else {
            // Message is complete, start a new message
            newMessage();
        }

        return true;
    }


    @SuppressWarnings("unchecked")
    private void sendMessageBinary(ByteBuffer msg, boolean last)
            throws WsIOException {
        if (binaryMsgHandler != null) {
            if (binaryMsgHandler instanceof WrappedMessageHandler) {
                long maxMessageSize =
                        ((WrappedMessageHandler) binaryMsgHandler).getMaxMessageSize();
                if (maxMessageSize > -1 && msg.remaining() > maxMessageSize) {
                    throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG,
                            MESSAGES.messageTooLarge(
                                    Long.valueOf(msg.remaining()),
                                    Long.valueOf(maxMessageSize))));
                }
            }
            try {
                if (binaryMsgHandler instanceof MessageHandler.Partial<?>) {
                    ((MessageHandler.Partial<ByteBuffer>) binaryMsgHandler).onMessage(msg, last);
                } else {
                    // Caller ensures last == true if this branch is used
                    ((MessageHandler.Whole<ByteBuffer>) binaryMsgHandler).onMessage(msg);
                }
            } catch(Throwable t) {
                ExceptionUtils.handleThrowable(t);
                wsSession.getLocal().onError(wsSession, t);
            }
        }
    }


    private void newMessage() {
        messageBufferBinary.clear();
        messageBufferText.clear();
        utf8DecoderMessage.reset();
        continuationExpected = false;
        newFrame();
    }


    private void newFrame() {
        if (readPos == writePos) {
            readPos = 0;
            writePos = 0;
        }

        maskIndex = 0;
        payloadWritten = 0;
        state = State.NEW_FRAME;

        // These get reset in processInitialHeader()
        // fin, rsv, opCode, payloadLength, mask

        checkRoomHeaders();
    }


    private void checkRoomHeaders() {
        // Is the start of the current frame too near the end of the input
        // buffer?
        if (inputBuffer.length - readPos < 131) {
            // Limit based on a control frame with a full payload
            makeRoom();
        }
    }


    private void checkRoomPayload() {
        if (inputBuffer.length - readPos - payloadLength + payloadWritten < 0) {
            makeRoom();
        }
    }


    private void makeRoom() {
        System.arraycopy(inputBuffer, readPos, inputBuffer, 0,
                writePos - readPos);
        writePos = writePos - readPos;
        readPos = 0;
    }


    private boolean usePartial() {
        if (Util.isControl(opCode)) {
            return false;
        } else if (textMessage) {
            if (textMsgHandler != null) {
                return textMsgHandler instanceof MessageHandler.Partial<?>;
            }
            return false;
        } else {
            // Must be binary
            if (binaryMsgHandler != null) {
                return binaryMsgHandler instanceof MessageHandler.Partial<?>;
            }
            return false;
        }
    }


    private boolean appendPayloadToMessage(ByteBuffer dest) {
        if (isMasked()) {
            while (payloadWritten < payloadLength && readPos < writePos &&
                    dest.hasRemaining()) {
                byte b = (byte) ((inputBuffer[readPos] ^ mask[maskIndex]) & 0xFF);
                maskIndex++;
                if (maskIndex == 4) {
                    maskIndex = 0;
                }
                readPos++;
                payloadWritten++;
                dest.put(b);
            }
            return (payloadWritten == payloadLength);
        } else {
            long toWrite = Math.min(
                    payloadLength - payloadWritten, writePos - readPos);
            toWrite = Math.min(toWrite, dest.remaining());

            dest.put(inputBuffer, readPos, (int) toWrite);
            readPos += toWrite;
            payloadWritten += toWrite;
            return (payloadWritten == payloadLength);

        }
    }


    protected static long byteArrayToLong(byte[] b, int start, int len)
            throws IOException {
        if (len > 8) {
            throw new IOException(MESSAGES.invalidLong(Long.valueOf(len)));
        }
        int shift = 0;
        long result = 0;
        for (int i = start + len - 1; i >= start; i--) {
            result = result + ((b[i] & 0xFF) << shift);
            shift += 8;
        }
        return result;
    }


    protected boolean isOpen() {
        return open;
    }


    private static enum State {
        NEW_FRAME, PARTIAL_HEADER, DATA
    }
}
