You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2021/02/23 04:35:32 UTC

[kafka] branch trunk updated: MINOR: Move `RequestChannel.Response` creation logic into `RequestChannel` (#9912)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 23b8541  MINOR: Move `RequestChannel.Response` creation logic into `RequestChannel` (#9912)
23b8541 is described below

commit 23b85417b3ed53bf092877baef42608f7b778843
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Mon Feb 22 20:34:25 2021 -0800

    MINOR: Move `RequestChannel.Response` creation logic into `RequestChannel` (#9912)
    
    This patch moves some common response creation logic from `RequestHandlerHelper` and into `RequestChannel`. This refactor has the following benefits:
    
    - It allows us to get rid of some logic that was previously duplicated in both `RequestHandlerHelper` and `TestRaftRequestHandler`.
    - It ensures that we do not need to rely on the caller to ensure that `updateErrorMetrics` gets called since this is handled internally in `RequestChannel`.
    - It provides better encapsulation of the quota workflow which relies on custom `Response` objects. Previously it was quite confusing for `KafkaApis` to handle this directly through the `sendResponse` API.
    
    Reviewers: Ismael Juma <is...@juma.me.uk>
---
 .../kafka/common/requests/RequestContext.java      |  15 +
 .../main/scala/kafka/network/RequestChannel.scala  |  40 +-
 .../scala/kafka/server/ClientQuotaManager.scala    |   8 +-
 .../main/scala/kafka/server/ControllerApis.scala   |   2 +-
 core/src/main/scala/kafka/server/KafkaApis.scala   |  29 +-
 .../scala/kafka/server/RequestHandlerHelper.scala  | 106 ++---
 .../main/scala/kafka/server/ThrottledChannel.scala |  22 +-
 .../scala/kafka/tools/TestRaftRequestHandler.scala |  63 +--
 .../unit/kafka/network/SocketServerTest.scala      |  16 +-
 .../kafka/server/BaseClientQuotaManagerTest.scala  |  16 +-
 .../unit/kafka/server/ControllerApisTest.scala     |  81 ++--
 .../scala/unit/kafka/server/KafkaApisTest.scala    | 508 ++++++++++-----------
 .../server/ThrottledChannelExpirationTest.scala    |  55 +--
 13 files changed, 449 insertions(+), 512 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java b/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java
index 225db37..d7a6df1 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java
@@ -175,4 +175,19 @@ public class RequestContext implements AuthorizableRequestContext {
     public int correlationId() {
         return header.correlationId();
     }
+
+    @Override
+    public String toString() {
+        return "RequestContext(" +
+            "header=" + header +
+            ", connectionId='" + connectionId + '\'' +
+            ", clientAddress=" + clientAddress +
+            ", principal=" + principal +
+            ", listenerName=" + listenerName +
+            ", securityProtocol=" + securityProtocol +
+            ", clientInformation=" + clientInformation +
+            ", fromPrivilegedListener=" + fromPrivilegedListener +
+            ", principalSerde=" + principalSerde +
+            ')';
+    }
 }
diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala
index 48f723f..40d0bd6 100644
--- a/core/src/main/scala/kafka/network/RequestChannel.scala
+++ b/core/src/main/scala/kafka/network/RequestChannel.scala
@@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.JsonNode
 import com.typesafe.scalalogging.Logger
 import com.yammer.metrics.core.Meter
 import kafka.metrics.KafkaMetricsGroup
+import kafka.network
 import kafka.server.KafkaConfig
 import kafka.utils.{Logging, NotNothing, Pool}
 import kafka.utils.Implicits._
@@ -372,9 +373,44 @@ class RequestChannel(val queueSize: Int,
     requestQueue.put(request)
   }
 
-  /** Send a response back to the socket server to be sent over the network */
-  def sendResponse(response: RequestChannel.Response): Unit = {
+  def closeConnection(
+    request: RequestChannel.Request,
+    errorCounts: java.util.Map[Errors, Integer]
+  ): Unit = {
+    // This case is used when the request handler has encountered an error, but the client
+    // does not expect a response (e.g. when produce request has acks set to 0)
+    updateErrorMetrics(request.header.apiKey, errorCounts.asScala)
+    sendResponse(new RequestChannel.CloseConnectionResponse(request))
+  }
+
+  def sendResponse(
+    request: RequestChannel.Request,
+    response: AbstractResponse,
+    onComplete: Option[Send => Unit]
+  ): Unit = {
+    updateErrorMetrics(request.header.apiKey, response.errorCounts.asScala)
+    sendResponse(new RequestChannel.SendResponse(
+      request,
+      request.buildResponseSend(response),
+      request.responseNode(response),
+      onComplete
+    ))
+  }
+
+  def sendNoOpResponse(request: RequestChannel.Request): Unit = {
+    sendResponse(new network.RequestChannel.NoOpResponse(request))
+  }
 
+  def startThrottling(request: RequestChannel.Request): Unit = {
+    sendResponse(new RequestChannel.StartThrottlingResponse(request))
+  }
+
+  def endThrottling(request: RequestChannel.Request): Unit = {
+    sendResponse(new EndThrottlingResponse(request))
+  }
+
+  /** Send a response back to the socket server to be sent over the network */
+  private[network] def sendResponse(response: RequestChannel.Response): Unit = {
     if (isTraceEnabled) {
       val requestHeader = response.request.headerForLoggingOrThrottling()
       val message = response match {
diff --git a/core/src/main/scala/kafka/server/ClientQuotaManager.scala b/core/src/main/scala/kafka/server/ClientQuotaManager.scala
index e32978c..1f5b752 100644
--- a/core/src/main/scala/kafka/server/ClientQuotaManager.scala
+++ b/core/src/main/scala/kafka/server/ClientQuotaManager.scala
@@ -335,11 +335,15 @@ class ClientQuotaManager(private val config: ClientQuotaManagerConfig,
    * @param throttleTimeMs Duration in milliseconds for which the channel is to be muted.
    * @param channelThrottlingCallback Callback for channel throttling
    */
-  def throttle(request: RequestChannel.Request, throttleTimeMs: Int, channelThrottlingCallback: Response => Unit): Unit = {
+  def throttle(
+    request: RequestChannel.Request,
+    throttleCallback: ThrottleCallback,
+    throttleTimeMs: Int
+  ): Unit = {
     if (throttleTimeMs > 0) {
       val clientSensors = getOrCreateQuotaSensors(request.session, request.headerForLoggingOrThrottling().clientId)
       clientSensors.throttleTimeSensor.record(throttleTimeMs)
-      val throttledChannel = new ThrottledChannel(request, time, throttleTimeMs, channelThrottlingCallback)
+      val throttledChannel = new ThrottledChannel(time, throttleTimeMs, throttleCallback)
       delayQueue.add(throttledChannel)
       delayQueueSensor.record()
       debug("Channel throttled for sensor (%s). Delay time: (%d)".format(clientSensors.quotaSensor.name(), throttleTimeMs))
diff --git a/core/src/main/scala/kafka/server/ControllerApis.scala b/core/src/main/scala/kafka/server/ControllerApis.scala
index abd8506..336775c 100644
--- a/core/src/main/scala/kafka/server/ControllerApis.scala
+++ b/core/src/main/scala/kafka/server/ControllerApis.scala
@@ -63,7 +63,7 @@ class ControllerApis(val requestChannel: RequestChannel,
                      val controllerNodes: Seq[Node]) extends ApiRequestHandler with Logging {
 
   val authHelper = new AuthHelper(authorizer)
-  val requestHelper = new RequestHandlerHelper(requestChannel, quotas, time, s"[ControllerApis id=${config.nodeId}] ")
+  val requestHelper = new RequestHandlerHelper(requestChannel, quotas, time)
 
   var supportedApiKeys = Set(
     ApiKeys.FETCH,
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index 5a926d4..1245fe7 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -120,7 +120,7 @@ class KafkaApis(val requestChannel: RequestChannel,
   private val alterAclsPurgatory = new DelayedFuturePurgatory(purgatoryName = "AlterAcls", brokerId = config.brokerId)
 
   val authHelper = new AuthHelper(authorizer)
-  val requestHelper = new RequestHandlerHelper(requestChannel, quotas, time, logIdent)
+  val requestHelper = new RequestHandlerHelper(requestChannel, quotas, time)
 
   def close(): Unit = {
     alterAclsPurgatory.shutdown()
@@ -142,7 +142,7 @@ class KafkaApis(val requestChannel: RequestChannel,
           info(s"The client connection will be closed due to controller responded " +
             s"unsupported version exception during $request forwarding. " +
             s"This could happen when the controller changed after the connection was established.")
-          requestHelper.closeConnection(request, Collections.emptyMap())
+          requestChannel.closeConnection(request, Collections.emptyMap())
       }
     }
 
@@ -226,7 +226,10 @@ class KafkaApis(val requestChannel: RequestChannel,
       }
     } catch {
       case e: FatalExitError => throw e
-      case e: Throwable => requestHelper.handleError(request, e)
+      case e: Throwable =>
+        error(s"Unexpected error handling request ${request.requestDesc(true)} " +
+          s"with context ${request.context}", e)
+        requestHelper.handleError(request, e)
     } finally {
       // try to complete delayed action. In order to avoid conflicting locking, the actions to complete delayed requests
       // are kept in a queue. We add the logic to check the ReplicaManager queue at the end of KafkaApis.handle() and the
@@ -593,9 +596,9 @@ class KafkaApis(val requestChannel: RequestChannel,
       if (maxThrottleTimeMs > 0) {
         request.apiThrottleTimeMs = maxThrottleTimeMs
         if (bandwidthThrottleTimeMs > requestThrottleTimeMs) {
-          quotas.produce.throttle(request, bandwidthThrottleTimeMs, requestChannel.sendResponse)
+          requestHelper.throttle(quotas.produce, request, bandwidthThrottleTimeMs)
         } else {
-          quotas.request.throttle(request, requestThrottleTimeMs, requestChannel.sendResponse)
+          requestHelper.throttle(quotas.request, request, requestThrottleTimeMs)
         }
       }
 
@@ -613,14 +616,14 @@ class KafkaApis(val requestChannel: RequestChannel,
               s"from client id ${request.header.clientId} with ack=0\n" +
               s"Topic and partition to exceptions: $exceptionsSummary"
           )
-          requestHelper.closeConnection(request, new ProduceResponse(mergedResponseStatus.asJava).errorCounts)
+          requestChannel.closeConnection(request, new ProduceResponse(mergedResponseStatus.asJava).errorCounts)
         } else {
           // Note that although request throttling is exempt for acks == 0, the channel may be throttled due to
           // bandwidth quota violation.
           requestHelper.sendNoOpResponseExemptThrottle(request)
         }
       } else {
-        requestHelper.sendResponse(request, Some(new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs)), None)
+        requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs), None)
       }
     }
 
@@ -872,9 +875,9 @@ class KafkaApis(val requestChannel: RequestChannel,
           // from the fetch quota because we are going to return an empty response.
           quotas.fetch.unrecordQuotaSensor(request, responseSize, timeMs)
           if (bandwidthThrottleTimeMs > requestThrottleTimeMs) {
-            quotas.fetch.throttle(request, bandwidthThrottleTimeMs, requestChannel.sendResponse)
+            requestHelper.throttle(quotas.fetch, request, bandwidthThrottleTimeMs)
           } else {
-            quotas.request.throttle(request, requestThrottleTimeMs, requestChannel.sendResponse)
+            requestHelper.throttle(quotas.request, request, requestThrottleTimeMs)
           }
           // If throttling is required, return an empty response.
           unconvertedFetchResponse = fetchContext.getThrottledResponse(maxThrottleTimeMs)
@@ -885,7 +888,7 @@ class KafkaApis(val requestChannel: RequestChannel,
         }
 
         // Send the response immediately.
-        requestHelper.sendResponse(request, Some(createResponse(maxThrottleTimeMs)), Some(updateConversionStats))
+        requestChannel.sendResponse(request, createResponse(maxThrottleTimeMs), Some(updateConversionStats))
       }
     }
 
@@ -3207,13 +3210,13 @@ class KafkaApis(val requestChannel: RequestChannel,
     if (!isForwardingEnabled(request)) {
       info(s"Closing connection ${request.context.connectionId} because it sent an `Envelope` " +
         "request even though forwarding has not been enabled")
-      requestHelper.closeConnection(request, Collections.emptyMap())
+      requestChannel.closeConnection(request, Collections.emptyMap())
       return
     } else if (!request.context.fromPrivilegedListener) {
       info(s"Closing connection ${request.context.connectionId} from listener ${request.context.listenerName} " +
         s"because it sent an `Envelope` request, which is only accepted on the inter-broker listener " +
         s"${config.interBrokerListenerName}.")
-      requestHelper.closeConnection(request, Collections.emptyMap())
+      requestChannel.closeConnection(request, Collections.emptyMap())
       return
     } else if (!authHelper.authorize(request.context, CLUSTER_ACTION, CLUSTER, CLUSTER_NAME)) {
       requestHelper.sendErrorResponseMaybeThrottle(request, new ClusterAuthorizationException(
@@ -3225,7 +3228,7 @@ class KafkaApis(val requestChannel: RequestChannel,
       return
     }
     EnvelopeUtils.handleEnvelopeRequest(request, requestChannel.metrics, handle)
-    }
+  }
 
   def handleDescribeProducersRequest(request: RequestChannel.Request): Unit = {
     val describeProducersRequest = request.body[DescribeProducersRequest]
diff --git a/core/src/main/scala/kafka/server/RequestHandlerHelper.scala b/core/src/main/scala/kafka/server/RequestHandlerHelper.scala
index cf2b816..ef9ff82 100644
--- a/core/src/main/scala/kafka/server/RequestHandlerHelper.scala
+++ b/core/src/main/scala/kafka/server/RequestHandlerHelper.scala
@@ -22,17 +22,12 @@ import kafka.coordinator.group.GroupCoordinator
 import kafka.coordinator.transaction.TransactionCoordinator
 import kafka.network.RequestChannel
 import kafka.server.QuotaFactory.QuotaManagers
-import kafka.utils.Logging
 import org.apache.kafka.common.errors.ClusterAuthorizationException
 import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.network.Send
-import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse}
 import org.apache.kafka.common.utils.Time
 
-import scala.jdk.CollectionConverters._
-
-
 object RequestHandlerHelper {
 
   def onLeadershipChange(groupCoordinator: GroupCoordinator,
@@ -56,38 +51,55 @@ object RequestHandlerHelper {
         txnCoordinator.onResignation(partition.partitionId, Some(partition.getLeaderEpoch))
     }
   }
-}
-
 
+}
 
-class RequestHandlerHelper(requestChannel: RequestChannel,
-                           quotas: QuotaManagers,
-                           time: Time,
-                           logPrefix: String) extends Logging {
-
-  this.logIdent = logPrefix
+class RequestHandlerHelper(
+  requestChannel: RequestChannel,
+  quotas: QuotaManagers,
+  time: Time
+) {
+
+  def throttle(
+    quotaManager: ClientQuotaManager,
+    request: RequestChannel.Request,
+    throttleTimeMs: Int
+  ): Unit = {
+    val callback = new ThrottleCallback {
+      override def startThrottling(): Unit = requestChannel.startThrottling(request)
+      override def endThrottling(): Unit = requestChannel.endThrottling(request)
+    }
+    quotaManager.throttle(request, callback, throttleTimeMs)
+  }
 
   def handleError(request: RequestChannel.Request, e: Throwable): Unit = {
     val mayThrottle = e.isInstanceOf[ClusterAuthorizationException] || !request.header.apiKey.clusterAction
-    error("Error when handling request: " +
-      s"clientId=${request.header.clientId}, " +
-      s"correlationId=${request.header.correlationId}, " +
-      s"api=${request.header.apiKey}, " +
-      s"version=${request.header.apiVersion}, " +
-      s"body=${request.body[AbstractRequest]}", e)
     if (mayThrottle)
       sendErrorResponseMaybeThrottle(request, e)
     else
       sendErrorResponseExemptThrottle(request, e)
   }
 
+  def sendErrorOrCloseConnection(
+    request: RequestChannel.Request,
+    error: Throwable,
+    throttleMs: Int
+  ): Unit = {
+    val requestBody = request.body[AbstractRequest]
+    val response = requestBody.getErrorResponse(throttleMs, error)
+    if (response == null)
+      requestChannel.closeConnection(request, requestBody.errorCounts(error))
+    else
+      requestChannel.sendResponse(request, response, None)
+  }
+
   def sendForwardedResponse(request: RequestChannel.Request,
                             response: AbstractResponse): Unit = {
     // For forwarded requests, we take the throttle time from the broker that
     // the request was forwarded to
     val throttleTimeMs = response.throttleTimeMs()
-    quotas.request.throttle(request, throttleTimeMs, requestChannel.sendResponse)
-    sendResponse(request, Some(response), None)
+    throttle(quotas.request, request, throttleTimeMs)
+    requestChannel.sendResponse(request, response, None)
   }
 
   // Throttle the channel if the request quota is enabled but has been violated. Regardless of throttling, send the
@@ -97,15 +109,15 @@ class RequestHandlerHelper(requestChannel: RequestChannel,
     val throttleTimeMs = maybeRecordAndGetThrottleTimeMs(request)
     // Only throttle non-forwarded requests
     if (!request.isForwarded)
-      quotas.request.throttle(request, throttleTimeMs, requestChannel.sendResponse)
-    sendResponse(request, Some(createResponse(throttleTimeMs)), None)
+      throttle(quotas.request, request, throttleTimeMs)
+    requestChannel.sendResponse(request, createResponse(throttleTimeMs), None)
   }
 
   def sendErrorResponseMaybeThrottle(request: RequestChannel.Request, error: Throwable): Unit = {
     val throttleTimeMs = maybeRecordAndGetThrottleTimeMs(request)
     // Only throttle non-forwarded requests or cluster authorization failures
     if (error.isInstanceOf[ClusterAuthorizationException] || !request.isForwarded)
-      quotas.request.throttle(request, throttleTimeMs, requestChannel.sendResponse)
+      throttle(quotas.request, request, throttleTimeMs)
     sendErrorOrCloseConnection(request, error, throttleTimeMs)
   }
 
@@ -130,29 +142,20 @@ class RequestHandlerHelper(requestChannel: RequestChannel,
     if (maxThrottleTimeMs > 0 && !request.isForwarded) {
       request.apiThrottleTimeMs = maxThrottleTimeMs
       if (controllerThrottleTimeMs > requestThrottleTimeMs) {
-        quotas.controllerMutation.throttle(request, controllerThrottleTimeMs, requestChannel.sendResponse)
+        throttle(quotas.controllerMutation, request, controllerThrottleTimeMs)
       } else {
-        quotas.request.throttle(request, requestThrottleTimeMs, requestChannel.sendResponse)
+        throttle(quotas.request, request, requestThrottleTimeMs)
       }
     }
 
-    sendResponse(request, Some(createResponse(maxThrottleTimeMs)), None)
+    requestChannel.sendResponse(request, createResponse(maxThrottleTimeMs), None)
   }
 
   def sendResponseExemptThrottle(request: RequestChannel.Request,
                                  response: AbstractResponse,
                                  onComplete: Option[Send => Unit] = None): Unit = {
     quotas.request.maybeRecordExempt(request)
-    sendResponse(request, Some(response), onComplete)
-  }
-
-  def sendErrorOrCloseConnection(request: RequestChannel.Request, error: Throwable, throttleMs: Int): Unit = {
-    val requestBody = request.body[AbstractRequest]
-    val response = requestBody.getErrorResponse(throttleMs, error)
-    if (response == null)
-      closeConnection(request, requestBody.errorCounts(error))
-    else
-      sendResponse(request, Some(response), None)
+    requestChannel.sendResponse(request, response, onComplete)
   }
 
   def sendErrorResponseExemptThrottle(request: RequestChannel.Request, error: Throwable): Unit = {
@@ -162,34 +165,7 @@ class RequestHandlerHelper(requestChannel: RequestChannel,
 
   def sendNoOpResponseExemptThrottle(request: RequestChannel.Request): Unit = {
     quotas.request.maybeRecordExempt(request)
-    sendResponse(request, None, None)
-  }
-
-  def closeConnection(request: RequestChannel.Request, errorCounts: java.util.Map[Errors, Integer]): Unit = {
-    // This case is used when the request handler has encountered an error, but the client
-    // does not expect a response (e.g. when produce request has acks set to 0)
-    requestChannel.updateErrorMetrics(request.header.apiKey, errorCounts.asScala)
-    requestChannel.sendResponse(new RequestChannel.CloseConnectionResponse(request))
+    requestChannel.sendNoOpResponse(request)
   }
 
-  def sendResponse(request: RequestChannel.Request,
-                   responseOpt: Option[AbstractResponse],
-                   onComplete: Option[Send => Unit]): Unit = {
-    // Update error metrics for each error code in the response including Errors.NONE
-    responseOpt.foreach(response => requestChannel.updateErrorMetrics(request.header.apiKey, response.errorCounts.asScala))
-
-    val response = responseOpt match {
-      case Some(response) =>
-        new RequestChannel.SendResponse(
-          request,
-          request.buildResponseSend(response),
-          request.responseNode(response),
-          onComplete
-        )
-      case None =>
-        new RequestChannel.NoOpResponse(request)
-    }
-
-    requestChannel.sendResponse(response)
-  }
 }
diff --git a/core/src/main/scala/kafka/server/ThrottledChannel.scala b/core/src/main/scala/kafka/server/ThrottledChannel.scala
index 531ef5d..8091678 100644
--- a/core/src/main/scala/kafka/server/ThrottledChannel.scala
+++ b/core/src/main/scala/kafka/server/ThrottledChannel.scala
@@ -19,33 +19,35 @@ package kafka.server
 
 import java.util.concurrent.{Delayed, TimeUnit}
 
-import kafka.network
-import kafka.network.RequestChannel
-import kafka.network.RequestChannel.Response
 import kafka.utils.Logging
 import org.apache.kafka.common.utils.Time
 
+trait ThrottleCallback {
+  def startThrottling(): Unit
+  def endThrottling(): Unit
+}
 
 /**
   * Represents a request whose response has been delayed.
-  * @param request The request that has been delayed
   * @param time Time instance to use
   * @param throttleTimeMs Delay associated with this request
-  * @param channelThrottlingCallback Callback for channel throttling
+  * @param callback Callback for channel throttling
   */
-class ThrottledChannel(val request: RequestChannel.Request, val time: Time, val throttleTimeMs: Int,
-                       channelThrottlingCallback: Response => Unit)
-  extends Delayed with Logging {
+class ThrottledChannel(
+  val time: Time,
+  val throttleTimeMs: Int,
+  val callback: ThrottleCallback
+) extends Delayed with Logging {
 
   private val endTimeNanos = time.nanoseconds() + TimeUnit.MILLISECONDS.toNanos(throttleTimeMs)
 
   // Notify the socket server that throttling has started for this channel.
-  channelThrottlingCallback(new RequestChannel.StartThrottlingResponse(request))
+  callback.startThrottling()
 
   // Notify the socket server that throttling has been done for this channel.
   def notifyThrottlingDone(): Unit = {
     trace(s"Channel throttled for: $throttleTimeMs ms")
-    channelThrottlingCallback(new network.RequestChannel.EndThrottlingResponse(request))
+    callback.endThrottling()
   }
 
   override def getDelay(unit: TimeUnit): Long = {
diff --git a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala
index e4dec2e..db825ff 100644
--- a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala
+++ b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala
@@ -18,19 +18,16 @@
 package kafka.tools
 
 import kafka.network.RequestChannel
-import kafka.network.RequestConvertToJson
 import kafka.raft.RaftManager
 import kafka.server.{ApiRequestHandler, ApiVersionManager}
 import kafka.utils.Logging
 import org.apache.kafka.common.internals.FatalExitError
 import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, FetchSnapshotResponseData, VoteResponseData}
-import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
+import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage}
 import org.apache.kafka.common.record.BaseRecords
 import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, BeginQuorumEpochResponse, EndQuorumEpochResponse, FetchResponse, FetchSnapshotResponse, VoteResponse}
 import org.apache.kafka.common.utils.Time
 
-import scala.jdk.CollectionConverters._
-
 /**
  * Simple request handler implementation for use by [[TestRaftServer]].
  */
@@ -43,8 +40,7 @@ class TestRaftRequestHandler(
 
   override def handle(request: RequestChannel.Request): Unit = {
     try {
-      trace(s"Handling request:${request.requestDesc(true)} from connection ${request.context.connectionId};" +
-        s"securityProtocol:${request.context.securityProtocol},principal:${request.context.principal}")
+      trace(s"Handling request:${request.requestDesc(true)} with context ${request.context}")
       request.header.apiKey match {
         case ApiKeys.API_VERSIONS => handleApiVersions(request)
         case ApiKeys.VOTE => handleVote(request)
@@ -56,7 +52,11 @@ class TestRaftRequestHandler(
       }
     } catch {
       case e: FatalExitError => throw e
-      case e: Throwable => handleError(request, e)
+      case e: Throwable =>
+        error(s"Unexpected error handling request ${request.requestDesc(true)} " +
+          s"with context ${request.context}", e)
+        val errorResponse = request.body[AbstractRequest].getErrorResponse(e)
+        requestChannel.sendResponse(request, errorResponse, None)
     } finally {
       // The local completion time may be set while processing the request. Only record it if it's unset.
       if (request.apiLocalCompleteTimeNanos < 0)
@@ -65,7 +65,7 @@ class TestRaftRequestHandler(
   }
 
   private def handleApiVersions(request: RequestChannel.Request): Unit = {
-    sendResponse(request, Some(apiVersionManager.apiVersionResponse(throttleTimeMs = 0)))
+    requestChannel.sendResponse(request, apiVersionManager.apiVersionResponse(throttleTimeMs = 0), None)
   }
 
   private def handleVote(request: RequestChannel.Request): Unit = {
@@ -106,53 +106,8 @@ class TestRaftRequestHandler(
       } else {
         buildResponse(response)
       }
-      sendResponse(request, Some(res))
+      requestChannel.sendResponse(request, res, None)
     })
   }
 
-  private def handleError(request: RequestChannel.Request, err: Throwable): Unit = {
-    error("Error when handling request: " +
-      s"clientId=${request.header.clientId}, " +
-      s"correlationId=${request.header.correlationId}, " +
-      s"api=${request.header.apiKey}, " +
-      s"version=${request.header.apiVersion}, " +
-      s"body=${request.body[AbstractRequest]}", err)
-
-    val requestBody = request.body[AbstractRequest]
-    val response = requestBody.getErrorResponse(0, err)
-    if (response == null)
-      closeConnection(request, requestBody.errorCounts(err))
-    else
-      sendResponse(request, Some(response))
-  }
-
-  private def closeConnection(request: RequestChannel.Request, errorCounts: java.util.Map[Errors, Integer]): Unit = {
-    // This case is used when the request handler has encountered an error, but the client
-    // does not expect a response (e.g. when produce request has acks set to 0)
-    requestChannel.updateErrorMetrics(request.header.apiKey, errorCounts.asScala)
-    requestChannel.sendResponse(new RequestChannel.CloseConnectionResponse(request))
-  }
-
-  private def sendResponse(request: RequestChannel.Request,
-                           responseOpt: Option[AbstractResponse]): Unit = {
-    // Update error metrics for each error code in the response including Errors.NONE
-    responseOpt.foreach(response => requestChannel.updateErrorMetrics(request.header.apiKey, response.errorCounts.asScala))
-
-    val response = responseOpt match {
-      case Some(response) =>
-        val responseSend = request.context.buildResponseSend(response)
-        val responseString =
-          if (RequestChannel.isRequestLoggingEnabled) Some(RequestConvertToJson.response(response, request.context.apiVersion))
-          else None
-        new RequestChannel.SendResponse(request, responseSend, responseString, None)
-      case None =>
-        new RequestChannel.NoOpResponse(request)
-    }
-    sendResponse(response)
-  }
-
-  private def sendResponse(response: RequestChannel.Response): Unit = {
-    requestChannel.sendResponse(response)
-  }
-
 }
diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
index 2936144..d323030 100644
--- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
+++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala
@@ -31,7 +31,7 @@ import com.yammer.metrics.core.{Gauge, Meter}
 import javax.net.ssl._
 import kafka.metrics.KafkaYammerMetrics
 import kafka.security.CredentialProvider
-import kafka.server.{KafkaConfig, SimpleApiVersionManager, ThrottledChannel}
+import kafka.server.{KafkaConfig, SimpleApiVersionManager, ThrottleCallback, ThrottledChannel}
 import kafka.utils.Implicits._
 import kafka.utils.TestUtils
 import org.apache.kafka.common.memory.MemoryPool
@@ -147,7 +147,7 @@ class SocketServerTest {
   }
 
   def processRequestNoOpResponse(channel: RequestChannel, request: RequestChannel.Request): Unit = {
-    channel.sendResponse(new RequestChannel.NoOpResponse(request))
+    channel.sendNoOpResponse(request)
   }
 
   def connect(s: SocketServer = server,
@@ -247,7 +247,7 @@ class SocketServerTest {
     assertEquals(ClientInformation.UNKNOWN_NAME_OR_VERSION, receivedReq.context.clientInformation.softwareName)
     assertEquals(ClientInformation.UNKNOWN_NAME_OR_VERSION, receivedReq.context.clientInformation.softwareVersion)
 
-    server.dataPlaneRequestChannel.sendResponse(new RequestChannel.NoOpResponse(receivedReq))
+    server.dataPlaneRequestChannel.sendNoOpResponse(receivedReq)
 
     // Send ProduceRequest - client info expected
     sendRequest(plainSocket, producerRequestBytes())
@@ -256,7 +256,7 @@ class SocketServerTest {
     assertEquals(expectedClientSoftwareName, receivedReq.context.clientInformation.softwareName)
     assertEquals(expectedClientSoftwareVersion, receivedReq.context.clientInformation.softwareVersion)
 
-    server.dataPlaneRequestChannel.sendResponse(new RequestChannel.NoOpResponse(receivedReq))
+    server.dataPlaneRequestChannel.sendNoOpResponse(receivedReq)
 
     // Close the socket
     plainSocket.setSoLinger(true, 0)
@@ -678,10 +678,12 @@ class SocketServerTest {
     val request = receiveRequest(server.dataPlaneRequestChannel)
     val byteBuffer = RequestTestUtils.serializeRequestWithHeader(request.header, request.body[AbstractRequest])
     val send = new NetworkSend(request.context.connectionId, ByteBufferSend.sizePrefixed(byteBuffer))
-    def channelThrottlingCallback(response: RequestChannel.Response): Unit = {
-      server.dataPlaneRequestChannel.sendResponse(response)
+
+    val channelThrottlingCallback = new ThrottleCallback {
+      override def startThrottling(): Unit = server.dataPlaneRequestChannel.startThrottling(request)
+      override def endThrottling(): Unit = server.dataPlaneRequestChannel.endThrottling(request)
     }
-    val throttledChannel = new ThrottledChannel(request, new MockTime(), 100, channelThrottlingCallback)
+    val throttledChannel = new ThrottledChannel(new MockTime(), 100, channelThrottlingCallback)
     val headerLog = RequestConvertToJson.requestHeaderNode(request.header)
     val response =
       if (!noOpResponse)
diff --git a/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala b/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala
index 76b3d3e..48379ca 100644
--- a/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala
@@ -20,9 +20,7 @@ import java.net.InetAddress
 import java.util
 import java.util.Collections
 import kafka.network.RequestChannel
-import kafka.network.RequestChannel.EndThrottlingResponse
 import kafka.network.RequestChannel.Session
-import kafka.network.RequestChannel.StartThrottlingResponse
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.memory.MemoryPool
 import org.apache.kafka.common.metrics.MetricConfig
@@ -47,11 +45,11 @@ class BaseClientQuotaManagerTest {
     metrics.close()
   }
 
-  protected def callback(response: RequestChannel.Response): Unit = {
-    // Count how many times this callback is called for notifyThrottlingDone().
-    (response: @unchecked) match {
-      case _: StartThrottlingResponse =>
-      case _: EndThrottlingResponse => numCallbacks += 1
+  protected def callback: ThrottleCallback = new ThrottleCallback {
+    override def startThrottling(): Unit = {}
+    override def endThrottling(): Unit = {
+      // Count how many times this callback is called for notifyThrottlingDone().
+      numCallbacks += 1
     }
   }
 
@@ -82,8 +80,8 @@ class BaseClientQuotaManagerTest {
   }
 
   protected def throttle(quotaManager: ClientQuotaManager, user: String, clientId: String, throttleTimeMs: Int,
-                         channelThrottlingCallback: RequestChannel.Response => Unit): Unit = {
+                         channelThrottlingCallback: ThrottleCallback): Unit = {
     val (_, request) = buildRequest(FetchRequest.Builder.forConsumer(0, 1000, new util.HashMap[TopicPartition, PartitionData]))
-    quotaManager.throttle(request, throttleTimeMs, channelThrottlingCallback)
+    quotaManager.throttle(request, channelThrottlingCallback, throttleTimeMs)
   }
 }
diff --git a/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala b/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala
index 3533fcf..0ca44f3 100644
--- a/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala
@@ -26,33 +26,36 @@ import kafka.server.QuotaFactory.QuotaManagers
 import kafka.server.{ClientQuotaManager, ClientRequestQuotaManager, ControllerApis, ControllerMutationQuotaManager, KafkaConfig, MetaProperties, ReplicationQuotaManager}
 import kafka.utils.MockTime
 import org.apache.kafka.common.Uuid
-import org.apache.kafka.common.errors.ClusterAuthorizationException
 import org.apache.kafka.common.memory.MemoryPool
 import org.apache.kafka.common.message.BrokerRegistrationRequestData
 import org.apache.kafka.common.network.{ClientInformation, ListenerName}
 import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.requests.{AbstractRequest, BrokerRegistrationRequest, RequestContext, RequestHeader, RequestTestUtils}
+import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, BrokerRegistrationRequest, BrokerRegistrationResponse, RequestContext, RequestHeader, RequestTestUtils}
 import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
 import org.apache.kafka.controller.Controller
-import org.apache.kafka.metadata.{ApiMessageAndVersion, VersionRange}
-import org.apache.kafka.server.authorizer.{AuthorizableRequestContext, AuthorizationResult, Authorizer}
-import org.easymock.{Capture, EasyMock, IAnswer}
+import org.apache.kafka.metadata.ApiMessageAndVersion
+import org.apache.kafka.server.authorizer.{Action, AuthorizableRequestContext, AuthorizationResult, Authorizer}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, Test}
+import org.mockito.ArgumentMatchers._
+import org.mockito.Mockito._
+import org.mockito.{ArgumentCaptor, ArgumentMatchers}
+import scala.jdk.CollectionConverters._
 
 class ControllerApisTest {
-  // Mocks
   private val nodeId = 1
   private val brokerRack = "Rack1"
   private val clientID = "Client1"
-  private val requestChannelMetrics: RequestChannel.Metrics = EasyMock.createNiceMock(classOf[RequestChannel.Metrics])
-  private val requestChannel: RequestChannel = EasyMock.createNiceMock(classOf[RequestChannel])
+  private val requestChannelMetrics: RequestChannel.Metrics = mock(classOf[RequestChannel.Metrics])
+  private val requestChannel: RequestChannel = mock(classOf[RequestChannel])
   private val time = new MockTime
-  private val clientQuotaManager: ClientQuotaManager = EasyMock.createNiceMock(classOf[ClientQuotaManager])
-  private val clientRequestQuotaManager: ClientRequestQuotaManager = EasyMock.createNiceMock(classOf[ClientRequestQuotaManager])
-  private val clientControllerQuotaManager: ControllerMutationQuotaManager = EasyMock.createNiceMock(classOf[ControllerMutationQuotaManager])
-  private val replicaQuotaManager: ReplicationQuotaManager = EasyMock.createNiceMock(classOf[ReplicationQuotaManager])
-  private val raftManager: RaftManager[ApiMessageAndVersion] = EasyMock.createNiceMock(classOf[RaftManager[ApiMessageAndVersion]])
+  private val clientQuotaManager: ClientQuotaManager = mock(classOf[ClientQuotaManager])
+  private val clientRequestQuotaManager: ClientRequestQuotaManager = mock(classOf[ClientRequestQuotaManager])
+  private val clientControllerQuotaManager: ControllerMutationQuotaManager = mock(classOf[ControllerMutationQuotaManager])
+  private val replicaQuotaManager: ReplicationQuotaManager = mock(classOf[ReplicationQuotaManager])
+  private val raftManager: RaftManager[ApiMessageAndVersion] = mock(classOf[RaftManager[ApiMessageAndVersion]])
+  private val authorizer: Authorizer = mock(classOf[Authorizer])
+
   private val quotas = QuotaManagers(
     clientQuotaManager,
     clientQuotaManager,
@@ -62,24 +65,21 @@ class ControllerApisTest {
     replicaQuotaManager,
     replicaQuotaManager,
     None)
-  private val controller: Controller = EasyMock.createNiceMock(classOf[Controller])
+  private val controller: Controller = mock(classOf[Controller])
 
-  private def createControllerApis(authorizer: Option[Authorizer],
-                                   supportedFeatures: Map[String, VersionRange] = Map.empty): ControllerApis = {
+  private def createControllerApis(): ControllerApis = {
     val props = new Properties()
     props.put(KafkaConfig.NodeIdProp, nodeId: java.lang.Integer)
     props.put(KafkaConfig.ProcessRolesProp, "controller")
     new ControllerApis(
       requestChannel,
-      authorizer,
+      Some(authorizer),
       quotas,
       time,
-      supportedFeatures,
+      Map.empty,
       controller,
       raftManager,
       new KafkaConfig(props),
-
-      // FIXME: Would make more sense to set controllerId here
       MetaProperties(Uuid.fromString("JgxuGe9URy-E-ceaL04lEw"), nodeId = nodeId),
       Seq.empty
     )
@@ -93,8 +93,10 @@ class ControllerApisTest {
    * @tparam T - Type of AbstractRequest
    * @return
    */
-  private def buildRequest[T <: AbstractRequest](request: AbstractRequest,
-                                                 listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)): RequestChannel.Request = {
+  private def buildRequest[T <: AbstractRequest](
+    request: AbstractRequest,
+    listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
+  ): RequestChannel.Request = {
     val buffer = RequestTestUtils.serializeRequestWithHeader(
       new RequestHeader(request.apiKey, request.version, clientID, 0), request)
 
@@ -107,7 +109,7 @@ class ControllerApisTest {
   }
 
   @Test
-  def testBrokerRegistration(): Unit = {
+  def testUnauthorizedBrokerRegistration(): Unit = {
     val brokerRegistrationRequest = new BrokerRegistrationRequest.Builder(
       new BrokerRegistrationRequestData()
         .setBrokerId(nodeId)
@@ -115,25 +117,26 @@ class ControllerApisTest {
     ).build()
 
     val request = buildRequest(brokerRegistrationRequest)
+    val capturedResponse: ArgumentCaptor[AbstractResponse] = ArgumentCaptor.forClass(classOf[AbstractResponse])
 
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
-
-    val authorizer = Some[Authorizer](EasyMock.createNiceMock(classOf[Authorizer]))
-    EasyMock.expect(authorizer.get.authorize(EasyMock.anyObject[AuthorizableRequestContext](), EasyMock.anyObject())).andAnswer(
-      new IAnswer[java.util.List[AuthorizationResult]]() {
-        override def answer(): java.util.List[AuthorizationResult] = {
-          new java.util.ArrayList[AuthorizationResult](){
-            add(AuthorizationResult.DENIED)
-          }
-        }
-      }
+    when(authorizer.authorize(
+      any(classOf[AuthorizableRequestContext]),
+      any(classOf[java.util.List[Action]])
+    )).thenReturn(
+      java.util.Collections.singletonList(AuthorizationResult.DENIED)
     )
-    EasyMock.replay(requestChannel, authorizer.get)
 
-    val assertion = assertThrows(classOf[ClusterAuthorizationException],
-      () => createControllerApis(authorizer = authorizer).handleBrokerRegistration(request))
-    assert(Errors.forException(assertion) == Errors.CLUSTER_AUTHORIZATION_FAILED)
+    createControllerApis().handle(request)
+    verify(requestChannel).sendResponse(
+      ArgumentMatchers.eq(request),
+      capturedResponse.capture(),
+      ArgumentMatchers.eq(None))
+
+    assertNotNull(capturedResponse.getValue)
+
+    val brokerRegistrationResponse = capturedResponse.getValue.asInstanceOf[BrokerRegistrationResponse]
+    assertEquals(Map(Errors.CLUSTER_AUTHORIZATION_FAILED -> 1),
+      brokerRegistrationResponse.errorCounts().asScala)
   }
 
   @AfterEach
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index e80c6eb..43b9ca5 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -32,7 +32,6 @@ import kafka.coordinator.group._
 import kafka.coordinator.transaction.{InitProducerIdResult, TransactionCoordinator}
 import kafka.log.AppendOrigin
 import kafka.network.RequestChannel
-import kafka.network.RequestChannel.{CloseConnectionResponse, SendResponse}
 import kafka.server.QuotaFactory.QuotaManagers
 import kafka.server.metadata.{CachedConfigRepository, ConfigRepository, RaftMetadataCache}
 import kafka.utils.{MockTime, TestUtils}
@@ -94,7 +93,6 @@ class KafkaApisTest {
   private val forwardingManager: ForwardingManager = EasyMock.createNiceMock(classOf[ForwardingManager])
   private val autoTopicCreationManager: AutoTopicCreationManager = EasyMock.createNiceMock(classOf[AutoTopicCreationManager])
 
-  private val hostAddress: Array[Byte] = InetAddress.getByName("192.168.1.1").getAddress
   private val kafkaPrincipalSerde = new KafkaPrincipalSerde {
     override def serialize(principal: KafkaPrincipal): Array[Byte] = Utils.utf8(principal.toString)
     override def deserialize(bytes: Array[Byte]): KafkaPrincipal = SecurityUtils.parseKafkaPrincipal(Utils.utf8(bytes))
@@ -211,8 +209,6 @@ class KafkaApisTest {
       .andReturn(Seq(AuthorizationResult.ALLOWED).asJava)
       .once()
 
-    val capturedResponse = expectNoThrottling()
-
     val configRepository: ConfigRepository = EasyMock.strictMock(classOf[ConfigRepository])
     val topicConfigs = new Properties()
     val propName = "min.insync.replicas"
@@ -229,8 +225,6 @@ class KafkaApisTest {
 
     expect(metadataCache.contains(resourceName)).andReturn(true)
 
-    EasyMock.replay(metadataCache, replicaManager, clientRequestQuotaManager, requestChannel, authorizer, configRepository, adminManager)
-
     val describeConfigsRequest = new DescribeConfigsRequest.Builder(new DescribeConfigsRequestData()
       .setIncludeSynonyms(true)
       .setResources(List(new DescribeConfigsRequestData.DescribeConfigsResource()
@@ -239,12 +233,16 @@ class KafkaApisTest {
       .build(requestHeader.apiVersion)
     val request = buildRequest(describeConfigsRequest,
       requestHeader = Option(requestHeader))
-    createKafkaApis(authorizer = Some(authorizer), configRepository = configRepository).handleDescribeConfigsRequest(request)
+    val capturedResponse = expectNoThrottling(request)
+
+    EasyMock.replay(metadataCache, replicaManager, clientRequestQuotaManager, requestChannel,
+      authorizer, configRepository, adminManager)
+    createKafkaApis(authorizer = Some(authorizer), configRepository = configRepository)
+      .handleDescribeConfigsRequest(request)
 
     verify(authorizer, replicaManager)
 
-    val response = readResponse(describeConfigsRequest, capturedResponse)
-      .asInstanceOf[DescribeConfigsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[DescribeConfigsResponse]
     val results = response.data().results()
     assertEquals(1, results.size())
     val describeConfigsResult: DescribeConfigsResult = results.get(0)
@@ -252,7 +250,7 @@ class KafkaApisTest {
     assertEquals(resourceName, describeConfigsResult.resourceName())
     val configs = describeConfigsResult.configs().asScala.filter(_.name() == propName)
     assertEquals(1, configs.length)
-    val describeConfigsResponseData = configs(0)
+    val describeConfigsResponseData = configs.head
     assertEquals(propName, describeConfigsResponseData.name())
     assertEquals(propValue, describeConfigsResponseData.value())
   }
@@ -290,37 +288,33 @@ class KafkaApisTest {
 
     authorizeResource(authorizer, operation, ResourceType.TOPIC, resourceName, AuthorizationResult.ALLOWED)
 
-    val capturedResponse = expectNoThrottling()
-
     val configResource = new ConfigResource(ConfigResource.Type.TOPIC, resourceName)
     EasyMock.expect(adminManager.alterConfigs(anyObject(), EasyMock.eq(false)))
       .andAnswer(() => {
         Map(configResource -> alterConfigHandler.apply())
       })
 
-    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer,
-      adminManager, controller)
-
     val configs = Map(
       configResource -> new AlterConfigsRequest.Config(
         Seq(new AlterConfigsRequest.ConfigEntry("foo", "bar")).asJava))
     val alterConfigsRequest = new AlterConfigsRequest.Builder(configs.asJava, false).build(requestHeader.apiVersion)
 
     val request = buildRequestWithEnvelope(alterConfigsRequest, fromPrivilegedListener = true)
+    val capturedResponse = EasyMock.newCapture[AbstractResponse]()
+    val capturedRequest = EasyMock.newCapture[RequestChannel.Request]()
 
-    createKafkaApis(authorizer = Some(authorizer), enableForwarding = true).handle(request)
-
-    val envelopeRequest = request.body[EnvelopeRequest]
-    val response = readResponse(envelopeRequest, capturedResponse)
-      .asInstanceOf[EnvelopeResponse]
-
-    assertEquals(Errors.NONE, response.error)
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.capture(capturedRequest),
+      EasyMock.capture(capturedResponse),
+      EasyMock.anyObject()
+    ))
 
-    val innerResponse = AbstractResponse.parseResponse(
-      response.responseData(),
-      requestHeader
-    ).asInstanceOf[AlterConfigsResponse]
+    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer,
+      adminManager, controller)
+    createKafkaApis(authorizer = Some(authorizer), enableForwarding = true).handle(request)
 
+    assertEquals(Some(request), capturedRequest.getValue.envelope)
+    val innerResponse = capturedResponse.getValue.asInstanceOf[AlterConfigsResponse]
     val responseMap = innerResponse.data.responses().asScala.map { resourceResponse =>
       resourceResponse.resourceName() -> Errors.forCode(resourceResponse.errorCode)
     }.toMap
@@ -336,29 +330,16 @@ class KafkaApisTest {
       clientId, 0)
     val leaveGroupRequest = new LeaveGroupRequest.Builder("group",
       Collections.singletonList(new MemberIdentity())).build(requestHeader.apiVersion)
-    val serializedRequestData = RequestTestUtils.serializeRequestWithHeader(requestHeader, leaveGroupRequest)
-
-    resetToStrict(requestChannel)
 
     EasyMock.expect(controller.isActive).andReturn(true)
 
-    EasyMock.expect(requestChannel.metrics).andReturn(EasyMock.niceMock(classOf[RequestChannel.Metrics]))
-    EasyMock.expect(requestChannel.updateErrorMetrics(ApiKeys.ENVELOPE, Map(Errors.INVALID_REQUEST -> 1)))
-    val capturedResponse = expectNoThrottling()
-
-    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, controller)
-
-    val envelopeHeader = new RequestHeader(ApiKeys.ENVELOPE, ApiKeys.ENVELOPE.latestVersion,
-      clientId, 0)
-
-    val envelopeRequest = new EnvelopeRequest.Builder(serializedRequestData, new Array[Byte](0), hostAddress)
-      .build(envelopeHeader.apiVersion)
     val request = buildRequestWithEnvelope(leaveGroupRequest, fromPrivilegedListener = true)
+    val capturedResponse = expectNoThrottling(request)
 
+    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, controller)
     createKafkaApis(enableForwarding = true).handle(request)
 
-    val response = readResponse(envelopeRequest, capturedResponse)
-      .asInstanceOf[EnvelopeResponse]
+    val response = capturedResponse.getValue.asInstanceOf[EnvelopeResponse]
     assertEquals(Errors.INVALID_REQUEST, response.error())
   }
 
@@ -397,13 +378,8 @@ class KafkaApisTest {
 
     EasyMock.expect(controller.isActive).andReturn(isActiveController)
 
-    val capturedResponse = expectNoThrottling()
-
     val configResource = new ConfigResource(ConfigResource.Type.TOPIC, resourceName)
 
-    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer,
-      adminManager, controller)
-
     val configs = Map(
       configResource -> new AlterConfigsRequest.Config(
         Seq(new AlterConfigsRequest.ConfigEntry("foo", "bar")).asJava))
@@ -412,19 +388,31 @@ class KafkaApisTest {
 
     val request = buildRequestWithEnvelope(alterConfigsRequest,
       fromPrivilegedListener = fromPrivilegedListener)
-    createKafkaApis(authorizer = Some(authorizer), enableForwarding = true).handle(request)
 
+    val capturedResponse = EasyMock.newCapture[AbstractResponse]()
     if (shouldCloseConnection) {
-      assertTrue(capturedResponse.getValue.isInstanceOf[CloseConnectionResponse])
+      EasyMock.expect(requestChannel.closeConnection(
+        EasyMock.eq(request),
+        EasyMock.eq(java.util.Collections.emptyMap())
+      ))
     } else {
-      val envelopeRequest = request.body[EnvelopeRequest]
-      val response = readResponse(envelopeRequest, capturedResponse)
-        .asInstanceOf[EnvelopeResponse]
+      EasyMock.expect(requestChannel.sendResponse(
+        EasyMock.eq(request),
+        EasyMock.capture(capturedResponse),
+        EasyMock.eq(None)
+      ))
+    }
 
-      assertEquals(expectedError, response.error())
+    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer,
+      adminManager, controller)
+    createKafkaApis(authorizer = Some(authorizer), enableForwarding = true).handle(request)
 
-      verify(authorizer, adminManager)
+    if (!shouldCloseConnection) {
+      val response = capturedResponse.getValue.asInstanceOf[EnvelopeResponse]
+      assertEquals(expectedError, response.error)
     }
+
+    verify(authorizer, adminManager, requestChannel)
   }
 
   @Test
@@ -452,7 +440,7 @@ class KafkaApisTest {
 
     EasyMock.expect(controller.isActive).andReturn(false)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
 
     EasyMock.expect(adminManager.alterConfigs(anyObject(), EasyMock.eq(false)))
       .andReturn(Map(authorizedResource -> ApiError.NONE))
@@ -483,7 +471,7 @@ class KafkaApisTest {
 
     EasyMock.expect(controller.isActive).andReturn(false)
 
-    expectNoThrottling()
+    expectNoThrottling(request)
 
     EasyMock.expect(forwardingManager.forwardRequest(
       EasyMock.eq(request),
@@ -519,10 +507,9 @@ class KafkaApisTest {
   }
 
   private def verifyAlterConfigResult(alterConfigsRequest: AlterConfigsRequest,
-                                      capturedResponse: Capture[RequestChannel.Response],
+                                      capturedResponse: Capture[AbstractResponse],
                                       expectedResults: Map[String, Errors]): Unit = {
-    val response = readResponse(alterConfigsRequest, capturedResponse)
-      .asInstanceOf[AlterConfigsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[AlterConfigsResponse]
     val responseMap = response.data.responses().asScala.map { resourceResponse =>
       resourceResponse.resourceName() -> Errors.forCode(resourceResponse.errorCode)
     }.toMap
@@ -559,19 +546,19 @@ class KafkaApisTest {
 
     EasyMock.expect(controller.isActive).andReturn(true)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
 
     EasyMock.expect(adminManager.incrementalAlterConfigs(anyObject(), EasyMock.eq(false)))
       .andReturn(Map(authorizedResource -> ApiError.NONE))
 
     EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer,
       adminManager, controller)
-
     createKafkaApis(authorizer = Some(authorizer)).handleIncrementalAlterConfigsRequest(request)
 
-    verifyIncrementalAlterConfigResult(incrementalAlterConfigsRequest,
-      capturedResponse, Map(authorizedTopic -> Errors.NONE,
-        unauthorizedTopic -> Errors.TOPIC_AUTHORIZATION_FAILED))
+    verifyIncrementalAlterConfigResult(capturedResponse, Map(
+      authorizedTopic -> Errors.NONE,
+      unauthorizedTopic -> Errors.TOPIC_AUTHORIZATION_FAILED
+    ))
 
     verify(authorizer, adminManager)
   }
@@ -593,15 +580,12 @@ class KafkaApisTest {
     new IncrementalAlterConfigsRequest.Builder(resourceMap, false)
   }
 
-  private def verifyIncrementalAlterConfigResult(incrementalAlterConfigsRequest: IncrementalAlterConfigsRequest,
-                                                 capturedResponse: Capture[RequestChannel.Response],
+  private def verifyIncrementalAlterConfigResult(capturedResponse: Capture[AbstractResponse],
                                                  expectedResults: Map[String, Errors]): Unit = {
-    val response = readResponse(incrementalAlterConfigsRequest, capturedResponse)
-      .asInstanceOf[IncrementalAlterConfigsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[IncrementalAlterConfigsResponse]
     val responseMap = response.data.responses().asScala.map { resourceResponse =>
       resourceResponse.resourceName() -> Errors.forCode(resourceResponse.errorCode)
     }.toMap
-
     assertEquals(expectedResults, responseMap)
   }
 
@@ -624,15 +608,13 @@ class KafkaApisTest {
 
     EasyMock.expect(controller.isActive).andReturn(true)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
 
     EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer,
       adminManager, controller)
-
     createKafkaApis(authorizer = Some(authorizer)).handleAlterClientQuotasRequest(request)
 
-    verifyAlterClientQuotaResult(alterClientQuotasRequest,
-      capturedResponse, Map(quotaEntity -> Errors.CLUSTER_AUTHORIZATION_FAILED))
+    verifyAlterClientQuotaResult(capturedResponse, Map(quotaEntity -> Errors.CLUSTER_AUTHORIZATION_FAILED))
 
     verify(authorizer, adminManager)
   }
@@ -643,11 +625,9 @@ class KafkaApisTest {
     testForwardableAPI(ApiKeys.ALTER_CLIENT_QUOTAS, requestBuilder)
   }
 
-  private def verifyAlterClientQuotaResult(alterClientQuotasRequest: AlterClientQuotasRequest,
-                                           capturedResponse: Capture[RequestChannel.Response],
+  private def verifyAlterClientQuotaResult(capturedResponse: Capture[AbstractResponse],
                                            expected: Map[ClientQuotaEntity, Errors]): Unit = {
-    val response = readResponse(alterClientQuotasRequest, capturedResponse)
-      .asInstanceOf[AlterClientQuotasResponse]
+    val response = capturedResponse.getValue.asInstanceOf[AlterClientQuotasResponse]
     val futures = expected.keys.map(quotaEntity => quotaEntity -> new KafkaFutureImpl[Void]()).toMap
     response.complete(futures.asJava)
     futures.foreach {
@@ -678,8 +658,6 @@ class KafkaApisTest {
 
     EasyMock.expect(controller.isActive).andReturn(true)
 
-    val capturedResponse = expectNoThrottling()
-
     val topics = new CreateTopicsRequestData.CreatableTopicCollection(2)
     val topicToCreate = new CreateTopicsRequestData.CreatableTopic()
       .setName(authorizedTopic)
@@ -699,6 +677,8 @@ class KafkaApisTest {
     val request = buildRequest(createTopicsRequest,
       fromPrivilegedListener = true, requestHeader = Option(requestHeader))
 
+    val capturedResponse = expectNoThrottling(request)
+
     EasyMock.expect(clientControllerQuotaManager.newQuotaFor(
       EasyMock.eq(request), EasyMock.eq(6))).andReturn(UnboundedControllerMutationQuota)
 
@@ -776,10 +756,9 @@ class KafkaApisTest {
   }
 
   private def verifyCreateTopicsResult(createTopicsRequest: CreateTopicsRequest,
-                                       capturedResponse: Capture[RequestChannel.Response],
+                                       capturedResponse: Capture[AbstractResponse],
                                        expectedResults: Map[String, Errors]): Unit = {
-    val response = readResponse(createTopicsRequest, capturedResponse)
-      .asInstanceOf[CreateTopicsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[CreateTopicsResponse]
     val responseMap = response.data.topics().asScala.map { topicResponse =>
       topicResponse.name() -> Errors.forCode(topicResponse.errorCode)
     }.toMap
@@ -910,7 +889,7 @@ class KafkaApisTest {
     ).build(requestHeader.apiVersion)
     val request = buildRequest(findCoordinatorRequest)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
 
     verifyTopicCreation(topicName, true, true, request)
 
@@ -920,8 +899,7 @@ class KafkaApisTest {
     createKafkaApis(authorizer = Some(authorizer),
       overrideProperties = topicConfigOverride).handleFindCoordinatorRequest(request)
 
-    val response = readResponse(findCoordinatorRequest, capturedResponse)
-      .asInstanceOf[FindCoordinatorResponse]
+    val response = capturedResponse.getValue.asInstanceOf[FindCoordinatorResponse]
     assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, response.error())
 
     verify(authorizer, autoTopicCreationManager)
@@ -1012,7 +990,7 @@ class KafkaApisTest {
     ).build(requestHeader.apiVersion)
     val request = buildRequest(metadataRequest)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
 
     verifyTopicCreation(topicName, enableAutoTopicCreation, isInternal, request)
 
@@ -1022,9 +1000,7 @@ class KafkaApisTest {
     createKafkaApis(authorizer = Some(authorizer), enableForwarding = enableAutoTopicCreation,
       overrideProperties = topicConfigOverride).handleTopicMetadataRequest(request)
 
-    val response = readResponse(metadataRequest, capturedResponse)
-      .asInstanceOf[MetadataResponse]
-
+    val response = capturedResponse.getValue.asInstanceOf[MetadataResponse]
     val expectedMetadataResponse = util.Collections.singletonList(new TopicMetadata(
       expectedError,
       topicName,
@@ -1088,12 +1064,11 @@ class KafkaApisTest {
           ))).build()
 
       val request = buildRequest(offsetCommitRequest)
-      val capturedResponse = expectNoThrottling()
+      val capturedResponse = expectNoThrottling(request)
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel)
       createKafkaApis().handleOffsetCommitRequest(request)
 
-      val response = readResponse(offsetCommitRequest, capturedResponse)
-        .asInstanceOf[OffsetCommitResponse]
+      val response = capturedResponse.getValue.asInstanceOf[OffsetCommitResponse]
       assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION,
         Errors.forCode(response.data().topics().get(0).partitions().get(0).errorCode()))
     }
@@ -1122,12 +1097,11 @@ class KafkaApisTest {
       ).build()
       val request = buildRequest(offsetCommitRequest)
 
-      val capturedResponse = expectNoThrottling()
+      val capturedResponse = expectNoThrottling(request)
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel)
       createKafkaApis().handleTxnOffsetCommitRequest(request)
 
-      val response = readResponse(offsetCommitRequest, capturedResponse)
-        .asInstanceOf[TxnOffsetCommitResponse]
+      val response = capturedResponse.getValue.asInstanceOf[TxnOffsetCommitResponse]
       assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response.errors().get(invalidTopicPartition))
     }
 
@@ -1144,7 +1118,7 @@ class KafkaApisTest {
       EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, groupCoordinator)
 
       val topicPartition = new TopicPartition(topic, 1)
-      val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+      val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
       val responseCallback: Capture[Map[TopicPartition, Errors] => Unit] = EasyMock.newCapture()
 
       val partitionOffsetCommitData = new TxnOffsetCommitRequest.CommittedOffset(15L, "", Optional.empty())
@@ -1175,14 +1149,17 @@ class KafkaApisTest {
       )).andAnswer(
         () => responseCallback.getValue.apply(Map(topicPartition -> Errors.COORDINATOR_LOAD_IN_PROGRESS)))
 
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+      EasyMock.expect(requestChannel.sendResponse(
+        EasyMock.eq(request),
+        EasyMock.capture(capturedResponse),
+        EasyMock.eq(None)
+      ))
 
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, groupCoordinator)
 
       createKafkaApis().handleTxnOffsetCommitRequest(request)
 
-      val response = readResponse(offsetCommitRequest, capturedResponse)
-        .asInstanceOf[TxnOffsetCommitResponse]
+      val response = capturedResponse.getValue.asInstanceOf[TxnOffsetCommitResponse]
 
       if (version < 2) {
         assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, response.errors().get(topicPartition))
@@ -1201,7 +1178,7 @@ class KafkaApisTest {
 
       EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
 
-      val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+      val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
       val responseCallback: Capture[InitProducerIdResult => Unit] = EasyMock.newCapture()
 
       val transactionalId = "txnId"
@@ -1240,14 +1217,17 @@ class KafkaApisTest {
       )).andAnswer(
         () => responseCallback.getValue.apply(InitProducerIdResult(producerId, epoch, Errors.PRODUCER_FENCED)))
 
-      EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+      EasyMock.expect(requestChannel.sendResponse(
+        EasyMock.eq(request),
+        EasyMock.capture(capturedResponse),
+        EasyMock.eq(None)
+      ))
 
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
 
       createKafkaApis().handleInitProducerIdRequest(request)
 
-      val response = readResponse(initProducerIdRequest, capturedResponse)
-        .asInstanceOf[InitProducerIdResponse]
+      val response = capturedResponse.getValue.asInstanceOf[InitProducerIdResponse]
 
       if (version < 4) {
         assertEquals(Errors.INVALID_PRODUCER_EPOCH.code, response.data.errorCode)
@@ -1266,7 +1246,7 @@ class KafkaApisTest {
 
       EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, groupCoordinator, txnCoordinator)
 
-      val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+      val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
       val responseCallback: Capture[Errors => Unit] = EasyMock.newCapture()
 
       val groupId = "groupId"
@@ -1297,14 +1277,17 @@ class KafkaApisTest {
       )).andAnswer(
         () => responseCallback.getValue.apply(Errors.PRODUCER_FENCED))
 
-      EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+      EasyMock.expect(requestChannel.sendResponse(
+        EasyMock.eq(request),
+        EasyMock.capture(capturedResponse),
+        EasyMock.eq(None)
+      ))
 
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator, groupCoordinator)
 
       createKafkaApis().handleAddOffsetsToTxnRequest(request)
 
-      val response = readResponse(addOffsetsToTxnRequest, capturedResponse)
-        .asInstanceOf[AddOffsetsToTxnResponse]
+      val response = capturedResponse.getValue.asInstanceOf[AddOffsetsToTxnResponse]
 
       if (version < 2) {
         assertEquals(Errors.INVALID_PRODUCER_EPOCH.code, response.data.errorCode)
@@ -1323,7 +1306,7 @@ class KafkaApisTest {
 
       EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
 
-      val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+      val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
       val responseCallback: Capture[Errors => Unit] = EasyMock.newCapture()
 
       val transactionalId = "txnId"
@@ -1351,14 +1334,17 @@ class KafkaApisTest {
       )).andAnswer(
         () => responseCallback.getValue.apply(Errors.PRODUCER_FENCED))
 
-      EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+      EasyMock.expect(requestChannel.sendResponse(
+        EasyMock.eq(request),
+        EasyMock.capture(capturedResponse),
+        EasyMock.eq(None)
+      ))
 
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
 
       createKafkaApis().handleAddPartitionToTxnRequest(request)
 
-      val response = readResponse(addPartitionsToTxnRequest, capturedResponse)
-        .asInstanceOf[AddPartitionsToTxnResponse]
+      val response = capturedResponse.getValue.asInstanceOf[AddPartitionsToTxnResponse]
 
       if (version < 2) {
         assertEquals(Collections.singletonMap(topicPartition, Errors.INVALID_PRODUCER_EPOCH), response.errors())
@@ -1374,10 +1360,9 @@ class KafkaApisTest {
     addTopicToMetadataCache(topic, numPartitions = 2)
 
     for (version <- ApiKeys.END_TXN.oldestVersion to ApiKeys.END_TXN.latestVersion) {
-
       EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
 
-      val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+      val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
       val responseCallback: Capture[Errors => Unit]  = EasyMock.newCapture()
 
       val transactionalId = "txnId"
@@ -1402,14 +1387,16 @@ class KafkaApisTest {
       )).andAnswer(
         () => responseCallback.getValue.apply(Errors.PRODUCER_FENCED))
 
-      EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+      EasyMock.expect(requestChannel.sendResponse(
+        EasyMock.eq(request),
+        EasyMock.capture(capturedResponse),
+        EasyMock.eq(None)
+      ))
 
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator)
-
       createKafkaApis().handleEndTxnRequest(request)
 
-      val response = readResponse(endTxnRequest, capturedResponse)
-        .asInstanceOf[EndTxnResponse]
+      val response = capturedResponse.getValue.asInstanceOf[EndTxnResponse]
 
       if (version < 2) {
         assertEquals(Errors.INVALID_PRODUCER_EPOCH.code, response.data.errorCode)
@@ -1456,7 +1443,7 @@ class KafkaApisTest {
         EasyMock.anyObject())
       ).andAnswer(() => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.INVALID_PRODUCER_EPOCH))))
 
-      val capturedResponse = expectNoThrottling()
+      val capturedResponse = expectNoThrottling(request)
       EasyMock.expect(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
         anyObject[RequestChannel.Request](), anyDouble, anyLong)).andReturn(0)
 
@@ -1464,8 +1451,7 @@ class KafkaApisTest {
 
       createKafkaApis().handleProduceRequest(request)
 
-      val response = readResponse(produceRequest, capturedResponse)
-        .asInstanceOf[ProduceResponse]
+      val response = capturedResponse.getValue.asInstanceOf[ProduceResponse]
 
       assertEquals(1, response.responses().size())
       for (partitionResponse <- response.responses().asScala) {
@@ -1488,12 +1474,11 @@ class KafkaApisTest {
       ).build()
       val request = buildRequest(addPartitionsToTxnRequest)
 
-      val capturedResponse = expectNoThrottling()
+      val capturedResponse = expectNoThrottling(request)
       EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel)
       createKafkaApis().handleAddPartitionToTxnRequest(request)
 
-      val response = readResponse(addPartitionsToTxnRequest, capturedResponse)
-        .asInstanceOf[AddPartitionsToTxnResponse]
+      val response = capturedResponse.getValue.asInstanceOf[AddPartitionsToTxnResponse]
       assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response.errors().get(invalidTopicPartition))
     }
 
@@ -1531,17 +1516,20 @@ class KafkaApisTest {
     val topicPartition = new TopicPartition("t", 0)
     val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(topicPartition))
     val expectedErrors = Map(topicPartition -> Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT).asJava
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+    val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
 
     EasyMock.expect(replicaManager.getMagic(topicPartition))
       .andReturn(Some(RecordBatch.MAGIC_VALUE_V1))
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.eq(None)
+    ))
     EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel)
 
     createKafkaApis().handleWriteTxnMarkersRequest(request)
 
-    val markersResponse = readResponse(writeTxnMarkersRequest, capturedResponse)
-      .asInstanceOf[WriteTxnMarkersResponse]
+    val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse]
     assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L))
   }
 
@@ -1550,17 +1538,20 @@ class KafkaApisTest {
     val topicPartition = new TopicPartition("t", 0)
     val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(topicPartition))
     val expectedErrors = Map(topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION).asJava
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+    val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
 
     EasyMock.expect(replicaManager.getMagic(topicPartition))
       .andReturn(None)
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.eq(None)
+    ))
     EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel)
 
     createKafkaApis().handleWriteTxnMarkersRequest(request)
 
-    val markersResponse = readResponse(writeTxnMarkersRequest, capturedResponse)
-      .asInstanceOf[WriteTxnMarkersResponse]
+    val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse]
     assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L))
   }
 
@@ -1571,7 +1562,7 @@ class KafkaApisTest {
     val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(tp1, tp2))
     val expectedErrors = Map(tp1 -> Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT, tp2 -> Errors.NONE).asJava
 
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+    val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
     val responseCallback: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture()
 
     EasyMock.expect(replicaManager.getMagic(tp1))
@@ -1589,13 +1580,16 @@ class KafkaApisTest {
       EasyMock.anyObject())
     ).andAnswer(() => responseCallback.getValue.apply(Map(tp2 -> new PartitionResponse(Errors.NONE))))
 
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.eq(None)
+    ))
     EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel)
 
     createKafkaApis().handleWriteTxnMarkersRequest(request)
 
-    val markersResponse = readResponse(writeTxnMarkersRequest, capturedResponse)
-      .asInstanceOf[WriteTxnMarkersResponse]
+    val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse]
     assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L))
     EasyMock.verify(replicaManager)
   }
@@ -1708,7 +1702,7 @@ class KafkaApisTest {
     val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(tp1, tp2))
     val expectedErrors = Map(tp1 -> Errors.UNKNOWN_TOPIC_OR_PARTITION, tp2 -> Errors.NONE).asJava
 
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+    val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
     val responseCallback: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture()
 
     EasyMock.expect(replicaManager.getMagic(tp1))
@@ -1726,13 +1720,16 @@ class KafkaApisTest {
       EasyMock.anyObject())
     ).andAnswer(() => responseCallback.getValue.apply(Map(tp2 -> new PartitionResponse(Errors.NONE))))
 
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.eq(None)
+    ))
     EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel)
 
     createKafkaApis().handleWriteTxnMarkersRequest(request)
 
-    val markersResponse = readResponse(writeTxnMarkersRequest, capturedResponse)
-      .asInstanceOf[WriteTxnMarkersResponse]
+    val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse]
     assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L))
     EasyMock.verify(replicaManager)
   }
