001 /**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.fusesource.hawtdispatch.transport;
019
020 import org.fusesource.hawtdispatch.Task;
021
022 import javax.net.ssl.*;
023 import java.io.EOFException;
024 import java.io.IOException;
025 import java.nio.ByteBuffer;
026 import java.nio.channels.GatheringByteChannel;
027 import java.nio.channels.ReadableByteChannel;
028 import java.nio.channels.ScatteringByteChannel;
029 import java.nio.channels.WritableByteChannel;
030 import java.security.cert.Certificate;
031 import java.security.cert.X509Certificate;
032 import java.util.ArrayList;
033
034 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;
035 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
036
037 /**
038 * Implements the SSL protocol as a WrappingProtocolCodec. Useful for when
039 * you want to switch to the SSL protocol on a regular TCP Transport.
040 */
041 public class SslProtocolCodec implements WrappingProtocolCodec, SecuredSession {
042
043 private ReadableByteChannel readChannel;
044 private WritableByteChannel writeChannel;
045
046 public enum ClientAuth {
047 WANT, NEED, NONE
048 };
049
050 private SSLContext sslContext;
051 private SSLEngine engine;
052
053 private ByteBuffer readBuffer;
054 private boolean readUnderflow;
055
056 private ByteBuffer writeBuffer;
057 private boolean writeFlushing;
058
059 private ByteBuffer readOverflowBuffer;
060 Transport transport;
061
062 int lastReadSize;
063 int lastWriteSize;
064 long readCounter;
065 long writeCounter;
066
067 ProtocolCodec next;
068
069
070 public SslProtocolCodec() {
071 }
072
073 public ProtocolCodec getNext() {
074 return next;
075 }
076 public void setNext(ProtocolCodec next) {
077 this.next = next;
078 initNext();
079 }
080
081 private void initNext() {
082 if( next!=null ) {
083 this.next.setTransport(new TransportFilter(transport){
084 public ReadableByteChannel getReadChannel() {
085 return sslReadChannel;
086 }
087 public WritableByteChannel getWriteChannel() {
088 return sslWriteChannel;
089 }
090 });
091 }
092 }
093
094 public void setSSLContext(SSLContext ctx) {
095 assert engine == null;
096 this.sslContext = ctx;
097 }
098
099 public SslProtocolCodec client() throws Exception {
100 initializeEngine();
101 engine.setUseClientMode(true);
102 engine.beginHandshake();
103 return this;
104 }
105
106 public SslProtocolCodec server(ClientAuth clientAuth) throws Exception {
107 initializeEngine();
108 engine.setUseClientMode(false);
109 switch (clientAuth) {
110 case WANT: engine.setWantClientAuth(true); break;
111 case NEED: engine.setNeedClientAuth(true); break;
112 case NONE: engine.setWantClientAuth(false); break;
113 }
114 engine.beginHandshake();
115 return this;
116 }
117
118 protected void initializeEngine() throws Exception {
119 assert engine == null;
120 if( sslContext == null ) {
121 sslContext = SSLContext.getDefault();
122 }
123 engine = sslContext.createSSLEngine();
124 SSLSession session = engine.getSession();
125 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
126 readBuffer.flip();
127 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
128 }
129
130
131 public SSLSession getSSLSession() {
132 return engine==null ? null : engine.getSession();
133 }
134
135 public X509Certificate[] getPeerX509Certificates() {
136 if( engine==null ) {
137 return null;
138 }
139 try {
140 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
141 for( Certificate c:engine.getSession().getPeerCertificates() ) {
142 if(c instanceof X509Certificate) {
143 rc.add((X509Certificate) c);
144 }
145 }
146 return rc.toArray(new X509Certificate[rc.size()]);
147 } catch (SSLPeerUnverifiedException e) {
148 return null;
149 }
150 }
151
152 SSLReadChannel sslReadChannel = new SSLReadChannel();
153 SSLWriteChannel sslWriteChannel = new SSLWriteChannel();
154
155 public void setTransport(Transport transport) {
156 this.transport = transport;
157 this.readChannel = transport.getReadChannel();
158 this.writeChannel = transport.getWriteChannel();
159 initNext();
160 }
161
162 public void handshake() throws IOException {
163 if( !transportFlush() ) {
164 return;
165 }
166 switch (engine.getHandshakeStatus()) {
167 case NEED_TASK:
168 final Runnable task = engine.getDelegatedTask();
169 if( task!=null ) {
170 transport.getBlockingExecutor().execute(new Task() {
171 public void run() {
172 task.run();
173 transport.getDispatchQueue().execute(new Task() {
174 public void run() {
175 if (readChannel.isOpen() && writeChannel.isOpen()) {
176 try {
177 handshake();
178 } catch (IOException e) {
179 transport.getTransportListener().onTransportFailure(e);
180 }
181 }
182 }
183 });
184 }
185 });
186 }
187 break;
188
189 case NEED_WRAP:
190 secure_write(ByteBuffer.allocate(0));
191 break;
192
193 case NEED_UNWRAP:
194 if( secure_read(ByteBuffer.allocate(0)) == -1) {
195 throw new EOFException("Peer disconnected during ssl handshake");
196 }
197 break;
198
199 case FINISHED:
200 case NOT_HANDSHAKING:
201 transport.drainInbound();
202 transport.getTransportListener().onRefill();
203 break;
204
205 default:
206 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
207 break;
208 }
209 }
210
211 /**
212 * @return true if fully flushed.
213 * @throws IOException
214 */
215 protected boolean transportFlush() throws IOException {
216 while (true) {
217 if(writeFlushing) {
218 lastWriteSize = writeChannel.write(writeBuffer);
219 if( lastWriteSize > 0 ) {
220 writeCounter += lastWriteSize;
221 }
222 if( !writeBuffer.hasRemaining() ) {
223 writeBuffer.clear();
224 writeFlushing = false;
225 return true;
226 } else {
227 return false;
228 }
229 } else {
230 if( writeBuffer.position()!=0 ) {
231 writeBuffer.flip();
232 writeFlushing = true;
233 } else {
234 return true;
235 }
236 }
237 }
238 }
239
240 private int secure_read(ByteBuffer plain) throws IOException {
241 int rc=0;
242 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
243 if( readOverflowBuffer !=null ) {
244 if( plain.hasRemaining() ) {
245 // lets drain the overflow buffer before trying to suck down anymore
246 // network bytes.
247 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
248 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
249 readOverflowBuffer.position(readOverflowBuffer.position()+size);
250 if( !readOverflowBuffer.hasRemaining() ) {
251 readOverflowBuffer = null;
252 }
253 rc += size;
254 } else {
255 return rc;
256 }
257 } else if( readUnderflow ) {
258 lastReadSize = readChannel.read(readBuffer);
259 if( lastReadSize == -1 ) { // peer closed socket.
260 if (rc==0) {
261 return -1;
262 } else {
263 return rc;
264 }
265 }
266 if( lastReadSize==0 ) { // no data available right now.
267 return rc;
268 }
269 readCounter += lastReadSize;
270 // read in some more data, perhaps now we can unwrap.
271 readUnderflow = false;
272 readBuffer.flip();
273 } else {
274 SSLEngineResult result = engine.unwrap(readBuffer, plain);
275 rc += result.bytesProduced();
276 if( result.getStatus() == BUFFER_OVERFLOW ) {
277 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
278 result = engine.unwrap(readBuffer, readOverflowBuffer);
279 if( readOverflowBuffer.position()==0 ) {
280 readOverflowBuffer = null;
281 } else {
282 readOverflowBuffer.flip();
283 }
284 }
285 switch( result.getStatus() ) {
286 case CLOSED:
287 if (rc==0) {
288 engine.closeInbound();
289 return -1;
290 } else {
291 return rc;
292 }
293 case OK:
294 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
295 handshake();
296 }
297 break;
298 case BUFFER_UNDERFLOW:
299 readBuffer.compact();
300 readUnderflow = true;
301 break;
302 case BUFFER_OVERFLOW:
303 throw new AssertionError("Unexpected case.");
304 }
305 }
306 }
307 return rc;
308 }
309
310 private int secure_write(ByteBuffer plain) throws IOException {
311 if( !transportFlush() ) {
312 // can't write anymore until the write_secured_buffer gets fully flushed out..
313 return 0;
314 }
315 int rc = 0;
316 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
317 SSLEngineResult result = engine.wrap(plain, writeBuffer);
318 assert result.getStatus()!= BUFFER_OVERFLOW;
319 rc += result.bytesConsumed();
320 if( !transportFlush() ) {
321 break;
322 }
323 }
324 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
325 handshake();
326 }
327 return rc;
328 }
329
330 public class SSLReadChannel implements ScatteringByteChannel {
331
332 public int read(ByteBuffer plain) throws IOException {
333 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
334 handshake();
335 }
336 return secure_read(plain);
337 }
338
339 public boolean isOpen() {
340 return readChannel.isOpen();
341 }
342
343 public void close() throws IOException {
344 readChannel.close();
345 }
346
347 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
348 if(offset+length > dsts.length || length<0 || offset<0) {
349 throw new IndexOutOfBoundsException();
350 }
351 long rc=0;
352 for (int i = 0; i < length; i++) {
353 ByteBuffer dst = dsts[offset+i];
354 if(dst.hasRemaining()) {
355 rc += read(dst);
356 }
357 if( dst.hasRemaining() ) {
358 return rc;
359 }
360 }
361 return rc;
362 }
363
364 public long read(ByteBuffer[] dsts) throws IOException {
365 return read(dsts, 0, dsts.length);
366 }
367 }
368
369 public class SSLWriteChannel implements GatheringByteChannel {
370
371 public int write(ByteBuffer plain) throws IOException {
372 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
373 handshake();
374 }
375 return secure_write(plain);
376 }
377
378 public boolean isOpen() {
379 return writeChannel.isOpen();
380 }
381
382 public void close() throws IOException {
383 writeChannel.close();
384 }
385
386 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
387 if(offset+length > srcs.length || length<0 || offset<0) {
388 throw new IndexOutOfBoundsException();
389 }
390 long rc=0;
391 for (int i = 0; i < length; i++) {
392 ByteBuffer src = srcs[offset+i];
393 if(src.hasRemaining()) {
394 rc += write(src);
395 }
396 if( src.hasRemaining() ) {
397 return rc;
398 }
399 }
400 return rc;
401 }
402
403 public long write(ByteBuffer[] srcs) throws IOException {
404 return write(srcs, 0, srcs.length);
405 }
406 }
407
408 public void unread(byte[] buffer) {
409 readBuffer.compact();
410 if( readBuffer.remaining() < buffer.length) {
411 throw new IllegalStateException("Cannot unread now");
412 }
413 readBuffer.put(buffer);
414 readBuffer.flip();
415 }
416
417 public Object read() throws IOException {
418 return next.read();
419 }
420
421 public ProtocolCodec.BufferState write(Object value) throws IOException {
422 return next.write(value);
423 }
424
425 public ProtocolCodec.BufferState flush() throws IOException {
426 return next.flush();
427 }
428
429 public boolean full() {
430 return next.full();
431 }
432
433 public long getWriteCounter() {
434 return writeCounter;
435 }
436
437 public long getLastWriteSize() {
438 return lastWriteSize;
439 }
440
441 public long getReadCounter() {
442 return readCounter;
443 }
444
445 public long getLastReadSize() {
446 return lastReadSize;
447 }
448
449 public int getReadBufferSize() {
450 return readBuffer.capacity();
451 }
452
453 public int getWriteBufferSize() {
454 return writeBuffer.capacity();
455 }
456
457
458
459 }