001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import javax.net.ssl.*;
021    import java.io.EOFException;
022    import java.io.IOException;
023    import java.net.Socket;
024    import java.net.URI;
025    import java.nio.ByteBuffer;
026    import java.nio.channels.*;
027    import java.security.cert.Certificate;
028    import java.security.cert.X509Certificate;
029    import java.util.ArrayList;
030    import java.util.concurrent.Executor;
031    
032    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
033    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
034    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
035    import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
036    
037    /**
038     * An SSL Transport for secure communications.
039     *
040     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
041     */
042    public class SslTransport extends TcpTransport implements SecureTransport {
043    
044    
045        /**
046         * Maps uri schemes to a protocol algorithm names.
047         * Valid algorithm names listed at:
048         * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
049         */
050        public static String protocol(String scheme) {
051            if( scheme.equals("tls") ) {
052                return "TLS";
053            } else if( scheme.startsWith("tlsv") ) {
054                return "TLSv"+scheme.substring(4);
055            } else if( scheme.equals("ssl") ) {
056                return "SSL";
057            } else if( scheme.startsWith("sslv") ) {
058                return "SSLv"+scheme.substring(4);
059            }
060            return null;
061        }
062    
063        private SSLContext sslContext;
064        private SSLEngine engine;
065    
066        private ByteBuffer readBuffer;
067        private boolean readUnderflow;
068    
069        private ByteBuffer writeBuffer;
070        private boolean writeFlushing;
071    
072        private ByteBuffer readOverflowBuffer;
073        private SSLChannel ssl_channel = new SSLChannel();
074    
075        private Executor blockingExecutor;
076    
077        public void setSSLContext(SSLContext ctx) {
078            this.sslContext = ctx;
079        }
080    
081        /**
082         * Allows subclasses of TcpTransportFactory to create custom instances of
083         * TcpTransport.
084         */
085        public static SslTransport createTransport(URI uri) throws Exception {
086            String protocol = protocol(uri.getScheme());
087            if( protocol !=null ) {
088                SslTransport rc = new SslTransport();
089                rc.setSSLContext(SSLContext.getInstance(protocol));
090                return rc;
091            }
092            return null;
093        }
094    
095        public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
096    
097            public int write(ByteBuffer plain) throws IOException {
098                return secure_write(plain);
099            }
100    
101            public int read(ByteBuffer plain) throws IOException {
102                return secure_read(plain);
103            }
104    
105            public boolean isOpen() {
106                return getSocketChannel().isOpen();
107            }
108    
109            public void close() throws IOException {
110                getSocketChannel().close();
111            }
112    
113            public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
114                if(offset+length > srcs.length || length<0 || offset<0) {
115                    throw new IndexOutOfBoundsException();
116                }
117                long rc=0;
118                for (int i = 0; i < length; i++) {
119                    ByteBuffer src = srcs[offset+i];
120                    if(src.hasRemaining()) {
121                        rc += write(src);
122                    }
123                    if( src.hasRemaining() ) {
124                        return rc;
125                    }
126                }
127                return rc;
128            }
129    
130            public long write(ByteBuffer[] srcs) throws IOException {
131                return write(srcs, 0, srcs.length);
132            }
133    
134            public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
135                if(offset+length > dsts.length || length<0 || offset<0) {
136                    throw new IndexOutOfBoundsException();
137                }
138                long rc=0;
139                for (int i = 0; i < length; i++) {
140                    ByteBuffer dst = dsts[offset+i];
141                    if(dst.hasRemaining()) {
142                        rc += read(dst);
143                    }
144                    if( dst.hasRemaining() ) {
145                        return rc;
146                    }
147                }
148                return rc;
149            }
150    
151            public long read(ByteBuffer[] dsts) throws IOException {
152                return read(dsts, 0, dsts.length);
153            }
154            
155            public Socket socket() {
156                SocketChannel c = channel;
157                if( c == null ) {
158                    return null;
159                }
160                return c.socket();
161            }
162        }
163    
164        public SSLSession getSSLSession() {
165            return engine==null ? null : engine.getSession();
166        }
167    
168        public X509Certificate[] getPeerX509Certificates() {
169            if( engine==null ) {
170                return null;
171            }
172            try {
173                ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
174                for( Certificate c:engine.getSession().getPeerCertificates() ) {
175                    if(c instanceof X509Certificate) {
176                        rc.add((X509Certificate) c);
177                    }
178                }
179                return rc.toArray(new X509Certificate[rc.size()]);
180            } catch (SSLPeerUnverifiedException e) {
181                return null;
182            }
183        }
184    
185        @Override
186        public void connecting(URI remoteLocation, URI localLocation) throws Exception {
187            assert engine == null;
188            engine = sslContext.createSSLEngine();
189            engine.setUseClientMode(true);
190            super.connecting(remoteLocation, localLocation);
191        }
192    
193        @Override
194        public void connected(SocketChannel channel) throws Exception {
195            if (engine == null) {
196                engine = sslContext.createSSLEngine();
197                engine.setUseClientMode(false);
198                engine.setWantClientAuth(true);
199            }
200            super.connected(channel);
201        }
202    
203        @Override
204        protected void initializeChannel() throws Exception {
205            super.initializeChannel();
206            SSLSession session = engine.getSession();
207            readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
208            readBuffer.flip();
209            writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
210        }
211    
212        @Override
213        protected void onConnected() throws IOException {
214            super.onConnected();
215            engine.beginHandshake();
216            handshake();
217        }
218    
219        @Override
220        public void flush() {
221            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
222                handshake();
223            } else {
224                super.flush();
225            }
226        }
227    
228        @Override
229        protected void drainInbound() {
230            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
231                handshake();
232            } else {
233                super.drainInbound();
234            }
235        }
236    
237        /**
238         * @return true if fully flushed.
239         * @throws IOException
240         */
241        protected boolean transportFlush() throws IOException {
242            while (true) {
243                if(writeFlushing) {
244                    int count = super.writeChannel().write(writeBuffer);
245                    if( !writeBuffer.hasRemaining() ) {
246                        writeBuffer.clear();
247                        writeFlushing = false;
248                        suspendWrite();
249                        return true;
250                    } else {
251                        return false;
252                    }
253                } else {
254                    if( writeBuffer.position()!=0 ) {
255                        writeBuffer.flip();
256                        writeFlushing = true;
257                        resumeWrite();
258                    } else {
259                        return true;
260                    }
261                }
262            }
263        }
264    
265        private int secure_write(ByteBuffer plain) throws IOException {
266            if( !transportFlush() ) {
267                // can't write anymore until the write_secured_buffer gets fully flushed out..
268                return 0;
269            }
270            int rc = 0;
271            while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
272                SSLEngineResult result = engine.wrap(plain, writeBuffer);
273                assert result.getStatus()!= BUFFER_OVERFLOW;
274                rc += result.bytesConsumed();
275                if( !transportFlush() ) {
276                    break;
277                }
278            }
279            if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
280                dispatchQueue.execute(new Runnable() {
281                    public void run() {
282                        handshake();
283                    }
284                });
285            }
286            return rc;
287        }
288    
289        private int secure_read(ByteBuffer plain) throws IOException {
290            int rc=0;
291            while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
292                if( readOverflowBuffer !=null ) {
293                    if(  plain.hasRemaining() ) {
294                        // lets drain the overflow buffer before trying to suck down anymore
295                        // network bytes.
296                        int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
297                        plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
298                        readOverflowBuffer.position(readOverflowBuffer.position()+size);
299                        if( !readOverflowBuffer.hasRemaining() ) {
300                            readOverflowBuffer = null;
301                        }
302                        rc += size;
303                    } else {
304                        return rc;
305                    }
306                } else if( readUnderflow ) {
307                    int count = super.readChannel().read(readBuffer);
308                    if( count == -1 ) {  // peer closed socket.
309                        if (rc==0) {
310                            return -1;
311                        } else {
312                            return rc;
313                        }
314                    }
315                    if( count==0 ) {  // no data available right now.
316                        return rc;
317                    }
318                    // read in some more data, perhaps now we can unwrap.
319                    readUnderflow = false;
320                    readBuffer.flip();
321                } else {
322                    SSLEngineResult result = engine.unwrap(readBuffer, plain);
323                    rc += result.bytesProduced();
324                    if( result.getStatus() == BUFFER_OVERFLOW ) {
325                        readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
326                        result = engine.unwrap(readBuffer, readOverflowBuffer);
327                        if( readOverflowBuffer.position()==0 ) {
328                            readOverflowBuffer = null;
329                        } else {
330                            readOverflowBuffer.flip();
331                        }
332                    }
333                    switch( result.getStatus() ) {
334                        case CLOSED:
335                            if (rc==0) {
336                                engine.closeInbound();
337                                return -1;
338                            } else {
339                                return rc;
340                            }
341                        case OK:
342                            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
343                                dispatchQueue.execute(new Runnable() {
344                                    public void run() {
345                                        handshake();
346                                    }
347                                });
348                            }
349                            break;
350                        case BUFFER_UNDERFLOW:
351                            readBuffer.compact();
352                            readUnderflow = true;
353                            break;
354                        case BUFFER_OVERFLOW:
355                            throw new AssertionError("Unexpected case.");
356                    }
357                }
358            }
359            return rc;
360        }
361    
362        public void handshake() {
363            try {
364                if( !transportFlush() ) {
365                    return;
366                }
367                switch (engine.getHandshakeStatus()) {
368                    case NEED_TASK:
369                        final Runnable task = engine.getDelegatedTask();
370                        if( task!=null ) {
371                            blockingExecutor.execute(new Runnable() {
372                                public void run() {
373                                    task.run();
374                                    dispatchQueue.execute(new Runnable() {
375                                        public void run() {
376                                            if (isConnected()) {
377                                                handshake();
378                                            }
379                                        }
380                                    });
381                                }
382                            });
383                        }
384                        break;
385    
386                    case NEED_WRAP:
387                        secure_write(ByteBuffer.allocate(0));
388                        break;
389    
390                    case NEED_UNWRAP:
391                        if( secure_read(ByteBuffer.allocate(0)) == -1) {
392                            throw new EOFException("Peer disconnected during ssl handshake");
393                        }
394                        break;
395    
396                    case FINISHED:
397                    case NOT_HANDSHAKING:
398                        drainOutboundSource.merge(1);
399                        break;
400    
401                    default:
402                        System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
403                        break;
404                }
405            } catch (IOException e ) {
406                onTransportFailure(e);
407            }
408        }
409    
410    
411        public ReadableByteChannel readChannel() {
412            return ssl_channel;
413        }
414    
415        public WritableByteChannel writeChannel() {
416            return ssl_channel;
417        }
418    
419        public Executor getBlockingExecutor() {
420            return blockingExecutor;
421        }
422    
423        public void setBlockingExecutor(Executor blockingExecutor) {
424            this.blockingExecutor = blockingExecutor;
425        }
426    }
427    
428