001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataOutputStream;
021import java.io.EOFException;
022import java.io.IOException;
023import java.net.Socket;
024import java.net.URI;
025import java.net.UnknownHostException;
026import java.nio.ByteBuffer;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import javax.net.SocketFactory;
030import javax.net.ssl.SSLContext;
031import javax.net.ssl.SSLEngine;
032import javax.net.ssl.SSLEngineResult;
033import javax.net.ssl.SSLParameters;
034
035import org.apache.activemq.thread.TaskRunnerFactory;
036import org.apache.activemq.util.IOExceptionSupport;
037import org.apache.activemq.util.ServiceStopper;
038import org.apache.activemq.wireformat.WireFormat;
039
040/**
041 * This transport initializes the SSLEngine and reads the first command before
042 * handing off to the detected transport.
043 *
044 */
045public class AutoInitNioSSLTransport extends NIOSSLTransport {
046
047    public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
048        super(wireFormat, socketFactory, remoteLocation, localLocation);
049    }
050
051    public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
052        super(wireFormat, socket, null, null, null);
053    }
054
055    @Override
056    public void setSslContext(SSLContext sslContext) {
057        this.sslContext = sslContext;
058    }
059
060    public ByteBuffer getInputBuffer() {
061        return this.inputBuffer;
062    }
063
064    @Override
065    protected void initializeStreams() throws IOException {
066        NIOOutputStream outputStream = null;
067        try {
068            channel = socket.getChannel();
069            channel.configureBlocking(false);
070
071            if (sslContext == null) {
072                sslContext = SSLContext.getDefault();
073            }
074
075            String remoteHost = null;
076            int remotePort = -1;
077
078            try {
079                URI remoteAddress = new URI(this.getRemoteAddress());
080                remoteHost = remoteAddress.getHost();
081                remotePort = remoteAddress.getPort();
082            } catch (Exception e) {
083            }
084
085            // initialize engine, the initial sslSession we get will need to be
086            // updated once the ssl handshake process is completed.
087            if (remoteHost != null && remotePort != -1) {
088                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
089            } else {
090                sslEngine = sslContext.createSSLEngine();
091            }
092
093            if (verifyHostName) {
094                SSLParameters sslParams = new SSLParameters();
095                sslParams.setEndpointIdentificationAlgorithm("HTTPS");
096                sslEngine.setSSLParameters(sslParams);
097            }
098
099            sslEngine.setUseClientMode(false);
100            if (enabledCipherSuites != null) {
101                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
102            }
103
104            if (enabledProtocols != null) {
105                sslEngine.setEnabledProtocols(enabledProtocols);
106            }
107
108            if (wantClientAuth) {
109                sslEngine.setWantClientAuth(wantClientAuth);
110            }
111
112            if (needClientAuth) {
113                sslEngine.setNeedClientAuth(needClientAuth);
114            }
115
116            sslSession = sslEngine.getSession();
117
118            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
119            inputBuffer.clear();
120
121            outputStream = new NIOOutputStream(channel);
122            outputStream.setEngine(sslEngine);
123            this.dataOut = new DataOutputStream(outputStream);
124            this.buffOut = outputStream;
125            sslEngine.beginHandshake();
126            handshakeStatus = sslEngine.getHandshakeStatus();
127            doHandshake();
128
129        } catch (Exception e) {
130            try {
131                if(outputStream != null) {
132                    outputStream.close();
133                }
134                super.closeStreams();
135            } catch (Exception ex) {}
136            throw new IOException(e);
137        }
138    }
139
140    @Override
141    protected void doOpenWireInit() throws Exception {
142
143    }
144
145    public SSLEngine getSslSession() {
146        return this.sslEngine;
147    }
148
149    private volatile byte[] readData;
150
151    private final AtomicInteger readSize = new AtomicInteger();
152
153    public byte[] getReadData() {
154        return readData != null ? readData : new byte[0];
155    }
156
157    public AtomicInteger getReadSize() {
158        return readSize;
159    }
160
161    //Prevent concurrent access to SSLEngine
162    @Override
163    public synchronized void serviceRead() {
164        try {
165            if (handshakeInProgress) {
166                doHandshake();
167            }
168
169            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
170            plain.position(plain.limit());
171
172            while (true) {
173                if (!plain.hasRemaining()) {
174                    int readCount = secureRead(plain);
175
176                    if (readCount == 0) {
177                        break;
178                    }
179
180                    // channel is closed, cleanup
181                    if (readCount == -1) {
182                        onException(new EOFException());
183                        break;
184                    }
185
186                    receiveCounter += readCount;
187                    readSize.addAndGet(readCount);
188                }
189
190                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
191                    processCommand(plain);
192                    //we have received enough bytes to detect the protocol
193                    if (receiveCounter >= 8) {
194                        break;
195                    }
196                }
197            }
198        } catch (IOException e) {
199            onException(e);
200        } catch (Throwable e) {
201            onException(IOExceptionSupport.create(e));
202        }
203    }
204
205    @Override
206    protected void processCommand(ByteBuffer plain) throws Exception {
207        ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter);
208        if (readData != null) {
209            newBuffer.put(readData);
210        }
211        newBuffer.put(plain);
212        newBuffer.flip();
213        readData = newBuffer.array();
214    }
215
216
217    @Override
218    public void doStart() throws Exception {
219        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
220        // no need to init as we can delay that until demand (eg in doHandshake)
221        connect();
222    }
223
224
225    @Override
226    protected void doStop(ServiceStopper stopper) throws Exception {
227        if (taskRunnerFactory != null) {
228            taskRunnerFactory.shutdownNow();
229            taskRunnerFactory = null;
230        }
231    }
232
233
234}