/*
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 */
package org.apache.qpid.server.protocol.v1_0;

import java.io.IOException;
import java.io.PrintWriter;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.security.PrivilegedAction;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.Subject;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.qpid.amqp_1_0.codec.FrameWriter;
import org.apache.qpid.amqp_1_0.codec.ProtocolHandler;
import org.apache.qpid.amqp_1_0.framing.AMQFrame;
import org.apache.qpid.amqp_1_0.framing.OversizeFrameException;
import org.apache.qpid.amqp_1_0.framing.SASLFrameHandler;
import org.apache.qpid.amqp_1_0.transport.SaslServerProvider;
import org.apache.qpid.amqp_1_0.transport.ConnectionEndpoint;
import org.apache.qpid.amqp_1_0.transport.Container;
import org.apache.qpid.amqp_1_0.transport.FrameOutputHandler;
import org.apache.qpid.amqp_1_0.type.Binary;
import org.apache.qpid.amqp_1_0.type.FrameBody;
import org.apache.qpid.amqp_1_0.type.Symbol;
import org.apache.qpid.amqp_1_0.type.transport.*;
import org.apache.qpid.amqp_1_0.type.transport.Error;
import org.apache.qpid.common.QpidProperties;
import org.apache.qpid.common.ServerPropertyNames;
import org.apache.qpid.protocol.ServerProtocolEngine;
import org.apache.qpid.server.model.Broker;
import org.apache.qpid.server.model.Port;
import org.apache.qpid.server.model.Transport;
import org.apache.qpid.server.security.SubjectCreator;
import org.apache.qpid.server.security.auth.UsernamePrincipal;
import org.apache.qpid.server.util.ServerScopedRuntimeException;
import org.apache.qpid.server.virtualhost.VirtualHost;
import org.apache.qpid.transport.Sender;
import org.apache.qpid.transport.TransportException;
import org.apache.qpid.transport.network.NetworkConnection;

public class ProtocolEngine_1_0_0_SASL implements ServerProtocolEngine, FrameOutputHandler
{
    private static final org.apache.log4j.Logger
            _logger = org.apache.log4j.Logger.getLogger(ProtocolEngine_1_0_0_SASL.class);

    private final Port _port;
    private final Transport _transport;
    private long _readBytes;
    private long _writtenBytes;

    private long _lastReadTime;
    private long _lastWriteTime;
    private final Broker _broker;
    private long _createTime = System.currentTimeMillis();
    private ConnectionEndpoint _endpoint;
    private long _connectionId;

    private static final ByteBuffer HEADER =
           ByteBuffer.wrap(new byte[]
                   {
                       (byte)'A',
                       (byte)'M',
                       (byte)'Q',
                       (byte)'P',
                       (byte) 3,
                       (byte) 1,
                       (byte) 0,
                       (byte) 0
                   });

    private static final ByteBuffer PROTOCOL_HEADER =
        ByteBuffer.wrap(new byte[]
                {
                    (byte)'A',
                    (byte)'M',
                    (byte)'Q',
                    (byte)'P',
                    (byte) 0,
                    (byte) 1,
                    (byte) 0,
                    (byte) 0
                });


    private FrameWriter _frameWriter;
    private ProtocolHandler _frameHandler;
    private ByteBuffer _buf = ByteBuffer.allocate(1024 * 1024);
    private Object _sendLock = new Object();
    private byte _major;
    private byte _minor;
    private byte _revision;
    private PrintWriter _out;
    private NetworkConnection _network;
    private Sender<ByteBuffer> _sender;
    private Connection_1_0 _connection;


    static enum State {
           A,
           M,
           Q,
           P,
           PROTOCOL,
           MAJOR,
           MINOR,
           REVISION,
           FRAME
       }

    private State _state = State.A;


    public ProtocolEngine_1_0_0_SASL(final NetworkConnection networkDriver, final Broker broker,
                                     long id, Port port, Transport transport)
    {
        _connectionId = id;
        _broker = broker;
        _port = port;
        _transport = transport;
        if(networkDriver != null)
        {
            setNetworkConnection(networkDriver, networkDriver.getSender());
        }
    }


    public SocketAddress getRemoteAddress()
    {
        return _network.getRemoteAddress();
    }

    public SocketAddress getLocalAddress()
    {
        return _network.getLocalAddress();
    }