@@ -1798,15 +1795,14 @@ class KafkaApisTest {
     ).build()
     val request = buildRequest(describeGroupsRequest)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
     EasyMock.expect(groupCoordinator.handleDescribeGroup(EasyMock.eq(groupId)))
       .andReturn((Errors.NONE, groupSummary))
     EasyMock.replay(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel)
 
     createKafkaApis().handleDescribeGroupRequest(request)
 
-    val response = readResponse(describeGroupsRequest, capturedResponse)
-      .asInstanceOf[DescribeGroupsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[DescribeGroupsResponse]
 
     val group = response.data().groups().get(0)
     assertEquals(Errors.NONE, Errors.forCode(group.errorCode()))
@@ -1852,7 +1848,7 @@ class KafkaApisTest {
     ).build()
     val request = buildRequest(offsetDeleteRequest)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
     EasyMock.expect(groupCoordinator.handleDeleteOffsets(
       EasyMock.eq(group),
       EasyMock.eq(Seq(
@@ -1872,8 +1868,7 @@ class KafkaApisTest {
 
     createKafkaApis().handleOffsetDeleteRequest(request)
 
-    val response = readResponse(offsetDeleteRequest, capturedResponse)
-      .asInstanceOf[OffsetDeleteResponse]
+    val response = capturedResponse.getValue.asInstanceOf[OffsetDeleteResponse]
 
     def errorForPartition(topic: String, partition: Int): Errors = {
       Errors.forCode(response.data.topics.find(topic).partitions.find(partition).errorCode())
@@ -1906,16 +1901,15 @@ class KafkaApisTest {
           .setTopics(topics)
       ).build()
       val request = buildRequest(offsetDeleteRequest)
+      val capturedResponse = expectNoThrottling(request)
 
-      val capturedResponse = expectNoThrottling()
       EasyMock.expect(groupCoordinator.handleDeleteOffsets(EasyMock.eq(group), EasyMock.eq(Seq.empty)))
         .andReturn((Errors.NONE, Map.empty))
       EasyMock.replay(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel)
 
       createKafkaApis().handleOffsetDeleteRequest(request)
 
-      val response = readResponse(offsetDeleteRequest, capturedResponse)
-        .asInstanceOf[OffsetDeleteResponse]
+      val response = capturedResponse.getValue.asInstanceOf[OffsetDeleteResponse]
 
       assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION,
         Errors.forCode(response.data.topics.find(topic).partitions.find(invalidPartitionId).errorCode()))
@@ -1937,15 +1931,14 @@ class KafkaApisTest {
     ).build()
     val request = buildRequest(offsetDeleteRequest)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
     EasyMock.expect(groupCoordinator.handleDeleteOffsets(EasyMock.eq(group), EasyMock.eq(Seq.empty)))
       .andReturn((Errors.GROUP_ID_NOT_FOUND, Map.empty))
     EasyMock.replay(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel)
 
     createKafkaApis().handleOffsetDeleteRequest(request)
 
-    val response = readResponse(offsetDeleteRequest, capturedResponse)
-      .asInstanceOf[OffsetDeleteResponse]
+    val response = capturedResponse.getValue.asInstanceOf[OffsetDeleteResponse]
 
     assertEquals(Errors.GROUP_ID_NOT_FOUND, Errors.forCode(response.data.errorCode()))
   }
@@ -1963,9 +1956,6 @@ class KafkaApisTest {
       fetchOnlyFromLeader = EasyMock.eq(true))
     ).andThrow(error.exception)
 
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel)
-
     val targetTimes = List(new ListOffsetsTopic()
       .setName(tp.topic)
       .setPartitions(List(new ListOffsetsPartition()
@@ -1975,10 +1965,12 @@ class KafkaApisTest {
     val listOffsetRequest = ListOffsetsRequest.Builder.forConsumer(true, isolationLevel)
       .setTargetTimes(targetTimes).build()
     val request = buildRequest(listOffsetRequest)
+    val capturedResponse = expectNoThrottling(request)
+
+    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel)
     createKafkaApis().handleListOffsetRequest(request)
 
-    val response = readResponse(listOffsetRequest, capturedResponse)
-      .asInstanceOf[ListOffsetsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[ListOffsetsResponse]
     val partitionDataOptional = response.topics.asScala.find(_.name == tp.topic).get
       .partitions.asScala.find(_.partitionIndex == tp.partition)
     assertTrue(partitionDataOptional.isDefined)
@@ -2125,19 +2117,18 @@ class KafkaApisTest {
       anyObject[util.List[TopicPartition]],
       anyBoolean)).andReturn(fetchContext)
 
-    val capturedResponse = expectNoThrottling()
     EasyMock.expect(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       anyObject[RequestChannel.Request](), anyDouble, anyLong)).andReturn(0)
 
-    EasyMock.replay(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, fetchManager)
-
     val fetchRequest = new FetchRequest.Builder(9, 9, -1, 100, 0, fetchData)
       .build()
     val request = buildRequest(fetchRequest)
+    val capturedResponse = expectNoThrottling(request)
+
+    EasyMock.replay(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, fetchManager)
     createKafkaApis().handleFetchRequest(request)
 
-    val response = readResponse(fetchRequest, capturedResponse)
-      .asInstanceOf[FetchResponse[BaseRecords]]
+    val response = capturedResponse.getValue.asInstanceOf[FetchResponse[BaseRecords]]
     assertTrue(response.responseData.containsKey(tp))
 
     val partitionData = response.responseData.get(tp)
@@ -2216,8 +2207,6 @@ class KafkaApisTest {
   def testJoinGroupWhenAnErrorOccurs(version: Short): Unit = {
     EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
 
-    val capturedResponse = expectNoThrottling()
-
     val groupId = "group"
     val memberId = "member1"
     val protocolType = "consumer"
@@ -2250,17 +2239,16 @@ class KafkaApisTest {
     ).build(version)
 
     val requestChannelRequest = buildRequest(joinGroupRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
 
     EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
-
     createKafkaApis().handleJoinGroupRequest(requestChannelRequest)
 
     EasyMock.verify(groupCoordinator)
 
     capturedCallback.getValue.apply(JoinGroupResult(memberId, Errors.INCONSISTENT_GROUP_PROTOCOL))
 
-    val response = readResponse(joinGroupRequest, capturedResponse)
-      .asInstanceOf[JoinGroupResponse]
+    val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
 
     assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, response.error)
     assertEquals(0, response.data.members.size)
@@ -2288,8 +2276,6 @@ class KafkaApisTest {
   def testJoinGroupProtocolType(version: Short): Unit = {
     EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
 
-    val capturedResponse = expectNoThrottling()
-
     val groupId = "group"
     val memberId = "member1"
     val protocolType = "consumer"
@@ -2323,9 +2309,9 @@ class KafkaApisTest {
     ).build(version)
 
     val requestChannelRequest = buildRequest(joinGroupRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
 
     EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
-
     createKafkaApis().handleJoinGroupRequest(requestChannelRequest)
 
     EasyMock.verify(groupCoordinator)
@@ -2340,8 +2326,7 @@ class KafkaApisTest {
       error = Errors.NONE
     ))
 
-    val response = readResponse(joinGroupRequest, capturedResponse)
-      .asInstanceOf[JoinGroupResponse]
+    val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
 
     assertEquals(Errors.NONE, response.error)
     assertEquals(0, response.data.members.size)
@@ -2349,12 +2334,7 @@ class KafkaApisTest {
     assertEquals(0, response.data.generationId)
     assertEquals(memberId, response.data.leader)
     assertEquals(protocolName, response.data.protocolName)
-
-    if (version >= 7) {
-      assertEquals(protocolType, response.data.protocolType)
-    } else {
-      assertNull(response.data.protocolType)
-    }
+    assertEquals(protocolType, response.data.protocolType)
 
     EasyMock.verify(clientRequestQuotaManager, requestChannel)
   }
@@ -2369,8 +2349,6 @@ class KafkaApisTest {
   def testSyncGroupProtocolTypeAndName(version: Short): Unit = {
     EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
 
-    val capturedResponse = expectNoThrottling()
-
     val groupId = "group"
     val memberId = "member1"
     val protocolType = "consumer"
@@ -2399,9 +2377,9 @@ class KafkaApisTest {
     ).build(version)
 
     val requestChannelRequest = buildRequest(syncGroupRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
 
     EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
-
     createKafkaApis().handleSyncGroupRequest(requestChannelRequest)
 
     EasyMock.verify(groupCoordinator)
@@ -2413,17 +2391,11 @@ class KafkaApisTest {
       error = Errors.NONE
     ))
 
-    val response = readResponse(syncGroupRequest, capturedResponse)
-      .asInstanceOf[SyncGroupResponse]
+    val response = capturedResponse.getValue.asInstanceOf[SyncGroupResponse]
 
     assertEquals(Errors.NONE, response.error)
     assertArrayEquals(Array.empty[Byte], response.data.assignment)
-
-    if (version >= 5) {
-      assertEquals(protocolType, response.data.protocolType)
-    } else {
-      assertNull(response.data.protocolType)
-    }
+    assertEquals(protocolType, response.data.protocolType)
 
     EasyMock.verify(clientRequestQuotaManager, requestChannel)
   }
@@ -2438,8 +2410,6 @@ class KafkaApisTest {
   def testSyncGroupProtocolTypeAndNameAreMandatorySinceV5(version: Short): Unit = {
     EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
 
-    val capturedResponse = expectNoThrottling()
-
     val groupId = "group"
     val memberId = "member1"
     val protocolType = "consumer"
@@ -2468,9 +2438,9 @@ class KafkaApisTest {
     ).build(version)
 
     val requestChannelRequest = buildRequest(syncGroupRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
 
     EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager)
-
     createKafkaApis().handleSyncGroupRequest(requestChannelRequest)
 
     EasyMock.verify(groupCoordinator)
@@ -2484,8 +2454,7 @@ class KafkaApisTest {
       ))
     }
 
-    val response = readResponse(syncGroupRequest, capturedResponse)
-      .asInstanceOf[SyncGroupResponse]
+    val response = capturedResponse.getValue.asInstanceOf[SyncGroupResponse]
 
     if (version < 5) {
       assertEquals(Errors.NONE, response.error)
@@ -2498,9 +2467,6 @@ class KafkaApisTest {
 
   @Test
   def rejectJoinGroupRequestWhenStaticMembershipNotSupported(): Unit = {
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val joinGroupRequest = new JoinGroupRequest.Builder(
       new JoinGroupRequestData()
         .setGroupId("test")
@@ -2511,18 +2477,18 @@ class KafkaApisTest {
     ).build()
 
     val requestChannelRequest = buildRequest(joinGroupRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
+
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
     createKafkaApis(KAFKA_2_2_IV1).handleJoinGroupRequest(requestChannelRequest)
 
-    val response = readResponse(joinGroupRequest, capturedResponse).asInstanceOf[JoinGroupResponse]
+    val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse]
     assertEquals(Errors.UNSUPPORTED_VERSION, response.error())
     EasyMock.replay(groupCoordinator)
   }
 
   @Test
   def rejectSyncGroupRequestWhenStaticMembershipNotSupported(): Unit = {
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val syncGroupRequest = new SyncGroupRequest.Builder(
       new SyncGroupRequestData()
         .setGroupId("test")
@@ -2532,18 +2498,18 @@ class KafkaApisTest {
     ).build()
 
     val requestChannelRequest = buildRequest(syncGroupRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
+
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
     createKafkaApis(KAFKA_2_2_IV1).handleSyncGroupRequest(requestChannelRequest)
 
-    val response = readResponse(syncGroupRequest, capturedResponse).asInstanceOf[SyncGroupResponse]
+    val response = capturedResponse.getValue.asInstanceOf[SyncGroupResponse]
     assertEquals(Errors.UNSUPPORTED_VERSION, response.error)
     EasyMock.replay(groupCoordinator)
   }
 
   @Test
   def rejectHeartbeatRequestWhenStaticMembershipNotSupported(): Unit = {
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val heartbeatRequest = new HeartbeatRequest.Builder(
       new HeartbeatRequestData()
         .setGroupId("test")
@@ -2552,18 +2518,18 @@ class KafkaApisTest {
         .setGenerationId(1)
     ).build()
     val requestChannelRequest = buildRequest(heartbeatRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
+
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
     createKafkaApis(KAFKA_2_2_IV1).handleHeartbeatRequest(requestChannelRequest)
 
-    val response = readResponse(heartbeatRequest, capturedResponse).asInstanceOf[HeartbeatResponse]
+    val response = capturedResponse.getValue.asInstanceOf[HeartbeatResponse]
     assertEquals(Errors.UNSUPPORTED_VERSION, response.error())
     EasyMock.replay(groupCoordinator)
   }
 
   @Test
   def rejectOffsetCommitRequestWhenStaticMembershipNotSupported(): Unit = {
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val offsetCommitRequest = new OffsetCommitRequest.Builder(
       new OffsetCommitRequestData()
         .setGroupId("test")
@@ -2584,6 +2550,9 @@ class KafkaApisTest {
     ).build()
 
     val requestChannelRequest = buildRequest(offsetCommitRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
+
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
     createKafkaApis(KAFKA_2_2_IV1).handleOffsetCommitRequest(requestChannelRequest)
 
     val expectedTopicErrors = Collections.singletonList(
@@ -2595,7 +2564,7 @@ class KafkaApisTest {
             .setErrorCode(Errors.UNSUPPORTED_VERSION.code())
         ))
     )
-    val response = readResponse(offsetCommitRequest, capturedResponse).asInstanceOf[OffsetCommitResponse]
+    val response = capturedResponse.getValue.asInstanceOf[OffsetCommitResponse]
     assertEquals(expectedTopicErrors, response.data.topics())
     EasyMock.replay(groupCoordinator)
   }
@@ -2723,9 +2692,6 @@ class KafkaApisTest {
 
   @Test
   def rejectInitProducerIdWhenIdButNotEpochProvided(): Unit = {
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val initProducerIdRequest = new InitProducerIdRequest.Builder(
       new InitProducerIdRequestData()
         .setTransactionalId("known")
@@ -2735,18 +2701,17 @@ class KafkaApisTest {
     ).build()
 
     val requestChannelRequest = buildRequest(initProducerIdRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
+
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
     createKafkaApis(KAFKA_2_2_IV1).handleInitProducerIdRequest(requestChannelRequest)
 
-    val response = readResponse(initProducerIdRequest, capturedResponse)
-      .asInstanceOf[InitProducerIdResponse]
+    val response = capturedResponse.getValue.asInstanceOf[InitProducerIdResponse]
     assertEquals(Errors.INVALID_REQUEST, response.error)
   }
 
   @Test
   def rejectInitProducerIdWhenEpochButNotIdProvided(): Unit = {
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val initProducerIdRequest = new InitProducerIdRequest.Builder(
       new InitProducerIdRequestData()
         .setTransactionalId("known")
@@ -2755,9 +2720,12 @@ class KafkaApisTest {
         .setProducerEpoch(2)
     ).build()
     val requestChannelRequest = buildRequest(initProducerIdRequest)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
+
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
     createKafkaApis(KAFKA_2_2_IV1).handleInitProducerIdRequest(requestChannelRequest)
 
-    val response = readResponse(initProducerIdRequest, capturedResponse).asInstanceOf[InitProducerIdResponse]
+    val response = capturedResponse.getValue.asInstanceOf[InitProducerIdResponse]
     assertEquals(Errors.INVALID_REQUEST, response.error)
   }
 
@@ -2783,7 +2751,7 @@ class KafkaApisTest {
     val updateMetadataRequest = createBasicMetadataRequest("topicA", 1, brokerEpochInRequest, 1)
     val request = buildRequest(updateMetadataRequest)
 
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+    val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
 
     EasyMock.expect(controller.brokerEpoch).andStubReturn(currentBrokerEpoch)
     EasyMock.expect(replicaManager.maybeUpdateMetadataCache(
@@ -2793,12 +2761,15 @@ class KafkaApisTest {
       Seq()
     )
 
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.eq(None)
+    ))
     EasyMock.replay(replicaManager, controller, requestChannel)
 
     createKafkaApis().handleUpdateMetadataRequest(request)
-    val updateMetadataResponse = readResponse(updateMetadataRequest, capturedResponse)
-      .asInstanceOf[UpdateMetadataResponse]
+    val updateMetadataResponse = capturedResponse.getValue.asInstanceOf[UpdateMetadataResponse]
     assertEquals(expectedError, updateMetadataResponse.error())
     EasyMock.verify(replicaManager)
   }
@@ -2824,7 +2795,7 @@ class KafkaApisTest {
   def testLeaderAndIsrRequest(currentBrokerEpoch: Long, brokerEpochInRequest: Long, expectedError: Errors): Unit = {
     val controllerId = 2
     val controllerEpoch = 6
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+    val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
     val partitionStates = Seq(
       new LeaderAndIsrRequestData.LeaderAndIsrPartitionState()
         .setTopicName("topicW")
@@ -2860,12 +2831,15 @@ class KafkaApisTest {
       response
     )
 
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.eq(None)
+    ))
     EasyMock.replay(replicaManager, controller, requestChannel)
 
     createKafkaApis().handleLeaderAndIsrRequest(request)
-    val leaderAndIsrResponse = readResponse(leaderAndIsrRequest, capturedResponse)
-      .asInstanceOf[LeaderAndIsrResponse]
+    val leaderAndIsrResponse = capturedResponse.getValue.asInstanceOf[LeaderAndIsrResponse]
     assertEquals(expectedError, leaderAndIsrResponse.error())
     EasyMock.verify(replicaManager)
   }
@@ -2891,7 +2865,7 @@ class KafkaApisTest {
   def testStopReplicaRequest(currentBrokerEpoch: Long, brokerEpochInRequest: Long, expectedError: Errors): Unit = {
     val controllerId = 0
     val controllerEpoch = 5
-    val capturedResponse: Capture[RequestChannel.Response] = EasyMock.newCapture()
+    val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture()
     val fooPartition = new TopicPartition("foo", 0)
     val topicStates = Seq(
       new StopReplicaTopicState()
@@ -2923,13 +2897,16 @@ class KafkaApisTest {
         fooPartition -> Errors.NONE
       ), Errors.NONE)
     )
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.eq(None)
+    ))
 
     EasyMock.replay(controller, replicaManager, requestChannel)
 
     createKafkaApis().handleStopReplicaRequest(request)
-    val stopReplicaResponse = readResponse(stopReplicaRequest, capturedResponse)
-      .asInstanceOf[StopReplicaResponse]
+    val stopReplicaResponse = capturedResponse.getValue.asInstanceOf[StopReplicaResponse]
     assertEquals(expectedError, stopReplicaResponse.error())
     EasyMock.verify(replicaManager)
   }
@@ -2965,7 +2942,7 @@ class KafkaApisTest {
     val listGroupsRequest = new ListGroupsRequest.Builder(data).build()
     val requestChannelRequest = buildRequest(listGroupsRequest)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
     val expectedStates: Set[String] = if (state.isDefined) Set(state.get) else Set()
     EasyMock.expect(groupCoordinator.handleListGroups(expectedStates))
       .andReturn((Errors.NONE, overviews))
@@ -2973,7 +2950,7 @@ class KafkaApisTest {
 
     createKafkaApis().handleListGroupsRequest(requestChannelRequest)
 
-    val response = readResponse(listGroupsRequest, capturedResponse).asInstanceOf[ListGroupsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[ListGroupsResponse]
     assertEquals(Errors.NONE.code, response.data.errorCode)
     response
   }
@@ -3006,17 +2983,16 @@ class KafkaApisTest {
       0, 0, Seq.empty[UpdateMetadataPartitionState].asJava, brokers.asJava, Collections.emptyMap()).build()
     metadataCache.updateMetadata(correlationId = 0, updateMetadataRequest)
 
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val describeClusterRequest = new DescribeClusterRequest.Builder(new DescribeClusterRequestData()
       .setIncludeClusterAuthorizedOperations(true)).build()
 
     val request = buildRequest(describeClusterRequest, plaintextListener)
+    val capturedResponse = expectNoThrottling(request)
+
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
     createKafkaApis().handleDescribeCluster(request)
 
-    val describeClusterResponse = readResponse(describeClusterRequest, capturedResponse)
-      .asInstanceOf[DescribeClusterResponse]
+    val describeClusterResponse = capturedResponse.getValue.asInstanceOf[DescribeClusterResponse]
 
     assertEquals(metadataCache.getControllerId.get, describeClusterResponse.data.controllerId)
     assertEquals(clusterId, describeClusterResponse.data.clusterId)
@@ -3064,14 +3040,14 @@ class KafkaApisTest {
   }
 
   private def sendMetadataRequestWithInconsistentListeners(requestListener: ListenerName): MetadataResponse = {
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(clientRequestQuotaManager, requestChannel)
-
     val metadataRequest = MetadataRequest.Builder.allTopics.build()
     val requestChannelRequest = buildRequest(metadataRequest, requestListener)
+    val capturedResponse = expectNoThrottling(requestChannelRequest)
+    EasyMock.replay(clientRequestQuotaManager, requestChannel)
+
     createKafkaApis().handleTopicMetadataRequest(requestChannelRequest)
 
-    readResponse(metadataRequest, capturedResponse).asInstanceOf[MetadataResponse]
+    capturedResponse.getValue.asInstanceOf[MetadataResponse]
   }
 
   private def testConsumerListOffsetLatest(isolationLevel: IsolationLevel): Unit = {
@@ -3087,9 +3063,6 @@ class KafkaApisTest {
       fetchOnlyFromLeader = EasyMock.eq(true))
     ).andReturn(Some(new TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, latestOffset, currentLeaderEpoch)))
 
-    val capturedResponse = expectNoThrottling()
-    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel)
-
     val targetTimes = List(new ListOffsetsTopic()
       .setName(tp.topic)
       .setPartitions(List(new ListOffsetsPartition()
@@ -3098,9 +3071,12 @@ class KafkaApisTest {
     val listOffsetRequest = ListOffsetsRequest.Builder.forConsumer(true, isolationLevel)
       .setTargetTimes(targetTimes).build()
     val request = buildRequest(listOffsetRequest)
+    val capturedResponse = expectNoThrottling(request)
+
+    EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel)
     createKafkaApis().handleListOffsetRequest(request)
 
-    val response = readResponse(listOffsetRequest, capturedResponse).asInstanceOf[ListOffsetsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[ListOffsetsResponse]
     val partitionDataOptional = response.topics.asScala.find(_.name == tp.topic).get
       .partitions.asScala.find(_.partitionIndex == tp.partition)
     assertTrue(partitionDataOptional.isDefined)
@@ -3164,28 +3140,22 @@ class KafkaApisTest {
       requestChannelMetrics, envelope = None)
   }
 
-  private def readResponse(request: AbstractRequest, capturedResponse: Capture[RequestChannel.Response]) = {
-    val api = request.apiKey
-    val response = capturedResponse.getValue
-    assertTrue(response.isInstanceOf[SendResponse], s"Unexpected response type: ${response.getClass}")
-    val sendResponse = response.asInstanceOf[SendResponse]
-    val send = sendResponse.responseSend
-    val channel = new ByteBufferChannel(send.size)
-    send.writeTo(channel)
-    channel.close()
-    channel.buffer.getInt() // read the size
-    ResponseHeader.parse(channel.buffer, api.responseHeaderVersion(request.version))
-    AbstractResponse.parseResponse(api, channel.buffer, request.version)
-  }
-
-  private def expectNoThrottling(): Capture[RequestChannel.Response] = {
+  private def expectNoThrottling(request: RequestChannel.Request): Capture[AbstractResponse] = {
     EasyMock.expect(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(EasyMock.anyObject[RequestChannel.Request](),
       EasyMock.anyObject[Long])).andReturn(0)
-    EasyMock.expect(clientRequestQuotaManager.throttle(EasyMock.anyObject[RequestChannel.Request](), EasyMock.eq(0),
-      EasyMock.anyObject[RequestChannel.Response => Unit]()))
 
-    val capturedResponse = EasyMock.newCapture[RequestChannel.Response]()
-    EasyMock.expect(requestChannel.sendResponse(EasyMock.capture(capturedResponse)))
+    EasyMock.expect(clientRequestQuotaManager.throttle(
+      EasyMock.eq(request),
+      EasyMock.anyObject[ThrottleCallback](),
+      EasyMock.eq(0)))
+
+    val capturedResponse = EasyMock.newCapture[AbstractResponse]()
+    EasyMock.expect(requestChannel.sendResponse(
+      EasyMock.eq(request),
+      EasyMock.capture(capturedResponse),
+      EasyMock.anyObject()
+    ))
+
     capturedResponse
   }
 
@@ -3244,7 +3214,7 @@ class KafkaApisTest {
 
     EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel)
 
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
     val t0p0 = new TopicPartition("t0", 0)
     val t0p1 = new TopicPartition("t0", 1)
     val t0p2 = new TopicPartition("t0", 2)
@@ -3261,8 +3231,7 @@ class KafkaApisTest {
 
     createKafkaApis().handleAlterReplicaLogDirsRequest(request)
 
-    val response = readResponse(alterReplicaLogDirsRequest, capturedResponse)
-      .asInstanceOf[AlterReplicaLogDirsResponse]
+    val response = capturedResponse.getValue.asInstanceOf[AlterReplicaLogDirsResponse]
     assertEquals(partitionResults, response.data.results.asScala.flatMap { tr =>
       tr.partitions().asScala.map { pr =>
         new TopicPartition(tr.topicName, pr.partitionIndex) -> Errors.forCode(pr.errorCode)
@@ -3358,13 +3327,12 @@ class KafkaApisTest {
 
     val describeProducersRequest = new DescribeProducersRequest.Builder(data).build()
     val request = buildRequest(describeProducersRequest)
-    val capturedResponse = expectNoThrottling()
+    val capturedResponse = expectNoThrottling(request)
 
     EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator, authorizer)
     createKafkaApis(authorizer = Some(authorizer)).handleDescribeProducersRequest(request)
 
-    val response = readResponse(describeProducersRequest, capturedResponse)
-      .asInstanceOf[DescribeProducersResponse]
+    val response = capturedResponse.getValue.asInstanceOf[DescribeProducersResponse]
     assertEquals(3, response.data.topics.size())
     assertEquals(Set("foo", "bar", "baz"), response.data.topics.asScala.map(_.name).toSet)
 
diff --git a/core/src/test/scala/unit/kafka/server/ThrottledChannelExpirationTest.scala b/core/src/test/scala/unit/kafka/server/ThrottledChannelExpirationTest.scala
index d7dec5b..15ad22d 100644
--- a/core/src/test/scala/unit/kafka/server/ThrottledChannelExpirationTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ThrottledChannelExpirationTest.scala
@@ -18,22 +18,11 @@
 package kafka.server
 
 
-import java.net.InetAddress
-import java.util
 import java.util.Collections
 import java.util.concurrent.{DelayQueue, TimeUnit}
-import kafka.network.RequestChannel
-import kafka.network.RequestChannel.{EndThrottlingResponse, Response, StartThrottlingResponse}
-import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.memory.MemoryPool
+
 import org.apache.kafka.common.metrics.MetricConfig
-import org.apache.kafka.common.network.ClientInformation
-import org.apache.kafka.common.network.ListenerName
-import org.apache.kafka.common.requests.FetchRequest.PartitionData
-import org.apache.kafka.common.requests.{AbstractRequest, FetchRequest, RequestContext, RequestHeader, RequestTestUtils}
-import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
 import org.apache.kafka.common.utils.MockTime
-import org.easymock.EasyMock
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{BeforeEach, Test}
 
@@ -44,28 +33,13 @@ class ThrottledChannelExpirationTest {
   private val metrics = new org.apache.kafka.common.metrics.Metrics(new MetricConfig(),
                                                                     Collections.emptyList(),
                                                                     time)
-  private val request = buildRequest(FetchRequest.Builder.forConsumer(0, 1000, new util.HashMap[TopicPartition, PartitionData]))._2
-
-  private def buildRequest[T <: AbstractRequest](builder: AbstractRequest.Builder[T],
-                                                 listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)): (T, RequestChannel.Request) = {
-
-    val request = builder.build()
-    val buffer = RequestTestUtils.serializeRequestWithHeader(
-      new RequestHeader(builder.apiKey, request.version, "", 0), request)
-    val requestChannelMetrics: RequestChannel.Metrics = EasyMock.createNiceMock(classOf[RequestChannel.Metrics])
-
-    // read the header from the buffer first so that the body can be read next from the Request constructor
-    val header = RequestHeader.parse(buffer)
-    val context = new RequestContext(header, "1", InetAddress.getLocalHost, KafkaPrincipal.ANONYMOUS,
-      listenerName, SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false)
-    (request, new RequestChannel.Request(processor = 1, context = context, startTimeNanos =  0, MemoryPool.NONE, buffer,
-      requestChannelMetrics))
-  }
+  private val callback = new ThrottleCallback {
+    override def startThrottling(): Unit = {
+      numCallbacksForStartThrottling += 1
+    }
 
-  def callback(response: Response): Unit = {
-    (response: @unchecked) match {
-      case _: StartThrottlingResponse => numCallbacksForStartThrottling += 1
-      case _: EndThrottlingResponse => numCallbacksForEndThrottling += 1
+    override def endThrottling(): Unit = {
+      numCallbacksForEndThrottling += 1
     }
   }
 
@@ -83,10 +57,10 @@ class ThrottledChannelExpirationTest {
     val reaper = new clientMetrics.ThrottledChannelReaper(delayQueue, "")
     try {
       // Add 4 elements to the queue out of order. Add 2 elements with the same expire timestamp.
-      val channel1 = new ThrottledChannel(request, time, 10, callback)
-      val channel2 = new ThrottledChannel(request, time, 30, callback)
-      val channel3 = new ThrottledChannel(request, time, 30, callback)
-      val channel4 = new ThrottledChannel(request, time, 20, callback)
+      val channel1 = new ThrottledChannel(time, 10, callback)
+      val channel2 = new ThrottledChannel(time, 30, callback)
+      val channel3 = new ThrottledChannel(time, 30, callback)
+      val channel4 = new ThrottledChannel(time, 20, callback)
       delayQueue.add(channel1)
       delayQueue.add(channel2)
       delayQueue.add(channel3)
@@ -110,9 +84,9 @@ class ThrottledChannelExpirationTest {
 
   @Test
   def testThrottledChannelDelay(): Unit = {
-    val t1: ThrottledChannel = new ThrottledChannel(request, time, 10, callback)
-    val t2: ThrottledChannel = new ThrottledChannel(request, time, 20, callback)
-    val t3: ThrottledChannel = new ThrottledChannel(request, time, 20, callback)
+    val t1: ThrottledChannel = new ThrottledChannel(time, 10, callback)
+    val t2: ThrottledChannel = new ThrottledChannel(time, 20, callback)
+    val t3: ThrottledChannel = new ThrottledChannel(time, 20, callback)
     assertEquals(10, t1.throttleTimeMs)
     assertEquals(20, t2.throttleTimeMs)
     assertEquals(20, t3.throttleTimeMs)
@@ -124,4 +98,5 @@ class ThrottledChannelExpirationTest {
       time.sleep(10)
     }
   }
+
 }