You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@activemq.apache.org by ta...@apache.org on 2012/10/17 23:17:46 UTC

svn commit: r1399438 - /activemq/trunk/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java

Author: tabish
Date: Wed Oct 17 21:17:45 2012
New Revision: 1399438

URL: http://svn.apache.org/viewvc?rev=1399438&view=rev
Log:
fix for: https://issues.apache.org/jira/browse/AMQ-3996


Modified:
    activemq/trunk/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java

Modified: activemq/trunk/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java
URL: http://svn.apache.org/viewvc/activemq/trunk/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java?rev=1399438&r1=1399437&r2=1399438&view=diff
==============================================================================
--- activemq/trunk/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java (original)
+++ activemq/trunk/activemq-core/src/main/java/org/apache/activemq/transport/nio/NIOSSLTransport.java Wed Oct 17 21:17:45 2012
@@ -41,8 +41,12 @@ import org.apache.activemq.thread.TaskRu
 import org.apache.activemq.util.IOExceptionSupport;
 import org.apache.activemq.util.ServiceStopper;
 import org.apache.activemq.wireformat.WireFormat;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-public class NIOSSLTransport extends NIOTransport  {
+public class NIOSSLTransport extends NIOTransport {
+
+    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
 
     protected boolean needClientAuth;
     protected boolean wantClientAuth;
@@ -79,15 +83,36 @@ public class NIOSSLTransport extends NIO
                 sslContext = SSLContext.getDefault();
             }
 
+            String remoteHost = null;
+            int remotePort = -1;
+
+            try {
+                URI remoteAddress = new URI(this.getRemoteAddress());
+                remoteHost = remoteAddress.getHost();
+                remotePort = remoteAddress.getPort();
+            } catch (Exception e) {
+            }
+
             // initialize engine, the initial sslSession we get will need to be
             // updated once the ssl handshake process is completed.
-            sslEngine = sslContext.createSSLEngine();
+            if (remoteHost != null && remotePort != -1) {
+                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
+            } else {
+                sslEngine = sslContext.createSSLEngine();
+            }
+
             sslEngine.setUseClientMode(false);
             if (enabledCipherSuites != null) {
                 sslEngine.setEnabledCipherSuites(enabledCipherSuites);
             }
-            sslEngine.setNeedClientAuth(needClientAuth);
-            sslEngine.setWantClientAuth(wantClientAuth);
+
+            if (wantClientAuth) {
+                sslEngine.setWantClientAuth(wantClientAuth);
+            }
+
+            if (needClientAuth) {
+                sslEngine.setNeedClientAuth(needClientAuth);
+            }
 
             sslSession = sslEngine.getSession();
 
@@ -107,31 +132,31 @@ public class NIOSSLTransport extends NIO
         }
     }
 
