001 /**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.fusesource.hawtdispatch.transport;
019
020 import javax.net.ssl.*;
021 import java.io.EOFException;
022 import java.io.IOException;
023 import java.net.Socket;
024 import java.net.URI;
025 import java.nio.ByteBuffer;
026 import java.nio.channels.*;
027 import java.security.cert.Certificate;
028 import java.security.cert.X509Certificate;
029 import java.util.ArrayList;
030 import java.util.concurrent.Executor;
031
032 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
033 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
034 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
035 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
036
037 /**
038 * An SSL Transport for secure communications.
039 *
040 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
041 */
042 public class SslTransport extends TcpTransport implements SecureTransport {
043
044
045 /**
046 * Maps uri schemes to a protocol algorithm names.
047 * Valid algorithm names listed at:
048 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
049 */
050 public static String protocol(String scheme) {
051 if( scheme.equals("tls") ) {
052 return "TLS";
053 } else if( scheme.startsWith("tlsv") ) {
054 return "TLSv"+scheme.substring(4);
055 } else if( scheme.equals("ssl") ) {
056 return "SSL";
057 } else if( scheme.startsWith("sslv") ) {
058 return "SSLv"+scheme.substring(4);
059 }
060 return null;
061 }
062
063 private SSLContext sslContext;
064 private SSLEngine engine;
065
066 private ByteBuffer readBuffer;
067 private boolean readUnderflow;
068
069 private ByteBuffer writeBuffer;
070 private boolean writeFlushing;
071
072 private ByteBuffer readOverflowBuffer;
073 private SSLChannel ssl_channel = new SSLChannel();
074
075 private Executor blockingExecutor;
076
077 public void setSSLContext(SSLContext ctx) {
078 this.sslContext = ctx;
079 }
080
081 /**
082 * Allows subclasses of TcpTransportFactory to create custom instances of
083 * TcpTransport.
084 */
085 public static SslTransport createTransport(URI uri) throws Exception {
086 String protocol = protocol(uri.getScheme());
087 if( protocol !=null ) {
088 SslTransport rc = new SslTransport();
089 rc.setSSLContext(SSLContext.getInstance(protocol));
090 return rc;
091 }
092 return null;
093 }
094
095 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
096
097 public int write(ByteBuffer plain) throws IOException {
098 return secure_write(plain);
099 }
100
101 public int read(ByteBuffer plain) throws IOException {
102 return secure_read(plain);
103 }
104
105 public boolean isOpen() {
106 return getSocketChannel().isOpen();
107 }
108
109 public void close() throws IOException {
110 getSocketChannel().close();
111 }
112
113 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
114 if(offset+length > srcs.length || length<0 || offset<0) {
115 throw new IndexOutOfBoundsException();
116 }
117 long rc=0;
118 for (int i = 0; i < length; i++) {
119 ByteBuffer src = srcs[offset+i];
120 if(src.hasRemaining()) {
121 rc += write(src);
122 }
123 if( src.hasRemaining() ) {
124 return rc;
125 }
126 }
127 return rc;
128 }
129
130 public long write(ByteBuffer[] srcs) throws IOException {
131 return write(srcs, 0, srcs.length);
132 }
133
134 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
135 if(offset+length > dsts.length || length<0 || offset<0) {
136 throw new IndexOutOfBoundsException();
137 }
138 long rc=0;
139 for (int i = 0; i < length; i++) {
140 ByteBuffer dst = dsts[offset+i];
141 if(dst.hasRemaining()) {
142 rc += read(dst);
143 }
144 if( dst.hasRemaining() ) {
145 return rc;
146 }
147 }
148 return rc;
149 }
150
151 public long read(ByteBuffer[] dsts) throws IOException {
152 return read(dsts, 0, dsts.length);
153 }
154
155 public Socket socket() {
156 SocketChannel c = channel;
157 if( c == null ) {
158 return null;
159 }
160 return c.socket();
161 }
162 }
163
164 public SSLSession getSSLSession() {
165 return engine==null ? null : engine.getSession();
166 }
167
168 public X509Certificate[] getPeerX509Certificates() {
169 if( engine==null ) {
170 return null;
171 }
172 try {
173 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
174 for( Certificate c:engine.getSession().getPeerCertificates() ) {
175 if(c instanceof X509Certificate) {
176 rc.add((X509Certificate) c);
177 }
178 }
179 return rc.toArray(new X509Certificate[rc.size()]);
180 } catch (SSLPeerUnverifiedException e) {
181 return null;
182 }
183 }
184
185 @Override
186 public void connecting(URI remoteLocation, URI localLocation) throws Exception {
187 assert engine == null;
188 engine = sslContext.createSSLEngine();
189 engine.setUseClientMode(true);
190 super.connecting(remoteLocation, localLocation);
191 }
192
193 @Override
194 public void connected(SocketChannel channel) throws Exception {
195 if (engine == null) {
196 engine = sslContext.createSSLEngine();
197 engine.setUseClientMode(false);
198 engine.setWantClientAuth(true);
199 }
200 super.connected(channel);
201 }
202
203 @Override
204 protected void initializeChannel() throws Exception {
205 super.initializeChannel();
206 SSLSession session = engine.getSession();
207 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
208 readBuffer.flip();
209 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
210 }
211
212 @Override
213 protected void onConnected() throws IOException {
214 super.onConnected();
215 engine.beginHandshake();
216 handshake();
217 }
218
219 @Override
220 public void flush() {
221 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
222 handshake();
223 } else {
224 super.flush();
225 }
226 }
227
228 @Override
229 protected void drainInbound() {
230 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
231 handshake();
232 } else {
233 super.drainInbound();
234 }
235 }
236
237 /**
238 * @return true if fully flushed.
239 * @throws IOException
240 */
241 protected boolean transportFlush() throws IOException {
242 while (true) {
243 if(writeFlushing) {
244 int count = super.writeChannel().write(writeBuffer);
245 if( !writeBuffer.hasRemaining() ) {
246 writeBuffer.clear();
247 writeFlushing = false;
248 suspendWrite();
249 return true;
250 } else {
251 return false;
252 }
253 } else {
254 if( writeBuffer.position()!=0 ) {
255 writeBuffer.flip();
256 writeFlushing = true;
257 resumeWrite();
258 } else {
259 return true;
260 }
261 }
262 }
263 }
264
265 private int secure_write(ByteBuffer plain) throws IOException {
266 if( !transportFlush() ) {
267 // can't write anymore until the write_secured_buffer gets fully flushed out..
268 return 0;
269 }
270 int rc = 0;
271 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
272 SSLEngineResult result = engine.wrap(plain, writeBuffer);
273 assert result.getStatus()!= BUFFER_OVERFLOW;
274 rc += result.bytesConsumed();
275 if( !transportFlush() ) {
276 break;
277 }
278 }
279 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
280 dispatchQueue.execute(new Runnable() {
281 public void run() {
282 handshake();
283 }
284 });
285 }
286 return rc;
287 }
288
289 private int secure_read(ByteBuffer plain) throws IOException {
290 int rc=0;
291 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
292 if( readOverflowBuffer !=null ) {
293 if( plain.hasRemaining() ) {
294 // lets drain the overflow buffer before trying to suck down anymore
295 // network bytes.
296 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
297 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
298 readOverflowBuffer.position(readOverflowBuffer.position()+size);
299 if( !readOverflowBuffer.hasRemaining() ) {
300 readOverflowBuffer = null;
301 }
302 rc += size;
303 } else {
304 return rc;
305 }
306 } else if( readUnderflow ) {
307 int count = super.readChannel().read(readBuffer);
308 if( count == -1 ) { // peer closed socket.
309 if (rc==0) {
310 return -1;
311 } else {
312 return rc;
313 }
314 }
315 if( count==0 ) { // no data available right now.
316 return rc;
317 }
318 // read in some more data, perhaps now we can unwrap.
319 readUnderflow = false;
320 readBuffer.flip();
321 } else {
322 SSLEngineResult result = engine.unwrap(readBuffer, plain);
323 rc += result.bytesProduced();
324 if( result.getStatus() == BUFFER_OVERFLOW ) {
325 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
326 result = engine.unwrap(readBuffer, readOverflowBuffer);
327 if( readOverflowBuffer.position()==0 ) {
328 readOverflowBuffer = null;
329 } else {
330 readOverflowBuffer.flip();
331 }
332 }
333 switch( result.getStatus() ) {
334 case CLOSED:
335 if (rc==0) {
336 engine.closeInbound();
337 return -1;
338 } else {
339 return rc;
340 }
341 case OK:
342 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
343 dispatchQueue.execute(new Runnable() {
344 public void run() {
345 handshake();
346 }
347 });
348 }
349 break;
350 case BUFFER_UNDERFLOW:
351 readBuffer.compact();
352 readUnderflow = true;
353 break;
354 case BUFFER_OVERFLOW:
355 throw new AssertionError("Unexpected case.");
356 }
357 }
358 }
359 return rc;
360 }
361
362 public void handshake() {
363 try {
364 if( !transportFlush() ) {
365 return;
366 }
367 switch (engine.getHandshakeStatus()) {
368 case NEED_TASK:
369 final Runnable task = engine.getDelegatedTask();
370 if( task!=null ) {
371 blockingExecutor.execute(new Runnable() {
372 public void run() {
373 task.run();
374 dispatchQueue.execute(new Runnable() {
375 public void run() {
376 if (isConnected()) {
377 handshake();
378 }
379 }
380 });
381 }
382 });
383 }
384 break;
385
386 case NEED_WRAP:
387 secure_write(ByteBuffer.allocate(0));
388 break;
389
390 case NEED_UNWRAP:
391 if( secure_read(ByteBuffer.allocate(0)) == -1) {
392 throw new EOFException("Peer disconnected during ssl handshake");
393 }
394 break;
395
396 case FINISHED:
397 case NOT_HANDSHAKING:
398 drainOutboundSource.merge(1);
399 break;
400
401 default:
402 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
403 break;
404 }
405 } catch (IOException e ) {
406 onTransportFailure(e);
407 }
408 }
409
410
411 public ReadableByteChannel readChannel() {
412 return ssl_channel;
413 }
414
415 public WritableByteChannel writeChannel() {
416 return ssl_channel;
417 }
418
419 public Executor getBlockingExecutor() {
420 return blockingExecutor;
421 }
422
423 public void setBlockingExecutor(Executor blockingExecutor) {
424 this.blockingExecutor = blockingExecutor;
425 }
426 }
427
428