You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by cm...@apache.org on 2024/03/14 00:01:30 UTC

(kafka) branch 3.6 updated: KAFKA-16180: Fix UMR and LAIR handling during ZK migration (#15293)

This is an automated email from the ASF dual-hosted git repository.

cmccabe pushed a commit to branch 3.6
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.6 by this push:
     new 4cec48f86ec KAFKA-16180: Fix UMR and LAIR handling during ZK migration (#15293)
4cec48f86ec is described below

commit 4cec48f86ec95848b2e3893f04dc90d1d8415a47
Author: Colin Patrick McCabe <cm...@apache.org>
AuthorDate: Fri Feb 2 15:49:10 2024 -0800

    KAFKA-16180: Fix UMR and LAIR handling during ZK migration (#15293)
    
    While migrating from ZK mode to KRaft mode, the broker passes through a "hybrid" phase, in which it
    receives LeaderAndIsrRequest and UpdateMetadataRequest RPCs from the KRaft controller. For the most
    part, these RPCs can be handled just like their traditional equivalents from a ZK-based controller.
    However, there is one thing that is different: the way topic deletions are handled.
    
    In ZK mode, there is a "deleting" state which topics enter prior to being completely removed.
    Partitions stay in this state until they are removed from the disks of all replicas. And partitions
    associated with these deleting topics show up in the UMR and LAIR as having a leader of -2 (which
    is not a valid broker ID, of course, because it's negative). When brokers receive these RPCs, they
    know to remove the associated partitions from their metadata caches, and disks. When a full UMR or
    ISR is sent, deleting partitions are included as well.
    
    In hybrid mode, in contrast, there is no "deleting" state. Topic deletion happens immediately. We
    can do this because we know that we have topic IDs that are never reused. This means that we can
    always tell the difference between a broker that had an old version of some topic, and a broker
    that has a new version that was re-created with the same name. To make this work, when handling a
    full UMR or LAIR, hybrid brokers must compare the full state that was sent over the wire to their
    own local state, and adjust accordingly.
    
    Prior to this PR, the code for handling those adjustments had several major flaws. The biggest flaw
    is that it did not correctly handle the "re-creation" case where a topic named FOO appears in the
    RPC, but with a different ID than the broker's local FOO. Another flaw is that a problem with a
    single partition would prevent handling the whole request.
    
    In ZkMetadataCache.scala, we handle full UMR requests from KRaft controllers by rewriting the UMR
    so that it contains the implied deletions. I fixed this code so that deletions always appear at the
    start of the list of topic states. This is important for the re-creation case since it means that a
    single request can both delete the old FOO and add a new FOO to the cache. Also, rather than
    modifying the requesst in-place, as the previous code did, I build a whole new request with the
    desired list of topic states. This is much safer because it avoids unforseen interactions with
    other parts of the code that deal with requests (like request logging). While this new copy may
    sound expensive, it should actually not be. We are doing a "shallow copy" which references the
    previous list topic state entries.
    
    I also reworked ZkMetadataCache.updateMetadata so that if a partition is re-created, it does not
    appear in the returned set of deleted TopicPartitions. Since this set is used only by the group
    manager, this seemed appropriate. (If I was in the consumer group for the previous iteration of
    FOO, I should still be in the consumer group for the new iteration.)
    
    On the ReplicaManager.scala side, we handle full LAIR requests by treating anything which does not
    appear in them as a "stray replica." (But we do not rewrite the request objects as we do with UMR.)
    I moved the logic for finding stray replicas from ReplicaManager into LogManager. It makes more
    sense there, since the information about what is on-disk is managed in LogManager. Also, the stray
    replica detection logic for KRaft mode is there, so it makes sense to put the stray replica
    detection logic for hybrid mode there as well.
    
    Since the stray replica detection is now in LogManager, I moved the unit tests there as well.
    Previously some of those tests had been in BrokerMetadataPublisherTest for historical reasons.
    
    The main advantage of the new LAIR logic is that it takes topic ID into account. A replica can be a
    stray even if the LAIR contains a topic of the given name, but a different ID. I also moved the
    stray replica handling earlier in the becomeLeaderOrFollower function, so that we could correctly
    handle the "delete and re-create FOO" case.
    
    Reviewers: David Arthur <mu...@gmail.com>
    
    Conflicts: For this cherry-pick to 3.6, there were numerous import statement conflicts. Cases where
        we were setting directories had to be fixed to not do that (since JBOD-style directories aren't
        in 3.6).
---
 .../kafka/common/requests/LeaderAndIsrRequest.java |   4 +-
 .../common/requests/UpdateMetadataRequest.java     |   4 +-
 core/src/main/scala/kafka/log/LogManager.scala     |  87 +++-
 .../main/scala/kafka/server/ReplicaManager.scala   |  36 +-
 .../kafka/server/metadata/ZkMetadataCache.scala    | 176 +++++---
 .../java/kafka/testkit/KafkaClusterTestKit.java    |   1 -
 .../test/scala/unit/kafka/log/LogManagerTest.scala | 194 +++++++-
 .../unit/kafka/server/MetadataCacheTest.scala      | 497 +++++++++++++++++----
 .../unit/kafka/server/ReplicaManagerTest.scala     | 113 ++++-
 .../metadata/BrokerMetadataPublisherTest.scala     | 100 +----
 10 files changed, 948 insertions(+), 264 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrRequest.java
index 9a07a88a35d..251810e30e9 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrRequest.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrRequest.java
@@ -40,7 +40,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
 
-public class LeaderAndIsrRequest extends AbstractControlRequest {
+public final class LeaderAndIsrRequest extends AbstractControlRequest {
 
     public static class Builder extends AbstractControlRequest.Builder<LeaderAndIsrRequest> {
 
@@ -129,7 +129,7 @@ public class LeaderAndIsrRequest extends AbstractControlRequest {
 
     private final LeaderAndIsrRequestData data;
 
-    LeaderAndIsrRequest(LeaderAndIsrRequestData data, short version) {
+    public LeaderAndIsrRequest(LeaderAndIsrRequestData data, short version) {
         super(ApiKeys.LEADER_AND_ISR, version);
         this.data = data;
         // Do this from the constructor to make it thread-safe (even though it's only needed when some methods are called)
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataRequest.java
index ea9ae814198..245fff7ffce 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataRequest.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataRequest.java
@@ -41,7 +41,7 @@ import java.util.Map;
 
 import static java.util.Collections.singletonList;
 
-public class UpdateMetadataRequest extends AbstractControlRequest {
+public final class UpdateMetadataRequest extends AbstractControlRequest {
 
     public static class Builder extends AbstractControlRequest.Builder<UpdateMetadataRequest> {
         private final List<UpdateMetadataPartitionState> partitionStates;
@@ -149,7 +149,7 @@ public class UpdateMetadataRequest extends AbstractControlRequest {
 
     private final UpdateMetadataRequestData data;
 
-    UpdateMetadataRequest(UpdateMetadataRequestData data, short version) {
+    public UpdateMetadataRequest(UpdateMetadataRequestData data, short version) {
         super(ApiKeys.UPDATE_METADATA, version);
         this.data = data;
         // Do this from the constructor to make it thread-safe (even though it's only needed when some methods are called)
diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala
index 8725b99711d..bb5d4e9270a 100755
--- a/core/src/main/scala/kafka/log/LogManager.scala
+++ b/core/src/main/scala/kafka/log/LogManager.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
 import kafka.server.checkpoints.OffsetCheckpointFile
 import kafka.server.metadata.ConfigRepository
 import kafka.server._
+import kafka.server.metadata.BrokerMetadataPublisher.info
 import kafka.utils._
 import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid}
 import org.apache.kafka.common.utils.{KafkaThread, Time, Utils}
@@ -35,6 +36,8 @@ import scala.collection.mutable.ArrayBuffer
 import scala.util.{Failure, Success, Try}
 import kafka.utils.Implicits._
 import org.apache.kafka.common.config.TopicConfig
+import org.apache.kafka.common.requests.{AbstractControlRequest, LeaderAndIsrRequest}
+import org.apache.kafka.image.TopicsImage
 
 import java.util.Properties
 import org.apache.kafka.server.common.MetadataVersion
@@ -43,6 +46,7 @@ import org.apache.kafka.server.metrics.KafkaMetricsGroup
 import org.apache.kafka.server.util.Scheduler
 import org.apache.kafka.storage.internals.log.{CleanerConfig, LogConfig, LogDirFailureChannel, ProducerStateManagerConfig, RemoteIndexCache}
 
+import java.util
 import scala.annotation.nowarn
 
 /**
@@ -1223,7 +1227,7 @@ class LogManager(logDirs: Seq[File],
    * @param errorHandler The error handler that will be called when a exception for a particular
    *                     topic-partition is raised
    */
-  def asyncDelete(topicPartitions: Set[TopicPartition],
+  def asyncDelete(topicPartitions: Iterable[TopicPartition],
                   isStray: Boolean,
                   errorHandler: (TopicPartition, Throwable) => Unit): Unit = {
     val logDirs = mutable.Set.empty[File]
@@ -1451,4 +1455,85 @@ object LogManager {
       remoteStorageSystemEnable = config.remoteLogManagerConfig.enableRemoteStorageSystem())
   }
 
+  /**
+   * Find logs which should not be on the current broker, according to the metadata image.
+   *
+   * @param brokerId        The ID of the current broker.
+   * @param newTopicsImage  The new topics image after broker has been reloaded
+   * @param logs            A collection of Log objects.
+   *
+   * @return          The topic partitions which are no longer needed on this broker.
+   */
+  def findStrayReplicas(
+    brokerId: Int,
+    newTopicsImage: TopicsImage,
+    logs: Iterable[UnifiedLog]
+  ): Iterable[TopicPartition] = {
+    logs.flatMap { log =>
+      val topicId = log.topicId.getOrElse {
+        throw new RuntimeException(s"The log dir $log does not have a topic ID, " +
+          "which is not allowed when running in KRaft mode.")
+      }
+
+      val partitionId = log.topicPartition.partition()
+      Option(newTopicsImage.getPartition(topicId, partitionId)) match {
+        case Some(partition) =>
+          if (!partition.replicas.contains(brokerId)) {
+            info(s"Found stray log dir $log: the current replica assignment ${partition.replicas} " +
+              s"does not contain the local brokerId $brokerId.")
+            Some(log.topicPartition)
+          } else {
+            None
+          }
+
+        case None =>
+          info(s"Found stray log dir $log: the topicId $topicId does not exist in the metadata image")
+          Some(log.topicPartition)
+      }
+    }
+  }
+
+  /**
+   * Find logs which should not be on the current broker, according to the full LeaderAndIsrRequest.
+   *
+   * @param brokerId        The ID of the current broker.
+   * @param request         The full LeaderAndIsrRequest, containing all partitions owned by the broker.
+   * @param logs            A collection of Log objects.
+   *
+   * @return                The topic partitions which are no longer needed on this broker.
+   */
+  def findStrayReplicas(
+    brokerId: Int,
+    request: LeaderAndIsrRequest,
+    logs: Iterable[UnifiedLog]
+  ): Iterable[TopicPartition] = {
+    if (request.requestType() != AbstractControlRequest.Type.FULL) {
+      throw new RuntimeException("Cannot use incremental LeaderAndIsrRequest to find strays.")
+    }
+    val partitions = new util.HashMap[TopicPartition, Uuid]()
+    request.data().topicStates().forEach(topicState => {
+      topicState.partitionStates().forEach(partition => {
+        partitions.put(new TopicPartition(topicState.topicName(), partition.partitionIndex()),
+          topicState.topicId());
+      })
+    })
+    logs.flatMap { log =>
+      val topicId = log.topicId.getOrElse {
+        throw new RuntimeException(s"The log dir $log does not have a topic ID, " +
+          "which is not allowed when running in KRaft mode.")
+      }
+      Option(partitions.get(log.topicPartition)) match {
+        case Some(id) =>
+          if (id.equals(topicId)) {
+            None
+          } else {
+            info(s"Found stray log dir $log: this partition now exists with topic ID $id not $topicId.")
+            Some(log.topicPartition)
+          }
+        case None =>
+          info(s"Found stray log dir $log: this partition does not exist in the new full LeaderAndIsrRequest.")
+          Some(log.topicPartition)
+      }
+    }
+  }
 }
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 63a6f049247..ae1bb1e7a1a 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -394,7 +394,7 @@ class ReplicaManager(val config: KafkaConfig,
       brokerTopicStats.removeMetrics(topic)
   }
 
-  private[server] def updateStrayLogs(strayPartitions: Set[TopicPartition]): Unit = {
+  private[server] def updateStrayLogs(strayPartitions: Iterable[TopicPartition]): Unit = {
     if (strayPartitions.isEmpty) {
       return
     }
@@ -430,11 +430,6 @@ class ReplicaManager(val config: KafkaConfig,
     })
   }
 
-  // Find logs which exist on the broker, but aren't present in the full LISR
-  private[server] def findStrayPartitionsFromLeaderAndIsr(partitionsFromRequest: Set[TopicPartition]): Set[TopicPartition] = {
-    logManager.allLogs.map(_.topicPartition).filterNot(partitionsFromRequest.contains).toSet
-  }
-
   protected def completeDelayedFetchOrProduceRequests(topicPartition: TopicPartition): Unit = {
     val topicPartitionOperationKey = TopicPartitionOperationKey(topicPartition)
     delayedProducePurgatory.checkAndComplete(topicPartitionOperationKey)
@@ -1775,6 +1770,24 @@ class ReplicaManager(val config: KafkaConfig,
             s"Latest known controller epoch is $controllerEpoch")
           leaderAndIsrRequest.getErrorResponse(0, Errors.STALE_CONTROLLER_EPOCH.exception)
         } else {
+          // In migration mode, reconcile missed topic deletions when handling full LISR from KRaft controller.
+          // LISR "type" field was previously unspecified (0), so if we see it set to Full (2), then we know the
+          // request came from a KRaft controller.
+          //
+          // Note that we have to do this first, before anything else, since topics may be recreated with the same
+          // name, but a different ID. And in that case, we need to move aside the old version of those topics
+          // (with the obsolete topic ID) before doing anything else.
+          if (config.migrationEnabled &&
+            leaderAndIsrRequest.isKRaftController &&
+            leaderAndIsrRequest.requestType() == AbstractControlRequest.Type.FULL)
+          {
+            val strays = LogManager.findStrayReplicas(localBrokerId, leaderAndIsrRequest, logManager.allLogs)
+            stateChangeLogger.info(s"While handling full LeaderAndIsr request from KRaft " +
+              s"controller $controllerId with correlation id $correlationId, found ${strays.size} " +
+              "stray partition(s).")
+            updateStrayLogs(strays)
+          }
+
           val responseMap = new mutable.HashMap[TopicPartition, Errors]
           controllerEpoch = leaderAndIsrRequest.controllerEpoch
 
@@ -1896,17 +1909,6 @@ class ReplicaManager(val config: KafkaConfig,
           // have been completely populated before starting the checkpointing there by avoiding weird race conditions
           startHighWatermarkCheckPointThread()
 
-          // In migration mode, reconcile missed topic deletions when handling full LISR from KRaft controller.
-          // LISR "type" field was previously unspecified (0), so if we see it set to Full (2), then we know the
-          // request came from a KRaft controller.
-          if (
-            config.migrationEnabled &&
-            leaderAndIsrRequest.isKRaftController &&
-            leaderAndIsrRequest.requestType() == AbstractControlRequest.Type.FULL
-          ) {
-            updateStrayLogs(findStrayPartitionsFromLeaderAndIsr(allTopicPartitionsInRequest))
-          }
-
           maybeAddLogDirFetchers(partitions, highWatermarkCheckpoints, topicIdFromRequest)
 
           replicaFetcherManager.shutdownIdleFetcherThreads()
diff --git a/core/src/main/scala/kafka/server/metadata/ZkMetadataCache.scala b/core/src/main/scala/kafka/server/metadata/ZkMetadataCache.scala
index 84ef973b8a6..5a91e3c0751 100755
--- a/core/src/main/scala/kafka/server/metadata/ZkMetadataCache.scala
+++ b/core/src/main/scala/kafka/server/metadata/ZkMetadataCache.scala
@@ -65,48 +65,88 @@ case class MetadataSnapshot(partitionStates: mutable.AnyRefMap[String, mutable.L
 }
 
 object ZkMetadataCache {
-  /**
-   * Create topic deletions (leader=-2) for topics that are missing in a FULL UpdateMetadataRequest coming from a
-   * KRaft controller during a ZK migration. This will modify the UpdateMetadataRequest object passed into this method.
-   */
-  def maybeInjectDeletedPartitionsFromFullMetadataRequest(
+  def transformKRaftControllerFullMetadataRequest(
     currentMetadata: MetadataSnapshot,
     requestControllerEpoch: Int,
     requestTopicStates: util.List[UpdateMetadataTopicState],
-  ): Seq[Uuid] = {
-    val prevTopicIds = currentMetadata.topicIds.values.toSet
-    val requestTopics = requestTopicStates.asScala.map { topicState =>
-      topicState.topicName() -> topicState.topicId()
-    }.toMap
-
-    val deleteTopics = prevTopicIds -- requestTopics.values.toSet
-    if (deleteTopics.isEmpty) {
-      return Seq.empty
+    handleLogMessage: String => Unit,
+  ): util.List[UpdateMetadataTopicState] = {
+    val topicIdToNewState = new util.HashMap[Uuid, UpdateMetadataTopicState]()
+    requestTopicStates.forEach(state => topicIdToNewState.put(state.topicId(), state))
+    val newRequestTopicStates = new util.ArrayList[UpdateMetadataTopicState]()
+    currentMetadata.topicNames.forKeyValue((id, name) => {
+      try {
+        Option(topicIdToNewState.get(id)) match {
+          case None =>
+            currentMetadata.partitionStates.get(name) match {
+              case None => handleLogMessage(s"Error: topic ${name} appeared in currentMetadata.topicNames, " +
+                "but not in currentMetadata.partitionStates.")
+              case Some(curPartitionStates) =>
+                handleLogMessage(s"Removing topic ${name} with ID ${id} from the metadata cache since " +
+                  "the full UMR did not include it.")
+                newRequestTopicStates.add(createDeletionEntries(name,
+                  id,
+                  curPartitionStates.values,
+                  requestControllerEpoch))
+            }
+          case Some(newTopicState) =>
+            val indexToState = new util.HashMap[Integer, UpdateMetadataPartitionState]
+            newTopicState.partitionStates().forEach(part => indexToState.put(part.partitionIndex, part))
+            currentMetadata.partitionStates.get(name) match {
+              case None => handleLogMessage(s"Error: topic ${name} appeared in currentMetadata.topicNames, " +
+                "but not in currentMetadata.partitionStates.")
+              case Some(curPartitionStates) =>
+                curPartitionStates.foreach(state => indexToState.remove(state._1.toInt))
+                if (!indexToState.isEmpty) {
+                  handleLogMessage(s"Removing ${indexToState.size()} partition(s) from topic ${name} with " +
+                    s"ID ${id} from the metadata cache since the full UMR did not include them.")
+                  newRequestTopicStates.add(createDeletionEntries(name,
+                    id,
+                    indexToState.values().asScala,
+                    requestControllerEpoch))
+                }
+            }
+        }
+      } catch {
+        case e: Exception => handleLogMessage(s"Error: ${e}")
+      }
+    })
+    if (newRequestTopicStates.isEmpty) {
+      // If the output is the same as the input, optimize by just returning the input.
+      requestTopicStates
+    } else {
+      // If the output has some new entries, they should all appear at the beginning. This will
+      // ensure that the old stuff is cleared out before the new stuff is added. We will need a
+      // new list for this, of course.
+      newRequestTopicStates.addAll(requestTopicStates)
+      newRequestTopicStates
     }
+  }
 
-    deleteTopics.foreach { deletedTopicId =>
-      val topicName = currentMetadata.topicNames(deletedTopicId)
-      val topicState = new UpdateMetadataRequestData.UpdateMetadataTopicState()
-        .setTopicId(deletedTopicId)
+  def createDeletionEntries(
+    topicName: String,
+    topicId: Uuid,
+    partitions: Iterable[UpdateMetadataPartitionState],
+    requestControllerEpoch: Int
+  ): UpdateMetadataTopicState = {
+    val topicState = new UpdateMetadataRequestData.UpdateMetadataTopicState()
+      .setTopicId(topicId)
+      .setTopicName(topicName)
+      .setPartitionStates(new util.ArrayList())
+    partitions.foreach(partition => {
+      val lisr = LeaderAndIsr.duringDelete(partition.isr().asScala.map(_.intValue()).toList)
+      val newPartitionState = new UpdateMetadataPartitionState()
+        .setPartitionIndex(partition.partitionIndex().toInt)
         .setTopicName(topicName)
-        .setPartitionStates(new util.ArrayList())
-
-      currentMetadata.partitionStates(topicName).foreach { case (partitionId, partitionState) =>
-        val lisr = LeaderAndIsr.duringDelete(partitionState.isr().asScala.map(_.intValue()).toList)
-        val newPartitionState = new UpdateMetadataPartitionState()
-          .setPartitionIndex(partitionId.toInt)
-          .setTopicName(topicName)
-          .setLeader(lisr.leader)
-          .setLeaderEpoch(lisr.leaderEpoch)
-          .setControllerEpoch(requestControllerEpoch)
-          .setReplicas(partitionState.replicas())
-          .setZkVersion(lisr.partitionEpoch)
-          .setIsr(lisr.isr.map(Integer.valueOf).asJava)
-        topicState.partitionStates().add(newPartitionState)
-      }
-      requestTopicStates.add(topicState)
-    }
-    deleteTopics.toSeq
+        .setLeader(lisr.leader)
+        .setLeaderEpoch(lisr.leaderEpoch)
+        .setControllerEpoch(requestControllerEpoch)
+        .setReplicas(partition.replicas())
+        .setZkVersion(lisr.partitionEpoch)
+        .setIsr(lisr.isr.map(Integer.valueOf).asJava)
+      topicState.partitionStates().add(newPartitionState)
+    })
+    topicState
   }
 }
 
@@ -429,26 +469,59 @@ class ZkMetadataCache(
       controllerId(snapshot).orNull)
   }
 
-  // This method returns the deleted TopicPartitions received from UpdateMetadataRequest
-  def updateMetadata(correlationId: Int, updateMetadataRequest: UpdateMetadataRequest): Seq[TopicPartition] = {
+  // This method returns the deleted TopicPartitions received from UpdateMetadataRequest.
+  // Note: if this ZK broker is migrating to KRaft, a singular UMR may sometimes both delete a
+  // partition and re-create a new partition with that same name. In that case, it will not appear
+  // in the return value of this function.
+  def updateMetadata(
+    correlationId: Int,
+    originalUpdateMetadataRequest: UpdateMetadataRequest
+  ): Seq[TopicPartition] = {
+    var updateMetadataRequest = originalUpdateMetadataRequest
     inWriteLock(partitionMetadataLock) {
       if (
         updateMetadataRequest.isKRaftController &&
         updateMetadataRequest.updateType() == AbstractControlRequest.Type.FULL
       ) {
-        if (!zkMigrationEnabled) {
+        if (updateMetadataRequest.version() < 8) {
+          stateChangeLogger.error(s"Received UpdateMetadataRequest with Type=FULL (2), but version of " +
+            updateMetadataRequest.version() + ", which should not be possible. Not treating this as a full " +
+            "metadata update")
+        } else if (!zkMigrationEnabled) {
           stateChangeLogger.error(s"Received UpdateMetadataRequest with Type=FULL (2), but ZK migrations " +
             s"are not enabled on this broker. Not treating this as a full metadata update")
         } else {
-          val deletedTopicIds = ZkMetadataCache.maybeInjectDeletedPartitionsFromFullMetadataRequest(
-            metadataSnapshot, updateMetadataRequest.controllerEpoch(), updateMetadataRequest.topicStates())
-          if (deletedTopicIds.isEmpty) {
-            stateChangeLogger.trace(s"Received UpdateMetadataRequest with Type=FULL (2), " +
-              s"but no deleted topics were detected.")
-          } else {
-            stateChangeLogger.debug(s"Received UpdateMetadataRequest with Type=FULL (2), " +
-              s"found ${deletedTopicIds.size} deleted topic ID(s): $deletedTopicIds.")
-          }
+          // When handling a UMR from a KRaft controller, we may have to insert some partition
+          // deletions at the beginning, to handle the different way topic deletion works in KRaft
+          // mode (and also migration mode).
+          //
+          // After we've done that, we re-create the whole UpdateMetadataRequest object using the
+          // updated list of topic info. This ensures that UpdateMetadataRequest.normalize is called
+          // on the new, updated topic data. Note that we don't mutate the old request object; it may
+          // be used elsewhere.
+          val newTopicStates = ZkMetadataCache.transformKRaftControllerFullMetadataRequest(
+            metadataSnapshot,
+            updateMetadataRequest.controllerEpoch(),
+            updateMetadataRequest.topicStates(),
+            logMessage => if (logMessage.startsWith("Error")) {
+              stateChangeLogger.error(logMessage)
+            } else {
+              stateChangeLogger.info(logMessage)
+            })
+
+          // It would be nice if we could call duplicate() here, but we don't want to copy the
+          // old topicStates array. That would be quite costly, and we're not going to use it anyway.
+          // Instead, we copy each field that we need.
+          val originalRequestData = updateMetadataRequest.data()
+          val newData = new UpdateMetadataRequestData().
+            setControllerId(originalRequestData.controllerId()).
+            setIsKRaftController(originalRequestData.isKRaftController).
+            setType(originalRequestData.`type`()).
+            setControllerEpoch(originalRequestData.controllerEpoch()).
+            setBrokerEpoch(originalRequestData.brokerEpoch()).
+            setTopicStates(newTopicStates).
+            setLiveBrokers(originalRequestData.liveBrokers())
+          updateMetadataRequest = new UpdateMetadataRequest(newData, updateMetadataRequest.version())
         }
       }
 
@@ -491,7 +564,7 @@ class ZkMetadataCache(
       newZeroIds.foreach { case (zeroIdTopic, _) => topicIds.remove(zeroIdTopic) }
       topicIds ++= newTopicIds.toMap
 
-      val deletedPartitions = new mutable.ArrayBuffer[TopicPartition]
+      val deletedPartitions = new java.util.LinkedHashSet[TopicPartition]
       if (!updateMetadataRequest.partitionStates.iterator.hasNext) {
         metadataSnapshot = MetadataSnapshot(metadataSnapshot.partitionStates, topicIds.toMap,
           controllerIdOpt, aliveBrokers, aliveNodes)
@@ -516,9 +589,10 @@ class ZkMetadataCache(
             if (traceEnabled)
               stateChangeLogger.trace(s"Deleted partition $tp from metadata cache in response to UpdateMetadata " +
                 s"request sent by controller $controllerId epoch $controllerEpoch with correlation id $correlationId")
-            deletedPartitions += tp
+            deletedPartitions.add(tp)
           } else {
             addOrUpdatePartitionInfo(partitionStates, tp.topic, tp.partition, state)
+            deletedPartitions.remove(tp)
             if (traceEnabled)
               stateChangeLogger.trace(s"Cached leader info $state for partition $tp in response to " +
                 s"UpdateMetadata request sent by controller $controllerId epoch $controllerEpoch with correlation id $correlationId")
@@ -530,7 +604,7 @@ class ZkMetadataCache(
 
         metadataSnapshot = MetadataSnapshot(partitionStates, topicIds.toMap, controllerIdOpt, aliveBrokers, aliveNodes)
       }
-      deletedPartitions
+      deletedPartitions.asScala.toSeq
     }
   }
 
diff --git a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
index f6e4566b479..0e37a20dcba 100644
--- a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
+++ b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java
@@ -36,7 +36,6 @@ import org.apache.kafka.common.utils.ThreadUtils;
 import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.controller.Controller;
-import org.apache.kafka.metadata.bootstrap.BootstrapMetadata;
 import org.apache.kafka.raft.RaftConfig;
 import org.apache.kafka.server.common.ApiMessageAndVersion;
 import org.apache.kafka.server.common.MetadataVersion;
diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
index ee70a518910..231b80d037c 100755
--- a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala
@@ -25,8 +25,13 @@ import kafka.utils._
 import org.apache.directory.api.util.FileUtils
 import org.apache.kafka.common.config.TopicConfig
 import org.apache.kafka.common.errors.OffsetOutOfRangeException
+import org.apache.kafka.common.message.LeaderAndIsrRequestData
+import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrTopicState
+import org.apache.kafka.common.requests.{AbstractControlRequest, LeaderAndIsrRequest}
 import org.apache.kafka.common.utils.Utils
-import org.apache.kafka.common.{KafkaException, TopicPartition}
+import org.apache.kafka.common.{KafkaException, TopicIdPartition, TopicPartition, Uuid}
+import org.apache.kafka.image.{TopicImage, TopicsImage}
+import org.apache.kafka.metadata.{LeaderRecoveryState, PartitionRegistration}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, BeforeEach, Test}
 import org.mockito.ArgumentMatchers.any
@@ -47,6 +52,7 @@ import scala.jdk.CollectionConverters._
 import scala.util.{Failure, Try}
 
 class LogManagerTest {
+  import LogManagerTest._
 
   val time = new MockTime()
   val maxRollInterval = 100
@@ -1010,4 +1016,190 @@ class LogManagerTest {
     assertEquals(8, invokedCount)
     assertEquals(4, failureCount)
   }
+
+  val foo0 = new TopicIdPartition(Uuid.fromString("Sl08ZXU2QW6uF5hIoSzc8w"), new TopicPartition("foo", 0))
+  val foo1 = new TopicIdPartition(Uuid.fromString("Sl08ZXU2QW6uF5hIoSzc8w"), new TopicPartition("foo", 1))
+  val bar0 = new TopicIdPartition(Uuid.fromString("69O438ZkTSeqqclTtZO2KA"), new TopicPartition("bar", 0))
+  val bar1 = new TopicIdPartition(Uuid.fromString("69O438ZkTSeqqclTtZO2KA"), new TopicPartition("bar", 1))
+  val baz0 = new TopicIdPartition(Uuid.fromString("2Ik9_5-oRDOKpSXd2SuG5w"), new TopicPartition("baz", 0))
+  val baz1 = new TopicIdPartition(Uuid.fromString("2Ik9_5-oRDOKpSXd2SuG5w"), new TopicPartition("baz", 1))
+  val baz2 = new TopicIdPartition(Uuid.fromString("2Ik9_5-oRDOKpSXd2SuG5w"), new TopicPartition("baz", 2))
+  val quux0 = new TopicIdPartition(Uuid.fromString("YS9owjv5TG2OlsvBM0Qw6g"), new TopicPartition("quux", 0))
+  val recreatedFoo0 = new TopicIdPartition(Uuid.fromString("_dOOzPe3TfiWV21Lh7Vmqg"), new TopicPartition("foo", 0))
+  val recreatedFoo1 = new TopicIdPartition(Uuid.fromString("_dOOzPe3TfiWV21Lh7Vmqg"), new TopicPartition("foo", 1))
+
+  @Test
+  def testFindStrayReplicasInEmptyImage(): Unit = {
+    val image: TopicsImage  = topicsImage(Seq())
+    val onDisk = Seq(foo0, foo1, bar0, bar1, quux0)
+    val expected = onDisk.map(_.topicPartition()).toSet
+    assertEquals(expected,
+      LogManager.findStrayReplicas(0,
+        image, onDisk.map(mockLog(_)).toSet))
+  }
+
+  @Test
+  def testFindSomeStrayReplicasInImage(): Unit = {
+    val image: TopicsImage  = topicsImage(Seq(
+      topicImage(Map(
+        foo0 -> Seq(0, 1, 2),
+      )),
+      topicImage(Map(
+        bar0 -> Seq(0, 1, 2),
+        bar1 -> Seq(0, 1, 2),
+      ))
+    ))
+    val onDisk = Seq(foo0, foo1, bar0, bar1, quux0).map(mockLog(_))
+    val expected = Set(foo1, quux0).map(_.topicPartition)
+    assertEquals(expected,
+      LogManager.findStrayReplicas(0,
+        image, onDisk).toSet)
+  }
+
+  @Test
+  def testFindSomeStrayReplicasInImageWithRemoteReplicas(): Unit = {
+    val image: TopicsImage  = topicsImage(Seq(
+      topicImage(Map(
+        foo0 -> Seq(0, 1, 2),
+      )),
+      topicImage(Map(
+        bar0 -> Seq(1, 2, 3),
+        bar1 -> Seq(2, 3, 0),
+      ))
+    ))
+    val onDisk = Seq(foo0, bar0, bar1).map(mockLog(_))
+    val expected = Set(bar0).map(_.topicPartition)
+    assertEquals(expected,
+      LogManager.findStrayReplicas(0,
+        image, onDisk).toSet)
+  }
+
+  @Test
+  def testFindStrayReplicasInEmptyLAIR(): Unit = {
+    val onDisk = Seq(foo0, foo1, bar0, bar1, baz0, baz1, baz2, quux0)
+    val expected = onDisk.map(_.topicPartition()).toSet
+    assertEquals(expected,
+      LogManager.findStrayReplicas(0,
+        createLeaderAndIsrRequestForStrayDetection(Seq()),
+          onDisk.map(mockLog(_))).toSet)
+  }
+
+  @Test
+  def testFindNoStrayReplicasInFullLAIR(): Unit = {
+    val onDisk = Seq(foo0, foo1, bar0, bar1, baz0, baz1, baz2, quux0)
+    assertEquals(Set(),
+      LogManager.findStrayReplicas(0,
+      createLeaderAndIsrRequestForStrayDetection(onDisk),
+        onDisk.map(mockLog(_))).toSet)
+  }
+
+  @Test
+  def testFindSomeStrayReplicasInFullLAIR(): Unit = {
+    val onDisk = Seq(foo0, foo1, bar0, bar1, baz0, baz1, baz2, quux0)
+    val present = Seq(foo0, bar0, bar1, quux0)
+    val expected = Seq(foo1, baz0, baz1, baz2).map(_.topicPartition()).toSet
+    assertEquals(expected,
+      LogManager.findStrayReplicas(0,
+        createLeaderAndIsrRequestForStrayDetection(present),
+        onDisk.map(mockLog(_))).toSet)
+  }
+
+  @Test
+  def testTopicRecreationInFullLAIR(): Unit = {
+    val onDisk = Seq(foo0, foo1, bar0, bar1, baz0, baz1, baz2, quux0)
+    val present = Seq(recreatedFoo0, recreatedFoo1, bar0, baz0, baz1, baz2, quux0)
+    val expected = Seq(foo0, foo1, bar1).map(_.topicPartition()).toSet
+    assertEquals(expected,
+      LogManager.findStrayReplicas(0,
+        createLeaderAndIsrRequestForStrayDetection(present),
+        onDisk.map(mockLog(_))).toSet)
+  }
 }
+
+object LogManagerTest {
+  def mockLog(
+    topicIdPartition: TopicIdPartition
+  ): UnifiedLog = {
+    val log = Mockito.mock(classOf[UnifiedLog])
+    Mockito.when(log.topicId).thenReturn(Some(topicIdPartition.topicId()))
+    Mockito.when(log.topicPartition).thenReturn(topicIdPartition.topicPartition())
+    log
+  }
+
+  def topicImage(
+    partitions: Map[TopicIdPartition, Seq[Int]]
+  ): TopicImage = {
+    var topicName: String = null
+    var topicId: Uuid = null
+    partitions.keySet.foreach {
+      partition => if (topicId == null) {
+        topicId = partition.topicId()
+      } else if (!topicId.equals(partition.topicId())) {
+        throw new IllegalArgumentException("partition topic IDs did not match")
+      }
+        if (topicName == null) {
+          topicName = partition.topic()
+        } else if (!topicName.equals(partition.topic())) {
+          throw new IllegalArgumentException("partition topic names did not match")
+        }
+    }
+    if (topicId == null) {
+      throw new IllegalArgumentException("Invalid empty partitions map.")
+    }
+    val partitionRegistrations = partitions.map { case (partition, replicas) =>
+      Int.box(partition.partition()) -> new PartitionRegistration.Builder().
+        setReplicas(replicas.toArray).
+        setIsr(replicas.toArray).
+        setLeader(replicas.head).
+        setLeaderRecoveryState(LeaderRecoveryState.RECOVERED).
+        setLeaderEpoch(0).
+        setPartitionEpoch(0).
+        build()
+    }
+    new TopicImage(topicName, topicId, partitionRegistrations.asJava)
+  }
+
+  def topicsImage(
+    topics: Seq[TopicImage]
+  ): TopicsImage = {
+    var retval = TopicsImage.EMPTY
+    topics.foreach { t => retval = retval.including(t) }
+    retval
+  }
+
+  def createLeaderAndIsrRequestForStrayDetection(
+    partitions: Iterable[TopicIdPartition],
+    leaders: Iterable[Int] = Seq(),
+  ): LeaderAndIsrRequest = {
+    val nextLeaderIter = leaders.iterator
+    def nextLeader(): Int = {
+      if (nextLeaderIter.hasNext) {
+        nextLeaderIter.next()
+      } else {
+        3
+      }
+    }
+    val data = new LeaderAndIsrRequestData().
+      setControllerId(1000).
+      setIsKRaftController(true).
+      setType(AbstractControlRequest.Type.FULL.toByte)
+    val topics = new java.util.LinkedHashMap[String, LeaderAndIsrTopicState]
+    partitions.foreach(partition => {
+      val topicState = topics.computeIfAbsent(partition.topic(),
+        _ => new LeaderAndIsrTopicState().
+          setTopicId(partition.topicId()).
+          setTopicName(partition.topic()))
+      topicState.partitionStates().add(new LeaderAndIsrRequestData.LeaderAndIsrPartitionState().
+        setTopicName(partition.topic()).
+        setPartitionIndex(partition.partition()).
+        setControllerEpoch(123).
+        setLeader(nextLeader()).
+        setLeaderEpoch(456).
+        setIsr(java.util.Arrays.asList(3, 4, 5)).
+        setReplicas(java.util.Arrays.asList(3, 4, 5)).
+        setLeaderRecoveryState(LeaderRecoveryState.RECOVERED.value()))
+    })
+    data.topicStates().addAll(topics.values())
+    new LeaderAndIsrRequest(data, 7.toShort)
+  }
+}
\ No newline at end of file
diff --git a/core/src/test/scala/unit/kafka/server/MetadataCacheTest.scala b/core/src/test/scala/unit/kafka/server/MetadataCacheTest.scala
index 8a8a3f4d38f..ecac9135ce9 100644
--- a/core/src/test/scala/unit/kafka/server/MetadataCacheTest.scala
+++ b/core/src/test/scala/unit/kafka/server/MetadataCacheTest.scala
@@ -22,7 +22,9 @@ import java.util
 import java.util.Arrays.asList
 import java.util.Collections
 import kafka.api.LeaderAndIsr
+import kafka.cluster.Broker
 import kafka.server.metadata.{KRaftMetadataCache, MetadataSnapshot, ZkMetadataCache}
+import org.apache.kafka.common.message.UpdateMetadataRequestData
 import org.apache.kafka.common.message.UpdateMetadataRequestData.{UpdateMetadataBroker, UpdateMetadataEndpoint, UpdateMetadataPartitionState, UpdateMetadataTopicState}
 import org.apache.kafka.common.network.ListenerName
 import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
@@ -851,93 +853,6 @@ class MetadataCacheTest {
     (initialTopicIds, initialTopicStates, newTopicIds, newPartitionStates)
   }
 
-  /**
-   * Verify that ZkMetadataCache#maybeInjectDeletedPartitionsFromFullMetadataRequest correctly
-   * generates deleted topic partition state when deleted topics are detected. This does not check
-   * any of the logic about when this method should be called, only that it does the correct thing
-   * when called.
-   */
-  @Test
-  def testMaybeInjectDeletedPartitionsFromFullMetadataRequest(): Unit = {
-    val (initialTopicIds, initialTopicStates, newTopicIds, _) = setupInitialAndFullMetadata()
-
-    val initialSnapshot = MetadataSnapshot(
-      partitionStates = initialTopicStates,
-      topicIds = initialTopicIds,
-      controllerId = Some(KRaftCachedControllerId(3000)),
-      aliveBrokers = mutable.LongMap.empty,
-      aliveNodes = mutable.LongMap.empty)
-
-    def verifyTopicStates(
-      updateMetadataRequest: UpdateMetadataRequest
-    )(
-      verifier: mutable.AnyRefMap[String, mutable.LongMap[UpdateMetadataPartitionState]] => Unit
-    ): Unit  = {
-      val finalTopicStates = mutable.AnyRefMap.empty[String, mutable.LongMap[UpdateMetadataPartitionState]]
-      updateMetadataRequest.topicStates().forEach { topicState =>
-        finalTopicStates.put(topicState.topicName(), mutable.LongMap.empty[UpdateMetadataPartitionState])
-        topicState.partitionStates().forEach { partitionState =>
-          finalTopicStates(topicState.topicName()).put(partitionState.partitionIndex(), partitionState)
-        }
-      }
-      verifier.apply(finalTopicStates)
-    }
-
-    // Empty UMR, deletes everything
-    var updateMetadataRequest = new UpdateMetadataRequest.Builder(8, 1, 42, brokerEpoch,
-      Seq.empty.asJava, Seq.empty.asJava, Map.empty[String, Uuid].asJava, true, AbstractControlRequest.Type.FULL).build()
-    assertEquals(
-      Seq(Uuid.fromString("IQ2F1tpCRoSbjfq4zBJwpg"), Uuid.fromString("4N8_J-q7SdWHPFkos275pQ")),
-      ZkMetadataCache.maybeInjectDeletedPartitionsFromFullMetadataRequest(
-        initialSnapshot, 42, updateMetadataRequest.topicStates())
-    )
-    verifyTopicStates(updateMetadataRequest) { topicStates =>
-      assertEquals(2, topicStates.size)
-      assertEquals(3, topicStates("test-topic-1").values.toSeq.count(_.leader() == -2))
-      assertEquals(3, topicStates("test-topic-2").values.toSeq.count(_.leader() == -2))
-    }
-
-    // One different topic, should remove other two
-    val oneTopicPartitionState = Seq(new UpdateMetadataPartitionState()
-      .setTopicName("different-topic")
-      .setPartitionIndex(0)
-      .setControllerEpoch(42)
-      .setLeader(0)
-      .setLeaderEpoch(10)
-      .setIsr(asList[Integer](0, 1, 2))
-      .setZkVersion(1)
-      .setReplicas(asList[Integer](0, 1, 2)))
-    updateMetadataRequest = new UpdateMetadataRequest.Builder(8, 1, 42, brokerEpoch,
-      oneTopicPartitionState.asJava, Seq.empty.asJava, newTopicIds.asJava, true, AbstractControlRequest.Type.FULL).build()
-    assertEquals(
-      Seq(Uuid.fromString("IQ2F1tpCRoSbjfq4zBJwpg"), Uuid.fromString("4N8_J-q7SdWHPFkos275pQ")),
-      ZkMetadataCache.maybeInjectDeletedPartitionsFromFullMetadataRequest(
-        initialSnapshot, 42, updateMetadataRequest.topicStates())
-    )
-    verifyTopicStates(updateMetadataRequest) { topicStates =>
-      assertEquals(3, topicStates.size)
-      assertEquals(3, topicStates("test-topic-1").values.toSeq.count(_.leader() == -2))
-      assertEquals(3, topicStates("test-topic-2").values.toSeq.count(_.leader() == -2))
-    }
-
-    // Existing two plus one new topic, nothing gets deleted, all topics should be present
-    val allTopicStates = initialTopicStates.flatMap(_._2.values).toSeq ++ oneTopicPartitionState
-    val allTopicIds = initialTopicIds ++ newTopicIds
-    updateMetadataRequest = new UpdateMetadataRequest.Builder(8, 1, 42, brokerEpoch,
-      allTopicStates.asJava, Seq.empty.asJava, allTopicIds.asJava, true, AbstractControlRequest.Type.FULL).build()
-    assertEquals(
-      Seq.empty,
-      ZkMetadataCache.maybeInjectDeletedPartitionsFromFullMetadataRequest(
-        initialSnapshot, 42, updateMetadataRequest.topicStates())
-    )
-    verifyTopicStates(updateMetadataRequest) { topicStates =>
-      assertEquals(3, topicStates.size)
-      // Ensure these two weren't deleted (leader = -2)
-      assertEquals(0, topicStates("test-topic-1").values.toSeq.count(_.leader() == -2))
-      assertEquals(0, topicStates("test-topic-2").values.toSeq.count(_.leader() == -2))
-    }
-  }
-
   /**
    * Verify the behavior of ZkMetadataCache when handling "Full" UpdateMetadataRequest
    */
@@ -1006,4 +921,412 @@ class MetadataCacheTest {
       assertTrue(cache.contains("test-topic-1"))
     }
   }
+
+  val oldRequestControllerEpoch: Int = 122
+  val newRequestControllerEpoch: Int = 123
+
+  val fooTopicName: String = "foo"
+  val fooTopicId: Uuid = Uuid.fromString("HDceyWK0Ry-j3XLR8DvvGA")
+  val oldFooPart0 = new UpdateMetadataPartitionState().
+    setTopicName(fooTopicName).
+    setPartitionIndex(0).
+    setControllerEpoch(oldRequestControllerEpoch).
+    setLeader(4).
+    setIsr(java.util.Arrays.asList(4, 5, 6)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(4, 5, 6)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val newFooPart0 = new UpdateMetadataPartitionState().
+    setTopicName(fooTopicName).
+    setPartitionIndex(0).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(5).
+    setIsr(java.util.Arrays.asList(4, 5, 6)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(4, 5, 6)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val oldFooPart1 = new UpdateMetadataPartitionState().
+    setTopicName(fooTopicName).
+    setPartitionIndex(1).
+    setControllerEpoch(oldRequestControllerEpoch).
+    setLeader(5).
+    setIsr(java.util.Arrays.asList(4, 5, 6)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(4, 5, 6)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val newFooPart1 = new UpdateMetadataPartitionState().
+    setTopicName(fooTopicName).
+    setPartitionIndex(1).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(5).
+    setIsr(java.util.Arrays.asList(4, 5)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(4, 5, 6)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+
+  val barTopicName: String = "bar"
+  val barTopicId: Uuid = Uuid.fromString("97FBD1g4QyyNNZNY94bkRA")
+  val recreatedBarTopicId: Uuid = Uuid.fromString("lZokxuaPRty7c5P4dNdTYA")
+  val oldBarPart0 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(0).
+    setControllerEpoch(oldRequestControllerEpoch).
+    setLeader(7).
+    setIsr(java.util.Arrays.asList(7, 8)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(7, 8, 9)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val newBarPart0 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(0).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(7).
+    setIsr(java.util.Arrays.asList(7, 8)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(7, 8, 9)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val deletedBarPart0 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(0).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(-2).
+    setIsr(java.util.Arrays.asList(7, 8)).
+    setZkVersion(0).
+    setReplicas(java.util.Arrays.asList(7, 8, 9)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val oldBarPart1 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(1).
+    setControllerEpoch(oldRequestControllerEpoch).
+    setLeader(5).
+    setIsr(java.util.Arrays.asList(4, 5, 6)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(4, 5, 6)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val newBarPart1 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(1).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(5).
+    setIsr(java.util.Arrays.asList(4, 5, 6)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(4, 5, 6)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val deletedBarPart1 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(1).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(-2).
+    setIsr(java.util.Arrays.asList(4, 5, 6)).
+    setZkVersion(0).
+    setReplicas(java.util.Arrays.asList(4, 5, 6)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val oldBarPart2 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(2).
+    setControllerEpoch(oldRequestControllerEpoch).
+    setLeader(9).
+    setIsr(java.util.Arrays.asList(7, 8, 9)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(7, 8, 9)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val newBarPart2 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(2).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(8).
+    setIsr(java.util.Arrays.asList(7, 8)).
+    setZkVersion(789).
+    setReplicas(java.util.Arrays.asList(7, 8, 9)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+  val deletedBarPart2 = new UpdateMetadataPartitionState().
+    setTopicName(barTopicName).
+    setPartitionIndex(2).
+    setControllerEpoch(newRequestControllerEpoch).
+    setLeader(-2).
+    setIsr(java.util.Arrays.asList(7, 8, 9)).
+    setZkVersion(0).
+    setReplicas(java.util.Arrays.asList(7, 8, 9)).
+    setOfflineReplicas(java.util.Collections.emptyList())
+
+  @Test
+  def testCreateDeletionEntries(): Unit = {
+    assertEquals(new UpdateMetadataTopicState().
+      setTopicName(fooTopicName).
+      setTopicId(fooTopicId).
+      setPartitionStates(Seq(
+        new UpdateMetadataPartitionState().
+          setTopicName(fooTopicName).
+          setPartitionIndex(0).
+          setControllerEpoch(newRequestControllerEpoch).
+          setLeader(-2).
+          setIsr(java.util.Arrays.asList(4, 5, 6)).
+          setZkVersion(0).
+          setReplicas(java.util.Arrays.asList(4, 5, 6)).
+          setOfflineReplicas(java.util.Collections.emptyList()),
+        new UpdateMetadataPartitionState().
+          setTopicName(fooTopicName).
+          setPartitionIndex(1).
+          setControllerEpoch(newRequestControllerEpoch).
+          setLeader(-2).
+          setIsr(java.util.Arrays.asList(4, 5, 6)).
+          setZkVersion(0).
+          setReplicas(java.util.Arrays.asList(4, 5, 6)).
+          setOfflineReplicas(java.util.Collections.emptyList())
+      ).asJava),
+    ZkMetadataCache.createDeletionEntries(fooTopicName,
+      fooTopicId,
+      Seq(oldFooPart0, oldFooPart1),
+      newRequestControllerEpoch))
+  }
+
+  val prevSnapshot: MetadataSnapshot = {
+    val parts = new mutable.AnyRefMap[String, mutable.LongMap[UpdateMetadataPartitionState]]
+    val fooParts = new mutable.LongMap[UpdateMetadataPartitionState]
+    fooParts.put(0L, oldFooPart0)
+    fooParts.put(1L, oldFooPart1)
+    parts.put(fooTopicName, fooParts)
+    val barParts = new mutable.LongMap[UpdateMetadataPartitionState]
+    barParts.put(0L, oldBarPart0)
+    barParts.put(1L, oldBarPart1)
+    barParts.put(2L, oldBarPart2)
+    parts.put(barTopicName, barParts)
+    MetadataSnapshot(parts,
+      Map[String, Uuid](
+        fooTopicName -> fooTopicId,
+        barTopicName -> barTopicId
+      ),
+      Some(KRaftCachedControllerId(1)),
+      mutable.LongMap[Broker](),
+      mutable.LongMap[collection.Map[ListenerName, Node]]()
+    )
+  }
+
+  def transformKRaftControllerFullMetadataRequest(
+    currentMetadata: MetadataSnapshot,
+    requestControllerEpoch: Int,
+    requestTopicStates: util.List[UpdateMetadataTopicState],
+  ): (util.List[UpdateMetadataTopicState], util.List[String]) = {
+
+    val logs = new util.ArrayList[String]
+    val results = ZkMetadataCache.transformKRaftControllerFullMetadataRequest(
+      currentMetadata, requestControllerEpoch, requestTopicStates, log => logs.add(log))
+    (results, logs)
+  }
+
+  @Test
+  def transformUMRWithNoChanges(): Unit = {
+    assertEquals((Seq(
+        new UpdateMetadataTopicState().
+          setTopicName(fooTopicName).
+          setTopicId(fooTopicId).
+          setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+        new UpdateMetadataTopicState().
+          setTopicName(barTopicName).
+          setTopicId(barTopicId).
+          setPartitionStates(Seq(newBarPart0, newBarPart1, newBarPart2).asJava)
+      ).asJava,
+      List[String]().asJava),
+      transformKRaftControllerFullMetadataRequest(prevSnapshot,
+        newRequestControllerEpoch,
+        Seq(
+          new UpdateMetadataTopicState().
+            setTopicName(fooTopicName).
+            setTopicId(fooTopicId).
+            setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+          new UpdateMetadataTopicState().
+            setTopicName(barTopicName).
+            setTopicId(barTopicId).
+            setPartitionStates(Seq(newBarPart0, newBarPart1, newBarPart2).asJava)
+        ).asJava
+      )
+    )
+  }
+
+  @Test
+  def transformUMRWithMissingBar(): Unit = {
+    assertEquals((Seq(
+      new UpdateMetadataTopicState().
+        setTopicName(barTopicName).
+        setTopicId(barTopicId).
+        setPartitionStates(Seq(deletedBarPart0, deletedBarPart1, deletedBarPart2).asJava),
+      new UpdateMetadataTopicState().
+        setTopicName(fooTopicName).
+        setTopicId(fooTopicId).
+        setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+    ).asJava,
+      List[String](
+        "Removing topic bar with ID 97FBD1g4QyyNNZNY94bkRA from the metadata cache since the full UMR did not include it.",
+      ).asJava),
+      transformKRaftControllerFullMetadataRequest(prevSnapshot,
+        newRequestControllerEpoch,
+        Seq(
+          new UpdateMetadataTopicState().
+            setTopicName(fooTopicName).
+            setTopicId(fooTopicId).
+            setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+        ).asJava
+      )
+    )
+  }
+
+  @Test
+  def transformUMRWithRecreatedBar(): Unit = {
+    assertEquals((Seq(
+      new UpdateMetadataTopicState().
+        setTopicName(barTopicName).
+        setTopicId(barTopicId).
+        setPartitionStates(Seq(deletedBarPart0, deletedBarPart1, deletedBarPart2).asJava),
+      new UpdateMetadataTopicState().
+        setTopicName(fooTopicName).
+        setTopicId(fooTopicId).
+        setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+      new UpdateMetadataTopicState().
+        setTopicName(barTopicName).
+        setTopicId(recreatedBarTopicId).
+        setPartitionStates(Seq(newBarPart0, newBarPart1, newBarPart2).asJava),
+    ).asJava,
+      List[String](
+        "Removing topic bar with ID 97FBD1g4QyyNNZNY94bkRA from the metadata cache since the full UMR did not include it.",
+      ).asJava),
+      transformKRaftControllerFullMetadataRequest(prevSnapshot,
+        newRequestControllerEpoch,
+        Seq(
+          new UpdateMetadataTopicState().
+            setTopicName(fooTopicName).
+            setTopicId(fooTopicId).
+            setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+          new UpdateMetadataTopicState().
+            setTopicName(barTopicName).
+            setTopicId(recreatedBarTopicId).
+            setPartitionStates(Seq(newBarPart0, newBarPart1, newBarPart2).asJava)
+        ).asJava
+      )
+    )
+  }
+
+  val buggySnapshot: MetadataSnapshot = new MetadataSnapshot(
+    new mutable.AnyRefMap[String, mutable.LongMap[UpdateMetadataPartitionState]],
+    prevSnapshot.topicIds,
+    prevSnapshot.controllerId,
+    prevSnapshot.aliveBrokers,
+    prevSnapshot.aliveNodes)
+
+  @Test
+  def transformUMRWithBuggySnapshot(): Unit = {
+    assertEquals((Seq(
+      new UpdateMetadataTopicState().
+        setTopicName(fooTopicName).
+        setTopicId(fooTopicId).
+        setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+      new UpdateMetadataTopicState().
+        setTopicName(barTopicName).
+        setTopicId(barTopicId).
+        setPartitionStates(Seq(newBarPart0, newBarPart1, newBarPart2).asJava),
+    ).asJava,
+      List[String](
+        "Error: topic foo appeared in currentMetadata.topicNames, but not in currentMetadata.partitionStates.",
+        "Error: topic bar appeared in currentMetadata.topicNames, but not in currentMetadata.partitionStates.",
+      ).asJava),
+      transformKRaftControllerFullMetadataRequest(buggySnapshot,
+        newRequestControllerEpoch,
+        Seq(
+          new UpdateMetadataTopicState().
+            setTopicName(fooTopicName).
+            setTopicId(fooTopicId).
+            setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+          new UpdateMetadataTopicState().
+            setTopicName(barTopicName).
+            setTopicId(barTopicId).
+            setPartitionStates(Seq(newBarPart0, newBarPart1, newBarPart2).asJava)
+        ).asJava
+      )
+    )
+  }
+
+  @Test
+  def testUpdateZkMetadataCacheViaHybridUMR(): Unit = {
+    val cache = MetadataCache.zkMetadataCache(1, MetadataVersion.latest())
+    cache.updateMetadata(123, createFullUMR(Seq(
+      new UpdateMetadataTopicState().
+        setTopicName(fooTopicName).
+        setTopicId(fooTopicId).
+        setPartitionStates(Seq(oldFooPart0, oldFooPart1).asJava),
+      new UpdateMetadataTopicState().
+        setTopicName(barTopicName).
+        setTopicId(barTopicId).
+        setPartitionStates(Seq(oldBarPart0, oldBarPart1).asJava),
+    )))
+    checkCacheContents(cache, Map(
+      fooTopicId -> Seq(oldFooPart0, oldFooPart1),
+      barTopicId -> Seq(oldBarPart0, oldBarPart1),
+    ))
+  }
+
+  @Test
+  def testUpdateZkMetadataCacheWithRecreatedTopic(): Unit = {
+    val cache = MetadataCache.zkMetadataCache(1, MetadataVersion.latest())
+    cache.updateMetadata(123, createFullUMR(Seq(
+      new UpdateMetadataTopicState().
+        setTopicName(fooTopicName).
+        setTopicId(fooTopicId).
+        setPartitionStates(Seq(oldFooPart0, oldFooPart1).asJava),
+      new UpdateMetadataTopicState().
+        setTopicName(barTopicName).
+        setTopicId(barTopicId).
+        setPartitionStates(Seq(oldBarPart0, oldBarPart1).asJava),
+    )))
+    cache.updateMetadata(124, createFullUMR(Seq(
+      new UpdateMetadataTopicState().
+        setTopicName(fooTopicName).
+        setTopicId(fooTopicId).
+        setPartitionStates(Seq(newFooPart0, newFooPart1).asJava),
+      new UpdateMetadataTopicState().
+        setTopicName(barTopicName).
+        setTopicId(barTopicId).
+        setPartitionStates(Seq(oldBarPart0, oldBarPart1).asJava),
+    )))
+    checkCacheContents(cache, Map(
+      fooTopicId -> Seq(newFooPart0, newFooPart1),
+      barTopicId -> Seq(oldBarPart0, oldBarPart1),
+    ))
+  }
+
+  def createFullUMR(
+    topicStates: Seq[UpdateMetadataTopicState]
+  ): UpdateMetadataRequest = {
+    val data = new UpdateMetadataRequestData().
+      setControllerId(0).
+      setIsKRaftController(true).
+      setControllerEpoch(123).
+      setBrokerEpoch(456).
+      setTopicStates(topicStates.asJava)
+    new UpdateMetadataRequest(data, 8.toShort)
+  }
+
+  def checkCacheContents(
+    cache: ZkMetadataCache,
+    expected: Map[Uuid, Iterable[UpdateMetadataPartitionState]],
+  ): Unit = {
+    val expectedTopics = new util.HashMap[String, Uuid]
+    val expectedIds = new util.HashMap[Uuid, String]
+    val expectedParts = new util.HashMap[String, util.Set[TopicPartition]]
+    expected.foreach {
+      case (id, states) =>
+        states.foreach {
+          case state =>
+            expectedTopics.put(state.topicName(), id)
+            expectedIds.put(id, state.topicName())
+            expectedParts.computeIfAbsent(state.topicName(),
+              _ => new util.HashSet[TopicPartition]()).
+              add(new TopicPartition(state.topicName(), state.partitionIndex()))
+        }
+    }
+    assertEquals(expectedTopics, cache.topicNamesToIds())
+    assertEquals(expectedIds, cache.topicIdsToNames())
+    cache.getAllTopics().foreach(topic =>
+      assertEquals(expectedParts.getOrDefault(topic, Collections.emptySet()),
+        cache.getTopicPartitions(topic).asJava)
+    )
+  }
 }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 70ee501d566..4720f2cab01 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -2735,9 +2735,16 @@ class ReplicaManagerTest {
     createHostedLogs("hosted-stray", numLogs = 10, replicaManager).toSet
     createStrayLogs(10, logManager)
 
-    val allReplicasFromLISR = Set(new TopicPartition("hosted-topic", 0), new TopicPartition("hosted-topic", 1))
+      val allReplicasFromLISR = Set(
+        new TopicPartition("hosted-topic", 0),
+        new TopicPartition("hosted-topic", 1)
+      ).map(p => new TopicIdPartition(new Uuid(p.topic().hashCode, p.topic().hashCode), p))
 
-    replicaManager.updateStrayLogs(replicaManager.findStrayPartitionsFromLeaderAndIsr(allReplicasFromLISR))
+      replicaManager.updateStrayLogs(
+        LogManager.findStrayReplicas(
+          config.nodeId,
+          LogManagerTest.createLeaderAndIsrRequestForStrayDetection(allReplicasFromLISR),
+          logManager.allLogs))
 
     assertEquals(validLogs, logManager.allLogs.toSet)
     assertEquals(validLogs.size, replicaManager.partitionCount.value)
@@ -2751,7 +2758,7 @@ class ReplicaManagerTest {
       val topicPartition = new TopicPartition(name, i)
       val partition = replicaManager.createPartition(topicPartition)
       partition.createLogIfNotExists(isNew = true, isFutureReplica = false,
-        new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), topicId = None)
+        new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), topicId = Some(new Uuid(name.hashCode, name.hashCode)))
       partition.log.get
     }
   }
@@ -2759,7 +2766,7 @@ class ReplicaManagerTest {
   private def createStrayLogs(numLogs: Int, logManager: LogManager): Seq[UnifiedLog] = {
     val name = "stray"
     for (i <- 0 until numLogs)
-      yield logManager.getOrCreateLog(new TopicPartition(name, i), topicId = None)
+      yield logManager.getOrCreateLog(new TopicPartition(name, i), topicId = Some(new Uuid(name.hashCode, name.hashCode)))
   }
 
   private def sendProducerAppend(
@@ -3229,6 +3236,9 @@ class ReplicaManagerTest {
     val path2 = TestUtils.tempRelativeDir("data2").getAbsolutePath
     props.put("log.dirs", path1 + "," + path2)
     propsModifier.apply(props)
+    if ("true".equals(props.getProperty(KafkaConfig.MigrationEnabledProp))) {
+      props.put("log.dirs", path1)
+    }
     val config = KafkaConfig.fromProps(props)
     val logProps = new Properties()
     val mockLog = setupMockLog(path1)
@@ -5779,6 +5789,101 @@ class ReplicaManagerTest {
 
     verify(spyRm).checkpointHighWatermarks()
   }
+
+  val foo0 = new TopicIdPartition(Uuid.fromString("Sl08ZXU2QW6uF5hIoSzc8w"), new TopicPartition("foo", 0))
+  val foo1 = new TopicIdPartition(Uuid.fromString("Sl08ZXU2QW6uF5hIoSzc8w"), new TopicPartition("foo", 1))
+  val newFoo0 = new TopicIdPartition(Uuid.fromString("JRCmVxWxQamFs4S8NXYufg"), new TopicPartition("foo", 0))
+  val bar0 = new TopicIdPartition(Uuid.fromString("69O438ZkTSeqqclTtZO2KA"), new TopicPartition("bar", 0))
+
+  def setupReplicaManagerForKRaftMigrationTest(): ReplicaManager = {
+    setupReplicaManagerWithMockedPurgatories(
+      brokerId = 3,
+      timer = new MockTimer(time),
+      aliveBrokerIds = Seq(0, 1, 2),
+      propsModifier = props => {
+        props.setProperty(KafkaConfig.MigrationEnabledProp, "true")
+        props.setProperty(KafkaConfig.QuorumVotersProp, "1000@localhost:9093")
+        props.setProperty(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER")
+        props.setProperty(KafkaConfig.ListenerSecurityProtocolMapProp, "CONTROLLER:PLAINTEXT,PLAINTEXT:PLAINTEXT")
+      })
+  }
+
+  def verifyPartitionIsOnlineAndHasId(
+    replicaManager: ReplicaManager,
+    topicIdPartition: TopicIdPartition
+  ): Unit = {
+    val partition = replicaManager.getPartition(topicIdPartition.topicPartition())
+    assertTrue(partition.isInstanceOf[HostedPartition.Online],
+      s"Expected ${topicIdPartition} to be in state: HostedPartition.Online. But was in state: ${partition}")
+    val hostedPartition = partition.asInstanceOf[HostedPartition.Online]
+    assertTrue(hostedPartition.partition.log.isDefined,
+      s"Expected ${topicIdPartition} to have a log set in ReplicaManager, but it did not.")
+    assertTrue(hostedPartition.partition.log.get.topicId.isDefined,
+      s"Expected the log for ${topicIdPartition} to topic ID set in LogManager, but it did not.")
+    assertEquals(topicIdPartition.topicId(), hostedPartition.partition.log.get.topicId.get)
+    assertEquals(topicIdPartition.topicPartition(), hostedPartition.partition.topicPartition)
+  }
+
+  def verifyPartitionIsOffline(
+    replicaManager: ReplicaManager,
+    topicIdPartition: TopicIdPartition
+  ): Unit = {
+    val partition = replicaManager.getPartition(topicIdPartition.topicPartition())
+    assertEquals(HostedPartition.None, partition, s"Expected ${topicIdPartition} to be offline, but it was: ${partition}")
+  }
+
+  @Test
+  def testFullLairDuringKRaftMigration(): Unit = {
+    val replicaManager = setupReplicaManagerForKRaftMigrationTest()
+    try {
+      val becomeLeaderRequest = LogManagerTest.createLeaderAndIsrRequestForStrayDetection(
+        Seq(foo0, foo1, bar0), Seq(3, 4, 3))
+      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
+      verifyPartitionIsOnlineAndHasId(replicaManager, foo0)
+      verifyPartitionIsOnlineAndHasId(replicaManager, foo1)
+      verifyPartitionIsOnlineAndHasId(replicaManager, bar0)
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
+
+  @Test
+  def testFullLairDuringKRaftMigrationRemovesOld(): Unit = {
+    val replicaManager = setupReplicaManagerForKRaftMigrationTest()
+    try {
+      val becomeLeaderRequest1 = LogManagerTest.createLeaderAndIsrRequestForStrayDetection(
+        Seq(foo0, foo1, bar0), Seq(3, 4, 3))
+      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest1, (_, _) => ())
+      val becomeLeaderRequest2 = LogManagerTest.createLeaderAndIsrRequestForStrayDetection(
+        Seq(bar0), Seq(3, 4, 3))
+      replicaManager.becomeLeaderOrFollower(2, becomeLeaderRequest2, (_, _) => ())
+
+      verifyPartitionIsOffline(replicaManager, foo0)
+      verifyPartitionIsOffline(replicaManager, foo1)
+      verifyPartitionIsOnlineAndHasId(replicaManager, bar0)
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
+
+  @Test
+  def testFullLairDuringKRaftMigrationWithTopicRecreations(): Unit = {
+    val replicaManager = setupReplicaManagerForKRaftMigrationTest()
+    try {
+      val becomeLeaderRequest1 = LogManagerTest.createLeaderAndIsrRequestForStrayDetection(
+        Seq(foo0, foo1, bar0), Seq(3, 4, 3))
+      replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest1, (_, _) => ())
+      val becomeLeaderRequest2 = LogManagerTest.createLeaderAndIsrRequestForStrayDetection(
+        Seq(newFoo0, bar0), Seq(3, 4, 3))
+      replicaManager.becomeLeaderOrFollower(2, becomeLeaderRequest2, (_, _) => ())
+
+      verifyPartitionIsOnlineAndHasId(replicaManager, newFoo0)
+      verifyPartitionIsOffline(replicaManager, foo1)
+      verifyPartitionIsOnlineAndHasId(replicaManager, bar0)
+    } finally {
+      replicaManager.shutdown(checkpointHW = false)
+    }
+  }
 }
 
 class MockReplicaSelector extends ReplicaSelector {
diff --git a/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala
index 136fb87e894..b271f433d22 100644
--- a/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala
+++ b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala
@@ -22,7 +22,7 @@ import kafka.coordinator.transaction.TransactionCoordinator
 import java.util.Collections.{singleton, singletonList, singletonMap}
 import java.util.Properties
 import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
-import kafka.log.{LogManager, UnifiedLog}
+import kafka.log.LogManager
 import kafka.server.{BrokerServer, KafkaConfig, ReplicaManager}
 import kafka.testkit.{KafkaClusterTestKit, TestKitNodes}
 import kafka.utils.TestUtils
@@ -31,12 +31,9 @@ import org.apache.kafka.clients.admin.{Admin, AlterConfigOp, ConfigEntry, NewTop
 import org.apache.kafka.common.config.ConfigResource
 import org.apache.kafka.common.config.ConfigResource.Type.BROKER
 import org.apache.kafka.common.utils.Exit
-import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.coordinator.group.GroupCoordinator
-import org.apache.kafka.image.{MetadataDelta, MetadataImage, MetadataImageTest, MetadataProvenance, TopicImage, TopicsImage}
+import org.apache.kafka.image.{MetadataDelta, MetadataImage, MetadataImageTest, MetadataProvenance}
 import org.apache.kafka.image.loader.LogDeltaManifest
-import org.apache.kafka.metadata.LeaderRecoveryState
-import org.apache.kafka.metadata.PartitionRegistration
 import org.apache.kafka.raft.LeaderAndEpoch
 import org.apache.kafka.server.fault.FaultHandler
 import org.junit.jupiter.api.Assertions.{assertEquals, assertNotNull, assertTrue}
@@ -87,99 +84,6 @@ class BrokerMetadataPublisherTest {
       MetadataImageTest.DELTA1).isDefined, "Expected to see delta for changed topic")
   }
 
-  @Test
-  def testFindStrayReplicas(): Unit = {
-    val brokerId = 0
-
-    // Topic has been deleted
-    val deletedTopic = "a"
-    val deletedTopicId = Uuid.randomUuid()
-    val deletedTopicPartition1 = new TopicPartition(deletedTopic, 0)
-    val deletedTopicLog1 = mockLog(deletedTopicId, deletedTopicPartition1)
-    val deletedTopicPartition2 = new TopicPartition(deletedTopic, 1)
-    val deletedTopicLog2 = mockLog(deletedTopicId, deletedTopicPartition2)
-
-    // Topic was deleted and recreated
-    val recreatedTopic = "b"
-    val recreatedTopicPartition = new TopicPartition(recreatedTopic, 0)
-    val recreatedTopicLog = mockLog(Uuid.randomUuid(), recreatedTopicPartition)
-    val recreatedTopicImage = topicImage(Uuid.randomUuid(), recreatedTopic, Map(
-      recreatedTopicPartition.partition -> Seq(0, 1, 2)
-    ))
-
-    // Topic exists, but some partitions were reassigned
-    val reassignedTopic = "c"
-    val reassignedTopicId = Uuid.randomUuid()
-    val reassignedTopicPartition = new TopicPartition(reassignedTopic, 0)
-    val reassignedTopicLog = mockLog(reassignedTopicId, reassignedTopicPartition)
-    val retainedTopicPartition = new TopicPartition(reassignedTopic, 1)
-    val retainedTopicLog = mockLog(reassignedTopicId, retainedTopicPartition)
-
-    val reassignedTopicImage = topicImage(reassignedTopicId, reassignedTopic, Map(
-      reassignedTopicPartition.partition -> Seq(1, 2, 3),
-      retainedTopicPartition.partition -> Seq(0, 2, 3)
-    ))
-
-    val logs = Seq(
-      deletedTopicLog1,
-      deletedTopicLog2,
-      recreatedTopicLog,
-      reassignedTopicLog,
-      retainedTopicLog
-    )
-
-    val image = topicsImage(Seq(
-      recreatedTopicImage,
-      reassignedTopicImage
-    ))
-
-    val expectedStrayPartitions = Set(
-      deletedTopicPartition1,
-      deletedTopicPartition2,
-      recreatedTopicPartition,
-      reassignedTopicPartition
-    )
-
-    val strayPartitions = BrokerMetadataPublisher.findStrayPartitions(brokerId, image, logs).toSet
-    assertEquals(expectedStrayPartitions, strayPartitions)
-  }
-
-  private def mockLog(
-    topicId: Uuid,
-    topicPartition: TopicPartition
-  ): UnifiedLog = {
-    val log = Mockito.mock(classOf[UnifiedLog])
-    Mockito.when(log.topicId).thenReturn(Some(topicId))
-    Mockito.when(log.topicPartition).thenReturn(topicPartition)
-    log
-  }
-
-  private def topicImage(
-    topicId: Uuid,
-    topic: String,
-    partitions: Map[Int, Seq[Int]]
-  ): TopicImage = {
-    val partitionRegistrations = partitions.map { case (partitionId, replicas) =>
-      Int.box(partitionId) -> new PartitionRegistration.Builder().
-        setReplicas(replicas.toArray).
-        setIsr(replicas.toArray).
-        setLeader(replicas.head).
-        setLeaderRecoveryState(LeaderRecoveryState.RECOVERED).
-        setLeaderEpoch(0).
-        setPartitionEpoch(0).
-        build()
-    }
-    new TopicImage(topic, topicId, partitionRegistrations.asJava)
-  }
-
-  private def topicsImage(
-    topics: Seq[TopicImage]
-  ): TopicsImage = {
-    var retval = TopicsImage.EMPTY
-    topics.foreach { t => retval = retval.including(t) }
-    retval
-  }
-
   private def newMockDynamicConfigPublisher(
     broker: BrokerServer,
     errorHandler: FaultHandler