    public long getReadBytes()
    {
        return _readBytes;
    }

    public long getWrittenBytes()
    {
        return _writtenBytes;
    }

    public void writerIdle()
    {
        //Todo
    }

    public void readerIdle()
    {
        //Todo
    }

    public void setNetworkConnection(final NetworkConnection network, final Sender<ByteBuffer> sender)
    {
        _network = network;
        _sender = sender;

        Container container = new Container(_broker.getId().toString());

        SubjectCreator subjectCreator = _broker.getSubjectCreator(getLocalAddress());
        _endpoint = new ConnectionEndpoint(container, asSaslServerProvider(subjectCreator));

        Map<Symbol,Object> serverProperties = new LinkedHashMap<Symbol, Object>();
        serverProperties.put(Symbol.valueOf(ServerPropertyNames.PRODUCT), QpidProperties.getProductName());
        serverProperties.put(Symbol.valueOf(ServerPropertyNames.VERSION), QpidProperties.getReleaseVersion());
        serverProperties.put(Symbol.valueOf(ServerPropertyNames.QPID_BUILD), QpidProperties.getBuildVersion());
        serverProperties.put(Symbol.valueOf(ServerPropertyNames.QPID_INSTANCE_NAME), _broker.getName());

        _endpoint.setProperties(serverProperties);

        _endpoint.setRemoteAddress(getRemoteAddress());
        _connection = new Connection_1_0(_broker, _endpoint, _connectionId, _port, _transport, subjectCreator);
        _endpoint.setConnectionEventListener(_connection);
        _endpoint.setFrameOutputHandler(this);
        _endpoint.setSaslFrameOutput(this);

        _endpoint.setOnSaslComplete(new Runnable()
        {
            public void run()
            {
                if (_endpoint.isAuthenticated())
                {
                    _sender.send(PROTOCOL_HEADER.duplicate());
                    _sender.flush();
                }
                else
                {
                    _network.close();
                }
            }
        });
        _frameWriter =  new FrameWriter(_endpoint.getDescribedTypeRegistry());
        _frameHandler = new SASLFrameHandler(_endpoint);

        _sender.send(HEADER.duplicate());
        _sender.flush();

        _endpoint.initiateSASL(subjectCreator.getMechanisms().split(" "));


    }

    private SaslServerProvider asSaslServerProvider(final SubjectCreator subjectCreator)
    {
        return new SaslServerProvider()
        {
            @Override
            public SaslServer getSaslServer(String mechanism, String fqdn) throws SaslException
            {
                return subjectCreator.createSaslServer(mechanism, fqdn, _network.getPeerPrincipal());
            }

            @Override
            public Principal getAuthenticatedPrincipal(SaslServer server)
            {
                return new UsernamePrincipal(server.getAuthorizationID());
            }
        };
    }

    public String getAddress()
    {
        return getRemoteAddress().toString();
    }

    public boolean isDurable()
    {
        return false;
    }

    private final Logger RAW_LOGGER = Logger.getLogger("RAW");


    public synchronized void received(final ByteBuffer msg)
    {
        try
        {
            _lastReadTime = System.currentTimeMillis();
            if(RAW_LOGGER.isLoggable(Level.FINE))
            {
                ByteBuffer dup = msg.duplicate();
                byte[] data = new byte[dup.remaining()];
                dup.get(data);
                Binary bin = new Binary(data);
                RAW_LOGGER.fine("RECV[" + getRemoteAddress() + "] : " + bin.toString());
            }
            _readBytes += msg.remaining();
            switch(_state)
            {
                case A:
                    if (msg.hasRemaining())
                    {
                        msg.get();
                    }
                    else
                    {
                        break;
                    }
                case M:
                    if (msg.hasRemaining())
                    {
                        msg.get();
                    }
                    else
                    {
                        _state = State.M;
                        break;
                    }

                case Q:
                    if (msg.hasRemaining())
                    {
                        msg.get();
                    }
                    else
                    {
                        _state = State.Q;
                        break;
                    }
                case P:
                    if (msg.hasRemaining())
                    {
                        msg.get();
                    }
                    else
                    {
                        _state = State.P;
                        break;
                    }
                case PROTOCOL:
                    if (msg.hasRemaining())
                    {
                        msg.get();
                    }
                    else
                    {
                        _state = State.PROTOCOL;
                        break;
                    }
                case MAJOR:
                    if (msg.hasRemaining())
                    {
                        _major = msg.get();
                    }
                    else
                    {
                        _state = State.MAJOR;
                        break;
                    }
                case MINOR:
                    if (msg.hasRemaining())
                    {
                        _minor = msg.get();
                    }
                    else
                    {
                        _state = State.MINOR;
                        break;
                    }
                case REVISION:
                    if (msg.hasRemaining())
                    {
                        _revision = msg.get();

                        _state = State.FRAME;
                    }
                    else
                    {
                        _state = State.REVISION;
                        break;
                    }
                case FRAME:
                    if (msg.hasRemaining())
                    {
                        Subject.doAs(_connection.getSubject(), new PrivilegedAction<Void>()
                        {
                            @Override
                            public Void run()
                            {
                                _frameHandler = _frameHandler.parse(msg);
                                return null;
                            }
                        });

                    }
            }
        }
        catch(RuntimeException e)
        {
            exception(e);
        }
     }

