You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ij...@apache.org on 2017/09/08 17:42:53 UTC

kafka git commit: KAFKA-5790, KAFKA-5607; Improve error handling in SocketServer to avoid issues later

Repository: kafka
Updated Branches:
  refs/heads/trunk 4769e3d92 -> 6ebd9d35f


KAFKA-5790, KAFKA-5607; Improve error handling in SocketServer to avoid issues later

Changes:
  1. When an exception is encountered in any of the methods in `Processor` while processing a channel, log the exception and close the connection. Continue to process other channels.
  2. Fixes KAFKA-5790: SocketServer.processNewResponses should not skip a response if exception is thrown.
  3. For `IllegalStateException` and `IOException` in `poll()`, don't close the `Selector`. Log the exception and continue.
  4. Close channel on any failed send in `Selector`.
  5. When closing channel fails or is closed, leave channel state as-is, indicating the state in which the channel was moved to closing.
  6. Add tests for various failure scenarios.
  7. Fix timing issue in `SocketServerTest.testConnectionIdReuse` by waiting for new connections to be processed by the server.

Author: Rajini Sivaram <ra...@googlemail.com>

Reviewers: Jun Rao <ju...@gmail.com>, Ismael Juma <is...@juma.me.uk>

Closes #3548 from rajinisivaram/KAFKA-5607


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/6ebd9d35
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/6ebd9d35
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/6ebd9d35

Branch: refs/heads/trunk
Commit: 6ebd9d35fc081bca3f1efeba8b70ec71cde85270
Parents: 4769e3d
Author: Rajini Sivaram <ra...@googlemail.com>
Authored: Fri Sep 8 18:42:42 2017 +0100
Committer: Ismael Juma <is...@juma.me.uk>
Committed: Fri Sep 8 18:42:42 2017 +0100

----------------------------------------------------------------------
 .../apache/kafka/common/network/Selector.java   |  73 +--
 .../kafka/common/network/SelectorTest.java      |  92 +++-
 .../kafka/common/network/SslSelectorTest.java   |   5 +
 .../main/scala/kafka/network/SocketServer.scala | 208 +++++---
 .../unit/kafka/network/SocketServerTest.scala   | 516 ++++++++++++++++++-
 5 files changed, 762 insertions(+), 132 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/clients/src/main/java/org/apache/kafka/common/network/Selector.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
index e873252..d321ba5 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
@@ -20,7 +20,6 @@ import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.net.Socket;
 import java.nio.channels.CancelledKeyException;
-import java.nio.channels.ClosedChannelException;
 import java.nio.channels.SelectionKey;
 import java.nio.channels.SocketChannel;
 import java.nio.channels.UnresolvedAddressException;
@@ -216,19 +215,7 @@ public class Selector implements Selectable, AutoCloseable {
             throw e;
         }
         SelectionKey key = socketChannel.register(nioSelector, SelectionKey.OP_CONNECT);
