You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ju...@apache.org on 2016/04/06 00:16:52 UTC

kafka git commit: KAFKA-3489; Update request metrics if a client closes a connection while the broker response is in flight

Repository: kafka
Updated Branches:
  refs/heads/trunk aee8ebb46 -> 893e79af8


KAFKA-3489; Update request metrics if a client closes a connection while the broker response is in flight

I also fixed a few issues in `SocketServerTest` and included a few clean-ups.

Author: Ismael Juma <is...@juma.me.uk>

Reviewers: Jun Rao <ju...@gmail.com>

Closes #1172 from ijuma/kafka-3489-update-request-metrics-if-client-closes


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/893e79af
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/893e79af
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/893e79af

Branch: refs/heads/trunk
Commit: 893e79af88016f1d8659e72d48a25d4a825b8867
Parents: aee8ebb
Author: Ismael Juma <is...@juma.me.uk>
Authored: Tue Apr 5 18:16:48 2016 -0400
Committer: Jun Rao <ju...@gmail.com>
Committed: Tue Apr 5 18:16:48 2016 -0400

----------------------------------------------------------------------
 .../apache/kafka/common/network/Selector.java   |   6 +-
 .../scala/kafka/network/RequestChannel.scala    |  53 +++---
 .../main/scala/kafka/network/SocketServer.scala | 185 +++++++++++--------
 .../unit/kafka/network/SocketServerTest.scala   | 141 +++++++++++---
 4 files changed, 257 insertions(+), 128 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/893e79af/clients/src/main/java/org/apache/kafka/common/network/Selector.java