    public void exception(Throwable throwable)
    {
        if (throwable instanceof IOException)
        {
            _logger.info("IOException caught in " + this + ", connection closed implicitly: " + throwable);
        }
        else
        {

            try
            {
                final Error err = new Error();
                err.setCondition(AmqpError.INTERNAL_ERROR);
                err.setDescription(throwable.getMessage());
                _endpoint.close(err);
                close();
            }
            catch(TransportException e)
            {
                _logger.info("Error when handling exception",e);
            }
            finally
            {
                if(throwable instanceof java.lang.Error)
                {
                    throw (java.lang.Error) throwable;
                }
                if(throwable instanceof ServerScopedRuntimeException)
                {
                    throw (ServerScopedRuntimeException) throwable;
                }
            }
        }
    }

    public void closed()
    {
        try
        {
            // todo
            _endpoint.inputClosed();
            if (_endpoint != null && _endpoint.getConnectionEventListener() != null)
            {
                ((Connection_1_0) _endpoint.getConnectionEventListener()).closed();
            }
        }
        catch(RuntimeException e)
        {
            exception(e);
        }
    }

    public long getCreateTime()
    {
        return _createTime;
    }


    public boolean canSend()
    {
        return true;
    }

    public void send(final AMQFrame amqFrame)
    {
        send(amqFrame, null);
    }

    private static final Logger FRAME_LOGGER = Logger.getLogger("FRM");


    public void send(final AMQFrame amqFrame, ByteBuffer buf)
    {

        synchronized (_sendLock)
        {
            _lastWriteTime = System.currentTimeMillis();
            if (FRAME_LOGGER.isLoggable(Level.FINE))
            {
                FRAME_LOGGER.fine("SEND[" + getRemoteAddress() + "|" + amqFrame.getChannel() + "] : " + amqFrame.getFrameBody());
            }

            _frameWriter.setValue(amqFrame);

            ByteBuffer dup = ByteBuffer.allocate(_endpoint.getMaxFrameSize());

            int size = _frameWriter.writeToBuffer(dup);
            if (size > _endpoint.getMaxFrameSize())
            {
                throw new OversizeFrameException(amqFrame, size);
            }

            dup.flip();
            _writtenBytes += dup.limit();

            if (RAW_LOGGER.isLoggable(Level.FINE))
            {
                ByteBuffer dup2 = dup.duplicate();
                byte[] data = new byte[dup2.remaining()];
                dup2.get(data);
                Binary bin = new Binary(data);
                RAW_LOGGER.fine("SEND[" + getRemoteAddress() + "] : " + bin.toString());
            }

            _sender.send(dup);
            _sender.flush();


        }
    }

    public void send(short channel, FrameBody body)
    {
        AMQFrame frame = AMQFrame.createAMQFrame(channel, body);
        send(frame);

    }

    public void close()
    {
        _sender.close();
    }

    public void setLogOutput(final PrintWriter out)
    {
        _out = out;
    }

    public long getConnectionId()
    {
        return _connectionId;
    }

    @Override
    public Subject getSubject()
    {
        return _connection.getSubject();
    }

    public long getLastReadTime()
    {
        return _lastReadTime;
    }

    public long getLastWriteTime()
    {
        return _lastWriteTime;
    }
}