-    protected void finishHandshake() throws Exception  {
-          if (handshakeInProgress) {
-              handshakeInProgress = false;
-              nextFrameSize = -1;
-
-              // Once handshake completes we need to ask for the now real sslSession
-              // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
-              // cipher suite.
-              sslSession = sslEngine.getSession();
-
-              // listen for events telling us when the socket is readable.
-              selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
-                  public void onSelect(SelectorSelection selection) {
-                      serviceRead();
-                  }
-
-                  public void onError(SelectorSelection selection, Throwable error) {
-                      if (error instanceof IOException) {
-                          onException((IOException) error);
-                      } else {
-                          onException(IOExceptionSupport.create(error));
-                      }
-                  }
-              });
-          }
+    protected void finishHandshake() throws Exception {
+        if (handshakeInProgress) {
+            handshakeInProgress = false;
+            nextFrameSize = -1;
+
+            // Once handshake completes we need to ask for the now real sslSession
+            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
+            // cipher suite.
+            sslSession = sslEngine.getSession();
+
+            // listen for events telling us when the socket is readable.
+            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
+                public void onSelect(SelectorSelection selection) {
+                    serviceRead();
+                }
+
+                public void onError(SelectorSelection selection, Throwable error) {
+                    if (error instanceof IOException) {
+                        onException((IOException) error);
+                    } else {
+                        onException(IOExceptionSupport.create(error));
+                    }
+                }
+            });
+        }
     }
 
     protected void serviceRead() {
@@ -143,7 +168,7 @@ public class NIOSSLTransport extends NIO
             ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
             plain.position(plain.limit());
 
-            while(true) {
+            while (true) {
                 if (!plain.hasRemaining()) {
 
                     if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
@@ -153,12 +178,11 @@ public class NIOSSLTransport extends NIO
                     }
                     int readCount = secureRead(plain);
 
-
                     if (readCount == 0)
                         break;
 
                     // channel is closed, cleanup
-                    if (readCount== -1) {
+                    if (readCount == -1) {
                         onException(new EOFException());
                         selection.close();
                         break;
@@ -181,7 +205,8 @@ public class NIOSSLTransport extends NIO
         if (wireFormat instanceof OpenWireFormat) {
             long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
             if (nextFrameSize > maxFrameSize) {
-                throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
+                throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
+                                       " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
             }
         }
         currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
@@ -213,8 +238,7 @@ public class NIOSSLTransport extends NIO
 
             if (bytesRead == -1) {
                 sslEngine.closeInbound();
-                if (inputBuffer.position() == 0 ||
-                        status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
+                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
                     return -1;
                 }
             }
@@ -226,18 +250,17 @@ public class NIOSSLTransport extends NIO
         SSLEngineResult res;
         do {
             res = sslEngine.unwrap(inputBuffer, plain);
-        } while (res.getStatus() == SSLEngineResult.Status.OK &&
-                res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP &&
-                res.bytesProduced() == 0);
+        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
+                && res.bytesProduced() == 0);
 
         if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
-           finishHandshake();
+            finishHandshake();
         }
 
         status = res.getStatus();
         handshakeStatus = res.getHandshakeStatus();
 
-        //TODO deal with BUFFER_OVERFLOW
+        // TODO deal with BUFFER_OVERFLOW
 
         if (status == SSLEngineResult.Status.CLOSED) {
             sslEngine.closeInbound();
@@ -254,22 +277,22 @@ public class NIOSSLTransport extends NIO
         handshakeInProgress = true;
         while (true) {
             switch (sslEngine.getHandshakeStatus()) {
-                case NEED_UNWRAP:
-                    secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
-                    break;
-                case NEED_TASK:
-                    Runnable task;
-                    while ((task = sslEngine.getDelegatedTask()) != null) {
-                        taskRunnerFactory.execute(task);
-                    }
-                    break;
-                case NEED_WRAP:
-                    ((NIOOutputStream)buffOut).write(ByteBuffer.allocate(0));
-                    break;
-                case FINISHED:
-                case NOT_HANDSHAKING:
-                    finishHandshake();
-                    return;
+            case NEED_UNWRAP:
+                secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
+                break;
+            case NEED_TASK:
+                Runnable task;
+                while ((task = sslEngine.getDelegatedTask()) != null) {
+                    taskRunnerFactory.execute(task);
+                }
+                break;
+            case NEED_WRAP:
+                ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
+                break;
+            case FINISHED:
+            case NOT_HANDSHAKING:
+                finishHandshake();
+                return;
             }
         }
     }
@@ -295,14 +318,15 @@ public class NIOSSLTransport extends NIO
     }
 
     /**
-     * Overriding in order to add the client's certificates to ConnectionInfo Commmands.
+     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
      *
-     * @param command The Command coming in.
+     * @param command
+     *            The Command coming in.
      */
     @Override
     public void doConsume(Object command) {
         if (command instanceof ConnectionInfo) {
-            ConnectionInfo connectionInfo = (ConnectionInfo)command;
+            ConnectionInfo connectionInfo = (ConnectionInfo) command;
             connectionInfo.setTransportContext(getPeerCertificates());
         }
         super.doConsume(command);
@@ -315,10 +339,13 @@ public class NIOSSLTransport extends NIO
 
         X509Certificate[] clientCertChain = null;
         try {
-            if (sslSession != null) {
-                clientCertChain = (X509Certificate[])sslSession.getPeerCertificates();
+            if (sslEngine.getSession() != null) {
+                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
             }
         } catch (SSLPeerUnverifiedException e) {
+            if (LOG.isTraceEnabled()) {
+                LOG.trace("Failed to get peer certificates.", e);
+            }
         }
 
         return clientCertChain;