You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by rs...@apache.org on 2018/07/18 11:32:42 UTC

[kafka] branch 2.0 updated: KAFKA-7168: Treat connection close during SSL handshake as retriable (#5371)

This is an automated email from the ASF dual-hosted git repository.

rsivaram pushed a commit to branch 2.0
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.0 by this push:
     new 984a702  KAFKA-7168: Treat connection close during SSL handshake as retriable (#5371)
984a702 is described below

commit 984a70241fcddeff075cf2e6f91dca1f884ddf32
Author: Rajini Sivaram <ra...@googlemail.com>
AuthorDate: Wed Jul 18 12:27:21 2018 +0100

    KAFKA-7168: Treat connection close during SSL handshake as retriable (#5371)
    
    SSL `close_notify` from broker connection close was processed as a handshake failure in clients while unwrapping the message if a handshake is in progress. Updated to handle this as a retriable IOException rather than a non-retriable SslAuthenticationException to avoid authentication exceptions in clients during rolling restart of brokers.
    
    Reviewers: Ismael Juma <is...@juma.me.uk>
---
 .../kafka/common/network/SslTransportLayer.java    | 36 ++++++++--
 .../apache/kafka/common/network/NioEchoServer.java | 26 +++++--
 .../common/network/SslTransportLayerTest.java      | 81 ++++++++++++++++++----
 3 files changed, 120 insertions(+), 23 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
index 06e7e93..838a6a7 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
@@ -31,8 +31,10 @@ import javax.net.ssl.SSLEngineResult.HandshakeStatus;
 import javax.net.ssl.SSLEngineResult.Status;
 import javax.net.ssl.SSLException;
 import javax.net.ssl.SSLHandshakeException;
-import javax.net.ssl.SSLSession;
+import javax.net.ssl.SSLKeyException;
 import javax.net.ssl.SSLPeerUnverifiedException;
+import javax.net.ssl.SSLProtocolException;
+import javax.net.ssl.SSLSession;
 
 import org.apache.kafka.common.errors.SslAuthenticationException;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
@@ -255,17 +257,17 @@ public class SslTransportLayer implements TransportLayer {
 
             doHandshake();
         } catch (SSLException e) {
-            handshakeFailure(e, true);
+            maybeProcessHandshakeFailure(e, true, null);
         } catch (IOException e) {
             maybeThrowSslAuthenticationException();
 
             // this exception could be due to a write. If there is data available to unwrap,
-            // process the data so that any SSLExceptions are reported
+            // process the data so that any SSL handshake exceptions are reported
             if (handshakeStatus == HandshakeStatus.NEED_UNWRAP && netReadBuffer.position() > 0) {
                 try {
                     handshakeUnwrap(false);
                 } catch (SSLException e1) {
-                    handshakeFailure(e1, false);
+                    maybeProcessHandshakeFailure(e1, false, e);
                 }
             }
             // If we get here, this is not a handshake failure, throw the original IOException
@@ -824,6 +826,32 @@ public class SslTransportLayer implements TransportLayer {
             throw handshakeException;
     }
 
+    // SSL handshake failures are typically thrown as SSLHandshakeException, SSLProtocolException,
+    // SSLPeerUnverifiedException or SSLKeyException if the cause is known. These exceptions indicate
+    // authentication failures (e.g. configuration errors) which should not be retried. But the SSL engine
+    // may also throw exceptions using the base class SSLException in a few cases:
+    //   a) If there are no matching ciphers or TLS version or the private key is invalid, client will be
+    //      unable to process the server message and an SSLException is thrown:
+    //      javax.net.ssl.SSLException: Unrecognized SSL message, plaintext connection?
+    //   b) If server closes the connection gracefully during handshake, client may receive close_notify
+    //      and and an SSLException is thrown:
+    //      javax.net.ssl.SSLException: Received close_notify during handshake
+    // We want to handle a) as a non-retriable SslAuthenticationException and b) as a retriable IOException.
+    // To do this we need to rely on the exception string. Since it is safer to throw a retriable exception
+    // when we are not sure, we will treat only the first exception string as a handshake exception.
+    private void maybeProcessHandshakeFailure(SSLException sslException, boolean flush, IOException ioException) throws IOException {
+        if (sslException instanceof SSLHandshakeException || sslException instanceof SSLProtocolException ||
+                sslException instanceof SSLPeerUnverifiedException || sslException instanceof SSLKeyException ||
+                sslException.getMessage().contains("Unrecognized SSL message"))
+            handshakeFailure(sslException, flush);
+        else if (ioException == null)
+            throw sslException;
+        else {
+            log.debug("SSLException while unwrapping data after IOException, original IOException will be propagated", sslException);
+            throw ioException;
+        }
+    }
+
     // If handshake has already failed, throw the authentication exception.
     private void maybeThrowSslAuthenticationException() {
         if (handshakeException != null)
diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
index 2ce9671..0c81b53 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
@@ -64,7 +64,8 @@ public class NioEchoServer extends Thread {
     private volatile WritableByteChannel outputChannel;
     private final CredentialCache credentialCache;
     private final Metrics metrics;
-    private int numSent = 0;
+    private volatile int numSent = 0;
+    private volatile boolean closeKafkaChannels;
     private final DelegationTokenCache tokenCache;
 
     public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config,
@@ -155,6 +156,11 @@ public class NioEchoServer extends Thread {
                     }
                     newChannels.clear();
                 }
+                if (closeKafkaChannels) {
+                    for (KafkaChannel channel : selector.channels())
+                        selector.close(channel.id());
+                    closeKafkaChannels = false;
+                }
 
                 List<NetworkReceive> completedReceives = selector.completedReceives();
                 for (NetworkReceive rcv : completedReceives) {
@@ -174,7 +180,6 @@ public class NioEchoServer extends Thread {
                     selector.unmute(send.destination());
                     numSent += 1;
                 }
-
             }
         } catch (IOException e) {
             // ignore
@@ -208,15 +213,26 @@ public class NioEchoServer extends Thread {
         return selector;
     }
 
-    public void closeConnections() throws IOException {
-        for (SocketChannel channel : socketChannels)
+    public void closeKafkaChannels() throws IOException {
+        closeKafkaChannels = true;
+        selector.wakeup();
+        try {
+            TestUtils.waitForCondition(() -> selector.channels().isEmpty(), "Channels not closed");
+        } catch (InterruptedException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    public void closeSocketChannels() throws IOException {
+        for (SocketChannel channel : socketChannels) {
             channel.close();
+        }
         socketChannels.clear();
     }
 
     public void close() throws IOException, InterruptedException {
         this.serverSocketChannel.close();
-        closeConnections();
+        closeSocketChannels();
         acceptorThread.interrupt();
         acceptorThread.join();
         interrupt();
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
index 1f62c10..6aef2f7 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
@@ -699,7 +699,8 @@ public class SslTransportLayerTest {
      */
     @Test
     public void testIOExceptionsDuringHandshakeRead() throws Exception {
-        testIOExceptionsDuringHandshake(true, false);
+        server = createEchoServer(SecurityProtocol.SSL);
+        testIOExceptionsDuringHandshake(FailureAction.THROW_IO_EXCEPTION, FailureAction.NO_OP);
     }
 
     /**
@@ -707,20 +708,60 @@ public class SslTransportLayerTest {
      */
     @Test
     public void testIOExceptionsDuringHandshakeWrite() throws Exception {
-        testIOExceptionsDuringHandshake(false, true);
+        server = createEchoServer(SecurityProtocol.SSL);
+        testIOExceptionsDuringHandshake(FailureAction.NO_OP, FailureAction.THROW_IO_EXCEPTION);
+    }
+
+    /**
+     * Tests that if the remote end closes connection ungracefully  during SSL handshake while reading data,
+     * the disconnection is not treated as an authentication failure.
+     */
+    @Test
+    public void testUngracefulRemoteCloseDuringHandshakeRead() throws Exception {
+        server = createEchoServer(SecurityProtocol.SSL);
+        testIOExceptionsDuringHandshake(server::closeSocketChannels, FailureAction.NO_OP);
+    }
+
+    /**
+     * Tests that if the remote end closes connection ungracefully during SSL handshake while writing data,
+     * the disconnection is not treated as an authentication failure.
+     */
+    @Test
+    public void testUngracefulRemoteCloseDuringHandshakeWrite() throws Exception {
+        server = createEchoServer(SecurityProtocol.SSL);
+        testIOExceptionsDuringHandshake(FailureAction.NO_OP, server::closeSocketChannels);
     }
 
-    private void testIOExceptionsDuringHandshake(boolean failRead, boolean failWrite) throws Exception {
+    /**
+     * Tests that if the remote end closes the connection during SSL handshake while reading data,
+     * the disconnection is not treated as an authentication failure.
+     */
+    @Test
+    public void testGracefulRemoteCloseDuringHandshakeRead() throws Exception {
+        server = createEchoServer(SecurityProtocol.SSL);
+        testIOExceptionsDuringHandshake(FailureAction.NO_OP, server::closeKafkaChannels);
+    }
+
+    /**
+     * Tests that if the remote end closes the connection during SSL handshake while writing data,
+     * the disconnection is not treated as an authentication failure.
+     */
+    @Test
+    public void testGracefulRemoteCloseDuringHandshakeWrite() throws Exception {
         server = createEchoServer(SecurityProtocol.SSL);
+        testIOExceptionsDuringHandshake(server::closeKafkaChannels, FailureAction.NO_OP);
+    }
+
+    private void testIOExceptionsDuringHandshake(FailureAction readFailureAction,
+                                                 FailureAction flushFailureAction) throws Exception {
         TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
         boolean done = false;
         for (int i = 1; i <= 100; i++) {
-            int readFailureIndex = failRead ? i : Integer.MAX_VALUE;
-            int flushFailureIndex = failWrite ? i : Integer.MAX_VALUE;
             String node = String.valueOf(i);
 
-            channelBuilder.readFailureIndex = readFailureIndex;
-            channelBuilder.flushFailureIndex = flushFailureIndex;
+            channelBuilder.readFailureAction = readFailureAction;
+            channelBuilder.flushFailureAction = flushFailureAction;
+            channelBuilder.failureIndex = i;
             channelBuilder.configure(sslClientConfigs);
             this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext());
 
@@ -734,7 +775,9 @@ public class SslTransportLayerTest {
                     break;
                 }
                 if (selector.disconnected().containsKey(node)) {
-                    assertEquals(ChannelState.State.AUTHENTICATE, selector.disconnected().get(node).state());
+                    ChannelState.State state = selector.disconnected().get(node).state();
+                    assertTrue("Unexpected channel state " + state,
+                            state == ChannelState.State.AUTHENTICATE || state == ChannelState.State.READY);
                     break;
                 }
             }
@@ -973,13 +1016,23 @@ public class SslTransportLayerTest {
         return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol);
     }
 
+    @FunctionalInterface
+    private interface FailureAction {
+        FailureAction NO_OP = () -> { };
+        FailureAction THROW_IO_EXCEPTION = () -> {
+            throw new IOException("Test IO exception");
+        };
+        void run() throws IOException;
+    }
+
     private static class TestSslChannelBuilder extends SslChannelBuilder {
 
         private Integer netReadBufSizeOverride;
         private Integer netWriteBufSizeOverride;
         private Integer appBufSizeOverride;
-        long readFailureIndex = Long.MAX_VALUE;
-        long flushFailureIndex = Long.MAX_VALUE;
+        private long failureIndex = Long.MAX_VALUE;
+        FailureAction readFailureAction = FailureAction.NO_OP;
+        FailureAction flushFailureAction = FailureAction.NO_OP;
         int flushDelayCount = 0;
 
         public TestSslChannelBuilder(Mode mode) {
@@ -1029,8 +1082,8 @@ public class SslTransportLayerTest {
                 this.netReadBufSize = new ResizeableBufferSize(netReadBufSizeOverride);
                 this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSizeOverride);
                 this.appBufSize = new ResizeableBufferSize(appBufSizeOverride);
-                numReadsRemaining = new AtomicLong(readFailureIndex);
-                numFlushesRemaining = new AtomicLong(flushFailureIndex);
+                numReadsRemaining = new AtomicLong(failureIndex);
+                numFlushesRemaining = new AtomicLong(failureIndex);
                 numDelayedFlushesRemaining = new AtomicInteger(flushDelayCount);
             }
 
@@ -1058,14 +1111,14 @@ public class SslTransportLayerTest {
             @Override
             protected int readFromSocketChannel() throws IOException {
                 if (numReadsRemaining.decrementAndGet() == 0 && !ready())
-                    throw new IOException("Test exception during read");
+                    readFailureAction.run();
                 return super.readFromSocketChannel();
             }
 
             @Override
             protected boolean flush(ByteBuffer buf) throws IOException {
                 if (numFlushesRemaining.decrementAndGet() == 0 && !ready())
-                    throw new IOException("Test exception during write");
+                    flushFailureAction.run();
                 else if (numDelayedFlushesRemaining.getAndDecrement() != 0)
                     return false;
                 resetDelayedFlush();