You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ij...@apache.org on 2017/09/22 19:58:29 UTC

kafka git commit: KAFKA-5920; Handle SSL handshake failures as authentication exceptions

Repository: kafka
Updated Branches:
  refs/heads/trunk e554dc518 -> d60f011d7


KAFKA-5920; Handle SSL handshake failures as authentication exceptions

1. Propagate `SSLException` as `SslAuthenticationException` to enable clients to report these and avoid retries
2. Updates to `SslTransportLayer` to process bytes received even if end-of-stream
3. Some tidy up of authentication handling
4. Report exceptions in SaslClientAuthenticator as AuthenticationExceptions

Author: Rajini Sivaram <ra...@googlemail.com>

Reviewers: Ismael Juma <is...@juma.me.uk>

Closes #3918 from rajinisivaram/KAFKA-5920-SSL-handshake-failure


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/d60f011d
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/d60f011d
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/d60f011d

Branch: refs/heads/trunk
Commit: d60f011d77ce80a44b02d43bf0889a50a8797dcd
Parents: e554dc5
Author: Rajini Sivaram <ra...@googlemail.com>
Authored: Fri Sep 22 20:26:46 2017 +0100
Committer: Ismael Juma <is...@juma.me.uk>
Committed: Fri Sep 22 20:29:25 2017 +0100

----------------------------------------------------------------------
 .../common/errors/AuthenticationException.java  |   3 +
 .../errors/SslAuthenticationException.java      |  44 +++
 .../kafka/common/network/Authenticator.java     |  17 +-
 .../kafka/common/network/KafkaChannel.java      |  31 +-
 .../common/network/PlaintextChannelBuilder.java |   7 -
 .../apache/kafka/common/network/Selector.java   |   3 +
 .../kafka/common/network/SslChannelBuilder.java |   8 -
 .../kafka/common/network/SslTransportLayer.java | 296 +++++++++++-------
 .../kafka/common/network/TransportLayer.java    |  12 +-
 .../authenticator/SaslClientAuthenticator.java  |  31 +-
 .../authenticator/SaslServerAuthenticator.java  |  16 +-
 .../kafka/common/network/NioEchoServer.java     |  12 +-
 .../common/network/SslTransportLayerTest.java   | 305 +++++++++++++------
 13 files changed, 506 insertions(+), 279 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
index c56ac88..f6458c6 100644
--- a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
+++ b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.common.errors;
 
