You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jo...@apache.org on 2023/11/10 05:14:30 UTC

(kafka) branch 3.6 updated: cherrypick KAFKA-15653: Pass requestLocal as argument to callback so we use the correct one for the thread (#14712)

This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch 3.6
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.6 by this push:
     new 814de813eaa cherrypick KAFKA-15653: Pass requestLocal as argument to callback so we use the correct one for the thread  (#14712)
814de813eaa is described below

commit 814de813eaa35ff47f526983f3ec63db558864d5
Author: Justine Olshan <jo...@confluent.io>
AuthorDate: Thu Nov 9 21:14:25 2023 -0800

    cherrypick KAFKA-15653: Pass requestLocal as argument to callback so we use the correct one for the thread  (#14712)
    
    With the new callback mechanism we were accidentally passing context with the wrong request local. Now include a RequestLocal as an explicit argument to the callback.
    
    Also make the arguments passed through the callback clearer by separating the method out.
    
    Added a test to ensure we use the request handler's request local and not the one passed in when the callback is executed via the request handler.
    
    Reviewers: Ismael Juma ismael@juma.me.uk, Divij Vaidya diviv@amazon.com, David Jacot djacot@confluent.io, Jason Gustafson jason@confluent.io, Artem Livshits alivshits@confluent.io, Jun Rao junrao@gmail.com,
    
    Conflicts:
    core/src/main/scala/kafka/server/KafkaRequestHandler.scala
    core/src/main/scala/kafka/server/ReplicaManager.scala
    core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
    
    Conflicts around verification guard, running the callback on the same thread, and checking the coordinator node before AddPartitionsToTxnManager. Remove test that is not applicable since we don't have https://github.com/apache/kafka/commit/08aa33127a4254497456aa7a0c1646c7c38adf81
---
 .../main/scala/kafka/network/RequestChannel.scala  |   4 +-
 .../scala/kafka/server/KafkaRequestHandler.scala   |  21 +-
 .../main/scala/kafka/server/ReplicaManager.scala   | 245 ++++++++++++---------
 .../kafka/server/KafkaRequestHandlerTest.scala     |  64 +++++-
 4 files changed, 205 insertions(+), 129 deletions(-)

diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala
index 477f02a9c98..88686082998 100644
--- a/core/src/main/scala/kafka/network/RequestChannel.scala
+++ b/core/src/main/scala/kafka/network/RequestChannel.scala
@@ -24,7 +24,7 @@ import com.fasterxml.jackson.databind.JsonNode
 import com.typesafe.scalalogging.Logger
 import com.yammer.metrics.core.Meter
 import kafka.network
-import kafka.server.KafkaConfig
+import kafka.server.{KafkaConfig, RequestLocal}
 import kafka.utils.{Logging, NotNothing, Pool}
 import kafka.utils.Implicits._
 import org.apache.kafka.common.config.ConfigResource
@@ -80,7 +80,7 @@ object RequestChannel extends Logging {
     }
   }
 
-  case class CallbackRequest(fun: () => Unit,
+  case class CallbackRequest(fun: RequestLocal => Unit,
                              originalRequest: Request) extends BaseRequest
 
   class Request(val processor: Int,
diff --git a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
index 316cf92ca5a..7885b436c59 100755
--- a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
+++ b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala
@@ -54,23 +54,27 @@ object KafkaRequestHandler {
   }
 
   /**
-   * Wrap callback to schedule it on a request thread.
-   * NOTE: this function must be called on a request thread.
-   * @param fun Callback function to execute
-   * @return Wrapped callback that would execute `fun` on a request thread
+   * Creates a wrapped callback to be executed asynchronously on an arbitrary request thread.
+   * NOTE: this function must be originally called from a request thread.
+   * @param asyncCompletionCallback A callback method that we intend to call from another thread after an asynchronous
+   *                                action completes. The RequestLocal passed in must belong to the request handler
+   *                                thread that is executing the callback.
+   * @param requestLocal The RequestLocal for the current request handler thread in case we need to execute the callback
+   *                     function synchronously from the calling thread (used for testing)
+   * @return Wrapped callback will schedule `asyncCompletionCallback` on an arbitrary request thread
    */
-  def wrap[T](fun: T => Unit): T => Unit = {
+  def wrapAsyncCallback[T](asyncCompletionCallback: (RequestLocal, T) => Unit, requestLocal: RequestLocal): T => Unit = {
     val requestChannel = threadRequestChannel.get()
     val currentRequest = threadCurrentRequest.get()
     if (requestChannel == null || currentRequest == null) {
       if (!bypassThreadCheck)
         throw new IllegalStateException("Attempted to reschedule to request handler thread from non-request handler thread.")
-      T => fun(T)
+      T => asyncCompletionCallback(requestLocal, T)
     } else {
       T => {
         // The requestChannel and request are captured in this lambda, so when it's executed on the callback thread
         // we can re-schedule the original callback on a request thread and update the metrics accordingly.
-        requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(() => fun(T), currentRequest))
+        requestChannel.sendCallbackRequest(RequestChannel.CallbackRequest(newRequestLocal => asyncCompletionCallback(newRequestLocal, T), currentRequest))
       }
     }
   }
@@ -127,7 +131,7 @@ class KafkaRequestHandler(id: Int,
             }
             
             threadCurrentRequest.set(originalRequest)
-            callback.fun()
+            callback.fun(requestLocal)
           } catch {
             case e: FatalExitError =>
               completeShutdown()
@@ -169,6 +173,7 @@ class KafkaRequestHandler(id: Int,
 
   private def completeShutdown(): Unit = {
     requestLocal.close()
+    threadRequestChannel.remove()
     shutdownComplete.countDown()
   }
 
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index a574ba37716..53a9cd12efe 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -727,7 +727,6 @@ class ReplicaManager(val config: KafkaConfig,
                     transactionStatePartition: Option[Int] = None,
                     actionQueue: ActionQueue = this.actionQueue): Unit = {
     if (isValidRequiredAcks(requiredAcks)) {
-      val sTime = time.milliseconds
 
       val verificationGuards: mutable.Map[TopicPartition, Object] = mutable.Map[TopicPartition, Object]()
       val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition, errorsPerPartition) =
@@ -741,117 +740,19 @@ class ReplicaManager(val config: KafkaConfig,
           (verifiedEntries.toMap, unverifiedEntries.toMap, errorEntries.toMap)
         }
 
-      def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
-        val verifiedEntries = 
-          if (unverifiedEntries.isEmpty)
-            allEntries 
-          else
-            allEntries.filter { case (tp, _) =>
-              !unverifiedEntries.contains(tp)
-            }
-        
-        val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
-          origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
-        debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
-
-        def produceStatusResult(appendResult: Map[TopicPartition, LogAppendResult],
-                                useCustomMessage: Boolean): Map[TopicPartition, ProducePartitionStatus] = {
-          appendResult.map { case (topicPartition, result) =>
-            topicPartition -> ProducePartitionStatus(
-              result.info.lastOffset + 1, // required offset
-              new PartitionResponse(
-                result.error,
-                result.info.firstOffset.map[Long](_.messageOffset).orElse(-1L),
-                result.info.lastOffset,
-                result.info.logAppendTime,
-                result.info.logStartOffset,
-                result.info.recordErrors,
-                if (useCustomMessage) result.exception.get.getMessage else result.info.errorMessage
-              )
-            ) // response status
-          }
-        }
-        
-        val unverifiedResults = unverifiedEntries.map {
-          case (topicPartition, error) =>
-            val finalException =
-              error match {
-                case Errors.INVALID_TXN_STATE => error.exception("Partition was not added to the transaction")
-                case Errors.CONCURRENT_TRANSACTIONS |
-                     Errors.COORDINATOR_LOAD_IN_PROGRESS |
-                     Errors.COORDINATOR_NOT_AVAILABLE |
-                     Errors.NOT_COORDINATOR => new NotEnoughReplicasException(
-                         s"Unable to verify the partition has been added to the transaction. Underlying error: ${error.toString}")
-                case _ => error.exception()
-            }
-            topicPartition -> LogAppendResult(
-              LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
-              Some(finalException)
-            )
-        }
-
-        val errorResults = errorsPerPartition.map {
-          case (topicPartition, error) =>
-            topicPartition -> LogAppendResult(
-              LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
-              Some(error.exception())
-            )
-        }
-
-        val produceStatus = Set((localProduceResults, false), (unverifiedResults, true), (errorResults, false)).flatMap {
-          case (results, useCustomError) => produceStatusResult(results, useCustomError)
-        }.toMap
-        val allResults = localProduceResults ++ unverifiedResults ++ errorResults
-
-        actionQueue.add {
-          () => allResults.foreach { case (topicPartition, result) =>
-            val requestKey = TopicPartitionOperationKey(topicPartition)
-            result.info.leaderHwChange match {
-              case LeaderHwChange.INCREASED =>
-                // some delayed operations may be unblocked after HW changed
-                delayedProducePurgatory.checkAndComplete(requestKey)
-                delayedFetchPurgatory.checkAndComplete(requestKey)
-                delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
-              case LeaderHwChange.SAME =>
-                // probably unblock some follower fetch requests since log end offset has been updated
-                delayedFetchPurgatory.checkAndComplete(requestKey)
-              case LeaderHwChange.NONE =>
-              // nothing
-            }
-          }
-        }
-
-        recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats })
-
-        if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) {
-          // create delayed produce operation
-          val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
-          val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)
-
-          // create a list of (topic, partition) pairs to use as keys for this delayed produce operation
-          val producerRequestKeys = allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
-
-          // try to complete the request immediately, otherwise put it into the purgatory
-          // this is because while the delayed produce operation is being created, new
-          // requests may arrive and hence make this operation completable.
-          delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
-        } else {
-          // we can respond immediately
-          val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus }
-          responseCallback(produceResponseStatus)
-        }
-      }
-
       if (notYetVerifiedEntriesPerPartition.isEmpty || addPartitionsToTxnManager.isEmpty) {
-        appendEntries(verifiedEntriesPerPartition)(Map.empty)
+        appendEntries(verifiedEntriesPerPartition, internalTopicsAllowed, origin, requiredAcks, verificationGuards.toMap,
+          errorsPerPartition, recordConversionStatsCallback, timeout, responseCallback, delayedProduceLock)(requestLocal, Map.empty)
       } else {
         // For unverified entries, send a request to verify. When verified, the append process will proceed via the callback.
         val (error, node) = getTransactionCoordinator(transactionStatePartition.get)
 
         if (error != Errors.NONE) {
-          appendEntries(entriesPerPartition)(notYetVerifiedEntriesPerPartition.map {
-            case (tp, _) => (tp, error)
-          }.toMap)
+          appendEntries(verifiedEntriesPerPartition, internalTopicsAllowed, origin, requiredAcks, verificationGuards.toMap,
+            errorsPerPartition, recordConversionStatsCallback, timeout, responseCallback, delayedProduceLock)(requestLocal,
+            notYetVerifiedEntriesPerPartition.map {
+              case (tp, _) => (tp, error)
+            }.toMap)
         } else {
           val topicGrouping = notYetVerifiedEntriesPerPartition.keySet.groupBy(tp => tp.topic())
           val topicCollection = new AddPartitionsToTxnTopicCollection()
@@ -871,7 +772,21 @@ class ReplicaManager(val config: KafkaConfig,
             .setVerifyOnly(true)
             .setTopics(topicCollection)
 
-          addPartitionsToTxnManager.foreach(_.addTxnData(node, notYetVerifiedTransaction, KafkaRequestHandler.wrap(appendEntries(entriesPerPartition)(_))))
+          addPartitionsToTxnManager.foreach(_.addTxnData(node, notYetVerifiedTransaction, KafkaRequestHandler.wrapAsyncCallback(
+            appendEntries(
+              entriesPerPartition,
+              internalTopicsAllowed,
+              origin,
+              requiredAcks,
+              verificationGuards.toMap,
+              errorsPerPartition,
+              recordConversionStatsCallback,
+              timeout,
+              responseCallback,
+              delayedProduceLock
+            ),
+            requestLocal)
+          ))
         }
       }
     } else {
@@ -889,6 +804,122 @@ class ReplicaManager(val config: KafkaConfig,
     }
   }
 