-        KafkaChannel channel;
-        try {
-            channel = channelBuilder.buildChannel(id, key, maxReceiveSize, memoryPool);
-        } catch (Exception e) {
-            try {
-                socketChannel.close();
-            } finally {
-                key.cancel();
-            }
-            throw new IOException("Channel could not be created for socket " + socketChannel, e);
-        }
-        key.attach(channel);
-        this.channels.put(id, channel);
+        KafkaChannel channel = buildChannel(socketChannel, id, key);
 
         if (connected) {
             // OP_CONNECT won't trigger for immediately connected channels
@@ -247,18 +234,36 @@ public class Selector implements Selectable, AutoCloseable {
      * Kafka brokers add an incrementing index to the connection id to avoid reuse in the timing window
      * where an existing connection may not yet have been closed by the broker when a new connection with
      * the same remote host:port is processed.
+     * </p><p>
+     * If a `KafkaChannel` cannot be created for this connection, the `socketChannel` is closed
+     * and its selection key cancelled.
      * </p>
      */
-    public void register(String id, SocketChannel socketChannel) throws ClosedChannelException {
+    public void register(String id, SocketChannel socketChannel) throws IOException {
         if (this.channels.containsKey(id))
             throw new IllegalStateException("There is already a connection for id " + id);
         if (this.closingChannels.containsKey(id))
             throw new IllegalStateException("There is already a connection for id " + id + " that is still being closed");
 
         SelectionKey key = socketChannel.register(nioSelector, SelectionKey.OP_READ);
-        KafkaChannel channel = channelBuilder.buildChannel(id, key, maxReceiveSize, memoryPool);
+        buildChannel(socketChannel, id, key);
+    }
+
+    private KafkaChannel buildChannel(SocketChannel socketChannel, String id, SelectionKey key) throws IOException {
+        KafkaChannel channel;
+        try {
+            channel = channelBuilder.buildChannel(id, key, maxReceiveSize, memoryPool);
+        } catch (Exception e) {
+            try {
+                socketChannel.close();
+            } finally {
+                key.cancel();
+            }
+            throw new IOException("Channel could not be created for socket " + socketChannel, e);
+        }
         key.attach(channel);
         this.channels.put(id, channel);
+        return channel;
     }
 
     /**
@@ -292,15 +297,24 @@ public class Selector implements Selectable, AutoCloseable {
      */
     public void send(Send send) {
         String connectionId = send.destination();
-        if (closingChannels.containsKey(connectionId))
+        KafkaChannel channel = openOrClosingChannelOrFail(connectionId);
+        if (closingChannels.containsKey(connectionId)) {
+            // ensure notification via `disconnected`, leave channel in the state in which closing was triggered
             this.failedSends.add(connectionId);
-        else {
-            KafkaChannel channel = channelOrFail(connectionId, false);
+        } else {
             try {
                 channel.setSend(send);
-            } catch (CancelledKeyException e) {
+            } catch (Exception e) {
+                // update the state for consistency, the channel will be discarded after `close`
+                channel.state(ChannelState.FAILED_SEND);
+                // ensure notification via `disconnected`
                 this.failedSends.add(connectionId);
                 close(channel, false);
+                if (!(e instanceof CancelledKeyException)) {
+                    log.error("Unexpected exception during send, closing connection {} and rethrowing exception {}",
+                            connectionId, e);
+                    throw e;
+                }
             }
         }
     }
@@ -535,7 +549,7 @@ public class Selector implements Selectable, AutoCloseable {
 
     @Override
     public void mute(String id) {
-        KafkaChannel channel = channelOrFail(id, true);
+        KafkaChannel channel = openOrClosingChannelOrFail(id);
         mute(channel);
     }
 
@@ -546,7 +560,7 @@ public class Selector implements Selectable, AutoCloseable {
 
     @Override
     public void unmute(String id) {
-        KafkaChannel channel = channelOrFail(id, true);
+        KafkaChannel channel = openOrClosingChannelOrFail(id);
         unmute(channel);
     }
 
@@ -603,12 +617,8 @@ public class Selector implements Selectable, AutoCloseable {
                 it.remove();
             }
         }
-        for (String channel : this.failedSends) {
-            KafkaChannel failedChannel = closingChannels.get(channel);
-            if (failedChannel != null)
-                failedChannel.state(ChannelState.FAILED_SEND);
+        for (String channel : this.failedSends)
             this.disconnected.put(channel, ChannelState.FAILED_SEND);
-        }
         this.failedSends.clear();
         this.madeReadProgressLastPoll = false;
     }
@@ -641,6 +651,11 @@ public class Selector implements Selectable, AutoCloseable {
             // channel state here anyway to avoid confusion.
             channel.state(ChannelState.LOCAL_CLOSE);
             close(channel, false);
+        } else {
+            KafkaChannel closingChannel = this.closingChannels.remove(id);
+            // Close any closing channel, leave the channel in the state in which closing was triggered
+            if (closingChannel != null)
+                doClose(closingChannel, false);
         }
     }
 
@@ -702,9 +717,9 @@ public class Selector implements Selectable, AutoCloseable {
         return channel != null && channel.ready();
     }
 
-    private KafkaChannel channelOrFail(String id, boolean maybeClosing) {
+    private KafkaChannel openOrClosingChannelOrFail(String id) {
         KafkaChannel channel = this.channels.get(id);
-        if (channel == null && maybeClosing)
+        if (channel == null)
             channel = this.closingChannels.get(id);
         if (channel == null)
             throw new IllegalStateException("Attempt to retrieve channel for which there is no connection. Connection id " + id + " existing connections " + channels.keySet());

http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
index 9671b36..e88a4ee 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.common.network;
 
+import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.memory.MemoryPool;
 import org.apache.kafka.common.memory.SimpleMemoryPool;
 import org.apache.kafka.common.metrics.Metrics;
@@ -43,6 +44,8 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -87,6 +90,10 @@ public class SelectorTest {
         this.metrics.close();
     }
 
+    public SecurityProtocol securityProtocol() {
+        return SecurityProtocol.PLAINTEXT;
+    }
+
     /**
      * Validate that when the server disconnects, a client send ends up with that node in the disconnected list.
      */
@@ -111,13 +118,20 @@ public class SelectorTest {
     /**
      * Sending a request with one already in flight should result in an exception
      */
-    @Test(expected = IllegalStateException.class)
+    @Test
     public void testCantSendWithInProgress() throws Exception {
         String node = "0";
         blockingConnect(node);
         selector.send(createSend(node, "test1"));
-        selector.send(createSend(node, "test2"));
-        selector.poll(1000L);
+        try {
+            selector.send(createSend(node, "test2"));
+            fail("IllegalStateException not thrown when sending a request with one in flight");
+        } catch (IllegalStateException e) {
+            // Expected exception
+        }
+        selector.poll(0);
+        assertTrue("Channel not closed", selector.disconnected().containsKey(node));
+        assertEquals(ChannelState.FAILED_SEND, selector.disconnected().get(node));
     }
 
     /**
@@ -274,6 +288,48 @@ public class SelectorTest {
         assertEquals("The response should be from the previously muted node", "1", selector.completedReceives().get(0).source());
     }
 
+    @Test
+    public void registerFailure() throws Exception {
+        ChannelBuilder channelBuilder = new PlaintextChannelBuilder() {
+            @Override
+            public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize,
+                    MemoryPool memoryPool) throws KafkaException {
+                throw new RuntimeException("Test exception");
+            }
+            @Override
+            public void close() {
+            }
+        };
+        Selector selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder);
+        SocketChannel socketChannel = SocketChannel.open();
+        socketChannel.configureBlocking(false);
+        try {
+            selector.register("1", socketChannel);
+            fail("Register did not fail");
+        } catch (IOException e) {
+            assertTrue("Unexpected exception: " + e, e.getCause().getMessage().contains("Test exception"));
+            assertFalse("Socket not closed", socketChannel.isOpen());
+        }
+        selector.close();
+    }
+
+    @Test
+    public void testCloseConnectionInClosingState() throws Exception {
+        KafkaChannel channel = createConnectionWithStagedReceives(5);
+        String id = channel.id();
+        time.sleep(6000); // The max idle time is 5000ms
+        selector.poll(0);
+        assertEquals(channel, selector.closingChannel(id));
+        assertNull("Channel not expired", selector.channel(id));
+        assertEquals(ChannelState.EXPIRED, channel.state());
+        selector.close(id);
+        assertNull("Channel not removed from channels", selector.channel(id));
+        assertNull("Channel not removed from closingChannels", selector.closingChannel(id));
+        assertTrue("Unexpected disconnect notification", selector.disconnected().isEmpty());
+        assertEquals(ChannelState.EXPIRED, channel.state());
+        selector.poll(0);
+        assertTrue("Unexpected disconnect notification", selector.disconnected().isEmpty());
+    }
 
     @Test
     public void testCloseOldestConnection() throws Exception {
@@ -297,22 +353,32 @@ public class SelectorTest {
         verifyCloseOldestConnectionWithStagedReceives(5);
     }
 
-    private void verifyCloseOldestConnectionWithStagedReceives(int maxStagedReceives) throws Exception {
+    private KafkaChannel createConnectionWithStagedReceives(int maxStagedReceives) throws Exception {
         String id = "0";
         blockingConnect(id);
         KafkaChannel channel = selector.channel(id);
+        int retries = 100;
 
-        selector.mute(id);
-        for (int i = 0; i <= maxStagedReceives; i++) {
-            selector.send(createSend(id, String.valueOf(i)));
-            selector.poll(1000);
-        }
-
-        selector.unmute(id);
         do {
-            selector.poll(1000);
-        } while (selector.completedReceives().isEmpty());
+            selector.mute(id);
+            for (int i = 0; i <= maxStagedReceives; i++) {
+                selector.send(createSend(id, String.valueOf(i)));
+                selector.poll(1000);
+            }
+
+            selector.unmute(id);
+            do {
+                selector.poll(1000);
+            } while (selector.completedReceives().isEmpty());
+        } while (selector.numStagedReceives(channel) == 0 && --retries > 0);
+        assertTrue("No staged receives after 100 attempts", selector.numStagedReceives(channel) > 0);
 
+        return channel;
+    }
+
+    private void verifyCloseOldestConnectionWithStagedReceives(int maxStagedReceives) throws Exception {
+        KafkaChannel channel = createConnectionWithStagedReceives(maxStagedReceives);
+        String id = channel.id();
         int stagedReceives = selector.numStagedReceives(channel);
         int completedReceives = 0;
         while (selector.disconnected().isEmpty()) {

http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
----------------------------------------------------------------------
diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
index 46d3b79..bada4d4 100644
--- a/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java
@@ -75,6 +75,11 @@ public class SslSelectorTest extends SelectorTest {
         this.metrics.close();
     }
 
+    @Override
+    public SecurityProtocol securityProtocol() {
+        return SecurityProtocol.PLAINTEXT;
+    }
+
     /**
      * Tests that SSL renegotiation initiated by the server are handled correctly by the client
      * @throws Exception

http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/core/src/main/scala/kafka/network/SocketServer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala
index 3d54c8a..c93fb55 100644
--- a/core/src/main/scala/kafka/network/SocketServer.scala
+++ b/core/src/main/scala/kafka/network/SocketServer.scala
@@ -35,7 +35,7 @@ import org.apache.kafka.common.errors.InvalidRequestException
 import org.apache.kafka.common.memory.{MemoryPool, SimpleMemoryPool}
 import org.apache.kafka.common.metrics._
 import org.apache.kafka.common.metrics.stats.Rate
-import org.apache.kafka.common.network.{ChannelBuilders, KafkaChannel, ListenerName, Selectable, Send, Selector => KSelector}
+import org.apache.kafka.common.network.{ChannelBuilder, ChannelBuilders, KafkaChannel, ListenerName, Selectable, Send, Selector => KSelector}
 import org.apache.kafka.common.protocol.SecurityProtocol
 import org.apache.kafka.common.security.auth.KafkaPrincipal
 import org.apache.kafka.common.protocol.types.SchemaException
@@ -44,7 +44,7 @@ import org.apache.kafka.common.utils.{KafkaThread, Time}
 
 import scala.collection._
 import JavaConverters._
-import scala.util.control.{ControlThrowable, NonFatal}
+import scala.util.control.ControlThrowable
 
 /**
  * An NIO socket server. The threading model is
@@ -229,20 +229,6 @@ private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQ
   protected def isRunning: Boolean = alive.get
 
   /**
-   * Close the connection identified by `connectionId` and decrement the connection count.
-   */
-  def close(selector: KSelector, connectionId: String): Unit = {
-    val channel = selector.channel(connectionId)
-    if (channel != null) {
-      debug(s"Closing selector connection $connectionId")
-      val address = channel.socketAddress
-      if (address != null)
-        connectionQuotas.dec(address)
-      selector.close(connectionId)
-    }
-  }
-
-  /**
    * Close `channel` and decrement the connection count.
    */
   def close(channel: SocketChannel): Unit = {
@@ -431,7 +417,10 @@ private[kafka] class Processor(val id: Int,
     Map("networkProcessor" -> id.toString)
   )
 
-  private val selector = new KSelector(
+  private val selector = createSelector(
+      ChannelBuilders.serverChannelBuilder(listenerName, securityProtocol, config, credentialProvider.credentialCache))
+  // Visible to override for testing
+  protected[network] def createSelector(channelBuilder: ChannelBuilder): KSelector = new KSelector(
     maxRequestSize,
     connectionsMaxIdleMs,
     metrics,
@@ -440,7 +429,7 @@ private[kafka] class Processor(val id: Int,
     metricTags,
     false,
     true,
-    ChannelBuilders.serverChannelBuilder(listenerName, securityProtocol, config, credentialProvider.credentialCache),
+    channelBuilder,
     memoryPool)
 
   // Connection ids have the format `localAddr:localPort-remoteAddr:remotePort-index`. The index is a
@@ -450,35 +439,53 @@ private[kafka] class Processor(val id: Int,
 
   override def run() {
     startupComplete()
-    while (isRunning) {
-      try {
-        // setup any new connections that have been queued up
-        configureNewConnections()
-        // register any new responses for writing
-        processNewResponses()
-        poll()
-        processCompletedReceives()
-        processCompletedSends()
-        processDisconnected()
-      } catch {
-        // We catch all the throwables here to prevent the processor thread from exiting. We do this because
-        // letting a processor exit might cause a bigger impact on the broker. Usually the exceptions thrown would
-        // be either associated with a specific socket channel or a bad request. We just ignore the bad socket channel
-        // or request. This behavior might need to be reviewed if we see an exception that need the entire broker to stop.
-        case e: ControlThrowable => throw e
-        case e: Throwable =>
-          error("Processor got uncaught exception.", e)
+    try {
+      while (isRunning) {
+        try {
+          // setup any new connections that have been queued up
+          configureNewConnections()
+          // register any new responses for writing
+          processNewResponses()
+          poll()
+          processCompletedReceives()
+          processCompletedSends()
+          processDisconnected()
+        } catch {
+          // We catch all the throwables here to prevent the processor thread from exiting. We do this because
+          // letting a processor exit might cause a bigger impact on the broker. This behavior might need to be
+          // reviewed if we see an exception that needs the entire broker to stop. Usually the exceptions thrown would
+          // be either associated with a specific socket channel or a bad request. These exceptions are caught and
+          // processed by the individual methods above which close the failing channel and continue processing other
+          // channels. So this catch block should only ever see ControlThrowables.
+          case e: Throwable => processException("Processor got uncaught exception.", e)
+        }
       }
+    } finally {
+      debug("Closing selector - processor " + id)
+      swallowError(closeAll())
+      shutdownComplete()
+    }
+  }
+
+  private def processException(errorMessage: String, throwable: Throwable) {
+    throwable match {
+      case e: ControlThrowable => throw e
+      case e => error(errorMessage, e)
     }
+  }
 
-    debug("Closing selector - processor " + id)
-    swallowError(closeAll())
-    shutdownComplete()
+  private def processChannelException(channelId: String, errorMessage: String, throwable: Throwable) {
+    if (openOrClosingChannel(channelId).isDefined) {
+      error(s"Closing socket for ${channelId} because of error", throwable)
+      close(channelId)
+    }
+    processException(errorMessage, throwable)
   }
 
   private def processNewResponses() {
-    var curr = requestChannel.receiveResponse(id)
-    while (curr != null) {
+    var curr: RequestChannel.Response = null
+    while ({curr = requestChannel.receiveResponse(id); curr != null}) {
+      val channelId = curr.request.context.connectionId
       try {
         curr.responseAction match {
           case RequestChannel.NoOpAction =>
@@ -486,9 +493,7 @@ private[kafka] class Processor(val id: Int,
             // that are sitting in the server's socket buffer
             updateRequestMetrics(curr)
             trace("Socket server received empty response to send, registering for read: " + curr)
-            val channelId = curr.request.context.connectionId
-            if (selector.channel(channelId) != null || selector.closingChannel(channelId) != null)
-                selector.unmute(channelId)
+            openOrClosingChannel(channelId).foreach(c => selector.unmute(c.id))
           case RequestChannel.SendAction =>
             val responseSend = curr.responseSend.getOrElse(
               throw new IllegalStateException(s"responseSend must be defined for SendAction, response: $curr"))
@@ -496,10 +501,11 @@ private[kafka] class Processor(val id: Int,
           case RequestChannel.CloseConnectionAction =>
             updateRequestMetrics(curr)
             trace("Closing socket connection actively according to the response code.")
-            close(selector, curr.request.context.connectionId)
+            close(channelId)
         }
-      } finally {
-        curr = requestChannel.receiveResponse(id)
+      } catch {
+        case e: Throwable =>
+          processChannelException(channelId, s"Exception while processing response for $channelId", e)
       }
     }
   }
@@ -526,43 +532,50 @@ private[kafka] class Processor(val id: Int,
     try selector.poll(300)
     catch {
       case e @ (_: IllegalStateException | _: IOException) =>
-        error(s"Closing processor $id due to illegal state or IO exception")
-        swallow(closeAll())
-        shutdownComplete()
-        throw e
+        // The exception is not re-thrown and any completed sends/receives/connections/disconnections
+        // from this poll will be processed.
+        error(s"Processor $id poll failed due to illegal state or IO exception")
     }
   }
 
   private def processCompletedReceives() {
     selector.completedReceives.asScala.foreach { receive =>
       try {
-        val openChannel = selector.channel(receive.source)
-        // Only methods that are safe to call on a disconnected channel should be invoked on 'openOrClosingChannel'.
-        val openOrClosingChannel = if (openChannel != null) openChannel else selector.closingChannel(receive.source)
-        val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, openOrClosingChannel.principal.getName)
-        val header = RequestHeader.parse(receive.payload)
-        val context = new RequestContext(header, receive.source, openOrClosingChannel.socketAddress,
-          principal, listenerName, securityProtocol)
-        val req = new RequestChannel.Request(processor = id, context = context,
-          startTimeNanos = time.nanoseconds, memoryPool, receive.payload)
-        requestChannel.sendRequest(req)
-        selector.mute(receive.source)
+        openOrClosingChannel(receive.source) match {
+          case Some(channel) =>
+            val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, channel.principal.getName)
+            val header = RequestHeader.parse(receive.payload)
+            val context = new RequestContext(header, receive.source, channel.socketAddress,
+              principal, listenerName, securityProtocol)
+            val req = new RequestChannel.Request(processor = id, context = context,
+              startTimeNanos = time.nanoseconds, memoryPool, receive.payload)
+            requestChannel.sendRequest(req)
+            selector.mute(receive.source)
+          case None =>
+            // This should never happen since completed receives are processed immediately after `poll()`
+            throw new IllegalStateException(s"Channel ${receive.source} removed from selector before processing completed receive")
+        }
       } catch {
-        case e @ (_: InvalidRequestException | _: SchemaException) =>
-          // note that even though we got an exception, we can assume that receive.source is valid. Issues with constructing a valid receive object were handled earlier
-          error(s"Closing socket for ${receive.source} because of error", e)
-          close(selector, receive.source)
+        // note that even though we got an exception, we can assume that receive.source is valid.
+        // Issues with constructing a valid receive object were handled earlier
+        case e: Throwable =>
+          processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e)
       }
     }
   }
 
   private def processCompletedSends() {
     selector.completedSends.asScala.foreach { send =>
-      val resp = inflightResponses.remove(send.destination).getOrElse {
-        throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`")
+      try {
+        val resp = inflightResponses.remove(send.destination).getOrElse {
+          throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`")
+        }
+        updateRequestMetrics(resp)
+        selector.unmute(send.destination)
+      } catch {
+        case e: Throwable => processChannelException(send.destination,
+            s"Exception while processing completed send to ${send.destination}", e)
       }
-      updateRequestMetrics(resp)
-      selector.unmute(send.destination)
     }
   }
 
@@ -574,12 +587,35 @@ private[kafka] class Processor(val id: Int,
 
   private def processDisconnected() {
     selector.disconnected.keySet.asScala.foreach { connectionId =>
-      val remoteHost = ConnectionId.fromString(connectionId).getOrElse {
-        throw new IllegalStateException(s"connectionId has unexpected format: $connectionId")
-      }.remoteHost
-      inflightResponses.remove(connectionId).foreach(updateRequestMetrics)
-      // the channel has been closed by the selector but the quotas still need to be updated
-      connectionQuotas.dec(InetAddress.getByName(remoteHost))
+      try {
+        val remoteHost = ConnectionId.fromString(connectionId).getOrElse {
+          throw new IllegalStateException(s"connectionId has unexpected format: $connectionId")
+        }.remoteHost
+        inflightResponses.remove(connectionId).foreach(updateRequestMetrics)
+        // the channel has been closed by the selector but the quotas still need to be updated
+        connectionQuotas.dec(InetAddress.getByName(remoteHost))
+      } catch {
+        case e: Throwable => processException(s"Exception while processing disconnection of $connectionId", e)
+      }
+    }
+  }
+
+  /**
+   * Close the connection identified by `connectionId` and decrement the connection count.
+   * The channel will be immediately removed from the selector's `channels` or `closingChannels`
+   * and no further disconnect notifications will be sent for this channel by the selector.
+   * If responses are pending for the channel, they are dropped and metrics is updated.
+   * If the channel has already been removed from selector, no action is taken.
+   */
+  private def close(connectionId: String): Unit = {
+    openOrClosingChannel(connectionId).foreach { channel =>
+      debug(s"Closing selector connection $connectionId")
+      val address = channel.socketAddress
+      if (address != null)
+        connectionQuotas.dec(address)
+      selector.close(connectionId)
+
+      inflightResponses.remove(connectionId).foreach(response => updateRequestMetrics(response))
     }
   }
 
@@ -601,13 +637,12 @@ private[kafka] class Processor(val id: Int,
         debug(s"Processor $id listening to new connection from ${channel.socket.getRemoteSocketAddress}")
         selector.register(connectionId(channel.socket), channel)
       } catch {
-        // We explicitly catch all non fatal exceptions and close the socket to avoid a socket leak. The other
-        // throwables will be caught in processor and logged as uncaught exceptions.
-        case NonFatal(e) =>
-          val remoteAddress = channel.getRemoteAddress
+        // We explicitly catch all exceptions and close the socket to avoid a socket leak.
+        case e: Throwable =>
+          val remoteAddress = channel.socket.getRemoteSocketAddress
           // need to close the channel here to avoid a socket leak.
           close(channel)
-          error(s"Processor $id closed connection from $remoteAddress", e)
+          processException(s"Processor $id closed connection from $remoteAddress", e)
       }
     }
   }
@@ -617,7 +652,7 @@ private[kafka] class Processor(val id: Int,
    */
   private def closeAll() {
     selector.channels.asScala.foreach { channel =>
-      close(selector, channel.id)
+      close(channel.id)
     }
     selector.close()
   }
@@ -633,10 +668,13 @@ private[kafka] class Processor(val id: Int,
     connId
   }
 
+  // Only for testing
+  private[network] def inflightResponseCount: Int = inflightResponses.size
+
   // Visible for testing
-  private[network] def openOrClosingChannel(connectionId: String): Option[KafkaChannel] = {
-    Option(selector.channel(connectionId)).orElse(Option(selector.closingChannel(connectionId)))
-  }
+  // Only methods that are safe to call on a disconnected channel should be invoked on 'openOrClosingChannel'.
+  private[network] def openOrClosingChannel(connectionId: String): Option[KafkaChannel] =
+     Option(selector.channel(connectionId)).orElse(Option(selector.closingChannel(connectionId)))
 
   /* For test usage */
   private[network] def channel(connectionId: String): Option[KafkaChannel] =

http://git-wip-us.apache.org/repos/asf/kafka/blob/6ebd9d35/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 9b32635..227714b 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -20,6 +20,7 @@ package kafka.network
 import java.io._
 import java.net._
 import java.nio.ByteBuffer
+import java.nio.channels.SocketChannel
 import java.util.{HashMap, Random}
 import javax.net.ssl._
 
@@ -32,7 +33,7 @@ import kafka.utils.TestUtils
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.memory.MemoryPool
 import org.apache.kafka.common.metrics.Metrics
-import org.apache.kafka.common.network.{KafkaChannel, ListenerName, NetworkSend, Send}
+import org.apache.kafka.common.network.{ChannelBuilder, ChannelState, KafkaChannel, ListenerName, NetworkReceive, NetworkSend, Selector, Send}
 import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol}
 import org.apache.kafka.common.record.{MemoryRecords, RecordBatch}
 import org.apache.kafka.common.requests.{AbstractRequest, ProduceRequest, RequestHeader}
@@ -43,7 +44,9 @@ import org.junit._
 import org.scalatest.junit.JUnitSuite
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.util.control.ControlThrowable
 
 class SocketServerTest extends JUnitSuite {
   val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0)
@@ -58,6 +61,7 @@ class SocketServerTest extends JUnitSuite {
   val config = KafkaConfig.fromProps(props)
   val metrics = new Metrics
   val credentialProvider = new CredentialProvider(config.saslEnabledMechanisms)
+  val localAddress = InetAddress.getLoopbackAddress
 
   // Clean-up any metrics left around by previous tests
   for (metricName <- YammerMetrics.defaultRegistry.allMetrics.keySet.asScala)
@@ -116,6 +120,19 @@ class SocketServerTest extends JUnitSuite {
     socket
   }
 
+  // Create a client connection, process one request and return (client socket, connectionId)
+  def connectAndProcessRequest(s: SocketServer): (Socket, String) = {
+    val socket = connect(s)
+    val request = sendAndReceiveRequest(socket, s)
+    processRequest(s.requestChannel, request)
+    (socket, request.context.connectionId)
+  }
+
+  def sendAndReceiveRequest(socket: Socket, server: SocketServer): RequestChannel.Request = {
+    sendRequest(socket, producerRequestBytes)
+    receiveRequest(server.requestChannel)
+  }
+
   @After
   def tearDown() {
     metrics.close()
@@ -259,7 +276,9 @@ class SocketServerTest extends JUnitSuite {
     val idleTimeMs = 60000
     val time = new MockTime()
     props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString)
+    props.put("listeners", "PLAINTEXT://localhost:0")
     val serverMetrics = new Metrics
+    @volatile var selector: TestableSelector = null
     val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0"
     val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, time, credentialProvider) {
       override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
@@ -267,6 +286,11 @@ class SocketServerTest extends JUnitSuite {
         new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas,
           config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool) {
           override protected[network] def connectionId(socket: Socket): String = overrideConnectionId
+          override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = {
+           val testableSelector = new TestableSelector(config, channelBuilder, time, metrics)
+           selector = testableSelector
+           testableSelector
+        }
         }
       }
     }
@@ -275,15 +299,26 @@ class SocketServerTest extends JUnitSuite {
     def openOrClosingChannel: Option[KafkaChannel] = overrideServer.processor(0).openOrClosingChannel(overrideConnectionId)
     def connectionCount = overrideServer.connectionCount(InetAddress.getByName("127.0.0.1"))
 
+    // Create a client connection and wait for server to register the connection with the selector. For
+    // test scenarios below where `Selector.register` fails, the wait ensures that checks are performed
+    // only after `register` is processed by the server.
+    def connectAndWaitForConnectionRegister(): Socket = {
+      val connections = selector.operationCounts(SelectorOperation.Register)
+      val socket = connect(overrideServer)
+      TestUtils.waitUntilTrue(() =>
+        selector.operationCounts(SelectorOperation.Register) == connections + 1, "Connection not registered")
+      socket
+    }
+
     try {
       overrideServer.startup()
-      val socket1 = connect(overrideServer)
+      val socket1 = connectAndWaitForConnectionRegister()
       TestUtils.waitUntilTrue(() => connectionCount == 1 && openChannel.isDefined, "Failed to create channel")
       val channel1 = openChannel.getOrElse(throw new RuntimeException("Channel not found"))
 
       // Create new connection with same id when `channel1` is still open and in Selector.channels
       // Check that new connection is closed and openChannel still contains `channel1`
-      connect(overrideServer)
+      connectAndWaitForConnectionRegister()
       TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel")
       assertSame(channel1, openChannel.getOrElse(throw new RuntimeException("Channel not found")))
 
@@ -297,7 +332,7 @@ class SocketServerTest extends JUnitSuite {
 
       // Create new connection with same id when when `channel1` is in Selector.closingChannels
       // Check that new connection is closed and openOrClosingChannel still contains `channel1`
-      connect(overrideServer)
+      connectAndWaitForConnectionRegister()
       TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel")
       assertSame(channel1, openOrClosingChannel.getOrElse(throw new RuntimeException("Channel not found")))
 
@@ -306,7 +341,7 @@ class SocketServerTest extends JUnitSuite {
       TestUtils.waitUntilTrue(() => connectionCount == 0 && openOrClosingChannel.isEmpty, "Failed to remove channel with failed send")
 
       // Check that new connections can be created with the same id since `channel1` is no longer in Selector
-      connect(overrideServer)
+      connectAndWaitForConnectionRegister()
       TestUtils.waitUntilTrue(() => connectionCount == 1 && openChannel.isDefined, "Failed to open new channel")
       val newChannel = openChannel.getOrElse(throw new RuntimeException("Channel not found"))
       assertNotSame(channel1, newChannel)
@@ -593,4 +628,475 @@ class SocketServerTest extends JUnitSuite {
     }
   }
 
+  /**
+   * Tests exception handling in [[Processor.configureNewConnections]]. Exception is
+   * injected into [[Selector.register]] which is used to register each new connection.
+   * Test creates two connections in a single iteration by waking up the selector only
+   * when two connections are ready.
+   * Verifies that
+   * - first failed connection is closed
+   * - second connection is processed successfully after the first fails with an exception
+   * - processor is healthy after the exception
+   */
+  @Test
+  def configureNewConnectionException(): Unit = {
+    withTestableServer { testableServer =>
+      val testableSelector = testableServer.testableSelector
+
+      testableSelector.updateMinWakeup(2)
+      testableSelector.addFailure(SelectorOperation.Register)
+      val sockets = (1 to 2).map(_ => connect(testableServer))
+      testableSelector.waitForOperations(SelectorOperation.Register, 2)
+      TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 1, "Failed channel not removed")
+
+      assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+    }
+  }
+
+  /**
+   * Tests exception handling in [[Processor.processNewResponses]]. Exception is
+   * injected into [[Selector.send]] which is used to send the new response.
+   * Test creates two responses in a single iteration by waking up the selector only
+   * when two responses are ready.
+   * Verifies that
+   * - first failed channel is closed
+   * - second response is processed successfully after the first fails with an exception
+   * - processor is healthy after the exception
+   */
+  @Test
+  def processNewResponseException(): Unit = {
+    withTestableServer { testableServer =>
+      val testableSelector = testableServer.testableSelector
+      testableSelector.updateMinWakeup(2)
+
+      val sockets = (1 to 2).map(_ => connect(testableServer))
+      sockets.foreach(sendRequest(_, producerRequestBytes))
+
+      testableServer.testableSelector.addFailure(SelectorOperation.Send)
+      sockets.foreach(_ => processRequest(testableServer.requestChannel))
+      testableSelector.waitForOperations(SelectorOperation.Send, 2)
+      testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true)
+
+      assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+    }
+  }
+
+  /**
+   * Tests exception handling in [[Processor.processNewResponses]] when [[Selector.send]]
+   * fails with `CancelledKeyException`, which is handled by the selector using a different
+   * code path. Test scenario is similar to [[SocketServerTest.processNewResponseException]].
+   */
+  @Test
+  def sendCancelledKeyException(): Unit = {
+    withTestableServer { testableServer =>
+      val testableSelector = testableServer.testableSelector
+      testableSelector.updateMinWakeup(2)
+
+      val sockets = (1 to 2).map(_ => connect(testableServer))
+      sockets.foreach(sendRequest(_, producerRequestBytes))
+      val requestChannel = testableServer.requestChannel
+
+      val requests = sockets.map(_ => receiveRequest(requestChannel))
+      val failedConnectionId = requests(0).context.connectionId
+      // `KafkaChannel.disconnect()` cancels the selection key, triggering CancelledKeyException during send
+      testableSelector.channel(failedConnectionId).disconnect()
+      requests.foreach(processRequest(requestChannel, _))
+      testableSelector.waitForOperations(SelectorOperation.Send, 2)
+      testableServer.waitForChannelClose(failedConnectionId, locallyClosed = false)
+
+      val successfulSocket = if (isSocketConnectionId(failedConnectionId, sockets(0))) sockets(1) else sockets(0)
+      assertProcessorHealthy(testableServer, Seq(successfulSocket))
+    }
+  }
+
+  /**
+   * Tests exception handling in [[Processor.processNewResponses]] when [[Selector.send]]
+   * to a channel in closing state throws an exception. Test scenario is similar to
+   * [[SocketServerTest.processNewResponseException]].
+   */
+  @Test
+  def closingChannelException(): Unit = {
+    withTestableServer { testableServer =>
+      val testableSelector = testableServer.testableSelector
+      testableSelector.updateMinWakeup(2)
+
+      val sockets = (1 to 2).map(_ => connect(testableServer))
+      val serializedBytes = producerRequestBytes
+      val request = sendRequestsUntilStagedReceive(testableServer, sockets(0), serializedBytes)
+      sendRequest(sockets(1), serializedBytes)
+
+      testableSelector.addFailure(SelectorOperation.Send)
+      sockets(0).close()
+      processRequest(testableServer.requestChannel, request)
+      processRequest(testableServer.requestChannel) // Also process request from other channel
+      testableSelector.waitForOperations(SelectorOperation.Send, 2)
+      testableServer.waitForChannelClose(request.context.connectionId, locallyClosed = true)
+
+      assertProcessorHealthy(testableServer, Seq(sockets(1)))
+    }
+  }
+
+  /**
+   * Tests exception handling in [[Processor.processCompletedReceives]]. Exception is
+   * injected into [[Selector.mute]] which is used to mute the channel when a receive is complete.
+   * Test creates two receives in a single iteration by caching completed receives until two receives
+   * are complete.
+   * Verifies that
+   * - first failed channel is closed
+   * - second receive is processed successfully after the first fails with an exception
+   * - processor is healthy after the exception
+   */
+  @Test
+  def processCompletedReceiveException(): Unit = {
+    withTestableServer { testableServer =>
+      val sockets = (1 to 2).map(_ => connect(testableServer))
+      val testableSelector = testableServer.testableSelector
+      val requestChannel = testableServer.requestChannel
+
+      testableSelector.cachedCompletedReceives.minPerPoll = 2
+      testableSelector.addFailure(SelectorOperation.Mute)
+      sockets.foreach(sendRequest(_, producerRequestBytes))
+      val requests = sockets.map(_ => receiveRequest(requestChannel))
+      testableSelector.waitForOperations(SelectorOperation.Mute, 2)
+      testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true)
+      requests.foreach(processRequest(requestChannel, _))
+
+      assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+    }
+  }
+
+  /**
+   * Tests exception handling in [[Processor.processCompletedSends]]. Exception is
+   * injected into [[Selector.unmute]] which is used to unmute the channel after send is complete.
+   * Test creates two completed sends in a single iteration by caching completed sends until two
+   * sends are complete.
+   * Verifies that
+   * - first failed channel is closed
+   * - second send is processed successfully after the first fails with an exception
+   * - processor is healthy after the exception
+   */
+  @Test
+  def processCompletedSendException(): Unit = {
+    withTestableServer { testableServer =>
+      val testableSelector = testableServer.testableSelector
+      val sockets = (1 to 2).map(_ => connect(testableServer))
+      val requests = sockets.map(sendAndReceiveRequest(_, testableServer))
+
+      testableSelector.addFailure(SelectorOperation.Unmute)
+      requests.foreach(processRequest(testableServer.requestChannel, _))
+      testableSelector.waitForOperations(SelectorOperation.Unmute, 2)
+      testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true)
+
+      assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets))
+    }
+  }
+
+  /**
+   * Tests exception handling in [[Processor.processDisconnected]]. An invalid connectionId
+   * is inserted to the disconnected list just before the actual valid one.
+   * Verifies that
+   * - first invalid connectionId is ignored
+   * - second disconnected channel is processed successfully after the first fails with an exception
+   * - processor is healthy after the exception
+   */
+  @Test
+  def processDisconnectedException(): Unit = {
+    withTestableServer { testableServer =>
+      val (socket, connectionId) = connectAndProcessRequest(testableServer)
+      val testableSelector = testableServer.testableSelector
+
+      // Add an invalid connectionId to `Selector.disconnected` list before the actual disconnected channel
+      // and check that the actual id is processed and the invalid one ignored.
+      testableSelector.cachedDisconnected.minPerPoll = 2
+      testableSelector.cachedDisconnected.deferredValues += "notAValidConnectionId" -> ChannelState.EXPIRED
+      socket.close()
+      testableSelector.operationCounts.clear()
+      testableSelector.waitForOperations(SelectorOperation.Poll, 1)
+      testableServer.waitForChannelClose(connectionId, locallyClosed = false)
+
+      assertProcessorHealthy(testableServer)
+    }
+  }
+
+  /**
+   * Tests that `Processor` continues to function correctly after a failed [[Selector.poll]].
+   */
+  @Test
+  def pollException(): Unit = {
+    withTestableServer { testableServer =>
+      val (socket, _) = connectAndProcessRequest(testableServer)
+      val testableSelector = testableServer.testableSelector
+
+      testableSelector.addFailure(SelectorOperation.Poll)
+      testableSelector.operationCounts.clear()
+      testableSelector.waitForOperations(SelectorOperation.Poll, 2)
+
+      assertProcessorHealthy(testableServer, Seq(socket))
+    }
+  }
+
+  /**
+   * Tests handling of `ControlThrowable`. Verifies that the selector is closed.
+   */
+  @Test
+  def controlThrowable(): Unit = {
+    withTestableServer { testableServer =>
+      val (socket, _) = connectAndProcessRequest(testableServer)
+      val testableSelector = testableServer.testableSelector
+
+      testableSelector.operationCounts.clear()
+      testableSelector.addFailure(SelectorOperation.Poll,
+          Some(new RuntimeException("ControlThrowable exception during poll()") with ControlThrowable))
+      testableSelector.waitForOperations(SelectorOperation.Poll, 1)
+
+      testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1)
+    }
+  }
+
+  private def withTestableServer(testWithServer: TestableSocketServer => Unit): Unit = {
+    props.put("listeners", "PLAINTEXT://localhost:0")
+    val testableServer = new TestableSocketServer
+    testableServer.startup()
+    try {
+        testWithServer(testableServer)
+    } finally {
+      testableServer.shutdown()
+      testableServer.metrics.close()
+    }
+  }
+
+  private def assertProcessorHealthy(testableServer: TestableSocketServer, healthySockets: Seq[Socket] = Seq.empty): Unit = {
+    val selector = testableServer.testableSelector
+    selector.reset()
+    val requestChannel = testableServer.requestChannel
+
+    // Check that existing channels behave as expected
+    healthySockets.foreach { socket =>
+      val request = sendAndReceiveRequest(socket, testableServer)
+      processRequest(requestChannel, request)
+      socket.close()
+    }
+    TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 0, "Channels not removed")
+
+    // Check new channel behaves as expected
+    val (socket, connectionId) = connectAndProcessRequest(testableServer)
+    assertArrayEquals(producerRequestBytes, receiveResponse(socket))
+    assertNotNull("Channel should not have been closed", selector.channel(connectionId))
+    assertNull("Channel should not be closing", selector.closingChannel(connectionId))
+    socket.close()
+    TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 0, "Channels not removed")
+  }
+
+  // Since all sockets use the same local host, it is sufficient to check the local port
+  def isSocketConnectionId(connectionId: String, socket: Socket): Boolean =
+    connectionId.contains(s":${socket.getLocalPort}-")
+
+  class TestableSocketServer extends SocketServer(KafkaConfig.fromProps(props),
+      new Metrics, Time.SYSTEM, credentialProvider) {
+
+    @volatile var selector: Option[TestableSelector] = None
+
+    override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
+                                protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = {
+      new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas,
+        config.connectionsMaxIdleMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool) {
+
+        override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = {
+           val testableSelector = new TestableSelector(config, channelBuilder, time, metrics)
+           assertEquals(None, selector)
+           selector = Some(testableSelector)
+           testableSelector
+        }
+      }
+    }
+
+    def testableSelector: TestableSelector =
+      selector.getOrElse(throw new IllegalStateException("Selector not created"))
+
+    def waitForChannelClose(connectionId: String, locallyClosed: Boolean): Unit = {
+      val selector = testableSelector
+      if (locallyClosed) {
+        TestUtils.waitUntilTrue(() => selector.allLocallyClosedChannels.contains(connectionId),
+            s"Channel not closed: $connectionId")
+        assertTrue("Unexpected disconnect notification", testableSelector.allDisconnectedChannels.isEmpty)
+      } else {
+        TestUtils.waitUntilTrue(() => selector.allDisconnectedChannels.contains(connectionId),
+            s"Disconnect notification not received: $connectionId")
+        assertTrue("Channel closed locally", testableSelector.allLocallyClosedChannels.isEmpty)
+      }
+      val openCount = selector.allChannels.size - 1 // minus one for the channel just closed above
+      TestUtils.waitUntilTrue(() => connectionCount(localAddress) == openCount, "Connection count not decremented")
+      TestUtils.waitUntilTrue(() =>
+        processor(0).inflightResponseCount == 0, "Inflight responses not cleared")
+      assertNull("Channel not removed", selector.channel(connectionId))
+      assertNull("Closing channel not removed", selector.closingChannel(connectionId))
+    }
+  }
+
+  sealed trait SelectorOperation
+  object SelectorOperation {
+    case object Register extends SelectorOperation
+    case object Poll extends SelectorOperation
+    case object Send extends SelectorOperation
+    case object Mute extends SelectorOperation
+    case object Unmute extends SelectorOperation
+    case object Wakeup extends SelectorOperation
+    case object Close extends SelectorOperation
+    case object CloseSelector extends SelectorOperation
+  }
+
+  class TestableSelector(config: KafkaConfig, channelBuilder: ChannelBuilder, time: Time, metrics: Metrics)
+        extends Selector(config.socketRequestMaxBytes, config.connectionsMaxIdleMs,
+            metrics, time, "socket-server", new HashMap, false, true, channelBuilder, MemoryPool.NONE) {
+
+    val failures = mutable.Map[SelectorOperation, Exception]()
+    val operationCounts = mutable.Map[SelectorOperation, Int]().withDefaultValue(0)
+    val allChannels = mutable.Set[String]()
+    val allLocallyClosedChannels = mutable.Set[String]()
+    val allDisconnectedChannels = mutable.Set[String]()
+    val allFailedChannels = mutable.Set[String]()
+
+    // Enable data from `Selector.poll()` to be deferred to a subsequent poll() until
+    // the number of elements of that type reaches `minPerPoll`. This enables tests to verify
+    // that failed processing doesn't impact subsequent processing within the same iteration.
+    class PollData[T] {
+      var minPerPoll = 1
+      val deferredValues = mutable.Buffer[T]()
+      val currentPollValues = mutable.Buffer[T]()
+      def update(newValues: mutable.Buffer[T]): Unit = {
+        if (currentPollValues.nonEmpty || deferredValues.size + newValues.size >= minPerPoll) {
+          if (deferredValues.nonEmpty) {
+            currentPollValues ++= deferredValues
+            deferredValues.clear()
+          }
+          currentPollValues ++= newValues
+        } else
+          deferredValues ++= newValues
+      }
+      def reset(): Unit = {
+        currentPollValues.clear()
+      }
+    }
+    val cachedCompletedReceives = new PollData[NetworkReceive]()
+    val cachedCompletedSends = new PollData[Send]()
+    val cachedDisconnected = new PollData[(String, ChannelState)]()
+    val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected)
+    @volatile var minWakeupCount = 0
+    @volatile var pollTimeoutOverride: Option[Long] = None
+
+    def addFailure(operation: SelectorOperation, exception: Option[Exception] = None) {
+      failures += operation ->
+        exception.getOrElse(new IllegalStateException(s"Test exception during $operation"))
+    }
+
+    private def onOperation(operation: SelectorOperation,
+        connectionId: Option[String] = None, onFailure: => Unit = {}): Unit = {
+      operationCounts(operation) += 1
+      failures.remove(operation).foreach { e =>
+        connectionId.foreach(allFailedChannels.add)
+        onFailure
+        throw e
+      }
+    }
+
+    def waitForOperations(operation: SelectorOperation, minExpectedTotal: Int): Unit = {
+      TestUtils.waitUntilTrue(() =>
+        operationCounts.getOrElse(operation, 0) >= minExpectedTotal, "Operations not performed within timeout")
+    }
+
+    def runOp[T](operation: SelectorOperation, connectionId: Option[String],
+        onFailure: => Unit = {})(code: => T): T = {
+      // If a failure is set on `operation`, throw that exception even if `code` fails
+      try code
+      finally onOperation(operation, connectionId, onFailure)
+    }
+
+    override def register(id: String, socketChannel: SocketChannel): Unit = {
+      runOp(SelectorOperation.Register, Some(id), onFailure = close(id)) {
+        super.register(id, socketChannel)
+      }
+    }
+
+    override def send(s: Send): Unit = {
+      runOp(SelectorOperation.Send, Some(s.destination)) {
+        super.send(s)
+      }
+    }
+
+    override def poll(timeout: Long): Unit = {
+      try {
+        allCachedPollData.foreach(_.reset)
+        runOp(SelectorOperation.Poll, None) {
+          super.poll(pollTimeoutOverride.getOrElse(timeout))
+        }
+      } finally {
+        super.channels.asScala.foreach(allChannels += _.id)
+        allDisconnectedChannels ++= super.disconnected.asScala.keys
+        cachedCompletedReceives.update(super.completedReceives.asScala)
+        cachedCompletedSends.update(super.completedSends.asScala)
+        cachedDisconnected.update(super.disconnected.asScala.toBuffer)
+      }
+    }
+
+    override def mute(id: String): Unit = {
+      runOp(SelectorOperation.Mute, Some(id)) {
+        super.mute(id)
+      }
+    }
+
+    override def unmute(id: String): Unit = {
+      runOp(SelectorOperation.Unmute, Some(id)) {
+        super.unmute(id)
+      }
+    }
+
+    override def wakeup(): Unit = {
+      runOp(SelectorOperation.Wakeup, None) {
+        if (minWakeupCount > 0)
+          minWakeupCount -= 1
+        if (minWakeupCount <= 0)
+          super.wakeup()
+      }
+    }
+
+    override def disconnected: java.util.Map[String, ChannelState] = cachedDisconnected.currentPollValues.toMap.asJava
+
+    override def completedSends: java.util.List[Send] = cachedCompletedSends.currentPollValues.asJava
+
+    override def completedReceives: java.util.List[NetworkReceive] = cachedCompletedReceives.currentPollValues.asJava
+
+    override def close(id: String): Unit = {
+      runOp(SelectorOperation.Close, Some(id)) {
+        super.close(id)
+        allLocallyClosedChannels += id
+      }
+    }
+
+    override def close(): Unit = {
+      runOp(SelectorOperation.CloseSelector, None) {
+        super.close()
+      }
+    }
+
+    def updateMinWakeup(count: Int): Unit = {
+      minWakeupCount = count
+      // For tests that ignore wakeup to process responses together, increase poll timeout
+      // to ensure that poll doesn't complete before the responses are ready
+      pollTimeoutOverride = Some(1000L)
+      // Wakeup current poll to force new poll timeout to take effect
+      super.wakeup()
+    }
+
+    def reset(): Unit = {
+      failures.clear()
+      allCachedPollData.foreach(_.minPerPoll = 1)
+    }
+
+    def notFailed(sockets: Seq[Socket]): Seq[Socket] = {
+      // Each test generates failure for exactly one failed channel
+      assertEquals(1, allFailedChannels.size)
+      val failedConnectionId = allFailedChannels.head
+      sockets.filterNot(socket => isSocketConnectionId(failedConnectionId, socket))
+    }
+  }
 }