+import javax.net.ssl.SSLException;
+
 /**
  * This exception indicates that SASL authentication has failed.
  * On authentication failure, clients abort the operation requested and raise one
@@ -27,6 +29,7 @@ package org.apache.kafka.common.errors;
  *   is not supported on the broker.</li>
  *   <li>{@link IllegalSaslStateException} if an unexpected request is received on during SASL
  *   handshake. This could be due to misconfigured security protocol.</li>
+ *   <li>{@link SslAuthenticationException} if SSL handshake failed due to any {@link SSLException}.
  * </ul>
  */
 public class AuthenticationException extends ApiException {

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java
new file mode 100644
index 0000000..3cdbf2a
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.errors;
+
+import javax.net.ssl.SSLException;
+
+/**
+ * This exception indicates that SSL handshake has failed. See {@link #getCause()}
+ * for the {@link SSLException} that caused this failure.
+ * <p>
+ * SSL handshake failures in clients may indicate client authentication
+ * failure due to untrusted certificates if server is configured to request
+ * client certificates. Handshake failures could also indicate misconfigured
+ * security including protocol/cipher suite mismatch, server certificate
+ * authentication failure or server host name verification failure.
+ * </p>
+ */
+public class SslAuthenticationException extends AuthenticationException {
+
+    private static final long serialVersionUID = 1L;
+
+    public SslAuthenticationException(String message) {
+        super(message);
+    }
+
+    public SslAuthenticationException(String message, Throwable cause) {
+        super(message, cause);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
index fa1123e..4e2e727 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java
@@ -16,7 +16,7 @@
  */
 package org.apache.kafka.common.network;
 
-import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 
 import java.io.Closeable;
@@ -28,15 +28,14 @@ import java.io.IOException;
 public interface Authenticator extends Closeable {
     /**
      * Implements any authentication mechanism. Use transportLayer to read or write tokens.
-     * If no further authentication needs to be done returns.
+     * For security protocols PLAINTEXT and SSL, this is a no-op since no further authentication
+     * needs to be done. For SASL_PLAINTEXT and SASL_SSL, this performs the SASL authentication.
+     *
+     * @throws AuthenticationException if authentication fails due to invalid credentials or
+     *      other security configuration errors
+     * @throws IOException if read/write fails due to an I/O error
      */
-    void authenticate() throws IOException;
-
-    /**
-     * Returns the first error encountered during authentication
-     * @return authentication error if authentication failed, Errors.NONE otherwise
-     */
-    Errors error();
+    void authenticate() throws AuthenticationException, IOException;
 
     /**
      * Returns Principal using PrincipalBuilder

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
index 24cd9cf..f07035a 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java
@@ -69,26 +69,21 @@ public class KafkaChannel {
     }
 
     /**
-     * Does handshake of transportLayer and authentication using configured authenticator
+     * Does handshake of transportLayer and authentication using configured authenticator.
+     * For SSL with client authentication enabled, {@link TransportLayer#handshake()} performs
+     * authentication. For SASL, authentication is performed by {@link Authenticator#authenticate()}.
      */
-    public void prepare() throws IOException {
-        if (!transportLayer.ready())
-            transportLayer.handshake();
-        if (transportLayer.ready() && !authenticator.complete()) {
-            try {
+    public void prepare() throws AuthenticationException, IOException {
+        try {
+            if (!transportLayer.ready())
+                transportLayer.handshake();
+            if (transportLayer.ready() && !authenticator.complete())
                 authenticator.authenticate();
-            } catch (AuthenticationException e) {
-                switch (authenticator.error()) {
-                    case SASL_AUTHENTICATION_FAILED:
-                    case ILLEGAL_SASL_STATE:
-                    case UNSUPPORTED_SASL_MECHANISM:
-                        state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e);
-                        break;
-                    default:
-                        // Other errors are handled as network exceptions in Selector
-                }
-                throw e;
-            }
+        } catch (AuthenticationException e) {
+            // Clients are notified of authentication exceptions to enable operations to be terminated
+            // without retries. Other errors are handled as network exceptions in Selector.
+            state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e);
+            throw e;
         }
         if (ready())
             state = ChannelState.READY;

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
index 95fd903..c0d1059 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java
@@ -18,7 +18,6 @@ package org.apache.kafka.common.network;
 
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.memory.MemoryPool;
-import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
 import org.apache.kafka.common.security.auth.PlaintextAuthenticationContext;
@@ -80,12 +79,6 @@ public class PlaintextChannelBuilder implements ChannelBuilder {
         }
 
         @Override
-        public Errors error() {
-            // PLAINTEXT never fails authentication
-            return Errors.NONE;
-        }
-
-        @Override
         public void close() {
             if (principalBuilder instanceof Closeable)
                 Utils.closeQuietly((Closeable) principalBuilder, "principal builder");

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/Selector.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
index 7977879..b753745 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
@@ -42,6 +42,7 @@ import org.apache.kafka.common.memory.MemoryPool;
 import org.apache.kafka.common.metrics.Measurable;
 import org.apache.kafka.common.metrics.MetricConfig;
 import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.metrics.stats.Avg;
@@ -485,6 +486,8 @@ public class Selector implements Selectable, AutoCloseable {
                 String desc = channel.socketDescription();
                 if (e instanceof IOException)
                     log.debug("Connection with {} disconnected", desc, e);
+                else if (e instanceof AuthenticationException) // will be logged later as error by clients
+                    log.debug("Connection with {} disconnected due to authentication exception", desc, e);
                 else
                     log.warn("Unexpected error from {}; closing connection", desc, e);
                 close(channel, true);

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
index 80b9e9a..9519e58 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java
@@ -18,7 +18,6 @@ package org.apache.kafka.common.network;
 
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.memory.MemoryPool;
-import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder;
 import org.apache.kafka.common.security.auth.SslAuthenticationContext;
@@ -158,12 +157,5 @@ public class SslChannelBuilder implements ChannelBuilder {
         public boolean complete() {
             return true;
         }
-
-        @Override
-        public Errors error() {
-            // SSL authentication failures are currently not propagated to clients
-            return Errors.NONE;
-        }
-
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
index 3cd0114..f5e1e70 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java
@@ -34,6 +34,7 @@ import javax.net.ssl.SSLHandshakeException;
 import javax.net.ssl.SSLSession;
 import javax.net.ssl.SSLPeerUnverifiedException;
 
+import org.apache.kafka.common.errors.SslAuthenticationException;
 import org.apache.kafka.common.security.auth.KafkaPrincipal;
 import org.apache.kafka.common.utils.Utils;
 import org.slf4j.Logger;
@@ -44,6 +45,14 @@ import org.slf4j.LoggerFactory;
  */
 public class SslTransportLayer implements TransportLayer {
     private static final Logger log = LoggerFactory.getLogger(SslTransportLayer.class);
+
+    private enum State {
+        HANDSHAKE,
+        HANDSHAKE_FAILED,
+        READY,
+        CLOSING
+    }
+
     private final String channelId;
     private final SSLEngine sslEngine;
     private final SelectionKey key;
@@ -52,8 +61,8 @@ public class SslTransportLayer implements TransportLayer {
 
     private HandshakeStatus handshakeStatus;
     private SSLEngineResult handshakeResult;
-    private boolean handshakeComplete = false;
-    private boolean closing = false;
+    private State state;
+    private SslAuthenticationException handshakeException;
     private ByteBuffer netReadBuffer;
     private ByteBuffer netWriteBuffer;
     private ByteBuffer appReadBuffer;
@@ -89,8 +98,7 @@ public class SslTransportLayer implements TransportLayer {
         netWriteBuffer.limit(0);
         netReadBuffer.position(0);
         netReadBuffer.limit(0);
-        handshakeComplete = false;
-        closing = false;
+        state = State.HANDSHAKE;
         //initiate handshake
         sslEngine.beginHandshake();
         handshakeStatus = sslEngine.getHandshakeStatus();
@@ -98,7 +106,7 @@ public class SslTransportLayer implements TransportLayer {
 
     @Override
     public boolean ready() {
-        return handshakeComplete;
+        return state == State.READY;
     }
 
     /**
@@ -141,8 +149,8 @@ public class SslTransportLayer implements TransportLayer {
     */
     @Override
     public void close() throws IOException {
-        if (closing) return;
-        closing = true;
+        if (state == State.CLOSING) return;
+        state = State.CLOSING;
         sslEngine.closeOutbound();
         try {
             if (isConnected()) {
@@ -183,12 +191,22 @@ public class SslTransportLayer implements TransportLayer {
     }
 
     /**
-    * Flushes the buffer to the network, non blocking
+     * Reads available bytes from socket channel to `netReadBuffer`.
+     * Visible for testing.
+     * @return  number of bytes read
+     */
+    protected int readFromSocketChannel() throws IOException {
+        return socketChannel.read(netReadBuffer);
+    }
+
+    /**
+    * Flushes the buffer to the network, non blocking.
+    * Visible for testing.
     * @param buf ByteBuffer
     * @return boolean true if the buffer has been emptied out, false otherwise
     * @throws IOException
     */
-    private boolean flush(ByteBuffer buf) throws IOException {
+    protected boolean flush(ByteBuffer buf) throws IOException {
         int remaining = buf.remaining();
         if (remaining > 0) {
             int written = socketChannel.write(buf);
@@ -217,101 +235,137 @@ public class SslTransportLayer implements TransportLayer {
     * | unwrap()    | Finished                         | FINISHED    |
     * +-------------+----------------------------------+-------------+
     *
-    * @throws IOException
+    * @throws IOException if read/write fails
+    * @throws SslAuthenticationException if handshake fails with an {@link SSLException}
     */
     @Override
     public void handshake() throws IOException {
+        // Reset state to support renegotiation. This can be removed if renegotiation support is removed.
+        if (state == State.READY)
+            state = State.HANDSHAKE;
+
+        int read = 0;
+        try {
+            // Read any available bytes before attempting any writes to ensure that handshake failures
+            // reported by the peer are processed even if writes fail (since peer closes connection
+            // if handshake fails)
+            if (key.isReadable())
+                read = readFromSocketChannel();
+
+            doHandshake();
+        } catch (SSLException e) {
+            handshakeFailure(e, true);
+        } catch (IOException e) {
+            maybeThrowSslAuthenticationException();
+
+            // this exception could be due to a write. If there is data available to unwrap,
+            // process the data so that any SSLExceptions are reported
+            if (handshakeStatus == HandshakeStatus.NEED_UNWRAP && netReadBuffer.position() > 0) {
+                try {
+                    handshakeUnwrap(false);
+                } catch (SSLException e1) {
+                    handshakeFailure(e1, false);
+                }
+            }
+            // If we get here, this is not a handshake failure, throw the original IOException
+            throw e;
+        }
+
+        // Read from socket failed, so throw any pending handshake exception or EOF exception.
+        if (read == -1) {
+            maybeThrowSslAuthenticationException();
+            throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus);
+        }
+    }
+
+    private void doHandshake() throws IOException {
         boolean read = key.isReadable();
         boolean write = key.isWritable();
-        handshakeComplete = false;
         handshakeStatus = sslEngine.getHandshakeStatus();
         if (!flush(netWriteBuffer)) {
             key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
             return;
         }
-        try {
-            switch (handshakeStatus) {
-                case NEED_TASK:
-                    log.trace("SSLHandshake NEED_TASK channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
-                              channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
-                    handshakeStatus = runDelegatedTasks();
+        // Throw any pending handshake exception since `netWriteBuffer` has been flushed
+        maybeThrowSslAuthenticationException();
+
+        switch (handshakeStatus) {
+            case NEED_TASK:
+                log.trace("SSLHandshake NEED_TASK channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+                          channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+                handshakeStatus = runDelegatedTasks();
+                break;
+            case NEED_WRAP:
+                log.trace("SSLHandshake NEED_WRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+                          channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+                handshakeResult = handshakeWrap(write);
+                if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
+                    int currentNetWriteBufferSize = netWriteBufferSize();
+                    netWriteBuffer.compact();
+                    netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize);
+                    netWriteBuffer.flip();
+                    if (netWriteBuffer.limit() >= currentNetWriteBufferSize) {
+                        throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() +
+                                                        ") >= network buffer size (" + currentNetWriteBufferSize + ")");
+                    }
+                } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
+                    throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP.");
+                } else if (handshakeResult.getStatus() == Status.CLOSED) {
+                    throw new EOFException();
+                }
+                log.trace("SSLHandshake NEED_WRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+                       channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+                //if handshake status is not NEED_UNWRAP or unable to flush netWriteBuffer contents
+                //we will break here otherwise we can do need_unwrap in the same call.
+                if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || !flush(netWriteBuffer)) {
+                    key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
                     break;
-                case NEED_WRAP:
-                    log.trace("SSLHandshake NEED_WRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
-                              channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
-                    handshakeResult = handshakeWrap(write);
+                }
+            case NEED_UNWRAP:
+                log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+                          channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+                do {
+                    handshakeResult = handshakeUnwrap(read);
                     if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
-                        int currentNetWriteBufferSize = netWriteBufferSize();
-                        netWriteBuffer.compact();
-                        netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize);
-                        netWriteBuffer.flip();
-                        if (netWriteBuffer.limit() >= currentNetWriteBufferSize) {
-                            throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() +
-                                                            ") >= network buffer size (" + currentNetWriteBufferSize + ")");
+                        int currentAppBufferSize = applicationBufferSize();
+                        appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize);
+                        if (appReadBuffer.position() > currentAppBufferSize) {
+                            throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() +
+                                                           ") > packet buffer size (" + currentAppBufferSize + ")");
                         }
-                    } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
-                        throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP.");
-                    } else if (handshakeResult.getStatus() == Status.CLOSED) {
-                        throw new EOFException();
                     }
-                    log.trace("SSLHandshake NEED_WRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
-                              channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
-                    //if handshake status is not NEED_UNWRAP or unable to flush netWriteBuffer contents
-                    //we will break here otherwise we can do need_unwrap in the same call.
-                    if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || !flush(netWriteBuffer)) {
-                        key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
-                        break;
-                    }
-                case NEED_UNWRAP:
-                    log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
-                              channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
-                    do {
-                        handshakeResult = handshakeUnwrap(read);
-                        if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) {
-                            int currentAppBufferSize = applicationBufferSize();
-                            appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize);
-                            if (appReadBuffer.position() > currentAppBufferSize) {
-                                throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() +
-                                                                ") > packet buffer size (" + currentAppBufferSize + ")");
-                            }
-                        }
-                    } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW);
-                    if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
-                        int currentNetReadBufferSize = netReadBufferSize();
-                        netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize);
-                        if (netReadBuffer.position() >= currentNetReadBufferSize) {
-                            throw new IllegalStateException("Buffer underflow when there is available data");
-                        }
-                    } else if (handshakeResult.getStatus() == Status.CLOSED) {
-                        throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP");
+                } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW);
+                if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) {
+                    int currentNetReadBufferSize = netReadBufferSize();
+                    netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize);
+                    if (netReadBuffer.position() >= currentNetReadBufferSize) {
+                        throw new IllegalStateException("Buffer underflow when there is available data");
                     }
-                    log.trace("SSLHandshake NEED_UNWRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
-                              channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
-
-                    //if handshakeStatus completed than fall-through to finished status.
-                    //after handshake is finished there is no data left to read/write in socketChannel.
-                    //so the selector won't invoke this channel if we don't go through the handshakeFinished here.
-                    if (handshakeStatus != HandshakeStatus.FINISHED) {
-                        if (handshakeStatus == HandshakeStatus.NEED_WRAP) {
-                            key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
-                        } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
-                            key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
-                        }
-                        break;
+                } else if (handshakeResult.getStatus() == Status.CLOSED) {
+                    throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP");
+                }
+                log.trace("SSLHandshake NEED_UNWRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}",
+                          channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position());
+
+                //if handshakeStatus completed than fall-through to finished status.
+                //after handshake is finished there is no data left to read/write in socketChannel.
+                //so the selector won't invoke this channel if we don't go through the handshakeFinished here.
+                if (handshakeStatus != HandshakeStatus.FINISHED) {
+                    if (handshakeStatus == HandshakeStatus.NEED_WRAP) {
+                        key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
+                    } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
+                        key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
                     }
-                case FINISHED:
-                    handshakeFinished();
-                    break;
-                case NOT_HANDSHAKING:
-                    handshakeFinished();
                     break;
-                default:
-                    throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus));
-            }
-
-        } catch (SSLException e) {
-            handshakeFailure();
-            throw e;
+                }
+            case FINISHED:
+                handshakeFinished();
+                break;
+            case NOT_HANDSHAKING:
+                handshakeFinished();
+                break;
+            default:
+                throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus));
         }
     }
 
@@ -346,12 +400,12 @@ public class SslTransportLayer implements TransportLayer {
         // It can move from FINISHED status to NOT_HANDSHAKING after the handshake is completed.
         // Hence we also need to check handshakeResult.getHandshakeStatus() if the handshake finished or not
         if (handshakeResult.getHandshakeStatus() == HandshakeStatus.FINISHED) {
-            //we are complete if we have delivered the last package
-            handshakeComplete = !netWriteBuffer.hasRemaining();
+            //we are complete if we have delivered the last packet
             //remove OP_WRITE if we are complete, otherwise we still have data to write
-            if (!handshakeComplete)
+            if (netWriteBuffer.hasRemaining())
                 key.interestOps(key.interestOps() | SelectionKey.OP_WRITE);
             else {
+                state = State.READY;
                 key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE);
                 SSLSession session = sslEngine.getSession();
                 log.debug("SSL handshake completed successfully with peerHost '{}' peerPort {} peerPrincipal '{}' cipherSuite '{}'",
@@ -400,10 +454,9 @@ public class SslTransportLayer implements TransportLayer {
     private SSLEngineResult handshakeUnwrap(boolean doRead) throws IOException {
         log.trace("SSLHandshake handshakeUnwrap {}", channelId);
         SSLEngineResult result;
-        if (doRead)  {
-            int read = socketChannel.read(netReadBuffer);
-            if (read == -1) throw new EOFException("EOF during handshake.");
-        }
+        int read = 0;
+        if (doRead)
+            read = readFromSocketChannel();
         boolean cont;
         do {
             //prepare the buffer with the incoming data
@@ -420,6 +473,11 @@ public class SslTransportLayer implements TransportLayer {
             log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus());
         } while (netReadBuffer.position() != 0 && cont);
 
+        // Throw EOF exception for failed read after processing already received data
+        // so that handshake failures are reported correctly
+        if (read == -1)
+            throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus);
+
         return result;
     }
 
@@ -429,27 +487,27 @@ public class SslTransportLayer implements TransportLayer {
     *
     * @param dst The buffer into which bytes are to be transferred
     * @return The number of bytes read, possible zero or -1 if the channel has reached end-of-stream
+    *         and no more data is available
     * @throws IOException if some other I/O error occurs
     */
     @Override
     public int read(ByteBuffer dst) throws IOException {
-        if (closing) return -1;
-        int read = 0;
-        if (!handshakeComplete) return read;
+        if (state == State.CLOSING) return -1;
+        else if (state != State.READY) return 0;
 
         //if we have unread decrypted data in appReadBuffer read that into dst buffer.
+        int read = 0;
         if (appReadBuffer.position() > 0) {
             read = readFromAppBuffer(dst);
         }
 
+        int netread = 0;
         if (dst.remaining() > 0) {
             netReadBuffer = Utils.ensureCapacity(netReadBuffer, netReadBufferSize());
-            if (netReadBuffer.remaining() > 0) {
-                int netread = socketChannel.read(netReadBuffer);
-                if (netread == 0 && netReadBuffer.position() == 0) return read;
-                else if (netread < 0) throw new EOFException("EOF during read");
-            }
-            do {
+            if (netReadBuffer.remaining() > 0)
+                netread = readFromSocketChannel();
+
+            while (netReadBuffer.position() > 0) {
                 netReadBuffer.flip();
                 SSLEngineResult unwrapResult = sslEngine.unwrap(netReadBuffer, appReadBuffer);
                 netReadBuffer.compact();
@@ -493,8 +551,12 @@ public class SslTransportLayer implements TransportLayer {
                     else
                         break;
                 }
-            } while (netReadBuffer.position() != 0);
+            }
         }
+        // If data has been read and unwrapped, return the data even if end-of-stream, channel will be closed
+        // on a subsequent poll.
+        if (read == 0 && netread < 0)
+            throw new EOFException("EOF during read");
         return read;
     }
 
@@ -553,8 +615,8 @@ public class SslTransportLayer implements TransportLayer {
     @Override
     public int write(ByteBuffer src) throws IOException {
         int written = 0;
-        if (closing) throw new IllegalStateException("Channel is in closing state");
-        if (!handshakeComplete) return written;
+        if (state == State.CLOSING) throw new IllegalStateException("Channel is in closing state");
+        if (state != State.READY) return written;
 
         if (!flush(netWriteBuffer))
             return written;
@@ -662,7 +724,7 @@ public class SslTransportLayer implements TransportLayer {
     public void addInterestOps(int ops) {
         if (!key.isValid())
             throw new CancelledKeyException();
-        else if (!handshakeComplete)
+        else if (state != State.READY)
             throw new IllegalStateException("handshake is not completed");
 
         key.interestOps(key.interestOps() | ops);
@@ -676,7 +738,7 @@ public class SslTransportLayer implements TransportLayer {
     public void removeInterestOps(int ops) {
         if (!key.isValid())
             throw new CancelledKeyException();
-        else if (!handshakeComplete)
+        else if (state != State.READY)
             throw new IllegalStateException("handshake is not completed");
 
         key.interestOps(key.interestOps() & ~ops);
@@ -723,7 +785,12 @@ public class SslTransportLayer implements TransportLayer {
         return netReadBuffer;
     }
 
-    private void handshakeFailure() {
+    /**
+     * SSL exceptions are propagated as authentication failures so that clients can avoid
+     * retries and report the failure. If `flush` is true, exceptions are propagated after
+     * any pending outgoing bytes are flushed to ensure that the peer is notified of the failure.
+     */
+    private void handshakeFailure(SSLException sslException, boolean flush) throws IOException {
         //Release all resources such as internal buffers that SSLEngine is managing
         sslEngine.closeOutbound();
         try {
@@ -731,6 +798,17 @@ public class SslTransportLayer implements TransportLayer {
         } catch (SSLException e) {
             log.debug("SSLEngine.closeInBound() raised an exception.", e);
         }
+
+        state = State.HANDSHAKE_FAILED;
+        handshakeException = new SslAuthenticationException("SSL handshake failed", sslException);
+        if (!flush || flush(netWriteBuffer))
+            throw handshakeException;
+    }
+
+    // If handshake has already failed, throw the authentication exception.
+    private void maybeThrowSslAuthenticationException() {
+        if (handshakeException != null)
+            throw handshakeException;
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
index be56ad5..23f866b 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java
@@ -31,6 +31,7 @@ import java.nio.channels.GatheringByteChannel;
 
 import java.security.Principal;
 
+import org.apache.kafka.common.errors.AuthenticationException;
 
 public interface TransportLayer extends ScatteringByteChannel, GatheringByteChannel {
 
@@ -61,11 +62,14 @@ public interface TransportLayer extends ScatteringByteChannel, GatheringByteChan
 
 
     /**
-     * Performs SSL handshake hence is a no-op for the non-secure
-     * implementation
-     * @throws IOException
+     * This a no-op for the non-secure PLAINTEXT implementation. For SSL, this performs
+     * SSL handshake. The SSL handshake includes client authentication if configured using
+     * {@link org.apache.kafka.common.config.SslConfigsSslConfigs#SSL_CLIENT_AUTH_CONFIG}.
+     * @throws AuthenticationException if handshake fails due to an
+     *         {@link javax.net.ssl.SSLExceptionSSLException}.
+     * @throws IOException if read or write fails with an I/O error.
     */
-    void handshake() throws IOException;
+    void handshake() throws AuthenticationException, IOException;
 
     /**
      * Returns true if there are any pending writes

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
index 8207a5a..d9e4f0c 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java
@@ -20,8 +20,8 @@ import org.apache.kafka.clients.CommonClientConfigs;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.config.SaslConfigs;
-import org.apache.kafka.common.errors.AuthenticationException;
 import org.apache.kafka.common.errors.IllegalSaslStateException;
+import org.apache.kafka.common.errors.SaslAuthenticationException;
 import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
 import org.apache.kafka.common.network.Authenticator;
 import org.apache.kafka.common.network.Mode;
@@ -104,9 +104,6 @@ public class SaslClientAuthenticator implements Authenticator {
     private RequestHeader currentRequestHeader;
     // Version of SaslAuthenticate request/responses
     private short saslAuthenticateVersion;
-    // Sasl authentication error which may be one of NONE, UNSUPPORTED_SASL_MECHANISM, ILLEGAL_SASL_STATE,
-    // SASL_AUTHENTICATION_FAILED or NETWORK_EXCEPTION
-    private Errors error;
 
     public SaslClientAuthenticator(Map<String, ?> configs,
                                    String node,
@@ -125,7 +122,6 @@ public class SaslClientAuthenticator implements Authenticator {
         this.transportLayer = transportLayer;
         this.configs = configs;
         this.saslAuthenticateVersion = DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER;
-        this.error = Errors.NONE;
 
         try {
             setSaslState(handshakeRequestEnable ? SaslState.SEND_APIVERSIONS_REQUEST : SaslState.INITIAL);
@@ -143,7 +139,7 @@ public class SaslClientAuthenticator implements Authenticator {
 
             saslClient = createSaslClient();
         } catch (Exception e) {
-            throw new KafkaException("Failed to configure SaslClientAuthenticator", e);
+            throw new SaslAuthenticationException("Failed to configure SaslClientAuthenticator", e);
         }
     }
 
@@ -158,7 +154,7 @@ public class SaslClientAuthenticator implements Authenticator {
                 }
             });
         } catch (PrivilegedActionException e) {
-            throw new KafkaException("Failed to create SaslClient with mechanism " + mechanism, e.getCause());
+            throw new SaslAuthenticationException("Failed to create SaslClient with mechanism " + mechanism, e.getCause());
         }
     }
 
@@ -236,11 +232,6 @@ public class SaslClientAuthenticator implements Authenticator {
         }
     }
 
-    @Override
-    public Errors error() {
-        return error;
-    }
-
     private RequestHeader nextRequestHeader(ApiKeys apiKey, short version) {
         String clientId = (String) configs.get(CommonClientConfigs.CLIENT_ID_CONFIG);
         currentRequestHeader = new RequestHeader(apiKey, version, clientId, correlationId++);
@@ -345,8 +336,8 @@ public class SaslClientAuthenticator implements Authenticator {
         } else {
             SaslAuthenticateResponse response = (SaslAuthenticateResponse) receiveKafkaResponse();
             if (response != null) {
-                this.error = response.error();
-                if (this.error != Errors.NONE) {
+                Errors error = response.error();
+                if (error != Errors.NONE) {
                     setSaslState(SaslState.FAILED);
                     String errMsg = response.errorMessage();
                     throw errMsg == null ? error.exception() : error.exception(errMsg);
@@ -360,7 +351,7 @@ public class SaslClientAuthenticator implements Authenticator {
 
     private byte[] createSaslToken(final byte[] saslToken, boolean isInitial) throws SaslException {
         if (saslToken == null)
-            throw new SaslException("Error authenticating with the Kafka Broker: received a `null` saslToken.");
+            throw new IllegalSaslStateException("Error authenticating with the Kafka Broker: received a `null` saslToken.");
 
         try {
             if (isInitial && !saslClient.hasInitialResponse())
@@ -384,9 +375,9 @@ public class SaslClientAuthenticator implements Authenticator {
                     " Users must configure FQDN of kafka brokers when authenticating using SASL and" +
                     " `socketChannel.socket().getInetAddress().getHostName()` must match the hostname in `principal/hostname@realm`";
             }
-            error += " Kafka Client will go to AUTH_FAILED state.";
+            error += " Kafka Client will go to AUTHENTICATION_FAILED state.";
             //Unwrap the SaslException inside `PrivilegedActionException`
-            throw new SaslException(error, e.getCause());
+            throw new SaslAuthenticationException(error, e.getCause());
         }
     }
 
@@ -410,12 +401,12 @@ public class SaslClientAuthenticator implements Authenticator {
         } catch (SchemaException | IllegalArgumentException e) {
             LOG.debug("Invalid SASL mechanism response, server may be expecting only GSSAPI tokens");
             setSaslState(SaslState.FAILED);
-            throw new AuthenticationException("Invalid SASL mechanism response", e);
+            throw new IllegalSaslStateException("Invalid SASL mechanism response, server may be expecting a different protocol", e);
         }
     }
 
     private void handleSaslHandshakeResponse(SaslHandshakeResponse response) {
-        this.error = response.error();
+        Errors error = response.error();
         if (error != Errors.NONE)
             setSaslState(SaslState.FAILED);
         switch (error) {
@@ -428,7 +419,7 @@ public class SaslClientAuthenticator implements Authenticator {
                 throw new IllegalSaslStateException(String.format("Unexpected handshake request with client mechanism %s, enabled mechanisms are %s",
                     mechanism, response.enabledMechanisms()));
             default:
-                throw new AuthenticationException(String.format("Unknown error code %s, client mechanism is %s, enabled mechanisms are %s",
+                throw new IllegalSaslStateException(String.format("Unknown error code %s, client mechanism is %s, enabled mechanisms are %s",
                     response.error(), mechanism, response.enabledMechanisms()));
         }
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
index 6202131..fe57d27 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java
@@ -123,8 +123,6 @@ public class SaslServerAuthenticator implements Authenticator {
     private Send netOutBuffer;
     // flag indicating if sasl tokens are sent as Kafka SaslAuthenticate request/responses
     private boolean enableKafkaSaslAuthenticateHeaders;
-    // authentication error if authentication failed
-    private Errors error;
 
     public SaslServerAuthenticator(Map<String, ?> configs,
                                    String connectionId,
@@ -144,7 +142,6 @@ public class SaslServerAuthenticator implements Authenticator {
         this.listenerName = listenerName;
         this.securityProtocol = securityProtocol;
         this.enableKafkaSaslAuthenticateHeaders = false;
-        this.error = Errors.NONE;
 
         this.transportLayer = transportLayer;
 
@@ -288,11 +285,6 @@ public class SaslServerAuthenticator implements Authenticator {
     }
 
     @Override
-    public Errors error() {
-        return error;
-    }
-
-    @Override
     public boolean complete() {
         return saslState == SaslState.COMPLETE;
     }
@@ -366,13 +358,11 @@ public class SaslServerAuthenticator implements Authenticator {
                     KafkaPrincipal.ANONYMOUS, listenerName, securityProtocol);
             RequestAndSize requestAndSize = requestContext.parseRequest(requestBuffer);
             if (apiKey != ApiKeys.SASL_AUTHENTICATE) {
-                this.error = Errors.ILLEGAL_SASL_STATE;
                 IllegalSaslStateException e = new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL authentication.");
                 sendKafkaResponse(requestContext, requestAndSize.request.getErrorResponse(e));
                 throw e;
             }
             if (!apiKey.isVersionSupported(version)) {
-                this.error = Errors.UNSUPPORTED_VERSION;
                 // We cannot create an error response if the request version of SaslAuthenticate is not supported
                 // This should not normally occur since clients typically check supported versions using ApiVersionsRequest
                 throw new UnsupportedVersionException("Version " + version + " is not supported for apiKey " + apiKey);
@@ -385,8 +375,7 @@ public class SaslServerAuthenticator implements Authenticator {
                 ByteBuffer responseBuf = responseToken == null ? EMPTY_BUFFER : ByteBuffer.wrap(responseToken);
                 sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.NONE, null, responseBuf));
             } catch (SaslException e) {
-                this.error = Errors.SASL_AUTHENTICATION_FAILED;
-                sendKafkaResponse(requestContext, new SaslAuthenticateResponse(this.error,
+                sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED,
                         "Authentication failed due to invalid credentials with SASL mechanism " + saslMechanism));
                 throw e;
             }
@@ -462,8 +451,7 @@ public class SaslServerAuthenticator implements Authenticator {
             return clientMechanism;
         } else {
             LOG.debug("SASL mechanism '{}' requested by client is not supported", clientMechanism);
-            this.error = Errors.UNSUPPORTED_SASL_MECHANISM;
-            sendKafkaResponse(context, new SaslHandshakeResponse(this.error, enabledMechanisms));
+            sendKafkaResponse(context, new SaslHandshakeResponse(Errors.UNSUPPORTED_SASL_MECHANISM, enabledMechanisms));
             throw new UnsupportedSaslMechanismException("Unsupported SASL mechanism " + clientMechanism);
         }
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
index e456d68..190fa3d 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java
@@ -85,12 +85,14 @@ public class NioEchoServer extends Thread {
             acceptorThread.start();
             while (serverSocketChannel.isOpen()) {
                 selector.poll(1000);
-                for (SocketChannel socketChannel : newChannels) {
-                    String id = id(socketChannel);
-                    selector.register(id, socketChannel);
-                    socketChannels.add(socketChannel);
+                synchronized (newChannels) {
+                    for (SocketChannel socketChannel : newChannels) {
+                        String id = id(socketChannel);
+                        selector.register(id, socketChannel);
+                        socketChannels.add(socketChannel);
+                    }
+                    newChannels.clear();
                 }
-                newChannels.clear();
 
                 List<NetworkReceive> completedReceives = selector.completedReceives();
                 for (NetworkReceive rcv : completedReceives) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/d60f011d/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
index cffcc89..90c8cd5 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java
@@ -48,8 +48,11 @@ import java.nio.channels.SocketChannel;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -159,7 +162,7 @@ public class SslTransportLayerTest {
         InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port());
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
 
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
     }
 
     /**
@@ -187,17 +190,13 @@ public class SslTransportLayerTest {
         sslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores);
 
         // Create a server with endpoint validation enabled on the server SSL engine
-        SslChannelBuilder serverChannelBuilder = new SslChannelBuilder(Mode.SERVER) {
+        SslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER) {
             @Override
-            protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException {
-                SocketChannel socketChannel = (SocketChannel) key.channel();
-                SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort());
+            protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException {
                 SSLParameters sslParams = sslEngine.getSSLParameters();
                 sslParams.setEndpointIdentificationAlgorithm("HTTPS");
                 sslEngine.setSSLParameters(sslParams);
-                TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, BUFFER_SIZE, BUFFER_SIZE, BUFFER_SIZE);
-                transportLayer.startHandshake();
-                return transportLayer;
+                return super.newTransportLayer(id, key, sslEngine);
             }
         };
         serverChannelBuilder.configure(sslServerConfigs);
@@ -211,7 +210,7 @@ public class SslTransportLayerTest {
 
         NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
     }
-    
+
     /**
      * Tests that server certificate with invalid host name is not accepted by
      * a client that validates server endpoint. Server certificate uses
@@ -230,9 +229,9 @@ public class SslTransportLayerTest {
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
 
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
     }
-    
+
     /**
      * Tests that server certificate with invalid IP address is accepted by
      * a client that has disabled endpoint validation
@@ -252,7 +251,7 @@ public class SslTransportLayerTest {
 
         NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
     }
-    
+
     /**
      * Tests that server accepts connections from clients with a trusted certificate
      * when client authentication is required.
@@ -295,7 +294,7 @@ public class SslTransportLayerTest {
         sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
         createSelector(sslClientConfigs);
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
         selector.close();
         server.close();
 
@@ -308,7 +307,7 @@ public class SslTransportLayerTest {
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
         NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
     }
-    
+
     /**
      * Tests that server does not accept connections from clients with an untrusted certificate
      * when client authentication is required.
@@ -323,9 +322,9 @@ public class SslTransportLayerTest {
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
 
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
     }
-    
+
     /**
      * Tests that server does not accept connections from clients which don't
      * provide a certificate when client authentication is required.
@@ -335,7 +334,7 @@ public class SslTransportLayerTest {
         String node = "0";
         sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required");
         server = createEchoServer(SecurityProtocol.SSL);
-        
+
         sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
         sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
         sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
@@ -343,9 +342,9 @@ public class SslTransportLayerTest {
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
 
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
     }
-    
+
     /**
      * Tests that server accepts connections from a client configured
      * with an untrusted certificate if client authentication is disabled
@@ -362,7 +361,7 @@ public class SslTransportLayerTest {
 
         NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
     }
-    
+
     /**
      * Tests that server accepts connections from a client that does not provide
      * a certificate if client authentication is disabled
@@ -372,7 +371,7 @@ public class SslTransportLayerTest {
         String node = "0";
         sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "none");
         server = createEchoServer(SecurityProtocol.SSL);
-        
+
         sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
         sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
         sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
@@ -382,7 +381,7 @@ public class SslTransportLayerTest {
 
         NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
     }
-    
+
     /**
      * Tests that server accepts connections from a client configured
      * with a valid certificate if client authentication is requested
@@ -398,7 +397,7 @@ public class SslTransportLayerTest {
 
         NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
     }
-    
+
     /**
      * Tests that server accepts connections from a client that does not provide
      * a certificate if client authentication is requested but not required
@@ -408,7 +407,7 @@ public class SslTransportLayerTest {
         String node = "0";
         sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "requested");
         server = createEchoServer(SecurityProtocol.SSL);
-        
+
         sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG);
         sslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG);
         sslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG);
@@ -448,7 +447,7 @@ public class SslTransportLayerTest {
             // Expected exception
         }
     }
-    
+
     /**
      * Tests that channels cannot be created if keystore cannot be loaded
      */
@@ -481,7 +480,7 @@ public class SslTransportLayerTest {
 
         NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
     }
-    
+
     /**
      * Tests that client connections cannot be created to a server
      * if key password is invalid
@@ -495,9 +494,9 @@ public class SslTransportLayerTest {
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
 
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
     }
-    
+
     /**
      * Tests that connections cannot be made with unsupported TLS versions
      */
@@ -506,15 +505,15 @@ public class SslTransportLayerTest {
         String node = "0";
         sslServerConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.2"));
         server = createEchoServer(SecurityProtocol.SSL);
-        
+
         sslClientConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.1"));
         createSelector(sslClientConfigs);
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
 
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
     }
-    
+
     /**
      * Tests that connections cannot be made with unsupported TLS cipher suites
      */
@@ -524,13 +523,13 @@ public class SslTransportLayerTest {
         String[] cipherSuites = SSLContext.getDefault().getDefaultSSLParameters().getCipherSuites();
         sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[0]));
         server = createEchoServer(SecurityProtocol.SSL);
-        
+
         sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[1]));
         createSelector(sslClientConfigs);
         InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
         selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
 
-        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.AUTHENTICATE.state());
+        NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
     }
 
     /**
@@ -546,7 +545,7 @@ public class SslTransportLayerTest {
 
         NetworkTestUtils.checkClientConnection(selector, node, 64000, 10);
     }
-    
+
     /**
      * Tests handling of BUFFER_OVERFLOW during wrap when network write buffer is smaller than SSL session packet buffer size.
      */
@@ -602,14 +601,98 @@ public class SslTransportLayerTest {
         }
         assertTrue("Send time not recorded", channel.getAndResetNetworkThreadTimeNanos() > 0);
         assertEquals("Time not reset", 0, channel.getAndResetNetworkThreadTimeNanos());
+        assertFalse("Unexpected bytes buffered", channel.hasBytesBuffered());
+        assertEquals(0, selector.completedReceives().size());
 
         selector.unmute(node);
         while (selector.completedReceives().isEmpty()) {
             selector.poll(100L);
+            assertEquals(0, selector.numStagedReceives(channel));
         }
         assertTrue("Receive time not recorded", channel.getAndResetNetworkThreadTimeNanos() > 0);
     }
 
+    /**
+     * Tests that IOExceptions from read during SSL handshake are not treated as authentication failures.
+     */
+    @Test
+    public void testIOExceptionsDuringHandshakeRead() throws Exception {
+        testIOExceptionsDuringHandshake(true, false);
+    }
+
+    /**
+     * Tests that IOExceptions from write during SSL handshake are not treated as authentication failures.
+     */
+    @Test
+    public void testIOExceptionsDuringHandshakeWrite() throws Exception {
+        testIOExceptionsDuringHandshake(false, true);
+    }
+
+    private void testIOExceptionsDuringHandshake(boolean failRead, boolean failWrite) throws Exception {
+        server = createEchoServer(SecurityProtocol.SSL);
+        TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
+        boolean done = false;
+        for (int i = 1; i <= 100; i++) {
+            int readFailureIndex = failRead ? i : Integer.MAX_VALUE;
+            int flushFailureIndex = failWrite ? i : Integer.MAX_VALUE;
+            String node = String.valueOf(i);
+
+            channelBuilder.readFailureIndex = readFailureIndex;
+            channelBuilder.flushFailureIndex = flushFailureIndex;
+            channelBuilder.configure(sslClientConfigs);
+            this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext());
+
+            InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
+            selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+            for (int j = 0; j < 30; j++) {
+                selector.poll(1000L);
+                KafkaChannel channel = selector.channel(node);
+                if (channel != null && channel.ready()) {
+                    done = true;
+                    break;
+                }
+                if (selector.disconnected().containsKey(node)) {
+                    assertEquals(ChannelState.State.AUTHENTICATE, selector.disconnected().get(node).state());
+                    break;
+                }
+            }
+            KafkaChannel channel = selector.channel(node);
+            if (channel != null)
+                assertTrue("Channel not ready or disconnected:" + channel.state().state(), channel.ready());
+        }
+        assertTrue("Too many invocations of read/write during SslTransportLayer.handshake()", done);
+    }
+
+    /**
+     * Tests that handshake failures are propagated only after writes complete, even when
+     * there are delays in writes to ensure that clients see an authentication exception
+     * rather than a connection failure.
+     */
+    @Test
+    public void testPeerNotifiedOfHandshakeFailure() throws Exception {
+        sslServerConfigs = serverCertStores.getUntrustingConfig();
+        sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required");
+
+        // Test without delay and a couple of delay counts to ensure delay applies to handshake failure
+        for (int i = 0; i < 3; i++) {
+            String node = "0";
+            TestSslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER);
+            serverChannelBuilder.configure(sslServerConfigs);
+            serverChannelBuilder.flushDelayCount = i;
+            server = new NioEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL),
+                    SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs),
+                    "localhost", serverChannelBuilder, null);
+            server.start();
+            createSelector(sslClientConfigs);
+            InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
+            selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
+
+            NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
+            server.close();
+            selector.close();
+        }
+    }
+
     @Test
     public void testCloseSsl() throws Exception {
         testClose(SecurityProtocol.SSL, new SslChannelBuilder(Mode.CLIENT));
@@ -654,24 +737,13 @@ public class SslTransportLayerTest {
 
     private void createSelector(Map<String, Object> sslClientConfigs) {
         createSelector(sslClientConfigs, null, null, null);
-    }      
+    }
 
     private void createSelector(Map<String, Object> sslClientConfigs, final Integer netReadBufSize,
                                 final Integer netWriteBufSize, final Integer appBufSize) {
-        
-        this.channelBuilder = new SslChannelBuilder(Mode.CLIENT) {
-
-            @Override
-            protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException {
-                SocketChannel socketChannel = (SocketChannel) key.channel();
-                SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort());
-                TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, netReadBufSize, netWriteBufSize, appBufSize);
-                transportLayer.startHandshake();
-                return transportLayer;
-            }
-
-
-        };
+        TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT);
+        channelBuilder.configureBufferSizes(netReadBufSize, netWriteBufSize, appBufSize);
+        this.channelBuilder = channelBuilder;
         this.channelBuilder.configure(sslClientConfigs);
         this.selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext());
     }
@@ -683,47 +755,111 @@ public class SslTransportLayerTest {
     private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception {
         return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol);
     }
-    
-    /**
-     * SSLTransportLayer with overrides for packet and application buffer size to test buffer resize
-     * code path. The overridden buffer size starts with a small value and increases in size when the buffer
-     * size is retrieved to handle overflow/underflow, until the actual session buffer size is reached.
-     */
-    private static class TestSslTransportLayer extends SslTransportLayer {
-
-        private final ResizeableBufferSize netReadBufSize;
-        private final ResizeableBufferSize netWriteBufSize;
-        private final ResizeableBufferSize appBufSize;
-
-        public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine,
-                                     Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) throws IOException {
-            super(channelId, key, sslEngine, false);
-            this.netReadBufSize = new ResizeableBufferSize(netReadBufSize);
-            this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSize);
-            this.appBufSize = new ResizeableBufferSize(appBufSize);
+
+    private static class TestSslChannelBuilder extends SslChannelBuilder {
+
+        private Integer netReadBufSizeOverride;
+        private Integer netWriteBufSizeOverride;
+        private Integer appBufSizeOverride;
+        long readFailureIndex = Long.MAX_VALUE;
+        long flushFailureIndex = Long.MAX_VALUE;
+        int flushDelayCount = 0;
+
+        public TestSslChannelBuilder(Mode mode) {
+            super(mode);
         }
-        
-        @Override
-        protected int netReadBufferSize() {
-            ByteBuffer netReadBuffer = netReadBuffer();
-            // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read
-            // operation. To avoid the read buffer being expanded too early, increase buffer size
-            // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always
-            // triggered in testNetReadBufferResize().
-            boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining();
-            return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize);
+
+        public void configureBufferSizes(Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) {
+            this.netReadBufSizeOverride = netReadBufSize;
+            this.netWriteBufSizeOverride = netWriteBufSize;
+            this.appBufSizeOverride = appBufSize;
         }
-        
+
         @Override
-        protected int netWriteBufferSize() {
-            return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true);
+        protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, String host) throws IOException {
+            SocketChannel socketChannel = (SocketChannel) key.channel();
+            SSLEngine sslEngine = sslFactory.createSslEngine(host, socketChannel.socket().getPort());
+            TestSslTransportLayer transportLayer = newTransportLayer(id, key, sslEngine);
+            transportLayer.startHandshake();
+            return transportLayer;
         }
 
-        @Override
-        protected int applicationBufferSize() {
-            return appBufSize.updateAndGet(super.applicationBufferSize(), true);
+        protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException {
+            return new TestSslTransportLayer(id, key, sslEngine);
+        }
+
+        /**
+         * SSLTransportLayer with overrides for testing including:
+         * <ul>
+         * <li>Overrides for packet and application buffer size to test buffer resize code path.
+         * The overridden buffer size starts with a small value and increases in size when the buffer size
+         * is retrieved to handle overflow/underflow, until the actual session buffer size is reached.</li>
+         * <li>IOException injection for reads and writes for testing exception handling during handshakes.</li>
+         * <li>Delayed writes to test handshake failure notifications to peer</li>
+         * </ul>
+         */
+        class TestSslTransportLayer extends SslTransportLayer {
+
+            private final ResizeableBufferSize netReadBufSize;
+            private final ResizeableBufferSize netWriteBufSize;
+            private final ResizeableBufferSize appBufSize;
+            private final AtomicLong numReadsRemaining;
+            private final AtomicLong numFlushesRemaining;
+            private final AtomicInteger numDelayedFlushesRemaining;
+
+            public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine) throws IOException {
+                super(channelId, key, sslEngine, false);
+                this.netReadBufSize = new ResizeableBufferSize(netReadBufSizeOverride);
+                this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSizeOverride);
+                this.appBufSize = new ResizeableBufferSize(appBufSizeOverride);
+                numReadsRemaining = new AtomicLong(readFailureIndex);
+                numFlushesRemaining = new AtomicLong(flushFailureIndex);
+                numDelayedFlushesRemaining = new AtomicInteger(flushDelayCount);
+            }
+
+            @Override
+            protected int netReadBufferSize() {
+                ByteBuffer netReadBuffer = netReadBuffer();
+                // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read
+                // operation. To avoid the read buffer being expanded too early, increase buffer size
+                // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always
+                // triggered in testNetReadBufferResize().
+                boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining();
+                return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize);
+            }
+
+            @Override
+            protected int netWriteBufferSize() {
+                return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true);
+            }
+
+            @Override
+            protected int applicationBufferSize() {
+                return appBufSize.updateAndGet(super.applicationBufferSize(), true);
+            }
+
+            @Override
+            protected int readFromSocketChannel() throws IOException {
+                if (numReadsRemaining.decrementAndGet() == 0 && !ready())
+                    throw new IOException("Test exception during read");
+                return super.readFromSocketChannel();
+            }
+
+            @Override
+            protected boolean flush(ByteBuffer buf) throws IOException {
+                if (numFlushesRemaining.decrementAndGet() == 0 && !ready())
+                    throw new IOException("Test exception during write");
+                else if (numDelayedFlushesRemaining.getAndDecrement() != 0)
+                    return false;
+                resetDelayedFlush();
+                return super.flush(buf);
+            }
+
+            private void resetDelayedFlush() {
+                numDelayedFlushesRemaining.set(flushDelayCount);
+            }
         }
-        
+
         private static class ResizeableBufferSize {
             private Integer bufSizeOverride;
             ResizeableBufferSize(Integer bufSizeOverride) {
@@ -740,5 +876,4 @@ public class SslTransportLayerTest {
             }
         }
     }
-
 }