+  /*
+   * Note: This method can be used as a callback in a different request thread. Ensure that correct RequestLocal
+   * is passed when executing this method. Accessing non-thread-safe data structures should be avoided if possible.
+   */
+  private def appendEntries(allEntries: Map[TopicPartition, MemoryRecords],
+                            internalTopicsAllowed: Boolean,
+                            origin: AppendOrigin,
+                            requiredAcks: Short,
+                            verificationGuards: Map[TopicPartition, Object],
+                            errorsPerPartition: Map[TopicPartition, Errors],
+                            recordConversionStatsCallback: Map[TopicPartition, RecordConversionStats] => Unit,
+                            timeout: Long,
+                            responseCallback: Map[TopicPartition, PartitionResponse] => Unit,
+                            delayedProduceLock: Option[Lock])
+                           (requestLocal: RequestLocal, unverifiedEntries: Map[TopicPartition, Errors]): Unit = {
+    val sTime = time.milliseconds
+    val verifiedEntries =
+      if (unverifiedEntries.isEmpty)
+        allEntries
+      else
+        allEntries.filter { case (tp, _) =>
+          !unverifiedEntries.contains(tp)
+        }
+
+    val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed,
+      origin, verifiedEntries, requiredAcks, requestLocal, verificationGuards.toMap)
+    debug("Produce to local log in %d ms".format(time.milliseconds - sTime))
+
+    def produceStatusResult(appendResult: Map[TopicPartition, LogAppendResult],
+                            useCustomMessage: Boolean): Map[TopicPartition, ProducePartitionStatus] = {
+      appendResult.map { case (topicPartition, result) =>
+        topicPartition -> ProducePartitionStatus(
+          result.info.lastOffset + 1, // required offset
+          new PartitionResponse(
+            result.error,
+            result.info.firstOffset.map[Long](_.messageOffset).orElse(-1L),
+            result.info.lastOffset,
+            result.info.logAppendTime,
+            result.info.logStartOffset,
+            result.info.recordErrors,
+            if (useCustomMessage) result.exception.get.getMessage else result.info.errorMessage
+          )
+        ) // response status
+      }
+    }
+
+    val unverifiedResults = unverifiedEntries.map {
+      case (topicPartition, error) =>
+        val finalException =
+          error match {
+            case Errors.INVALID_TXN_STATE => error.exception("Partition was not added to the transaction")
+            case Errors.CONCURRENT_TRANSACTIONS |
+                 Errors.COORDINATOR_LOAD_IN_PROGRESS |
+                 Errors.COORDINATOR_NOT_AVAILABLE |
+                 Errors.NOT_COORDINATOR => new NotEnoughReplicasException(
+              s"Unable to verify the partition has been added to the transaction. Underlying error: ${error.toString}")
+            case _ => error.exception()
+          }
+        topicPartition -> LogAppendResult(
+          LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
+          Some(finalException)
+        )
+    }
+
+    val errorResults = errorsPerPartition.map {
+      case (topicPartition, error) =>
+        topicPartition -> LogAppendResult(
+          LogAppendInfo.UNKNOWN_LOG_APPEND_INFO,
+          Some(error.exception())
+        )
+    }
+
+    val produceStatus = Set((localProduceResults, false), (unverifiedResults, true), (errorResults, false)).flatMap {
+      case (results, useCustomError) => produceStatusResult(results, useCustomError)
+    }.toMap
+    val allResults = localProduceResults ++ unverifiedResults ++ errorResults
+
+    actionQueue.add {
+      () => allResults.foreach { case (topicPartition, result) =>
+        val requestKey = TopicPartitionOperationKey(topicPartition)
+        result.info.leaderHwChange match {
+          case LeaderHwChange.INCREASED =>
+            // some delayed operations may be unblocked after HW changed
+            delayedProducePurgatory.checkAndComplete(requestKey)
+            delayedFetchPurgatory.checkAndComplete(requestKey)
+            delayedDeleteRecordsPurgatory.checkAndComplete(requestKey)
+          case LeaderHwChange.SAME =>
+            // probably unblock some follower fetch requests since log end offset has been updated
+            delayedFetchPurgatory.checkAndComplete(requestKey)
+          case LeaderHwChange.NONE =>
+          // nothing
+        }
+      }
+    }
+
+    recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats })
+
+    if (delayedProduceRequestRequired(requiredAcks, allEntries, allResults)) {
+      // create delayed produce operation
+      val produceMetadata = ProduceMetadata(requiredAcks, produceStatus)
+      val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock)
+
+      // create a list of (topic, partition) pairs to use as keys for this delayed produce operation
+      val producerRequestKeys = allEntries.keys.map(TopicPartitionOperationKey(_)).toSeq
+
+      // try to complete the request immediately, otherwise put it into the purgatory
+      // this is because while the delayed produce operation is being created, new
+      // requests may arrive and hence make this operation completable.
+      delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys)
+    } else {
+      // we can respond immediately
+      val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus }
+      responseCallback(produceResponseStatus)
+    }
+  }
+
   private def partitionEntriesForVerification(verificationGuards: mutable.Map[TopicPartition, Object],
                                               entriesPerPartition: Map[TopicPartition, MemoryRecords],
                                               verifiedEntries: mutable.Map[TopicPartition, MemoryRecords],
diff --git a/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
index 822dd5263d6..dd3f7996999 100644
--- a/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
+++ b/core/src/test/scala/kafka/server/KafkaRequestHandlerTest.scala
@@ -32,10 +32,11 @@ import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.ValueSource
 import org.mockito.ArgumentMatchers
 import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.{mock, when}
+import org.mockito.Mockito.{mock, times, verify, when}
 
 import java.net.InetAddress
 import java.nio.ByteBuffer
+import java.util.concurrent.CompletableFuture
 import java.util.concurrent.atomic.AtomicInteger
 
 class KafkaRequestHandlerTest {
@@ -53,14 +54,17 @@ class KafkaRequestHandlerTest {
       val request = makeRequest(time, metrics)
       requestChannel.sendRequest(request)
 
-      def callback(ms: Int): Unit = {
-        time.sleep(ms)
-        handler.stop()
-      }
-
       when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
         time.sleep(2)
-        KafkaRequestHandler.wrap(callback(_: Int))(1)
+        // Prepare the callback.
+        val callback = KafkaRequestHandler.wrapAsyncCallback(
+          (reqLocal: RequestLocal, ms: Int) => {
+            time.sleep(ms)
+            handler.stop()
+          },
+          RequestLocal.NoCaching)
+        // Execute the callback asynchronously.
+        CompletableFuture.runAsync(() => callback(1))
         request.apiLocalCompleteTimeNanos = time.nanoseconds
       }
 
@@ -86,16 +90,19 @@ class KafkaRequestHandlerTest {
     var handledCount = 0
     var tryCompleteActionCount = 0
 
-    def callback(x: Int): Unit = {
-      handler.stop()
-    }
-
     val request = makeRequest(time, metrics)
     requestChannel.sendRequest(request)
 
     when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
       handledCount = handledCount + 1
-      KafkaRequestHandler.wrap(callback(_: Int))(1)
+      // Prepare the callback.
+      val callback = KafkaRequestHandler.wrapAsyncCallback(
+        (reqLocal: RequestLocal, ms: Int) => {
+          handler.stop()
+        },
+        RequestLocal.NoCaching)
+      // Execute the callback asynchronously.
+      CompletableFuture.runAsync(() => callback(1))
     }
 
     when(apiHandler.tryCompleteActions()).thenAnswer { _ =>
@@ -108,6 +115,39 @@ class KafkaRequestHandlerTest {
     assertEquals(1, tryCompleteActionCount)
   }
 
+  @Test
+  def testHandlingCallbackOnNewThread(): Unit = {
+    val time = new MockTime()
+    val metrics = mock(classOf[RequestChannel.Metrics])
+    val apiHandler = mock(classOf[ApiRequestHandler])
+    val requestChannel = new RequestChannel(10, "", time, metrics)
+    val handler = new KafkaRequestHandler(0, 0, mock(classOf[Meter]), new AtomicInteger(1), requestChannel, apiHandler, time)
+
+    val originalRequestLocal = mock(classOf[RequestLocal])
+
+    var handledCount = 0
+
+    val request = makeRequest(time, metrics)
+    requestChannel.sendRequest(request)
+
+    when(apiHandler.handle(ArgumentMatchers.eq(request), any())).thenAnswer { _ =>
+      // Prepare the callback.
+      val callback = KafkaRequestHandler.wrapAsyncCallback(
+        (reqLocal: RequestLocal, ms: Int) => {
+          reqLocal.bufferSupplier.close()
+          handledCount = handledCount + 1
+          handler.stop()
+        },
+        originalRequestLocal)
+      // Execute the callback asynchronously.
+      CompletableFuture.runAsync(() => callback(1))
+    }
+
+    handler.run()
+    // Verify that we don't use the request local that we passed in.
+    verify(originalRequestLocal, times(0)).bufferSupplier
+    assertEquals(1, handledCount)
+  }
 
   @ParameterizedTest
   @ValueSource(booleans = Array(true, false))