001 /**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.fusesource.hawtdispatch.transport;
019
020 import org.fusesource.hawtdispatch.Task;
021
022 import javax.net.ssl.*;
023 import java.io.EOFException;
024 import java.io.IOException;
025 import java.net.Socket;
026 import java.net.URI;
027 import java.nio.ByteBuffer;
028 import java.nio.channels.*;
029 import java.security.cert.Certificate;
030 import java.security.cert.X509Certificate;
031 import java.util.ArrayList;
032
033 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
034 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
035 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
036 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
037
038 /**
039 * An SSL Transport for secure communications.
040 *
041 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
042 */
043 public class SslTransport extends TcpTransport implements SecuredSession {
044
045
046 /**
047 * Maps uri schemes to a protocol algorithm names.
048 * Valid algorithm names listed at:
049 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
050 */
051 public static String protocol(String scheme) {
052 if( scheme.equals("tls") ) {
053 return "TLS";
054 } else if( scheme.startsWith("tlsv") ) {
055 return "TLSv"+scheme.substring(4);
056 } else if( scheme.equals("ssl") ) {
057 return "SSL";
058 } else if( scheme.startsWith("sslv") ) {
059 return "SSLv"+scheme.substring(4);
060 }
061 return null;
062 }
063
064 enum ClientAuth {
065 WANT, NEED, NONE
066 };
067
068 private ClientAuth clientAuth = ClientAuth.WANT;
069
070 private SSLContext sslContext;
071 private SSLEngine engine;
072
073 private ByteBuffer readBuffer;
074 private boolean readUnderflow;
075
076 private ByteBuffer writeBuffer;
077 private boolean writeFlushing;
078
079 private ByteBuffer readOverflowBuffer;
080 private SSLChannel ssl_channel = new SSLChannel();
081
082
083 public void setSSLContext(SSLContext ctx) {
084 this.sslContext = ctx;
085 }
086
087 /**
088 * Allows subclasses of TcpTransportFactory to create custom instances of
089 * TcpTransport.
090 */
091 public static SslTransport createTransport(URI uri) throws Exception {
092 String protocol = protocol(uri.getScheme());
093 if( protocol !=null ) {
094 SslTransport rc = new SslTransport();
095 rc.setSSLContext(SSLContext.getInstance(protocol));
096 return rc;
097 }
098 return null;
099 }
100
101 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
102
103 public int write(ByteBuffer plain) throws IOException {
104 return secure_write(plain);
105 }
106
107 public int read(ByteBuffer plain) throws IOException {
108 return secure_read(plain);
109 }
110
111 public boolean isOpen() {
112 return getSocketChannel().isOpen();
113 }
114
115 public void close() throws IOException {
116 getSocketChannel().close();
117 }
118
119 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
120 if(offset+length > srcs.length || length<0 || offset<0) {
121 throw new IndexOutOfBoundsException();
122 }
123 long rc=0;
124 for (int i = 0; i < length; i++) {
125 ByteBuffer src = srcs[offset+i];
126 if(src.hasRemaining()) {
127 rc += write(src);
128 }
129 if( src.hasRemaining() ) {
130 return rc;
131 }
132 }
133 return rc;
134 }
135
136 public long write(ByteBuffer[] srcs) throws IOException {
137 return write(srcs, 0, srcs.length);
138 }
139
140 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
141 if(offset+length > dsts.length || length<0 || offset<0) {
142 throw new IndexOutOfBoundsException();
143 }
144 long rc=0;
145 for (int i = 0; i < length; i++) {
146 ByteBuffer dst = dsts[offset+i];
147 if(dst.hasRemaining()) {
148 rc += read(dst);
149 }
150 if( dst.hasRemaining() ) {
151 return rc;
152 }
153 }
154 return rc;
155 }
156
157 public long read(ByteBuffer[] dsts) throws IOException {
158 return read(dsts, 0, dsts.length);
159 }
160
161 public Socket socket() {
162 SocketChannel c = channel;
163 if( c == null ) {
164 return null;
165 }
166 return c.socket();
167 }
168 }
169
170 public SSLSession getSSLSession() {
171 return engine==null ? null : engine.getSession();
172 }
173
174 public X509Certificate[] getPeerX509Certificates() {
175 if( engine==null ) {
176 return null;
177 }
178 try {
179 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
180 for( Certificate c:engine.getSession().getPeerCertificates() ) {
181 if(c instanceof X509Certificate) {
182 rc.add((X509Certificate) c);
183 }
184 }
185 return rc.toArray(new X509Certificate[rc.size()]);
186 } catch (SSLPeerUnverifiedException e) {
187 return null;
188 }
189 }
190
191 @Override
192 public void connecting(URI remoteLocation, URI localLocation) throws Exception {
193 assert engine == null;
194 engine = sslContext.createSSLEngine();
195 engine.setUseClientMode(true);
196 super.connecting(remoteLocation, localLocation);
197 }
198
199 @Override
200 public void connected(SocketChannel channel) throws Exception {
201 if (engine == null) {
202 engine = sslContext.createSSLEngine();
203 engine.setUseClientMode(false);
204 switch (clientAuth) {
205 case WANT: engine.setWantClientAuth(true); break;
206 case NEED: engine.setNeedClientAuth(true); break;
207 case NONE: engine.setWantClientAuth(false); break;
208 }
209
210 }
211 super.connected(channel);
212 }
213
214 @Override
215 protected void initializeChannel() throws Exception {
216 super.initializeChannel();
217 SSLSession session = engine.getSession();
218 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
219 readBuffer.flip();
220 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
221 }
222
223 @Override
224 protected void onConnected() throws IOException {
225 super.onConnected();
226 engine.beginHandshake();
227 handshake();
228 }
229
230 @Override
231 public void flush() {
232 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
233 handshake();
234 } else {
235 super.flush();
236 }
237 }
238
239 @Override
240 public void drainInbound() {
241 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
242 handshake();
243 } else {
244 super.drainInbound();
245 }
246 }
247
248 /**
249 * @return true if fully flushed.
250 * @throws IOException
251 */
252 protected boolean transportFlush() throws IOException {
253 while (true) {
254 if(writeFlushing) {
255 int count = super.getWriteChannel().write(writeBuffer);
256 if( !writeBuffer.hasRemaining() ) {
257 writeBuffer.clear();
258 writeFlushing = false;
259 suspendWrite();
260 return true;
261 } else {
262 return false;
263 }
264 } else {
265 if( writeBuffer.position()!=0 ) {
266 writeBuffer.flip();
267 writeFlushing = true;
268 resumeWrite();
269 } else {
270 return true;
271 }
272 }
273 }
274 }
275
276 private int secure_write(ByteBuffer plain) throws IOException {
277 if( !transportFlush() ) {
278 // can't write anymore until the write_secured_buffer gets fully flushed out..
279 return 0;
280 }
281 int rc = 0;
282 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
283 SSLEngineResult result = engine.wrap(plain, writeBuffer);
284 assert result.getStatus()!= BUFFER_OVERFLOW;
285 rc += result.bytesConsumed();
286 if( !transportFlush() ) {
287 break;
288 }
289 }
290 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
291 dispatchQueue.execute(new Task() {
292 public void run() {
293 handshake();
294 }
295 });
296 }
297 return rc;
298 }
299
300 private int secure_read(ByteBuffer plain) throws IOException {
301 int rc=0;
302 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
303 if( readOverflowBuffer !=null ) {
304 if( plain.hasRemaining() ) {
305 // lets drain the overflow buffer before trying to suck down anymore
306 // network bytes.
307 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
308 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
309 readOverflowBuffer.position(readOverflowBuffer.position()+size);
310 if( !readOverflowBuffer.hasRemaining() ) {
311 readOverflowBuffer = null;
312 }
313 rc += size;
314 } else {
315 return rc;
316 }
317 } else if( readUnderflow ) {
318 int count = super.getReadChannel().read(readBuffer);
319 if( count == -1 ) { // peer closed socket.
320 if (rc==0) {
321 return -1;
322 } else {
323 return rc;
324 }
325 }
326 if( count==0 ) { // no data available right now.
327 return rc;
328 }
329 // read in some more data, perhaps now we can unwrap.
330 readUnderflow = false;
331 readBuffer.flip();
332 } else {
333 SSLEngineResult result = engine.unwrap(readBuffer, plain);
334 rc += result.bytesProduced();
335 if( result.getStatus() == BUFFER_OVERFLOW ) {
336 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
337 result = engine.unwrap(readBuffer, readOverflowBuffer);
338 if( readOverflowBuffer.position()==0 ) {
339 readOverflowBuffer = null;
340 } else {
341 readOverflowBuffer.flip();
342 }
343 }
344 switch( result.getStatus() ) {
345 case CLOSED:
346 if (rc==0) {
347 engine.closeInbound();
348 return -1;
349 } else {
350 return rc;
351 }
352 case OK:
353 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
354 dispatchQueue.execute(new Task() {
355 public void run() {
356 handshake();
357 }
358 });
359 }
360 break;
361 case BUFFER_UNDERFLOW:
362 readBuffer.compact();
363 readUnderflow = true;
364 break;
365 case BUFFER_OVERFLOW:
366 throw new AssertionError("Unexpected case.");
367 }
368 }
369 }
370 return rc;
371 }
372
373 public void handshake() {
374 try {
375 if( !transportFlush() ) {
376 return;
377 }
378 switch (engine.getHandshakeStatus()) {
379 case NEED_TASK:
380 final Runnable task = engine.getDelegatedTask();
381 if( task!=null ) {
382 blockingExecutor.execute(new Task() {
383 public void run() {
384 task.run();
385 dispatchQueue.execute(new Task() {
386 public void run() {
387 if (isConnected()) {
388 handshake();
389 }
390 }
391 });
392 }
393 });
394 }
395 break;
396
397 case NEED_WRAP:
398 secure_write(ByteBuffer.allocate(0));
399 break;
400
401 case NEED_UNWRAP:
402 if( secure_read(ByteBuffer.allocate(0)) == -1) {
403 throw new EOFException("Peer disconnected during ssl handshake");
404 }
405 break;
406
407 case FINISHED:
408 case NOT_HANDSHAKING:
409 drainOutboundSource.merge(1);
410 drainInbound();
411 break;
412
413 default:
414 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
415 break;
416 }
417 } catch (IOException e ) {
418 onTransportFailure(e);
419 }
420 }
421
422
423 public ReadableByteChannel getReadChannel() {
424 return ssl_channel;
425 }
426
427 public WritableByteChannel getWriteChannel() {
428 return ssl_channel;
429 }
430
431 public String getClientAuth() {
432 return clientAuth.name();
433 }
434
435 public void setClientAuth(String clientAuth) {
436 this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
437 }
438 }
439
440