----------------------------------------------------------------------
diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
index 698b99c..c333741 100644
--- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java
+++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java
@@ -491,7 +491,7 @@ public class Selector implements Selectable {
     private KafkaChannel channelOrFail(String id) {
         KafkaChannel channel = this.channels.get(id);
         if (channel == null)
-            throw new IllegalStateException("Attempt to retrieve channel for which there is no open connection. Connection id " + id + " existing connections " + channels.keySet().toString());
+            throw new IllegalStateException("Attempt to retrieve channel for which there is no open connection. Connection id " + id + " existing connections " + channels.keySet());
         return channel;
     }
 
@@ -551,7 +551,7 @@ public class Selector implements Selectable {
      * checks if there are any staged receives and adds to completedReceives
      */
     private void addToCompletedReceives() {
-        if (this.stagedReceives.size() > 0) {
+        if (!this.stagedReceives.isEmpty()) {
             Iterator<Map.Entry<KafkaChannel, Deque<NetworkReceive>>> iter = this.stagedReceives.entrySet().iterator();
             while (iter.hasNext()) {
                 Map.Entry<KafkaChannel, Deque<NetworkReceive>> entry = iter.next();
@@ -561,7 +561,7 @@ public class Selector implements Selectable {
                     NetworkReceive networkReceive = deque.poll();
                     this.completedReceives.add(networkReceive);
                     this.sensors.recordBytesReceived(channel.id(), networkReceive.payload().limit());
-                    if (deque.size() == 0)
+                    if (deque.isEmpty())
                         iter.remove();
                 }
             }

http://git-wip-us.apache.org/repos/asf/kafka/blob/893e79af/core/src/main/scala/kafka/network/RequestChannel.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala
index 1105802..17c5b9b 100644
--- a/core/src/main/scala/kafka/network/RequestChannel.scala
+++ b/core/src/main/scala/kafka/network/RequestChannel.scala
@@ -117,36 +117,39 @@ object RequestChannel extends Logging {
       if (apiRemoteCompleteTimeMs < 0)
         apiRemoteCompleteTimeMs = responseCompleteTimeMs
 
-      val requestQueueTime = (requestDequeueTimeMs - startTimeMs).max(0L)
-      val apiLocalTime = (apiLocalCompleteTimeMs - requestDequeueTimeMs).max(0L)
-      val apiRemoteTime = (apiRemoteCompleteTimeMs - apiLocalCompleteTimeMs).max(0L)
-      val apiThrottleTime = (responseCompleteTimeMs - apiRemoteCompleteTimeMs).max(0L)
-      val responseQueueTime = (responseDequeueTimeMs - responseCompleteTimeMs).max(0L)
-      val responseSendTime = (endTimeMs - responseDequeueTimeMs).max(0L)
+      val requestQueueTime = math.max(requestDequeueTimeMs - startTimeMs, 0)
+      val apiLocalTime = math.max(apiLocalCompleteTimeMs - requestDequeueTimeMs, 0)
+      val apiRemoteTime = math.max(apiRemoteCompleteTimeMs - apiLocalCompleteTimeMs, 0)
+      val apiThrottleTime = math.max(responseCompleteTimeMs - apiRemoteCompleteTimeMs, 0)
+      val responseQueueTime = math.max(responseDequeueTimeMs - responseCompleteTimeMs, 0)
+      val responseSendTime = math.max(endTimeMs - responseDequeueTimeMs, 0)
       val totalTime = endTimeMs - startTimeMs
-      var metricsList = List(RequestMetrics.metricsMap(ApiKeys.forId(requestId).name))
-      if (requestId == ApiKeys.FETCH.id) {
-        val isFromFollower = requestObj.asInstanceOf[FetchRequest].isFromFollower
-        metricsList ::= ( if (isFromFollower)
-                            RequestMetrics.metricsMap(RequestMetrics.followFetchMetricName)
-                          else
-                            RequestMetrics.metricsMap(RequestMetrics.consumerFetchMetricName) )
-      }
-      metricsList.foreach{
-        m => m.requestRate.mark()
-             m.requestQueueTimeHist.update(requestQueueTime)
-             m.localTimeHist.update(apiLocalTime)
-             m.remoteTimeHist.update(apiRemoteTime)
-             m.throttleTimeHist.update(apiThrottleTime)
-             m.responseQueueTimeHist.update(responseQueueTime)
-             m.responseSendTimeHist.update(responseSendTime)
-             m.totalTimeHist.update(totalTime)
+      val fetchMetricNames =
+        if (requestId == ApiKeys.FETCH.id) {
+          val isFromFollower = requestObj.asInstanceOf[FetchRequest].isFromFollower
+          Seq(
+            if (isFromFollower) RequestMetrics.followFetchMetricName
+            else RequestMetrics.consumerFetchMetricName
+          )
+        }
+        else Seq.empty
+      val metricNames = fetchMetricNames :+ ApiKeys.forId(requestId).name
+      metricNames.foreach { metricName =>
+        val m = RequestMetrics.metricsMap(metricName)
+        m.requestRate.mark()
+        m.requestQueueTimeHist.update(requestQueueTime)
+        m.localTimeHist.update(apiLocalTime)
+        m.remoteTimeHist.update(apiRemoteTime)
+        m.throttleTimeHist.update(apiThrottleTime)
+        m.responseQueueTimeHist.update(responseQueueTime)
+        m.responseSendTimeHist.update(responseSendTime)
+        m.totalTimeHist.update(totalTime)
       }
 
-      if(requestLogger.isTraceEnabled)
+      if (requestLogger.isTraceEnabled)
         requestLogger.trace("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d,securityProtocol:%s,principal:%s"
           .format(requestDesc(true), connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime, securityProtocol, session.principal))
-      else if(requestLogger.isDebugEnabled)
+      else if (requestLogger.isDebugEnabled)
         requestLogger.debug("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d,securityProtocol:%s,principal:%s"
           .format(requestDesc(false), connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime, securityProtocol, session.principal))
     }

http://git-wip-us.apache.org/repos/asf/kafka/blob/893e79af/core/src/main/scala/kafka/network/SocketServer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala
index 5c31ac6..f1ec2ef 100644
--- a/core/src/main/scala/kafka/network/SocketServer.scala
+++ b/core/src/main/scala/kafka/network/SocketServer.scala
@@ -31,9 +31,8 @@ import kafka.common.KafkaException
 import kafka.metrics.KafkaMetricsGroup
 import kafka.server.KafkaConfig
 import kafka.utils._
-import org.apache.kafka.common.MetricName
 import org.apache.kafka.common.metrics._
-import org.apache.kafka.common.network.{Selector => KSelector, LoginType, Mode, ChannelBuilders}
+import org.apache.kafka.common.network.{ChannelBuilders, KafkaChannel, LoginType, Mode, Selector => KSelector}
 import org.apache.kafka.common.security.auth.KafkaPrincipal
 import org.apache.kafka.common.protocol.SecurityProtocol
 import org.apache.kafka.common.protocol.types.SchemaException
@@ -41,7 +40,7 @@ import org.apache.kafka.common.utils.{Time, Utils}
 
 import scala.collection._
 import JavaConverters._
-import scala.util.control.{NonFatal, ControlThrowable}
+import scala.util.control.{ControlThrowable, NonFatal}
 
 /**
  * An NIO socket server. The threading model is
@@ -83,8 +82,6 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time
 
       val sendBufferSize = config.socketSendBufferBytes
       val recvBufferSize = config.socketReceiveBufferBytes
-      val maxRequestSize = config.socketRequestMaxBytes
-      val connectionsMaxIdleMs = config.connectionsMaxIdleMs
       val brokerId = config.brokerId
 
       var processorBeginIndex = 0
@@ -92,18 +89,8 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time
         val protocol = endpoint.protocolType
         val processorEndIndex = processorBeginIndex + numProcessorThreads
 
-        for (i <- processorBeginIndex until processorEndIndex) {
-          processors(i) = new Processor(i,
-            time,
-            maxRequestSize,
-            requestChannel,
-            connectionQuotas,
-            connectionsMaxIdleMs,
-            protocol,
-            config.values,
-            metrics
-          )
-        }
+        for (i <- processorBeginIndex until processorEndIndex)
+          processors(i) = newProcessor(i, connectionQuotas, protocol)
 
         val acceptor = new Acceptor(endpoint, sendBufferSize, recvBufferSize, brokerId,
           processors.slice(processorBeginIndex, processorEndIndex), connectionQuotas)
@@ -148,10 +135,27 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time
     }
   }
 
+  /* `protected` for test usage */
+  protected[network] def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, protocol: SecurityProtocol): Processor = {
+    new Processor(id,
+      time,
+      config.socketRequestMaxBytes,
+      requestChannel,
+      connectionQuotas,
+      config.connectionsMaxIdleMs,
+      protocol,
+      config.values,
+      metrics
+    )
+  }
+
   /* For test usage */
   private[network] def connectionCount(address: InetAddress): Int =
     Option(connectionQuotas).fold(0)(_.get(address))
 
+  /* For test usage */
+  private[network] def processor(index: Int): Processor = processors(index)
+
 }
 
 /**
@@ -376,10 +380,7 @@ private[kafka] class Processor(val id: Int,
 
   private val newConnections = new ConcurrentLinkedQueue[SocketChannel]()
   private val inflightResponses = mutable.Map[String, RequestChannel.Response]()
-  private val channelBuilder = ChannelBuilders.create(protocol, Mode.SERVER, LoginType.SERVER, channelConfigs)
-  private val metricTags = new util.HashMap[String, String]()
-  metricTags.put("networkProcessor", id.toString)
-
+  private val metricTags = Map("networkProcessor" -> id.toString).asJava
 
   newGauge("IdlePercent",
     new Gauge[Double] {
@@ -398,65 +399,27 @@ private[kafka] class Processor(val id: Int,
     "socket-server",
     metricTags,
     false,
-    channelBuilder)
+    ChannelBuilders.create(protocol, Mode.SERVER, LoginType.SERVER, channelConfigs))
 
   override def run() {
     startupComplete()
-    while(isRunning) {
+    while (isRunning) {
       try {
         // setup any new connections that have been queued up
         configureNewConnections()
         // register any new responses for writing
         processNewResponses()
-
-        try {
-          selector.poll(300)
-        } catch {
-          case e @ (_: IllegalStateException | _: IOException) =>
-            error("Closing processor %s due to illegal state or IO exception".format(id))
-            swallow(closeAll())
-            shutdownComplete()
-            throw e
-        }
-        selector.completedReceives.asScala.foreach { receive =>
-          try {
-            val channel = selector.channel(receive.source)
-            val session = RequestChannel.Session(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, channel.principal.getName),
-              channel.socketAddress)
-            val req = RequestChannel.Request(processor = id, connectionId = receive.source, session = session, buffer = receive.payload, startTimeMs = time.milliseconds, securityProtocol = protocol)
-            requestChannel.sendRequest(req)
-            selector.mute(receive.source)
-          } catch {
-            case e @ (_: InvalidRequestException | _: SchemaException) =>
-              // note that even though we got an exception, we can assume that receive.source is valid. Issues with constructing a valid receive object were handled earlier
-              error("Closing socket for " + receive.source + " because of error", e)
-              close(selector, receive.source)
-          }
-        }
-
-        selector.completedSends.asScala.foreach { send =>
-          val resp = inflightResponses.remove(send.destination).getOrElse {
-            throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`")
-          }
-          resp.request.updateRequestMetrics()
-          selector.unmute(send.destination)
-        }
-
-        selector.disconnected.asScala.foreach { connectionId =>
-          val remoteHost = ConnectionId.fromString(connectionId).getOrElse {
-            throw new IllegalStateException(s"connectionId has unexpected format: $connectionId")
-          }.remoteHost
-          // the channel has been closed by the selector but the quotas still need to be updated
-          connectionQuotas.dec(InetAddress.getByName(remoteHost))
-        }
-
+        poll()
+        processCompletedReceives()
+        processCompletedSends()
+        processDisconnected()
       } catch {
         // We catch all the throwables here to prevent the processor thread from exiting. We do this because
-        // letting a processor exit might cause bigger impact on the broker. Usually the exceptions thrown would
+        // letting a processor exit might cause a bigger impact on the broker. Usually the exceptions thrown would
         // be either associated with a specific socket channel or a bad request. We just ignore the bad socket channel
         // or request. This behavior might need to be reviewed if we see an exception that need the entire broker to stop.
-        case e : ControlThrowable => throw e
-        case e : Throwable =>
+        case e: ControlThrowable => throw e
+        case e: Throwable =>
           error("Processor got uncaught exception.", e)
       }
     }
@@ -468,7 +431,7 @@ private[kafka] class Processor(val id: Int,
 
   private def processNewResponses() {
     var curr = requestChannel.receiveResponse(id)
-    while(curr != null) {
+    while (curr != null) {
       try {
         curr.responseAction match {
           case RequestChannel.NoOpAction =>
@@ -478,9 +441,7 @@ private[kafka] class Processor(val id: Int,
             trace("Socket server received empty response to send, registering for read: " + curr)
             selector.unmute(curr.request.connectionId)
           case RequestChannel.SendAction =>
-            trace("Socket server received response to send, registering for write and sending data: " + curr)
-            selector.send(curr.responseSend)
-            inflightResponses += (curr.request.connectionId -> curr)
+            sendResponse(curr)
           case RequestChannel.CloseConnectionAction =>
             curr.request.updateRequestMetrics
             trace("Closing socket connection actively according to the response code.")
@@ -492,6 +453,71 @@ private[kafka] class Processor(val id: Int,
     }
   }
 
+  /* `protected` for test usage */
+  protected[network] def sendResponse(response: RequestChannel.Response) {
+    trace(s"Socket server received response to send, registering for write and sending data: $response")
+    val channel = selector.channel(response.responseSend.destination)
+    // `channel` can be null if the selector closed the connection because it was idle for too long
+    if (channel == null) {
+      warn(s"Attempting to send response via channel for which there is no open connection, connection id $id")
+      response.request.updateRequestMetrics()
+    }
+    else {
+      selector.send(response.responseSend)
+      inflightResponses += (response.request.connectionId -> response)
+    }
+  }
+
+  private def poll() {
+    try selector.poll(300)
+    catch {
+      case e @ (_: IllegalStateException | _: IOException) =>
+        error(s"Closing processor $id due to illegal state or IO exception")
+        swallow(closeAll())
+        shutdownComplete()
+        throw e
+    }
+  }
+
+  private def processCompletedReceives() {
+    selector.completedReceives.asScala.foreach { receive =>
+      try {
+        val channel = selector.channel(receive.source)
+        val session = RequestChannel.Session(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, channel.principal.getName),
+          channel.socketAddress)
+        val req = RequestChannel.Request(processor = id, connectionId = receive.source, session = session, buffer = receive.payload, startTimeMs = time.milliseconds, securityProtocol = protocol)
+        requestChannel.sendRequest(req)
+        selector.mute(receive.source)
+      } catch {
+        case e @ (_: InvalidRequestException | _: SchemaException) =>
+          // note that even though we got an exception, we can assume that receive.source is valid. Issues with constructing a valid receive object were handled earlier
+          error(s"Closing socket for ${receive.source} because of error", e)
+          close(selector, receive.source)
+      }
+    }
+  }
+
+  private def processCompletedSends() {
+    selector.completedSends.asScala.foreach { send =>
+      val resp = inflightResponses.remove(send.destination).getOrElse {
+        throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`")
+      }
+      resp.request.updateRequestMetrics()
+      selector.unmute(send.destination)
+    }
+  }
+
+  private def processDisconnected() {
+    selector.disconnected.asScala.foreach { connectionId =>
+      val remoteHost = ConnectionId.fromString(connectionId).getOrElse {
+        throw new IllegalStateException(s"connectionId has unexpected format: $connectionId")
+      }.remoteHost
+      inflightResponses.remove(connectionId).foreach(_.request.updateRequestMetrics())
+      // the channel has been closed by the selector but the quotas still need to be updated
+      connectionQuotas.dec(InetAddress.getByName(remoteHost))
+    }
+  }
+
   /**
    * Queue up a new connection for reading
    */
@@ -504,10 +530,10 @@ private[kafka] class Processor(val id: Int,
    * Register any new connections that have been queued up
    */
   private def configureNewConnections() {
-    while(!newConnections.isEmpty) {
+    while (!newConnections.isEmpty) {
       val channel = newConnections.poll()
       try {
-        debug("Processor " + id + " listening to new connection from " + channel.socket.getRemoteSocketAddress)
+        debug(s"Processor $id listening to new connection from ${channel.socket.getRemoteSocketAddress}")
         val localHost = channel.socket().getLocalAddress.getHostAddress
         val localPort = channel.socket().getLocalPort
         val remoteHost = channel.socket().getInetAddress.getHostAddress
@@ -515,12 +541,12 @@ private[kafka] class Processor(val id: Int,
         val connectionId = ConnectionId(localHost, localPort, remoteHost, remotePort).toString
         selector.register(connectionId, channel)
       } catch {
-        // We explicitly catch all non fatal exceptions and close the socket to avoid socket leak. The other
-        // throwables will be caught in processor and logged as uncaught exception.
+        // We explicitly catch all non fatal exceptions and close the socket to avoid a socket leak. The other
+        // throwables will be caught in processor and logged as uncaught exceptions.
         case NonFatal(e) =>
-          // need to close the channel here to avoid socket leak.
+          // need to close the channel here to avoid a socket leak.
           close(channel)
-          error("Processor " + id + " closed connection from " + channel.getRemoteAddress, e)
+          error(s"Processor $id closed connection from ${channel.getRemoteAddress}", e)
       }
     }
   }
@@ -535,6 +561,9 @@ private[kafka] class Processor(val id: Int,
     selector.close()
   }
 
+  /* For test usage */
+  private[network] def channel(connectionId: String): Option[KafkaChannel] =
+    Option(selector.channel(connectionId))
 
   /**
    * Wakeup the thread for selection.

http://git-wip-us.apache.org/repos/asf/kafka/blob/893e79af/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 5d28894..81e5232 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -39,7 +39,7 @@ import org.junit.Assert._
 import org.junit._
 import org.scalatest.junit.JUnitSuite
 
-import scala.collection.Map
+import scala.collection.mutable.ArrayBuffer
 
 class SocketServerTest extends JUnitSuite {
   val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0)
@@ -55,6 +55,7 @@ class SocketServerTest extends JUnitSuite {
   val metrics = new Metrics
   val server = new SocketServer(config, metrics, new SystemTime)
   server.startup()
+  val sockets = new ArrayBuffer[Socket]
 
   def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None) {
     val outgoing = new DataOutputStream(socket.getOutputStream)
@@ -79,7 +80,12 @@ class SocketServerTest extends JUnitSuite {
 
   /* A simple request handler that just echos back the response */
   def processRequest(channel: RequestChannel) {
-    val request = channel.receiveRequest
+    val request = channel.receiveRequest(2000)
+    assertNotNull("receiveRequest timed out", request)
+    processRequest(channel, request)
+  }
+
+  def processRequest(channel: RequestChannel, request: RequestChannel.Request) {
     val byteBuffer = ByteBuffer.allocate(request.header.sizeOf + request.body.sizeOf)
     request.header.writeTo(byteBuffer)
     request.body.writeTo(byteBuffer)
@@ -89,13 +95,18 @@ class SocketServerTest extends JUnitSuite {
     channel.sendResponse(new RequestChannel.Response(request.processor, request, send))
   }
 
-  def connect(s: SocketServer = server, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT) =
-    new Socket("localhost", server.boundPort(protocol))
+  def connect(s: SocketServer = server, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT) = {
+    val socket = new Socket("localhost", s.boundPort(protocol))
+    sockets += socket
+    socket
+  }
 
   @After
-  def cleanup() {
+  def tearDown() {
     metrics.close()
     server.shutdown()
+    sockets.foreach(_.close())
+    sockets.clear()
   }
 
   private def producerRequestBytes: Array[Byte] = {
@@ -183,7 +194,7 @@ class SocketServerTest extends JUnitSuite {
 
   @Test
   def testMaxConnectionsPerIp() {
-    // make the maximum allowable number of connections and then leak them
+    // make the maximum allowable number of connections
     val conns = (0 until server.config.maxConnectionsPerIp).map(_ => connect())
     // now try one more (should fail)
     val conn = connect()
@@ -201,27 +212,30 @@ class SocketServerTest extends JUnitSuite {
     sendRequest(conn2, serializedBytes)
     val request = server.requestChannel.receiveRequest(2000)
     assertNotNull(request)
-    conn2.close()
-    conns.tail.foreach(_.close())
   }
 
   @Test
-  def testMaxConnectionsPerIPOverrides() {
-    val overrideNum = 6
-    val overrides = Map("localhost" -> overrideNum)
+  def testMaxConnectionsPerIpOverrides() {
+    val overrideNum = server.config.maxConnectionsPerIp + 1
     val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0)
+    overrideProps.put(KafkaConfig.MaxConnectionsPerIpOverridesProp, s"localhost:$overrideNum")
     val serverMetrics = new Metrics()
-    val overrideServer: SocketServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime())
+    val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime())
     try {
       overrideServer.startup()
-      // make the maximum allowable number of connections and then leak them
-      val conns = ((0 until overrideNum).map(i => connect(overrideServer)))
+      // make the maximum allowable number of connections
+      val conns = (0 until overrideNum).map(_ => connect(overrideServer))
+
+      // it should succeed
+      val serializedBytes = producerRequestBytes
+      sendRequest(conns.last, serializedBytes)
+      val request = overrideServer.requestChannel.receiveRequest(2000)
+      assertNotNull(request)
+
       // now try one more (should fail)
       val conn = connect(overrideServer)
       conn.setSoTimeout(3000)
       assertEquals(-1, conn.getInputStream.read())
-      conn.close()
-      conns.foreach(_.close())
     } finally {
       overrideServer.shutdown()
       serverMetrics.close()
@@ -229,16 +243,16 @@ class SocketServerTest extends JUnitSuite {
   }
 
   @Test
-  def testSslSocketServer(): Unit = {
+  def testSslSocketServer() {
     val trustStoreFile = File.createTempFile("truststore", ".jks")
     val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, interBrokerSecurityProtocol = Some(SecurityProtocol.SSL),
       trustStoreFile = Some(trustStoreFile))
     overrideProps.put(KafkaConfig.ListenersProp, "SSL://localhost:0")
 
     val serverMetrics = new Metrics
-    val overrideServer: SocketServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime)
-    overrideServer.startup()
+    val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime)
     try {
+      overrideServer.startup()
       val sslContext = SSLContext.getInstance("TLSv1.2")
       sslContext.init(null, Array(TestUtils.trustAllCerts), new java.security.SecureRandom())
       val socketFactory = sslContext.getSocketFactory
@@ -271,12 +285,95 @@ class SocketServerTest extends JUnitSuite {
   }
 
   @Test
-  def testSessionPrincipal(): Unit = {
+  def testSessionPrincipal() {
     val socket = connect()
     val bytes = new Array[Byte](40)
     sendRequest(socket, bytes, Some(0))
-    assertEquals(KafkaPrincipal.ANONYMOUS, server.requestChannel.receiveRequest().session.principal)
-    socket.close()
+    assertEquals(KafkaPrincipal.ANONYMOUS, server.requestChannel.receiveRequest(2000).session.principal)
+  }
+
+  /* Test that we update request metrics if the client closes the connection while the broker response is in flight. */
+  @Test
+  def testClientDisconnectionUpdatesRequestMetrics() {
+    val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0)
+    val serverMetrics = new Metrics
+    var conn: Socket = null
+    val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, new SystemTime) {
+      override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, protocol: SecurityProtocol): Processor = {
+        new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas,
+          config.connectionsMaxIdleMs, protocol, config.values, metrics) {
+          override protected[network] def sendResponse(response: RequestChannel.Response) {
+            conn.close()
+            super.sendResponse(response)
+          }
+        }
+      }
+    }
+    try {
+      overrideServer.startup()
+      conn = connect(overrideServer)
+      val serializedBytes = producerRequestBytes
+      sendRequest(conn, serializedBytes)
+
+      val channel = overrideServer.requestChannel
+      val request = channel.receiveRequest(2000)
+
+      val requestMetrics = RequestMetrics.metricsMap(ApiKeys.forId(request.requestId).name)
+      def totalTimeHistCount(): Long = requestMetrics.totalTimeHist.count
+      val expectedTotalTimeCount = totalTimeHistCount() + 1
+
+      // send a large buffer to ensure that the broker detects the client disconnection while writing to the socket channel.
+      // On Mac OS X, the initial write seems to always succeed and it is able to write up to 102400 bytes on the initial
+      // write. If the buffer is smaller than this, the write is considered complete and the disconnection is not
+      // detected. If the buffer is larger than 102400 bytes, a second write is attempted and it fails with an
+      // IOException.
+      val send = new NetworkSend(request.connectionId, ByteBuffer.allocate(550000))
+      channel.sendResponse(new RequestChannel.Response(request.processor, request, send))
+      TestUtils.waitUntilTrue(() => totalTimeHistCount() == expectedTotalTimeCount,
+        s"request metrics not updated, expected: $expectedTotalTimeCount, actual: ${totalTimeHistCount()}")
+
+    } finally {
+      overrideServer.shutdown()
+      serverMetrics.close()
+    }
+  }
+
+  /*
+   * Test that we update request metrics if the channel has been removed from the selector when the broker calls
+   * `selector.send` (selector closes old connections, for example).
+   */
+  @Test
+  def testBrokerSendAfterChannelClosedUpdatesRequestMetrics() {
+    val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0)
+    props.setProperty(KafkaConfig.ConnectionsMaxIdleMsProp, "100")
+    val serverMetrics = new Metrics
+    var conn: Socket = null
+    val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, new SystemTime)
+    try {
+      overrideServer.startup()
+      conn = connect(overrideServer)
+      val serializedBytes = producerRequestBytes
+      sendRequest(conn, serializedBytes)
+      val channel = overrideServer.requestChannel
+      val request = channel.receiveRequest(2000)
+
+      TestUtils.waitUntilTrue(() => overrideServer.processor(request.processor).channel(request.connectionId).isEmpty,
+        s"Idle connection `${request.connectionId}` was not closed by selector")
+
+      val requestMetrics = RequestMetrics.metricsMap(ApiKeys.forId(request.requestId).name)
+      def totalTimeHistCount(): Long = requestMetrics.totalTimeHist.count
+      val expectedTotalTimeCount = totalTimeHistCount() + 1
+
+      processRequest(channel, request)
+
+      TestUtils.waitUntilTrue(() => totalTimeHistCount() == expectedTotalTimeCount,
+        s"request metrics not updated, expected: $expectedTotalTimeCount, actual: ${totalTimeHistCount()}")
+
+    } finally {
+      overrideServer.shutdown()
+      serverMetrics.close()
+    }
+
   }
 
 }