You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by da...@apache.org on 2021/11/15 09:09:20 UTC

[kafka] branch 3.1 updated: KAFKA-13111: Re-evaluate Fetch Sessions when using topic IDs (#11331)

This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.1 by this push:
     new f424324  KAFKA-13111: Re-evaluate Fetch Sessions when using topic IDs (#11331)
f424324 is described below

commit f424324fd9f39ce81184fe868194b713ebea03b9
Author: Justine Olshan <jo...@confluent.io>
AuthorDate: Mon Nov 15 01:04:43 2021 -0800

    KAFKA-13111: Re-evaluate Fetch Sessions when using topic IDs (#11331)
    
    With the changes for topic IDs, we have a different flow. When a broker receives a request, it uses a map to convert the topic ID to topic names. If the topic ID is not found in the map, we return a top level error and close the session. This decision was motivated by the difficulty to store “unresolved” partitions in the session. In earlier iterations we stored an “unresolved” partition object in the cache, but it was somewhat hard to reason about and required extra logic to try to r [...]
    
    One helpful simplifying factor is that we only allow one type of request (uses topic ID or does not use topic ID) in the session. That means we can rely on a session continuing to have the same information. We don’t have to worry about converting topics only known by name to topic ID for a response and we won’t need to convert topics only known by ID to name for a response.
    
    This PR introduces a change to store the "unresolved partitions" in the cached partition object. If a version 13+ request is sent with a topic ID that is unknown, a cached partition will be created with that fetch request data and a null topic name. On subsequent incremental requests, unresolved partitions may be resolved with the new IDs found in the metadata cache. When handling the request, getting all partitions will return a TopicIdPartition object that will be used to handle the [...]
    
    This PR involves changes both in FetchSessionHandler and FetchSession. Some major changes are outlined below.
    
    1. FetchSessionHandler: Forgetting a topic and adding a new topic with the same name -  We may have a case where there is a topic foo with ID 1 in the session. Upon a subsequent metadata update, we may have topic foo with ID 2. This means that topic foo has been deleted and recreated. When sending fetch requests version 13+ we will send a request to add foo ID 2 to the session and remove foo ID 1. Otherwise, we will fall back to the same behavior for versions 12 and below
    
    2. FetchSession: Resolving in Incremental Sessions - Incremental sessions contain two distinct sets of partitions. Partitions that are sent in the latest request that are new/updates/forgotten partitions and the partitions already in the session. If we want to resolve unknown topic IDs we will need to handle both cases.
        * Partitions in the request  - These partitions are either new or updating/forgetting previous partitions in the session. The new partitions are trivial. We either have a resolved partition or create a partition that is unresolved. For the other cases, we need to be a bit more careful.
            * For updated partitions we have a few cases – keep in mind, we may not programmatically know if a partition is an update:
                1. partition in session is resolved, update is resolved: trivial

                2. partition in session is unresolved, update is unresolved: in code, this is equivalent to the case above, so trivial as well

                3. partition in session is unresolved, update is resolved: this means the partition in the session does not have a name, but the metadata cache now contains the name –  to fix this we can check if there exists a cached partition with the given ID and update it both with the partition update and with the topic name.

                4. partition in session is resolved, update is unresolved: this means the partition in the session has a name, but the update was unable to be resolved (ie, the topic is deleted) – this is the odd case. We will look up the partition using the ID. We will find the old version with a name but will not replace the name. This will lead to an UNKNOWN_TOPIC_OR_PARTITION or INCONSISTENT_TOPIC_ID error which will be handled with a metadata update. Likely a future request will forg [...]
                5. Two partitions in the session have IDs, but they are different: only one topic ID should exist in the metadata at a time, so likely only one topic ID is in the fetch set. The other one should be in the toForget. We will be able to remove this partition from the session. If for some reason, we don't try to forget this partition — one of the partitions in the session will cause an inconsistent topic ID error and the metadata for this partition will be refreshed — this sho [...]
            * For the forgotten partitions we have the same cases:
                1. partition in session is resolved, forgotten is resolved: trivial

                2. partition in session is unresolved, forgotten is unresolved: in code, this is equivalent to the case above, so trivial as well

                3. partition in session is unresolved, forgotten is resolved: this means the partition in the session does not have a name, but the metadata cache now contains the name –  to fix this we can check if there exists a cached partition with the given ID and try to forget it before we check the resolved name case.

                4. partition in session is resolved, update is unresolved: this means the partition in the session has a name, but the update was unable to be resolved (ie, the topic is deleted) We will look up the partition using the ID. We will find the old version with a name and be able to delete it.

                5. both partitions in the session have IDs, but they are different: This should be the same case as described above. If we somehow do not have the ID in the session, no partition will be removed. This should not happen unless the Fetch Session Handler is out of sync.

        * Partitions in the session - there may be some partitions in the session already that are unresolved. We can resolve them in forEachPartition using a method that checks if the partition is unresolved and tries to resolve it using a topicName map from the request. The partition will be resolved before the function using the cached partition is applied.
    
    Reviewers: David Jacot <dj...@confluent.io>
---
 checkstyle/suppressions.xml                        |    2 +-
 .../apache/kafka/clients/FetchSessionHandler.java  |  165 ++-
 .../kafka/clients/consumer/internals/Fetcher.java  |   27 +-
 .../apache/kafka/common/requests/FetchRequest.java |  148 +--
 .../kafka/common/requests/FetchResponse.java       |   53 +-
 .../kafka/clients/FetchSessionHandlerTest.java     |  420 +++++---
 .../kafka/clients/consumer/KafkaConsumerTest.java  |   16 +-
 .../clients/consumer/internals/FetcherTest.java    |  626 ++++++++---
 .../kafka/common/requests/FetchRequestTest.java    |  210 ++++
 .../kafka/common/requests/RequestResponseTest.java |  138 +--
 .../kafka/common/requests/RequestTestUtils.java    |   17 +
 .../scala/kafka/server/AbstractFetcherThread.scala |   18 +-
 .../src/main/scala/kafka/server/DelayedFetch.scala |   29 +-
 .../scala/kafka/server/DelayedOperationKey.scala   |    5 +-
 .../src/main/scala/kafka/server/FetchSession.scala |  246 ++---
 core/src/main/scala/kafka/server/KafkaApis.scala   |  116 +-
 .../kafka/server/ReplicaAlterLogDirsThread.scala   |   21 +-
 .../scala/kafka/server/ReplicaFetcherThread.scala  |   14 +-
 .../main/scala/kafka/server/ReplicaManager.scala   |   66 +-
 .../kafka/tools/ReplicaVerificationTool.scala      |    6 +-
 .../kafka/api/AuthorizerIntegrationTest.scala      |   17 +-
 .../kafka/server/DelayedFetchTest.scala            |   62 +-
 .../FetchRequestBetweenDifferentIbpTest.scala      |    6 +-
 .../kafka/server/AbstractFetcherThreadTest.scala   |   62 +-
 .../kafka/server/BaseClientQuotaManagerTest.scala  |    2 +-
 .../unit/kafka/server/BaseFetchRequestTest.scala   |   10 +-
 .../FetchRequestDownConversionConfigTest.scala     |   15 +-
 .../kafka/server/FetchRequestMaxBytesTest.scala    |    4 +-
 .../scala/unit/kafka/server/FetchRequestTest.scala |  100 +-
 .../FetchRequestWithLegacyMessageFormatTest.scala  |    3 +-
 .../scala/unit/kafka/server/FetchSessionTest.scala | 1108 ++++++++++++++------
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  138 ++-
 .../scala/unit/kafka/server/LogOffsetTest.scala    |    7 +-
 .../server/ReplicaAlterLogDirsThreadTest.scala     |   50 +-
 .../kafka/server/ReplicaFetcherThreadTest.scala    |   79 +-
 .../server/ReplicaManagerConcurrencyTest.scala     |   21 +-
 .../kafka/server/ReplicaManagerQuotasTest.scala    |   37 +-
 .../unit/kafka/server/ReplicaManagerTest.scala     |  222 ++--
 .../scala/unit/kafka/server/RequestQuotaTest.scala |    4 +-
 .../TopicIdWithOldInterBrokerProtocolTest.scala    |   22 +-
 .../util/ReplicaFetcherMockBlockingSend.scala      |   10 +-
 .../kafka/jmh/common/FetchRequestBenchmark.java    |   14 +-
 .../kafka/jmh/common/FetchResponseBenchmark.java   |    9 +-
 .../jmh/fetcher/ReplicaFetcherThreadBenchmark.java |    7 +-
 .../jmh/fetchsession/FetchSessionBenchmark.java    |   17 +-
 45 files changed, 2854 insertions(+), 1515 deletions(-)

diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 20224a4..c1f9b34 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -68,7 +68,7 @@
               files="(Utils|Topic|KafkaLZ4BlockOutputStream|AclData|JoinGroupRequest).java"/>
 
     <suppress checks="CyclomaticComplexity"
-              files="(ConsumerCoordinator|Fetcher|KafkaProducer|ConfigDef|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator|AbstractCoordinator|TransactionManager|AbstractStickyAssignor|DefaultSslEngineFactory|Authorizer|RecordAccumulator|MemoryRecords).java"/>
+              files="(ConsumerCoordinator|Fetcher|KafkaProducer|ConfigDef|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator|AbstractCoordinator|TransactionManager|AbstractStickyAssignor|DefaultSslEngineFactory|Authorizer|RecordAccumulator|MemoryRecords|FetchSessionHandler).java"/>
 
     <suppress checks="JavaNCSS"
               files="(AbstractRequest|AbstractResponse|KerberosLogin|WorkerSinkTaskTest|TransactionManagerTest|SenderTest|KafkaAdminClient|ConsumerCoordinatorTest|KafkaAdminClientTest|KafkaRaftClientTest).java"/>
diff --git a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
index 8e29f9c..aca847c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
+++ b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
@@ -17,6 +17,7 @@
 
 package org.apache.kafka.clients;
 
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.protocol.Errors;
@@ -77,11 +78,6 @@ public class FetchSessionHandler {
         new LinkedHashMap<>(0);
 
     /**
-     * All of the topic ids mapped to topic names for topics which exist in the fetch request session.
-     */
-    private Map<String, Uuid> sessionTopicIds = new HashMap<>(0);
-
-    /**
      * All of the topic names mapped to topic ids for topics which exist in the fetch request session.
      */
     private Map<Uuid, String> sessionTopicNames = new HashMap<>(0);
@@ -99,22 +95,18 @@ public class FetchSessionHandler {
         /**
          * The partitions to send in the request's "forget" list.
          */
-        private final List<TopicPartition> toForget;
-
-        /**
-         * All of the partitions which exist in the fetch request session.
-         */
-        private final Map<TopicPartition, PartitionData> sessionPartitions;
+        private final List<TopicIdPartition> toForget;
 
         /**
-         * All of the topic IDs for topics which exist in the fetch request session.
+         * The partitions to send in the request's "forget" list if
+         * the version is >= 13.
          */
-        private final Map<String, Uuid> topicIds;
+        private final List<TopicIdPartition> toReplace;
 
         /**
-         * All of the topic IDs for topics which exist in the fetch request session
+         * All of the partitions which exist in the fetch request session.
          */
-        private final Map<Uuid, String> topicNames;
+        private final Map<TopicPartition, PartitionData> sessionPartitions;
 
         /**
          * The metadata to use in this fetch request.
@@ -128,17 +120,15 @@ public class FetchSessionHandler {
         private final boolean canUseTopicIds;
 
         FetchRequestData(Map<TopicPartition, PartitionData> toSend,
-                         List<TopicPartition> toForget,
+                         List<TopicIdPartition> toForget,
+                         List<TopicIdPartition> toReplace,
                          Map<TopicPartition, PartitionData> sessionPartitions,
-                         Map<String, Uuid> topicIds,
-                         Map<Uuid, String> topicNames,
                          FetchMetadata metadata,
                          boolean canUseTopicIds) {
             this.toSend = toSend;
             this.toForget = toForget;
+            this.toReplace = toReplace;
             this.sessionPartitions = sessionPartitions;
-            this.topicIds = topicIds;
-            this.topicNames = topicNames;
             this.metadata = metadata;
             this.canUseTopicIds = canUseTopicIds;
         }
@@ -153,25 +143,24 @@ public class FetchSessionHandler {
         /**
          * Get a list of partitions to forget in this fetch request.
          */
-        public List<TopicPartition> toForget() {
+        public List<TopicIdPartition> toForget() {
             return toForget;
         }
 
         /**
+         * Get a list of partitions to forget in this fetch request.
+         */
+        public List<TopicIdPartition> toReplace() {
+            return toReplace;
+        }
+
+        /**
          * Get the full set of partitions involved in this fetch request.
          */
         public Map<TopicPartition, PartitionData> sessionPartitions() {
             return sessionPartitions;
         }
 
-        public Map<String, Uuid> topicIds() {
-            return topicIds;
-        }
-
-        public Map<Uuid, String> topicNames() {
-            return topicNames;
-        }
-
         public FetchMetadata metadata() {
             return metadata;
         }
@@ -201,7 +190,14 @@ public class FetchSessionHandler {
                 }
                 bld.append("), toForget=(");
                 prefix = "";
-                for (TopicPartition partition : toForget) {
+                for (TopicIdPartition partition : toForget) {
+                    bld.append(prefix);
+                    bld.append(partition);
+                    prefix = ", ";
+                }
+                bld.append("), toReplace=(");
+                prefix = "";
+                for (TopicIdPartition partition : toReplace) {
                     bld.append(prefix);
                     bld.append(partition);
                     prefix = ", ";
@@ -216,15 +212,6 @@ public class FetchSessionHandler {
                     }
                 }
             }
-            bld.append("), topicIds=(");
-            String prefix = "";
-            for (Map.Entry<String, Uuid> entry : topicIds.entrySet()) {
-                bld.append(prefix);
-                bld.append(entry.getKey());
-                bld.append(": ");
-                bld.append(entry.getValue());
-                prefix = ", ";
-            }
             if (canUseTopicIds) {
                 bld.append("), canUseTopicIds=True");
             } else {
@@ -250,32 +237,32 @@ public class FetchSessionHandler {
          * incremental fetch requests (see below).
          */
         private LinkedHashMap<TopicPartition, PartitionData> next;
-        private Map<String, Uuid> topicIds;
+        private Map<Uuid, String> topicNames;
         private final boolean copySessionPartitions;
         private int partitionsWithoutTopicIds = 0;
 
         Builder() {
             this.next = new LinkedHashMap<>();
-            this.topicIds = new HashMap<>();
+            this.topicNames = new HashMap<>();
             this.copySessionPartitions = true;
         }
 
         Builder(int initialSize, boolean copySessionPartitions) {
             this.next = new LinkedHashMap<>(initialSize);
-            this.topicIds = new HashMap<>(initialSize);
+            this.topicNames = new HashMap<>();
             this.copySessionPartitions = copySessionPartitions;
         }
 
         /**
          * Mark that we want data from this partition in the upcoming fetch.
          */
-        public void add(TopicPartition topicPartition, Uuid topicId, PartitionData data) {
+        public void add(TopicPartition topicPartition, PartitionData data) {
             next.put(topicPartition, data);
             // topicIds should not change between adding partitions and building, so we can use putIfAbsent
-            if (!topicId.equals(Uuid.ZERO_UUID)) {
-                topicIds.putIfAbsent(topicPartition.topic(), topicId);
-            } else {
+            if (data.topicId.equals(Uuid.ZERO_UUID)) {
                 partitionsWithoutTopicIds++;
+            } else {
+                topicNames.putIfAbsent(data.topicId, topicPartition.topic());
             }
         }
 
@@ -285,52 +272,55 @@ public class FetchSessionHandler {
             if (nextMetadata.isFull()) {
                 if (log.isDebugEnabled()) {
                     log.debug("Built full fetch {} for node {} with {}.",
-                              nextMetadata, node, partitionsToLogString(next.keySet()));
+                            nextMetadata, node, topicPartitionsToLogString(next.keySet()));
                 }
                 sessionPartitions = next;
                 next = null;
                 // Only add topic IDs to the session if we are using topic IDs.
                 if (canUseTopicIds) {
-                    sessionTopicIds = topicIds;
-                    sessionTopicNames = new HashMap<>(topicIds.size());
-                    topicIds.forEach((name, id) -> sessionTopicNames.put(id, name));
+                    sessionTopicNames = topicNames;
                 } else {
-                    sessionTopicIds = new HashMap<>();
-                    sessionTopicNames = new HashMap<>();
+                    sessionTopicNames = Collections.emptyMap();
                 }
-                topicIds = null;
                 Map<TopicPartition, PartitionData> toSend =
-                    Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions));
-                Map<String, Uuid> toSendTopicIds =
-                    Collections.unmodifiableMap(new HashMap<>(sessionTopicIds));
-                Map<Uuid, String> toSendTopicNames =
-                    Collections.unmodifiableMap(new HashMap<>(sessionTopicNames));
-                return new FetchRequestData(toSend, Collections.emptyList(), toSend, toSendTopicIds, toSendTopicNames, nextMetadata, canUseTopicIds);
+                        Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions));
+                return new FetchRequestData(toSend, Collections.emptyList(), Collections.emptyList(), toSend, nextMetadata, canUseTopicIds);
             }
 
-            List<TopicPartition> added = new ArrayList<>();
-            List<TopicPartition> removed = new ArrayList<>();
-            List<TopicPartition> altered = new ArrayList<>();
+            List<TopicIdPartition> added = new ArrayList<>();
+            List<TopicIdPartition> removed = new ArrayList<>();
+            List<TopicIdPartition> altered = new ArrayList<>();
+            List<TopicIdPartition> replaced = new ArrayList<>();
             for (Iterator<Entry<TopicPartition, PartitionData>> iter =
-                     sessionPartitions.entrySet().iterator(); iter.hasNext(); ) {
+                 sessionPartitions.entrySet().iterator(); iter.hasNext(); ) {
                 Entry<TopicPartition, PartitionData> entry = iter.next();
                 TopicPartition topicPartition = entry.getKey();
                 PartitionData prevData = entry.getValue();
                 PartitionData nextData = next.remove(topicPartition);
                 if (nextData != null) {
-                    if (!prevData.equals(nextData)) {
+                    // We basically check if the new partition had the same topic ID. If not,
+                    // we add it to the "replaced" set. If the request is version 13 or higher, the replaced
+                    // partition will be forgotten. In any case, we will send the new partition in the request.
+                    if (!prevData.topicId.equals(nextData.topicId)
+                            && !prevData.topicId.equals(Uuid.ZERO_UUID)
+                            && !nextData.topicId.equals(Uuid.ZERO_UUID)) {
+                        // Re-add the replaced partition to the end of 'next'
+                        next.put(topicPartition, nextData);
+                        entry.setValue(nextData);
+                        replaced.add(new TopicIdPartition(prevData.topicId, topicPartition));
+                    } else if (!prevData.equals(nextData)) {
                         // Re-add the altered partition to the end of 'next'
                         next.put(topicPartition, nextData);
                         entry.setValue(nextData);
-                        altered.add(topicPartition);
+                        altered.add(new TopicIdPartition(nextData.topicId, topicPartition));
                     }
                 } else {
                     // Remove this partition from the session.
                     iter.remove();
                     // Indicate that we no longer want to listen to this partition.
-                    removed.add(topicPartition);
+                    removed.add(new TopicIdPartition(prevData.topicId, topicPartition));
                     // If we do not have this topic ID in the builder or the session, we can not use topic IDs.
-                    if (canUseTopicIds && !topicIds.containsKey(topicPartition.topic()) && !sessionTopicIds.containsKey(topicPartition.topic()))
+                    if (canUseTopicIds && prevData.topicId.equals(Uuid.ZERO_UUID))
                         canUseTopicIds = false;
                 }
             }
@@ -346,38 +336,34 @@ public class FetchSessionHandler {
                     break;
                 }
                 sessionPartitions.put(topicPartition, nextData);
-                added.add(topicPartition);
+                added.add(new TopicIdPartition(nextData.topicId, topicPartition));
             }
 
             // Add topic IDs to session if we can use them. If an ID is inconsistent, we will handle in the receiving broker.
             // If we switched from using topic IDs to not using them (or vice versa), that error will also be handled in the receiving broker.
             if (canUseTopicIds) {
-                for (Map.Entry<String, Uuid> topic : topicIds.entrySet()) {
-                    String topicName = topic.getKey();
-                    Uuid addedId = topic.getValue();
-                    sessionTopicIds.put(topicName, addedId);
-                    sessionTopicNames.put(addedId, topicName);
-                }
+                sessionTopicNames = topicNames;
+            } else {
+                sessionTopicNames = Collections.emptyMap();
             }
 
             if (log.isDebugEnabled()) {
-                log.debug("Built incremental fetch {} for node {}. Added {}, altered {}, removed {} " +
-                          "out of {}", nextMetadata, node, partitionsToLogString(added),
-                          partitionsToLogString(altered), partitionsToLogString(removed),
-                          partitionsToLogString(sessionPartitions.keySet()));
+                log.debug("Built incremental fetch {} for node {}. Added {}, altered {}, removed {}, " +
+                          "replaced {} out of {}", nextMetadata, node, topicIdPartitionsToLogString(added),
+                          topicIdPartitionsToLogString(altered), topicIdPartitionsToLogString(removed),
+                          topicIdPartitionsToLogString(replaced), topicPartitionsToLogString(sessionPartitions.keySet()));
             }
             Map<TopicPartition, PartitionData> toSend = Collections.unmodifiableMap(next);
             Map<TopicPartition, PartitionData> curSessionPartitions = copySessionPartitions
                     ? Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions))
                     : Collections.unmodifiableMap(sessionPartitions);
-            Map<String, Uuid> toSendTopicIds =
-                Collections.unmodifiableMap(new HashMap<>(sessionTopicIds));
-            Map<Uuid, String> toSendTopicNames =
-                Collections.unmodifiableMap(new HashMap<>(sessionTopicNames));
             next = null;
-            topicIds = null;
-            return new FetchRequestData(toSend, Collections.unmodifiableList(removed), curSessionPartitions,
-                                        toSendTopicIds, toSendTopicNames, nextMetadata, canUseTopicIds);
+            return new FetchRequestData(toSend,
+                    Collections.unmodifiableList(removed),
+                    Collections.unmodifiableList(replaced),
+                    curSessionPartitions,
+                    nextMetadata,
+                    canUseTopicIds);
         }
     }
 
@@ -397,7 +383,14 @@ public class FetchSessionHandler {
         return new Builder(size, copySessionPartitions);
     }
 
-    private String partitionsToLogString(Collection<TopicPartition> partitions) {
+    private String topicPartitionsToLogString(Collection<TopicPartition> partitions) {
+        if (!log.isTraceEnabled()) {
+            return String.format("%d partition(s)", partitions.size());
+        }
+        return "(" + Utils.join(partitions, ", ") + ")";
+    }
+
+    private String topicIdPartitionsToLogString(Collection<TopicIdPartition> partitions) {
         if (!log.isTraceEnabled()) {
             return String.format("%d partition(s)", partitions.size());
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 055f2d1..d567f5b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -262,11 +262,12 @@ public class Fetcher<K, V> implements Closeable {
                 maxVersion = ApiKeys.FETCH.latestVersion();
             }
             final FetchRequest.Builder request = FetchRequest.Builder
-                    .forConsumer(maxVersion, this.maxWaitMs, this.minBytes, data.toSend(), data.topicIds())
+                    .forConsumer(maxVersion, this.maxWaitMs, this.minBytes, data.toSend())
                     .isolationLevel(isolationLevel)
                     .setMaxBytes(this.maxBytes)
                     .metadata(data.metadata())
-                    .toForget(data.toForget())
+                    .removed(data.toForget())
+                    .replaced(data.toReplace())
                     .rackId(clientRackId);
 
             if (log.isDebugEnabled()) {
@@ -291,15 +292,13 @@ public class Fetcher<K, V> implements Closeable {
                                 return;
                             }
                             if (!handler.handleResponse(response, resp.requestHeader().apiVersion())) {
-                                if (response.error() == Errors.FETCH_SESSION_TOPIC_ID_ERROR
-                                        || response.error() == Errors.UNKNOWN_TOPIC_ID
-                                        || response.error() == Errors.INCONSISTENT_TOPIC_ID) {
+                                if (response.error() == Errors.FETCH_SESSION_TOPIC_ID_ERROR) {
                                     metadata.requestUpdate();
                                 }
                                 return;
                             }
 
-                            Map<TopicPartition, FetchResponseData.PartitionData> responseData = response.responseData(data.topicNames(), resp.requestHeader().apiVersion());
+                            Map<TopicPartition, FetchResponseData.PartitionData> responseData = response.responseData(handler.sessionTopicNames(), resp.requestHeader().apiVersion());
                             Set<TopicPartition> partitions = new HashSet<>(responseData.keySet());
                             FetchResponseMetricAggregator metricAggregator = new FetchResponseMetricAggregator(sensors, partitions);
 
@@ -314,8 +313,8 @@ public class Fetcher<K, V> implements Closeable {
                                                 new Object[]{partition, data.metadata()}).getMessage();
                                     } else {
                                         message = MessageFormatter.arrayFormat(
-                                                "Response for missing session request partition: partition={}; metadata={}; toSend={}; toForget={}",
-                                                new Object[]{partition, data.metadata(), data.toSend(), data.toForget()}).getMessage();
+                                                "Response for missing session request partition: partition={}; metadata={}; toSend={}; toForget={}; toReplace={}",
+                                                new Object[]{partition, data.metadata(), data.toSend(), data.toForget(), data.toReplace()}).getMessage();
                                     }
 
                                     // Received fetch response for missing session partition
@@ -1238,9 +1237,9 @@ public class Fetcher<K, V> implements Closeable {
                     builder = handler.newBuilder();
                     fetchable.put(node, builder);
                 }
-
-                builder.add(partition, topicIds.getOrDefault(partition.topic(), Uuid.ZERO_UUID), new FetchRequest.PartitionData(position.offset,
-                    FetchRequest.INVALID_LOG_START_OFFSET, this.fetchSize,
+                builder.add(partition, new FetchRequest.PartitionData(
+                    topicIds.getOrDefault(partition.topic(), Uuid.ZERO_UUID),
+                    position.offset, FetchRequest.INVALID_LOG_START_OFFSET, this.fetchSize,
                     position.currentLeader.epoch, Optional.empty()));
 
                 log.debug("Added {} fetch request for partition {} at position {} to node {}", isolationLevel,
@@ -1353,6 +1352,12 @@ public class Fetcher<K, V> implements Closeable {
             } else if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) {
                 log.warn("Received unknown topic or partition error in fetch for partition {}", tp);
                 this.metadata.requestUpdate();
+            } else if (error == Errors.UNKNOWN_TOPIC_ID) {
+                log.warn("Received unknown topic ID error in fetch for partition {}", tp);
+                this.metadata.requestUpdate();
+            } else if (error == Errors.INCONSISTENT_TOPIC_ID) {
+                log.warn("Received inconsistent topic ID error in fetch for partition {}", tp);
+                this.metadata.requestUpdate();
             } else if (error == Errors.OFFSET_OUT_OF_RANGE) {
                 Optional<Integer> clearedReplicaId = subscriptions.clearPreferredReadReplica(tp);
                 if (!clearedReplicaId.isPresent()) {
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
index 1531433..48ba022 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java
@@ -17,10 +17,11 @@
 package org.apache.kafka.common.requests;
 
 import org.apache.kafka.common.IsolationLevel;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
-import org.apache.kafka.common.errors.UnknownTopicIdException;
 import org.apache.kafka.common.message.FetchRequestData;
+import org.apache.kafka.common.message.FetchRequestData.ForgottenTopic;
 import org.apache.kafka.common.message.FetchResponseData;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.ByteBufferAccessor;
@@ -47,13 +48,14 @@ public class FetchRequest extends AbstractRequest {
     public static final long INVALID_LOG_START_OFFSET = -1L;
 
     private final FetchRequestData data;
-    private volatile LinkedHashMap<TopicPartition, PartitionData> fetchData = null;
-    private volatile List<TopicPartition> toForget = null;
+    private volatile LinkedHashMap<TopicIdPartition, PartitionData> fetchData = null;
+    private volatile List<TopicIdPartition> toForget = null;
 
     // This is an immutable read-only structures derived from FetchRequestData
     private final FetchMetadata metadata;
 
     public static final class PartitionData {
+        public final Uuid topicId;
         public final long fetchOffset;
         public final long logStartOffset;
         public final int maxBytes;
@@ -61,21 +63,24 @@ public class FetchRequest extends AbstractRequest {
         public final Optional<Integer> lastFetchedEpoch;
 
         public PartitionData(
+            Uuid topicId,
             long fetchOffset,
             long logStartOffset,
             int maxBytes,
             Optional<Integer> currentLeaderEpoch
         ) {
-            this(fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, Optional.empty());
+            this(topicId, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, Optional.empty());
         }
 
         public PartitionData(
+            Uuid topicId,
             long fetchOffset,
             long logStartOffset,
             int maxBytes,
             Optional<Integer> currentLeaderEpoch,
             Optional<Integer> lastFetchedEpoch
         ) {
+            this.topicId = topicId;
             this.fetchOffset = fetchOffset;
             this.logStartOffset = logStartOffset;
             this.maxBytes = maxBytes;
@@ -88,7 +93,8 @@ public class FetchRequest extends AbstractRequest {
             if (this == o) return true;
             if (o == null || getClass() != o.getClass()) return false;
             PartitionData that = (PartitionData) o;
-            return fetchOffset == that.fetchOffset &&
+            return Objects.equals(topicId, that.topicId) &&
+                fetchOffset == that.fetchOffset &&
                 logStartOffset == that.logStartOffset &&
                 maxBytes == that.maxBytes &&
                 Objects.equals(currentLeaderEpoch, that.currentLeaderEpoch) &&
@@ -97,13 +103,14 @@ public class FetchRequest extends AbstractRequest {
 
         @Override
         public int hashCode() {
-            return Objects.hash(fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, lastFetchedEpoch);
+            return Objects.hash(topicId, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, lastFetchedEpoch);
         }
 
         @Override
         public String toString() {
             return "PartitionData(" +
-                "fetchOffset=" + fetchOffset +
+                "topicId=" + topicId +
+                ", fetchOffset=" + fetchOffset +
                 ", logStartOffset=" + logStartOffset +
                 ", maxBytes=" + maxBytes +
                 ", currentLeaderEpoch=" + currentLeaderEpoch +
@@ -124,33 +131,31 @@ public class FetchRequest extends AbstractRequest {
         private final int maxWait;
         private final int minBytes;
         private final int replicaId;
-        private final Map<TopicPartition, PartitionData> fetchData;
-        private final Map<String, Uuid> topicIds;
+        private final Map<TopicPartition, PartitionData> toFetch;
         private IsolationLevel isolationLevel = IsolationLevel.READ_UNCOMMITTED;
         private int maxBytes = DEFAULT_RESPONSE_MAX_BYTES;
         private FetchMetadata metadata = FetchMetadata.LEGACY;
-        private List<TopicPartition> toForget = Collections.emptyList();
+        private List<TopicIdPartition> removed = Collections.emptyList();
+        private List<TopicIdPartition> replaced = Collections.emptyList();
         private String rackId = "";
 
-        public static Builder forConsumer(short maxVersion, int maxWait, int minBytes, Map<TopicPartition, PartitionData> fetchData,
-                                          Map<String, Uuid> topicIds) {
+        public static Builder forConsumer(short maxVersion, int maxWait, int minBytes, Map<TopicPartition, PartitionData> fetchData) {
             return new Builder(ApiKeys.FETCH.oldestVersion(), maxVersion,
-                CONSUMER_REPLICA_ID, maxWait, minBytes, fetchData, topicIds);
+                CONSUMER_REPLICA_ID, maxWait, minBytes, fetchData);
         }
 
         public static Builder forReplica(short allowedVersion, int replicaId, int maxWait, int minBytes,
-                                         Map<TopicPartition, PartitionData> fetchData, Map<String, Uuid> topicIds) {
-            return new Builder(allowedVersion, allowedVersion, replicaId, maxWait, minBytes, fetchData, topicIds);
+                                         Map<TopicPartition, PartitionData> fetchData) {
+            return new Builder(allowedVersion, allowedVersion, replicaId, maxWait, minBytes, fetchData);
         }
 
         public Builder(short minVersion, short maxVersion, int replicaId, int maxWait, int minBytes,
-                        Map<TopicPartition, PartitionData> fetchData, Map<String, Uuid> topicIds) {
+                       Map<TopicPartition, PartitionData> fetchData) {
             super(ApiKeys.FETCH, minVersion, maxVersion);
             this.replicaId = replicaId;
             this.maxWait = maxWait;
             this.minBytes = minBytes;
-            this.fetchData = fetchData;
-            this.topicIds = topicIds;
+            this.toFetch = fetchData;
         }
 
         public Builder isolationLevel(IsolationLevel isolationLevel) {
@@ -169,7 +174,7 @@ public class FetchRequest extends AbstractRequest {
         }
 
         public Map<TopicPartition, PartitionData> fetchData() {
-            return this.fetchData;
+            return this.toFetch;
         }
 
         public Builder setMaxBytes(int maxBytes) {
@@ -177,15 +182,37 @@ public class FetchRequest extends AbstractRequest {
             return this;
         }
 
-        public List<TopicPartition> toForget() {
-            return toForget;
+        public List<TopicIdPartition> removed() {
+            return removed;
         }
 
-        public Builder toForget(List<TopicPartition> toForget) {
-            this.toForget = toForget;
+        public Builder removed(List<TopicIdPartition> removed) {
+            this.removed = removed;
             return this;
         }
 
+        public List<TopicIdPartition> replaced() {
+            return replaced;
+        }
+
+        public Builder replaced(List<TopicIdPartition> replaced) {
+            this.replaced = replaced;
+            return this;
+        }
+
+        private void addToForgottenTopicMap(List<TopicIdPartition> toForget, Map<String, FetchRequestData.ForgottenTopic> forgottenTopicMap) {
+            toForget.forEach(topicIdPartition -> {
+                FetchRequestData.ForgottenTopic forgottenTopic = forgottenTopicMap.get(topicIdPartition.topic());
+                if (forgottenTopic == null) {
+                    forgottenTopic = new ForgottenTopic()
+                            .setTopic(topicIdPartition.topic())
+                            .setTopicId(topicIdPartition.topicId());
+                    forgottenTopicMap.put(topicIdPartition.topic(), forgottenTopic);
+                }
+                forgottenTopic.partitions().add(topicIdPartition.partition());
+            });
+        }
+
         @Override
         public FetchRequest build(short version) {
             if (version < 3) {
@@ -199,26 +226,31 @@ public class FetchRequest extends AbstractRequest {
             fetchRequestData.setMaxBytes(maxBytes);
             fetchRequestData.setIsolationLevel(isolationLevel.id());
             fetchRequestData.setForgottenTopicsData(new ArrayList<>());
-            toForget.stream()
-                .collect(Collectors.groupingBy(TopicPartition::topic, LinkedHashMap::new, Collectors.toList()))
-                .forEach((topic, partitions) ->
-                    fetchRequestData.forgottenTopicsData().add(new FetchRequestData.ForgottenTopic()
-                        .setTopic(topic)
-                        .setTopicId(topicIds.getOrDefault(topic, Uuid.ZERO_UUID))
-                        .setPartitions(partitions.stream().map(TopicPartition::partition).collect(Collectors.toList())))
-                );
-            fetchRequestData.setTopics(new ArrayList<>());
+
+            Map<String, FetchRequestData.ForgottenTopic> forgottenTopicMap = new LinkedHashMap<>();
+            addToForgottenTopicMap(removed, forgottenTopicMap);
+
+            // If a version older than v13 is used, topic-partition which were replaced
+            // by a topic-partition with the same name but a different topic ID are not
+            // sent out in the "forget" set in order to not remove the newly added
+            // partition in the "fetch" set.
+            if (version >= 13) {
+                addToForgottenTopicMap(replaced, forgottenTopicMap);
+            }
+
+            forgottenTopicMap.forEach((topic, forgottenTopic) -> fetchRequestData.forgottenTopicsData().add(forgottenTopic));
 
             // We collect the partitions in a single FetchTopic only if they appear sequentially in the fetchData
+            fetchRequestData.setTopics(new ArrayList<>());
             FetchRequestData.FetchTopic fetchTopic = null;
-            for (Map.Entry<TopicPartition, PartitionData> entry : fetchData.entrySet()) {
+            for (Map.Entry<TopicPartition, PartitionData> entry : toFetch.entrySet()) {
                 TopicPartition topicPartition = entry.getKey();
                 PartitionData partitionData = entry.getValue();
 
                 if (fetchTopic == null || !topicPartition.topic().equals(fetchTopic.topic())) {
                     fetchTopic = new FetchRequestData.FetchTopic()
                        .setTopic(topicPartition.topic())
-                       .setTopicId(topicIds.getOrDefault(topicPartition.topic(), Uuid.ZERO_UUID))
+                       .setTopicId(partitionData.topicId)
                        .setPartitions(new ArrayList<>());
                     fetchRequestData.topics().add(fetchTopic);
                 }
@@ -251,9 +283,10 @@ public class FetchRequest extends AbstractRequest {
                     append(", maxWait=").append(maxWait).
                     append(", minBytes=").append(minBytes).
                     append(", maxBytes=").append(maxBytes).
-                    append(", fetchData=").append(fetchData).
+                    append(", fetchData=").append(toFetch).
                     append(", isolationLevel=").append(isolationLevel).
-                    append(", toForget=").append(Utils.join(toForget, ", ")).
+                    append(", removed=").append(Utils.join(removed, ", ")).
+                    append(", replaced=").append(Utils.join(replaced, ", ")).
                     append(", metadata=").append(metadata).
                     append(", rackId=").append(rackId).
                     append(")");
@@ -314,8 +347,7 @@ public class FetchRequest extends AbstractRequest {
 
     // For versions < 13, builds the partitionData map using only the FetchRequestData.
     // For versions 13+, builds the partitionData map using both the FetchRequestData and a mapping of topic IDs to names.
-    // Throws UnknownTopicIdException for versions 13+ if the topic ID was unknown to the server.
-    public Map<TopicPartition, PartitionData> fetchData(Map<Uuid, String> topicNames) throws UnknownTopicIdException {
+    public Map<TopicIdPartition, PartitionData> fetchData(Map<Uuid, String> topicNames) {
         if (fetchData == null) {
             synchronized (this) {
                 if (fetchData == null) {
@@ -328,22 +360,19 @@ public class FetchRequest extends AbstractRequest {
                         } else {
                             name = topicNames.get(fetchTopic.topicId());
                         }
-                        if (name != null) {
-                            // If topic name is resolved, simply add to fetchData map
-                            fetchTopic.partitions().forEach(fetchPartition ->
-                                    fetchData.put(new TopicPartition(name, fetchPartition.partition()),
-                                            new PartitionData(
-                                                    fetchPartition.fetchOffset(),
-                                                    fetchPartition.logStartOffset(),
-                                                    fetchPartition.partitionMaxBytes(),
-                                                    optionalEpoch(fetchPartition.currentLeaderEpoch()),
-                                                    optionalEpoch(fetchPartition.lastFetchedEpoch())
-                                            )
-                                    )
-                            );
-                        } else {
-                            throw new UnknownTopicIdException(String.format("Topic Id %s in FetchRequest was unknown to the server", fetchTopic.topicId()));
-                        }
+                        fetchTopic.partitions().forEach(fetchPartition ->
+                                // Topic name may be null here if the topic name was unable to be resolved using the topicNames map.
+                                fetchData.put(new TopicIdPartition(fetchTopic.topicId(), new TopicPartition(name, fetchPartition.partition())),
+                                        new PartitionData(
+                                                fetchTopic.topicId(),
+                                                fetchPartition.fetchOffset(),
+                                                fetchPartition.logStartOffset(),
+                                                fetchPartition.partitionMaxBytes(),
+                                                optionalEpoch(fetchPartition.currentLeaderEpoch()),
+                                                optionalEpoch(fetchPartition.lastFetchedEpoch())
+                                        )
+                                )
+                        );
                     });
                 }
             }
@@ -351,8 +380,9 @@ public class FetchRequest extends AbstractRequest {
         return fetchData;
     }
 
-    // For versions 13+, throws UnknownTopicIdException if the topic ID was unknown to the server.
-    public List<TopicPartition> forgottenTopics(Map<Uuid, String> topicNames) throws UnknownTopicIdException {
+    // For versions < 13, builds the forgotten topics list using only the FetchRequestData.
+    // For versions 13+, builds the forgotten topics list using both the FetchRequestData and a mapping of topic IDs to names.
+    public List<TopicIdPartition> forgottenTopics(Map<Uuid, String> topicNames) {
         if (toForget == null) {
             synchronized (this) {
                 if (toForget == null) {
@@ -364,10 +394,8 @@ public class FetchRequest extends AbstractRequest {
                         } else {
                             name = topicNames.get(forgottenTopic.topicId());
                         }
-                        if (name == null) {
-                            throw new UnknownTopicIdException(String.format("Topic Id %s in FetchRequest was unknown to the server", forgottenTopic.topicId()));
-                        }
-                        forgottenTopic.partitions().forEach(partitionId -> toForget.add(new TopicPartition(name, partitionId)));
+                        // Topic name may be null here if the topic name was unable to be resolved using the topicNames map.
+                        forgottenTopic.partitions().forEach(partitionId -> toForget.add(new TopicIdPartition(forgottenTopic.topicId(), new TopicPartition(name, partitionId))));
                     });
                 }
             }
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java
index f5fa40f..2e0a02e 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.common.requests;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.message.FetchResponseData;
 import org.apache.kafka.common.protocol.ApiKeys;
@@ -57,8 +58,7 @@ import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
  *     not supported by the fetch request version
  * - {@link Errors#CORRUPT_MESSAGE} If corrupt message encountered, e.g. when the broker scans the log to find
  *     the fetch offset after the index lookup
- * - {@link Errors#UNKNOWN_TOPIC_ID} If the request contains a topic ID unknown to the broker or a partition in the session has
- *     an ID that differs from the broker
+ * - {@link Errors#UNKNOWN_TOPIC_ID} If the request contains a topic ID unknown to the broker
  * - {@link Errors#FETCH_SESSION_TOPIC_ID_ERROR} If the request version supports topic IDs but the session does not or vice versa,
  *     or a topic ID in the request is inconsistent with a topic ID in the session
  * - {@link Errors#INCONSISTENT_TOPIC_ID} If a topic ID in the session does not match the topic ID in the log
@@ -82,9 +82,9 @@ public class FetchResponse extends AbstractResponse {
     /**
      * From version 3 or later, the authorized and existing entries in `FetchRequest.fetchData` should be in the same order in `responseData`.
      * Version 13 introduces topic IDs which can lead to a few new errors. If there is any unknown topic ID in the request, the
-     * response will contain a top-level UNKNOWN_TOPIC_ID error.
+     * response will contain a partition-level UNKNOWN_TOPIC_ID error for that partition.
      * If a request's topic ID usage is inconsistent with the session, we will return a top level FETCH_SESSION_TOPIC_ID_ERROR error.
-     * We may also return INCONSISTENT_TOPIC_ID error as a top-level error when a partition in the session has a topic ID
+     * We may also return INCONSISTENT_TOPIC_ID error as a partition-level error when a partition in the session has a topic ID
      * inconsistent with the log.
      */
     public FetchResponse(FetchResponseData fetchResponseData) {
@@ -110,7 +110,7 @@ public class FetchResponse extends AbstractResponse {
                         }
                         if (name != null) {
                             topicResponse.partitions().forEach(partition ->
-                                    responseData.put(new TopicPartition(name, partition.partitionIndex()), partition));
+                                responseData.put(new TopicPartition(name, partition.partitionIndex()), partition));
                         }
                     });
                 }
@@ -154,16 +154,14 @@ public class FetchResponse extends AbstractResponse {
      *
      * @param version       The version of the response to use.
      * @param partIterator  The partition iterator.
-     * @param topicIds      The mapping from topic name to topic ID.
      * @return              The response size in bytes.
      */
     public static int sizeOf(short version,
-                             Iterator<Map.Entry<TopicPartition,
-                             FetchResponseData.PartitionData>> partIterator,
-                             Map<String, Uuid> topicIds) {
+                             Iterator<Map.Entry<TopicIdPartition,
+                             FetchResponseData.PartitionData>> partIterator) {
         // Since the throttleTimeMs and metadata field sizes are constant and fixed, we can
         // use arbitrary values here without affecting the result.
-        FetchResponseData data = toMessage(Errors.NONE, 0, INVALID_SESSION_ID, partIterator, topicIds);
+        FetchResponseData data = toMessage(Errors.NONE, 0, INVALID_SESSION_ID, partIterator);
         ObjectSerializationCache cache = new ObjectSerializationCache();
         return 4 + data.size(cache, version);
     }
@@ -191,6 +189,10 @@ public class FetchResponse extends AbstractResponse {
         return partitionResponse.preferredReadReplica() != INVALID_PREFERRED_REPLICA_ID;
     }
 
+    public static FetchResponseData.PartitionData partitionResponse(TopicIdPartition topicIdPartition, Errors error) {
+        return partitionResponse(topicIdPartition.topicPartition().partition(), error);
+    }
+
     public static FetchResponseData.PartitionData partitionResponse(int partition, Errors error) {
         return new FetchResponseData.PartitionData()
             .setPartitionIndex(partition)
@@ -226,36 +228,45 @@ public class FetchResponse extends AbstractResponse {
     public static FetchResponse of(Errors error,
                                    int throttleTimeMs,
                                    int sessionId,
-                                   LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseData,
-                                   Map<String, Uuid> topicIds) {
-        return new FetchResponse(toMessage(error, throttleTimeMs, sessionId, responseData.entrySet().iterator(), topicIds));
+                                   LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseData) {
+        return new FetchResponse(toMessage(error, throttleTimeMs, sessionId, responseData.entrySet().iterator()));
+    }
+
+    private static boolean matchingTopic(FetchResponseData.FetchableTopicResponse previousTopic, TopicIdPartition currentTopic) {
+        if (previousTopic == null)
+            return false;
+        if (!previousTopic.topicId().equals(Uuid.ZERO_UUID))
+            return previousTopic.topicId().equals(currentTopic.topicId());
+        else
+            return previousTopic.topic().equals(currentTopic.topicPartition().topic());
+
     }
 
     private static FetchResponseData toMessage(Errors error,
                                                int throttleTimeMs,
                                                int sessionId,
-                                               Iterator<Map.Entry<TopicPartition, FetchResponseData.PartitionData>> partIterator,
-                                               Map<String, Uuid> topicIds) {
+                                               Iterator<Map.Entry<TopicIdPartition, FetchResponseData.PartitionData>> partIterator) {
         List<FetchResponseData.FetchableTopicResponse> topicResponseList = new ArrayList<>();
-        partIterator.forEachRemaining(entry -> {
+        while (partIterator.hasNext()) {
+            Map.Entry<TopicIdPartition, FetchResponseData.PartitionData> entry = partIterator.next();
             FetchResponseData.PartitionData partitionData = entry.getValue();
             // Since PartitionData alone doesn't know the partition ID, we set it here
-            partitionData.setPartitionIndex(entry.getKey().partition());
+            partitionData.setPartitionIndex(entry.getKey().topicPartition().partition());
             // We have to keep the order of input topic-partition. Hence, we batch the partitions only if the last
             // batch is in the same topic group.
             FetchResponseData.FetchableTopicResponse previousTopic = topicResponseList.isEmpty() ? null
                 : topicResponseList.get(topicResponseList.size() - 1);
-            if (previousTopic != null && previousTopic.topic().equals(entry.getKey().topic()))
+            if (matchingTopic(previousTopic, entry.getKey()))
                 previousTopic.partitions().add(partitionData);
             else {
                 List<FetchResponseData.PartitionData> partitionResponses = new ArrayList<>();
                 partitionResponses.add(partitionData);
                 topicResponseList.add(new FetchResponseData.FetchableTopicResponse()
-                    .setTopic(entry.getKey().topic())
-                    .setTopicId(topicIds.getOrDefault(entry.getKey().topic(), Uuid.ZERO_UUID))
+                    .setTopic(entry.getKey().topicPartition().topic())
+                    .setTopicId(entry.getKey().topicId())
                     .setPartitions(partitionResponses));
             }
-        });
+        }
 
         return new FetchResponseData()
             .setThrottleTimeMs(throttleTimeMs)
diff --git a/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java
index 8eb09f0..4bf53d9 100644
--- a/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.clients;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.message.FetchResponseData;
@@ -28,6 +29,8 @@ import org.apache.kafka.common.utils.LogContext;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Timeout;
 import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 import org.junit.jupiter.params.provider.ValueSource;
 
 import java.util.ArrayList;
@@ -41,6 +44,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Stream;
 import java.util.TreeSet;
 
 import static org.apache.kafka.common.requests.FetchMetadata.INITIAL_EPOCH;
@@ -97,9 +101,9 @@ public class FetchSessionHandlerTest {
         final TopicPartition part;
         final FetchRequest.PartitionData data;
 
-        ReqEntry(String topic, int partition, long fetchOffset, long logStartOffset, int maxBytes) {
+        ReqEntry(String topic, Uuid topicId, int partition, long fetchOffset, long logStartOffset, int maxBytes) {
             this.part = new TopicPartition(topic, partition);
-            this.data = new FetchRequest.PartitionData(fetchOffset, logStartOffset, maxBytes, Optional.empty());
+            this.data = new FetchRequest.PartitionData(topicId, fetchOffset, logStartOffset, maxBytes, Optional.empty());
         }
     }
 
@@ -143,13 +147,13 @@ public class FetchSessionHandlerTest {
         }
     }
 
-    private static void assertListEquals(List<TopicPartition> expected, List<TopicPartition> actual) {
-        for (TopicPartition expectedPart : expected) {
+    private static void assertListEquals(List<TopicIdPartition> expected, List<TopicIdPartition> actual) {
+        for (TopicIdPartition expectedPart : expected) {
             if (!actual.contains(expectedPart)) {
                 fail("Failed to find expected partition " + expectedPart);
             }
         }
-        for (TopicPartition actualPart : actual) {
+        for (TopicIdPartition actualPart : actual) {
             if (!expected.contains(actualPart)) {
                 fail("Found unexpected partition " + actualPart);
             }
@@ -157,11 +161,11 @@ public class FetchSessionHandlerTest {
     }
 
     private static final class RespEntry {
-        final TopicPartition part;
+        final TopicIdPartition part;
         final FetchResponseData.PartitionData data;
 
-        RespEntry(String topic, int partition, long highWatermark, long lastStableOffset) {
-            this.part = new TopicPartition(topic, partition);
+        RespEntry(String topic, int partition, Uuid topicId, long highWatermark, long lastStableOffset) {
+            this.part = new TopicIdPartition(topicId, new TopicPartition(topic, partition));
 
             this.data = new FetchResponseData.PartitionData()
                 .setPartitionIndex(partition)
@@ -170,8 +174,8 @@ public class FetchSessionHandlerTest {
                 .setLogStartOffset(0);
         }
 
-        RespEntry(String topic, int partition, Errors error) {
-            this.part = new TopicPartition(topic, partition);
+        RespEntry(String topic, int partition, Uuid topicId, Errors error) {
+            this.part = new TopicIdPartition(topicId, new TopicPartition(topic, partition));
 
             this.data = new FetchResponseData.PartitionData()
                     .setPartitionIndex(partition)
@@ -180,8 +184,8 @@ public class FetchSessionHandlerTest {
         }
     }
 
-    private static LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> respMap(RespEntry... entries) {
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> map = new LinkedHashMap<>();
+    private static LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> respMap(RespEntry... entries) {
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> map = new LinkedHashMap<>();
         for (RespEntry entry : entries) {
             map.put(entry.part, entry.data);
         }
@@ -202,29 +206,30 @@ public class FetchSessionHandlerTest {
             FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
             FetchSessionHandler.Builder builder = handler.newBuilder();
             addTopicId(topicIds, topicNames, "foo", version);
-            builder.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-            builder.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
+            Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID);
+            builder.add(new TopicPartition("foo", 0),
+                    new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+            builder.add(new TopicPartition("foo", 1),
+                    new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty()));
             FetchSessionHandler.FetchRequestData data = builder.build();
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
-                    new ReqEntry("foo", 1, 10, 110, 210)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200),
+                    new ReqEntry("foo", fooId, 1, 10, 110, 210)),
                     data.toSend(), data.sessionPartitions());
             assertEquals(INVALID_SESSION_ID, data.metadata().sessionId());
             assertEquals(INITIAL_EPOCH, data.metadata().epoch());
 
             FetchResponse resp = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID,
-                respMap(new RespEntry("foo", 0, 0, 0),
-                        new RespEntry("foo", 1, 0, 0)), topicIds);
+                respMap(new RespEntry("foo", 0, fooId, 0, 0),
+                        new RespEntry("foo", 1, fooId, 0, 0)));
             handler.handleResponse(resp, version);
 
             FetchSessionHandler.Builder builder2 = handler.newBuilder();
-            builder2.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+            builder2.add(new TopicPartition("foo", 0),
+                    new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
             FetchSessionHandler.FetchRequestData data2 = builder2.build();
             assertEquals(INVALID_SESSION_ID, data2.metadata().sessionId());
             assertEquals(INITIAL_EPOCH, data2.metadata().epoch());
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)),
                     data2.toSend(), data2.sessionPartitions());
         });
     }
@@ -242,65 +247,62 @@ public class FetchSessionHandlerTest {
             FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
             FetchSessionHandler.Builder builder = handler.newBuilder();
             addTopicId(topicIds, topicNames, "foo", version);
-            builder.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-            builder.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
+            Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID);
+            TopicPartition foo0 = new TopicPartition("foo", 0);
+            TopicPartition foo1 = new TopicPartition("foo", 1);
+            builder.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+            builder.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty()));
             FetchSessionHandler.FetchRequestData data = builder.build();
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
-                    new ReqEntry("foo", 1, 10, 110, 210)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200),
+                    new ReqEntry("foo", fooId, 1, 10, 110, 210)),
                     data.toSend(), data.sessionPartitions());
             assertEquals(INVALID_SESSION_ID, data.metadata().sessionId());
             assertEquals(INITIAL_EPOCH, data.metadata().epoch());
 
             FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
-                respMap(new RespEntry("foo", 0, 10, 20),
-                        new RespEntry("foo", 1, 10, 20)), topicIds);
+                respMap(new RespEntry("foo", 0, fooId, 10, 20),
+                        new RespEntry("foo", 1, fooId, 10, 20)));
             handler.handleResponse(resp, version);
 
             // Test an incremental fetch request which adds one partition and modifies another.
             FetchSessionHandler.Builder builder2 = handler.newBuilder();
             addTopicId(topicIds, topicNames, "bar", version);
-            builder2.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-            builder2.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(10, 120, 210, Optional.empty()));
-            builder2.add(new TopicPartition("bar", 0), topicIds.getOrDefault("bar", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(20, 200, 200, Optional.empty()));
+            Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID);
+            TopicPartition bar0 = new TopicPartition("bar", 0);
+            builder2.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+            builder2.add(foo1, new FetchRequest.PartitionData(fooId, 10, 120, 210, Optional.empty()));
+            builder2.add(bar0, new FetchRequest.PartitionData(barId, 20, 200, 200, Optional.empty()));
             FetchSessionHandler.FetchRequestData data2 = builder2.build();
             assertFalse(data2.metadata().isFull());
-            assertMapEquals(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
-                    new ReqEntry("foo", 1, 10, 120, 210),
-                    new ReqEntry("bar", 0, 20, 200, 200)),
+            assertMapEquals(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200),
+                    new ReqEntry("foo", fooId, 1, 10, 120, 210),
+                    new ReqEntry("bar", barId, 0, 20, 200, 200)),
                     data2.sessionPartitions());
-            assertMapEquals(reqMap(new ReqEntry("bar", 0, 20, 200, 200),
-                    new ReqEntry("foo", 1, 10, 120, 210)),
+            assertMapEquals(reqMap(new ReqEntry("bar", barId, 0, 20, 200, 200),
+                    new ReqEntry("foo", fooId, 1, 10, 120, 210)),
                     data2.toSend());
 
             FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, 123,
-                respMap(new RespEntry("foo", 1, 20, 20)), topicIds);
+                respMap(new RespEntry("foo", 1, fooId, 20, 20)));
             handler.handleResponse(resp2, version);
 
             // Skip building a new request.  Test that handling an invalid fetch session epoch response results
             // in a request which closes the session.
             FetchResponse resp3 = FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, 0, INVALID_SESSION_ID,
-                respMap(), topicIds);
+                respMap());
             handler.handleResponse(resp3, version);
 
             FetchSessionHandler.Builder builder4 = handler.newBuilder();
-            builder4.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-            builder4.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(10, 120, 210, Optional.empty()));
-            builder4.add(new TopicPartition("bar", 0), topicIds.getOrDefault("bar", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(20, 200, 200, Optional.empty()));
+            builder4.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+            builder4.add(foo1, new FetchRequest.PartitionData(fooId, 10, 120, 210, Optional.empty()));
+            builder4.add(bar0, new FetchRequest.PartitionData(barId, 20, 200, 200, Optional.empty()));
             FetchSessionHandler.FetchRequestData data4 = builder4.build();
             assertTrue(data4.metadata().isFull());
             assertEquals(data2.metadata().sessionId(), data4.metadata().sessionId());
             assertEquals(INITIAL_EPOCH, data4.metadata().epoch());
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
-                    new ReqEntry("foo", 1, 10, 120, 210),
-                    new ReqEntry("bar", 0, 20, 200, 200)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200),
+                    new ReqEntry("foo", fooId, 1, 10, 120, 210),
+                    new ReqEntry("bar", barId, 0, 20, 200, 200)),
                     data4.sessionPartitions(), data4.toSend());
         });
     }
@@ -312,8 +314,8 @@ public class FetchSessionHandlerTest {
     public void testDoubleBuild() {
         FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
         FetchSessionHandler.Builder builder = handler.newBuilder();
-        builder.add(new TopicPartition("foo", 0), Uuid.randomUuid(),
-            new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+        builder.add(new TopicPartition("foo", 0),
+            new FetchRequest.PartitionData(Uuid.randomUuid(), 0, 100, 200, Optional.empty()));
         builder.build();
         try {
             builder.build();
@@ -334,55 +336,55 @@ public class FetchSessionHandlerTest {
             FetchSessionHandler.Builder builder = handler.newBuilder();
             addTopicId(topicIds, topicNames, "foo", version);
             addTopicId(topicIds, topicNames, "bar", version);
-            builder.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-            builder.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
-            builder.add(new TopicPartition("bar", 0), topicIds.getOrDefault("bar", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(20, 120, 220, Optional.empty()));
+            Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID);
+            Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID);
+            TopicPartition foo0 = new TopicPartition("foo", 0);
+            TopicPartition foo1 = new TopicPartition("foo", 1);
+            TopicPartition bar0 = new TopicPartition("bar", 0);
+            builder.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+            builder.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty()));
+            builder.add(bar0, new FetchRequest.PartitionData(barId, 20, 120, 220, Optional.empty()));
             FetchSessionHandler.FetchRequestData data = builder.build();
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200),
-                    new ReqEntry("foo", 1, 10, 110, 210),
-                    new ReqEntry("bar", 0, 20, 120, 220)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200),
+                    new ReqEntry("foo", fooId, 1, 10, 110, 210),
+                    new ReqEntry("bar", barId, 0, 20, 120, 220)),
                     data.toSend(), data.sessionPartitions());
             assertTrue(data.metadata().isFull());
 
             FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
-                respMap(new RespEntry("foo", 0, 10, 20),
-                        new RespEntry("foo", 1, 10, 20),
-                        new RespEntry("bar", 0, 10, 20)), topicIds);
+                respMap(new RespEntry("foo", 0, fooId, 10, 20),
+                        new RespEntry("foo", 1, fooId, 10, 20),
+                        new RespEntry("bar", 0, barId, 10, 20)));
             handler.handleResponse(resp, version);
 
             // Test an incremental fetch request which removes two partitions.
             FetchSessionHandler.Builder builder2 = handler.newBuilder();
-            builder2.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
+            builder2.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty()));
             FetchSessionHandler.FetchRequestData data2 = builder2.build();
             assertFalse(data2.metadata().isFull());
             assertEquals(123, data2.metadata().sessionId());
             assertEquals(1, data2.metadata().epoch());
-            assertMapEquals(reqMap(new ReqEntry("foo", 1, 10, 110, 210)),
+            assertMapEquals(reqMap(new ReqEntry("foo", fooId, 1, 10, 110, 210)),
                     data2.sessionPartitions());
             assertMapEquals(reqMap(), data2.toSend());
-            ArrayList<TopicPartition> expectedToForget2 = new ArrayList<>();
-            expectedToForget2.add(new TopicPartition("foo", 0));
-            expectedToForget2.add(new TopicPartition("bar", 0));
+            ArrayList<TopicIdPartition> expectedToForget2 = new ArrayList<>();
+            expectedToForget2.add(new TopicIdPartition(fooId, foo0));
+            expectedToForget2.add(new TopicIdPartition(barId, bar0));
             assertListEquals(expectedToForget2, data2.toForget());
 
             // A FETCH_SESSION_ID_NOT_FOUND response triggers us to close the session.
             // The next request is a session establishing FULL request.
             FetchResponse resp2 = FetchResponse.of(Errors.FETCH_SESSION_ID_NOT_FOUND, 0, INVALID_SESSION_ID,
-                respMap(), topicIds);
+                respMap());
             handler.handleResponse(resp2, version);
 
             FetchSessionHandler.Builder builder3 = handler.newBuilder();
-            builder3.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+            builder3.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
             FetchSessionHandler.FetchRequestData data3 = builder3.build();
             assertTrue(data3.metadata().isFull());
             assertEquals(INVALID_SESSION_ID, data3.metadata().sessionId());
             assertEquals(INITIAL_EPOCH, data3.metadata().epoch());
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)),
                     data3.sessionPartitions(), data3.toSend());
         });
     }
@@ -396,28 +398,30 @@ public class FetchSessionHandlerTest {
             String testType = partition == 0 ? "updating a partition" : "adding a new partition";
             FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
             FetchSessionHandler.Builder builder = handler.newBuilder();
-            builder.add(new TopicPartition("foo", 0),  Uuid.ZERO_UUID,
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+            builder.add(new TopicPartition("foo", 0),
+                    new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 100, 200, Optional.empty()));
             FetchSessionHandler.FetchRequestData data = builder.build();
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", Uuid.ZERO_UUID, 0, 0, 100, 200)),
                     data.toSend(), data.sessionPartitions());
             assertTrue(data.metadata().isFull());
             assertFalse(data.canUseTopicIds());
 
             FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
-                    respMap(new RespEntry("foo", 0, 10, 20)), Collections.emptyMap());
+                    respMap(new RespEntry("foo", 0, Uuid.ZERO_UUID, 10, 20)));
             handler.handleResponse(resp, (short) 12);
 
             // Try to add a topic ID to an already existing topic partition (0) or a new partition (1) in the session.
+            Uuid topicId = Uuid.randomUuid();
             FetchSessionHandler.Builder builder2 = handler.newBuilder();
-            builder2.add(new TopicPartition("foo", partition), Uuid.randomUuid(),
-                    new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
+            builder2.add(new TopicPartition("foo", partition),
+                    new FetchRequest.PartitionData(topicId, 10, 110, 210, Optional.empty()));
             FetchSessionHandler.FetchRequestData data2 = builder2.build();
-            // Should have the same session ID and next epoch, but we can now use topic IDs.
-            // The receiving broker will close the session if we were previously not using topic IDs.
+            // Should have the same session ID, and next epoch and can only use topic IDs if the partition was updated.
+            boolean updated = partition == 0;
+            // The receiving broker will handle closing the session.
             assertEquals(123, data2.metadata().sessionId(), "Did not use same session when " + testType);
             assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch when " + testType);
-            assertTrue(data2.canUseTopicIds());
+            assertEquals(updated, data2.canUseTopicIds());
         });
     }
 
@@ -428,62 +432,173 @@ public class FetchSessionHandlerTest {
         List<Integer> partitions = Arrays.asList(0, 1);
         partitions.forEach(partition -> {
             String testType = partition == 0 ? "updating a partition" : "adding a new partition";
-            Map<String, Uuid> topicIds = Collections.singletonMap("foo", Uuid.randomUuid());
+            Uuid fooId = Uuid.randomUuid();
             FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
             FetchSessionHandler.Builder builder = handler.newBuilder();
-            builder.add(new TopicPartition("foo", 0),  topicIds.get("foo"),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+            builder.add(new TopicPartition("foo", 0),
+                    new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
             FetchSessionHandler.FetchRequestData data = builder.build();
-            assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+            assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)),
                     data.toSend(), data.sessionPartitions());
             assertTrue(data.metadata().isFull());
             assertTrue(data.canUseTopicIds());
 
             FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
-                    respMap(new RespEntry("foo", 0, 10, 20)), topicIds);
+                    respMap(new RespEntry("foo", 0, fooId, 10, 20)));
             handler.handleResponse(resp, ApiKeys.FETCH.latestVersion());
 
             // Try to remove a topic ID from an existing topic partition (0) or add a new topic partition (1) without an ID.
             FetchSessionHandler.Builder builder2 = handler.newBuilder();
-            builder2.add(new TopicPartition("foo", partition), Uuid.ZERO_UUID,
-                    new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
+            builder2.add(new TopicPartition("foo", partition),
+                    new FetchRequest.PartitionData(Uuid.ZERO_UUID, 10, 110, 210, Optional.empty()));
             FetchSessionHandler.FetchRequestData data2 = builder2.build();
-            // Should have the same session ID and next epoch, but can no longer use topic IDs.
-            // The receiving broker will close the session if we were previously using topic IDs.
+            // Should have the same session ID, and next epoch and can no longer use topic IDs.
+            // The receiving broker will handle closing the session.
             assertEquals(123, data2.metadata().sessionId(), "Did not use same session when " + testType);
             assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch when " + testType);
             assertFalse(data2.canUseTopicIds());
         });
     }
 
+    private static Stream<Arguments> idUsageCombinations() {
+        return Stream.of(
+                Arguments.of(true, true),
+                Arguments.of(true, false),
+                Arguments.of(false, true),
+                Arguments.of(false, false)
+        );
+    }
+
+    @ParameterizedTest
+    @MethodSource("idUsageCombinations")
+    public void testTopicIdReplaced(boolean startsWithTopicIds, boolean endsWithTopicIds) {
+        TopicPartition tp = new TopicPartition("foo", 0);
+        FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
+        FetchSessionHandler.Builder builder = handler.newBuilder();
+        Uuid topicId1 = startsWithTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID;
+        builder.add(tp, new FetchRequest.PartitionData(topicId1, 0, 100, 200, Optional.empty()));
+        FetchSessionHandler.FetchRequestData data = builder.build();
+        assertMapsEqual(reqMap(new ReqEntry("foo", topicId1, 0, 0, 100, 200)),
+                data.toSend(), data.sessionPartitions());
+        assertTrue(data.metadata().isFull());
+        assertEquals(startsWithTopicIds, data.canUseTopicIds());
+
+        FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, respMap(new RespEntry("foo", 0, topicId1, 10, 20)));
+        short version = startsWithTopicIds ? ApiKeys.FETCH.latestVersion() : 12;
+        handler.handleResponse(resp, version);
+
+        // Try to add a new topic ID.
+        FetchSessionHandler.Builder builder2 = handler.newBuilder();
+        Uuid topicId2 = endsWithTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID;
+        // Use the same data besides the topic ID.
+        FetchRequest.PartitionData partitionData = new FetchRequest.PartitionData(topicId2, 0, 100, 200, Optional.empty());
+        builder2.add(tp, partitionData);
+        FetchSessionHandler.FetchRequestData data2 = builder2.build();
+
+        if (startsWithTopicIds && endsWithTopicIds) {
+            // If we started with an ID, both a only a new ID will count towards replaced.
+            // The old topic ID partition should be in toReplace, and the new one should be in toSend.
+            assertEquals(Collections.singletonList(new TopicIdPartition(topicId1, tp)), data2.toReplace());
+            assertMapsEqual(reqMap(new ReqEntry("foo", topicId2, 0, 0, 100, 200)),
+                    data2.toSend(), data2.sessionPartitions());
+
+            // sessionTopicNames should contain only the second topic ID.
+            assertEquals(Collections.singletonMap(topicId2, tp.topic()), handler.sessionTopicNames());
+
+        } else if (startsWithTopicIds || endsWithTopicIds) {
+            // If we downgraded to not using topic IDs we will want to send this data.
+            // However, we will not mark the partition as one replaced. In this scenario, we should see the session close due to
+            // changing request types.
+            // We will have the new topic ID in the session partition map
+            assertEquals(Collections.emptyList(), data2.toReplace());
+            assertMapsEqual(reqMap(new ReqEntry("foo", topicId2, 0, 0, 100, 200)),
+                    data2.toSend(), data2.sessionPartitions());
+            // The topicNames map will have the new topic ID if it is valid.
+            // The old topic ID should be removed as the map will be empty if the request doesn't use topic IDs.
+            if (endsWithTopicIds) {
+                assertEquals(Collections.singletonMap(topicId2, tp.topic()), handler.sessionTopicNames());
+            } else {
+                assertEquals(Collections.emptyMap(), handler.sessionTopicNames());
+            }
+
+        } else {
+            // Otherwise, we have no partition in toReplace and since the partition and topic ID was not updated, there is no data to send.
+            assertEquals(Collections.emptyList(), data2.toReplace());
+            assertEquals(Collections.emptyMap(), data2.toSend());
+            assertMapsEqual(reqMap(new ReqEntry("foo", topicId2, 0, 0, 100, 200)), data2.sessionPartitions());
+            // There is also nothing in the sessionTopicNames map, as there are no topic IDs used.
+            assertEquals(Collections.emptyMap(), handler.sessionTopicNames());
+        }
+
+        // Should have the same session ID, and next epoch and can use topic IDs if it ended with topic IDs.
+        assertEquals(123, data2.metadata().sessionId(), "Did not use same session");
+        assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch");
+        assertEquals(endsWithTopicIds, data2.canUseTopicIds());
+    }
+
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    public void testSessionEpochWhenMixedUsageOfTopicIDs(boolean startsWithTopicIds) {
+        Uuid fooId = startsWithTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID;
+        Uuid barId = startsWithTopicIds ? Uuid.ZERO_UUID : Uuid.randomUuid();
+        short responseVersion = startsWithTopicIds ? ApiKeys.FETCH.latestVersion() : 12;
+
+        TopicPartition tp0 = new TopicPartition("foo", 0);
+        TopicPartition tp1 = new TopicPartition("bar", 1);
+
+        FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
+        FetchSessionHandler.Builder builder = handler.newBuilder();
+        builder.add(tp0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+        FetchSessionHandler.FetchRequestData data = builder.build();
+        assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)),
+                data.toSend(), data.sessionPartitions());
+        assertTrue(data.metadata().isFull());
+        assertEquals(startsWithTopicIds, data.canUseTopicIds());
+
+        FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
+                respMap(new RespEntry("foo", 0, fooId, 10, 20)));
+        handler.handleResponse(resp, responseVersion);
+
+        // Re-add the first partition. Then add a partition with opposite ID usage.
+        FetchSessionHandler.Builder builder2 = handler.newBuilder();
+        builder2.add(tp0, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty()));
+        builder2.add(tp1, new FetchRequest.PartitionData(barId, 0, 100, 200, Optional.empty()));
+        FetchSessionHandler.FetchRequestData data2 = builder2.build();
+        // Should have the same session ID, and the next epoch and can not use topic IDs.
+        // The receiving broker will handle closing the session.
+        assertEquals(123, data2.metadata().sessionId(), "Did not use same session");
+        assertEquals(1, data2.metadata().epoch(), "Did not have final epoch");
+        assertFalse(data2.canUseTopicIds());
+    }
+
+
     @ParameterizedTest
     @ValueSource(booleans = {true, false})
     public void testIdUsageWithAllForgottenPartitions(boolean useTopicIds) {
         // We want to test when all topics are removed from the session
+        TopicPartition foo0 = new TopicPartition("foo", 0);
         Uuid topicId = useTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID;
-        Short responseVersion = useTopicIds ? ApiKeys.FETCH.latestVersion() : 12;
-        Map<String, Uuid> topicIds = Collections.singletonMap("foo", topicId);
+        short responseVersion = useTopicIds ? ApiKeys.FETCH.latestVersion() : 12;
         FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
 
         // Add topic foo to the session
         FetchSessionHandler.Builder builder = handler.newBuilder();
-        builder.add(new TopicPartition("foo", 0), topicIds.get("foo"),
-                new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+        builder.add(foo0, new FetchRequest.PartitionData(topicId, 0, 100, 200, Optional.empty()));
         FetchSessionHandler.FetchRequestData data = builder.build();
-        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+        assertMapsEqual(reqMap(new ReqEntry("foo", topicId, 0, 0, 100, 200)),
                 data.toSend(), data.sessionPartitions());
         assertTrue(data.metadata().isFull());
         assertEquals(useTopicIds, data.canUseTopicIds());
 
         FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
-                respMap(new RespEntry("foo", 0, 10, 20)), topicIds);
-        handler.handleResponse(resp, responseVersion.shortValue());
+                respMap(new RespEntry("foo", 0, topicId, 10, 20)));
+        handler.handleResponse(resp, responseVersion);
 
         // Remove the topic from the session
         FetchSessionHandler.Builder builder2 = handler.newBuilder();
         FetchSessionHandler.FetchRequestData data2 = builder2.build();
-        // Should have the same session ID and next epoch, but can no longer use topic IDs.
-        // The receiving broker will close the session if we were previously using topic IDs.
+        assertEquals(Collections.singletonList(new TopicIdPartition(topicId, foo0)), data2.toForget());
+        // Should have the same session ID, next epoch, and same ID usage.
         assertEquals(123, data2.metadata().sessionId(), "Did not use same session when useTopicIds was " + useTopicIds);
         assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch when useTopicIds was " + useTopicIds);
         assertEquals(useTopicIds, data2.canUseTopicIds());
@@ -491,19 +606,19 @@ public class FetchSessionHandlerTest {
 
     @Test
     public void testOkToAddNewIdAfterTopicRemovedFromSession() {
-        Map<String, Uuid> topicIds = Collections.singletonMap("foo", Uuid.randomUuid());
+        Uuid topicId = Uuid.randomUuid();
         FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
         FetchSessionHandler.Builder builder = handler.newBuilder();
-        builder.add(new TopicPartition("foo", 0),  topicIds.get("foo"),
-                new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+        builder.add(new TopicPartition("foo", 0),
+                new FetchRequest.PartitionData(topicId, 0, 100, 200, Optional.empty()));
         FetchSessionHandler.FetchRequestData data = builder.build();
-        assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)),
+        assertMapsEqual(reqMap(new ReqEntry("foo", topicId, 0, 0, 100, 200)),
                 data.toSend(), data.sessionPartitions());
         assertTrue(data.metadata().isFull());
         assertTrue(data.canUseTopicIds());
 
         FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
-                respMap(new RespEntry("foo", 0, 10, 20)), topicIds);
+                respMap(new RespEntry("foo", 0, topicId, 10, 20)));
         handler.handleResponse(resp, ApiKeys.FETCH.latestVersion());
 
         // Remove the partition from the session. Return a session ID as though the session is still open.
@@ -512,13 +627,13 @@ public class FetchSessionHandlerTest {
         assertMapsEqual(new LinkedHashMap<>(),
                 data2.toSend(), data2.sessionPartitions());
         FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, 123,
-                new LinkedHashMap<>(), topicIds);
+                new LinkedHashMap<>());
         handler.handleResponse(resp2, ApiKeys.FETCH.latestVersion());
 
         // After the topic is removed, add a recreated topic with a new ID.
         FetchSessionHandler.Builder builder3 = handler.newBuilder();
-        builder3.add(new TopicPartition("foo", 0),  Uuid.randomUuid(),
-                new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+        builder3.add(new TopicPartition("foo", 0),
+                new FetchRequest.PartitionData(Uuid.randomUuid(), 0, 100, 200, Optional.empty()));
         FetchSessionHandler.FetchRequestData data3 = builder3.build();
         // Should have the same session ID and epoch 2.
         assertEquals(123, data3.metadata().sessionId(), "Did not use same session");
@@ -536,32 +651,34 @@ public class FetchSessionHandlerTest {
             FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
             addTopicId(topicIds, topicNames, "foo", version);
             addTopicId(topicIds, topicNames, "bar", version);
+            Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID);
+            Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID);
+            TopicPartition foo0 = new TopicPartition("foo", 0);
+            TopicPartition foo1 = new TopicPartition("foo", 1);
+            TopicPartition bar0 = new TopicPartition("bar", 0);
             FetchResponse resp1 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID,
-                respMap(new RespEntry("foo", 0, 10, 20),
-                        new RespEntry("foo", 1, 10, 20),
-                        new RespEntry("bar", 0, 10, 20)), topicIds);
+                respMap(new RespEntry("foo", 0, fooId, 10, 20),
+                        new RespEntry("foo", 1, fooId, 10, 20),
+                        new RespEntry("bar", 0, barId, 10, 20)));
             String issue = handler.verifyFullFetchResponsePartitions(resp1.responseData(topicNames, version).keySet(),
                     resp1.topicIds(), version);
             assertTrue(issue.contains("extraPartitions="));
             assertFalse(issue.contains("omittedPartitions="));
             FetchSessionHandler.Builder builder = handler.newBuilder();
-            builder.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-            builder.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
-            builder.add(new TopicPartition("bar", 0), topicIds.getOrDefault("bar", Uuid.ZERO_UUID),
-                    new FetchRequest.PartitionData(20, 120, 220, Optional.empty()));
+            builder.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+            builder.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty()));
+            builder.add(bar0, new FetchRequest.PartitionData(barId, 20, 120, 220, Optional.empty()));
             builder.build();
             FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID,
-                respMap(new RespEntry("foo", 0, 10, 20),
-                        new RespEntry("foo", 1, 10, 20),
-                        new RespEntry("bar", 0, 10, 20)), topicIds);
+                respMap(new RespEntry("foo", 0, fooId, 10, 20),
+                        new RespEntry("foo", 1, fooId, 10, 20),
+                        new RespEntry("bar", 0, barId, 10, 20)));
             String issue2 = handler.verifyFullFetchResponsePartitions(resp2.responseData(topicNames, version).keySet(),
                     resp2.topicIds(), version);
             assertNull(issue2);
             FetchResponse resp3 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID,
-                respMap(new RespEntry("foo", 0, 10, 20),
-                        new RespEntry("foo", 1, 10, 20)), topicIds);
+                respMap(new RespEntry("foo", 0, fooId, 10, 20),
+                        new RespEntry("foo", 1, fooId, 10, 20)));
             String issue3 = handler.verifyFullFetchResponsePartitions(resp3.responseData(topicNames, version).keySet(),
                     resp3.topicIds(), version);
             assertFalse(issue3.contains("extraPartitions="));
@@ -576,36 +693,32 @@ public class FetchSessionHandlerTest {
         FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
         addTopicId(topicIds, topicNames, "foo", ApiKeys.FETCH.latestVersion());
         addTopicId(topicIds, topicNames, "bar", ApiKeys.FETCH.latestVersion());
-        Uuid extraId = Uuid.randomUuid();
-        topicIds.put("extra2", extraId);
+        addTopicId(topicIds, topicNames, "extra2", ApiKeys.FETCH.latestVersion());
         FetchResponse resp1 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID,
-            respMap(new RespEntry("foo", 0, 10, 20),
-                    new RespEntry("extra2", 1, 10, 20),
-                    new RespEntry("bar", 0, 10, 20)), topicIds);
+            respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20),
+                    new RespEntry("extra2", 1, topicIds.get("extra2"), 10, 20),
+                    new RespEntry("bar", 0, topicIds.get("bar"), 10, 20)));
         String issue = handler.verifyFullFetchResponsePartitions(resp1.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet(),
                 resp1.topicIds(), ApiKeys.FETCH.latestVersion());
         assertTrue(issue.contains("extraPartitions="));
-        assertTrue(issue.contains("extraIds="));
         assertFalse(issue.contains("omittedPartitions="));
         FetchSessionHandler.Builder builder = handler.newBuilder();
-        builder.add(new TopicPartition("foo", 0), topicIds.get("foo"),
-                new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-        builder.add(new TopicPartition("bar", 0), topicIds.get("bar"),
-                new FetchRequest.PartitionData(20, 120, 220, Optional.empty()));
+        builder.add(new TopicPartition("foo", 0),
+                new FetchRequest.PartitionData(topicIds.get("foo"), 0, 100, 200, Optional.empty()));
+        builder.add(new TopicPartition("bar", 0),
+                new FetchRequest.PartitionData(topicIds.get("bar"), 20, 120, 220, Optional.empty()));
         builder.build();
         FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID,
-            respMap(new RespEntry("foo", 0, 10, 20),
-                    new RespEntry("extra2", 1, 10, 20),
-                    new RespEntry("bar", 0, 10, 20)), topicIds);
+            respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20),
+                    new RespEntry("extra2", 1, topicIds.get("extra2"), 10, 20),
+                    new RespEntry("bar", 0, topicIds.get("bar"), 10, 20)));
         String issue2 = handler.verifyFullFetchResponsePartitions(resp2.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet(),
                 resp2.topicIds(), ApiKeys.FETCH.latestVersion());
-        assertFalse(issue2.contains("extraPartitions="));
-        assertTrue(issue2.contains("extraIds="));
+        assertTrue(issue2.contains("extraPartitions="));
         assertFalse(issue2.contains("omittedPartitions="));
-        topicNames.put(extraId, "extra2");
         FetchResponse resp3 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID,
-            respMap(new RespEntry("foo", 0, 10, 20),
-                    new RespEntry("bar", 0, 10, 20)), topicIds);
+            respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20),
+                    new RespEntry("bar", 0, topicIds.get("bar"), 10, 20)));
         String issue3 = handler.verifyFullFetchResponsePartitions(resp3.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet(),
                 resp3.topicIds(), ApiKeys.FETCH.latestVersion());
         assertNull(issue3);
@@ -618,24 +731,25 @@ public class FetchSessionHandlerTest {
         FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1);
         FetchSessionHandler.Builder builder = handler.newBuilder();
         addTopicId(topicIds, topicNames, "foo", ApiKeys.FETCH.latestVersion());
-        builder.add(new TopicPartition("foo", 0), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
-        builder.add(new TopicPartition("foo", 1), topicIds.getOrDefault("foo", Uuid.ZERO_UUID),
-                new FetchRequest.PartitionData(10, 110, 210, Optional.empty()));
+        Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID);
+        builder.add(new TopicPartition("foo", 0),
+                new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty()));
+        builder.add(new TopicPartition("foo", 1),
+                new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty()));
         FetchSessionHandler.FetchRequestData data = builder.build();
         assertEquals(INVALID_SESSION_ID, data.metadata().sessionId());
         assertEquals(INITIAL_EPOCH, data.metadata().epoch());
 
         FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123,
-            respMap(new RespEntry("foo", 0, 10, 20),
-                    new RespEntry("foo", 1, 10, 20)), topicIds);
+            respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20),
+                    new RespEntry("foo", 1, topicIds.get("foo"), 10, 20)));
         handler.handleResponse(resp, ApiKeys.FETCH.latestVersion());
 
         // Test an incremental fetch request which adds an ID unknown to the broker.
         FetchSessionHandler.Builder builder2 = handler.newBuilder();
         addTopicId(topicIds, topicNames, "unknown", ApiKeys.FETCH.latestVersion());
-        builder2.add(new TopicPartition("unknown", 0), topicIds.getOrDefault("unknown", Uuid.ZERO_UUID),
-                new FetchRequest.PartitionData(0, 100, 200, Optional.empty()));
+        builder2.add(new TopicPartition("unknown", 0),
+                new FetchRequest.PartitionData(topicIds.getOrDefault("unknown", Uuid.ZERO_UUID), 0, 100, 200, Optional.empty()));
         FetchSessionHandler.FetchRequestData data2 = builder2.build();
         assertFalse(data2.metadata().isFull());
         assertEquals(123, data2.metadata().sessionId());
@@ -643,7 +757,7 @@ public class FetchSessionHandlerTest {
 
         // Return and handle a response with a top level error
         FetchResponse resp2 = FetchResponse.of(Errors.UNKNOWN_TOPIC_ID, 0, 123,
-            respMap(new RespEntry("unknown", 0, Errors.UNKNOWN_TOPIC_ID)), topicIds);
+            respMap(new RespEntry("unknown", 0, Uuid.randomUuid(), Errors.UNKNOWN_TOPIC_ID)));
         assertFalse(handler.handleResponse(resp2, ApiKeys.FETCH.latestVersion()));
 
         // Ensure we start with a new epoch. This will close the session in the next request.
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index bc7d506..2872983 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -38,6 +38,7 @@ import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.config.SslConfigs;
 import org.apache.kafka.common.errors.AuthenticationException;
@@ -715,9 +716,10 @@ public class KafkaConsumerTest {
         client.prepareResponse(
             body -> {
                 FetchRequest request = (FetchRequest) body;
-                Map<TopicPartition, FetchRequest.PartitionData> fetchData = request.fetchData(topicNames);
-                return fetchData.keySet().equals(singleton(tp0)) &&
-                        fetchData.get(tp0).fetchOffset == 50L;
+                Map<TopicIdPartition, FetchRequest.PartitionData> fetchData = request.fetchData(topicNames);
+                TopicIdPartition tidp0 = new TopicIdPartition(topicIds.get(tp0.topic()), tp0);
+                return fetchData.keySet().equals(singleton(tidp0)) &&
+                        fetchData.get(tidp0).fetchOffset == 50L;
 
             }, fetchResponse(tp0, 50L, 5));
 
@@ -1762,7 +1764,7 @@ public class KafkaConsumerTest {
         client.prepareResponseFrom(syncGroupResponse(singletonList(tp0), Errors.NONE), coordinator);
 
         client.prepareResponseFrom(body -> body instanceof FetchRequest 
-            && ((FetchRequest) body).fetchData(topicNames).containsKey(tp0), fetchResponse(tp0, 1, 1), node);
+            && ((FetchRequest) body).fetchData(topicNames).containsKey(new TopicIdPartition(topicId, tp0)), fetchResponse(tp0, 1, 1), node);
         time.sleep(heartbeatIntervalMs);
         Thread.sleep(heartbeatIntervalMs);
         consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE));
@@ -2541,7 +2543,7 @@ public class KafkaConsumerTest {
     }
 
     private FetchResponse fetchResponse(Map<TopicPartition, FetchInfo> fetches) {
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> tpResponses = new LinkedHashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> tpResponses = new LinkedHashMap<>();
         for (Map.Entry<TopicPartition, FetchInfo> fetchEntry : fetches.entrySet()) {
             TopicPartition partition = fetchEntry.getKey();
             long fetchOffset = fetchEntry.getValue().offset;
@@ -2558,14 +2560,14 @@ public class KafkaConsumerTest {
                     builder.append(0L, ("key-" + i).getBytes(), ("value-" + i).getBytes());
                 records = builder.build();
             }
-            tpResponses.put(partition,
+            tpResponses.put(new TopicIdPartition(topicIds.get(partition.topic()), partition),
                 new FetchResponseData.PartitionData()
                     .setPartitionIndex(partition.partition())
                     .setHighWatermark(highWatermark)
                     .setLogStartOffset(logStartOffset)
                     .setRecords(records));
         }
-        return FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, tpResponses, topicIds);
+        return FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, tpResponses);
     }
 
     private FetchResponse fetchResponse(TopicPartition partition, long fetchOffset, int count) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 7b9f33d..7509eb6 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -40,6 +40,7 @@ import org.apache.kafka.common.MetricNameTemplate;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.errors.InvalidTopicException;
 import org.apache.kafka.common.errors.RecordTooLargeException;
@@ -68,6 +69,7 @@ import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.FetchRequest.PartitionData;
 import org.apache.kafka.common.utils.BufferSupplier;
 import org.apache.kafka.common.record.CompressionType;
 import org.apache.kafka.common.record.ControlRecordType;
@@ -136,10 +138,12 @@ import static java.util.Arrays.asList;
 import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singleton;
+import static java.util.Collections.singletonList;
 import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
 import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH;
 import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET;
+import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.apache.kafka.test.TestUtils.assertOptional;
 import static org.junit.jupiter.api.Assertions.assertArrayEquals;
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -169,6 +173,10 @@ public class FetcherTest {
     private TopicPartition tp1 = new TopicPartition(topicName, 1);
     private TopicPartition tp2 = new TopicPartition(topicName, 2);
     private TopicPartition tp3 = new TopicPartition(topicName, 3);
+    private TopicIdPartition tidp0 = new TopicIdPartition(topicId, tp0);
+    private TopicIdPartition tidp1 = new TopicIdPartition(topicId, tp1);
+    private TopicIdPartition tidp2 = new TopicIdPartition(topicId, tp2);
+    private TopicIdPartition tidp3 = new TopicIdPartition(topicId, tp3);
     private int validLeaderEpoch = 0;
     private MetadataResponse initialUpdateResponse =
         RequestTestUtils.metadataUpdateWithIds(1, singletonMap(topicName, 4), topicIds);
@@ -247,7 +255,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -265,28 +273,31 @@ public class FetcherTest {
     }
 
     @Test
-    public void testFetchWithNoId() {
+    public void testFetchWithNoTopicId() {
         // Should work and default to using old request type.
         buildFetcher();
 
-        TopicPartition noId = new TopicPartition("noId", 0);
-        assignFromUserNoId(singleton(noId));
-        subscriptions.seek(noId, 0);
+        TopicIdPartition noId = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("noId", 0));
+        assignFromUserNoId(singleton(noId.topicPartition()));
+        subscriptions.seek(noId.topicPartition(), 0);
 
-        // fetch should use request version 12
+        // Fetch should use request version 12
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(noId, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(
+            fetchRequestMatcher((short) 12, noId, 0, Optional.of(validLeaderEpoch)),
+            fullFetchResponse(noId, this.records, Errors.NONE, 100L, 0)
+        );
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
-        assertTrue(partitionRecords.containsKey(noId));
+        assertTrue(partitionRecords.containsKey(noId.topicPartition()));
 
-        List<ConsumerRecord<byte[], byte[]>> records = partitionRecords.get(noId);
+        List<ConsumerRecord<byte[], byte[]>> records = partitionRecords.get(noId.topicPartition());
         assertEquals(3, records.size());
-        assertEquals(4L, subscriptions.position(noId).offset); // this is the next fetching position
+        assertEquals(4L, subscriptions.position(noId.topicPartition()).offset); // this is the next fetching position
         long offset = 1;
         for (ConsumerRecord<byte[], byte[]> record : records) {
             assertEquals(offset, record.offset());
@@ -295,6 +306,288 @@ public class FetcherTest {
     }
 
     @Test
+    public void testFetchWithTopicId() {
+        buildFetcher();
+
+        TopicIdPartition tp = new TopicIdPartition(topicId, new TopicPartition(topicName, 0));
+        assignFromUser(singleton(tp.topicPartition()));
+        subscriptions.seek(tp.topicPartition(), 0);
+
+        // Fetch should use latest version
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        client.prepareResponse(
+            fetchRequestMatcher(ApiKeys.FETCH.latestVersion(), tp, 0, Optional.of(validLeaderEpoch)),
+            fullFetchResponse(tp, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
+        assertTrue(partitionRecords.containsKey(tp.topicPartition()));
+
+        List<ConsumerRecord<byte[], byte[]>> records = partitionRecords.get(tp.topicPartition());
+        assertEquals(3, records.size());
+        assertEquals(4L, subscriptions.position(tp.topicPartition()).offset); // this is the next fetching position
+        long offset = 1;
+        for (ConsumerRecord<byte[], byte[]> record : records) {
+            assertEquals(offset, record.offset());
+            offset += 1;
+        }
+    }
+
+    @Test
+    public void testFetchForgetTopicIdWhenUnassigned() {
+        buildFetcher();
+
+        TopicIdPartition foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0));
+        TopicIdPartition bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0));
+
+        // Assign foo and bar.
+        subscriptions.assignFromUser(singleton(foo.topicPartition()));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(foo), tp -> validLeaderEpoch));
+        subscriptions.seek(foo.topicPartition(), 0);
+
+        // Fetch should use latest version.
+        assertEquals(1, fetcher.sendFetches());
+
+        client.prepareResponse(
+            fetchRequestMatcher(ApiKeys.FETCH.latestVersion(),
+                singletonMap(foo, new PartitionData(
+                    foo.topicId(),
+                    0,
+                    FetchRequest.INVALID_LOG_START_OFFSET,
+                    fetchSize,
+                    Optional.of(validLeaderEpoch))
+                ),
+                emptyList()
+            ),
+            fullFetchResponse(1, foo, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+
+        // Assign bar and unassign foo.
+        subscriptions.assignFromUser(singleton(bar.topicPartition()));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(bar), tp -> validLeaderEpoch));
+        subscriptions.seek(bar.topicPartition(), 0);
+
+        // Fetch should use latest version.
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        client.prepareResponse(
+            fetchRequestMatcher(ApiKeys.FETCH.latestVersion(),
+                singletonMap(bar, new PartitionData(
+                    bar.topicId(),
+                    0,
+                    FetchRequest.INVALID_LOG_START_OFFSET,
+                    fetchSize,
+                    Optional.of(validLeaderEpoch))
+                ),
+                singletonList(foo)
+            ),
+            fullFetchResponse(1, bar, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+    }
+
+    @Test
+    public void testFetchForgetTopicIdWhenReplaced() {
+        buildFetcher();
+
+        TopicIdPartition fooWithOldTopicId = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0));
+        TopicIdPartition fooWithNewTopicId = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0));
+
+        // Assign foo with old topic id.
+        subscriptions.assignFromUser(singleton(fooWithOldTopicId.topicPartition()));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithOldTopicId), tp -> validLeaderEpoch));
+        subscriptions.seek(fooWithOldTopicId.topicPartition(), 0);
+
+        // Fetch should use latest version.
+        assertEquals(1, fetcher.sendFetches());
+
+        client.prepareResponse(
+            fetchRequestMatcher(ApiKeys.FETCH.latestVersion(),
+                singletonMap(fooWithOldTopicId, new PartitionData(
+                    fooWithOldTopicId.topicId(),
+                    0,
+                    FetchRequest.INVALID_LOG_START_OFFSET,
+                    fetchSize,
+                    Optional.of(validLeaderEpoch))
+                ),
+                emptyList()
+            ),
+            fullFetchResponse(1, fooWithOldTopicId, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+
+        // Replace foo with old topic id with foo with new topic id.
+        subscriptions.assignFromUser(singleton(fooWithNewTopicId.topicPartition()));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithNewTopicId), tp -> validLeaderEpoch));
+        subscriptions.seek(fooWithNewTopicId.topicPartition(), 0);
+
+        // Fetch should use latest version.
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        // foo with old topic id should be removed from the session.
+        client.prepareResponse(
+            fetchRequestMatcher(ApiKeys.FETCH.latestVersion(),
+                singletonMap(fooWithNewTopicId, new PartitionData(
+                    fooWithNewTopicId.topicId(),
+                    0,
+                    FetchRequest.INVALID_LOG_START_OFFSET,
+                    fetchSize,
+                    Optional.of(validLeaderEpoch))
+                ),
+                singletonList(fooWithOldTopicId)
+            ),
+            fullFetchResponse(1, fooWithNewTopicId, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+    }
+
+    @Test
+    public void testFetchTopicIdUpgradeDowngrade() {
+        buildFetcher();
+
+        TopicIdPartition fooWithoutId = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0));
+        TopicIdPartition fooWithId = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0));
+
+        // Assign foo without a topic id.
+        subscriptions.assignFromUser(singleton(fooWithoutId.topicPartition()));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithoutId), tp -> validLeaderEpoch));
+        subscriptions.seek(fooWithoutId.topicPartition(), 0);
+
+        // Fetch should use version 12.
+        assertEquals(1, fetcher.sendFetches());
+
+        client.prepareResponse(
+            fetchRequestMatcher((short) 12,
+                singletonMap(fooWithoutId, new PartitionData(
+                    fooWithoutId.topicId(),
+                    0,
+                    FetchRequest.INVALID_LOG_START_OFFSET,
+                    fetchSize,
+                    Optional.of(validLeaderEpoch))
+                ),
+                emptyList()
+            ),
+            fullFetchResponse(1, fooWithoutId, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+
+        // Upgrade.
+        subscriptions.assignFromUser(singleton(fooWithId.topicPartition()));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithId), tp -> validLeaderEpoch));
+        subscriptions.seek(fooWithId.topicPartition(), 0);
+
+        // Fetch should use latest version.
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        // foo with old topic id should be removed from the session.
+        client.prepareResponse(
+            fetchRequestMatcher(ApiKeys.FETCH.latestVersion(),
+                singletonMap(fooWithId, new PartitionData(
+                    fooWithId.topicId(),
+                    0,
+                    FetchRequest.INVALID_LOG_START_OFFSET,
+                    fetchSize,
+                    Optional.of(validLeaderEpoch))
+                ),
+                emptyList()
+            ),
+            fullFetchResponse(1, fooWithId, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+
+        // Downgrade.
+        subscriptions.assignFromUser(singleton(fooWithoutId.topicPartition()));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithoutId), tp -> validLeaderEpoch));
+        subscriptions.seek(fooWithoutId.topicPartition(), 0);
+
+        // Fetch should use version 12.
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        // foo with old topic id should be removed from the session.
+        client.prepareResponse(
+            fetchRequestMatcher((short) 12,
+                singletonMap(fooWithoutId, new PartitionData(
+                    fooWithoutId.topicId(),
+                    0,
+                    FetchRequest.INVALID_LOG_START_OFFSET,
+                    fetchSize,
+                    Optional.of(validLeaderEpoch))
+                ),
+                emptyList()
+            ),
+            fullFetchResponse(1, fooWithoutId, this.records, Errors.NONE, 100L, 0)
+        );
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+    }
+
+    private MockClient.RequestMatcher fetchRequestMatcher(
+        short expectedVersion,
+        TopicIdPartition tp,
+        long expectedFetchOffset,
+        Optional<Integer> expectedCurrentLeaderEpoch
+    ) {
+        return fetchRequestMatcher(
+            expectedVersion,
+            singletonMap(tp, new PartitionData(
+                tp.topicId(),
+                expectedFetchOffset,
+                FetchRequest.INVALID_LOG_START_OFFSET,
+                fetchSize,
+                expectedCurrentLeaderEpoch
+            )),
+            emptyList()
+        );
+    }
+
+    private MockClient.RequestMatcher fetchRequestMatcher(
+        short expectedVersion,
+        Map<TopicIdPartition, PartitionData> fetch,
+        List<TopicIdPartition> forgotten
+    ) {
+        return body -> {
+            if (body instanceof FetchRequest) {
+                FetchRequest fetchRequest = (FetchRequest) body;
+                assertEquals(expectedVersion, fetchRequest.version());
+                assertEquals(fetch, fetchRequest.fetchData(topicNames(new ArrayList<>(fetch.keySet()))));
+                assertEquals(forgotten, fetchRequest.forgottenTopics(topicNames(forgotten)));
+                return true;
+            } else {
+                fail("Should have seen FetchRequest");
+                return false;
+            }
+        };
+    }
+
+    private Map<Uuid, String> topicNames(List<TopicIdPartition> partitions) {
+        Map<Uuid, String> topicNames = new HashMap<>();
+        partitions.forEach(partition -> topicNames.putIfAbsent(partition.topicId(), partition.topic()));
+        return topicNames;
+    }
+
+    @Test
     public void testMissingLeaderEpochInRecords() {
         buildFetcher();
 
@@ -312,7 +605,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -363,7 +656,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -388,7 +681,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
         Set<TopicPartition> newAssignedTopicPartitions = new HashSet<>();
@@ -440,7 +733,7 @@ public class FetcherTest {
 
         buffer.flip();
 
-        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -465,7 +758,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_OR_FOLLOWER, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NOT_LEADER_OR_FOLLOWER, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -473,10 +766,10 @@ public class FetcherTest {
         assertFalse(partitionRecords.containsKey(tp0));
     }
 
-    private MockClient.RequestMatcher matchesOffset(final TopicPartition tp, final long offset) {
+    private MockClient.RequestMatcher matchesOffset(final TopicIdPartition tp, final long offset) {
         return body -> {
             FetchRequest fetch = (FetchRequest) body;
-            Map<TopicPartition, FetchRequest.PartitionData> fetchData =  fetch.fetchData(topicNames);
+            Map<TopicIdPartition, FetchRequest.PartitionData> fetchData = fetch.fetchData(topicNames);
             return fetchData.containsKey(tp) &&
                     fetchData.get(tp).fetchOffset == offset;
         };
@@ -504,7 +797,7 @@ public class FetcherTest {
         assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
-        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
@@ -568,7 +861,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         // the first fetchedRecords() should return the first valid message
@@ -606,7 +899,7 @@ public class FetcherTest {
         // Should not throw exception after the seek.
         fetcher.fetchedRecords();
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
@@ -642,7 +935,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         // the fetchedRecords() should always throw exception due to the bad batch.
@@ -674,7 +967,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         try {
             fetcher.fetchedRecords();
@@ -707,7 +1000,7 @@ public class FetcherTest {
         assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
-        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, memoryRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, memoryRecords, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
@@ -738,8 +1031,8 @@ public class FetcherTest {
         assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 1);
 
-        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
-        client.prepareResponse(matchesOffset(tp0, 4), fullFetchResponse(tp0, this.nextRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tidp0, 4), fullFetchResponse(tidp0, this.nextRecords, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
@@ -782,7 +1075,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 1);
 
         // Returns 3 records while `max.poll.records` is configured to 2
-        client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
 
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
@@ -794,7 +1087,7 @@ public class FetcherTest {
         assertEquals(2, records.get(1).offset());
 
         assignFromUser(singleton(tp1));
-        client.prepareResponse(matchesOffset(tp1, 4), fullFetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(matchesOffset(tidp1, 4), fullFetchResponse(tidp1, this.nextRecords, Errors.NONE, 100L, 0));
         subscriptions.seek(tp1, 4);
 
         assertEquals(1, fetcher.sendFetches());
@@ -827,7 +1120,7 @@ public class FetcherTest {
 
         // normal fetch
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
         consumerRecords = recordsByPartition.get(tp0);
@@ -890,7 +1183,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
         MemoryRecords partialRecord = MemoryRecords.readableRecords(
             ByteBuffer.wrap(new byte[]{0, 0, 0, 0, 0, 0, 0, 0}));
-        client.prepareResponse(fullFetchResponse(tp0, partialRecord, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, partialRecord, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
     }
@@ -904,7 +1197,7 @@ public class FetcherTest {
 
         // resize the limit of the buffer to pretend it is only fetch-size large
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0));
         consumerClient.poll(time.timer(0));
         try {
             fetcher.fetchedRecords();
@@ -931,7 +1224,7 @@ public class FetcherTest {
         subscriptions.assignFromSubscribed(Collections.emptyList());
 
         subscriptions.assignFromSubscribed(singleton(tp0));
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         // The active fetch should be ignored since its position is no longer valid
@@ -954,7 +1247,7 @@ public class FetcherTest {
         // Now the cooperative rebalance happens and fetch positions are NOT cleared for unrevoked partitions
         subscriptions.assignFromSubscribed(singleton(tp0));
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
@@ -974,7 +1267,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         subscriptions.pause(tp0);
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertNull(fetcher.fetchedRecords().get(tp0));
     }
@@ -1002,7 +1295,7 @@ public class FetcherTest {
 
         subscriptions.pause(tp0);
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
@@ -1034,20 +1327,20 @@ public class FetcherTest {
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords;
 
-        assignFromUser(Utils.mkSet(tp0, tp1));
+        assignFromUser(mkSet(tp0, tp1));
 
         // seek to tp0 and tp1 in two polls to generate 2 complete requests and responses
 
         // #1 seek, request, poll, response
         subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp0)));
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         // #2 seek, request, poll, response
         subscriptions.seekUnvalidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1)));
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp1, this.nextRecords, Errors.NONE, 100L, 0));
 
         subscriptions.pause(tp0);
         consumerClient.poll(time.timer(0));
@@ -1069,20 +1362,20 @@ public class FetcherTest {
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords;
 
-        assignFromUser(Utils.mkSet(tp0, tp1));
+        assignFromUser(mkSet(tp0, tp1));
 
         // seek to tp0 and tp1 in two polls to generate 2 complete requests and responses
 
         // #1 seek, request, poll, response
         subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp0)));
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         // #2 seek, request, poll, response
         subscriptions.seekUnvalidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1)));
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp1, this.nextRecords, Errors.NONE, 100L, 0));
 
         subscriptions.pause(tp0);
         subscriptions.pause(tp1);
@@ -1104,11 +1397,11 @@ public class FetcherTest {
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords;
 
-        assignFromUser(Utils.mkSet(tp0, tp1));
+        assignFromUser(mkSet(tp0, tp1));
 
         subscriptions.seek(tp0, 1);
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         fetchedRecords = fetchedRecords();
@@ -1143,7 +1436,7 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
         subscriptions.pause(tp0);
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
 
         subscriptions.seek(tp0, 3);
         subscriptions.resume(tp0);
@@ -1163,7 +1456,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_OR_FOLLOWER, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NOT_LEADER_OR_FOLLOWER, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
@@ -1176,7 +1469,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
@@ -1189,7 +1482,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponseWithTopLevelError(tp0, Errors.UNKNOWN_TOPIC_ID, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.UNKNOWN_TOPIC_ID, -1L, 0));
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
@@ -1202,7 +1495,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponseWithTopLevelError(tp0, Errors.FETCH_SESSION_TOPIC_ID_ERROR, 0));
+        client.prepareResponse(fetchResponseWithTopLevelError(tidp0, Errors.FETCH_SESSION_TOPIC_ID_ERROR, 0));
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
@@ -1215,7 +1508,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fetchResponseWithTopLevelError(tp0, Errors.INCONSISTENT_TOPIC_ID, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.INCONSISTENT_TOPIC_ID, -1L, 0));
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()));
@@ -1228,7 +1521,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.FENCED_LEADER_EPOCH, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.FENCED_LEADER_EPOCH, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         assertEquals(0, fetcher.fetchedRecords().size(), "Should not return any records");
@@ -1242,7 +1535,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.UNKNOWN_LEADER_EPOCH, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.UNKNOWN_LEADER_EPOCH, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         assertEquals(0, fetcher.fetchedRecords().size(), "Should not return any records");
@@ -1274,7 +1567,7 @@ public class FetcherTest {
                 return false;
             }
         };
-        client.prepareResponse(matcher, fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(matcher, fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.pollNoWakeup();
     }
 
@@ -1285,7 +1578,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertTrue(subscriptions.isOffsetResetNeeded(tp0));
@@ -1302,7 +1595,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         subscriptions.seek(tp0, 1);
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
@@ -1319,7 +1612,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertTrue(fetcher.sendFetches() > 0);
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         subscriptions.seek(tp0, 2);
@@ -1335,7 +1628,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         fetcher.sendFetches();
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
@@ -1353,22 +1646,22 @@ public class FetcherTest {
         // some fetched partitions cause Exception. This ensures that consumer won't lose record upon exception
         buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
-        assignFromUser(Utils.mkSet(tp0, tp1));
+        assignFromUser(mkSet(tp0, tp1));
         subscriptions.seek(tp0, 1);
         subscriptions.seek(tp1, 1);
 
         assertEquals(1, fetcher.sendFetches());
 
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = new LinkedHashMap<>();
-        partitions.put(tp1, new FetchResponseData.PartitionData()
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = new LinkedHashMap<>();
+        partitions.put(tidp1, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp1.partition())
                 .setHighWatermark(100)
                 .setRecords(records));
-        partitions.put(tp0, new FetchResponseData.PartitionData()
+        partitions.put(tidp0, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp0.partition())
                 .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code())
                 .setHighWatermark(100));
-        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds));
+        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)));
         consumerClient.poll(time.timer(0));
 
         List<ConsumerRecord<byte[], byte[]>> allFetchedRecords = new ArrayList<>();
@@ -1399,7 +1692,7 @@ public class FetcherTest {
         // Ensure the removal of completed fetches that cause an Exception if and only if they contain empty records.
         buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED);
-        assignFromUser(Utils.mkSet(tp0, tp1, tp2, tp3));
+        assignFromUser(mkSet(tp0, tp1, tp2, tp3));
 
         subscriptions.seek(tp0, 1);
         subscriptions.seek(tp1, 1);
@@ -1408,28 +1701,28 @@ public class FetcherTest {
 
         assertEquals(1, fetcher.sendFetches());
 
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = new LinkedHashMap<>();
-        partitions.put(tp1, new FetchResponseData.PartitionData()
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = new LinkedHashMap<>();
+        partitions.put(tidp1, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp1.partition())
                 .setHighWatermark(100)
                 .setRecords(records));
-        partitions.put(tp0, new FetchResponseData.PartitionData()
+        partitions.put(tidp0, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp0.partition())
                 .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code())
                 .setHighWatermark(100));
-        partitions.put(tp2, new FetchResponseData.PartitionData()
+        partitions.put(tidp2, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp2.partition())
                 .setHighWatermark(100)
                 .setLastStableOffset(4)
                 .setLogStartOffset(0)
                 .setRecords(nextRecords));
-        partitions.put(tp3, new FetchResponseData.PartitionData()
+        partitions.put(tidp3, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp3.partition())
                 .setHighWatermark(100)
                 .setLastStableOffset(4)
                 .setLogStartOffset(0)
                 .setRecords(partialRecords));
-        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds));
+        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)));
         consumerClient.poll(time.timer(0));
 
         List<ConsumerRecord<byte[], byte[]>> fetchedRecords = new ArrayList<>();
@@ -1484,29 +1777,29 @@ public class FetcherTest {
         buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(),
                 new ByteArrayDeserializer(), 2, IsolationLevel.READ_UNCOMMITTED);
 
-        assignFromUser(Utils.mkSet(tp0));
+        assignFromUser(mkSet(tp0));
         subscriptions.seek(tp0, 1);
         assertEquals(1, fetcher.sendFetches());
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = new HashMap<>();
-        partitions.put(tp0, new FetchResponseData.PartitionData()
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = new HashMap<>();
+        partitions.put(tidp0, new FetchResponseData.PartitionData()
                         .setPartitionIndex(tp0.partition())
                         .setHighWatermark(100)
                         .setRecords(records));
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
 
         assertEquals(2, fetcher.fetchedRecords().get(tp0).size());
 
-        subscriptions.assignFromUser(Utils.mkSet(tp0, tp1));
+        subscriptions.assignFromUser(mkSet(tp0, tp1));
         subscriptions.seekUnvalidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1)));
 
         assertEquals(1, fetcher.sendFetches());
         partitions = new HashMap<>();
-        partitions.put(tp1, new FetchResponseData.PartitionData()
+        partitions.put(tidp1, new FetchResponseData.PartitionData()
                         .setPartitionIndex(tp1.partition())
                         .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code())
                         .setHighWatermark(100));
-        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds));
+        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)));
         consumerClient.poll(time.timer(0));
         assertEquals(1, fetcher.fetchedRecords().get(tp0).size());
 
@@ -1523,7 +1816,7 @@ public class FetcherTest {
         subscriptions.seek(tp0, 0);
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0), true);
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0), true);
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
 
@@ -2196,12 +2489,12 @@ public class FetcherTest {
 
         for (int i = 1; i <= 3; i++) {
             int throttleTimeMs = 100 * i;
-            FetchRequest.Builder builder = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 100, 100, new LinkedHashMap<>(), topicIds);
+            FetchRequest.Builder builder = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 100, 100, new LinkedHashMap<>());
             builder.rackId("");
             ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true);
             client.send(request, time.milliseconds());
             client.poll(1, time.milliseconds());
-            FetchResponse response = fullFetchResponse(tp0, nextRecords, Errors.NONE, i, throttleTimeMs);
+            FetchResponse response = fullFetchResponse(tidp0, nextRecords, Errors.NONE, i, throttleTimeMs);
             buffer = RequestTestUtils.serializeResponseWithHeader(response, ApiKeys.FETCH.latestVersion(), request.correlationId());
             selector.completeReceive(new NetworkReceive(node.idString(), buffer));
             client.poll(1, time.milliseconds());
@@ -2240,7 +2533,7 @@ public class FetcherTest {
         assertEquals(Double.NaN, (Double) recordsFetchLagMax.metricValue(), EPSILON);
 
         // recordsFetchLagMax should be hw - fetchOffset after receiving an empty FetchResponse
-        fetchRecords(tp0, MemoryRecords.EMPTY, Errors.NONE, 100L, 0);
+        fetchRecords(tidp0, MemoryRecords.EMPTY, Errors.NONE, 100L, 0);
         assertEquals(100, (Double) recordsFetchLagMax.metricValue(), EPSILON);
 
         KafkaMetric partitionLag = allMetrics.get(partitionLagMetric);
@@ -2251,7 +2544,7 @@ public class FetcherTest {
                 TimestampType.CREATE_TIME, 0L);
         for (int v = 0; v < 3; v++)
             builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes());
-        fetchRecords(tp0, builder.build(), Errors.NONE, 200L, 0);
+        fetchRecords(tidp0, builder.build(), Errors.NONE, 200L, 0);
         assertEquals(197, (Double) recordsFetchLagMax.metricValue(), EPSILON);
         assertEquals(197, (Double) partitionLag.metricValue(), EPSILON);
 
@@ -2281,7 +2574,7 @@ public class FetcherTest {
         assertEquals(Double.NaN, (Double) recordsFetchLeadMin.metricValue(), EPSILON);
 
         // recordsFetchLeadMin should be position - logStartOffset after receiving an empty FetchResponse
-        fetchRecords(tp0, MemoryRecords.EMPTY, Errors.NONE, 100L, -1L, 0L, 0);
+        fetchRecords(tidp0, MemoryRecords.EMPTY, Errors.NONE, 100L, -1L, 0L, 0);
         assertEquals(0L, (Double) recordsFetchLeadMin.metricValue(), EPSILON);
 
         KafkaMetric partitionLead = allMetrics.get(partitionLeadMetric);
@@ -2293,7 +2586,7 @@ public class FetcherTest {
         for (int v = 0; v < 3; v++) {
             builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes());
         }
-        fetchRecords(tp0, builder.build(), Errors.NONE, 200L, -1L, 0L, 0);
+        fetchRecords(tidp0, builder.build(), Errors.NONE, 200L, -1L, 0L, 0);
         assertEquals(0L, (Double) recordsFetchLeadMin.metricValue(), EPSILON);
         assertEquals(3L, (Double) partitionLead.metricValue(), EPSILON);
 
@@ -2325,7 +2618,7 @@ public class FetcherTest {
         assertEquals(Double.NaN, (Double) recordsFetchLagMax.metricValue(), EPSILON);
 
         // recordsFetchLagMax should be lso - fetchOffset after receiving an empty FetchResponse
-        fetchRecords(tp0, MemoryRecords.EMPTY, Errors.NONE, 100L, 50L, 0);
+        fetchRecords(tidp0, MemoryRecords.EMPTY, Errors.NONE, 100L, 50L, 0);
         assertEquals(50, (Double) recordsFetchLagMax.metricValue(), EPSILON);
 
         KafkaMetric partitionLag = allMetrics.get(partitionLagMetric);
@@ -2336,7 +2629,7 @@ public class FetcherTest {
                 TimestampType.CREATE_TIME, 0L);
         for (int v = 0; v < 3; v++)
             builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes());
-        fetchRecords(tp0, builder.build(), Errors.NONE, 200L, 150L, 0);
+        fetchRecords(tidp0, builder.build(), Errors.NONE, 200L, 150L, 0);
         assertEquals(147, (Double) recordsFetchLagMax.metricValue(), EPSILON);
         assertEquals(147, (Double) partitionLag.metricValue(), EPSILON);
 
@@ -2355,20 +2648,22 @@ public class FetcherTest {
         TopicPartition tp1 = new TopicPartition(topic1, 0);
         TopicPartition tp2 = new TopicPartition(topic2, 0);
 
-        subscriptions.assignFromUser(Utils.mkSet(tp1, tp2));
+        subscriptions.assignFromUser(mkSet(tp1, tp2));
 
         Map<String, Integer> partitionCounts = new HashMap<>();
         partitionCounts.put(topic1, 1);
         partitionCounts.put(topic2, 1);
         topicIds.put(topic1, Uuid.randomUuid());
         topicIds.put(topic2, Uuid.randomUuid());
+        TopicIdPartition tidp1 = new TopicIdPartition(topicIds.get(topic1), tp1);
+        TopicIdPartition tidp2 = new TopicIdPartition(topicIds.get(topic2), tp2);
         client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, partitionCounts, tp -> validLeaderEpoch, topicIds));
 
         int expectedBytes = 0;
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> fetchPartitionData = new LinkedHashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> fetchPartitionData = new LinkedHashMap<>();
 
-        for (TopicPartition tp : Utils.mkSet(tp1, tp2)) {
-            subscriptions.seek(tp, 0);
+        for (TopicIdPartition tp : mkSet(tidp1, tidp2)) {
+            subscriptions.seek(tp.topicPartition(), 0);
 
             MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE,
                     TimestampType.CREATE_TIME, 0L);
@@ -2379,14 +2674,14 @@ public class FetcherTest {
                 expectedBytes += record.sizeInBytes();
 
             fetchPartitionData.put(tp, new FetchResponseData.PartitionData()
-                    .setPartitionIndex(tp.partition())
+                    .setPartitionIndex(tp.topicPartition().partition())
                     .setHighWatermark(15)
                     .setLogStartOffset(0)
                     .setRecords(records));
         }
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, fetchPartitionData, topicIds));
+        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, fetchPartitionData));
         consumerClient.poll(time.timer(0));
 
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchedRecords = fetchedRecords();
@@ -2423,7 +2718,7 @@ public class FetcherTest {
                 expectedBytes += record.sizeInBytes();
         }
 
-        fetchRecords(tp0, records, Errors.NONE, 100L, 0);
+        fetchRecords(tidp0, records, Errors.NONE, 100L, 0);
         assertEquals(expectedBytes, (Double) fetchSizeAverage.metricValue(), EPSILON);
         assertEquals(2, (Double) recordsCountAverage.metricValue(), EPSILON);
     }
@@ -2431,7 +2726,7 @@ public class FetcherTest {
     @Test
     public void testFetchResponseMetricsWithOnePartitionError() {
         buildFetcher();
-        assignFromUser(Utils.mkSet(tp0, tp1));
+        assignFromUser(mkSet(tp0, tp1));
         subscriptions.seek(tp0, 0);
         subscriptions.seek(tp1, 0);
 
@@ -2445,20 +2740,20 @@ public class FetcherTest {
             builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes());
         MemoryRecords records = builder.build();
 
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = new HashMap<>();
-        partitions.put(tp0, new FetchResponseData.PartitionData()
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = new HashMap<>();
+        partitions.put(tidp0, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp0.partition())
                 .setHighWatermark(100)
                 .setLogStartOffset(0)
                 .setRecords(records));
-        partitions.put(tp1, new FetchResponseData.PartitionData()
+        partitions.put(tidp1, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp1.partition())
                 .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code())
                 .setHighWatermark(100)
                 .setLogStartOffset(0));
 
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds));
+        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)));
         consumerClient.poll(time.timer(0));
         fetcher.fetchedRecords();
 
@@ -2474,7 +2769,7 @@ public class FetcherTest {
     public void testFetchResponseMetricsWithOnePartitionAtTheWrongOffset() {
         buildFetcher();
 
-        assignFromUser(Utils.mkSet(tp0, tp1));
+        assignFromUser(mkSet(tp0, tp1));
         subscriptions.seek(tp0, 0);
         subscriptions.seek(tp1, 0);
 
@@ -2492,19 +2787,19 @@ public class FetcherTest {
             builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes());
         MemoryRecords records = builder.build();
 
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = new HashMap<>();
-        partitions.put(tp0, new FetchResponseData.PartitionData()
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = new HashMap<>();
+        partitions.put(tidp0, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp0.partition())
                 .setHighWatermark(100)
                 .setLogStartOffset(0)
                 .setRecords(records));
-        partitions.put(tp1, new FetchResponseData.PartitionData()
+        partitions.put(tidp1, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp1.partition())
                 .setHighWatermark(100)
                 .setLogStartOffset(0)
                 .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("val".getBytes()))));
 
-        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds));
+        client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)));
         consumerClient.poll(time.timer(0));
         fetcher.fetchedRecords();
 
@@ -2527,7 +2822,7 @@ public class FetcherTest {
         assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
@@ -2547,12 +2842,12 @@ public class FetcherTest {
     }
 
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchRecords(
-            TopicPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) {
+            TopicIdPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) {
         return fetchRecords(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime);
     }
 
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchRecords(
-            TopicPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, int throttleTime) {
+            TopicIdPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, int throttleTime) {
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fullFetchResponse(tp, records, error, hw, lastStableOffset, throttleTime));
         consumerClient.poll(time.timer(0));
@@ -2560,7 +2855,7 @@ public class FetcherTest {
     }
 
     private Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> fetchRecords(
-            TopicPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, long logStartOffset, int throttleTime) {
+            TopicIdPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, long logStartOffset, int throttleTime) {
         assertEquals(1, fetcher.sendFetches());
         client.prepareResponse(fetchResponse(tp, records, error, hw, lastStableOffset, logStartOffset, throttleTime));
         consumerClient.poll(time.timer(0));
@@ -2632,7 +2927,7 @@ public class FetcherTest {
         for (Errors retriableError : retriableErrors) {
             buildFetcher();
 
-            subscriptions.assignFromUser(Utils.mkSet(tp0, tp1));
+            subscriptions.assignFromUser(mkSet(tp0, tp1));
             client.updateMetadata(initialUpdateResponse);
 
             final long fetchTimestamp = 10L;
@@ -2784,7 +3079,7 @@ public class FetcherTest {
     public void testGetOffsetsForTimesWhenSomeTopicPartitionLeadersNotKnownInitially() {
         buildFetcher();
 
-        subscriptions.assignFromUser(Utils.mkSet(tp0, tp1));
+        subscriptions.assignFromUser(mkSet(tp0, tp1));
         final String anotherTopic = "another-topic";
         final TopicPartition t2p0 = new TopicPartition(anotherTopic, 0);
 
@@ -2831,7 +3126,7 @@ public class FetcherTest {
         buildFetcher();
         final String anotherTopic = "another-topic";
         final TopicPartition t2p0 = new TopicPartition(anotherTopic, 0);
-        subscriptions.assignFromUser(Utils.mkSet(tp0, t2p0));
+        subscriptions.assignFromUser(mkSet(tp0, t2p0));
 
         client.reset();
 
@@ -3044,7 +3339,7 @@ public class FetcherTest {
         for (ConsumerRecord<byte[], byte[]> consumerRecord : fetchedConsumerRecords) {
             fetchedKeys.add(new String(consumerRecord.key(), StandardCharsets.UTF_8));
         }
-        assertEquals(Utils.mkSet("commit1-1", "commit1-2", "commit2-1"), fetchedKeys);
+        assertEquals(mkSet("commit1-1", "commit1-2", "commit2-1"), fetchedKeys);
     }
 
     @Test
@@ -3164,7 +3459,7 @@ public class FetcherTest {
         assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, compactedRecords, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, compactedRecords, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -3201,7 +3496,7 @@ public class FetcherTest {
         assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         assertEquals(1, fetcher.sendFetches());
-        client.prepareResponse(fullFetchResponse(tp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -3355,19 +3650,19 @@ public class FetcherTest {
         subscriptions.seekValidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1)));
 
         // Fetch some records and establish an incremental fetch session.
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> partitions1 = new LinkedHashMap<>();
-        partitions1.put(tp0, new FetchResponseData.PartitionData()
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> partitions1 = new LinkedHashMap<>();
+        partitions1.put(tidp0, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp0.partition())
                 .setHighWatermark(2)
                 .setLastStableOffset(2)
                 .setLogStartOffset(0)
                 .setRecords(this.records));
-        partitions1.put(tp1, new FetchResponseData.PartitionData()
+        partitions1.put(tidp1, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp1.partition())
                 .setHighWatermark(100)
                 .setLogStartOffset(0)
                 .setRecords(emptyRecords));
-        FetchResponse resp1 = FetchResponse.of(Errors.NONE, 0, 123, partitions1, topicIds);
+        FetchResponse resp1 = FetchResponse.of(Errors.NONE, 0, 123, partitions1);
         client.prepareResponse(resp1);
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
@@ -3392,8 +3687,8 @@ public class FetcherTest {
         assertEquals(4L, subscriptions.position(tp0).offset);
 
         // The second response contains no new records.
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> partitions2 = new LinkedHashMap<>();
-        FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, 123, partitions2, topicIds);
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> partitions2 = new LinkedHashMap<>();
+        FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, 123, partitions2);
         client.prepareResponse(resp2);
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
@@ -3403,14 +3698,14 @@ public class FetcherTest {
         assertEquals(1L, subscriptions.position(tp1).offset);
 
         // The third response contains some new records for tp0.
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> partitions3 = new LinkedHashMap<>();
-        partitions3.put(tp0, new FetchResponseData.PartitionData()
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> partitions3 = new LinkedHashMap<>();
+        partitions3.put(tidp0, new FetchResponseData.PartitionData()
                 .setPartitionIndex(tp0.partition())
                 .setHighWatermark(100)
                 .setLastStableOffset(4)
                 .setLogStartOffset(0)
                 .setRecords(this.nextRecords));
-        FetchResponse resp3 = FetchResponse.of(Errors.NONE, 0, 123, partitions3, topicIds);
+        FetchResponse resp3 = FetchResponse.of(Errors.NONE, 0, 123, partitions3);
         client.prepareResponse(resp3);
         assertEquals(1, fetcher.sendFetches());
         consumerClient.poll(time.timer(0));
@@ -3517,18 +3812,18 @@ public class FetcherTest {
                     if (!client.requests().isEmpty()) {
                         ClientRequest request = client.requests().peek();
                         FetchRequest fetchRequest = (FetchRequest) request.requestBuilder().build();
-                        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseMap = new LinkedHashMap<>();
-                        for (Map.Entry<TopicPartition, FetchRequest.PartitionData> entry : fetchRequest.fetchData(topicNames).entrySet()) {
-                            TopicPartition tp = entry.getKey();
+                        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseMap = new LinkedHashMap<>();
+                        for (Map.Entry<TopicIdPartition, FetchRequest.PartitionData> entry : fetchRequest.fetchData(topicNames).entrySet()) {
+                            TopicIdPartition tp = entry.getKey();
                             long offset = entry.getValue().fetchOffset;
                             responseMap.put(tp, new FetchResponseData.PartitionData()
-                                    .setPartitionIndex(tp.partition())
+                                    .setPartitionIndex(tp.topicPartition().partition())
                                     .setHighWatermark(offset + 2)
                                     .setLastStableOffset(offset + 2)
                                     .setLogStartOffset(0)
                                     .setRecords(buildRecords(offset, 2, offset)));
                         }
-                        client.respondToRequest(request, FetchResponse.of(Errors.NONE, 0, 123, responseMap, topicIds));
+                        client.respondToRequest(request, FetchResponse.of(Errors.NONE, 0, 123, responseMap));
                         consumerClient.poll(time.timer(0));
                     }
                 }
@@ -3583,15 +3878,15 @@ public class FetcherTest {
                         assertTrue(epoch == 0 || epoch == nextEpoch,
                             String.format("Unexpected epoch expected %d got %d", nextEpoch, epoch));
                         nextEpoch++;
-                        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseMap = new LinkedHashMap<>();
-                        responseMap.put(tp0, new FetchResponseData.PartitionData()
+                        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseMap = new LinkedHashMap<>();
+                        responseMap.put(tidp0, new FetchResponseData.PartitionData()
                                 .setPartitionIndex(tp0.partition())
                                 .setHighWatermark(nextOffset + 2)
                                 .setLastStableOffset(nextOffset + 2)
                                 .setLogStartOffset(0)
                                 .setRecords(buildRecords(nextOffset, 2, nextOffset)));
                         nextOffset += 2;
-                        client.respondToRequest(request, FetchResponse.of(Errors.NONE, 0, 123, responseMap, topicIds));
+                        client.respondToRequest(request, FetchResponse.of(Errors.NONE, 0, 123, responseMap));
                         consumerClient.poll(time.timer(0));
                     }
                 }
@@ -3857,7 +4152,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0));
         consumerClient.pollNoWakeup();
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -3871,7 +4166,7 @@ public class FetcherTest {
     @Test
     public void testOffsetValidationRequestGrouping() {
         buildFetcher();
-        assignFromUser(Utils.mkSet(tp0, tp1, tp2, tp3));
+        assignFromUser(mkSet(tp0, tp1, tp2, tp3));
 
         metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 3,
             Collections.emptyMap(), singletonMap(topicName, 4),
@@ -4314,7 +4609,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0));
         consumerClient.pollNoWakeup();
         assertTrue(fetcher.hasCompletedFetches());
 
@@ -4342,7 +4637,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         // Set preferred read replica to node=1
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L,
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(1)));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
@@ -4359,7 +4654,7 @@ public class FetcherTest {
         assertFalse(fetcher.hasCompletedFetches());
 
         // Set preferred read replica to node=2, which isn't in our metadata, should revert to leader
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L,
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(2)));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
@@ -4381,7 +4676,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L,
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(1)));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
@@ -4395,7 +4690,7 @@ public class FetcherTest {
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
 
-        client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L,
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L,
                 FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.empty()));
         consumerClient.poll(time.timer(0));
         assertTrue(fetcher.hasCompletedFetches());
@@ -4412,7 +4707,7 @@ public class FetcherTest {
         assignFromUser(singleton(tp0));
         subscriptions.seek(tp0, 0);
         fetcher.sendFetches();
-        client.prepareResponse(fullFetchResponse(tp0, buildRecords(1L, 1, 1), Errors.NONE, 100L, 0));
+        client.prepareResponse(fullFetchResponse(tidp0, buildRecords(1L, 1, 1), Errors.NONE, 100L, 0));
         consumerClient.poll(time.timer(0));
         fetchedRecords();
 
@@ -4447,7 +4742,7 @@ public class FetcherTest {
 
         // Prepare a response with the CORRUPT_MESSAGE error.
         client.prepareResponse(fullFetchResponse(
-                tp0,
+                tidp0,
                 buildRecords(1L, 1, 1),
                 Errors.CORRUPT_MESSAGE,
                 100L, 0));
@@ -4621,22 +4916,22 @@ public class FetcherTest {
         return new ListOffsetsResponse(data);
     }
 
-    private FetchResponse fetchResponseWithTopLevelError(TopicPartition tp, Errors error, int throttleTime) {
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
+    private FetchResponse fetchResponseWithTopLevelError(TopicIdPartition tp, Errors error, int throttleTime) {
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
                 new FetchResponseData.PartitionData()
-                        .setPartitionIndex(tp.partition())
+                        .setPartitionIndex(tp.topicPartition().partition())
                         .setErrorCode(error.code())
                         .setHighWatermark(FetchResponse.INVALID_HIGH_WATERMARK));
-        return FetchResponse.of(error, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds);
+        return FetchResponse.of(error, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions));
     }
 
     private FetchResponse fullFetchResponseWithAbortedTransactions(MemoryRecords records,
-                                                                                  List<FetchResponseData.AbortedTransaction> abortedTransactions,
-                                                                                  Errors error,
-                                                                                  long lastStableOffset,
-                                                                                  long hw,
-                                                                                  int throttleTime) {
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp0,
+                                                                   List<FetchResponseData.AbortedTransaction> abortedTransactions,
+                                                                   Errors error,
+                                                                   long lastStableOffset,
+                                                                   long hw,
+                                                                   int throttleTime) {
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tidp0,
                 new FetchResponseData.PartitionData()
                         .setPartitionIndex(tp0.partition())
                         .setErrorCode(error.code())
@@ -4645,51 +4940,60 @@ public class FetcherTest {
                         .setLogStartOffset(0)
                         .setAbortedTransactions(abortedTransactions)
                         .setRecords(records));
-        return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds);
+        return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions));
+    }
+
+    private FetchResponse fullFetchResponse(int sessionId, TopicIdPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) {
+        return fullFetchResponse(sessionId, tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime);
     }
 
-    private FetchResponse fullFetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) {
+    private FetchResponse fullFetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) {
         return fullFetchResponse(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime);
     }
 
-    private FetchResponse fullFetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw,
+    private FetchResponse fullFetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw,
+                                            long lastStableOffset, int throttleTime) {
+        return fullFetchResponse(INVALID_SESSION_ID, tp, records, error, hw, lastStableOffset, throttleTime);
+    }
+
+    private FetchResponse fullFetchResponse(int sessionId, TopicIdPartition tp, MemoryRecords records, Errors error, long hw,
                                             long lastStableOffset, int throttleTime) {
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
                 new FetchResponseData.PartitionData()
-                        .setPartitionIndex(tp.partition())
+                        .setPartitionIndex(tp.topicPartition().partition())
                         .setErrorCode(error.code())
                         .setHighWatermark(hw)
                         .setLastStableOffset(lastStableOffset)
                         .setLogStartOffset(0)
                         .setRecords(records));
-        return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds);
+        return FetchResponse.of(Errors.NONE, throttleTime, sessionId, new LinkedHashMap<>(partitions));
     }
 
-    private FetchResponse fullFetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw,
-                                                           long lastStableOffset, int throttleTime, Optional<Integer> preferredReplicaId) {
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
+    private FetchResponse fullFetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw,
+                                            long lastStableOffset, int throttleTime, Optional<Integer> preferredReplicaId) {
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
                 new FetchResponseData.PartitionData()
-                        .setPartitionIndex(tp.partition())
+                        .setPartitionIndex(tp.topicPartition().partition())
                         .setErrorCode(error.code())
                         .setHighWatermark(hw)
                         .setLastStableOffset(lastStableOffset)
                         .setLogStartOffset(0)
                         .setRecords(records)
                         .setPreferredReadReplica(preferredReplicaId.orElse(FetchResponse.INVALID_PREFERRED_REPLICA_ID)));
-        return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds);
+        return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions));
     }
 
-    private FetchResponse fetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw,
+    private FetchResponse fetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw,
                                         long lastStableOffset, long logStartOffset, int throttleTime) {
-        Map<TopicPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
+        Map<TopicIdPartition, FetchResponseData.PartitionData> partitions = Collections.singletonMap(tp,
                 new FetchResponseData.PartitionData()
-                        .setPartitionIndex(tp.partition())
+                        .setPartitionIndex(tp.topicPartition().partition())
                         .setErrorCode(error.code())
                         .setHighWatermark(hw)
                         .setLastStableOffset(lastStableOffset)
                         .setLogStartOffset(logStartOffset)
                         .setRecords(records));
-        return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions), topicIds);
+        return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions));
     }
 
     private MetadataResponse newMetadataResponse(String topic, Errors error) {
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/FetchRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/FetchRequestTest.java
new file mode 100644
index 0000000..a567d43
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/common/requests/FetchRequestTest.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.requests;
+
+import org.apache.kafka.common.TopicIdPartition;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Stream;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+
+public class FetchRequestTest {
+
+    private static Stream<Arguments> fetchVersions() {
+        return ApiKeys.FETCH.allVersions().stream().map(version -> Arguments.of(version));
+    }
+
+    @ParameterizedTest
+    @MethodSource("fetchVersions")
+    public void testToReplaceWithDifferentVersions(short version) {
+        boolean fetchRequestUsesTopicIds = version >= 13;
+        Uuid topicId = Uuid.randomUuid();
+        TopicIdPartition tp = new TopicIdPartition(topicId, 0, "topic");
+
+        Map<TopicPartition, FetchRequest.PartitionData> partitionData = Collections.singletonMap(tp.topicPartition(),
+                new FetchRequest.PartitionData(topicId, 0, 0, 0, Optional.empty()));
+        List<TopicIdPartition> toReplace = Collections.singletonList(tp);
+
+        FetchRequest fetchRequest = FetchRequest.Builder
+                .forReplica(version, 0, 1, 1, partitionData)
+                .removed(Collections.emptyList())
+                .replaced(toReplace)
+                .metadata(FetchMetadata.newIncremental(123)).build(version);
+
+        // If version < 13, we should not see any partitions in forgottenTopics. This is because we can not
+        // distinguish different topic IDs on versions earlier than 13.
+        assertEquals(fetchRequestUsesTopicIds, fetchRequest.data().forgottenTopicsData().size() > 0);
+        fetchRequest.data().forgottenTopicsData().forEach(forgottenTopic -> {
+            // Since we didn't serialize, we should see the topic name and ID regardless of the version.
+            assertEquals(tp.topic(), forgottenTopic.topic());
+            assertEquals(topicId, forgottenTopic.topicId());
+        });
+
+        assertEquals(1, fetchRequest.data().topics().size());
+        fetchRequest.data().topics().forEach(topic -> {
+            // Since we didn't serialize, we should see the topic name and ID regardless of the version.
+            assertEquals(tp.topic(), topic.topic());
+            assertEquals(topicId, topic.topicId());
+        });
+    }
+
+    @ParameterizedTest
+    @MethodSource("fetchVersions")
+    public void testFetchData(short version) {
+        TopicPartition topicPartition0 = new TopicPartition("topic", 0);
+        TopicPartition topicPartition1 = new TopicPartition("unknownIdTopic", 0);
+        Uuid topicId0 = Uuid.randomUuid();
+        Uuid topicId1 = Uuid.randomUuid();
+
+        // Only include topic IDs for the first topic partition.
+        Map<Uuid, String> topicNames = Collections.singletonMap(topicId0, topicPartition0.topic());
+        List<TopicIdPartition> topicIdPartitions = new LinkedList<>();
+        topicIdPartitions.add(new TopicIdPartition(topicId0, topicPartition0));
+        topicIdPartitions.add(new TopicIdPartition(topicId1, topicPartition1));
+
+        // Include one topic with topic IDs in the topic names map and one without.
+        Map<TopicPartition, FetchRequest.PartitionData> partitionData = new LinkedHashMap<>();
+        partitionData.put(topicPartition0, new FetchRequest.PartitionData(topicId0, 0, 0, 0, Optional.empty()));
+        partitionData.put(topicPartition1, new FetchRequest.PartitionData(topicId1, 0, 0, 0, Optional.empty()));
+        boolean fetchRequestUsesTopicIds = version >= 13;
+
+        FetchRequest fetchRequest = FetchRequest.parse(FetchRequest.Builder
+                .forReplica(version, 0, 1, 1, partitionData)
+                .removed(Collections.emptyList())
+                .replaced(Collections.emptyList())
+                .metadata(FetchMetadata.newIncremental(123)).build(version).serialize(), version);
+
+        // For versions < 13, we will be provided a topic name and a zero UUID in FetchRequestData.
+        // Versions 13+ will contain a valid topic ID but an empty topic name.
+        List<TopicIdPartition> expectedData = new LinkedList<>();
+        topicIdPartitions.forEach(tidp -> {
+            String expectedName = fetchRequestUsesTopicIds ? "" : tidp.topic();
+            Uuid expectedTopicId = fetchRequestUsesTopicIds ? tidp.topicId() : Uuid.ZERO_UUID;
+            expectedData.add(new TopicIdPartition(expectedTopicId, tidp.partition(), expectedName));
+        });
+
+        // Build the list of TopicIdPartitions based on the FetchRequestData that was serialized and parsed.
+        List<TopicIdPartition> convertedFetchData = new LinkedList<>();
+        fetchRequest.data().topics().forEach(topic ->
+                topic.partitions().forEach(partition ->
+                        convertedFetchData.add(new TopicIdPartition(topic.topicId(), partition.partition(), topic.topic()))
+                )
+        );
+        // The TopicIdPartitions built from the request data should match what we expect.
+        assertEquals(expectedData, convertedFetchData);
+
+        // For fetch request version 13+ we expect topic names to be filled in for all topics in the topicNames map.
+        // Otherwise, the topic name should be null.
+        // For earlier request versions, we expect topic names and zero Uuids.
+        Map<TopicIdPartition, FetchRequest.PartitionData> expectedFetchData = new LinkedHashMap<>();
+        // Build the expected map based on fetchRequestUsesTopicIds.
+        expectedData.forEach(tidp -> {
+            String expectedName = fetchRequestUsesTopicIds ? topicNames.get(tidp.topicId()) : tidp.topic();
+            TopicIdPartition tpKey = new TopicIdPartition(tidp.topicId(), new TopicPartition(expectedName, tidp.partition()));
+            // logStartOffset was not a valid field in versions 4 and earlier.
+            int logStartOffset = version > 4 ? 0 : -1;
+            expectedFetchData.put(tpKey, new FetchRequest.PartitionData(tidp.topicId(), 0, logStartOffset, 0, Optional.empty()));
+        });
+        assertEquals(expectedFetchData, fetchRequest.fetchData(topicNames));
+    }
+
+    @ParameterizedTest
+    @MethodSource("fetchVersions")
+    public void testForgottenTopics(short version) {
+        // Forgotten topics are not allowed prior to version 7
+        if (version >= 7) {
+            TopicPartition topicPartition0 = new TopicPartition("topic", 0);
+            TopicPartition topicPartition1 = new TopicPartition("unknownIdTopic", 0);
+            Uuid topicId0 = Uuid.randomUuid();
+            Uuid topicId1 = Uuid.randomUuid();
+            // Only include topic IDs for the first topic partition.
+            Map<Uuid, String> topicNames = Collections.singletonMap(topicId0, topicPartition0.topic());
+
+            // Include one topic with topic IDs in the topic names map and one without.
+            List<TopicIdPartition> toForgetTopics = new LinkedList<>();
+            toForgetTopics.add(new TopicIdPartition(topicId0, topicPartition0));
+            toForgetTopics.add(new TopicIdPartition(topicId1, topicPartition1));
+
+            boolean fetchRequestUsesTopicIds = version >= 13;
+
+            FetchRequest fetchRequest = FetchRequest.parse(FetchRequest.Builder
+                    .forReplica(version, 0, 1, 1, Collections.emptyMap())
+                    .removed(toForgetTopics)
+                    .replaced(Collections.emptyList())
+                    .metadata(FetchMetadata.newIncremental(123)).build(version).serialize(), version);
+
+            // For versions < 13, we will be provided a topic name and a zero Uuid in FetchRequestData.
+            // Versions 13+ will contain a valid topic ID but an empty topic name.
+            List<TopicIdPartition> expectedForgottenTopicData = new LinkedList<>();
+            toForgetTopics.forEach(tidp -> {
+                String expectedName = fetchRequestUsesTopicIds ? "" : tidp.topic();
+                Uuid expectedTopicId = fetchRequestUsesTopicIds ? tidp.topicId() : Uuid.ZERO_UUID;
+                expectedForgottenTopicData.add(new TopicIdPartition(expectedTopicId, tidp.partition(), expectedName));
+            });
+
+            // Build the list of TopicIdPartitions based on the FetchRequestData that was serialized and parsed.
+            List<TopicIdPartition> convertedForgottenTopicData = new LinkedList<>();
+            fetchRequest.data().forgottenTopicsData().forEach(forgottenTopic ->
+                    forgottenTopic.partitions().forEach(partition ->
+                            convertedForgottenTopicData.add(new TopicIdPartition(forgottenTopic.topicId(), partition, forgottenTopic.topic()))
+                    )
+            );
+            // The TopicIdPartitions built from the request data should match what we expect.
+            assertEquals(expectedForgottenTopicData, convertedForgottenTopicData);
+
+            // Get the forgottenTopics from the request data.
+            List<TopicIdPartition> forgottenTopics = fetchRequest.forgottenTopics(topicNames);
+
+            // For fetch request version 13+ we expect topic names to be filled in for all topics in the topicNames map.
+            // Otherwise, the topic name should be null.
+            // For earlier request versions, we expect topic names and zero Uuids.
+            // Build the list of expected TopicIdPartitions. These are different from the earlier expected topicIdPartitions
+            // as empty strings are converted to nulls.
+            assertEquals(expectedForgottenTopicData.size(), forgottenTopics.size());
+            List<TopicIdPartition> expectedForgottenTopics = new LinkedList<>();
+            expectedForgottenTopicData.forEach(tidp -> {
+                String expectedName = fetchRequestUsesTopicIds ? topicNames.get(tidp.topicId()) : tidp.topic();
+                expectedForgottenTopics.add(new TopicIdPartition(tidp.topicId(), new TopicPartition(expectedName, tidp.partition())));
+            });
+            assertEquals(expectedForgottenTopics, forgottenTopics);
+        }
+    }
+
+    @Test
+    public void testPartitionDataEquals() {
+        assertEquals(new FetchRequest.PartitionData(Uuid.ZERO_UUID, 300, 0L, 300, Optional.of(300)),
+                new FetchRequest.PartitionData(Uuid.ZERO_UUID, 300, 0L, 300, Optional.of(300)));
+
+        assertNotEquals(new FetchRequest.PartitionData(Uuid.randomUuid(), 300, 0L, 300, Optional.of(300)),
+            new FetchRequest.PartitionData(Uuid.randomUuid(), 300, 0L, 300, Optional.of(300)));
+    }
+
+}
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
index 0f52478..70b1e05 100644
--- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
@@ -21,6 +21,7 @@ import org.apache.kafka.common.ElectionType;
 import org.apache.kafka.common.IsolationLevel;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.acl.AccessControlEntry;
 import org.apache.kafka.common.acl.AccessControlEntryFilter;
@@ -219,6 +220,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import static java.util.Arrays.asList;
 import static java.util.Collections.emptyList;
@@ -258,10 +260,10 @@ public class RequestResponseTest {
         checkErrorResponse(createControlledShutdownRequest(0), unknownServerException, true);
         checkRequest(createFetchRequest(4), true);
         checkResponse(createFetchResponse(true), 4, true);
-        List<TopicPartition> toForgetTopics = new ArrayList<>();
-        toForgetTopics.add(new TopicPartition("foo", 0));
-        toForgetTopics.add(new TopicPartition("foo", 2));
-        toForgetTopics.add(new TopicPartition("bar", 0));
+        List<TopicIdPartition> toForgetTopics = new ArrayList<>();
+        toForgetTopics.add(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0)));
+        toForgetTopics.add(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 2)));
+        toForgetTopics.add(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("bar", 0)));
         checkRequest(createFetchRequest(7, new FetchMetadata(123, 456), toForgetTopics), true);
         checkResponse(createFetchResponse(123), 7, true);
         checkResponse(createFetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, 123), 7, true);
@@ -834,32 +836,41 @@ public class RequestResponseTest {
 
     @Test
     public void fetchResponseVersionTest() {
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
-        Map<Uuid, String> topicNames = new HashMap<>();
-        Map<String, Uuid> topicIds = new HashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
         Uuid id = Uuid.randomUuid();
-        topicNames.put(id, "test");
-        topicIds.put("test", id);
-
+        Map<Uuid, String> topicNames = Collections.singletonMap(id, "test");
+        TopicPartition tp = new TopicPartition("test", 0);
 
         MemoryRecords records = MemoryRecords.readableRecords(ByteBuffer.allocate(10));
-        responseData.put(new TopicPartition("test", 0),
-                new FetchResponseData.PartitionData()
-                        .setPartitionIndex(0)
-                        .setHighWatermark(1000000)
-                        .setLogStartOffset(-1)
-                        .setRecords(records));
+        FetchResponseData.PartitionData partitionData = new FetchResponseData.PartitionData()
+                .setPartitionIndex(0)
+                .setHighWatermark(1000000)
+                .setLogStartOffset(-1)
+                .setRecords(records);
+
+        // Use zero UUID since we are comparing with old request versions
+        responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, tp), partitionData);
+
+        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> tpResponseData = new LinkedHashMap<>();
+        tpResponseData.put(tp, partitionData);
 
-        FetchResponse v0Response = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, responseData, topicIds);
-        FetchResponse v1Response = FetchResponse.of(Errors.NONE, 10, INVALID_SESSION_ID, responseData, topicIds);
+        FetchResponse v0Response = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, responseData);
+        FetchResponse v1Response = FetchResponse.of(Errors.NONE, 10, INVALID_SESSION_ID, responseData);
         FetchResponse v0Deserialized = FetchResponse.parse(v0Response.serialize((short) 0), (short) 0);
         FetchResponse v1Deserialized = FetchResponse.parse(v1Response.serialize((short) 1), (short) 1);
         assertEquals(0, v0Deserialized.throttleTimeMs(), "Throttle time must be zero");
         assertEquals(10, v1Deserialized.throttleTimeMs(), "Throttle time must be 10");
-        assertEquals(responseData, v0Deserialized.responseData(topicNames, (short) 0), "Response data does not match");
-        assertEquals(responseData, v1Deserialized.responseData(topicNames, (short) 1), "Response data does not match");
+        assertEquals(tpResponseData, v0Deserialized.responseData(topicNames, (short) 0), "Response data does not match");
+        assertEquals(tpResponseData, v1Deserialized.responseData(topicNames, (short) 1), "Response data does not match");
 
-        FetchResponse idTestResponse = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, responseData, topicIds);
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> idResponseData = new LinkedHashMap<>();
+        idResponseData.put(new TopicIdPartition(id, new TopicPartition("test", 0)),
+                new FetchResponseData.PartitionData()
+                        .setPartitionIndex(0)
+                        .setHighWatermark(1000000)
+                        .setLogStartOffset(-1)
+                        .setRecords(records));
+        FetchResponse idTestResponse = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, idResponseData);
         FetchResponse v12Deserialized = FetchResponse.parse(idTestResponse.serialize((short) 12), (short) 12);
         FetchResponse newestDeserialized = FetchResponse.parse(idTestResponse.serialize(FETCH.latestVersion()), FETCH.latestVersion());
         assertTrue(v12Deserialized.topicIds().isEmpty());
@@ -869,40 +880,41 @@ public class RequestResponseTest {
 
     @Test
     public void testFetchResponseV4() {
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
         Map<Uuid, String> topicNames = new HashMap<>();
-        Map<String, Uuid> topicIds = new HashMap<>();
         topicNames.put(Uuid.randomUuid(), "bar");
         topicNames.put(Uuid.randomUuid(), "foo");
-        topicNames.forEach((id, name) -> topicIds.put(name, id));
         MemoryRecords records = MemoryRecords.readableRecords(ByteBuffer.allocate(10));
 
         List<FetchResponseData.AbortedTransaction> abortedTransactions = asList(
                 new FetchResponseData.AbortedTransaction().setProducerId(10).setFirstOffset(100),
                 new FetchResponseData.AbortedTransaction().setProducerId(15).setFirstOffset(50)
         );
-        responseData.put(new TopicPartition("bar", 0),
+
+        // Use zero UUID since this is an old request version.
+        responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("bar", 0)),
                 new FetchResponseData.PartitionData()
                         .setPartitionIndex(0)
                         .setHighWatermark(1000000)
                         .setAbortedTransactions(abortedTransactions)
                         .setRecords(records));
-        responseData.put(new TopicPartition("bar", 1),
+        responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("bar", 1)),
                 new FetchResponseData.PartitionData()
                         .setPartitionIndex(1)
                         .setHighWatermark(900000)
                         .setLastStableOffset(5)
                         .setRecords(records));
-        responseData.put(new TopicPartition("foo", 0),
+        responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0)),
                 new FetchResponseData.PartitionData()
                         .setPartitionIndex(0)
                         .setHighWatermark(70000)
                         .setLastStableOffset(6)
                         .setRecords(records));
 
-        FetchResponse response = FetchResponse.of(Errors.NONE, 10, INVALID_SESSION_ID, responseData, topicIds);
+        FetchResponse response = FetchResponse.of(Errors.NONE, 10, INVALID_SESSION_ID, responseData);
         FetchResponse deserialized = FetchResponse.parse(response.serialize((short) 4), (short) 4);
-        assertEquals(responseData, deserialized.responseData(topicNames, (short) 4));
+        assertEquals(responseData.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().topicPartition(), Map.Entry::getValue)),
+                deserialized.responseData(topicNames, (short) 4));
     }
 
     @Test
@@ -1015,10 +1027,9 @@ public class RequestResponseTest {
     @Test
     public void testFetchRequestCompat() {
         Map<TopicPartition, FetchRequest.PartitionData> fetchData = new HashMap<>();
-        fetchData.put(new TopicPartition("test", 0), new FetchRequest.PartitionData(100, 2, 100, Optional.of(42)));
-        Map<String, Uuid> topicIds = Collections.singletonMap("test1", Uuid.randomUuid());
+        fetchData.put(new TopicPartition("test", 0), new FetchRequest.PartitionData(Uuid.ZERO_UUID, 100, 2, 100, Optional.of(42)));
         FetchRequest req = FetchRequest.Builder
-                .forConsumer((short) 2, 100, 100, fetchData, topicIds)
+                .forConsumer((short) 2, 100, 100, fetchData)
                 .metadata(new FetchMetadata(10, 20))
                 .isolationLevel(IsolationLevel.READ_COMMITTED)
                 .build((short) 2);
@@ -1286,75 +1297,66 @@ public class RequestResponseTest {
             return FindCoordinatorResponse.prepareResponse(Errors.NONE, "group", node);
     }
 
-    private FetchRequest createFetchRequest(int version, FetchMetadata metadata, List<TopicPartition> toForget) {
+    private FetchRequest createFetchRequest(int version, FetchMetadata metadata, List<TopicIdPartition> toForget) {
         LinkedHashMap<TopicPartition, FetchRequest.PartitionData> fetchData = new LinkedHashMap<>();
-        fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, -1L,
-                1000000, Optional.empty()));
-        fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, -1L,
-                1000000, Optional.empty()));
-        Map<String, Uuid> topicIds = new HashMap<>();
-        topicIds.put("test1", Uuid.randomUuid());
-        topicIds.put("test2", Uuid.randomUuid());
-        return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData, topicIds).
-            metadata(metadata).setMaxBytes(1000).toForget(toForget).build((short) version);
+        fetchData.put(new TopicPartition("test1", 0),
+                new FetchRequest.PartitionData(Uuid.randomUuid(), 100, -1L, 1000000, Optional.empty()));
+        fetchData.put(new TopicPartition("test2", 0),
+                new FetchRequest.PartitionData(Uuid.randomUuid(), 200, -1L, 1000000, Optional.empty()));
+        return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData).
+            metadata(metadata).setMaxBytes(1000).removed(toForget).build((short) version);
     }
 
     private FetchRequest createFetchRequest(int version, IsolationLevel isolationLevel) {
         LinkedHashMap<TopicPartition, FetchRequest.PartitionData> fetchData = new LinkedHashMap<>();
-        fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, -1L,
-                1000000, Optional.empty()));
-        fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, -1L,
-                1000000, Optional.empty()));
-        Map<String, Uuid> topicIds = new HashMap<>();
-        topicIds.put("test1", Uuid.randomUuid());
-        topicIds.put("test2", Uuid.randomUuid());
-        return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData, topicIds).
+        fetchData.put(new TopicPartition("test1", 0),
+                new FetchRequest.PartitionData(Uuid.randomUuid(), 100, -1L, 1000000, Optional.empty()));
+        fetchData.put(new TopicPartition("test2", 0),
+                new FetchRequest.PartitionData(Uuid.randomUuid(), 200, -1L, 1000000, Optional.empty()));
+        return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData).
             isolationLevel(isolationLevel).setMaxBytes(1000).build((short) version);
     }
 
     private FetchRequest createFetchRequest(int version) {
         LinkedHashMap<TopicPartition, FetchRequest.PartitionData> fetchData = new LinkedHashMap<>();
-        fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, -1L,
-                1000000, Optional.empty()));
-        fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, -1L,
-                1000000, Optional.empty()));
-        Map<String, Uuid> topicIds = new HashMap<>();
-        topicIds.put("test1", Uuid.randomUuid());
-        topicIds.put("test2", Uuid.randomUuid());
-        return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData, topicIds).setMaxBytes(1000).build((short) version);
+        fetchData.put(new TopicPartition("test1", 0),
+                new FetchRequest.PartitionData(Uuid.randomUuid(), 100, -1L, 1000000, Optional.empty()));
+        fetchData.put(new TopicPartition("test2", 0),
+                new FetchRequest.PartitionData(Uuid.randomUuid(), 200, -1L, 1000000, Optional.empty()));
+        return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData).setMaxBytes(1000).build((short) version);
     }
 
     private FetchResponse createFetchResponse(Errors error, int sessionId) {
         return FetchResponse.parse(
-            FetchResponse.of(error, 25, sessionId, new LinkedHashMap<>(),
-                             new HashMap<>()).serialize(FETCH.latestVersion()), FETCH.latestVersion());
+            FetchResponse.of(error, 25, sessionId, new LinkedHashMap<>()).serialize(FETCH.latestVersion()), FETCH.latestVersion());
     }
 
     private FetchResponse createFetchResponse(int sessionId) {
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
         Map<String, Uuid> topicIds = new HashMap<>();
         topicIds.put("test", Uuid.randomUuid());
         MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes()));
-        responseData.put(new TopicPartition("test", 0), new FetchResponseData.PartitionData()
+        responseData.put(new TopicIdPartition(topicIds.get("test"), new TopicPartition("test", 0)), new FetchResponseData.PartitionData()
                         .setPartitionIndex(0)
                         .setHighWatermark(1000000)
                         .setLogStartOffset(0)
                         .setRecords(records));
         List<FetchResponseData.AbortedTransaction> abortedTransactions = Collections.singletonList(
             new FetchResponseData.AbortedTransaction().setProducerId(234L).setFirstOffset(999L));
-        responseData.put(new TopicPartition("test", 1), new FetchResponseData.PartitionData()
+        responseData.put(new TopicIdPartition(topicIds.get("test"), new TopicPartition("test", 1)), new FetchResponseData.PartitionData()
                         .setPartitionIndex(1)
                         .setHighWatermark(1000000)
                         .setLogStartOffset(0)
                         .setAbortedTransactions(abortedTransactions));
         return FetchResponse.parse(FetchResponse.of(Errors.NONE, 25, sessionId,
-            responseData, topicIds).serialize(FETCH.latestVersion()), FETCH.latestVersion());
+            responseData).serialize(FETCH.latestVersion()), FETCH.latestVersion());
     }
 
     private FetchResponse createFetchResponse(boolean includeAborted) {
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseData = new LinkedHashMap<>();
+        Uuid topicId = Uuid.randomUuid();
         MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes()));
-        responseData.put(new TopicPartition("test", 0), new FetchResponseData.PartitionData()
+        responseData.put(new TopicIdPartition(topicId, new TopicPartition("test", 0)), new FetchResponseData.PartitionData()
                         .setPartitionIndex(0)
                         .setHighWatermark(1000000)
                         .setLogStartOffset(0)
@@ -1365,13 +1367,13 @@ public class RequestResponseTest {
             abortedTransactions = Collections.singletonList(
                     new FetchResponseData.AbortedTransaction().setProducerId(234L).setFirstOffset(999L));
         }
-        responseData.put(new TopicPartition("test", 1), new FetchResponseData.PartitionData()
+        responseData.put(new TopicIdPartition(topicId, new TopicPartition("test", 1)), new FetchResponseData.PartitionData()
                         .setPartitionIndex(1)
                         .setHighWatermark(1000000)
                         .setLogStartOffset(0)
                         .setAbortedTransactions(abortedTransactions));
         return FetchResponse.parse(FetchResponse.of(Errors.NONE, 25, INVALID_SESSION_ID,
-            responseData, Collections.singletonMap("test", Uuid.randomUuid())).serialize(FETCH.latestVersion()), FETCH.latestVersion());
+            responseData).serialize(FETCH.latestVersion()), FETCH.latestVersion());
     }
 
     private HeartbeatRequest createHeartBeatRequest() {
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java
index 8d8f4ca..e2366ce 100644
--- a/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java
+++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java
@@ -16,7 +16,10 @@
  */
 package org.apache.kafka.common.requests;
 
+import java.util.HashMap;
+import java.util.Set;
 import org.apache.kafka.common.Node;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.internals.Topic;
@@ -150,6 +153,20 @@ public class RequestTestUtils {
     }
 
     public static MetadataResponse metadataUpdateWithIds(final int numNodes,
+                                                         final Set<TopicIdPartition> partitions,
+                                                         final Function<TopicPartition, Integer> epochSupplier) {
+        final Map<String, Integer> topicPartitionCounts = new HashMap<>();
+        final Map<String, Uuid> topicIds = new HashMap<>();
+
+        partitions.forEach(partition -> {
+            topicPartitionCounts.compute(partition.topic(), (key, value) -> value == null ? 1 : value + 1);
+            topicIds.putIfAbsent(partition.topic(), partition.topicId());
+        });
+
+        return metadataUpdateWithIds(numNodes, topicPartitionCounts, epochSupplier, topicIds);
+    }
+
+    public static MetadataResponse metadataUpdateWithIds(final int numNodes,
                                                          final Map<String, Integer> topicPartitionCounts,
                                                          final Function<TopicPartition, Integer> epochSupplier,
                                                          final Map<String, Uuid> topicIds) {
diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
index eef7ecb..492cec4 100755
--- a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala
@@ -374,7 +374,7 @@ abstract class AbstractFetcherThread(name: String,
                       }
                     }
                   } catch {
-                    case ime@( _: CorruptRecordException | _: InvalidRecordException) =>
+                    case ime@(_: CorruptRecordException | _: InvalidRecordException) =>
                       // we log the error and continue. This ensures two things
                       // 1. If there is a corrupt message in a topic partition, it does not bring the fetcher thread
                       //    down and cause other topic partition to also lag
@@ -413,8 +413,20 @@ abstract class AbstractFetcherThread(name: String,
 
                 case Errors.UNKNOWN_TOPIC_OR_PARTITION =>
                   warn(s"Received ${Errors.UNKNOWN_TOPIC_OR_PARTITION} from the leader for partition $topicPartition. " +
-                       "This error may be returned transiently when the partition is being created or deleted, but it is not " +
-                       "expected to persist.")
+                    "This error may be returned transiently when the partition is being created or deleted, but it is not " +
+                    "expected to persist.")
+                  partitionsWithError += topicPartition
+
+                case Errors.UNKNOWN_TOPIC_ID =>
+                  warn(s"Received ${Errors.UNKNOWN_TOPIC_ID} from the leader for partition $topicPartition. " +
+                    "This error may be returned transiently when the partition is being created or deleted, but it is not " +
+                    "expected to persist.")
+                  partitionsWithError += topicPartition
+
+                case Errors.INCONSISTENT_TOPIC_ID =>
+                  warn(s"Received ${Errors.INCONSISTENT_TOPIC_ID} from the leader for partition $topicPartition. " +
+                    "This error may be returned transiently when the partition is being created or deleted, but it is not " +
+                    "expected to persist.")
                   partitionsWithError += topicPartition
 
                 case partitionError =>
diff --git a/core/src/main/scala/kafka/server/DelayedFetch.scala b/core/src/main/scala/kafka/server/DelayedFetch.scala
index d32ae13..1bc2a73 100644
--- a/core/src/main/scala/kafka/server/DelayedFetch.scala
+++ b/core/src/main/scala/kafka/server/DelayedFetch.scala
@@ -17,11 +17,10 @@
 
 package kafka.server
 
-import java.util
 import java.util.concurrent.TimeUnit
 
 import kafka.metrics.KafkaMetricsGroup
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.TopicIdPartition
 import org.apache.kafka.common.errors._
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.replica.ClientMetadata
@@ -49,8 +48,7 @@ case class FetchMetadata(fetchMinBytes: Int,
                          fetchIsolation: FetchIsolation,
                          isFromFollower: Boolean,
                          replicaId: Int,
-                         topicIds: util.Map[String, Uuid],
-                         fetchPartitionStatus: Seq[(TopicPartition, FetchPartitionStatus)]) {
+                         fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)]) {
 
   override def toString = "FetchMetadata(minBytes=" + fetchMinBytes + ", " +
     "maxBytes=" + fetchMaxBytes + ", " +
@@ -68,7 +66,7 @@ class DelayedFetch(delayMs: Long,
                    replicaManager: ReplicaManager,
                    quota: ReplicaQuota,
                    clientMetadata: Option[ClientMetadata],
-                   responseCallback: Seq[(TopicPartition, FetchPartitionData)] => Unit)
+                   responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit)
   extends DelayedOperation(delayMs) {
 
   /**
@@ -87,12 +85,12 @@ class DelayedFetch(delayMs: Long,
   override def tryComplete(): Boolean = {
     var accumulatedSize = 0
     fetchMetadata.fetchPartitionStatus.foreach {
-      case (topicPartition, fetchStatus) =>
+      case (topicIdPartition, fetchStatus) =>
         val fetchOffset = fetchStatus.startOffsetMetadata
         val fetchLeaderEpoch = fetchStatus.fetchInfo.currentLeaderEpoch
         try {
           if (fetchOffset != LogOffsetMetadata.UnknownOffsetMetadata) {
-            val partition = replicaManager.getPartitionOrException(topicPartition)
+            val partition = replicaManager.getPartitionOrException(topicIdPartition.topicPartition)
             val offsetSnapshot = partition.fetchOffsetSnapshot(fetchLeaderEpoch, fetchMetadata.fetchOnlyLeader)
 
             val endOffset = fetchMetadata.fetchIsolation match {
@@ -107,7 +105,7 @@ class DelayedFetch(delayMs: Long,
             if (endOffset.messageOffset != fetchOffset.messageOffset) {
               if (endOffset.onOlderSegment(fetchOffset)) {
                 // Case F, this can happen when the new fetch operation is on a truncated leader
-                debug(s"Satisfying fetch $fetchMetadata since it is fetching later segments of partition $topicPartition.")
+                debug(s"Satisfying fetch $fetchMetadata since it is fetching later segments of partition $topicIdPartition.")
                 return forceComplete()
               } else if (fetchOffset.onOlderSegment(endOffset)) {
                 // Case F, this can happen when the fetch operation is falling behind the current segment
@@ -130,27 +128,27 @@ class DelayedFetch(delayMs: Long,
               if (epochEndOffset.errorCode != Errors.NONE.code()
                   || epochEndOffset.endOffset == UNDEFINED_EPOCH_OFFSET
                   || epochEndOffset.leaderEpoch == UNDEFINED_EPOCH) {
-                debug(s"Could not obtain last offset for leader epoch for partition $topicPartition, epochEndOffset=$epochEndOffset.")
+                debug(s"Could not obtain last offset for leader epoch for partition $topicIdPartition, epochEndOffset=$epochEndOffset.")
                 return forceComplete()
               } else if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset < fetchStatus.fetchInfo.fetchOffset) {
                 debug(s"Satisfying fetch $fetchMetadata since it has diverging epoch requiring truncation for partition " +
-                  s"$topicPartition epochEndOffset=$epochEndOffset fetchEpoch=$fetchEpoch fetchOffset=${fetchStatus.fetchInfo.fetchOffset}.")
+                  s"$topicIdPartition epochEndOffset=$epochEndOffset fetchEpoch=$fetchEpoch fetchOffset=${fetchStatus.fetchInfo.fetchOffset}.")
                 return forceComplete()
               }
             }
           }
         } catch {
           case _: NotLeaderOrFollowerException =>  // Case A or Case B
-            debug(s"Broker is no longer the leader or follower of $topicPartition, satisfy $fetchMetadata immediately")
+            debug(s"Broker is no longer the leader or follower of $topicIdPartition, satisfy $fetchMetadata immediately")
             return forceComplete()
           case _: UnknownTopicOrPartitionException => // Case C
-            debug(s"Broker no longer knows of partition $topicPartition, satisfy $fetchMetadata immediately")
+            debug(s"Broker no longer knows of partition $topicIdPartition, satisfy $fetchMetadata immediately")
             return forceComplete()
           case _: KafkaStorageException => // Case D
-            debug(s"Partition $topicPartition is in an offline log directory, satisfy $fetchMetadata immediately")
+            debug(s"Partition $topicIdPartition is in an offline log directory, satisfy $fetchMetadata immediately")
             return forceComplete()
           case _: FencedLeaderEpochException => // Case E
-            debug(s"Broker is the leader of partition $topicPartition, but the requested epoch " +
+            debug(s"Broker is the leader of partition $topicIdPartition, but the requested epoch " +
               s"$fetchLeaderEpoch is fenced by the latest leader epoch, satisfy $fetchMetadata immediately")
             return forceComplete()
         }
@@ -181,13 +179,12 @@ class DelayedFetch(delayMs: Long,
       fetchMaxBytes = fetchMetadata.fetchMaxBytes,
       hardMaxBytesLimit = fetchMetadata.hardMaxBytesLimit,
       readPartitionInfo = fetchMetadata.fetchPartitionStatus.map { case (tp, status) => tp -> status.fetchInfo },
-      topicIds = fetchMetadata.topicIds,
       clientMetadata = clientMetadata,
       quota = quota)
 
     val fetchPartitionData = logReadResults.map { case (tp, result) =>
       val isReassignmentFetch = fetchMetadata.isFromFollower &&
-        replicaManager.isAddingReplica(tp, fetchMetadata.replicaId)
+        replicaManager.isAddingReplica(tp.topicPartition, fetchMetadata.replicaId)
 
       tp -> result.toFetchPartitionData(isReassignmentFetch)
     }
diff --git a/core/src/main/scala/kafka/server/DelayedOperationKey.scala b/core/src/main/scala/kafka/server/DelayedOperationKey.scala
index 05a6a99..13ed462 100644
--- a/core/src/main/scala/kafka/server/DelayedOperationKey.scala
+++ b/core/src/main/scala/kafka/server/DelayedOperationKey.scala
@@ -17,7 +17,7 @@
 
 package kafka.server
 
-import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition}
 
 /**
  * Keys used for delayed operation metrics recording
@@ -39,6 +39,9 @@ object TopicPartitionOperationKey {
   def apply(topicPartition: TopicPartition): TopicPartitionOperationKey = {
     apply(topicPartition.topic, topicPartition.partition)
   }
+  def apply(topicIdPartition: TopicIdPartition): TopicPartitionOperationKey = {
+    apply(topicIdPartition.topic, topicIdPartition.partition)
+  }
 }
 
 /* used by delayed-join-group operations */
diff --git a/core/src/main/scala/kafka/server/FetchSession.scala b/core/src/main/scala/kafka/server/FetchSession.scala
index cec3dfe..f7d348d 100644
--- a/core/src/main/scala/kafka/server/FetchSession.scala
+++ b/core/src/main/scala/kafka/server/FetchSession.scala
@@ -19,32 +19,32 @@ package kafka.server
 
 import kafka.metrics.KafkaMetricsGroup
 import kafka.utils.Logging
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INITIAL_EPOCH, INVALID_SESSION_ID}
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata}
 import org.apache.kafka.common.utils.{ImplicitLinkedHashCollection, Time, Utils}
 import java.util
-import java.util.{Collections, Objects, Optional}
+import java.util.Optional
 import java.util.concurrent.{ThreadLocalRandom, TimeUnit}
 
 import scala.collection.{mutable, _}
 import scala.math.Ordered.orderingToOrdered
 
 object FetchSession {
-  type REQ_MAP = util.Map[TopicPartition, FetchRequest.PartitionData]
-  type RESP_MAP = util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+  type REQ_MAP = util.Map[TopicIdPartition, FetchRequest.PartitionData]
+  type RESP_MAP = util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
   type CACHE_MAP = ImplicitLinkedHashCollection[CachedPartition]
-  type RESP_MAP_ITER = util.Iterator[util.Map.Entry[TopicPartition, FetchResponseData.PartitionData]]
-  type TOPIC_ID_MAP = util.Map[String, Uuid]
+  type RESP_MAP_ITER = util.Iterator[util.Map.Entry[TopicIdPartition, FetchResponseData.PartitionData]]
+  type TOPIC_NAME_MAP = util.Map[Uuid, String]
 
   val NUM_INCREMENTAL_FETCH_SESSIONS = "NumIncrementalFetchSessions"
   val NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED = "NumIncrementalFetchPartitionsCached"
   val INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC = "IncrementalFetchSessionEvictionsPerSec"
   val EVICTIONS = "evictions"
 
-  def partitionsToLogString(partitions: util.Collection[TopicPartition], traceEnabled: Boolean): String = {
+  def partitionsToLogString(partitions: util.Collection[TopicIdPartition], traceEnabled: Boolean): String = {
     if (traceEnabled) {
       "(" + Utils.join(partitions, ", ") + ")"
     } else {
@@ -70,7 +70,7 @@ object FetchSession {
   * Note that fetcherLogStartOffset is the LSO of the follower performing the fetch, whereas
   * localLogStartOffset is the log start offset of the partition on this broker.
   */
-class CachedPartition(val topic: String,
+class CachedPartition(var topic: String,
                       val topicId: Uuid,
                       val partition: Int,
                       var maxBytes: Int,
@@ -90,22 +90,23 @@ class CachedPartition(val topic: String,
   override def prev: Int = cachedPrev
   override def setPrev(prev: Int): Unit = this.cachedPrev = prev
 
-  def this(topic: String, partition: Int, topicId: Uuid) =
+  def this(topic: String, topicId: Uuid, partition: Int) =
     this(topic, topicId, partition, -1, -1, -1, Optional.empty(), -1, -1, Optional.empty[Integer])
 
-  def this(part: TopicPartition, topicId: Uuid) =
-    this(part.topic, part.partition, topicId)
+  def this(part: TopicIdPartition) = {
+    this(part.topic, part.topicId, part.partition)
+  }
 
-  def this(part: TopicPartition, id: Uuid, reqData: FetchRequest.PartitionData) =
-    this(part.topic, id, part.partition, reqData.maxBytes, reqData.fetchOffset, -1,
+  def this(part: TopicIdPartition, reqData: FetchRequest.PartitionData) =
+    this(part.topic, part.topicId, part.partition, reqData.maxBytes, reqData.fetchOffset, -1,
       reqData.currentLeaderEpoch, reqData.logStartOffset, -1, reqData.lastFetchedEpoch)
 
-  def this(part: TopicPartition, id: Uuid, reqData: FetchRequest.PartitionData,
+  def this(part: TopicIdPartition, reqData: FetchRequest.PartitionData,
            respData: FetchResponseData.PartitionData) =
-    this(part.topic, id, part.partition, reqData.maxBytes, reqData.fetchOffset, respData.highWatermark,
+    this(part.topic, part.topicId, part.partition, reqData.maxBytes, reqData.fetchOffset, respData.highWatermark,
       reqData.currentLeaderEpoch, reqData.logStartOffset, respData.logStartOffset, reqData.lastFetchedEpoch)
 
-  def reqData = new FetchRequest.PartitionData(fetchOffset, fetcherLogStartOffset, maxBytes, leaderEpoch, lastFetchedEpoch)
+  def reqData = new FetchRequest.PartitionData(topicId, fetchOffset, fetcherLogStartOffset, maxBytes, leaderEpoch, lastFetchedEpoch)
 
   def updateRequestParams(reqData: FetchRequest.PartitionData): Unit = {
     // Update our cached request parameters.
@@ -116,6 +117,12 @@ class CachedPartition(val topic: String,
     lastFetchedEpoch = reqData.lastFetchedEpoch
   }
 
+  def maybeResolveUnknownName(topicNames: FetchSession.TOPIC_NAME_MAP): Unit = {
+    if (this.topic == null) {
+      this.topic = topicNames.get(this.topicId)
+    }
+  }
+
   /**
     * Determine whether or not the specified cached partition should be included in the FetchResponse we send back to
     * the fetcher and update it if requested.
@@ -163,18 +170,35 @@ class CachedPartition(val topic: String,
     mustRespond
   }
 
-  override def hashCode: Int = Objects.hash(new TopicPartition(topic, partition), topicId)
-
-  def canEqual(that: Any): Boolean = that.isInstanceOf[CachedPartition]
+  /**
+   * We have different equality checks depending on whether topic IDs are used.
+   * This means we need a different hash function as well. We use name to calculate the hash if the ID is zero and unused.
+   * Otherwise, we use the topic ID in the hash calculation.
+   *
+   * @return the hash code for the CachedPartition depending on what request version we are using.
+   */
+  override def hashCode: Int =
+    if (topicId != Uuid.ZERO_UUID)
+      (31 * partition) + topicId.hashCode
+    else
+      (31 * partition) + topic.hashCode
 
+  /**
+   * We have different equality checks depending on whether topic IDs are used.
+   *
+   * This is because when we use topic IDs, a partition with a given ID and an unknown name is the same as a partition with that
+   * ID and a known name. This means we can only use topic ID and partition when determining equality.
+   *
+   * On the other hand, if we are using topic names, all IDs are zero. This means we can only use topic name and partition
+   * when determining equality.
+   */
   override def equals(that: Any): Boolean =
     that match {
       case that: CachedPartition =>
-        this.eq(that) ||
-          (that.canEqual(this) &&
-            this.partition.equals(that.partition) &&
-            this.topic.equals(that.topic) &&
-            this.topicId.equals(that.topicId))
+        this.eq(that) || (if (this.topicId != Uuid.ZERO_UUID)
+          this.partition.equals(that.partition) && this.topicId.equals(that.topicId)
+        else
+          this.partition.equals(that.partition) && this.topic.equals(that.topic))
       case _ => false
     }
 
@@ -202,7 +226,6 @@ class CachedPartition(val topic: String,
   *                           are privileged; session created by consumers are not.
   * @param partitionMap       The CachedPartitionMap.
   * @param usesTopicIds       True if this session is using topic IDs
-  * @param sessionTopicIds    The mapping from topic name to topic ID for topics in the session.
   * @param creationMs         The time in milliseconds when this session was created.
   * @param lastUsedMs         The last used time in milliseconds.  This should only be updated by
   *                           FetchSessionCache#touch.
@@ -212,7 +235,6 @@ class FetchSession(val id: Int,
                    val privileged: Boolean,
                    val partitionMap: FetchSession.CACHE_MAP,
                    val usesTopicIds: Boolean,
-                   val sessionTopicIds: FetchSession.TOPIC_ID_MAP,
                    val creationMs: Long,
                    var lastUsedMs: Long,
                    var epoch: Int) {
@@ -238,35 +260,24 @@ class FetchSession(val id: Int,
 
   def metadata: JFetchMetadata = synchronized { new JFetchMetadata(id, epoch) }
 
-  def getFetchOffset(topicPartition: TopicPartition): Option[Long] = synchronized {
-    Option(partitionMap.find(new CachedPartition(topicPartition,
-      sessionTopicIds.getOrDefault(topicPartition.topic(), Uuid.ZERO_UUID)))).map(_.fetchOffset)
+  def getFetchOffset(topicIdPartition: TopicIdPartition): Option[Long] = synchronized {
+    Option(partitionMap.find(new CachedPartition(topicIdPartition))).map(_.fetchOffset)
   }
 
-  type TL = util.ArrayList[TopicPartition]
+  type TL = util.ArrayList[TopicIdPartition]
 
   // Update the cached partition data based on the request.
   def update(fetchData: FetchSession.REQ_MAP,
-             toForget: util.List[TopicPartition],
-             reqMetadata: JFetchMetadata,
-             topicIds: util.Map[String, Uuid]): (TL, TL, TL, TL) = synchronized {
+             toForget: util.List[TopicIdPartition],
+             reqMetadata: JFetchMetadata): (TL, TL, TL) = synchronized {
     val added = new TL
     val updated = new TL
     val removed = new TL
-    val inconsistentTopicIds = new TL
     fetchData.forEach { (topicPart, reqData) =>
-      // Get the topic ID on the broker, if it is valid and the topic is new to the session, add its ID.
-      // If the topic already existed, check that its ID is consistent.
-      val id = topicIds.getOrDefault(topicPart.topic, Uuid.ZERO_UUID)
-      val newCachedPart = new CachedPartition(topicPart, id, reqData)
-      if (id != Uuid.ZERO_UUID) {
-        val prevSessionTopicId = sessionTopicIds.putIfAbsent(topicPart.topic, id)
-        if (prevSessionTopicId != null && prevSessionTopicId != id)
-          inconsistentTopicIds.add(topicPart)
-      }
-      val cachedPart = partitionMap.find(newCachedPart)
+      val cachedPartitionKey = new CachedPartition(topicPart, reqData)
+      val cachedPart = partitionMap.find(cachedPartitionKey)
       if (cachedPart == null) {
-        partitionMap.mustAdd(newCachedPart)
+        partitionMap.mustAdd(cachedPartitionKey)
         added.add(topicPart)
       } else {
         cachedPart.updateRequestParams(reqData)
@@ -274,11 +285,11 @@ class FetchSession(val id: Int,
       }
     }
     toForget.forEach { p =>
-      if (partitionMap.remove(new CachedPartition(p.topic, p.partition, topicIds.getOrDefault(p.topic, Uuid.ZERO_UUID)))) {
+      if (partitionMap.remove(new CachedPartition(p))) {
         removed.add(p)
       }
     }
-    (added, updated, removed, inconsistentTopicIds)
+    (added, updated, removed)
   }
 
   override def toString: String = synchronized {
@@ -296,12 +307,12 @@ trait FetchContext extends Logging {
   /**
     * Get the fetch offset for a given partition.
     */
-  def getFetchOffset(part: TopicPartition): Option[Long]
+  def getFetchOffset(part: TopicIdPartition): Option[Long]
 
   /**
     * Apply a function to each partition in the fetch request.
     */
-  def foreachPartition(fun: (TopicPartition, Uuid, FetchRequest.PartitionData) => Unit): Unit
+  def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit
 
   /**
     * Get the response size to be used for quota computation. Since we are returning an empty response in case of
@@ -315,14 +326,14 @@ trait FetchContext extends Logging {
     */
   def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse
 
-  def partitionsToLogString(partitions: util.Collection[TopicPartition]): String =
+  def partitionsToLogString(partitions: util.Collection[TopicIdPartition]): String =
     FetchSession.partitionsToLogString(partitions, isTraceEnabled)
 
   /**
     * Return an empty throttled response due to quota violation.
     */
   def getThrottledResponse(throttleTimeMs: Int): FetchResponse =
-    FetchResponse.of(Errors.NONE, throttleTimeMs, INVALID_SESSION_ID, new FetchSession.RESP_MAP, Collections.emptyMap())
+    FetchResponse.of(Errors.NONE, throttleTimeMs, INVALID_SESSION_ID, new FetchSession.RESP_MAP)
 }
 
 /**
@@ -330,18 +341,18 @@ trait FetchContext extends Logging {
   */
 class SessionErrorContext(val error: Errors,
                           val reqMetadata: JFetchMetadata) extends FetchContext {
-  override def getFetchOffset(part: TopicPartition): Option[Long] = None
+  override def getFetchOffset(part: TopicIdPartition): Option[Long] = None
 
-  override def foreachPartition(fun: (TopicPartition, Uuid, FetchRequest.PartitionData) => Unit): Unit = {}
+  override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = {}
 
   override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int = {
-    FetchResponse.sizeOf(versionId, (new FetchSession.RESP_MAP).entrySet.iterator, Collections.emptyMap())
+    FetchResponse.sizeOf(versionId, (new FetchSession.RESP_MAP).entrySet.iterator)
   }
 
   // Because of the fetch session error, we don't know what partitions were supposed to be in this request.
   override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = {
     debug(s"Session error fetch context returning $error")
-    FetchResponse.of(error, 0, INVALID_SESSION_ID, new FetchSession.RESP_MAP, Collections.emptyMap())
+    FetchResponse.of(error, 0, INVALID_SESSION_ID, new FetchSession.RESP_MAP)
   }
 }
 
@@ -349,24 +360,22 @@ class SessionErrorContext(val error: Errors,
   * The fetch context for a sessionless fetch request.
   *
   * @param fetchData          The partition data from the fetch request.
-  * @param topicIds           The map from topic names to topic IDs.
   */
-class SessionlessFetchContext(val fetchData: util.Map[TopicPartition, FetchRequest.PartitionData],
-                              val topicIds: util.Map[String, Uuid]) extends FetchContext {
-  override def getFetchOffset(part: TopicPartition): Option[Long] =
+class SessionlessFetchContext(val fetchData: util.Map[TopicIdPartition, FetchRequest.PartitionData]) extends FetchContext {
+  override def getFetchOffset(part: TopicIdPartition): Option[Long] =
     Option(fetchData.get(part)).map(_.fetchOffset)
 
-  override def foreachPartition(fun: (TopicPartition, Uuid, FetchRequest.PartitionData) => Unit): Unit = {
-    fetchData.forEach((tp, data) => fun(tp, topicIds.get(tp.topic), data))
+  override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = {
+    fetchData.forEach((tp, data) => fun(tp, data))
   }
 
   override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int = {
-    FetchResponse.sizeOf(versionId, updates.entrySet.iterator, topicIds)
+    FetchResponse.sizeOf(versionId, updates.entrySet.iterator)
   }
 
   override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = {
     debug(s"Sessionless fetch context returning ${partitionsToLogString(updates.keySet)}")
-    FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, updates, topicIds)
+    FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, updates)
   }
 }
 
@@ -378,54 +387,39 @@ class SessionlessFetchContext(val fetchData: util.Map[TopicPartition, FetchReque
   * @param reqMetadata        The request metadata.
   * @param fetchData          The partition data from the fetch request.
   * @param usesTopicIds       True if this session should use topic IDs.
-  * @param topicIds           The map from topic names to topic IDs.
   * @param isFromFollower     True if this fetch request came from a follower.
   */
 class FullFetchContext(private val time: Time,
                        private val cache: FetchSessionCache,
                        private val reqMetadata: JFetchMetadata,
-                       private val fetchData: util.Map[TopicPartition, FetchRequest.PartitionData],
+                       private val fetchData: util.Map[TopicIdPartition, FetchRequest.PartitionData],
                        private val usesTopicIds: Boolean,
-                       private val topicIds: util.Map[String, Uuid],
                        private val isFromFollower: Boolean) extends FetchContext {
-  override def getFetchOffset(part: TopicPartition): Option[Long] =
+  override def getFetchOffset(part: TopicIdPartition): Option[Long] =
     Option(fetchData.get(part)).map(_.fetchOffset)
 
-  override def foreachPartition(fun: (TopicPartition, Uuid, FetchRequest.PartitionData) => Unit): Unit = {
-    fetchData.forEach((tp, data) => fun(tp, topicIds.get(tp.topic), data))
+  override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = {
+    fetchData.forEach((tp, data) => fun(tp, data))
   }
 
   override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int = {
-    FetchResponse.sizeOf(versionId, updates.entrySet.iterator, topicIds)
+    FetchResponse.sizeOf(versionId, updates.entrySet.iterator)
   }
 
   override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = {
-    var hasInconsistentTopicIds = false
-    def createNewSession: (FetchSession.CACHE_MAP, FetchSession.TOPIC_ID_MAP) = {
+    def createNewSession: FetchSession.CACHE_MAP = {
       val cachedPartitions = new FetchSession.CACHE_MAP(updates.size)
-      val sessionTopicIds = new util.HashMap[String, Uuid](updates.size)
       updates.forEach { (part, respData) =>
-        if (respData.errorCode() == Errors.INCONSISTENT_TOPIC_ID.code()) {
-          info(s"Session encountered an inconsistent topic ID for topicPartition $part.")
-          hasInconsistentTopicIds = true
-        }
         val reqData = fetchData.get(part)
-        val id = topicIds.getOrDefault(part.topic(), Uuid.ZERO_UUID)
-        cachedPartitions.mustAdd(new CachedPartition(part, id, reqData, respData))
-        if (id != Uuid.ZERO_UUID)
-          sessionTopicIds.put(part.topic, id)
+        cachedPartitions.mustAdd(new CachedPartition(part, reqData, respData))
       }
-      (cachedPartitions, sessionTopicIds)
+      cachedPartitions
     }
     val responseSessionId = cache.maybeCreateSession(time.milliseconds(), isFromFollower,
         updates.size, usesTopicIds, () => createNewSession)
-    if (hasInconsistentTopicIds) {
-      FetchResponse.of(Errors.INCONSISTENT_TOPIC_ID, 0, responseSessionId, new FetchSession.RESP_MAP, Collections.emptyMap())
-    } else {
-      debug(s"Full fetch context with session id $responseSessionId returning " +
-        s"${partitionsToLogString(updates.keySet)}")
-      FetchResponse.of(Errors.NONE, 0, responseSessionId, updates, topicIds)
-    }
+    debug(s"Full fetch context with session id $responseSessionId returning " +
+      s"${partitionsToLogString(updates.keySet)}")
+    FetchResponse.of(Errors.NONE, 0, responseSessionId, updates)
   }
 }
 
@@ -435,18 +429,23 @@ class FullFetchContext(private val time: Time,
   * @param time         The clock to use.
   * @param reqMetadata  The request metadata.
   * @param session      The incremental fetch request session.
+  * @param topicNames   A mapping from topic ID to topic name used to resolve partitions already in the session.
   */
 class IncrementalFetchContext(private val time: Time,
                               private val reqMetadata: JFetchMetadata,
-                              private val session: FetchSession) extends FetchContext {
+                              private val session: FetchSession,
+                              private val topicNames: FetchSession.TOPIC_NAME_MAP) extends FetchContext {
 
-  override def getFetchOffset(tp: TopicPartition): Option[Long] = session.getFetchOffset(tp)
+  override def getFetchOffset(tp: TopicIdPartition): Option[Long] = session.getFetchOffset(tp)
 
-  override def foreachPartition(fun: (TopicPartition, Uuid, FetchRequest.PartitionData) => Unit): Unit = {
+  override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = {
     // Take the session lock and iterate over all the cached partitions.
     session.synchronized {
       session.partitionMap.forEach { part =>
-        fun(new TopicPartition(part.topic, part.partition), part.topicId, part.reqData)
+        // Try to resolve an unresolved partition if it does not yet have a name
+        if (session.usesTopicIds)
+          part.maybeResolveUnknownName(topicNames)
+        fun(new TopicIdPartition(part.topicId, new TopicPartition(part.topic, part.partition)), part.reqData)
       }
     }
   }
@@ -457,15 +456,14 @@ class IncrementalFetchContext(private val time: Time,
   private class PartitionIterator(val iter: FetchSession.RESP_MAP_ITER,
                                   val updateFetchContextAndRemoveUnselected: Boolean)
     extends FetchSession.RESP_MAP_ITER {
-    var nextElement: util.Map.Entry[TopicPartition, FetchResponseData.PartitionData] = null
+    var nextElement: util.Map.Entry[TopicIdPartition, FetchResponseData.PartitionData] = null
 
     override def hasNext: Boolean = {
       while ((nextElement == null) && iter.hasNext) {
         val element = iter.next()
         val topicPart = element.getKey
         val respData = element.getValue
-        val cachedPart = session.partitionMap.find(new CachedPartition(topicPart,
-          session.sessionTopicIds.getOrDefault(topicPart.topic(), Uuid.ZERO_UUID)))
+        val cachedPart = session.partitionMap.find(new CachedPartition(topicPart))
         val mustRespond = cachedPart.maybeUpdateResponseData(respData, updateFetchContextAndRemoveUnselected)
         if (mustRespond) {
           nextElement = element
@@ -482,7 +480,7 @@ class IncrementalFetchContext(private val time: Time,
       nextElement != null
     }
 
-    override def next(): util.Map.Entry[TopicPartition, FetchResponseData.PartitionData] = {
+    override def next(): util.Map.Entry[TopicIdPartition, FetchResponseData.PartitionData] = {
       if (!hasNext) throw new NoSuchElementException
       val element = nextElement
       nextElement = null
@@ -496,10 +494,10 @@ class IncrementalFetchContext(private val time: Time,
     session.synchronized {
       val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch)
       if (session.epoch != expectedEpoch) {
-        FetchResponse.sizeOf(versionId, (new FetchSession.RESP_MAP).entrySet.iterator, Collections.emptyMap())
+        FetchResponse.sizeOf(versionId, (new FetchSession.RESP_MAP).entrySet.iterator)
       } else {
         // Pass the partition iterator which updates neither the fetch context nor the partition map.
-        FetchResponse.sizeOf(versionId, new PartitionIterator(updates.entrySet.iterator, false), session.sessionTopicIds)
+        FetchResponse.sizeOf(versionId, new PartitionIterator(updates.entrySet.iterator, false))
       }
     }
   }
@@ -512,26 +510,16 @@ class IncrementalFetchContext(private val time: Time,
       if (session.epoch != expectedEpoch) {
         info(s"Incremental fetch session ${session.id} expected epoch $expectedEpoch, but " +
           s"got ${session.epoch}.  Possible duplicate request.")
-        FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, 0, session.id, new FetchSession.RESP_MAP, Collections.emptyMap())
+        FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, 0, session.id, new FetchSession.RESP_MAP)
       } else {
-        var hasInconsistentTopicIds = false
         // Iterate over the update list using PartitionIterator. This will prune updates which don't need to be sent
-        // It will also set the top-level error to INCONSISTENT_TOPIC_ID if any partitions had this error.
         val partitionIter = new PartitionIterator(updates.entrySet.iterator, true)
         while (partitionIter.hasNext) {
-          val entry = partitionIter.next()
-          if (entry.getValue.errorCode() == Errors.INCONSISTENT_TOPIC_ID.code()) {
-            info(s"Incremental fetch session ${session.id} encountered an inconsistent topic ID for topicPartition ${entry.getKey}.")
-            hasInconsistentTopicIds = true
-          }
-        }
-        if (hasInconsistentTopicIds) {
-          FetchResponse.of(Errors.INCONSISTENT_TOPIC_ID, 0, session.id, new FetchSession.RESP_MAP, Collections.emptyMap())
-        } else {
-          debug(s"Incremental fetch context with session id ${session.id} returning " +
-            s"${partitionsToLogString(updates.keySet)}")
-          FetchResponse.of(Errors.NONE, 0, session.id, updates, session.sessionTopicIds)
+          partitionIter.next()
         }
+        debug(s"Incremental fetch context with session id ${session.id} returning " +
+          s"${partitionsToLogString(updates.keySet)}")
+        FetchResponse.of(Errors.NONE, 0, session.id, updates)
       }
     }
   }
@@ -544,9 +532,9 @@ class IncrementalFetchContext(private val time: Time,
       if (session.epoch != expectedEpoch) {
         info(s"Incremental fetch session ${session.id} expected epoch $expectedEpoch, but " +
           s"got ${session.epoch}.  Possible duplicate request.")
-        FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, throttleTimeMs, session.id, new FetchSession.RESP_MAP, Collections.emptyMap())
+        FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, throttleTimeMs, session.id, new FetchSession.RESP_MAP)
       } else {
-        FetchResponse.of(Errors.NONE, throttleTimeMs, session.id, new FetchSession.RESP_MAP, Collections.emptyMap())
+        FetchResponse.of(Errors.NONE, throttleTimeMs, session.id, new FetchSession.RESP_MAP)
       }
     }
   }
@@ -653,13 +641,13 @@ class FetchSessionCache(private val maxEntries: Int,
                          privileged: Boolean,
                          size: Int,
                          usesTopicIds: Boolean,
-                         createPartitions: () => (FetchSession.CACHE_MAP, FetchSession.TOPIC_ID_MAP)): Int =
+                         createPartitions: () => FetchSession.CACHE_MAP): Int =
   synchronized {
     // If there is room, create a new session entry.
     if ((sessions.size < maxEntries) ||
         tryEvict(privileged, EvictableKey(privileged, size, 0), now)) {
-      val (partitionMap, topicIds) = createPartitions()
-      val session = new FetchSession(newSessionId(), privileged, partitionMap, usesTopicIds, topicIds,
+      val partitionMap = createPartitions()
+      val session = new FetchSession(newSessionId(), privileged, partitionMap, usesTopicIds,
           now, now, JFetchMetadata.nextEpoch(INITIAL_EPOCH))
       debug(s"Created fetch session ${session.toString}")
       sessions.put(session.id, session)
@@ -783,8 +771,8 @@ class FetchManager(private val time: Time,
                  reqMetadata: JFetchMetadata,
                  isFollower: Boolean,
                  fetchData: FetchSession.REQ_MAP,
-                 toForget: util.List[TopicPartition],
-                 topicIds: util.Map[String, Uuid]): FetchContext = {
+                 toForget: util.List[TopicIdPartition],
+                 topicNames: FetchSession.TOPIC_NAME_MAP): FetchContext = {
     val context = if (reqMetadata.isFull) {
       var removedFetchSessionStr = ""
       if (reqMetadata.sessionId != INVALID_SESSION_ID) {
@@ -797,9 +785,9 @@ class FetchManager(private val time: Time,
       val context = if (reqMetadata.epoch == FINAL_EPOCH) {
         // If the epoch is FINAL_EPOCH, don't try to create a new session.
         suffix = " Will not try to create a new session."
-        new SessionlessFetchContext(fetchData, topicIds)
+        new SessionlessFetchContext(fetchData)
       } else {
-        new FullFetchContext(time, cache, reqMetadata, fetchData, reqVersion >= 13, topicIds, isFollower)
+        new FullFetchContext(time, cache, reqMetadata, fetchData, reqVersion >= 13, isFollower)
       }
       debug(s"Created a new full FetchContext with ${partitionsToLogString(fetchData.keySet)}."+
         s"${removedFetchSessionStr}${suffix}")
@@ -822,17 +810,13 @@ class FetchManager(private val time: Time,
                 s", but request version $reqVersion means that we can not.")
               new SessionErrorContext(Errors.FETCH_SESSION_TOPIC_ID_ERROR, reqMetadata)
             } else {
-              val (added, updated, removed, inconsistent) = session.update(fetchData, toForget, reqMetadata, topicIds)
+              val (added, updated, removed) = session.update(fetchData, toForget, reqMetadata)
               if (session.isEmpty) {
                 debug(s"Created a new sessionless FetchContext and closing session id ${session.id}, " +
                   s"epoch ${session.epoch}: after removing ${partitionsToLogString(removed)}, " +
                   s"there are no more partitions left.")
                 cache.remove(session)
-                new SessionlessFetchContext(fetchData, topicIds)
-              } else if (!inconsistent.isEmpty) {
-                debug(s"Session error for session id ${session.id},epoch ${session.epoch}: after finding " +
-                  s"inconsistent topic IDs on partitions: ${partitionsToLogString(inconsistent)}.")
-                new SessionErrorContext(Errors.FETCH_SESSION_TOPIC_ID_ERROR, reqMetadata)
+                new SessionlessFetchContext(fetchData)
               } else {
                 cache.touch(session, time.milliseconds())
                 session.epoch = JFetchMetadata.nextEpoch(session.epoch)
@@ -840,7 +824,7 @@ class FetchManager(private val time: Time,
                   s"epoch ${session.epoch}: added ${partitionsToLogString(added)}, " +
                   s"updated ${partitionsToLogString(updated)}, " +
                   s"removed ${partitionsToLogString(removed)}")
-                new IncrementalFetchContext(time, reqMetadata, session)
+                new IncrementalFetchContext(time, reqMetadata, session, topicNames)
               }
             }
           }
@@ -850,6 +834,6 @@ class FetchManager(private val time: Time,
     context
   }
 
-  def partitionsToLogString(partitions: util.Collection[TopicPartition]): String =
+  def partitionsToLogString(partitions: util.Collection[TopicIdPartition]): String =
     FetchSession.partitionsToLogString(partitions, isTraceEnabled)
 }
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index e20a36a..a4b38be 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -70,7 +70,7 @@ import org.apache.kafka.common.resource.{Resource, ResourceType}
 import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol}
 import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation}
 import org.apache.kafka.common.utils.{ProducerIdAndEpoch, Time}
-import org.apache.kafka.common.{Node, TopicPartition, Uuid}
+import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.server.authorizer._
 import java.lang.{Long => JLong}
 import java.nio.ByteBuffer
@@ -672,21 +672,14 @@ class KafkaApis(val requestChannel: RequestChannel,
     val versionId = request.header.apiVersion
     val clientId = request.header.clientId
     val fetchRequest = request.body[FetchRequest]
-    val (topicIds, topicNames) =
+    val topicNames =
       if (fetchRequest.version() >= 13)
-        metadataCache.topicIdInfo()
+        metadataCache.topicIdsToNames()
       else
-        (Collections.emptyMap[String, Uuid](), Collections.emptyMap[Uuid, String]())
+        Collections.emptyMap[Uuid, String]()
 
-    // If fetchData or forgottenTopics contain an unknown topic ID, return a top level error.
-    var fetchData: util.Map[TopicPartition, FetchRequest.PartitionData] = null
-    var forgottenTopics: util.List[TopicPartition] = null
-    try {
-      fetchData = fetchRequest.fetchData(topicNames)
-      forgottenTopics = fetchRequest.forgottenTopics(topicNames)
-    } catch {
-      case e: UnknownTopicIdException => throw e
-    }
+    val fetchData = fetchRequest.fetchData(topicNames)
+    val forgottenTopics = fetchRequest.forgottenTopics(topicNames)
 
     val fetchContext = fetchManager.newContext(
       fetchRequest.version,
@@ -694,7 +687,7 @@ class KafkaApis(val requestChannel: RequestChannel,
       fetchRequest.isFromFollower,
       fetchData,
       forgottenTopics,
-      topicIds)
+      topicNames)
 
     val clientMetadata: Option[ClientMetadata] = if (versionId >= 11) {
       // Fetch API version 11 added preferred replica logic
@@ -708,40 +701,41 @@ class KafkaApis(val requestChannel: RequestChannel,
       None
     }
 
-    val erroneous = mutable.ArrayBuffer[(TopicPartition, FetchResponseData.PartitionData)]()
-    val interesting = mutable.ArrayBuffer[(TopicPartition, FetchRequest.PartitionData)]()
-    val sessionTopicIds = mutable.Map[String, Uuid]()
+    val erroneous = mutable.ArrayBuffer[(TopicIdPartition, FetchResponseData.PartitionData)]()
+    val interesting = mutable.ArrayBuffer[(TopicIdPartition, FetchRequest.PartitionData)]()
     if (fetchRequest.isFromFollower) {
       // The follower must have ClusterAction on ClusterResource in order to fetch partition data.
       if (authHelper.authorize(request.context, CLUSTER_ACTION, CLUSTER, CLUSTER_NAME)) {
-        fetchContext.foreachPartition { (topicPartition, topicId, data) =>
-          sessionTopicIds.put(topicPartition.topic(), topicId)
-          if (!metadataCache.contains(topicPartition))
-            erroneous += topicPartition -> FetchResponse.partitionResponse(topicPartition.partition, Errors.UNKNOWN_TOPIC_OR_PARTITION)
+        fetchContext.foreachPartition { (topicIdPartition, data) =>
+          if (topicIdPartition.topic == null)
+            erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_ID)
+          else if (!metadataCache.contains(topicIdPartition.topicPartition))
+            erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_OR_PARTITION)
           else
-            interesting += (topicPartition -> data)
+            interesting += topicIdPartition -> data
         }
       } else {
-        fetchContext.foreachPartition { (part, topicId, _) =>
-          sessionTopicIds.put(part.topic(), topicId)
-          erroneous += part -> FetchResponse.partitionResponse(part.partition, Errors.TOPIC_AUTHORIZATION_FAILED)
+        fetchContext.foreachPartition { (topicIdPartition, _) =>
+          erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.TOPIC_AUTHORIZATION_FAILED)
         }
       }
     } else {
       // Regular Kafka consumers need READ permission on each partition they are fetching.
-      val partitionDatas = new mutable.ArrayBuffer[(TopicPartition, FetchRequest.PartitionData)]
-      fetchContext.foreachPartition { (topicPartition, topicId, partitionData) =>
-        partitionDatas += topicPartition -> partitionData
-        sessionTopicIds.put(topicPartition.topic(), topicId)
-      }
-      val authorizedTopics = authHelper.filterByAuthorized(request.context, READ, TOPIC, partitionDatas)(_._1.topic)
-      partitionDatas.foreach { case (topicPartition, data) =>
-        if (!authorizedTopics.contains(topicPartition.topic))
-          erroneous += topicPartition -> FetchResponse.partitionResponse(topicPartition.partition, Errors.TOPIC_AUTHORIZATION_FAILED)
-        else if (!metadataCache.contains(topicPartition))
-          erroneous += topicPartition -> FetchResponse.partitionResponse(topicPartition.partition, Errors.UNKNOWN_TOPIC_OR_PARTITION)
+      val partitionDatas = new mutable.ArrayBuffer[(TopicIdPartition, FetchRequest.PartitionData)]
+      fetchContext.foreachPartition { (topicIdPartition, partitionData) =>
+        if (topicIdPartition.topic == null)
+          erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_ID)
         else
-          interesting += (topicPartition -> data)
+          partitionDatas += topicIdPartition -> partitionData
+      }
+      val authorizedTopics = authHelper.filterByAuthorized(request.context, READ, TOPIC, partitionDatas)(_._1.topicPartition.topic)
+      partitionDatas.foreach { case (topicIdPartition, data) =>
+        if (!authorizedTopics.contains(topicIdPartition.topic))
+          erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.TOPIC_AUTHORIZATION_FAILED)
+        else if (!metadataCache.contains(topicIdPartition.topicPartition))
+          erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_OR_PARTITION)
+        else
+          interesting += topicIdPartition -> data
       }
     }
 
@@ -757,13 +751,14 @@ class KafkaApis(val requestChannel: RequestChannel,
       }
     }
 
-    def maybeConvertFetchedData(tp: TopicPartition,
+    def maybeConvertFetchedData(tp: TopicIdPartition,
                                 partitionData: FetchResponseData.PartitionData): FetchResponseData.PartitionData = {
-      val logConfig = replicaManager.getLogConfig(tp)
+      // We will never return a logConfig when the topic is unresolved and the name is null. This is ok since we won't have any records to convert.
+      val logConfig = replicaManager.getLogConfig(tp.topicPartition)
 
       if (logConfig.exists(_.compressionType == ZStdCompressionCodec.name) && versionId < 10) {
         trace(s"Fetching messages is disabled for ZStandard compressed partition $tp. Sending unsupported version response to $clientId.")
-        FetchResponse.partitionResponse(tp.partition, Errors.UNSUPPORTED_COMPRESSION_TYPE)
+        FetchResponse.partitionResponse(tp, Errors.UNSUPPORTED_COMPRESSION_TYPE)
       } else {
         // Down-conversion of fetched records is needed when the on-disk magic value is greater than what is
         // supported by the fetch request version.
@@ -792,7 +787,7 @@ class KafkaApis(val requestChannel: RequestChannel,
             // For fetch requests from clients, check if down-conversion is disabled for the particular partition
             if (!fetchRequest.isFromFollower && !logConfig.forall(_.messageDownConversionEnable)) {
               trace(s"Conversion to message format ${downConvertMagic.get} is disabled for partition $tp. Sending unsupported version response to $clientId.")
-              FetchResponse.partitionResponse(tp.partition, Errors.UNSUPPORTED_VERSION)
+              FetchResponse.partitionResponse(tp, Errors.UNSUPPORTED_VERSION)
             } else {
               try {
                 trace(s"Down converting records from partition $tp to message format version $magic for fetch request from $clientId")
@@ -807,12 +802,12 @@ class KafkaApis(val requestChannel: RequestChannel,
                   .setLastStableOffset(partitionData.lastStableOffset)
                   .setLogStartOffset(partitionData.logStartOffset)
                   .setAbortedTransactions(partitionData.abortedTransactions)
-                  .setRecords(new LazyDownConversionRecords(tp, unconvertedRecords, magic, fetchContext.getFetchOffset(tp).get, time))
+                  .setRecords(new LazyDownConversionRecords(tp.topicPartition, unconvertedRecords, magic, fetchContext.getFetchOffset(tp).get, time))
                   .setPreferredReadReplica(partitionData.preferredReadReplica())
               } catch {
                 case e: UnsupportedCompressionTypeException =>
                   trace("Received unsupported compression type error during down-conversion", e)
-                  FetchResponse.partitionResponse(tp.partition, Errors.UNSUPPORTED_COMPRESSION_TYPE)
+                  FetchResponse.partitionResponse(tp, Errors.UNSUPPORTED_COMPRESSION_TYPE)
               }
             }
           case None =>
@@ -831,9 +826,9 @@ class KafkaApis(val requestChannel: RequestChannel,
     }
 
     // the callback for process a fetch response, invoked before throttling
-    def processResponseCallback(responsePartitionData: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
-      val partitions = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-      val reassigningPartitions = mutable.Set[TopicPartition]()
+    def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
+      val partitions = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+      val reassigningPartitions = mutable.Set[TopicIdPartition]()
       responsePartitionData.foreach { case (tp, data) =>
         val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull
         val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
@@ -856,10 +851,10 @@ class KafkaApis(val requestChannel: RequestChannel,
 
       def createResponse(throttleTimeMs: Int): FetchResponse = {
         // Down-convert messages for each partition if required
-        val convertedData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+        val convertedData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
         unconvertedFetchResponse.data().responses().forEach { topicResponse =>
           topicResponse.partitions().forEach { unconvertedPartitionData =>
-            val tp = new TopicPartition(topicResponse.topic(), unconvertedPartitionData.partitionIndex())
+            val tp = new TopicIdPartition(topicResponse.topicId, new TopicPartition(topicResponse.topic, unconvertedPartitionData.partitionIndex()))
             val error = Errors.forCode(unconvertedPartitionData.errorCode)
             if (error != Errors.NONE)
               debug(s"Fetch request with correlation id ${request.header.correlationId} from client $clientId " +
@@ -870,12 +865,15 @@ class KafkaApis(val requestChannel: RequestChannel,
 
         // Prepare fetch response from converted data
         val response =
-          FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData, sessionTopicIds.asJava)
+          FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData)
         // record the bytes out metrics only when the response is being sent
-        response.data().responses().forEach { topicResponse =>
-          topicResponse.partitions().forEach { data =>
-            val tp = new TopicPartition(topicResponse.topic(), data.partitionIndex())
-            brokerTopicStats.updateBytesOut(tp.topic, fetchRequest.isFromFollower, reassigningPartitions.contains(tp), FetchResponse.recordsSize(data))
+        response.data.responses.forEach { topicResponse =>
+          topicResponse.partitions.forEach { data =>
+            // If the topic name was not known, we will have no bytes out.
+            if (topicResponse.topic != null) {
+              val tp = new TopicIdPartition(topicResponse.topicId, new TopicPartition(topicResponse.topic, data.partitionIndex))
+              brokerTopicStats.updateBytesOut(tp.topic, fetchRequest.isFromFollower, reassigningPartitions.contains(tp), FetchResponse.recordsSize(data))
+            }
           }
         }
         response
@@ -894,7 +892,7 @@ class KafkaApis(val requestChannel: RequestChannel,
       if (fetchRequest.isFromFollower) {
         // We've already evaluated against the quota and are good to go. Just need to record it now.
         unconvertedFetchResponse = fetchContext.updateAndGenerateResponseData(partitions)
-        val responseSize = KafkaApis.sizeOfThrottledPartitions(versionId, unconvertedFetchResponse, quotas.leader, sessionTopicIds.asJava)
+        val responseSize = KafkaApis.sizeOfThrottledPartitions(versionId, unconvertedFetchResponse, quotas.leader)
         quotas.leader.record(responseSize)
         val responsePartitionsSize = unconvertedFetchResponse.data().responses().stream().mapToInt(_.partitions().size()).sum()
         trace(s"Sending Fetch response with partitions.size=$responsePartitionsSize, " +
@@ -960,7 +958,6 @@ class KafkaApis(val requestChannel: RequestChannel,
         fetchMaxBytes,
         versionId <= 2,
         interesting,
-        sessionTopicIds.asJava,
         replicationQuota(fetchRequest),
         processResponseCallback,
         fetchRequest.isolationLevel,
@@ -3497,14 +3494,13 @@ object KafkaApis {
   // TODO: remove resolvedResponseData method when sizeOf can take a data object.
   private[server] def sizeOfThrottledPartitions(versionId: Short,
                                                 unconvertedResponse: FetchResponse,
-                                                quota: ReplicationQuotaManager,
-                                                topicIds: util.Map[String, Uuid]): Int = {
-    val responseData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+                                                quota: ReplicationQuotaManager): Int = {
+    val responseData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     unconvertedResponse.data.responses().forEach(topicResponse =>
       topicResponse.partitions().forEach(partition =>
-        responseData.put(new TopicPartition(topicResponse.topic(), partition.partitionIndex()), partition)))
+        responseData.put(new TopicIdPartition(topicResponse.topicId, new TopicPartition(topicResponse.topic(), partition.partitionIndex)), partition)))
     FetchResponse.sizeOf(versionId, responseData.entrySet
-      .iterator.asScala.filter(element => quota.isThrottled(element.getKey)).asJava, topicIds)
+      .iterator.asScala.filter(element => element.getKey.topicPartition.topic != null && quota.isThrottled(element.getKey.topicPartition)).asJava)
   }
 
   // visible for testing
diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
index 1984486..2ce33c8 100644
--- a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala
@@ -22,16 +22,15 @@ import kafka.cluster.BrokerEndPoint
 import kafka.log.{LeaderOffsetIncremented, LogAppendInfo}
 import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
 import kafka.server.QuotaFactory.UnboundedQuota
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.common.errors.KafkaStorageException
 import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, RequestUtils}
-
 import java.util
-import java.util.{Collections, Optional}
+import java.util.Optional
 import scala.collection.{Map, Seq, Set, mutable}
 import scala.compat.java8.OptionConverters._
 import scala.jdk.CollectionConverters._
@@ -78,20 +77,18 @@ class ReplicaAlterLogDirsThread(name: String,
 
     // We can build the map from the request since it contains topic IDs and names.
     // Only one ID can be associated with a name and vice versa.
-    val topicIds = new mutable.HashMap[String, Uuid]()
     val topicNames = new mutable.HashMap[Uuid, String]()
     request.data.topics.forEach { topic =>
-      topicIds.put(topic.topic, topic.topicId)
       topicNames.put(topic.topicId, topic.topic)
     }
 
 
-    def processResponseCallback(responsePartitionData: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
+    def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
       partitionData = responsePartitionData.map { case (tp, data) =>
         val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull
         val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET)
-        tp -> new FetchResponseData.PartitionData()
-          .setPartitionIndex(tp.partition)
+        tp.topicPartition -> new FetchResponseData.PartitionData()
+          .setPartitionIndex(tp.topicPartition.partition)
           .setErrorCode(data.error.code)
           .setHighWatermark(data.highWatermark)
           .setLastStableOffset(lastStableOffset)
@@ -101,7 +98,6 @@ class ReplicaAlterLogDirsThread(name: String,
       }
     }
 
-    // Will throw UnknownTopicIdException if a topic ID is unknown.
     val fetchData = request.fetchData(topicNames.asJava)
 
     replicaMgr.fetchMessages(
@@ -111,7 +107,6 @@ class ReplicaAlterLogDirsThread(name: String,
       request.maxBytes,
       false,
       fetchData.asScala.toSeq,
-      topicIds.asJava,
       UnboundedQuota,
       processResponseCallback,
       request.isolationLevel,
@@ -277,7 +272,8 @@ class ReplicaAlterLogDirsThread(name: String,
         fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
       else
         Optional.empty[Integer]
-      requestMap.put(tp, new FetchRequest.PartitionData(fetchState.fetchOffset, logStartOffset,
+      val topicId = fetchState.topicId.getOrElse(Uuid.ZERO_UUID)
+      requestMap.put(tp, new FetchRequest.PartitionData(topicId, fetchState.fetchOffset, logStartOffset,
         fetchSize, Optional.of(fetchState.currentLeaderEpoch), lastFetchedEpoch))
     } catch {
       case e: KafkaStorageException =>
@@ -294,8 +290,7 @@ class ReplicaAlterLogDirsThread(name: String,
         ApiKeys.FETCH.latestVersion
       // Set maxWait and minBytes to 0 because the response should return immediately if
       // the future log has caught up with the current log of the partition
-      val requestBuilder = FetchRequest.Builder.forReplica(version, replicaId, 0, 0, requestMap,
-        Collections.singletonMap(tp.topic, fetchState.topicId.getOrElse(Uuid.ZERO_UUID))).setMaxBytes(maxBytes)
+      val requestBuilder = FetchRequest.Builder.forReplica(version, replicaId, 0, 0, requestMap).setMaxBytes(maxBytes)
       Some(ReplicaFetch(requestMap, requestBuilder))
     }
 
diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
index 97910f2..57d89dc 100644
--- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
+++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala
@@ -224,10 +224,8 @@ class ReplicaFetcherThread(name: String,
     }
     val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse]
     if (!fetchSessionHandler.handleResponse(fetchResponse, clientResponse.requestHeader().apiVersion())) {
-      // If we had a topic ID related error, throw it, otherwise return an empty fetch data map.
-      if (fetchResponse.error == Errors.UNKNOWN_TOPIC_ID ||
-          fetchResponse.error == Errors.FETCH_SESSION_TOPIC_ID_ERROR ||
-          fetchResponse.error == Errors.INCONSISTENT_TOPIC_ID) {
+      // If we had a session topic ID related error, throw it, otherwise return an empty fetch data map.
+      if (fetchResponse.error == Errors.FETCH_SESSION_TOPIC_ID_ERROR) {
         throw Errors.forCode(fetchResponse.error().code()).exception()
       } else {
         Map.empty
@@ -284,7 +282,8 @@ class ReplicaFetcherThread(name: String,
             fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
           else
             Optional.empty[Integer]
-          builder.add(topicPartition, fetchState.topicId.getOrElse(Uuid.ZERO_UUID), new FetchRequest.PartitionData(
+          builder.add(topicPartition, new FetchRequest.PartitionData(
+            fetchState.topicId.getOrElse(Uuid.ZERO_UUID),
             fetchState.fetchOffset,
             logStartOffset,
             fetchSize,
@@ -305,9 +304,10 @@ class ReplicaFetcherThread(name: String,
     } else {
       val version: Short = if (fetchRequestVersion >= 13 && !fetchData.canUseTopicIds) 12 else fetchRequestVersion
       val requestBuilder = FetchRequest.Builder
-        .forReplica(version, replicaId, maxWait, minBytes, fetchData.toSend, fetchData.topicIds)
+        .forReplica(version, replicaId, maxWait, minBytes, fetchData.toSend)
         .setMaxBytes(maxBytes)
-        .toForget(fetchData.toForget)
+        .removed(fetchData.toForget)
+        .replaced(fetchData.toReplace)
         .metadata(fetchData.metadata)
       Some(ReplicaFetch(fetchData.sessionPartitions(), requestBuilder))
     }
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 7dbdbd5..1985cb1 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -17,7 +17,6 @@
 package kafka.server
 
 import java.io.File
-import java.util
 import java.util.Optional
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicBoolean
@@ -37,7 +36,7 @@ import kafka.server.metadata.ZkMetadataCache
 import kafka.utils._
 import kafka.utils.Implicits._
 import kafka.zk.KafkaZkClient
-import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicPartition, Uuid}
+import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.common.errors._
 import org.apache.kafka.common.internals.Topic
 import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState
@@ -983,10 +982,9 @@ class ReplicaManager(val config: KafkaConfig,
                     fetchMinBytes: Int,
                     fetchMaxBytes: Int,
                     hardMaxBytesLimit: Boolean,
-                    fetchInfos: Seq[(TopicPartition, PartitionData)],
-                    topicIds: util.Map[String, Uuid],
+                    fetchInfos: Seq[(TopicIdPartition, PartitionData)],
                     quota: ReplicaQuota,
-                    responseCallback: Seq[(TopicPartition, FetchPartitionData)] => Unit,
+                    responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit,
                     isolationLevel: IsolationLevel,
                     clientMetadata: Option[ClientMetadata]): Unit = {
     val isFromFollower = Request.isValidBrokerId(replicaId)
@@ -1000,7 +998,7 @@ class ReplicaManager(val config: KafkaConfig,
 
     // Restrict fetching to leader if request is from follower or from a client with older version (no ClientMetadata)
     val fetchOnlyFromLeader = isFromFollower || (isFromConsumer && clientMetadata.isEmpty)
-    def readFromLog(): Seq[(TopicPartition, LogReadResult)] = {
+    def readFromLog(): Seq[(TopicIdPartition, LogReadResult)] = {
       val result = readFromLocalLog(
         replicaId = replicaId,
         fetchOnlyFromLeader = fetchOnlyFromLeader,
@@ -1008,7 +1006,6 @@ class ReplicaManager(val config: KafkaConfig,
         fetchMaxBytes = fetchMaxBytes,
         hardMaxBytesLimit = hardMaxBytesLimit,
         readPartitionInfo = fetchInfos,
-        topicIds = topicIds,
         quota = quota,
         clientMetadata = clientMetadata)
       if (isFromFollower) updateFollowerFetchState(replicaId, result)
@@ -1021,9 +1018,9 @@ class ReplicaManager(val config: KafkaConfig,
     var bytesReadable: Long = 0
     var errorReadingData = false
     var hasDivergingEpoch = false
-    val logReadResultMap = new mutable.HashMap[TopicPartition, LogReadResult]
-    logReadResults.foreach { case (topicPartition, logReadResult) =>
-      brokerTopicStats.topicStats(topicPartition.topic).totalFetchRequestRate.mark()
+    val logReadResultMap = new mutable.HashMap[TopicIdPartition, LogReadResult]
+    logReadResults.foreach { case (topicIdPartition, logReadResult) =>
+      brokerTopicStats.topicStats(topicIdPartition.topicPartition.topic).totalFetchRequestRate.mark()
       brokerTopicStats.allTopicsStats.totalFetchRequestRate.mark()
 
       if (logReadResult.error != Errors.NONE)
@@ -1031,7 +1028,7 @@ class ReplicaManager(val config: KafkaConfig,
       if (logReadResult.divergingEpoch.nonEmpty)
         hasDivergingEpoch = true
       bytesReadable = bytesReadable + logReadResult.info.records.sizeInBytes
-      logReadResultMap.put(topicPartition, logReadResult)
+      logReadResultMap.put(topicIdPartition, logReadResult)
     }
 
     // respond immediately if 1) fetch request does not want to wait
@@ -1041,21 +1038,21 @@ class ReplicaManager(val config: KafkaConfig,
     //                        5) we found a diverging epoch
     if (timeout <= 0 || fetchInfos.isEmpty || bytesReadable >= fetchMinBytes || errorReadingData || hasDivergingEpoch) {
       val fetchPartitionData = logReadResults.map { case (tp, result) =>
-        val isReassignmentFetch = isFromFollower && isAddingReplica(tp, replicaId)
+        val isReassignmentFetch = isFromFollower && isAddingReplica(tp.topicPartition, replicaId)
         tp -> result.toFetchPartitionData(isReassignmentFetch)
       }
       responseCallback(fetchPartitionData)
     } else {
       // construct the fetch results from the read results
-      val fetchPartitionStatus = new mutable.ArrayBuffer[(TopicPartition, FetchPartitionStatus)]
-      fetchInfos.foreach { case (topicPartition, partitionData) =>
-        logReadResultMap.get(topicPartition).foreach(logReadResult => {
+      val fetchPartitionStatus = new mutable.ArrayBuffer[(TopicIdPartition, FetchPartitionStatus)]
+      fetchInfos.foreach { case (topicIdPartition, partitionData) =>
+        logReadResultMap.get(topicIdPartition).foreach(logReadResult => {
           val logOffsetMetadata = logReadResult.info.fetchOffsetMetadata
-          fetchPartitionStatus += (topicPartition -> FetchPartitionStatus(logOffsetMetadata, partitionData))
+          fetchPartitionStatus += (topicIdPartition -> FetchPartitionStatus(logOffsetMetadata, partitionData))
         })
       }
       val fetchMetadata: SFetchMetadata = SFetchMetadata(fetchMinBytes, fetchMaxBytes, hardMaxBytesLimit,
-        fetchOnlyFromLeader, fetchIsolation, isFromFollower, replicaId, topicIds, fetchPartitionStatus)
+        fetchOnlyFromLeader, fetchIsolation, isFromFollower, replicaId, fetchPartitionStatus)
       val delayedFetch = new DelayedFetch(timeout, fetchMetadata, this, quota, clientMetadata,
         responseCallback)
 
@@ -1077,22 +1074,12 @@ class ReplicaManager(val config: KafkaConfig,
                        fetchIsolation: FetchIsolation,
                        fetchMaxBytes: Int,
                        hardMaxBytesLimit: Boolean,
-                       readPartitionInfo: Seq[(TopicPartition, PartitionData)],
-                       topicIds: util.Map[String, Uuid],
+                       readPartitionInfo: Seq[(TopicIdPartition, PartitionData)],
                        quota: ReplicaQuota,
-                       clientMetadata: Option[ClientMetadata]): Seq[(TopicPartition, LogReadResult)] = {
+                       clientMetadata: Option[ClientMetadata]): Seq[(TopicIdPartition, LogReadResult)] = {
     val traceEnabled = isTraceEnabled
 
-    def topicIdFromSession(topicName: String): Option[Uuid] = {
-      val topicId = topicIds.get(topicName)
-      // if invalid topic ID return None
-      if (topicId == null || topicId == Uuid.ZERO_UUID)
-        None
-      else
-        Some(topicId)
-    }
-
-    def read(tp: TopicPartition, fetchInfo: PartitionData, limitBytes: Int, minOneMessage: Boolean): LogReadResult = {
+    def read(tp: TopicIdPartition, fetchInfo: PartitionData, limitBytes: Int, minOneMessage: Boolean): LogReadResult = {
       val offset = fetchInfo.fetchOffset
       val partitionFetchSize = fetchInfo.maxBytes
       val followerLogStartOffset = fetchInfo.logStartOffset
@@ -1104,11 +1091,12 @@ class ReplicaManager(val config: KafkaConfig,
             s"remaining response limit $limitBytes" +
             (if (minOneMessage) s", ignoring response/partition size limits" else ""))
 
-        val partition = getPartitionOrException(tp)
+        val partition = getPartitionOrException(tp.topicPartition)
         val fetchTimeMs = time.milliseconds
 
         // Check if topic ID from the fetch request/session matches the ID in the log
-        if (!hasConsistentTopicId(topicIdFromSession(partition.topic), partition.topicId))
+        val topicId = if (tp.topicId == Uuid.ZERO_UUID) None else Some(tp.topicId)
+        if (!hasConsistentTopicId(topicId, partition.topicId))
           throw new InconsistentTopicIdException("Topic ID in the fetch session did not match the topic ID in the log.")
 
         // If we are the leader, determine the preferred read-replica
@@ -1206,7 +1194,7 @@ class ReplicaManager(val config: KafkaConfig,
     }
 
     var limitBytes = fetchMaxBytes
-    val result = new mutable.ArrayBuffer[(TopicPartition, LogReadResult)]
+    val result = new mutable.ArrayBuffer[(TopicIdPartition, LogReadResult)]
     var minOneMessage = !hardMaxBytesLimit
     readPartitionInfo.foreach { case (tp, fetchInfo) =>
       val readResult = read(tp, fetchInfo, limitBytes, minOneMessage)
@@ -1814,8 +1802,8 @@ class ReplicaManager(val config: KafkaConfig,
    * fails with any error, follower fetch state is not updated.
    */
   private def updateFollowerFetchState(followerId: Int,
-                                       readResults: Seq[(TopicPartition, LogReadResult)]): Seq[(TopicPartition, LogReadResult)] = {
-    readResults.map { case (topicPartition, readResult) =>
+                                       readResults: Seq[(TopicIdPartition, LogReadResult)]): Seq[(TopicIdPartition, LogReadResult)] = {
+    readResults.map { case (topicIdPartition, readResult) =>
       val updatedReadResult = if (readResult.error != Errors.NONE) {
         debug(s"Skipping update of fetch state for follower $followerId since the " +
           s"log read returned error ${readResult.error}")
@@ -1825,7 +1813,7 @@ class ReplicaManager(val config: KafkaConfig,
           s"log read returned diverging epoch ${readResult.divergingEpoch}")
         readResult
       } else {
-        onlinePartition(topicPartition) match {
+        onlinePartition(topicIdPartition.topicPartition) match {
           case Some(partition) =>
             if (partition.updateFollowerFetchState(followerId,
               followerFetchOffsetMetadata = readResult.info.fetchOffsetMetadata,
@@ -1837,15 +1825,15 @@ class ReplicaManager(val config: KafkaConfig,
               warn(s"Leader $localBrokerId failed to record follower $followerId's position " +
                 s"${readResult.info.fetchOffsetMetadata.messageOffset}, and last sent HW since the replica " +
                 s"is not recognized to be one of the assigned replicas ${partition.assignmentState.replicas.mkString(",")} " +
-                s"for partition $topicPartition. Empty records will be returned for this partition.")
+                s"for partition $topicIdPartition. Empty records will be returned for this partition.")
               readResult.withEmptyFetchInfo
             }
           case None =>
-            warn(s"While recording the replica LEO, the partition $topicPartition hasn't been created.")
+            warn(s"While recording the replica LEO, the partition $topicIdPartition hasn't been created.")
             readResult
         }
       }
-      topicPartition -> updatedReadResult
+      topicIdPartition -> updatedReadResult
     }
   }
 
diff --git a/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala b/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala
index b2bece4..2c74550 100644
--- a/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala
+++ b/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala
@@ -32,7 +32,6 @@ import org.apache.kafka.common.requests.{AbstractRequest, FetchResponse, ListOff
 import org.apache.kafka.common.serialization.StringDeserializer
 import org.apache.kafka.common.utils.{LogContext, Time}
 import org.apache.kafka.common.{Node, TopicPartition, Uuid}
-
 import java.net.SocketTimeoutException
 import java.text.SimpleDateFormat
 import java.util
@@ -40,6 +39,7 @@ import java.util.concurrent.CountDownLatch
 import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
 import java.util.regex.{Pattern, PatternSyntaxException}
 import java.util.{Date, Optional, Properties}
+
 import scala.collection.Seq
 import scala.jdk.CollectionConverters._
 
@@ -407,11 +407,11 @@ private class ReplicaFetcher(name: String, sourceBroker: Node, topicPartitions:
 
     val requestMap = new util.LinkedHashMap[TopicPartition, JFetchRequest.PartitionData]
     for (topicPartition <- topicPartitions)
-      requestMap.put(topicPartition, new JFetchRequest.PartitionData(replicaBuffer.getOffset(topicPartition),
+      requestMap.put(topicPartition, new JFetchRequest.PartitionData(topicIds.getOrElse(topicPartition.topic, Uuid.ZERO_UUID), replicaBuffer.getOffset(topicPartition),
         0L, fetchSize, Optional.empty()))
 
     val fetchRequestBuilder = JFetchRequest.Builder.
-      forReplica(ApiKeys.FETCH.latestVersion, Request.DebuggingConsumerId, maxWait, minBytes, requestMap, topicIds.asJava)
+      forReplica(ApiKeys.FETCH.latestVersion, Request.DebuggingConsumerId, maxWait, minBytes, requestMap)
 
     debug("Issuing fetch request ")
 
diff --git a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
index fbb6f79..6efb860 100644
--- a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
@@ -350,21 +350,24 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
 
   private def createFetchRequest = {
     val partitionMap = new util.LinkedHashMap[TopicPartition, requests.FetchRequest.PartitionData]
-    partitionMap.put(tp, new requests.FetchRequest.PartitionData(0, 0, 100, Optional.of(27)))
-    requests.FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 100, Int.MaxValue, partitionMap, getTopicIds().asJava).build()
+    partitionMap.put(tp, new requests.FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID),
+      0, 0, 100, Optional.of(27)))
+    requests.FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 100, Int.MaxValue, partitionMap).build()
   }
 
   private def createFetchRequestWithUnknownTopic(id: Uuid, version: Short) = {
     val partitionMap = new util.LinkedHashMap[TopicPartition, requests.FetchRequest.PartitionData]
-    partitionMap.put(tp, new requests.FetchRequest.PartitionData(0, 0, 100, Optional.of(27)))
-    requests.FetchRequest.Builder.forConsumer(version, 100, Int.MaxValue, partitionMap, Collections.singletonMap(topic, id)).build()
+    partitionMap.put(tp,
+      new requests.FetchRequest.PartitionData(id, 0, 0, 100, Optional.of(27)))
+    requests.FetchRequest.Builder.forConsumer(version, 100, Int.MaxValue, partitionMap).build()
   }
 
   private def createFetchFollowerRequest = {
     val partitionMap = new util.LinkedHashMap[TopicPartition, requests.FetchRequest.PartitionData]
-    partitionMap.put(tp, new requests.FetchRequest.PartitionData(0, 0, 100, Optional.of(27)))
+    partitionMap.put(tp, new requests.FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID),
+      0, 0, 100, Optional.of(27)))
     val version = ApiKeys.FETCH.latestVersion
-    requests.FetchRequest.Builder.forReplica(version, 5000, 100, Int.MaxValue, partitionMap, getTopicIds().asJava).build()
+    requests.FetchRequest.Builder.forReplica(version, 5000, 100, Int.MaxValue, partitionMap).build()
   }
 
   private def createListOffsetsRequest = {
@@ -712,7 +715,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
         val describeAcls = topicDescribeAcl(topicResource)
         val isAuthorized = describeAcls == acls
         addAndVerifyAcls(describeAcls, topicResource)
-        sendRequestAndVerifyResponseError(request, resources, isAuthorized = isAuthorized,  topicExists = topicExists, topicNames = topicNames)
+        sendRequestAndVerifyResponseError(request, resources, isAuthorized = isAuthorized, topicExists = topicExists, topicNames = topicNames)
         removeAllClientAcls()
       }
 
diff --git a/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
index 1d95ee6..7585862 100644
--- a/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
+++ b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala
@@ -16,13 +16,12 @@
  */
 package kafka.server
 
-import java.util
-import java.util.{Collections, Optional}
+import java.util.Optional
 
 import scala.collection.Seq
 import kafka.cluster.Partition
 import kafka.log.LogOffsetSnapshot
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicIdPartition, Uuid}
 import org.apache.kafka.common.errors.{FencedLeaderEpochException, NotLeaderOrFollowerException}
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
 import org.apache.kafka.common.protocol.Errors
@@ -39,8 +38,7 @@ class DelayedFetchTest extends EasyMockSupport {
 
   @Test
   def testFetchWithFencedEpoch(): Unit = {
-    val topicPartition = new TopicPartition("topic", 0)
-    val topicIds = Collections.singletonMap("topic", Uuid.randomUuid())
+    val topicIdPartition = new TopicIdPartition(Uuid.randomUuid(), 0, "topic")
     val fetchOffset = 500L
     val logStartOffset = 0L
     val currentLeaderEpoch = Optional.of[Integer](10)
@@ -48,11 +46,11 @@ class DelayedFetchTest extends EasyMockSupport {
 
     val fetchStatus = FetchPartitionStatus(
       startOffsetMetadata = LogOffsetMetadata(fetchOffset),
-      fetchInfo = new FetchRequest.PartitionData(fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch))
-    val fetchMetadata = buildFetchMetadata(replicaId, topicPartition, topicIds, fetchStatus)
+      fetchInfo = new FetchRequest.PartitionData(Uuid.ZERO_UUID, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch))
+    val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus)
 
     var fetchResultOpt: Option[FetchPartitionData] = None
-    def callback(responses: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
+    def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
       fetchResultOpt = Some(responses.head._2)
     }
 
@@ -66,7 +64,7 @@ class DelayedFetchTest extends EasyMockSupport {
 
     val partition: Partition = mock(classOf[Partition])
 
-    EasyMock.expect(replicaManager.getPartitionOrException(topicPartition))
+    EasyMock.expect(replicaManager.getPartitionOrException(topicIdPartition.topicPartition))
         .andReturn(partition)
     EasyMock.expect(partition.fetchOffsetSnapshot(
         currentLeaderEpoch,
@@ -74,7 +72,7 @@ class DelayedFetchTest extends EasyMockSupport {
         .andThrow(new FencedLeaderEpochException("Requested epoch has been fenced"))
     EasyMock.expect(replicaManager.isAddingReplica(EasyMock.anyObject(), EasyMock.anyInt())).andReturn(false)
 
-    expectReadFromReplica(replicaId, topicPartition, topicIds, fetchStatus.fetchInfo, Errors.FENCED_LEADER_EPOCH)
+    expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, Errors.FENCED_LEADER_EPOCH)
 
     replayAll()
 
@@ -88,8 +86,7 @@ class DelayedFetchTest extends EasyMockSupport {
 
   @Test
   def testNotLeaderOrFollower(): Unit = {
-    val topicPartition = new TopicPartition("topic", 0)
-    val topicIds = Collections.singletonMap("topic", Uuid.randomUuid())
+    val topicIdPartition = new TopicIdPartition(Uuid.randomUuid(), 0, "topic")
     val fetchOffset = 500L
     val logStartOffset = 0L
     val currentLeaderEpoch = Optional.of[Integer](10)
@@ -97,11 +94,11 @@ class DelayedFetchTest extends EasyMockSupport {
 
     val fetchStatus = FetchPartitionStatus(
       startOffsetMetadata = LogOffsetMetadata(fetchOffset),
-      fetchInfo = new FetchRequest.PartitionData(fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch))
-    val fetchMetadata = buildFetchMetadata(replicaId, topicPartition, topicIds, fetchStatus)
+      fetchInfo = new FetchRequest.PartitionData(Uuid.ZERO_UUID, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch))
+    val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus)
 
     var fetchResultOpt: Option[FetchPartitionData] = None
-    def callback(responses: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
+    def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
       fetchResultOpt = Some(responses.head._2)
     }
 
@@ -113,9 +110,9 @@ class DelayedFetchTest extends EasyMockSupport {
       clientMetadata = None,
       responseCallback = callback)
 
-    EasyMock.expect(replicaManager.getPartitionOrException(topicPartition))
-      .andThrow(new NotLeaderOrFollowerException(s"Replica for $topicPartition not available"))
-    expectReadFromReplica(replicaId, topicPartition, topicIds, fetchStatus.fetchInfo, Errors.NOT_LEADER_OR_FOLLOWER)
+    EasyMock.expect(replicaManager.getPartitionOrException(topicIdPartition.topicPartition))
+      .andThrow(new NotLeaderOrFollowerException(s"Replica for $topicIdPartition not available"))
+    expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, Errors.NOT_LEADER_OR_FOLLOWER)
     EasyMock.expect(replicaManager.isAddingReplica(EasyMock.anyObject(), EasyMock.anyInt())).andReturn(false)
 
     replayAll()
@@ -127,8 +124,7 @@ class DelayedFetchTest extends EasyMockSupport {
 
   @Test
   def testDivergingEpoch(): Unit = {
-    val topicPartition = new TopicPartition("topic", 0)
-    val topicIds = Collections.singletonMap("topic", Uuid.randomUuid())
+    val topicIdPartition = new TopicIdPartition(Uuid.randomUuid(), 0, "topic")
     val fetchOffset = 500L
     val logStartOffset = 0L
     val currentLeaderEpoch = Optional.of[Integer](10)
@@ -137,11 +133,11 @@ class DelayedFetchTest extends EasyMockSupport {
 
     val fetchStatus = FetchPartitionStatus(
       startOffsetMetadata = LogOffsetMetadata(fetchOffset),
-      fetchInfo = new FetchRequest.PartitionData(fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, lastFetchedEpoch))
-    val fetchMetadata = buildFetchMetadata(replicaId, topicPartition, topicIds, fetchStatus)
+      fetchInfo = new FetchRequest.PartitionData(topicIdPartition.topicId, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, lastFetchedEpoch))
+    val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus)
 
     var fetchResultOpt: Option[FetchPartitionData] = None
-    def callback(responses: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
+    def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
       fetchResultOpt = Some(responses.head._2)
     }
 
@@ -154,7 +150,7 @@ class DelayedFetchTest extends EasyMockSupport {
       responseCallback = callback)
 
     val partition: Partition = mock(classOf[Partition])
-    EasyMock.expect(replicaManager.getPartitionOrException(topicPartition)).andReturn(partition)
+    EasyMock.expect(replicaManager.getPartitionOrException(topicIdPartition.topicPartition)).andReturn(partition)
     val endOffsetMetadata = LogOffsetMetadata(messageOffset = 500L, segmentBaseOffset = 0L, relativePositionInSegment = 500)
     EasyMock.expect(partition.fetchOffsetSnapshot(
       currentLeaderEpoch,
@@ -162,12 +158,12 @@ class DelayedFetchTest extends EasyMockSupport {
       .andReturn(LogOffsetSnapshot(0L, endOffsetMetadata, endOffsetMetadata, endOffsetMetadata))
     EasyMock.expect(partition.lastOffsetForLeaderEpoch(currentLeaderEpoch, lastFetchedEpoch.get, fetchOnlyFromLeader = false))
       .andReturn(new EpochEndOffset()
-        .setPartition(topicPartition.partition)
+        .setPartition(topicIdPartition.partition)
         .setErrorCode(Errors.NONE.code)
         .setLeaderEpoch(lastFetchedEpoch.get)
         .setEndOffset(fetchOffset - 1))
     EasyMock.expect(replicaManager.isAddingReplica(EasyMock.anyObject(), EasyMock.anyInt())).andReturn(false)
-    expectReadFromReplica(replicaId, topicPartition, topicIds, fetchStatus.fetchInfo, Errors.NONE)
+    expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, Errors.NONE)
     replayAll()
 
     assertTrue(delayedFetch.tryComplete())
@@ -176,8 +172,7 @@ class DelayedFetchTest extends EasyMockSupport {
   }
 
   private def buildFetchMetadata(replicaId: Int,
-                                 topicPartition: TopicPartition,
-                                 topicIds: util.Map[String, Uuid],
+                                 topicIdPartition: TopicIdPartition,
                                  fetchStatus: FetchPartitionStatus): FetchMetadata = {
     FetchMetadata(fetchMinBytes = 1,
       fetchMaxBytes = maxBytes,
@@ -186,13 +181,11 @@ class DelayedFetchTest extends EasyMockSupport {
       fetchIsolation = FetchLogEnd,
       isFromFollower = true,
       replicaId = replicaId,
-      topicIds = topicIds,
-      fetchPartitionStatus = Seq((topicPartition, fetchStatus)))
+      fetchPartitionStatus = Seq((topicIdPartition, fetchStatus)))
   }
 
   private def expectReadFromReplica(replicaId: Int,
-                                    topicPartition: TopicPartition,
-                                    topicIds: util.Map[String, Uuid],
+                                    topicIdPartition: TopicIdPartition,
                                     fetchPartitionData: FetchRequest.PartitionData,
                                     error: Errors): Unit = {
     EasyMock.expect(replicaManager.readFromLocalLog(
@@ -201,11 +194,10 @@ class DelayedFetchTest extends EasyMockSupport {
       fetchIsolation = FetchLogEnd,
       fetchMaxBytes = maxBytes,
       hardMaxBytesLimit = false,
-      readPartitionInfo = Seq((topicPartition, fetchPartitionData)),
-      topicIds = topicIds,
+      readPartitionInfo = Seq((topicIdPartition, fetchPartitionData)),
       clientMetadata = None,
       quota = replicaQuota))
-      .andReturn(Seq((topicPartition, buildReadResult(error))))
+      .andReturn(Seq((topicIdPartition, buildReadResult(error))))
   }
 
   private def buildReadResult(error: Errors): LogReadResult = {
diff --git a/core/src/test/scala/integration/kafka/server/FetchRequestBetweenDifferentIbpTest.scala b/core/src/test/scala/integration/kafka/server/FetchRequestBetweenDifferentIbpTest.scala
index d10bafd..37ac4a2 100644
--- a/core/src/test/scala/integration/kafka/server/FetchRequestBetweenDifferentIbpTest.scala
+++ b/core/src/test/scala/integration/kafka/server/FetchRequestBetweenDifferentIbpTest.scala
@@ -73,7 +73,7 @@ class FetchRequestBetweenDifferentIbpTest extends BaseRequestTest {
     producer.send(record2)
 
     consumer.assign(asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1)))
-    val count = consumer.poll(Duration.ofMillis(5000)).count() + consumer.poll(Duration.ofMillis(5000)).count()
+    val count = consumer.poll(Duration.ofMillis(1500)).count() + consumer.poll(Duration.ofMillis(1500)).count()
     assertEquals(2, count)
   }
 
@@ -109,7 +109,7 @@ class FetchRequestBetweenDifferentIbpTest extends BaseRequestTest {
 
     consumer.assign(asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1)))
 
-    val count = consumer.poll(Duration.ofMillis(5000)).count() + consumer.poll(Duration.ofMillis(5000)).count()
+    val count = consumer.poll(Duration.ofMillis(1500)).count() + consumer.poll(Duration.ofMillis(1500)).count()
     assertEquals(2, count)
 
     // Make controller version2
@@ -128,7 +128,7 @@ class FetchRequestBetweenDifferentIbpTest extends BaseRequestTest {
     // Assign this new topic in addition to the old topics.
     consumer.assign(asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1), new TopicPartition(topic2, 0)))
 
-    val count2 = consumer.poll(Duration.ofMillis(5000)).count() + consumer.poll(Duration.ofMillis(5000)).count()
+    val count2 = consumer.poll(Duration.ofMillis(1500)).count() + consumer.poll(Duration.ofMillis(1500)).count()
     assertEquals(2, count2)
   }
 
diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
index 63f972f..148a903 100644
--- a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala
@@ -20,6 +20,7 @@ package kafka.server
 import java.nio.ByteBuffer
 import java.util.Optional
 import java.util.concurrent.atomic.AtomicInteger
+
 import kafka.cluster.BrokerEndPoint
 import kafka.log.LogAppendInfo
 import kafka.message.NoCompressionCodec
@@ -28,9 +29,7 @@ import kafka.server.AbstractFetcherThread.ReplicaFetch
 import kafka.server.AbstractFetcherThread.ResultWithPartitions
 import kafka.utils.Implicits.MapExtensionMethods
 import kafka.utils.TestUtils
-import org.apache.kafka.common.KafkaException
-import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.Uuid
+import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid}
 import org.apache.kafka.common.errors.{FencedLeaderEpochException, UnknownLeaderEpochException, UnknownTopicIdException}
 import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
@@ -46,7 +45,6 @@ import org.junit.jupiter.api.{BeforeEach, Test}
 import scala.jdk.CollectionConverters._
 import scala.collection.{Map, Set, mutable}
 import scala.util.Random
-
 import scala.collection.mutable.ArrayBuffer
 import scala.compat.java8.OptionConverters._
 
@@ -150,14 +148,11 @@ class AbstractFetcherThreadTest {
     val partition = new TopicPartition("topic", 0)
     val fetchBackOffMs = 250
 
-    class ErrorMockFetcherThread(fetchBackOffMs: Int)
-      extends MockFetcherThread(fetchBackOffMs =  fetchBackOffMs) {
-
+    val fetcher = new MockFetcherThread(fetchBackOffMs = fetchBackOffMs) {
       override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
-         throw new UnknownTopicIdException("Topic ID was unknown as expected for this test")
+        throw new UnknownTopicIdException("Topic ID was unknown as expected for this test")
       }
     }
-    val fetcher = new ErrorMockFetcherThread(fetchBackOffMs = fetchBackOffMs)
 
     fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0))
     fetcher.addPartitions(Map(partition -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
@@ -186,6 +181,50 @@ class AbstractFetcherThreadTest {
   }
 
   @Test
+  def testPartitionsInError(): Unit = {
+    val partition1 = new TopicPartition("topic1", 0)
+    val partition2 = new TopicPartition("topic2", 0)
+    val partition3 = new TopicPartition("topic3", 0)
+    val fetchBackOffMs = 250
+
+    val fetcher = new MockFetcherThread(fetchBackOffMs = fetchBackOffMs) {
+      override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = {
+        Map(partition1 -> new FetchData().setErrorCode(Errors.UNKNOWN_TOPIC_ID.code),
+          partition2 -> new FetchData().setErrorCode(Errors.INCONSISTENT_TOPIC_ID.code),
+          partition3 -> new FetchData().setErrorCode(Errors.NONE.code))
+      }
+    }
+
+    fetcher.setReplicaState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.addPartitions(Map(partition1 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
+    fetcher.setReplicaState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.addPartitions(Map(partition2 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
+    fetcher.setReplicaState(partition3, MockFetcherThread.PartitionState(leaderEpoch = 0))
+    fetcher.addPartitions(Map(partition3 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0)))
+
+    val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0,
+      new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))
+    val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L)
+    fetcher.setLeaderState(partition1, leaderState)
+    fetcher.setLeaderState(partition2, leaderState)
+    fetcher.setLeaderState(partition3, leaderState)
+
+    fetcher.doWork()
+
+    val partition1FetchState = fetcher.fetchState(partition1)
+    val partition2FetchState = fetcher.fetchState(partition2)
+    val partition3FetchState = fetcher.fetchState(partition3)
+    assertTrue(partition1FetchState.isDefined)
+    assertTrue(partition2FetchState.isDefined)
+    assertTrue(partition3FetchState.isDefined)
+
+    // Only the partitions with errors should be delayed.
+    assertTrue(partition1FetchState.get.isDelayed)
+    assertTrue(partition2FetchState.get.isDelayed)
+    assertFalse(partition3FetchState.get.isDelayed)
+  }
+
+  @Test
   def testFencedTruncation(): Unit = {
     val partition = new TopicPartition("topic", 0)
     val fetcher = new MockFetcherThread
@@ -1098,11 +1137,12 @@ class AbstractFetcherThreadTest {
             state.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava
           else
             Optional.empty[Integer]
-          fetchData.put(partition, new FetchRequest.PartitionData(state.fetchOffset, replicaState.logStartOffset,
+          fetchData.put(partition,
+            new FetchRequest.PartitionData(state.topicId.getOrElse(Uuid.ZERO_UUID), state.fetchOffset, replicaState.logStartOffset,
             1024 * 1024, Optional.of[Integer](state.currentLeaderEpoch), lastFetchedEpoch))
         }
       }
-      val fetchRequest = FetchRequest.Builder.forReplica(version, replicaId, 0, 1, fetchData.asJava, topicIds.asJava)
+      val fetchRequest = FetchRequest.Builder.forReplica(version, replicaId, 0, 1, fetchData.asJava)
       val fetchRequestOpt =
         if (fetchData.isEmpty)
           None
diff --git a/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala b/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala
index ea9f41d..526e6b5 100644
--- a/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala
@@ -79,7 +79,7 @@ class BaseClientQuotaManagerTest {
 
   protected def throttle(quotaManager: ClientQuotaManager, user: String, clientId: String, throttleTimeMs: Int,
                          channelThrottlingCallback: ThrottleCallback): Unit = {
-    val (_, request) = buildRequest(FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 1000, new util.HashMap[TopicPartition, PartitionData], Collections.emptyMap()))
+    val (_, request) = buildRequest(FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 1000, new util.HashMap[TopicPartition, PartitionData]))
     quotaManager.throttle(request, channelThrottlingCallback, throttleTimeMs)
   }
 }
diff --git a/core/src/test/scala/unit/kafka/server/BaseFetchRequestTest.scala b/core/src/test/scala/unit/kafka/server/BaseFetchRequestTest.scala
index aacd8c1..8ef7424 100644
--- a/core/src/test/scala/unit/kafka/server/BaseFetchRequestTest.scala
+++ b/core/src/test/scala/unit/kafka/server/BaseFetchRequestTest.scala
@@ -19,15 +19,15 @@ package kafka.server
 import kafka.log.LogConfig
 import kafka.utils.TestUtils
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata}
-import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.record.Record
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse}
 import org.apache.kafka.common.serialization.StringSerializer
 import org.junit.jupiter.api.AfterEach
-
 import java.util
 import java.util.{Optional, Properties}
+
 import scala.collection.Seq
 import scala.jdk.CollectionConverters._
 
@@ -49,8 +49,7 @@ class BaseFetchRequestTest extends BaseRequestTest {
   protected def createFetchRequest(maxResponseBytes: Int, maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition],
                                    offsetMap: Map[TopicPartition, Long],
                                    version: Short): FetchRequest = {
-    val topicIds = getTopicIds().asJava
-    FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, topicPartitions, offsetMap), topicIds)
+    FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, topicPartitions, offsetMap))
       .setMaxBytes(maxResponseBytes).build()
   }
 
@@ -58,7 +57,8 @@ class BaseFetchRequestTest extends BaseRequestTest {
                                    offsetMap: Map[TopicPartition, Long] = Map.empty): util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = {
     val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
     topicPartitions.foreach { tp =>
-      partitionMap.put(tp, new FetchRequest.PartitionData(offsetMap.getOrElse(tp, 0), 0L, maxPartitionBytes,
+      partitionMap.put(tp,
+        new FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID), offsetMap.getOrElse(tp, 0), 0L, maxPartitionBytes,
         Optional.empty()))
     }
     partitionMap
diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala
index 7038678a..07a20f9 100644
--- a/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala
+++ b/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala
@@ -21,7 +21,7 @@ import java.util.{Optional, Properties}
 import kafka.log.LogConfig
 import kafka.utils.TestUtils
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord}
-import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse}
 import org.apache.kafka.common.serialization.StringSerializer
@@ -71,10 +71,11 @@ class FetchRequestDownConversionConfigTest extends BaseRequestTest {
   }
 
   private def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition],
+                                 topicIds: Map[String, Uuid],
                                  offsetMap: Map[TopicPartition, Long] = Map.empty): util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = {
     val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
     topicPartitions.foreach { tp =>
-      partitionMap.put(tp, new FetchRequest.PartitionData(offsetMap.getOrElse(tp, 0), 0L,
+      partitionMap.put(tp, new FetchRequest.PartitionData(topicIds.getOrElse(tp.topic, Uuid.ZERO_UUID), offsetMap.getOrElse(tp, 0), 0L,
         maxPartitionBytes, Optional.empty()))
     }
     partitionMap
@@ -95,7 +96,7 @@ class FetchRequestDownConversionConfigTest extends BaseRequestTest {
     val topicNames = topicIds.map(_.swap)
     topicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get())
     val fetchRequest = FetchRequest.Builder.forConsumer(1, Int.MaxValue, 0, createPartitionMap(1024,
-      topicPartitions), topicIds.asJava).build(1)
+      topicPartitions, topicIds.toMap)).build(1)
     val fetchResponse = sendFetchRequest(topicMap.head._2, fetchRequest)
     val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1)
     topicPartitions.foreach(tp => assertEquals(Errors.UNSUPPORTED_VERSION, Errors.forCode(fetchResponseData.get(tp).errorCode)))
@@ -112,7 +113,7 @@ class FetchRequestDownConversionConfigTest extends BaseRequestTest {
     val topicNames = topicIds.map(_.swap)
     topicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get())
     val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, Int.MaxValue, 0, createPartitionMap(1024,
-      topicPartitions), topicIds.asJava).build()
+      topicPartitions, topicIds.toMap)).build()
     val fetchResponse = sendFetchRequest(topicMap.head._2, fetchRequest)
     val fetchResponseData = fetchResponse.responseData(topicNames.asJava, ApiKeys.FETCH.latestVersion)
     topicPartitions.foreach(tp => assertEquals(Errors.NONE, Errors.forCode(fetchResponseData.get(tp).errorCode)))
@@ -129,7 +130,7 @@ class FetchRequestDownConversionConfigTest extends BaseRequestTest {
     val topicNames = topicIds.map(_.swap)
     topicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get())
     val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, Int.MaxValue, 0, createPartitionMap(1024,
-      topicPartitions), topicIds.asJava).build(12)
+      topicPartitions, topicIds.toMap)).build(12)
     val fetchResponse = sendFetchRequest(topicMap.head._2, fetchRequest)
     val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 12)
     topicPartitions.foreach(tp => assertEquals(Errors.NONE, Errors.forCode(fetchResponseData.get(tp).errorCode)))
@@ -157,7 +158,7 @@ class FetchRequestDownConversionConfigTest extends BaseRequestTest {
 
     allTopics.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get())
     val fetchRequest = FetchRequest.Builder.forConsumer(1, Int.MaxValue, 0, createPartitionMap(1024,
-      allTopics), topicIds.asJava).build(1)
+      allTopics, topicIds.toMap)).build(1)
     val fetchResponse = sendFetchRequest(leaderId, fetchRequest)
 
     val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1)
@@ -186,7 +187,7 @@ class FetchRequestDownConversionConfigTest extends BaseRequestTest {
 
     allTopicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic, "key", "value")).get())
     val fetchRequest = FetchRequest.Builder.forReplica(1, 1, Int.MaxValue, 0,
-      createPartitionMap(1024, allTopicPartitions), topicIds.asJava).build()
+      createPartitionMap(1024, allTopicPartitions, topicIds.toMap)).build()
     val fetchResponse = sendFetchRequest(leaderId, fetchRequest)
     val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1)
     allTopicPartitions.foreach(tp => assertEquals(Errors.NONE, Errors.forCode(fetchResponseData.get(tp).errorCode)))
diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestMaxBytesTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestMaxBytesTest.scala
index db57ec9..5bf43b3 100644
--- a/core/src/test/scala/unit/kafka/server/FetchRequestMaxBytesTest.scala
+++ b/core/src/test/scala/unit/kafka/server/FetchRequestMaxBytesTest.scala
@@ -20,7 +20,7 @@ package kafka.server
 import kafka.log.LogConfig
 import kafka.utils.TestUtils
 import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord}
-import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.{TopicPartition, Uuid}
 import org.apache.kafka.common.requests.FetchRequest.PartitionData
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse}
 import org.junit.jupiter.api.Assertions._
@@ -115,7 +115,7 @@ class FetchRequestMaxBytesTest extends BaseRequestTest {
     val response = sendFetchRequest(0,
       FetchRequest.Builder.forConsumer(3, Int.MaxValue, 0,
         Map(testTopicPartition ->
-          new PartitionData(fetchOffset, 0, Integer.MAX_VALUE, Optional.empty())).asJava, getTopicIds().asJava).build(3))
+          new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, Integer.MAX_VALUE, Optional.empty())).asJava).build(3))
     val records = FetchResponse.recordsOrFail(response.responseData(getTopicNames().asJava, 3).get(testTopicPartition)).records()
     assertNotNull(records)
     val recordsList = records.asScala.toList
diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala
index 2719bae..82c990d 100644
--- a/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala
+++ b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala
@@ -24,7 +24,7 @@ import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.RecordBatch
 import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata}
 import org.apache.kafka.common.serialization.{ByteArraySerializer, StringSerializer}
-import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid}
+import org.apache.kafka.common.{IsolationLevel, TopicIdPartition, TopicPartition, Uuid}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
 
@@ -153,7 +153,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
     producer.send(new ProducerRecord(topicPartition.topic, topicPartition.partition,
       "key", new String(new Array[Byte](maxPartitionBytes + 1)))).get
     val fetchRequest = FetchRequest.Builder.forConsumer(4, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes,
-      Seq(topicPartition)), topicIds).isolationLevel(IsolationLevel.READ_COMMITTED).build(4)
+      Seq(topicPartition))).isolationLevel(IsolationLevel.READ_COMMITTED).build(4)
     val fetchResponse = sendFetchRequest(leaderId, fetchRequest)
     val partitionData = fetchResponse.responseData(topicNames, 4).get(topicPartition)
     assertEquals(Errors.NONE.code, partitionData.errorCode)
@@ -178,14 +178,14 @@ class FetchRequestTest extends BaseFetchRequestTest {
 
     // Send the fetch request to the non-replica and verify the error code
     val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, Int.MaxValue, 0, createPartitionMap(1024,
-      Seq(topicPartition)), topicIds).build()
+      Seq(topicPartition))).build()
     val fetchResponse = sendFetchRequest(nonReplicaId, fetchRequest)
     val partitionData = fetchResponse.responseData(topicNames, ApiKeys.FETCH.latestVersion).get(topicPartition)
     assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, partitionData.errorCode)
 
     // Repeat with request that does not use topic IDs
     val oldFetchRequest = FetchRequest.Builder.forConsumer(12, Int.MaxValue, 0, createPartitionMap(1024,
-      Seq(topicPartition)), topicIds).build()
+      Seq(topicPartition))).build()
     val oldFetchResponse = sendFetchRequest(nonReplicaId, oldFetchRequest)
     val oldPartitionData = oldFetchResponse.responseData(topicNames, 12).get(topicPartition)
     assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, oldPartitionData.errorCode)
@@ -226,9 +226,10 @@ class FetchRequestTest extends BaseFetchRequestTest {
     // Build a fetch request in the middle of the second epoch, but with the first epoch
     val fetchOffset = secondEpochEndOffset + (secondEpochEndOffset - firstEpochEndOffset) / 2
     val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    partitionMap.put(topicPartition, new FetchRequest.PartitionData(fetchOffset, 0L, 1024,
+    partitionMap.put(topicPartition,
+      new FetchRequest.PartitionData(topicIds.getOrDefault(topic, Uuid.ZERO_UUID), fetchOffset, 0L, 1024,
       Optional.of(secondLeaderEpoch), Optional.of(firstLeaderEpoch)))
-    val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap, topicIds).build()
+    val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap).build()
 
     // Validate the expected truncation
     val fetchResponse = sendFetchRequest(secondLeaderId, fetchRequest)
@@ -262,8 +263,9 @@ class FetchRequestTest extends BaseFetchRequestTest {
       val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
       val topicIds = getTopicIds().asJava
       val topicNames = topicIds.asScala.map(_.swap).asJava
-      partitionMap.put(topicPartition, new FetchRequest.PartitionData(0L, 0L, 1024, leaderEpoch))
-      val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap, topicIds).build()
+      partitionMap.put(topicPartition,
+        new FetchRequest.PartitionData(topicIds.getOrDefault(topic, Uuid.ZERO_UUID), 0L, 0L, 1024, leaderEpoch))
+      val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap).build()
       val fetchResponse = sendFetchRequest(brokerId, fetchRequest)
       val partitionData = fetchResponse.responseData(topicNames, version).get(topicPartition)
       assertEquals(error.code, partitionData.errorCode)
@@ -322,11 +324,11 @@ class FetchRequestTest extends BaseFetchRequestTest {
                                        destinationBrokerId: Int,
                                        version: Short): Unit = {
     val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    partitionMap.put(topicPartition, new FetchRequest.PartitionData(0L, 0L, 1024,
-      Optional.of(leaderEpoch)))
     val topicIds = getTopicIds().asJava
     val topicNames = topicIds.asScala.map(_.swap).asJava
-    val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap, topicIds)
+    partitionMap.put(topicPartition, new FetchRequest.PartitionData(topicIds.getOrDefault(topicPartition.topic, Uuid.ZERO_UUID),
+      0L, 0L, 1024, Optional.of(leaderEpoch)))
+    val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap)
       .metadata(JFetchMetadata.INITIAL)
       .build()
     val fetchResponse = sendFetchRequest(destinationBrokerId, fetchRequest)
@@ -336,8 +338,8 @@ class FetchRequestTest extends BaseFetchRequestTest {
                                     sessionFetchEpoch: Int,
                                     leaderEpoch: Optional[Integer]): Unit = {
       val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-      partitionMap.put(topicPartition, new FetchRequest.PartitionData(0L, 0L, 1024, leaderEpoch))
-      val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap, topicIds)
+      partitionMap.put(topicPartition, new FetchRequest.PartitionData(topicIds.getOrDefault(topicPartition.topic, Uuid.ZERO_UUID), 0L, 0L, 1024, leaderEpoch))
+      val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap)
         .metadata(new JFetchMetadata(sessionId, sessionFetchEpoch))
         .build()
       val fetchResponse = sendFetchRequest(destinationBrokerId, fetchRequest)
@@ -383,7 +385,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
 
     def fetch(version: Short, maxPartitionBytes: Int, closeAfterPartialResponse: Boolean): Option[FetchResponse] = {
       val fetchRequest = FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes,
-        Seq(topicPartition)), topicIds).build(version)
+        Seq(topicPartition))).build(version)
 
       val socket = connect(brokerSocketServer(leaderId))
       try {
@@ -453,7 +455,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
       // all batches we are interested in.
       while (batchesReceived < expectedNumBatches) {
         val fetchRequest = FetchRequest.Builder.forConsumer(requestVersion, Int.MaxValue, 0, createPartitionMap(Int.MaxValue,
-          Seq(topicPartition), Map(topicPartition -> currentFetchOffset)), topicIds).build(requestVersion)
+          Seq(topicPartition), Map(topicPartition -> currentFetchOffset))).build(requestVersion)
         val fetchResponse = sendFetchRequest(leaderId, fetchRequest)
 
         // validate response
@@ -506,10 +508,10 @@ class FetchRequestTest extends BaseFetchRequestTest {
   def testCreateIncrementalFetchWithPartitionsInErrorV12(): Unit = {
     def createFetchRequest(topicPartitions: Seq[TopicPartition],
                            metadata: JFetchMetadata,
-                           toForget: Seq[TopicPartition]): FetchRequest =
+                           toForget: Seq[TopicIdPartition]): FetchRequest =
       FetchRequest.Builder.forConsumer(12, Int.MaxValue, 0,
-        createPartitionMap(Integer.MAX_VALUE, topicPartitions, Map.empty), Map[String, Uuid]().asJava)
-        .toForget(toForget.asJava)
+        createPartitionMap(Integer.MAX_VALUE, topicPartitions, Map.empty))
+        .removed(toForget.asJava)
         .metadata(metadata)
         .build()
     val foo0 = new TopicPartition("foo", 0)
@@ -563,28 +565,44 @@ class FetchRequestTest extends BaseFetchRequestTest {
    */
   @Test
   def testFetchWithPartitionsWithIdError(): Unit = {
-    def createFetchRequest(topicPartitions: Seq[TopicPartition],
+    def createFetchRequest(fetchData: util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData],
                            metadata: JFetchMetadata,
-                           toForget: Seq[TopicPartition],
-                           topicIds: scala.collection.Map[String, Uuid]): FetchRequest = {
-      FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), Int.MaxValue, 0,
-        createPartitionMap(Integer.MAX_VALUE, topicPartitions, Map.empty), topicIds.asJava)
-        .toForget(toForget.asJava)
+                           toForget: Seq[TopicIdPartition]): FetchRequest = {
+      FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), Int.MaxValue, 0, fetchData)
+        .removed(toForget.asJava)
         .metadata(metadata)
         .build()
     }
+
     val foo0 = new TopicPartition("foo", 0)
     val foo1 = new TopicPartition("foo", 1)
     createTopic("foo", Map(0 -> List(0, 1), 1 -> List(0, 2)))
     val topicIds = getTopicIds()
-    val topicIDsWithUnknown = topicIds ++ Map("bar" -> Uuid.randomUuid())
+    val topicIdsWithUnknown = topicIds ++ Map("bar" -> Uuid.randomUuid())
     val bar0 = new TopicPartition("bar", 0)
-    val req1 = createFetchRequest(List(foo0, foo1, bar0), JFetchMetadata.INITIAL, Nil, topicIDsWithUnknown)
+
+    def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition],
+                           offsetMap: Map[TopicPartition, Long]): util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = {
+      val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+      topicPartitions.foreach { tp =>
+        partitionMap.put(tp,
+          new FetchRequest.PartitionData(topicIdsWithUnknown.getOrElse(tp.topic, Uuid.ZERO_UUID), offsetMap.getOrElse(tp, 0),
+            0L, maxPartitionBytes, Optional.empty()))
+      }
+      partitionMap
+    }
+
+    val req1 = createFetchRequest( createPartitionMap(Integer.MAX_VALUE, List(foo0, foo1, bar0), Map.empty), JFetchMetadata.INITIAL, Nil)
     val resp1 = sendFetchRequest(0, req1)
-    assertEquals(Errors.UNKNOWN_TOPIC_ID, resp1.error())
-    val topicNames1 = topicIDsWithUnknown.map(_.swap).asJava
+    assertEquals(Errors.NONE, resp1.error())
+    val topicNames1 = topicIdsWithUnknown.map(_.swap).asJava
     val responseData1 = resp1.responseData(topicNames1, ApiKeys.FETCH.latestVersion())
-    assertEquals(0, responseData1.size())
+    assertTrue(responseData1.containsKey(foo0))
+    assertTrue(responseData1.containsKey(foo1))
+    assertTrue(responseData1.containsKey(bar0))
+    assertEquals(Errors.NONE.code, responseData1.get(foo0).errorCode)
+    assertEquals(Errors.NONE.code, responseData1.get(foo1).errorCode)
+    assertEquals(Errors.UNKNOWN_TOPIC_ID.code, responseData1.get(bar0).errorCode)
   }
 
   @Test
@@ -609,7 +627,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
 
     // fetch request with version below v10: UNSUPPORTED_COMPRESSION_TYPE error occurs
     val req0 = new FetchRequest.Builder(0, 9, -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map.empty), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map.empty))
       .setMaxBytes(800).build()
 
     val res0 = sendFetchRequest(leaderId, req0)
@@ -618,7 +636,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
 
     // fetch request with version 10: works fine!
     val req1= new FetchRequest.Builder(0, 10, -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map.empty), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map.empty))
       .setMaxBytes(800).build()
     val res1 = sendFetchRequest(leaderId, req1)
     val data1 = res1.responseData(topicNames, 10).get(topicPartition)
@@ -626,7 +644,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
     assertEquals(3, records(data1).size)
 
     val req2 = new FetchRequest.Builder(ApiKeys.FETCH.latestVersion(), ApiKeys.FETCH.latestVersion(), -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map.empty), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map.empty))
       .setMaxBytes(800).build()
     val res2 = sendFetchRequest(leaderId, req2)
     val data2 = res2.responseData(topicNames, ApiKeys.FETCH.latestVersion()).get(topicPartition)
@@ -635,12 +653,6 @@ class FetchRequestTest extends BaseFetchRequestTest {
   }
 
   @Test
-  def testPartitionDataEquals(): Unit = {
-    assertEquals(new FetchRequest.PartitionData(300, 0L, 300, Optional.of(300)),
-    new FetchRequest.PartitionData(300, 0L, 300, Optional.of(300)))
-  }
-
-  @Test
   def testZStdCompressedRecords(): Unit = {
     // Producer compressed topic
     val topicConfig = Map(LogConfig.CompressionTypeProp -> ProducerCompressionCodec.name)
@@ -671,7 +683,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
     // gzip compressed record is returned with down-conversion.
     // zstd compressed record raises UNSUPPORTED_COMPRESSION_TYPE error.
     val req0 = new FetchRequest.Builder(0, 1, -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map.empty), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map.empty))
       .setMaxBytes(800)
       .build()
 
@@ -681,7 +693,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
     assertEquals(1, records(data0).size)
 
     val req1 = new FetchRequest.Builder(0, 1, -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map(topicPartition -> 1L)), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map(topicPartition -> 1L)))
       .setMaxBytes(800).build()
 
     val res1 = sendFetchRequest(leaderId, req1)
@@ -692,7 +704,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
     // gzip compressed record is returned with down-conversion.
     // zstd compressed record raises UNSUPPORTED_COMPRESSION_TYPE error.
     val req2 = new FetchRequest.Builder(2, 3, -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map.empty), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map.empty))
       .setMaxBytes(800).build()
 
     val res2 = sendFetchRequest(leaderId, req2)
@@ -701,7 +713,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
     assertEquals(1, records(data2).size)
 
     val req3 = new FetchRequest.Builder(0, 1, -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map(topicPartition -> 1L)), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map(topicPartition -> 1L)))
       .setMaxBytes(800).build()
 
     val res3 = sendFetchRequest(leaderId, req3)
@@ -710,7 +722,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
 
     // fetch request with version 10: works fine!
     val req4 = new FetchRequest.Builder(0, 10, -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map.empty), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map.empty))
       .setMaxBytes(800).build()
     val res4 = sendFetchRequest(leaderId, req4)
     val data4 = res4.responseData(topicNames, 10).get(topicPartition)
@@ -718,7 +730,7 @@ class FetchRequestTest extends BaseFetchRequestTest {
     assertEquals(3, records(data4).size)
 
     val req5 = new FetchRequest.Builder(0, ApiKeys.FETCH.latestVersion(), -1, Int.MaxValue, 0,
-      createPartitionMap(300, Seq(topicPartition), Map.empty), topicIds)
+      createPartitionMap(300, Seq(topicPartition), Map.empty))
       .setMaxBytes(800).build()
     val res5 = sendFetchRequest(leaderId, req5)
     val data5 = res5.responseData(topicNames, ApiKeys.FETCH.latestVersion()).get(topicPartition)
diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestWithLegacyMessageFormatTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestWithLegacyMessageFormatTest.scala
index 84b8d35..2f78b9d 100644
--- a/core/src/test/scala/unit/kafka/server/FetchRequestWithLegacyMessageFormatTest.scala
+++ b/core/src/test/scala/unit/kafka/server/FetchRequestWithLegacyMessageFormatTest.scala
@@ -25,6 +25,7 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue}
 import org.junit.jupiter.api.Test
 
 import java.util.Properties
+
 import scala.annotation.nowarn
 import scala.collection.Seq
 import scala.jdk.CollectionConverters._
@@ -56,7 +57,7 @@ class FetchRequestWithLegacyMessageFormatTest extends BaseFetchRequestTest {
       "key", new String(new Array[Byte](maxPartitionBytes + 1)))).get
     val fetchVersion: Short = 2
     val fetchRequest = FetchRequest.Builder.forConsumer(fetchVersion, Int.MaxValue, 0,
-      createPartitionMap(maxPartitionBytes, Seq(topicPartition)), topicIds).build(fetchVersion)
+      createPartitionMap(maxPartitionBytes, Seq(topicPartition))).build(fetchVersion)
     val fetchResponse = sendFetchRequest(leaderId, fetchRequest)
     val partitionData = fetchResponse.responseData(topicNames, fetchVersion).get(topicPartition)
     assertEquals(Errors.NONE.code, partitionData.errorCode)
diff --git a/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
index d80504bf..538a061 100755
--- a/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
+++ b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala
@@ -17,17 +17,19 @@
 package kafka.server
 
 import kafka.utils.MockTime
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.common.message.FetchResponseData
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.CompressionType
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.record.SimpleRecord
 import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INVALID_SESSION_ID}
-import org.apache.kafka.common.requests.{FetchRequest, FetchMetadata => JFetchMetadata}
+import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata}
 import org.apache.kafka.common.utils.Utils
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{Test, Timeout}
+import org.junit.jupiter.params.ParameterizedTest
+import org.junit.jupiter.params.provider.{Arguments, MethodSource, ValueSource}
 
 import scala.jdk.CollectionConverters._
 import java.util
@@ -57,14 +59,12 @@ class FetchSessionTest {
     assertEquals(sessionIds.size, cache.size)
   }
 
-  private def dummyCreate(size: Int): (FetchSession.CACHE_MAP, FetchSession.TOPIC_ID_MAP) = {
+  private def dummyCreate(size: Int): FetchSession.CACHE_MAP = {
     val cacheMap = new FetchSession.CACHE_MAP(size)
-    val topicIds = new util.HashMap[String, Uuid]()
-    topicIds.put("topic", Uuid.randomUuid())
     for (i <- 0 until size) {
-      cacheMap.add(new CachedPartition("test", i, topicIds.get("test")))
+      cacheMap.add(new CachedPartition("test", Uuid.randomUuid(), i))
     }
-    (cacheMap, topicIds)
+    cacheMap
   }
 
   @Test
@@ -133,22 +133,21 @@ class FetchSessionTest {
     assertEquals(3, cache.totalPartitions)
   }
 
-  private val EMPTY_PART_LIST = Collections.unmodifiableList(new util.ArrayList[TopicPartition]())
+  private val EMPTY_PART_LIST = Collections.unmodifiableList(new util.ArrayList[TopicIdPartition]())
 
   def createRequest(metadata: JFetchMetadata,
                     fetchData: util.Map[TopicPartition, FetchRequest.PartitionData],
-                    topicIds: util.Map[String, Uuid],
-                    toForget: util.List[TopicPartition], isFromFollower: Boolean): FetchRequest = {
-    new FetchRequest.Builder(ApiKeys.FETCH.latestVersion, ApiKeys.FETCH.latestVersion, if (isFromFollower) 1 else FetchRequest.CONSUMER_REPLICA_ID,
-      0, 0, fetchData, topicIds).metadata(metadata).toForget(toForget).build
+                    toForget: util.List[TopicIdPartition], isFromFollower: Boolean,
+                    version: Short = ApiKeys.FETCH.latestVersion): FetchRequest = {
+    new FetchRequest.Builder(version, version, if (isFromFollower) 1 else FetchRequest.CONSUMER_REPLICA_ID,
+      0, 0, fetchData).metadata(metadata).removed(toForget).build
   }
 
   def createRequestWithoutTopicIds(metadata: JFetchMetadata,
                     fetchData: util.Map[TopicPartition, FetchRequest.PartitionData],
-                    topicIds: util.Map[String, Uuid],
-                    toForget: util.List[TopicPartition], isFromFollower: Boolean): FetchRequest = {
+                    toForget: util.List[TopicIdPartition], isFromFollower: Boolean): FetchRequest = {
     new FetchRequest.Builder(12, 12, if (isFromFollower) 1 else FetchRequest.CONSUMER_REPLICA_ID,
-      0, 0, fetchData, topicIds).metadata(metadata).toForget(toForget).build
+      0, 0, fetchData).metadata(metadata).removed(toForget).build
   }
 
   @Test
@@ -157,38 +156,38 @@ class FetchSessionTest {
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
 
-    val tp0 = new TopicPartition("foo", 0)
-    val tp1 = new TopicPartition("foo", 1)
-    val tp2 = new TopicPartition("bar", 1)
     val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid()).asJava
+    val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0))
+    val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1))
+    val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1))
     val topicNames = topicIds.asScala.map(_.swap).asJava
 
-    def cachedLeaderEpochs(context: FetchContext): Map[TopicPartition, Optional[Integer]] = {
-      val mapBuilder = Map.newBuilder[TopicPartition, Optional[Integer]]
-      context.foreachPartition((tp, _, data) => mapBuilder += tp -> data.currentLeaderEpoch)
+    def cachedLeaderEpochs(context: FetchContext): Map[TopicIdPartition, Optional[Integer]] = {
+      val mapBuilder = Map.newBuilder[TopicIdPartition, Optional[Integer]]
+      context.foreachPartition((tp, data) => mapBuilder += tp -> data.currentLeaderEpoch)
       mapBuilder.result()
     }
 
     val requestData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    requestData1.put(tp0, new FetchRequest.PartitionData(0, 0, 100, Optional.empty()))
-    requestData1.put(tp1, new FetchRequest.PartitionData(10, 0, 100, Optional.of(1)))
-    requestData1.put(tp2, new FetchRequest.PartitionData(10, 0, 100, Optional.of(2)))
+    requestData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.empty()))
+    requestData1.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.of(1)))
+    requestData1.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(2)))
 
-    val request1 = createRequest(JFetchMetadata.INITIAL, requestData1, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequest(JFetchMetadata.INITIAL, requestData1, EMPTY_PART_LIST, false)
     val context1 = fetchManager.newContext(
       request1.version,
       request1.metadata,
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     val epochs1 = cachedLeaderEpochs(context1)
     assertEquals(Optional.empty(), epochs1(tp0))
     assertEquals(Optional.of(1), epochs1(tp1))
     assertEquals(Optional.of(2), epochs1(tp2))
 
-    val response = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val response = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     response.put(tp0, new FetchResponseData.PartitionData()
       .setPartitionIndex(tp0.partition)
       .setHighWatermark(100)
@@ -209,14 +208,14 @@ class FetchSessionTest {
 
     // With no changes, the cached epochs should remain the same
     val requestData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val request2 = createRequest(new JFetchMetadata(sessionId, 1), requestData2, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequest(new JFetchMetadata(sessionId, 1), requestData2, EMPTY_PART_LIST, false)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     val epochs2 = cachedLeaderEpochs(context2)
     assertEquals(Optional.empty(), epochs1(tp0))
@@ -226,18 +225,18 @@ class FetchSessionTest {
 
     // Now verify we can change the leader epoch and the context is updated
     val requestData3 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    requestData3.put(tp0, new FetchRequest.PartitionData(0, 0, 100, Optional.of(6)))
-    requestData3.put(tp1, new FetchRequest.PartitionData(10, 0, 100, Optional.empty()))
-    requestData3.put(tp2, new FetchRequest.PartitionData(10, 0, 100, Optional.of(3)))
+    requestData3.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.of(6)))
+    requestData3.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.empty()))
+    requestData3.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(3)))
 
-    val request3 = createRequest(new JFetchMetadata(sessionId, 2), requestData3, topicIds, EMPTY_PART_LIST, false)
+    val request3 = createRequest(new JFetchMetadata(sessionId, 2), requestData3, EMPTY_PART_LIST, false)
     val context3 = fetchManager.newContext(
       request3.version,
       request3.metadata,
       request3.isFromFollower,
       request3.fetchData(topicNames),
       request3.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     val epochs3 = cachedLeaderEpochs(context3)
     assertEquals(Optional.of(6), epochs3(tp0))
@@ -251,44 +250,44 @@ class FetchSessionTest {
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
 
-    val tp0 = new TopicPartition("foo", 0)
-    val tp1 = new TopicPartition("foo", 1)
-    val tp2 = new TopicPartition("bar", 1)
     val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid()).asJava
     val topicNames = topicIds.asScala.map(_.swap).asJava
+    val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0))
+    val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1))
+    val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1))
 
-    def cachedLeaderEpochs(context: FetchContext): Map[TopicPartition, Optional[Integer]] = {
-      val mapBuilder = Map.newBuilder[TopicPartition, Optional[Integer]]
-      context.foreachPartition((tp, _, data) => mapBuilder += tp -> data.currentLeaderEpoch)
+    def cachedLeaderEpochs(context: FetchContext): Map[TopicIdPartition, Optional[Integer]] = {
+      val mapBuilder = Map.newBuilder[TopicIdPartition, Optional[Integer]]
+      context.foreachPartition((tp, data) => mapBuilder += tp -> data.currentLeaderEpoch)
       mapBuilder.result()
     }
 
-    def cachedLastFetchedEpochs(context: FetchContext): Map[TopicPartition, Optional[Integer]] = {
-      val mapBuilder = Map.newBuilder[TopicPartition, Optional[Integer]]
-      context.foreachPartition((tp, _, data) => mapBuilder += tp -> data.lastFetchedEpoch)
+    def cachedLastFetchedEpochs(context: FetchContext): Map[TopicIdPartition, Optional[Integer]] = {
+      val mapBuilder = Map.newBuilder[TopicIdPartition, Optional[Integer]]
+      context.foreachPartition((tp, data) => mapBuilder += tp -> data.lastFetchedEpoch)
       mapBuilder.result()
     }
 
     val requestData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    requestData1.put(tp0, new FetchRequest.PartitionData(0, 0, 100, Optional.empty[Integer], Optional.empty[Integer]))
-    requestData1.put(tp1, new FetchRequest.PartitionData(10, 0, 100, Optional.of(1), Optional.empty[Integer]))
-    requestData1.put(tp2, new FetchRequest.PartitionData(10, 0, 100, Optional.of(2), Optional.of(1)))
+    requestData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.empty[Integer], Optional.empty[Integer]))
+    requestData1.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.of(1), Optional.empty[Integer]))
+    requestData1.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(2), Optional.of(1)))
 
-    val request1 = createRequest(JFetchMetadata.INITIAL, requestData1, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequest(JFetchMetadata.INITIAL, requestData1, EMPTY_PART_LIST, false)
     val context1 = fetchManager.newContext(
       request1.version,
       request1.metadata,
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.of(1), tp2 -> Optional.of(2)),
       cachedLeaderEpochs(context1))
     assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.empty, tp2 -> Optional.of(1)),
       cachedLastFetchedEpochs(context1))
 
-    val response = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val response = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     response.put(tp0, new FetchResponseData.PartitionData()
       .setPartitionIndex(tp0.partition)
       .setHighWatermark(100)
@@ -309,14 +308,14 @@ class FetchSessionTest {
 
     // With no changes, the cached epochs should remain the same
     val requestData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val request2 = createRequest(new JFetchMetadata(sessionId, 1), requestData2, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequest(new JFetchMetadata(sessionId, 1), requestData2, EMPTY_PART_LIST, false)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.of(1), tp2 -> Optional.of(2)), cachedLeaderEpochs(context2))
     assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.empty, tp2 -> Optional.of(1)),
@@ -325,18 +324,18 @@ class FetchSessionTest {
 
     // Now verify we can change the leader epoch and the context is updated
     val requestData3 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    requestData3.put(tp0, new FetchRequest.PartitionData(0, 0, 100, Optional.of(6), Optional.of(5)))
-    requestData3.put(tp1, new FetchRequest.PartitionData(10, 0, 100, Optional.empty[Integer], Optional.empty[Integer]))
-    requestData3.put(tp2, new FetchRequest.PartitionData(10, 0, 100, Optional.of(3), Optional.of(3)))
+    requestData3.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.of(6), Optional.of(5)))
+    requestData3.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.empty[Integer], Optional.empty[Integer]))
+    requestData3.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(3), Optional.of(3)))
 
-    val request3 = createRequest(new JFetchMetadata(sessionId, 2), requestData3, topicIds, EMPTY_PART_LIST, false)
+    val request3 = createRequest(new JFetchMetadata(sessionId, 2), requestData3, EMPTY_PART_LIST, false)
     val context3 = fetchManager.newContext(
       request3.version,
       request3.metadata,
       request3.isFromFollower,
       request3.fetchData(topicNames),
       request3.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(Map(tp0 -> Optional.of(6), tp1 -> Optional.empty, tp2 -> Optional.of(3)),
       cachedLeaderEpochs(context3))
@@ -351,52 +350,56 @@ class FetchSessionTest {
     val fetchManager = new FetchManager(time, cache)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
+    val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0))
+    val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1))
+    val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 0))
+    val tp3 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1))
 
     // Verify that SESSIONLESS requests get a SessionlessFetchContext
-    val request = createRequest(JFetchMetadata.LEGACY, new util.HashMap[TopicPartition, FetchRequest.PartitionData](), topicIds, EMPTY_PART_LIST, true)
+    val request = createRequest(JFetchMetadata.LEGACY, new util.HashMap[TopicPartition, FetchRequest.PartitionData](), EMPTY_PART_LIST, true)
     val context = fetchManager.newContext(
       request.version,
       request.metadata,
       request.isFromFollower,
       request.fetchData(topicNames),
       request.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[SessionlessFetchContext], context.getClass)
 
     // Create a new fetch session with a FULL fetch request
     val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData2.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData2.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100,
       Optional.empty()))
-    reqData2.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    reqData2.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100,
       Optional.empty()))
-    val request2 = createRequest(JFetchMetadata.INITIAL, reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequest(JFetchMetadata.INITIAL, reqData2, EMPTY_PART_LIST, false)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], context2.getClass)
     val reqData2Iter = reqData2.entrySet().iterator()
-    context2.foreachPartition((topicPart, topicId, data) => {
+    context2.foreachPartition((topicIdPart, data) => {
       val entry = reqData2Iter.next()
-      assertEquals(entry.getKey, topicPart)
-      assertEquals(topicIds.get(entry.getKey.topic()), topicId)
+      assertEquals(entry.getKey, topicIdPart.topicPartition)
+      assertEquals(topicIds.get(entry.getKey.topic), topicIdPart.topicId)
       assertEquals(entry.getValue, data)
     })
-    assertEquals(0, context2.getFetchOffset(new TopicPartition("foo", 0)).get)
-    assertEquals(10, context2.getFetchOffset(new TopicPartition("foo", 1)).get)
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData2.put(new TopicPartition("foo", 0),
+    assertEquals(0, context2.getFetchOffset(tp0).get)
+    assertEquals(10, context2.getFetchOffset(tp1).get)
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData2.put(tp0,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    respData2.put(new TopicPartition("foo", 1),
+    respData2.put(tp1,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
@@ -405,69 +408,69 @@ class FetchSessionTest {
     val resp2 = context2.updateAndGenerateResponseData(respData2)
     assertEquals(Errors.NONE, resp2.error())
     assertTrue(resp2.sessionId() != INVALID_SESSION_ID)
-    assertEquals(respData2, resp2.responseData(topicNames, request2.version))
+    assertEquals(respData2.asScala.map { case (tp, data) => (tp.topicPartition, data)}.toMap.asJava, resp2.responseData(topicNames, request2.version))
 
     // Test trying to create a new session with an invalid epoch
-    val request3 = createRequest(new JFetchMetadata(resp2.sessionId(), 5), reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request3 = createRequest(new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false)
     val context3 = fetchManager.newContext(
       request3.version,
       request3.metadata,
       request3.isFromFollower,
       request3.fetchData(topicNames),
       request3.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[SessionErrorContext], context3.getClass)
     assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH,
       context3.updateAndGenerateResponseData(respData2).error())
 
     // Test trying to create a new session with a non-existent session id
-    val request4 = createRequest(new JFetchMetadata(resp2.sessionId() + 1, 1), reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request4 = createRequest(new JFetchMetadata(resp2.sessionId() + 1, 1), reqData2, EMPTY_PART_LIST, false)
     val context4 = fetchManager.newContext(
       request4.version,
       request4.metadata,
       request4.isFromFollower,
       request4.fetchData(topicNames),
       request4.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(Errors.FETCH_SESSION_ID_NOT_FOUND,
       context4.updateAndGenerateResponseData(respData2).error())
 
     // Continue the first fetch session we created.
     val reqData5 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val request5 = createRequest( new JFetchMetadata(resp2.sessionId(), 1), reqData5, topicIds, EMPTY_PART_LIST, false)
+    val request5 = createRequest( new JFetchMetadata(resp2.sessionId(), 1), reqData5, EMPTY_PART_LIST, false)
     val context5 = fetchManager.newContext(
       request5.version,
       request5.metadata,
       request5.isFromFollower,
       request5.fetchData(topicNames),
       request5.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[IncrementalFetchContext], context5.getClass)
     val reqData5Iter = reqData2.entrySet().iterator()
-    context5.foreachPartition((topicPart, topicId, data) => {
+    context5.foreachPartition((topicIdPart, data) => {
       val entry = reqData5Iter.next()
-      assertEquals(entry.getKey, topicPart)
-      assertEquals(topicIds.get(entry.getKey.topic()), topicId)
+      assertEquals(entry.getKey, topicIdPart.topicPartition)
+      assertEquals(topicIds.get(entry.getKey.topic()), topicIdPart.topicId)
       assertEquals(entry.getValue, data)
     })
-    assertEquals(10, context5.getFetchOffset(new TopicPartition("foo", 1)).get)
+    assertEquals(10, context5.getFetchOffset(tp1).get)
     val resp5 = context5.updateAndGenerateResponseData(respData2)
     assertEquals(Errors.NONE, resp5.error())
     assertEquals(resp2.sessionId(), resp5.sessionId())
     assertEquals(0, resp5.responseData(topicNames, request5.version).size())
 
     // Test setting an invalid fetch session epoch.
-    val request6 = createRequest( new JFetchMetadata(resp2.sessionId(), 5), reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request6 = createRequest( new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false)
     val context6 = fetchManager.newContext(
       request6.version,
       request6.metadata,
       request6.isFromFollower,
       request6.fetchData(topicNames),
       request6.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[SessionErrorContext], context6.getClass)
     assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH,
@@ -475,14 +478,14 @@ class FetchSessionTest {
 
     // Test generating a throttled response for the incremental fetch session
     val reqData7 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val request7 = createRequest( new JFetchMetadata(resp2.sessionId(), 2), reqData7, topicIds, EMPTY_PART_LIST, false)
+    val request7 = createRequest( new JFetchMetadata(resp2.sessionId(), 2), reqData7, EMPTY_PART_LIST, false)
     val context7 = fetchManager.newContext(
       request7.version,
       request7.metadata,
       request7.isFromFollower,
       request7.fetchData(topicNames),
       request7.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     val resp7 = context7.getThrottledResponse(100)
     assertEquals(Errors.NONE, resp7.error())
@@ -494,29 +497,29 @@ class FetchSessionTest {
     var nextSessionId = prevSessionId
     do {
       val reqData8 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-      reqData8.put(new TopicPartition("bar", 0), new FetchRequest.PartitionData(0, 0, 100,
+      reqData8.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 0, 0, 100,
         Optional.empty()))
-      reqData8.put(new TopicPartition("bar", 1), new FetchRequest.PartitionData(10, 0, 100,
+      reqData8.put(tp3.topicPartition, new FetchRequest.PartitionData(tp3.topicId, 10, 0, 100,
         Optional.empty()))
-      val request8 = createRequest(new JFetchMetadata(prevSessionId, FINAL_EPOCH), reqData8, topicIds, EMPTY_PART_LIST, false)
+      val request8 = createRequest(new JFetchMetadata(prevSessionId, FINAL_EPOCH), reqData8, EMPTY_PART_LIST, false)
       val context8 = fetchManager.newContext(
         request8.version,
         request8.metadata,
         request8.isFromFollower,
         request8.fetchData(topicNames),
         request8.forgottenTopics(topicNames),
-        topicIds
+        topicNames
       )
       assertEquals(classOf[SessionlessFetchContext], context8.getClass)
       assertEquals(0, cache.size)
-      val respData8 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-      respData8.put(new TopicPartition("bar", 0),
+      val respData8 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+      respData8.put(tp2,
         new FetchResponseData.PartitionData()
           .setPartitionIndex(0)
           .setHighWatermark(100)
           .setLastStableOffset(100)
           .setLogStartOffset(100))
-      respData8.put(new TopicPartition("bar", 1),
+      respData8.put(tp3,
         new FetchResponseData.PartitionData()
           .setPartitionIndex(1)
           .setHighWatermark(100)
@@ -528,37 +531,44 @@ class FetchSessionTest {
     } while (nextSessionId == prevSessionId)
   }
 
-  @Test
-  def testIncrementalFetchSession(): Unit = {
+  @ParameterizedTest
+  @ValueSource(booleans = Array(true, false))
+  def testIncrementalFetchSession(usesTopicIds: Boolean): Unit = {
     val time = new MockTime()
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava
+    val topicNames = if (usesTopicIds) Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava else Map[Uuid, String]().asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
+    val version = if (usesTopicIds) ApiKeys.FETCH.latestVersion else 12.toShort
+    val fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID)
+    val barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID)
+    val tp0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0))
+    val tp1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1))
+    val tp2 = new TopicIdPartition(barId, new TopicPartition("bar", 0))
 
     // Create a new fetch session with foo-0 and foo-1
     val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData1.put(tp0.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100,
       Optional.empty()))
-    reqData1.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    reqData1.put(tp1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100,
       Optional.empty()))
-    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false, version)
     val context1 = fetchManager.newContext(
       request1.version,
       request1.metadata,
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(tp0, new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    respData1.put(new TopicPartition("foo", 1), new FetchResponseData.PartitionData()
+    respData1.put(tp1, new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
         .setLastStableOffset(10)
@@ -570,36 +580,36 @@ class FetchSessionTest {
 
     // Create an incremental fetch request that removes foo-0 and adds bar-0
     val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData2.put(new TopicPartition("bar", 0), new FetchRequest.PartitionData(15, 0, 0,
+    reqData2.put(tp2.topicPartition, new FetchRequest.PartitionData(barId,15, 0, 0,
       Optional.empty()))
-    val removed2 = new util.ArrayList[TopicPartition]
-    removed2.add(new TopicPartition("foo", 0))
-    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, topicIds, removed2, false)
+    val removed2 = new util.ArrayList[TopicIdPartition]
+    removed2.add(tp0)
+    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, removed2, false, version)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[IncrementalFetchContext], context2.getClass)
-    val parts2 = Set(new TopicPartition("foo", 1), new TopicPartition("bar", 0))
+    val parts2 = Set(tp1, tp2)
     val reqData2Iter = parts2.iterator
-    context2.foreachPartition((topicPart, _, _) => {
-      assertEquals(reqData2Iter.next(), topicPart)
+    context2.foreachPartition((topicIdPart, _) => {
+      assertEquals(reqData2Iter.next(), topicIdPart)
     })
-    assertEquals(None, context2.getFetchOffset(new TopicPartition("foo", 0)))
-    assertEquals(10, context2.getFetchOffset(new TopicPartition("foo", 1)).get)
-    assertEquals(15, context2.getFetchOffset(new TopicPartition("bar", 0)).get)
-    assertEquals(None, context2.getFetchOffset(new TopicPartition("bar", 2)))
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData2.put(new TopicPartition("foo", 1), new FetchResponseData.PartitionData()
+    assertEquals(None, context2.getFetchOffset(tp0))
+    assertEquals(10, context2.getFetchOffset(tp1).get)
+    assertEquals(15, context2.getFetchOffset(tp2).get)
+    assertEquals(None, context2.getFetchOffset(new TopicIdPartition(barId, new TopicPartition("bar", 2))))
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData2.put(tp1, new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
         .setLastStableOffset(10)
         .setLogStartOffset(10))
-    respData2.put(new TopicPartition("bar", 0), new FetchResponseData.PartitionData()
+    respData2.put(tp2, new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(10)
         .setLastStableOffset(10)
@@ -618,17 +628,18 @@ class FetchSessionTest {
     val fetchManager = new FetchManager(time, cache)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
+    val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0))
+    val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1))
 
     // Create a new fetch session with foo-0 and foo-1
     val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100,
       Optional.empty()))
-    reqData1.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    reqData1.put(tp1.topicPartition, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 10, 0, 100,
       Optional.empty()))
-    val request1 = createRequestWithoutTopicIds(JFetchMetadata.INITIAL, reqData1, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequestWithoutTopicIds(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
     // Simulate unknown topic ID for foo.
     val topicNamesOnlyBar = Collections.singletonMap(topicIds.get("bar"), "bar")
-    val topicIdsOnlyBar = Collections.singletonMap("bar", topicIds.get("bar"))
     // We should not throw error since we have an older request version.
     val context1 = fetchManager.newContext(
       request1.version,
@@ -636,16 +647,16 @@ class FetchSessionTest {
       request1.isFromFollower,
       request1.fetchData(topicNamesOnlyBar),
       request1.forgottenTopics(topicNamesOnlyBar),
-      topicIdsOnlyBar
+      topicNamesOnlyBar
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(tp0, new FetchResponseData.PartitionData()
       .setPartitionIndex(0)
       .setHighWatermark(100)
       .setLastStableOffset(100)
       .setLogStartOffset(100))
-    respData1.put(new TopicPartition("foo", 1), new FetchResponseData.PartitionData()
+    respData1.put(tp1, new FetchResponseData.PartitionData()
       .setPartitionIndex(1)
       .setHighWatermark(10)
       .setLastStableOffset(10)
@@ -659,88 +670,139 @@ class FetchSessionTest {
   }
 
   @Test
-  def testIncrementalFetchSessionWithIdsWhenSessionDoesNotUseIds() : Unit = {
+  def testFetchSessionWithUnknownId(): Unit = {
     val time = new MockTime()
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val topicIds = new util.HashMap[String, Uuid]()
-    val topicNames = new util.HashMap[Uuid, String]()
+    val fooId = Uuid.randomUuid()
+    val barId = Uuid.randomUuid()
+    val zarId = Uuid.randomUuid()
+    val topicNames = Map(fooId -> "foo", barId -> "bar", zarId -> "zar").asJava
+    val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0))
+    val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1))
+    val zar0 = new TopicIdPartition(zarId, new TopicPartition("zar", 0))
+    val emptyFoo0 = new TopicIdPartition(fooId, new TopicPartition(null, 0))
+    val emptyFoo1 = new TopicIdPartition(fooId, new TopicPartition(null, 1))
+    val emptyZar0 = new TopicIdPartition(zarId, new TopicPartition(null, 0))
 
-    // Create a new fetch session with foo-0
+    // Create a new fetch session with foo-0 and foo-1
     val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(foo0.topicId, 0, 0, 100,
       Optional.empty()))
-    val request1 = createRequestWithoutTopicIds(JFetchMetadata.INITIAL, reqData1, topicIds, EMPTY_PART_LIST, false)
-    // Start a fetch session using a request version that does not use topic IDs.
+    reqData1.put(foo1.topicPartition, new FetchRequest.PartitionData(foo1.topicId, 10, 0, 100,
+      Optional.empty()))
+    reqData1.put(zar0.topicPartition, new FetchRequest.PartitionData(zar0.topicId, 10, 0, 100,
+      Optional.empty()))
+    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
+    // Simulate unknown topic ID for foo.
+    val topicNamesOnlyBar = Collections.singletonMap(barId, "bar")
+    // We should not throw error since we have an older request version.
     val context1 = fetchManager.newContext(
       request1.version,
       request1.metadata,
       request1.isFromFollower,
-      request1.fetchData(topicNames),
-      request1.forgottenTopics(topicNames),
-      topicIds
+      request1.fetchData(topicNamesOnlyBar),
+      request1.forgottenTopics(topicNamesOnlyBar),
+      topicNamesOnlyBar
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    assertPartitionsOrder(context1, Seq(emptyFoo0, emptyFoo1, emptyZar0))
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(emptyFoo0, new FetchResponseData.PartitionData()
       .setPartitionIndex(0)
-      .setHighWatermark(100)
-      .setLastStableOffset(100)
-      .setLogStartOffset(100))
+      .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code))
+    respData1.put(emptyFoo1, new FetchResponseData.PartitionData()
+      .setPartitionIndex(1)
+      .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code))
+    respData1.put(emptyZar0, new FetchResponseData.PartitionData()
+      .setPartitionIndex(1)
+      .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code))
     val resp1 = context1.updateAndGenerateResponseData(respData1)
+    // On the latest request version, we should have unknown topic ID errors.
     assertEquals(Errors.NONE, resp1.error())
     assertTrue(resp1.sessionId() != INVALID_SESSION_ID)
+    assertEquals(
+      Map(
+        foo0.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code,
+        foo1.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code,
+        zar0.topicPartition() -> Errors.UNKNOWN_TOPIC_ID.code
+      ),
+      resp1.responseData(topicNames, request1.version).asScala.map { case (tp, resp) =>
+        tp -> resp.errorCode
+      }
+    )
 
-    // Create an incremental fetch request as though no topics changed. However, send a v13 request.
-    // Also simulate the topic ID found on the server.
-    val fooId = Uuid.randomUuid()
-    topicIds.put("foo", fooId)
-    topicNames.put(fooId, "foo")
+    // Create an incremental request where we resolve the partitions
     val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false)
+    val topicNamesNoZar = Map(fooId -> "foo", barId -> "bar").asJava
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
-      request2.fetchData(topicNames),
-      request2.forgottenTopics(topicNames),
-      topicIds
+      request2.fetchData(topicNamesNoZar),
+      request2.forgottenTopics(topicNamesNoZar),
+      topicNamesNoZar
+    )
+    assertEquals(classOf[IncrementalFetchContext], context2.getClass)
+    // Topic names in the session but not in the request are lazily resolved via foreachPartition. Resolve foo topic IDs here.
+    assertPartitionsOrder(context2, Seq(foo0, foo1, emptyZar0))
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData2.put(foo0, new FetchResponseData.PartitionData()
+      .setPartitionIndex(0)
+      .setHighWatermark(100)
+      .setLastStableOffset(100)
+      .setLogStartOffset(100))
+    respData2.put(foo1, new FetchResponseData.PartitionData()
+      .setPartitionIndex(1)
+      .setHighWatermark(10)
+      .setLastStableOffset(10)
+      .setLogStartOffset(10))
+    respData2.put(emptyZar0, new FetchResponseData.PartitionData()
+      .setPartitionIndex(1)
+      .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code))
+    val resp2 = context2.updateAndGenerateResponseData(respData2)
+    // Since we are ignoring IDs, we should have no errors.
+    assertEquals(Errors.NONE, resp2.error())
+    assertTrue(resp2.sessionId() != INVALID_SESSION_ID)
+    assertEquals(3, resp2.responseData(topicNames, request2.version).size)
+    assertEquals(
+      Map(
+        foo0.topicPartition -> Errors.NONE.code,
+        foo1.topicPartition -> Errors.NONE.code,
+        zar0.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code
+      ),
+      resp2.responseData(topicNames, request2.version).asScala.map { case (tp, resp) =>
+        tp -> resp.errorCode
+      }
     )
-
-    assertEquals(classOf[SessionErrorContext], context2.getClass)
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    assertEquals(Errors.FETCH_SESSION_TOPIC_ID_ERROR,
-      context2.updateAndGenerateResponseData(respData2).error())
   }
 
   @Test
-  def testIncrementalFetchSessionWithoutIdsWhenSessionUsesIds() : Unit = {
+  def testIncrementalFetchSessionWithIdsWhenSessionDoesNotUseIds() : Unit = {
     val time = new MockTime()
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val fooId = Uuid.randomUuid()
-    val topicIds = new util.HashMap[String, Uuid]()
     val topicNames = new util.HashMap[Uuid, String]()
-    topicIds.put("foo", fooId)
-    topicNames.put(fooId, "foo")
+    val foo0 = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0))
 
     // Create a new fetch session with foo-0
     val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 100,
       Optional.empty()))
-    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, topicIds, EMPTY_PART_LIST, false)
-    // Start a fetch session using a request version that uses topic IDs.
+    val request1 = createRequestWithoutTopicIds(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
+    // Start a fetch session using a request version that does not use topic IDs.
     val context1 = fetchManager.newContext(
       request1.version,
       request1.metadata,
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(foo0, new FetchResponseData.PartitionData()
       .setPartitionIndex(0)
       .setHighWatermark(100)
       .setLastStableOffset(100)
@@ -749,43 +811,42 @@ class FetchSessionTest {
     assertEquals(Errors.NONE, resp1.error())
     assertTrue(resp1.sessionId() != INVALID_SESSION_ID)
 
-    // Create an incremental fetch request as though no topics changed. However, send a v12 request.
-    // Also simulate the topic ID not found on the server
-    topicIds.remove("foo")
-    topicNames.remove(fooId)
+    // Create an incremental fetch request as though no topics changed. However, send a v13 request.
+    // Also simulate the topic ID found on the server.
+    val fooId = Uuid.randomUuid()
+    topicNames.put(fooId, "foo")
     val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val request2 = createRequestWithoutTopicIds(new JFetchMetadata(resp1.sessionId(), 1), reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
 
     assertEquals(classOf[SessionErrorContext], context2.getClass)
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     assertEquals(Errors.FETCH_SESSION_TOPIC_ID_ERROR,
       context2.updateAndGenerateResponseData(respData2).error())
   }
 
   @Test
-  def testIncrementalFetchSessionWithIdsSwitchesIdForTopic() : Unit = {
+  def testIncrementalFetchSessionWithoutIdsWhenSessionUsesIds() : Unit = {
     val time = new MockTime()
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
     val fooId = Uuid.randomUuid()
-    val topicIds = new util.HashMap[String, Uuid]()
     val topicNames = new util.HashMap[Uuid, String]()
-    topicIds.put("foo", fooId)
     topicNames.put(fooId, "foo")
+    val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0))
 
     // Create a new fetch session with foo-0
     val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100,
       Optional.empty()))
-    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
     // Start a fetch session using a request version that uses topic IDs.
     val context1 = fetchManager.newContext(
       request1.version,
@@ -793,11 +854,11 @@ class FetchSessionTest {
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(foo0, new FetchResponseData.PartitionData()
       .setPartitionIndex(0)
       .setHighWatermark(100)
       .setLastStableOffset(100)
@@ -806,27 +867,22 @@ class FetchSessionTest {
     assertEquals(Errors.NONE, resp1.error())
     assertTrue(resp1.sessionId() != INVALID_SESSION_ID)
 
-    // Create an incremental fetch request adding a new partition to change the topic ID.
-    // Also simulate the topic ID changed on the server.
+    // Create an incremental fetch request as though no topics changed. However, send a v12 request.
+    // Also simulate the topic ID not found on the server
     topicNames.remove(fooId)
-    val newFooId = Uuid.randomUuid()
-    topicIds.put("foo", newFooId)
-    topicNames.put(newFooId, "foo")
     val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData2.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(0, 0, 100,
-      Optional.empty()))
-    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequestWithoutTopicIds(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
 
     assertEquals(classOf[SessionErrorContext], context2.getClass)
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     assertEquals(Errors.FETCH_SESSION_TOPIC_ID_ERROR,
       context2.updateAndGenerateResponseData(respData2).error())
   }
@@ -840,14 +896,16 @@ class FetchSessionTest {
     val fetchManager = new FetchManager(time, cache)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
+    val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0))
+    val tp1 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1))
 
     // Create a new fetch session with foo-0 and bar-1
     val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100,
       Optional.empty()))
-    reqData1.put(new TopicPartition("bar", 1), new FetchRequest.PartitionData(10, 0, 100,
+    reqData1.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100,
       Optional.empty()))
-    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
     // Start a fetch session. Simulate unknown partition foo-0.
     val context1 = fetchManager.newContext(
       request1.version,
@@ -855,16 +913,16 @@ class FetchSessionTest {
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("bar", 1), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(tp1, new FetchResponseData.PartitionData()
       .setPartitionIndex(1)
       .setHighWatermark(10)
       .setLastStableOffset(10)
       .setLogStartOffset(10))
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    respData1.put(tp0, new FetchResponseData.PartitionData()
       .setPartitionIndex(0)
       .setHighWatermark(-1)
       .setLastStableOffset(-1)
@@ -877,22 +935,21 @@ class FetchSessionTest {
 
     // Create an incremental fetch request as though no topics changed.
     val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false)
     // Simulate ID changing on server.
     val topicNamesFooChanged =  Map(topicIds.get("bar") -> "bar", Uuid.randomUuid() -> "foo").asJava
-    val topicIdsFooChanged = topicNamesFooChanged.asScala.map(_.swap).asJava
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNamesFooChanged),
       request2.forgottenTopics(topicNamesFooChanged),
-      topicIdsFooChanged
+      topicNamesFooChanged
     )
     assertEquals(classOf[IncrementalFetchContext], context2.getClass)
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     // Likely if the topic ID is different in the broker, it will be different in the log. Simulate the log check finding an inconsistent ID.
-    respData2.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    respData2.put(tp0, new FetchResponseData.PartitionData()
       .setPartitionIndex(0)
       .setHighWatermark(-1)
       .setLastStableOffset(-1)
@@ -900,11 +957,360 @@ class FetchSessionTest {
       .setErrorCode(Errors.INCONSISTENT_TOPIC_ID.code))
     val resp2 = context2.updateAndGenerateResponseData(respData2)
 
-    assertEquals(Errors.INCONSISTENT_TOPIC_ID, resp2.error)
+    assertEquals(Errors.NONE, resp2.error)
     assertTrue(resp2.sessionId > 0)
     val responseData2 = resp2.responseData(topicNames, request2.version)
-    // We should have no partition responses with this top level error.
-    assertEquals(0, responseData2.size())
+    // We should have the inconsistent topic ID error on the partition
+    assertEquals(Errors.INCONSISTENT_TOPIC_ID.code, responseData2.get(tp0.topicPartition).errorCode)
+  }
+
+  private def noErrorResponse: FetchResponseData.PartitionData = {
+    new FetchResponseData.PartitionData()
+      .setPartitionIndex(1)
+      .setHighWatermark(10)
+      .setLastStableOffset(10)
+      .setLogStartOffset(10)
+  }
+
+  private def errorResponse(errorCode: Short): FetchResponseData.PartitionData  = {
+    new FetchResponseData.PartitionData()
+      .setPartitionIndex(0)
+      .setHighWatermark(-1)
+      .setLastStableOffset(-1)
+      .setLogStartOffset(-1)
+      .setErrorCode(errorCode)
+  }
+
+  @Test
+  def testResolveUnknownPartitions(): Unit = {
+    val time = new MockTime()
+    val cache = new FetchSessionCache(10, 1000)
+    val fetchManager = new FetchManager(time, cache)
+
+    def newContext(
+      metadata: JFetchMetadata,
+      partitions: Seq[TopicIdPartition],
+      topicNames: Map[Uuid, String] // Topic ID to name mapping known by the broker.
+    ): FetchContext = {
+      val data = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+      partitions.foreach { topicIdPartition =>
+        data.put(
+          topicIdPartition.topicPartition,
+          new FetchRequest.PartitionData(topicIdPartition.topicId, 0, 0, 100, Optional.empty())
+        )
+      }
+
+      val fetchRequest = createRequest(metadata, data, EMPTY_PART_LIST, false)
+
+      fetchManager.newContext(
+        fetchRequest.version,
+        fetchRequest.metadata,
+        fetchRequest.isFromFollower,
+        fetchRequest.fetchData(topicNames.asJava),
+        fetchRequest.forgottenTopics(topicNames.asJava),
+        topicNames.asJava
+      )
+    }
+
+    def updateAndGenerateResponseData(
+      context: FetchContext
+    ): Int = {
+      val data = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+      context.foreachPartition { (topicIdPartition, _) =>
+        data.put(
+          topicIdPartition,
+          if (topicIdPartition.topic == null)
+            errorResponse(Errors.UNKNOWN_TOPIC_ID.code)
+          else
+            noErrorResponse
+        )
+      }
+      context.updateAndGenerateResponseData(data).sessionId
+    }
+
+    val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0))
+    val bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0))
+    val zar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("zar", 0))
+
+    val fooUnresolved = new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition))
+    val barUnresolved = new TopicIdPartition(bar.topicId, new TopicPartition(null, bar.partition))
+    val zarUnresolved = new TopicIdPartition(zar.topicId, new TopicPartition(null, zar.partition))
+
+    // The metadata cache does not know about the topic.
+    val context1 = newContext(
+      JFetchMetadata.INITIAL,
+      Seq(foo, bar, zar),
+      Map.empty[Uuid, String]
+    )
+
+    // So the context contains unresolved partitions.
+    assertEquals(classOf[FullFetchContext], context1.getClass)
+    assertPartitionsOrder(context1, Seq(fooUnresolved, barUnresolved, zarUnresolved))
+
+    // The response is sent back to create the session.
+    val sessionId = updateAndGenerateResponseData(context1)
+
+    // The metadata cache only knows about foo.
+    val context2 = newContext(
+      new JFetchMetadata(sessionId, 1),
+      Seq.empty,
+      Map(foo.topicId -> foo.topic)
+    )
+
+    // So foo is resolved but not the others.
+    assertEquals(classOf[IncrementalFetchContext], context2.getClass)
+    assertPartitionsOrder(context2, Seq(foo, barUnresolved, zarUnresolved))
+
+    updateAndGenerateResponseData(context2)
+
+    // The metadata cache knows about foo and bar.
+    val context3 = newContext(
+      new JFetchMetadata(sessionId, 2),
+      Seq(bar),
+      Map(foo.topicId -> foo.topic, bar.topicId -> bar.topic)
+    )
+
+    // So foo and bar are resolved.
+    assertEquals(classOf[IncrementalFetchContext], context3.getClass)
+    assertPartitionsOrder(context3, Seq(foo, bar, zarUnresolved))
+
+    updateAndGenerateResponseData(context3)
+
+    // The metadata cache knows about all topics.
+    val context4 = newContext(
+      new JFetchMetadata(sessionId, 3),
+      Seq.empty,
+      Map(foo.topicId -> foo.topic, bar.topicId -> bar.topic, zar.topicId -> zar.topic)
+    )
+
+    // So all topics are resolved.
+    assertEquals(classOf[IncrementalFetchContext], context4.getClass)
+    assertPartitionsOrder(context4, Seq(foo, bar, zar))
+
+    updateAndGenerateResponseData(context4)
+
+    // The metadata cache does not know about the topics anymore (e.g. deleted).
+    val context5 = newContext(
+      new JFetchMetadata(sessionId, 4),
+      Seq.empty,
+      Map.empty
+    )
+
+    // All topics remain resolved.
+    assertEquals(classOf[IncrementalFetchContext], context5.getClass)
+    assertPartitionsOrder(context4, Seq(foo, bar, zar))
+  }
+
+  // This test simulates trying to forget a topic partition with all possible topic ID usages for both requests.
+  @ParameterizedTest
+  @MethodSource(Array("idUsageCombinations"))
+  def testToForgetPartitions(fooStartsResolved: Boolean, fooEndsResolved: Boolean): Unit = {
+    val time = new MockTime()
+    val cache = new FetchSessionCache(10, 1000)
+    val fetchManager = new FetchManager(time, cache)
+
+    def newContext(
+      metadata: JFetchMetadata,
+      partitions: Seq[TopicIdPartition],
+      toForget: Seq[TopicIdPartition],
+      topicNames: Map[Uuid, String] // Topic ID to name mapping known by the broker.
+    ): FetchContext = {
+      val data = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+      partitions.foreach { topicIdPartition =>
+        data.put(
+          topicIdPartition.topicPartition,
+          new FetchRequest.PartitionData(topicIdPartition.topicId, 0, 0, 100, Optional.empty())
+        )
+      }
+
+      val fetchRequest = createRequest(metadata, data, toForget.toList.asJava, false)
+
+      fetchManager.newContext(
+        fetchRequest.version,
+        fetchRequest.metadata,
+        fetchRequest.isFromFollower,
+        fetchRequest.fetchData(topicNames.asJava),
+        fetchRequest.forgottenTopics(topicNames.asJava),
+        topicNames.asJava
+      )
+    }
+
+    def updateAndGenerateResponseData(
+      context: FetchContext
+    ): Int = {
+      val data = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+      context.foreachPartition { (topicIdPartition, _) =>
+        data.put(
+          topicIdPartition,
+          if (topicIdPartition.topic == null)
+            errorResponse(Errors.UNKNOWN_TOPIC_ID.code)
+          else
+            noErrorResponse
+        )
+      }
+      context.updateAndGenerateResponseData(data).sessionId
+    }
+
+    val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0))
+    val bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0))
+
+    val fooUnresolved = new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition))
+    val barUnresolved = new TopicIdPartition(bar.topicId, new TopicPartition(null, bar.partition))
+
+    // Create a new context where foo's resolution depends on fooStartsResolved and bar is unresolved.
+    val context1Names = if (fooStartsResolved) Map(foo.topicId -> foo.topic) else Map.empty[Uuid, String]
+    val fooContext1 = if (fooStartsResolved) foo else fooUnresolved
+    val context1 = newContext(
+      JFetchMetadata.INITIAL,
+      Seq(fooContext1, bar),
+      Seq.empty,
+      context1Names
+    )
+
+    // So the context contains unresolved bar and a resolved foo iff fooStartsResolved
+    assertEquals(classOf[FullFetchContext], context1.getClass)
+    assertPartitionsOrder(context1, Seq(fooContext1, barUnresolved))
+
+    // The response is sent back to create the session.
+    val sessionId = updateAndGenerateResponseData(context1)
+
+    // Forget foo, but keep bar. Foo's resolution depends on fooEndsResolved and bar stays unresolved.
+    val context2Names = if (fooEndsResolved) Map(foo.topicId -> foo.topic) else Map.empty[Uuid, String]
+    val fooContext2 = if (fooEndsResolved) foo else fooUnresolved
+    val context2 = newContext(
+      new JFetchMetadata(sessionId, 1),
+      Seq.empty,
+      Seq(fooContext2),
+      context2Names
+    )
+
+    // So foo is removed but not the others.
+    assertEquals(classOf[IncrementalFetchContext], context2.getClass)
+    assertPartitionsOrder(context2, Seq(barUnresolved))
+
+    updateAndGenerateResponseData(context2)
+
+    // Now remove bar
+    val context3 = newContext(
+      new JFetchMetadata(sessionId, 2),
+      Seq.empty,
+      Seq(bar),
+      Map.empty[Uuid, String]
+    )
+
+    // Context is sessionless since it is empty.
+    assertEquals(classOf[SessionlessFetchContext], context3.getClass)
+    assertPartitionsOrder(context3, Seq())
+  }
+
+  @Test
+  def testUpdateAndGenerateResponseData(): Unit = {
+    val time = new MockTime()
+    val cache = new FetchSessionCache(10, 1000)
+    val fetchManager = new FetchManager(time, cache)
+
+    def newContext(
+      metadata: JFetchMetadata,
+      partitions: Seq[TopicIdPartition],
+      topicNames: Map[Uuid, String] // Topic ID to name mapping known by the broker.
+    ): FetchContext = {
+      val data = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
+      partitions.foreach { topicIdPartition =>
+        data.put(
+          topicIdPartition.topicPartition,
+          new FetchRequest.PartitionData(topicIdPartition.topicId, 0, 0, 100, Optional.empty())
+        )
+      }
+
+      val fetchRequest = createRequest(metadata, data, EMPTY_PART_LIST, false)
+
+      fetchManager.newContext(
+        fetchRequest.version,
+        fetchRequest.metadata,
+        fetchRequest.isFromFollower,
+        fetchRequest.fetchData(topicNames.asJava),
+        fetchRequest.forgottenTopics(topicNames.asJava),
+        topicNames.asJava
+      )
+    }
+
+    // Give both topics errors so they will stay in the session.
+    def updateAndGenerateResponseData(
+      context: FetchContext
+    ): FetchResponse = {
+      val data = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+      context.foreachPartition { (topicIdPartition, _) =>
+        data.put(
+          topicIdPartition,
+          if (topicIdPartition.topic == null)
+            errorResponse(Errors.UNKNOWN_TOPIC_ID.code)
+          else
+            errorResponse(Errors.UNKNOWN_TOPIC_OR_PARTITION.code)
+        )
+      }
+      context.updateAndGenerateResponseData(data)
+    }
+
+    val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0))
+    val bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0))
+
+    // Foo will always be resolved and bar will always not be resolved on the receiving broker.
+    val receivingBrokerTopicNames = Map(foo.topicId -> foo.topic)
+    // The sender will know both topics' id to name mappings.
+    val sendingTopicNames = Map(foo.topicId -> foo.topic, bar.topicId -> bar.topic)
+
+    def checkResponseData(response: FetchResponse): Unit = {
+      assertEquals(
+        Map(
+          foo.topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION.code,
+          bar.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code,
+        ),
+        response.responseData(sendingTopicNames.asJava, ApiKeys.FETCH.latestVersion).asScala.map { case (tp, resp) =>
+          tp -> resp.errorCode
+        }
+      )
+    }
+
+    // Start with a sessionless context.
+    val context1 = newContext(
+      JFetchMetadata.LEGACY,
+      Seq(foo, bar),
+      receivingBrokerTopicNames
+    )
+    assertEquals(classOf[SessionlessFetchContext], context1.getClass)
+    // Check the response can be read as expected.
+    checkResponseData(updateAndGenerateResponseData(context1))
+
+    // Now create a full context.
+    val context2 = newContext(
+      JFetchMetadata.INITIAL,
+      Seq(foo, bar),
+      receivingBrokerTopicNames
+    )
+    assertEquals(classOf[FullFetchContext], context2.getClass)
+    // We want to get the session ID to build more contexts in this session.
+    val response2 = updateAndGenerateResponseData(context2)
+    val sessionId = response2.sessionId
+    checkResponseData(response2)
+
+    // Now create an incremental context. We re-add foo as though the partition data is updated. In a real broker, the data would update.
+    val context3 = newContext(
+      new JFetchMetadata(sessionId, 1),
+      Seq.empty,
+      receivingBrokerTopicNames
+    )
+    assertEquals(classOf[IncrementalFetchContext], context3.getClass)
+    checkResponseData(updateAndGenerateResponseData(context3))
+
+    // Finally create an error context by using the same epoch
+    val context4 = newContext(
+      new JFetchMetadata(sessionId, 1),
+      Seq.empty,
+      receivingBrokerTopicNames
+    )
+    assertEquals(classOf[SessionErrorContext], context4.getClass)
+    // The response should be empty.
+    assertEquals(Collections.emptyList, updateAndGenerateResponseData(context4).data.responses)
   }
 
   @Test
@@ -913,32 +1319,34 @@ class FetchSessionTest {
     // set maximum entries to 2 to allow for eviction later
     val cache = new FetchSessionCache(2, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val topicNames = Map(Uuid.randomUuid() -> "foo").asJava
-    val topicIds = topicNames.asScala.map(_.swap).asJava
+    val fooId = Uuid.randomUuid()
+    val topicNames = Map(fooId -> "foo").asJava
+    val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0))
+    val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1))
 
     // Create a new fetch session, session 1
     val session1req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    session1req.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    session1req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100,
       Optional.empty()))
-    session1req.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    session1req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId,10, 0, 100,
       Optional.empty()))
-    val session1request1 = createRequest(JFetchMetadata.INITIAL, session1req, topicIds, EMPTY_PART_LIST, false)
+    val session1request1 = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, false)
     val session1context1 = fetchManager.newContext(
       session1request1.version,
       session1request1.metadata,
       session1request1.isFromFollower,
       session1request1.fetchData(topicNames),
       session1request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], session1context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(foo0, new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    respData1.put(new TopicPartition("foo", 1), new FetchResponseData.PartitionData()
+    respData1.put(foo1, new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
         .setLastStableOffset(10)
@@ -954,28 +1362,28 @@ class FetchSessionTest {
 
     // Create a second new fetch session
     val session2req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    session2req.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    session2req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    session2req.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    session2req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId,10, 0, 100,
       Optional.empty()))
-    val session2request1 = createRequest(JFetchMetadata.INITIAL, session1req, topicIds, EMPTY_PART_LIST, false)
+    val session2request1 = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, false)
     val session2context = fetchManager.newContext(
       session2request1.version,
       session2request1.metadata,
       session2request1.isFromFollower,
       session2request1.fetchData(topicNames),
       session2request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], session2context.getClass)
     val session2RespData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    session2RespData.put(new TopicPartition("foo", 0),
+    session2RespData.put(foo0.topicPartition,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    session2RespData.put(new TopicPartition("foo", 1), new FetchResponseData.PartitionData()
+    session2RespData.put(foo1.topicPartition, new FetchResponseData.PartitionData()
       .setPartitionIndex(1)
       .setHighWatermark(10)
       .setLastStableOffset(10)
@@ -993,15 +1401,15 @@ class FetchSessionTest {
     // Create an incremental fetch request for session 1
     val session1request2 = createRequest(
       new JFetchMetadata(session1resp.sessionId(), 1),
-      new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData], topicIds,
-      new util.ArrayList[TopicPartition], false)
+      new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData],
+      new util.ArrayList[TopicIdPartition], false)
     val context1v2 = fetchManager.newContext(
       session1request2.version,
       session1request2.metadata,
       session1request2.isFromFollower,
       session1request2.fetchData(topicNames),
       session1request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[IncrementalFetchContext], context1v2.getClass)
 
@@ -1012,27 +1420,27 @@ class FetchSessionTest {
     // the second session should be evicted because the first session was incrementally fetched
     // more recently than the second session was created
     val session3req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    session3req.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    session3req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    session3req.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(0, 0, 100,
+    session3req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100,
       Optional.empty()))
-    val session3request1 = createRequest(JFetchMetadata.INITIAL, session3req, topicIds, EMPTY_PART_LIST, false)
+    val session3request1 = createRequest(JFetchMetadata.INITIAL, session3req, EMPTY_PART_LIST, false)
     val session3context = fetchManager.newContext(
       session3request1.version,
       session3request1.metadata,
       session3request1.isFromFollower,
       session3request1.fetchData(topicNames),
       session3request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], session3context.getClass)
-    val respData3 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData3.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData3 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData3.put(new TopicIdPartition(fooId, new TopicPartition("foo", 0)), new FetchResponseData.PartitionData()
       .setPartitionIndex(0)
       .setHighWatermark(100)
       .setLastStableOffset(100)
       .setLogStartOffset(100))
-    respData3.put(new TopicPartition("foo", 1),
+    respData3.put(new TopicIdPartition(fooId, new TopicPartition("foo", 1)),
       new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
@@ -1054,32 +1462,34 @@ class FetchSessionTest {
     // set maximum entries to 2 to allow for eviction later
     val cache = new FetchSessionCache(2, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val topicNames = Map(Uuid.randomUuid() -> "foo").asJava
-    val topicIds = topicNames.asScala.map(_.swap).asJava
+    val fooId = Uuid.randomUuid()
+    val topicNames = Map(fooId -> "foo").asJava
+    val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0))
+    val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1))
 
     // Create a new fetch session, session 1
     val session1req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    session1req.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    session1req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    session1req.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    session1req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100,
       Optional.empty()))
-    val session1request = createRequest(JFetchMetadata.INITIAL, session1req, topicIds, EMPTY_PART_LIST, true)
+    val session1request = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, true)
     val session1context = fetchManager.newContext(
       session1request.version,
       session1request.metadata,
       session1request.isFromFollower,
       session1request.fetchData(topicNames),
       session1request.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], session1context.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(foo0, new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    respData1.put(new TopicPartition("foo", 1), new FetchResponseData.PartitionData()
+    respData1.put(foo1, new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
         .setLastStableOffset(10)
@@ -1095,34 +1505,34 @@ class FetchSessionTest {
 
     // Create a second new fetch session, unprivileged
     val session2req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    session2req.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    session2req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    session2req.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    session2req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100,
       Optional.empty()))
-    val session2request = createRequest(JFetchMetadata.INITIAL, session1req, topicIds, EMPTY_PART_LIST, false)
+    val session2request = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, false)
     val session2context = fetchManager.newContext(
       session2request.version,
       session2request.metadata,
       session2request.isFromFollower,
       session2request.fetchData(topicNames),
       session2request.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], session2context.getClass)
-    val session2RespData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    session2RespData.put(new TopicPartition("foo", 0),
+    val session2RespData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    session2RespData.put(foo0,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    session2RespData.put(new TopicPartition("foo", 1),
+    session2RespData.put(foo1,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
         .setLastStableOffset(10)
         .setLogStartOffset(10))
-    val session2resp = session2context.updateAndGenerateResponseData(respData1)
+    val session2resp = session2context.updateAndGenerateResponseData(session2RespData)
     assertEquals(Errors.NONE, session2resp.error())
     assertTrue(session2resp.sessionId() != INVALID_SESSION_ID)
     assertEquals(2, session2resp.responseData(topicNames, session2request.version).size)
@@ -1135,28 +1545,28 @@ class FetchSessionTest {
 
     // create a session to test session1 privileges mean that session 1 is retained and session 2 is evicted
     val session3req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    session3req.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    session3req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    session3req.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(0, 0, 100,
+    session3req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    val session3request = createRequest(JFetchMetadata.INITIAL, session3req, topicIds, EMPTY_PART_LIST, true)
+    val session3request = createRequest(JFetchMetadata.INITIAL, session3req, EMPTY_PART_LIST, true)
     val session3context = fetchManager.newContext(
       session3request.version,
       session3request.metadata,
       session3request.isFromFollower,
       session3request.fetchData(topicNames),
       session3request.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], session3context.getClass)
-    val respData3 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData3.put(new TopicPartition("foo", 0),
+    val respData3 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData3.put(foo0,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    respData3.put(new TopicPartition("foo", 1),
+    respData3.put(foo1,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
@@ -1178,28 +1588,28 @@ class FetchSessionTest {
 
     // create a final session to test whether session1 can be evicted due to age even though it is privileged
     val session4req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    session4req.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    session4req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    session4req.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(0, 0, 100,
+    session4req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    val session4request = createRequest(JFetchMetadata.INITIAL, session4req, topicIds, EMPTY_PART_LIST, true)
+    val session4request = createRequest(JFetchMetadata.INITIAL, session4req, EMPTY_PART_LIST, true)
     val session4context = fetchManager.newContext(
       session4request.version,
       session4request.metadata,
       session4request.isFromFollower,
       session4request.fetchData(topicNames),
       session4request.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], session4context.getClass)
-    val respData4 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData4.put(new TopicPartition("foo", 0),
+    val respData4 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData4.put(foo0,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    respData4.put(new TopicPartition("foo", 1),
+    respData4.put(foo1,
       new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
@@ -1221,32 +1631,34 @@ class FetchSessionTest {
     val time = new MockTime()
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val topicNames = Map(Uuid.randomUuid() -> "foo").asJava
-    val topicIds = topicNames.asScala.map(_.swap).asJava
+    val fooId = Uuid.randomUuid()
+    val topicNames = Map(fooId -> "foo").asJava
+    val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0))
+    val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1))
 
     // Create a new fetch session with foo-0 and foo-1
     val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100,
+    reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100,
       Optional.empty()))
-    reqData1.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100,
+    reqData1.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100,
       Optional.empty()))
-    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false)
     val context1 = fetchManager.newContext(
       request1.version,
       request1.metadata,
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-    respData1.put(new TopicPartition("foo", 0), new FetchResponseData.PartitionData()
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    respData1.put(foo0, new FetchResponseData.PartitionData()
         .setPartitionIndex(0)
         .setHighWatermark(100)
         .setLastStableOffset(100)
         .setLogStartOffset(100))
-    respData1.put(new TopicPartition("foo", 1), new FetchResponseData.PartitionData()
+    respData1.put(foo1, new FetchResponseData.PartitionData()
         .setPartitionIndex(1)
         .setHighWatermark(10)
         .setLastStableOffset(10)
@@ -1259,20 +1671,20 @@ class FetchSessionTest {
     // Create an incremental fetch request that removes foo-0 and foo-1
     // Verify that the previous fetch session was closed.
     val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    val removed2 = new util.ArrayList[TopicPartition]
-    removed2.add(new TopicPartition("foo", 0))
-    removed2.add(new TopicPartition("foo", 1))
-    val request2 = createRequest( new JFetchMetadata(resp1.sessionId, 1), reqData2, topicIds, removed2, false)
+    val removed2 = new util.ArrayList[TopicIdPartition]
+    removed2.add(foo0)
+    removed2.add(foo1)
+    val request2 = createRequest( new JFetchMetadata(resp1.sessionId, 1), reqData2, removed2, false)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[SessionlessFetchContext], context2.getClass)
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     val resp2 = context2.updateAndGenerateResponseData(respData2)
     assertEquals(INVALID_SESSION_ID, resp2.sessionId)
     assertTrue(resp2.responseData(topicNames, request2.version).isEmpty)
@@ -1284,27 +1696,27 @@ class FetchSessionTest {
     val time = new MockTime()
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val tp1 = new TopicPartition("foo", 1)
-    val tp2 = new TopicPartition("bar", 2)
     val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava
     val topicIds = topicNames.asScala.map(_.swap).asJava
+    val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1))
+    val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 2))
 
     val reqData = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData.put(tp1, new FetchRequest.PartitionData(100, 0, 1000, Optional.of(5), Optional.of(4)))
-    reqData.put(tp2, new FetchRequest.PartitionData(100, 0, 1000, Optional.of(5), Optional.of(4)))
+    reqData.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4)))
+    reqData.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4)))
 
     // Full fetch context returns all partitions in the response
-    val request1 = createRequest(JFetchMetadata.INITIAL, reqData, topicIds, EMPTY_PART_LIST, false)
+    val request1 = createRequest(JFetchMetadata.INITIAL, reqData, EMPTY_PART_LIST, false)
     val context1 = fetchManager.newContext(
       request1.version,
       request1.metadata,
       request1.isFromFollower,
       request1.fetchData(topicNames),
       request1.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[FullFetchContext], context1.getClass)
-    val respData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     respData.put(tp1, new FetchResponseData.PartitionData()
         .setPartitionIndex(tp1.partition)
         .setHighWatermark(105)
@@ -1320,24 +1732,24 @@ class FetchSessionTest {
     val resp1 = context1.updateAndGenerateResponseData(respData)
     assertEquals(Errors.NONE, resp1.error)
     assertNotEquals(INVALID_SESSION_ID, resp1.sessionId)
-    assertEquals(Utils.mkSet(tp1, tp2), resp1.responseData(topicNames, request1.version).keySet)
+    assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp1.responseData(topicNames, request1.version).keySet)
 
     // Incremental fetch context returns partitions with divergent epoch even if none
     // of the other conditions for return are met.
-    val request2 = createRequest(new JFetchMetadata(resp1.sessionId, 1), reqData, topicIds, EMPTY_PART_LIST, false)
+    val request2 = createRequest(new JFetchMetadata(resp1.sessionId, 1), reqData, EMPTY_PART_LIST, false)
     val context2 = fetchManager.newContext(
       request2.version,
       request2.metadata,
       request2.isFromFollower,
       request2.fetchData(topicNames),
       request2.forgottenTopics(topicNames),
-      topicIds
+      topicNames
     )
     assertEquals(classOf[IncrementalFetchContext], context2.getClass)
     val resp2 = context2.updateAndGenerateResponseData(respData)
     assertEquals(Errors.NONE, resp2.error)
     assertEquals(resp1.sessionId, resp2.sessionId)
-    assertEquals(Collections.singleton(tp2), resp2.responseData(topicNames, request2.version).keySet)
+    assertEquals(Collections.singleton(tp2.topicPartition), resp2.responseData(topicNames, request2.version).keySet)
 
     // All partitions with divergent epoch should be returned.
     respData.put(tp1, new FetchResponseData.PartitionData()
@@ -1349,7 +1761,7 @@ class FetchSessionTest {
     val resp3 = context2.updateAndGenerateResponseData(respData)
     assertEquals(Errors.NONE, resp3.error)
     assertEquals(resp1.sessionId, resp3.sessionId)
-    assertEquals(Utils.mkSet(tp1, tp2), resp3.responseData(topicNames, request2.version).keySet)
+    assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp3.responseData(topicNames, request2.version).keySet)
 
     // Partitions that meet other conditions should be returned regardless of whether
     // divergingEpoch is set or not.
@@ -1361,7 +1773,7 @@ class FetchSessionTest {
     val resp4 = context2.updateAndGenerateResponseData(respData)
     assertEquals(Errors.NONE, resp4.error)
     assertEquals(resp1.sessionId, resp4.sessionId)
-    assertEquals(Utils.mkSet(tp1, tp2), resp4.responseData(topicNames, request2.version).keySet)
+    assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp4.responseData(topicNames, request2.version).keySet)
   }
 
   @Test
@@ -1369,35 +1781,35 @@ class FetchSessionTest {
     val time = new MockTime()
     val cache = new FetchSessionCache(10, 1000)
     val fetchManager = new FetchManager(time, cache)
-    val tp1 = new TopicPartition("foo", 1)
-    val tp2 = new TopicPartition("bar", 2)
-    val tp3 = new TopicPartition("zar", 3)
     val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid(), "zar" -> Uuid.randomUuid()).asJava
     val topicNames = topicIds.asScala.map(_.swap).asJava
+    val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1))
+    val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 2))
+    val tp3 = new TopicIdPartition(topicIds.get("zar"), new TopicPartition("zar", 3))
 
-    val reqData = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-    reqData.put(tp1, new FetchRequest.PartitionData(100, 0, 1000, Optional.of(5), Optional.of(4)))
-    reqData.put(tp2, new FetchRequest.PartitionData(100, 0, 1000, Optional.of(5), Optional.of(4)))
-    reqData.put(tp3, new FetchRequest.PartitionData(100, 0, 1000, Optional.of(5), Optional.of(4)))
+    val reqData = new util.LinkedHashMap[TopicIdPartition, FetchRequest.PartitionData]
+    reqData.put(tp1, new FetchRequest.PartitionData(tp1.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4)))
+    reqData.put(tp2, new FetchRequest.PartitionData(tp2.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4)))
+    reqData.put(tp3, new FetchRequest.PartitionData(tp3.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4)))
 
     // Full fetch context returns all partitions in the response
     val context1 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), JFetchMetadata.INITIAL, false,
-     reqData, Collections.emptyList(), topicIds)
+     reqData, Collections.emptyList(), topicNames)
     assertEquals(classOf[FullFetchContext], context1.getClass)
 
-    val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     respData1.put(tp1, new FetchResponseData.PartitionData()
-      .setPartitionIndex(tp1.partition)
+      .setPartitionIndex(tp1.topicPartition.partition)
       .setHighWatermark(50)
       .setLastStableOffset(50)
       .setLogStartOffset(0))
     respData1.put(tp2, new FetchResponseData.PartitionData()
-      .setPartitionIndex(tp2.partition)
+      .setPartitionIndex(tp2.topicPartition.partition)
       .setHighWatermark(50)
       .setLastStableOffset(50)
       .setLogStartOffset(0))
     respData1.put(tp3, new FetchResponseData.PartitionData()
-      .setPartitionIndex(tp3.partition)
+      .setPartitionIndex(tp3.topicPartition.partition)
       .setHighWatermark(50)
       .setLastStableOffset(50)
       .setLogStartOffset(0))
@@ -1405,58 +1817,130 @@ class FetchSessionTest {
     val resp1 = context1.updateAndGenerateResponseData(respData1)
     assertEquals(Errors.NONE, resp1.error)
     assertNotEquals(INVALID_SESSION_ID, resp1.sessionId)
-    assertEquals(Utils.mkSet(tp1, tp2, tp3), resp1.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet())
+    assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition, tp3.topicPartition), resp1.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet())
 
     // Incremental fetch context returns partitions with changes but only deprioritizes
     // the partitions with records
     val context2 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), new JFetchMetadata(resp1.sessionId, 1), false,
-      reqData, Collections.emptyList(), topicIds)
+      reqData, Collections.emptyList(), topicNames)
     assertEquals(classOf[IncrementalFetchContext], context2.getClass)
 
     // Partitions are ordered in the session as per last response
     assertPartitionsOrder(context2, Seq(tp1, tp2, tp3))
 
     // Response is empty
-    val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     val resp2 = context2.updateAndGenerateResponseData(respData2)
     assertEquals(Errors.NONE, resp2.error)
     assertEquals(resp1.sessionId, resp2.sessionId)
     assertEquals(Collections.emptySet(), resp2.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet)
 
     // All partitions with changes should be returned.
-    val respData3 = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
+    val respData3 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
     respData3.put(tp1, new FetchResponseData.PartitionData()
-      .setPartitionIndex(tp1.partition)
+      .setPartitionIndex(tp1.topicPartition.partition)
       .setHighWatermark(60)
       .setLastStableOffset(50)
       .setLogStartOffset(0))
     respData3.put(tp2, new FetchResponseData.PartitionData()
-      .setPartitionIndex(tp2.partition)
+      .setPartitionIndex(tp2.topicPartition.partition)
       .setHighWatermark(60)
       .setLastStableOffset(50)
       .setLogStartOffset(0)
       .setRecords(MemoryRecords.withRecords(CompressionType.NONE,
         new SimpleRecord(100, null))))
     respData3.put(tp3, new FetchResponseData.PartitionData()
-      .setPartitionIndex(tp3.partition)
+      .setPartitionIndex(tp3.topicPartition.partition)
       .setHighWatermark(50)
       .setLastStableOffset(50)
       .setLogStartOffset(0))
     val resp3 = context2.updateAndGenerateResponseData(respData3)
     assertEquals(Errors.NONE, resp3.error)
     assertEquals(resp1.sessionId, resp3.sessionId)
-    assertEquals(Utils.mkSet(tp1, tp2), resp3.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet)
+    assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp3.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet)
 
     // Only the partitions whose returned records in the last response
     // were deprioritized
     assertPartitionsOrder(context2, Seq(tp1, tp3, tp2))
   }
 
-  private def assertPartitionsOrder(context: FetchContext, partitions: Seq[TopicPartition]): Unit = {
-    val partitionsInContext = ArrayBuffer.empty[TopicPartition]
-    context.foreachPartition { (tp, _, _) =>
+  @Test
+  def testCachedPartitionEqualsAndHashCode(): Unit = {
+    val topicId = Uuid.randomUuid()
+    val topicName = "topic"
+    val partition = 0
+
+    val cachedPartitionWithIdAndName = new CachedPartition(topicName, topicId, partition)
+    val cachedPartitionWithIdAndNoName = new CachedPartition(null, topicId, partition)
+    val cachedPartitionWithDifferentIdAndName = new CachedPartition(topicName, Uuid.randomUuid(), partition)
+    val cachedPartitionWithZeroIdAndName = new CachedPartition(topicName, Uuid.ZERO_UUID, partition)
+    val cachedPartitionWithZeroIdAndOtherName = new CachedPartition("otherTopic", Uuid.ZERO_UUID, partition)
+
+    // CachedPartitions with valid topic IDs will compare topic ID and partition but not topic name.
+    assertEquals(cachedPartitionWithIdAndName, cachedPartitionWithIdAndNoName)
+    assertEquals(cachedPartitionWithIdAndName.hashCode, cachedPartitionWithIdAndNoName.hashCode)
+
+    assertNotEquals(cachedPartitionWithIdAndName, cachedPartitionWithDifferentIdAndName)
+    assertNotEquals(cachedPartitionWithIdAndName.hashCode, cachedPartitionWithDifferentIdAndName.hashCode)
+
+    assertNotEquals(cachedPartitionWithIdAndName, cachedPartitionWithZeroIdAndName)
+    assertNotEquals(cachedPartitionWithIdAndName.hashCode, cachedPartitionWithZeroIdAndName.hashCode)
+
+    // CachedPartitions will null name and valid IDs will act just like ones with valid names
+    assertEquals(cachedPartitionWithIdAndNoName, cachedPartitionWithIdAndName)
+    assertEquals(cachedPartitionWithIdAndNoName.hashCode, cachedPartitionWithIdAndName.hashCode)
+
+    assertNotEquals(cachedPartitionWithIdAndNoName, cachedPartitionWithDifferentIdAndName)
+    assertNotEquals(cachedPartitionWithIdAndNoName.hashCode, cachedPartitionWithDifferentIdAndName.hashCode)
+
+    assertNotEquals(cachedPartitionWithIdAndNoName, cachedPartitionWithZeroIdAndName)
+    assertNotEquals(cachedPartitionWithIdAndNoName.hashCode, cachedPartitionWithZeroIdAndName.hashCode)
+
+    // CachedPartition with zero Uuids will compare topic name and partition.
+    assertNotEquals(cachedPartitionWithZeroIdAndName, cachedPartitionWithZeroIdAndOtherName)
+    assertNotEquals(cachedPartitionWithZeroIdAndName.hashCode, cachedPartitionWithZeroIdAndOtherName.hashCode)
+
+    assertEquals(cachedPartitionWithZeroIdAndName, cachedPartitionWithZeroIdAndName)
+    assertEquals(cachedPartitionWithZeroIdAndName.hashCode, cachedPartitionWithZeroIdAndName.hashCode)
+  }
+
+  @Test
+  def testMaybeResolveUnknownName(): Unit = {
+    val namedPartition = new CachedPartition("topic", Uuid.randomUuid(), 0)
+    val nullNamePartition1 = new CachedPartition(null, Uuid.randomUuid(), 0)
+    val nullNamePartition2 = new CachedPartition(null, Uuid.randomUuid(), 0)
+
+    val topicNames = Map(namedPartition.topicId -> "foo", nullNamePartition1.topicId -> "bar").asJava
+
+    // Since the name is not null, we should not change the topic name.
+    // We should never have a scenario where the same ID is used by two topic names, but this is used to test we respect the null check.
+    namedPartition.maybeResolveUnknownName(topicNames)
+    assertEquals("topic", namedPartition.topic)
+
+    // We will resolve this name as it is in the map and the current name is null.
+    nullNamePartition1.maybeResolveUnknownName(topicNames)
+    assertEquals("bar", nullNamePartition1.topic)
+
+    // If the ID is not in the map, then we don't resolve the name.
+    nullNamePartition2.maybeResolveUnknownName(topicNames)
+    assertEquals(null, nullNamePartition2.topic)
+  }
+
+  private def assertPartitionsOrder(context: FetchContext, partitions: Seq[TopicIdPartition]): Unit = {
+    val partitionsInContext = ArrayBuffer.empty[TopicIdPartition]
+    context.foreachPartition { (tp, _) =>
       partitionsInContext += tp
     }
     assertEquals(partitions, partitionsInContext.toSeq)
   }
 }
+
+object FetchSessionTest {
+  def idUsageCombinations: java.util.stream.Stream[Arguments] = {
+    val data = new java.util.ArrayList[Arguments]()
+    for (startsWithTopicIds <- Array(java.lang.Boolean.TRUE, java.lang.Boolean.FALSE))
+      for (endsWithTopicIds <- Array(java.lang.Boolean.TRUE, java.lang.Boolean.FALSE))
+        data.add(Arguments.of(startsWithTopicIds, endsWithTopicIds))
+    data.stream()
+  }
+}
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index fb77d5b..777857f 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -69,7 +69,7 @@ import org.apache.kafka.common.requests.{FetchMetadata => JFetchMetadata, _}
 import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourceType}
 import org.apache.kafka.common.security.auth.{KafkaPrincipal, KafkaPrincipalSerde, SecurityProtocol}
 import org.apache.kafka.common.utils.{ProducerIdAndEpoch, SecurityUtils, Utils}
-import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicPartition, Uuid}
+import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer}
 import org.easymock.EasyMock._
 import org.easymock.{Capture, EasyMock, IAnswer}
@@ -2380,7 +2380,8 @@ class KafkaApisTest {
    */
   @Test
   def testFetchRequestV9WithNoLogConfig(): Unit = {
-    val tp = new TopicPartition("foo", 0)
+    val tidp = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0))
+    val tp = tidp.topicPartition
     addTopicToMetadataCache(tp.topic, numPartitions = 1)
     val hw = 3
     val timestamp = 1000
@@ -2388,38 +2389,39 @@ class KafkaApisTest {
     expect(replicaManager.getLogConfig(EasyMock.eq(tp))).andReturn(None)
 
     replicaManager.fetchMessages(anyLong, anyInt, anyInt, anyInt, anyBoolean,
-      anyObject[Seq[(TopicPartition, FetchRequest.PartitionData)]], anyObject[util.Map[String, Uuid]](), anyObject[ReplicaQuota],
-      anyObject[Seq[(TopicPartition, FetchPartitionData)] => Unit](), anyObject[IsolationLevel],
+      anyObject[Seq[(TopicIdPartition, FetchRequest.PartitionData)]], anyObject[ReplicaQuota],
+      anyObject[Seq[(TopicIdPartition, FetchPartitionData)] => Unit](), anyObject[IsolationLevel],
       anyObject[Option[ClientMetadata]])
     expectLastCall[Unit].andAnswer(new IAnswer[Unit] {
       def answer: Unit = {
-        val callback = getCurrentArguments.apply(8)
-          .asInstanceOf[Seq[(TopicPartition, FetchPartitionData)] => Unit]
+        val callback = getCurrentArguments.apply(7)
+          .asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]
         val records = MemoryRecords.withRecords(CompressionType.NONE,
           new SimpleRecord(timestamp, "foo".getBytes(StandardCharsets.UTF_8)))
-        callback(Seq(tp -> FetchPartitionData(Errors.NONE, hw, 0, records,
+        callback(Seq(tidp -> FetchPartitionData(Errors.NONE, hw, 0, records,
           None, None, None, Option.empty, isReassignmentFetch = false)))
       }
     })
 
-    val fetchData = Map(tp -> new FetchRequest.PartitionData(0, 0, 1000,
+    val fetchData = Map(tidp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000,
+      Optional.empty())).asJava
+    val fetchDataBuilder = Map(tp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000,
       Optional.empty())).asJava
     val fetchMetadata = new JFetchMetadata(0, 0)
     val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100),
-      fetchMetadata, fetchData, false, metadataCache.topicNamesToIds(), false)
+      fetchMetadata, fetchData, false, false)
     expect(fetchManager.newContext(
       anyObject[Short],
       anyObject[JFetchMetadata],
       anyObject[Boolean],
-      anyObject[util.Map[TopicPartition, FetchRequest.PartitionData]],
-      anyObject[util.List[TopicPartition]],
-      anyObject[util.Map[String, Uuid]])).andReturn(fetchContext)
+      anyObject[util.Map[TopicIdPartition, FetchRequest.PartitionData]],
+      anyObject[util.List[TopicIdPartition]],
+      anyObject[util.Map[Uuid, String]])).andReturn(fetchContext)
 
     EasyMock.expect(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
       anyObject[RequestChannel.Request](), anyDouble, anyLong)).andReturn(0)
 
-    val fetchRequest = new FetchRequest.Builder(9, 9, -1, 100, 0, fetchData,
-      metadataCache.topicNamesToIds())
+    val fetchRequest = new FetchRequest.Builder(9, 9, -1, 100, 0, fetchDataBuilder)
       .build()
     val request = buildRequest(fetchRequest)
     val capturedResponse = expectNoThrottling(request)
@@ -2440,6 +2442,62 @@ class KafkaApisTest {
     assertNull(partitionData.abortedTransactions)
   }
 
+  /**
+   * Verifies that partitions with unknown topic ID errors are added to the erroneous set and there is not an attempt to fetch them.
+   */
+  @ParameterizedTest
+  @ValueSource(ints = Array(-1, 0))
+  def testFetchRequestErroneousPartitions(replicaId: Int): Unit = {
+    val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0))
+    val unresolvedFoo = new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition))
+
+    addTopicToMetadataCache(foo.topic, 1, topicId = foo.topicId)
+
+    // We will never return a logConfig when the topic name is null. This is ok since we won't have any records to convert.
+    expect(replicaManager.getLogConfig(EasyMock.eq(unresolvedFoo.topicPartition))).andReturn(None)
+
+    // Simulate unknown topic ID in the context
+    val fetchData = Map(new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition)) ->
+      new FetchRequest.PartitionData(foo.topicId, 0, 0, 1000, Optional.empty())).asJava
+    val fetchDataBuilder = Map(foo.topicPartition -> new FetchRequest.PartitionData(foo.topicId, 0, 0, 1000,
+      Optional.empty())).asJava
+    val fetchMetadata = new JFetchMetadata(0, 0)
+    val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100),
+      fetchMetadata, fetchData, true, replicaId >= 0)
+    // We expect to have the resolved partition, but we will simulate an unknown one with the fetchContext we return.
+    expect(fetchManager.newContext(
+      ApiKeys.FETCH.latestVersion,
+      fetchMetadata,
+      replicaId >= 0,
+      Collections.singletonMap(foo, new FetchRequest.PartitionData(foo.topicId, 0, 0, 1000, Optional.empty())),
+      Collections.emptyList[TopicIdPartition],
+      metadataCache.topicIdsToNames())
+    ).andReturn(fetchContext)
+
+    EasyMock.expect(clientQuotaManager.maybeRecordAndGetThrottleTimeMs(
+      anyObject[RequestChannel.Request](), anyDouble, anyLong)).andReturn(0)
+
+    // If replicaId is -1 we will build a consumer request. Any non-negative replicaId will build a follower request.
+    val fetchRequest = new FetchRequest.Builder(ApiKeys.FETCH.latestVersion, ApiKeys.FETCH.latestVersion,
+      replicaId, 100, 0, fetchDataBuilder).metadata(fetchMetadata).build()
+    val request = buildRequest(fetchRequest)
+    val capturedResponse = expectNoThrottling(request)
+
+    EasyMock.replay(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, fetchManager)
+    createKafkaApis().handleFetchRequest(request)
+
+    val response = capturedResponse.getValue.asInstanceOf[FetchResponse]
+    val responseData = response.responseData(metadataCache.topicIdsToNames(), ApiKeys.FETCH.latestVersion)
+    assertTrue(responseData.containsKey(foo.topicPartition))
+
+    val partitionData = responseData.get(foo.topicPartition)
+    assertEquals(Errors.UNKNOWN_TOPIC_ID.code, partitionData.errorCode)
+    assertEquals(-1, partitionData.highWatermark)
+    assertEquals(-1, partitionData.lastStableOffset)
+    assertEquals(-1, partitionData.logStartOffset)
+    assertEquals(MemoryRecords.EMPTY, FetchResponse.recordsOrFail(partitionData))
+  }
+
   @Test
   def testJoinGroupProtocolsOrder(): Unit = {
     val protocols = List(
@@ -2948,39 +3006,41 @@ class KafkaApisTest {
   private def assertReassignmentAndReplicationBytesOutPerSec(isReassigning: Boolean): Unit = {
     val leaderEpoch = 0
     val tp0 = new TopicPartition("tp", 0)
+    val topicId = Uuid.randomUuid()
+    val tidp0 = new TopicIdPartition(topicId, tp0)
 
-    setupBasicMetadataCache(tp0.topic, numPartitions = 1, 1, Uuid.randomUuid())
+    setupBasicMetadataCache(tp0.topic, numPartitions = 1, 1, topicId)
     val hw = 3
 
-    val fetchData = Collections.singletonMap(tp0, new FetchRequest.PartitionData(0, 0, Int.MaxValue, Optional.of(leaderEpoch)))
+    val fetchDataBuilder = Collections.singletonMap(tp0, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, Int.MaxValue, Optional.of(leaderEpoch)))
+    val fetchData = Collections.singletonMap(tidp0, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, Int.MaxValue, Optional.of(leaderEpoch)))
     val fetchFromFollower = buildRequest(new FetchRequest.Builder(
-      ApiKeys.FETCH.oldestVersion(), ApiKeys.FETCH.latestVersion(), 1, 1000, 0, fetchData,
-        metadataCache.topicNamesToIds()).build())
+      ApiKeys.FETCH.oldestVersion(), ApiKeys.FETCH.latestVersion(), 1, 1000, 0, fetchDataBuilder).build())
 
     val records = MemoryRecords.withRecords(CompressionType.NONE,
       new SimpleRecord(1000, "foo".getBytes(StandardCharsets.UTF_8)))
     replicaManager.fetchMessages(anyLong, anyInt, anyInt, anyInt, anyBoolean,
-      anyObject[Seq[(TopicPartition, FetchRequest.PartitionData)]], anyObject[util.Map[String, Uuid]](), anyObject[ReplicaQuota],
-      anyObject[Seq[(TopicPartition, FetchPartitionData)] => Unit](), anyObject[IsolationLevel],
+      anyObject[Seq[(TopicIdPartition, FetchRequest.PartitionData)]], anyObject[ReplicaQuota],
+      anyObject[Seq[(TopicIdPartition, FetchPartitionData)] => Unit](), anyObject[IsolationLevel],
       anyObject[Option[ClientMetadata]])
     expectLastCall[Unit].andAnswer(new IAnswer[Unit] {
       def answer: Unit = {
-        val callback = getCurrentArguments.apply(8).asInstanceOf[Seq[(TopicPartition, FetchPartitionData)] => Unit]
-        callback(Seq(tp0 -> FetchPartitionData(Errors.NONE, hw, 0, records,
+        val callback = getCurrentArguments.apply(7).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]
+        callback(Seq(tidp0 -> FetchPartitionData(Errors.NONE, hw, 0, records,
           None, None, None, Option.empty, isReassignmentFetch = isReassigning)))
       }
     })
 
     val fetchMetadata = new JFetchMetadata(0, 0)
     val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100),
-      fetchMetadata, fetchData, true, metadataCache.topicNamesToIds(), true)
+      fetchMetadata, fetchData, true, true)
     expect(fetchManager.newContext(
       anyObject[Short],
       anyObject[JFetchMetadata],
       anyObject[Boolean],
-      anyObject[util.Map[TopicPartition, FetchRequest.PartitionData]],
-      anyObject[util.List[TopicPartition]],
-      anyObject[util.Map[String, Uuid]])).andReturn(fetchContext)
+      anyObject[util.Map[TopicIdPartition, FetchRequest.PartitionData]],
+      anyObject[util.List[TopicIdPartition]],
+      anyObject[util.Map[Uuid, String]])).andReturn(fetchContext)
 
     expect(replicaQuotaManager.record(anyLong()))
     expect(replicaManager.getLogConfig(EasyMock.eq(tp0))).andReturn(None)
@@ -3468,7 +3528,7 @@ class KafkaApisTest {
   }
 
   private def addTopicToMetadataCache(topic: String, numPartitions: Int, numBrokers: Int = 1, topicId: Uuid = Uuid.ZERO_UUID): Unit = {
-    val updateMetadataRequest = createBasicMetadataRequest(topic, numPartitions, 0, numBrokers)
+    val updateMetadataRequest = createBasicMetadataRequest(topic, numPartitions, 0, numBrokers, topicId)
     MetadataCacheTest.updateCache(metadataCache, updateMetadataRequest)
   }
 
@@ -3530,11 +3590,11 @@ class KafkaApisTest {
   def testSizeOfThrottledPartitions(): Unit = {
     val topicNames = new util.HashMap[Uuid, String]
     val topicIds = new util.HashMap[String, Uuid]()
-    def fetchResponse(data: Map[TopicPartition, String]): FetchResponse = {
-      val responseData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData](
+    def fetchResponse(data: Map[TopicIdPartition, String]): FetchResponse = {
+      val responseData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData](
         data.map { case (tp, raw) =>
           tp -> new FetchResponseData.PartitionData()
-            .setPartitionIndex(tp.partition)
+            .setPartitionIndex(tp.topicPartition.partition)
             .setHighWatermark(105)
             .setLastStableOffset(105)
             .setLogStartOffset(0)
@@ -3542,25 +3602,25 @@ class KafkaApisTest {
       }.toMap.asJava)
 
       data.foreach{case (tp, _) =>
-        val id = Uuid.randomUuid()
-        topicIds.put(tp.topic(), id)
-        topicNames.put(id, tp.topic())
+        topicIds.put(tp.topicPartition.topic, tp.topicId)
+        topicNames.put(tp.topicId, tp.topicPartition.topic)
       }
-      FetchResponse.of(Errors.NONE, 100, 100, responseData, topicIds)
+      FetchResponse.of(Errors.NONE, 100, 100, responseData)
     }
 
-    val throttledPartition = new TopicPartition("throttledData", 0)
+    val throttledPartition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("throttledData", 0))
     val throttledData = Map(throttledPartition -> "throttledData")
     val expectedSize = FetchResponse.sizeOf(FetchResponseData.HIGHEST_SUPPORTED_VERSION,
-      fetchResponse(throttledData).responseData(topicNames, FetchResponseData.HIGHEST_SUPPORTED_VERSION).entrySet.iterator, topicIds)
+      fetchResponse(throttledData).responseData(topicNames, FetchResponseData.HIGHEST_SUPPORTED_VERSION).entrySet.asScala.map( entry =>
+      (new TopicIdPartition(Uuid.ZERO_UUID, entry.getKey), entry.getValue)).toMap.asJava.entrySet.iterator)
 
-    val response = fetchResponse(throttledData ++ Map(new TopicPartition("nonThrottledData", 0) -> "nonThrottledData"))
+    val response = fetchResponse(throttledData ++ Map(new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("nonThrottledData", 0)) -> "nonThrottledData"))
 
     val quota = Mockito.mock(classOf[ReplicationQuotaManager])
     Mockito.when(quota.isThrottled(ArgumentMatchers.any(classOf[TopicPartition])))
-      .thenAnswer(invocation => throttledPartition == invocation.getArgument(0).asInstanceOf[TopicPartition])
+      .thenAnswer(invocation => throttledPartition.topicPartition == invocation.getArgument(0).asInstanceOf[TopicPartition])
 
-    assertEquals(expectedSize, KafkaApis.sizeOfThrottledPartitions(FetchResponseData.HIGHEST_SUPPORTED_VERSION, response, quota, topicIds))
+    assertEquals(expectedSize, KafkaApis.sizeOfThrottledPartitions(FetchResponseData.HIGHEST_SUPPORTED_VERSION, response, quota))
   }
 
   @Test
diff --git a/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala b/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala
index 78f03d9..be62211 100755
--- a/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala
+++ b/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala
@@ -17,7 +17,7 @@
 
 package kafka.server
 
-import kafka.log.{ClientRecordDeletion, UnifiedLog, LogSegment}
+import kafka.log.{ClientRecordDeletion, LogSegment, UnifiedLog}
 import kafka.utils.{MockTime, TestUtils}
 import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic}
 import org.apache.kafka.common.message.ListOffsetsResponseData.{ListOffsetsPartitionResponse, ListOffsetsTopicResponse}
@@ -135,6 +135,7 @@ class LogOffsetTest extends BaseRequestTest {
 
     val topicIds = getTopicIds().asJava
     val topicNames = topicIds.asScala.map(_.swap).asJava
+    val topicId = topicIds.get(topic)
 
     for (_ <- 0 until 20)
       log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0)
@@ -152,8 +153,8 @@ class LogOffsetTest extends BaseRequestTest {
 
     // try to fetch using latest offset
     val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 1,
-      Map(topicPartition -> new FetchRequest.PartitionData(consumerOffsets.head, FetchRequest.INVALID_LOG_START_OFFSET,
-        300 * 1024, Optional.empty())).asJava, topicIds).build()
+      Map(topicPartition -> new FetchRequest.PartitionData(topicId, consumerOffsets.head, FetchRequest.INVALID_LOG_START_OFFSET,
+        300 * 1024, Optional.empty())).asJava).build()
     val fetchResponse = sendFetchRequest(fetchRequest)
     assertFalse(FetchResponse.recordsOrFail(fetchResponse.responseData(topicNames, ApiKeys.FETCH.latestVersion).get(topicPartition)).batches.iterator.hasNext)
   }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
index c80be99..80d69ec 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala
@@ -31,7 +31,7 @@ import org.apache.kafka.common.message.UpdateMetadataRequestData
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.MemoryRecords
 import org.apache.kafka.common.requests.{FetchRequest, UpdateMetadataRequest}
-import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid}
+import org.apache.kafka.common.{IsolationLevel, TopicIdPartition, TopicPartition, Uuid}
 import org.easymock.EasyMock._
 import org.easymock.{Capture, CaptureType, EasyMock, IExpectationSetters}
 import org.junit.jupiter.api.Assertions._
@@ -49,6 +49,7 @@ class ReplicaAlterLogDirsThreadTest {
   private val topicId = Uuid.randomUuid()
   private val topicIds = collection.immutable.Map("topic1" -> topicId)
   private val topicNames = collection.immutable.Map(topicId -> "topic1")
+  private val tid1p0 = new TopicIdPartition(topicId, t1p0)
   private val failedPartitions = new FailedPartitions
 
   private val partitionStates = List(new UpdateMetadataRequestData.UpdateMetadataPartitionState()
@@ -132,7 +133,7 @@ class ReplicaAlterLogDirsThreadTest {
     when(futureLog.logEndOffset).thenReturn(0L)
     when(futureLog.latestEpoch).thenReturn(None)
 
-    val fencedRequestData = new FetchRequest.PartitionData(0L, 0L,
+    val fencedRequestData = new FetchRequest.PartitionData(topicId, 0L, 0L,
       config.replicaFetchMaxBytes, Optional.of(leaderEpoch - 1))
     val fencedResponseData = FetchPartitionData(
       error = Errors.FENCED_LEADER_EPOCH,
@@ -144,7 +145,7 @@ class ReplicaAlterLogDirsThreadTest {
       abortedTransactions = None,
       preferredReadReplica = None,
       isReassignmentFetch = false)
-    mockFetchFromCurrentLog(t1p0, fencedRequestData, config, replicaManager, fencedResponseData)
+    mockFetchFromCurrentLog(tid1p0, fencedRequestData, config, replicaManager, fencedResponseData)
 
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
     val thread = new ReplicaAlterLogDirsThread(
@@ -172,7 +173,7 @@ class ReplicaAlterLogDirsThreadTest {
     assertEquals(Some(leaderEpoch), thread.fetchState(t1p0).map(_.currentLeaderEpoch))
     assertEquals(1, thread.partitionCount)
 
-    val requestData = new FetchRequest.PartitionData(0L, 0L,
+    val requestData = new FetchRequest.PartitionData(topicId, 0L, 0L,
       config.replicaFetchMaxBytes, Optional.of(leaderEpoch))
     val responseData = FetchPartitionData(
       error = Errors.NONE,
@@ -184,7 +185,7 @@ class ReplicaAlterLogDirsThreadTest {
       abortedTransactions = None,
       preferredReadReplica = None,
       isReassignmentFetch = false)
-    mockFetchFromCurrentLog(t1p0, requestData, config, replicaManager, responseData)
+    mockFetchFromCurrentLog(tid1p0, requestData, config, replicaManager, responseData)
 
     thread.doWork()
 
@@ -230,7 +231,7 @@ class ReplicaAlterLogDirsThreadTest {
     when(futureLog.logEndOffset).thenReturn(0L)
     when(futureLog.latestEpoch).thenReturn(None)
 
-    val requestData = new FetchRequest.PartitionData(0L, 0L,
+    val requestData = new FetchRequest.PartitionData(topicId, 0L, 0L,
       config.replicaFetchMaxBytes, Optional.of(leaderEpoch))
     val responseData = FetchPartitionData(
       error = Errors.NONE,
@@ -242,7 +243,7 @@ class ReplicaAlterLogDirsThreadTest {
       abortedTransactions = None,
       preferredReadReplica = None,
       isReassignmentFetch = false)
-    mockFetchFromCurrentLog(t1p0, requestData, config, replicaManager, responseData)
+    mockFetchFromCurrentLog(tid1p0, requestData, config, replicaManager, responseData)
 
     val endPoint = new BrokerEndPoint(0, "localhost", 1000)
     val thread = new ReplicaAlterLogDirsThread(
@@ -264,27 +265,26 @@ class ReplicaAlterLogDirsThreadTest {
     assertEquals(0, thread.partitionCount)
   }
 
-  private def mockFetchFromCurrentLog(topicPartition: TopicPartition,
+  private def mockFetchFromCurrentLog(topicIdPartition: TopicIdPartition,
                                       requestData: FetchRequest.PartitionData,
                                       config: KafkaConfig,
                                       replicaManager: ReplicaManager,
                                       responseData: FetchPartitionData): Unit = {
-    val callbackCaptor: ArgumentCaptor[Seq[(TopicPartition, FetchPartitionData)] => Unit] =
-      ArgumentCaptor.forClass(classOf[Seq[(TopicPartition, FetchPartitionData)] => Unit])
+    val callbackCaptor: ArgumentCaptor[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] =
+      ArgumentCaptor.forClass(classOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit])
     when(replicaManager.fetchMessages(
       timeout = ArgumentMatchers.eq(0L),
       replicaId = ArgumentMatchers.eq(Request.FutureLocalReplicaId),
       fetchMinBytes = ArgumentMatchers.eq(0),
       fetchMaxBytes = ArgumentMatchers.eq(config.replicaFetchResponseMaxBytes),
       hardMaxBytesLimit = ArgumentMatchers.eq(false),
-      fetchInfos = ArgumentMatchers.eq(Seq(topicPartition -> requestData)),
-      topicIds = ArgumentMatchers.eq(topicIds.asJava),
+      fetchInfos = ArgumentMatchers.eq(Seq(topicIdPartition -> requestData)),
       quota = ArgumentMatchers.eq(UnboundedQuota),
       responseCallback = callbackCaptor.capture(),
       isolationLevel = ArgumentMatchers.eq(IsolationLevel.READ_UNCOMMITTED),
       clientMetadata = ArgumentMatchers.eq(None)
     )).thenAnswer(_ => {
-      callbackCaptor.getValue.apply(Seq((topicPartition, responseData)))
+      callbackCaptor.getValue.apply(Seq((topicIdPartition, responseData)))
     })
   }
 
@@ -446,7 +446,7 @@ class ReplicaAlterLogDirsThreadTest {
     val partitionT1p0: Partition = createMock(classOf[Partition])
     val partitionT1p1: Partition = createMock(classOf[Partition])
     val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager])
-    val responseCallback: Capture[Seq[(TopicPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
+    val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
 
     val partitionT1p0Id = 0
     val partitionT1p1Id = 1
@@ -536,7 +536,7 @@ class ReplicaAlterLogDirsThreadTest {
     val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog])
     val partition: Partition = createMock(classOf[Partition])
     val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager])
-    val responseCallback: Capture[Seq[(TopicPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
+    val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
 
     val partitionId = 0
     val leaderEpoch = 5
@@ -622,7 +622,7 @@ class ReplicaAlterLogDirsThreadTest {
     val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog])
     val partition: Partition = createMock(classOf[Partition])
     val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager])
-    val responseCallback: Capture[Seq[(TopicPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
+    val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
 
     val initialFetchOffset = 100
 
@@ -676,7 +676,7 @@ class ReplicaAlterLogDirsThreadTest {
     val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog])
     val partition: Partition = createMock(classOf[Partition])
     val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager])
-    val responseCallback: Capture[Seq[(TopicPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
+    val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
 
     val partitionId = 0
     val futureReplicaLeaderEpoch = 1
@@ -720,11 +720,10 @@ class ReplicaAlterLogDirsThreadTest {
       EasyMock.anyObject(),
       EasyMock.anyObject(),
       EasyMock.anyObject(),
-      EasyMock.anyObject(),
       EasyMock.capture(responseCallback),
       EasyMock.anyObject(),
       EasyMock.anyObject())
-    ).andAnswer(() => responseCallback.getValue.apply(Seq.empty[(TopicPartition, FetchPartitionData)])).anyTimes()
+    ).andAnswer(() => responseCallback.getValue.apply(Seq.empty[(TopicIdPartition, FetchPartitionData)])).anyTimes()
 
     replay(replicaManager, logManager, quotaManager, partition, log, futureLog)
 
@@ -766,7 +765,7 @@ class ReplicaAlterLogDirsThreadTest {
     val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog])
     val partition: Partition = createMock(classOf[Partition])
     val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager])
-    val responseCallback: Capture[Seq[(TopicPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
+    val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]  = EasyMock.newCapture()
 
     val partitionId = 0
     val leaderEpoch = 5
@@ -864,7 +863,7 @@ class ReplicaAlterLogDirsThreadTest {
     assertEquals(0, request.minBytes)
     val fetchInfos = request.fetchData(topicNames.asJava).asScala.toSeq
     assertEquals(1, fetchInfos.length)
-    assertEquals(t1p0, fetchInfos.head._1, "Expected fetch request for first partition")
+    assertEquals(t1p0, fetchInfos.head._1.topicPartition, "Expected fetch request for first partition")
     assertEquals(150, fetchInfos.head._2.fetchOffset)
   }
 
@@ -915,7 +914,7 @@ class ReplicaAlterLogDirsThreadTest {
     assertFalse(partitionsWithError.nonEmpty)
     val fetchInfos = fetchRequest.fetchRequest.build().fetchData(topicNames.asJava).asScala.toSeq
     assertEquals(1, fetchInfos.length)
-    assertEquals(t1p0, fetchInfos.head._1, "Expected fetch request for non-truncating partition")
+    assertEquals(t1p0, fetchInfos.head._1.topicPartition, "Expected fetch request for non-truncating partition")
     assertEquals(150, fetchInfos.head._2.fetchOffset)
 
     // one partition is ready and one is delayed
@@ -929,7 +928,7 @@ class ReplicaAlterLogDirsThreadTest {
     assertFalse(partitionsWithError2.nonEmpty)
     val fetchInfos2 = fetchRequest2.fetchRequest.build().fetchData(topicNames.asJava).asScala.toSeq
     assertEquals(1, fetchInfos2.length)
-    assertEquals(t1p0, fetchInfos2.head._1, "Expected fetch request for non-delayed partition")
+    assertEquals(t1p0, fetchInfos2.head._1.topicPartition, "Expected fetch request for non-delayed partition")
     assertEquals(140, fetchInfos2.head._2.fetchOffset)
 
     // both partitions are delayed
@@ -955,7 +954,7 @@ class ReplicaAlterLogDirsThreadTest {
   }
 
   def stubWithFetchMessages(logT1p0: UnifiedLog, logT1p1: UnifiedLog, futureLog: UnifiedLog, partition: Partition, replicaManager: ReplicaManager,
-                            responseCallback: Capture[Seq[(TopicPartition, FetchPartitionData)] => Unit]): IExpectationSetters[Unit] = {
+                            responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]): IExpectationSetters[Unit] = {
     stub(logT1p0, logT1p1, futureLog, partition, replicaManager)
     expect(replicaManager.fetchMessages(
       EasyMock.anyLong(),
@@ -965,10 +964,9 @@ class ReplicaAlterLogDirsThreadTest {
       EasyMock.anyObject(),
       EasyMock.anyObject(),
       EasyMock.anyObject(),
-      EasyMock.anyObject(),
       EasyMock.capture(responseCallback),
       EasyMock.anyObject(),
       EasyMock.anyObject())
-    ).andAnswer(() => responseCallback.getValue.apply(Seq.empty[(TopicPartition, FetchPartitionData)])).anyTimes()
+    ).andAnswer(() => responseCallback.getValue.apply(Seq.empty[(TopicIdPartition, FetchPartitionData)])).anyTimes()
   }
 }
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
index 0472770..bbd9330 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala
@@ -19,11 +19,12 @@ package kafka.server
 import kafka.api.{ApiVersion, KAFKA_2_6_IV0}
 import kafka.cluster.{BrokerEndPoint, Partition}
 import kafka.log.{LogAppendInfo, LogManager, UnifiedLog}
+import kafka.server.AbstractFetcherThread.ResultWithPartitions
 import kafka.server.QuotaFactory.UnboundedQuota
 import kafka.server.epoch.util.ReplicaFetcherMockBlockingSend
 import kafka.server.metadata.ZkMetadataCache
 import kafka.utils.TestUtils
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.common.message.{FetchResponseData, UpdateMetadataRequestData}
 import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition
 import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
@@ -32,7 +33,7 @@ import org.apache.kafka.common.protocol.Errors._
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET}
-import org.apache.kafka.common.requests.UpdateMetadataRequest
+import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, UpdateMetadataRequest}
 import org.apache.kafka.common.utils.SystemTime
 import org.easymock.EasyMock._
 import org.easymock.{Capture, CaptureType}
@@ -40,7 +41,8 @@ import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.{AfterEach, Test}
 
 import java.nio.charset.StandardCharsets
-import java.util.Collections
+import java.util
+import java.util.{Collections, Optional}
 import scala.collection.{Map, mutable}
 import scala.jdk.CollectionConverters._
 
@@ -949,6 +951,77 @@ class ReplicaFetcherThreadTest {
     assertProcessPartitionDataWhen(isReassigning = false)
   }
 
+  @Test
+  def testBuildFetch(): Unit = {
+    val tid1p0 = new TopicIdPartition(topicId1, t1p0)
+    val tid1p1 = new TopicIdPartition(topicId1, t1p1)
+    val tid2p1 = new TopicIdPartition(topicId2, t2p1)
+
+    val props = TestUtils.createBrokerConfig(1, "localhost:1234")
+    val config = KafkaConfig.fromProps(props)
+    val replicaManager: ReplicaManager = mock(classOf[ReplicaManager])
+    val mockBlockingSend: BlockingSend = createMock(classOf[BlockingSend])
+    val replicaQuota: ReplicaQuota = createNiceMock(classOf[ReplicaQuota])
+    val log: UnifiedLog = createNiceMock(classOf[UnifiedLog])
+
+    expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats]))
+    expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes()
+    expect(replicaQuota.isThrottled(anyObject(classOf[TopicPartition]))).andReturn(false).anyTimes()
+    expect(log.logStartOffset).andReturn(0).anyTimes()
+    replay(log, replicaQuota, replicaManager)
+
+    val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions,
+      replicaManager, new Metrics(), new SystemTime(), replicaQuota, Some(mockBlockingSend))
+
+    val leaderEpoch = 1
+
+    val partitionMap = Map(
+        t1p0 -> PartitionFetchState(Some(topicId1), 150, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None),
+        t1p1 -> PartitionFetchState(Some(topicId1), 155, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None),
+        t2p1 -> PartitionFetchState(Some(topicId2), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None))
+
+    val ResultWithPartitions(fetchRequestOpt, _) = thread.buildFetch(partitionMap)
+
+    assertTrue(fetchRequestOpt.isDefined)
+    val fetchRequestBuilder = fetchRequestOpt.get.fetchRequest
+
+    val partitionDataMap = partitionMap.map { case (tp, state) =>
+      (tp, new FetchRequest.PartitionData(state.topicId.get, state.fetchOffset, 0L,
+        config.replicaFetchMaxBytes, Optional.of(state.currentLeaderEpoch), Optional.empty()))
+    }
+
+    assertEquals(partitionDataMap.asJava, fetchRequestBuilder.fetchData())
+    assertEquals(0, fetchRequestBuilder.replaced().size)
+    assertEquals(0, fetchRequestBuilder.removed().size)
+
+    val responseData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+    responseData.put(tid1p0, new FetchResponseData.PartitionData())
+    responseData.put(tid1p1, new FetchResponseData.PartitionData())
+    responseData.put(tid2p1, new FetchResponseData.PartitionData())
+    val fetchResponse = FetchResponse.of(Errors.NONE, 0, 123, responseData)
+
+    thread.fetchSessionHandler.handleResponse(fetchResponse, ApiKeys.FETCH.latestVersion())
+
+    // Remove t1p0, change the ID for t2p1, and keep t1p1 the same
+    val newTopicId = Uuid.randomUuid()
+    val partitionMap2 = Map(
+      t1p1 -> PartitionFetchState(Some(topicId1), 155, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None),
+      t2p1 -> PartitionFetchState(Some(newTopicId), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None))
+    val ResultWithPartitions(fetchRequestOpt2, _) = thread.buildFetch(partitionMap2)
+
+    // Since t1p1 didn't change, we drop that one
+    val partitionDataMap2 = partitionMap2.drop(1).map { case (tp, state) =>
+      (tp, new FetchRequest.PartitionData(state.topicId.get, state.fetchOffset, 0L,
+        config.replicaFetchMaxBytes, Optional.of(state.currentLeaderEpoch), Optional.empty()))
+    }
+
+    assertTrue(fetchRequestOpt2.isDefined)
+    val fetchRequestBuilder2 = fetchRequestOpt2.get.fetchRequest
+    assertEquals(partitionDataMap2.asJava, fetchRequestBuilder2.fetchData())
+    assertEquals(Collections.singletonList(tid2p1), fetchRequestBuilder2.replaced())
+    assertEquals(Collections.singletonList(tid1p0), fetchRequestBuilder2.removed())
+  }
+
   private def newOffsetForLeaderPartitionResult(
    tp: TopicPartition,
    leaderEpoch: Int,
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala
index 714920d..08c0acf 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala
@@ -19,7 +19,7 @@ package kafka.server
 import java.net.InetAddress
 import java.util
 import java.util.concurrent.{CompletableFuture, Executors, LinkedBlockingQueue, TimeUnit}
-import java.util.{Collections, Optional, Properties}
+import java.util.{Optional, Properties}
 
 import kafka.api.LeaderAndIsr
 import kafka.log.{AppendOrigin, LogConfig}
@@ -34,7 +34,7 @@ import org.apache.kafka.common.replica.ClientMetadata.DefaultClientMetadata
 import org.apache.kafka.common.requests.{FetchRequest, ProduceResponse}
 import org.apache.kafka.common.security.auth.KafkaPrincipal
 import org.apache.kafka.common.utils.Time
-import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid}
+import org.apache.kafka.common.{IsolationLevel, TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.image.{MetadataDelta, MetadataImage}
 import org.apache.kafka.metadata.PartitionRegistration
 import org.junit.jupiter.api.Assertions._
@@ -81,6 +81,7 @@ class ReplicaManagerConcurrencyTest {
 
     val topicModel = new TopicModel(Uuid.randomUuid(), "foo", Map(0 -> initialPartitionRegistration))
     val topicPartition = new TopicPartition(topicModel.name, 0)
+    val topicIdPartition = new TopicIdPartition(topicModel.topicId, topicPartition)
     val controller = new ControllerModel(topicModel, channel, replicaManager)
 
     submit(new Clock(time))
@@ -111,8 +112,7 @@ class ReplicaManagerConcurrencyTest {
     val fetcher = new FetcherModel(
       clientId = s"replica-$remoteId",
       replicaId = remoteId,
-      topicModel.topicId,
-      topicPartition,
+      topicIdPartition,
       replicaManager
     )
 
@@ -183,8 +183,7 @@ class ReplicaManagerConcurrencyTest {
   private class FetcherModel(
     clientId: String,
     replicaId: Int,
-    topicId: Uuid,
-    topicPartition: TopicPartition,
+    topicIdPartition: TopicIdPartition,
     replicaManager: ReplicaManager
   ) extends ShutdownableThread(name = clientId, isInterruptible = false) {
     private val random = new Random()
@@ -201,6 +200,7 @@ class ReplicaManagerConcurrencyTest {
 
     override def doWork(): Unit = {
       val partitionData = new FetchRequest.PartitionData(
+        topicIdPartition.topicId,
         fetchOffset,
         -1,
         65536,
@@ -209,11 +209,11 @@ class ReplicaManagerConcurrencyTest {
       )
 
       val future = new CompletableFuture[FetchPartitionData]()
-      def fetchCallback(results: collection.Seq[(TopicPartition, FetchPartitionData)]): Unit = {
+      def fetchCallback(results: collection.Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
         try {
           assertEquals(1, results.size)
-          val (topicPartition, result) = results.head
-          assertEquals(this.topicPartition, topicPartition)
+          val (topicIdPartition, result) = results.head
+          assertEquals(this.topicIdPartition, topicIdPartition)
           assertEquals(Errors.NONE, result.error)
           future.complete(result)
         } catch {
@@ -227,8 +227,7 @@ class ReplicaManagerConcurrencyTest {
         fetchMinBytes = 1,
         fetchMaxBytes = 1024 * 1024,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(topicPartition -> partitionData),
-        topicIds = Collections.singletonMap(topicPartition.topic, topicId),
+        fetchInfos = Seq(topicIdPartition -> partitionData),
         quota = QuotaFactory.UnboundedQuota,
         responseCallback = fetchCallback,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
index 9598430..8be7810 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala
@@ -22,7 +22,7 @@ import java.util.{Collections, Optional, Properties}
 import kafka.cluster.Partition
 import kafka.log.{UnifiedLog, LogManager, LogOffsetSnapshot}
 import kafka.utils._
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicPartition, TopicIdPartition, Uuid}
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord}
 import org.apache.kafka.common.requests.FetchRequest.PartitionData
@@ -43,9 +43,11 @@ class ReplicaManagerQuotasTest {
   val topicPartition2 = new TopicPartition("test-topic", 2)
   val topicId = Uuid.randomUuid()
   val topicIds = Collections.singletonMap("test-topic", topicId)
+  val topicIdPartition1 = new TopicIdPartition(topicId, topicPartition1)
+  val topicIdPartition2 = new TopicIdPartition(topicId, topicPartition2)
   val fetchInfo = Seq(
-    topicPartition1 -> new PartitionData(0, 0, 100, Optional.empty()),
-    topicPartition2 -> new PartitionData(0, 0, 100, Optional.empty()))
+    topicIdPartition1 -> new PartitionData(Uuid.ZERO_UUID, 0, 0, 100, Optional.empty()),
+    topicIdPartition2 -> new PartitionData(Uuid.ZERO_UUID, 0, 0, 100, Optional.empty()))
   var quotaManager: QuotaManagers = _
   var replicaManager: ReplicaManager = _
 
@@ -66,13 +68,12 @@ class ReplicaManagerQuotasTest {
       fetchMaxBytes = Int.MaxValue,
       hardMaxBytesLimit = false,
       readPartitionInfo = fetchInfo,
-      topicIds = topicIds,
       quota = quota,
       clientMetadata = None)
-    assertEquals(1, fetch.find(_._1 == topicPartition1).get._2.info.records.batches.asScala.size,
+    assertEquals(1, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with only one throttled, we should get the first")
 
-    assertEquals(0, fetch.find(_._1 == topicPartition2).get._2.info.records.batches.asScala.size,
+    assertEquals(0, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size,
       "But we shouldn't get the second")
   }
 
@@ -93,12 +94,11 @@ class ReplicaManagerQuotasTest {
       fetchMaxBytes = Int.MaxValue,
       hardMaxBytesLimit = false,
       readPartitionInfo = fetchInfo,
-      topicIds = topicIds,
       quota = quota,
       clientMetadata = None)
-    assertEquals(0, fetch.find(_._1 == topicPartition1).get._2.info.records.batches.asScala.size,
+    assertEquals(0, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with both throttled, we should get no messages")
-    assertEquals(0, fetch.find(_._1 == topicPartition2).get._2.info.records.batches.asScala.size,
+    assertEquals(0, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size,
       "Given two partitions, with both throttled, we should get no messages")
   }
 
@@ -119,12 +119,11 @@ class ReplicaManagerQuotasTest {
       fetchMaxBytes = Int.MaxValue,
       hardMaxBytesLimit = false,
       readPartitionInfo = fetchInfo,
-      topicIds = topicIds,
       quota = quota,
       clientMetadata = None)
-    assertEquals(1, fetch.find(_._1 == topicPartition1).get._2.info.records.batches.asScala.size,
+    assertEquals(1, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with both non-throttled, we should get both messages")
-    assertEquals(1, fetch.find(_._1 == topicPartition2).get._2.info.records.batches.asScala.size,
+    assertEquals(1, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size,
       "Given two partitions, with both non-throttled, we should get both messages")
   }
 
@@ -145,13 +144,12 @@ class ReplicaManagerQuotasTest {
       fetchMaxBytes = Int.MaxValue,
       hardMaxBytesLimit = false,
       readPartitionInfo = fetchInfo,
-      topicIds = topicIds,
       quota = quota,
       clientMetadata = None)
-    assertEquals(1, fetch.find(_._1 == topicPartition1).get._2.info.records.batches.asScala.size,
+    assertEquals(1, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size,
       "Given two partitions, with only one throttled, we should get the first")
 
-    assertEquals(1, fetch.find(_._1 == topicPartition2).get._2.info.records.batches.asScala.size,
+    assertEquals(1, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size,
       "But we should get the second too since it's throttled but in sync")
   }
 
@@ -179,9 +177,9 @@ class ReplicaManagerQuotasTest {
       EasyMock.expect(partition.getReplica(1)).andReturn(None)
       EasyMock.replay(replicaManager, partition)
 
-      val tp = new TopicPartition("t1", 0)
+      val tp = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("t1", 0))
       val fetchPartitionStatus = FetchPartitionStatus(LogOffsetMetadata(messageOffset = 50L, segmentBaseOffset = 0L,
-         relativePositionInSegment = 250), new PartitionData(50, 0, 1, Optional.empty()))
+         relativePositionInSegment = 250), new PartitionData(Uuid.ZERO_UUID, 50, 0, 1, Optional.empty()))
       val fetchMetadata = FetchMetadata(fetchMinBytes = 1,
         fetchMaxBytes = 1000,
         hardMaxBytesLimit = true,
@@ -189,7 +187,6 @@ class ReplicaManagerQuotasTest {
         fetchIsolation = FetchLogEnd,
         isFromFollower = true,
         replicaId = 1,
-        topicIds = topicIds,
         fetchPartitionStatus = List((tp, fetchPartitionStatus))
       )
       new DelayedFetch(delayMs = 600, fetchMetadata = fetchMetadata, replicaManager = replicaManager,
@@ -202,7 +199,7 @@ class ReplicaManagerQuotasTest {
     assertFalse(setupDelayedFetch(isReplicaInSync = false).tryComplete(), "Out of sync replica should not complete")
   }
 
-  def setUpMocks(fetchInfo: Seq[(TopicPartition, PartitionData)], record: SimpleRecord = this.record,
+  def setUpMocks(fetchInfo: Seq[(TopicIdPartition, PartitionData)], record: SimpleRecord = this.record,
                  bothReplicasInSync: Boolean = false): Unit = {
     val scheduler: KafkaScheduler = createNiceMock(classOf[KafkaScheduler])
 
@@ -261,7 +258,7 @@ class ReplicaManagerQuotasTest {
 
     //create the two replicas
     for ((p, _) <- fetchInfo) {
-      val partition = replicaManager.createPartition(p)
+      val partition = replicaManager.createPartition(p.topicPartition)
       log.updateHighWatermark(5)
       partition.leaderReplicaIdOpt = Some(leaderBrokerId)
       partition.setLog(log, isFutureLog = false)
diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
index 046e8fb..9f3776f 100644
--- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala
@@ -50,7 +50,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
 import org.apache.kafka.common.requests._
 import org.apache.kafka.common.security.auth.KafkaPrincipal
 import org.apache.kafka.common.utils.{Time, Utils}
-import org.apache.kafka.common.{IsolationLevel, Node, TopicPartition, Uuid}
+import org.apache.kafka.common.{IsolationLevel, Node, TopicPartition, TopicIdPartition, Uuid}
 import org.apache.kafka.image.{ClientQuotasImage, ClusterImageTest, ConfigurationsImage, FeaturesImage, MetadataImage, TopicsDelta, TopicsImage}
 import org.apache.kafka.raft.{OffsetAndEpoch => RaftOffsetAndEpoch}
 import org.easymock.EasyMock
@@ -439,13 +439,13 @@ class ReplicaManagerTest {
       }
 
       // fetch as follower to advance the high watermark
-      fetchAsFollower(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(numRecords, 0, 100000, Optional.empty()),
+      fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, numRecords, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_UNCOMMITTED)
 
       // fetch should return empty since LSO should be stuck at 0
-      var consumerFetchResult = fetchAsConsumer(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(0, 0, 100000, Optional.empty()),
+      var consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_COMMITTED)
       var fetchData = consumerFetchResult.assertFired
       assertEquals(Errors.NONE, fetchData.error)
@@ -454,8 +454,8 @@ class ReplicaManagerTest {
       assertEquals(Some(List.empty[FetchResponseData.AbortedTransaction]), fetchData.abortedTransactions)
 
       // delayed fetch should timeout and return nothing
-      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(0, 0, 100000, Optional.empty()),
+      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_COMMITTED, minBytes = 1000)
       assertFalse(consumerFetchResult.isFired)
       timer.advanceClock(1001)
@@ -475,8 +475,8 @@ class ReplicaManagerTest {
 
       // the LSO has advanced, but the appended commit marker has not been replicated, so
       // none of the data from the transaction should be visible yet
-      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(0, 0, 100000, Optional.empty()),
+      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_COMMITTED)
 
       fetchData = consumerFetchResult.assertFired
@@ -484,13 +484,13 @@ class ReplicaManagerTest {
       assertTrue(fetchData.records.batches.asScala.isEmpty)
 
       // fetch as follower to advance the high watermark
-      fetchAsFollower(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(numRecords + 1, 0, 100000, Optional.empty()),
+      fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, numRecords + 1, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_UNCOMMITTED)
 
       // now all of the records should be fetchable
-      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(0, 0, 100000, Optional.empty()),
+      consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_COMMITTED)
 
       fetchData = consumerFetchResult.assertFired
@@ -553,14 +553,14 @@ class ReplicaManagerTest {
         .onFire { response => assertEquals(Errors.NONE, response.error) }
 
       // fetch as follower to advance the high watermark
-      fetchAsFollower(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(numRecords + 1, 0, 100000, Optional.empty()),
+      fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, numRecords + 1, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_UNCOMMITTED)
 
       // Set the minBytes in order force this request to enter purgatory. When it returns, we should still
       // see the newly aborted transaction.
-      val fetchResult = fetchAsConsumer(replicaManager, new TopicPartition(topic, 0),
-        new PartitionData(0, 0, 100000, Optional.empty()),
+      val fetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         isolationLevel = IsolationLevel.READ_COMMITTED, minBytes = 10000)
       assertFalse(fetchResult.isFired)
 
@@ -618,8 +618,8 @@ class ReplicaManagerTest {
       }
 
       // Followers are always allowed to fetch above the high watermark
-      val followerFetchResult = fetchAsFollower(rm, new TopicPartition(topic, 0),
-        new PartitionData(1, 0, 100000, Optional.empty()))
+      val followerFetchResult = fetchAsFollower(rm, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()))
       val followerFetchData = followerFetchResult.assertFired
       assertEquals(Errors.NONE, followerFetchData.error, "Should not give an exception")
       assertTrue(followerFetchData.records.batches.iterator.hasNext, "Should return some data")
@@ -627,8 +627,8 @@ class ReplicaManagerTest {
       // Consumers are not allowed to consume above the high watermark. However, since the
       // high watermark could be stale at the time of the request, we do not return an out of
       // range error and instead return an empty record set.
-      val consumerFetchResult = fetchAsConsumer(rm, new TopicPartition(topic, 0),
-        new PartitionData(1, 0, 100000, Optional.empty()))
+      val consumerFetchResult = fetchAsConsumer(rm, new TopicIdPartition(topicId, new TopicPartition(topic, 0)),
+        new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()))
       val consumerFetchData = consumerFetchResult.assertFired
       assertEquals(Errors.NONE, consumerFetchData.error, "Should not give an exception")
       assertEquals(MemoryRecords.EMPTY, consumerFetchData.records, "Should return empty response")
@@ -646,6 +646,7 @@ class ReplicaManagerTest {
       brokerId = 0, aliveBrokersIds)
     try {
       val tp = new TopicPartition(topic, 0)
+      val tidp = new TopicIdPartition(topicId, tp)
       val replicas = aliveBrokersIds.toList.map(Int.box).asJava
 
       // Broker 0 becomes leader of the partition
@@ -684,11 +685,11 @@ class ReplicaManagerTest {
 
       // We receive one valid request from the follower and replica state is updated
       var successfulFetch: Option[FetchPartitionData] = None
-      def callback(response: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
-        successfulFetch = response.headOption.filter(_._1 == tp).map(_._2)
+      def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
+        successfulFetch = response.headOption.filter(_._1 == tidp).map(_._2)
       }
 
-      val validFetchPartitionData = new FetchRequest.PartitionData(0L, 0L, maxFetchBytes,
+      val validFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, maxFetchBytes,
         Optional.of(leaderEpoch))
 
       replicaManager.fetchMessages(
@@ -697,8 +698,7 @@ class ReplicaManagerTest {
         fetchMinBytes = 1,
         fetchMaxBytes = maxFetchBytes,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(tp -> validFetchPartitionData),
-        topicIds = topicIds.asJava,
+        fetchInfos = Seq(tidp -> validFetchPartitionData),
         quota = UnboundedQuota,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
         responseCallback = callback,
@@ -712,7 +712,7 @@ class ReplicaManagerTest {
 
       // Next we receive an invalid request with a higher fetch offset, but an old epoch.
       // We expect that the replica state does not get updated.
-      val invalidFetchPartitionData = new FetchRequest.PartitionData(3L, 0L, maxFetchBytes,
+      val invalidFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 3L, 0L, maxFetchBytes,
         Optional.of(leaderEpoch - 1))
 
       replicaManager.fetchMessages(
@@ -721,8 +721,7 @@ class ReplicaManagerTest {
         fetchMinBytes = 1,
         fetchMaxBytes = maxFetchBytes,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(tp -> invalidFetchPartitionData),
-        topicIds = topicIds.asJava,
+        fetchInfos = Seq(tidp -> invalidFetchPartitionData),
         quota = UnboundedQuota,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
         responseCallback = callback,
@@ -735,7 +734,7 @@ class ReplicaManagerTest {
 
       // Next we receive an invalid request with a higher fetch offset, but a diverging epoch.
       // We expect that the replica state does not get updated.
-      val divergingFetchPartitionData = new FetchRequest.PartitionData(3L, 0L, maxFetchBytes,
+      val divergingFetchPartitionData = new FetchRequest.PartitionData(tidp.topicId, 3L, 0L, maxFetchBytes,
         Optional.of(leaderEpoch), Optional.of(leaderEpoch - 1))
 
       replicaManager.fetchMessages(
@@ -744,8 +743,7 @@ class ReplicaManagerTest {
         fetchMinBytes = 1,
         fetchMaxBytes = maxFetchBytes,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(tp -> divergingFetchPartitionData),
-        topicIds = topicIds.asJava,
+        fetchInfos = Seq(tidp -> divergingFetchPartitionData),
         quota = UnboundedQuota,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
         responseCallback = callback,
@@ -770,6 +768,7 @@ class ReplicaManagerTest {
       brokerId = 0, aliveBrokersIds)
     try {
       val tp = new TopicPartition(topic, 0)
+      val tidp = new TopicIdPartition(topicId, tp)
       val replicas = aliveBrokersIds.toList.map(Int.box).asJava
 
       // Broker 0 becomes leader of the partition
@@ -793,52 +792,55 @@ class ReplicaManagerTest {
       assertEquals(Some(topicId), replicaManager.getPartitionOrException(tp).topicId)
 
       // We receive one valid request from the follower and replica state is updated
-      var successfulFetch: Option[FetchPartitionData] = None
-      def callback(response: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
-        successfulFetch = response.headOption.filter { case (topicPartition, _) => topicPartition == tp }.map { case (_, data) => data }
-      }
+      var successfulFetch: Seq[(TopicIdPartition, FetchPartitionData)] = Seq()
 
-      val validFetchPartitionData = new FetchRequest.PartitionData(0L, 0L, maxFetchBytes,
+      val validFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, maxFetchBytes,
         Optional.of(leaderEpoch))
 
       // Fetch messages simulating a different ID than the one in the log.
+      val inconsistentTidp = new TopicIdPartition(Uuid.randomUuid(), tidp.topicPartition)
+      def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
+        successfulFetch = response
+      }
       replicaManager.fetchMessages(
         timeout = 0L,
         replicaId = 1,
         fetchMinBytes = 1,
         fetchMaxBytes = maxFetchBytes,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(tp -> validFetchPartitionData),
-        topicIds = Collections.singletonMap(topic, Uuid.randomUuid()),
+        fetchInfos = Seq(inconsistentTidp -> validFetchPartitionData),
         quota = UnboundedQuota,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
         responseCallback = callback,
         clientMetadata = None
       )
-      assertTrue(successfulFetch.isDefined)
-      assertEquals(Errors.INCONSISTENT_TOPIC_ID, successfulFetch.get.error)
+      val fetch1 = successfulFetch.headOption.filter(_._1 == inconsistentTidp).map(_._2)
+      assertTrue(fetch1.isDefined)
+      assertEquals(Errors.INCONSISTENT_TOPIC_ID, fetch1.get.error)
 
       // Simulate where the fetch request did not use topic IDs
       // Fetch messages simulating an ID in the log.
       // We should not see topic ID errors.
+      val zeroTidp = new TopicIdPartition(Uuid.ZERO_UUID, tidp.topicPartition)
       replicaManager.fetchMessages(
         timeout = 0L,
         replicaId = 1,
         fetchMinBytes = 1,
         fetchMaxBytes = maxFetchBytes,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(tp -> validFetchPartitionData),
-        topicIds = Collections.emptyMap(),
+        fetchInfos = Seq(zeroTidp -> validFetchPartitionData),
         quota = UnboundedQuota,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
         responseCallback = callback,
         clientMetadata = None
       )
-      assertTrue(successfulFetch.isDefined)
-      assertEquals(Errors.NONE, successfulFetch.get.error)
+      val fetch2 = successfulFetch.headOption.filter(_._1 == zeroTidp).map(_._2)
+      assertTrue(fetch2.isDefined)
+      assertEquals(Errors.NONE, fetch2.get.error)
 
       // Next create a topic without a topic ID written in the log.
       val tp2 = new TopicPartition("noIdTopic", 0)
+      val tidp2 = new TopicIdPartition(Uuid.randomUuid(), tp2)
 
       // Broker 0 becomes leader of the partition
       val leaderAndIsrPartitionState2 = new LeaderAndIsrPartitionState()
@@ -867,32 +869,33 @@ class ReplicaManagerTest {
         fetchMinBytes = 1,
         fetchMaxBytes = maxFetchBytes,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(tp -> validFetchPartitionData),
-        topicIds = Collections.singletonMap("noIdTopic", Uuid.randomUuid()),
+        fetchInfos = Seq(tidp2 -> validFetchPartitionData),
         quota = UnboundedQuota,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
         responseCallback = callback,
         clientMetadata = None
       )
-      assertTrue(successfulFetch.isDefined)
-      assertEquals(Errors.NONE, successfulFetch.get.error)
+      val fetch3 = successfulFetch.headOption.filter(_._1 == tidp2).map(_._2)
+      assertTrue(fetch3.isDefined)
+      assertEquals(Errors.NONE, fetch3.get.error)
 
       // Fetch messages simulating the request not containing a topic ID. We should not have an error.
+      val zeroTidp2 = new TopicIdPartition(Uuid.ZERO_UUID, tidp2.topicPartition)
       replicaManager.fetchMessages(
         timeout = 0L,
         replicaId = 1,
         fetchMinBytes = 1,
         fetchMaxBytes = maxFetchBytes,
         hardMaxBytesLimit = false,
-        fetchInfos = Seq(tp -> validFetchPartitionData),
-        topicIds = Collections.emptyMap(),
+        fetchInfos = Seq(zeroTidp2 -> validFetchPartitionData),
         quota = UnboundedQuota,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
         responseCallback = callback,
         clientMetadata = None
       )
-      assertTrue(successfulFetch.isDefined)
-      assertEquals(Errors.NONE, successfulFetch.get.error)
+      val fetch4 = successfulFetch.headOption.filter(_._1 == zeroTidp2).map(_._2)
+      assertTrue(fetch4.isDefined)
+      assertEquals(Errors.NONE, fetch4.get.error)
 
     } finally {
       replicaManager.shutdown(checkpointHW = false)
@@ -911,12 +914,15 @@ class ReplicaManagerTest {
       // Create 2 partitions, assign replica 0 as the leader for both a different follower (1 and 2) for each
       val tp0 = new TopicPartition(topic, 0)
       val tp1 = new TopicPartition(topic, 1)
+      val topicId = Uuid.randomUuid();
+      val tidp0 = new TopicIdPartition(topicId, tp0)
+      val tidp1 = new TopicIdPartition(topicId, tp1)
       val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints)
       replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
       replicaManager.createPartition(tp1).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
       val partition0Replicas = Seq[Integer](0, 1).asJava
       val partition1Replicas = Seq[Integer](0, 2).asJava
-      val topicIds = Map(tp0.topic -> Uuid.randomUuid(), tp1.topic -> Uuid.randomUuid()).asJava
+      val topicIds = Map(tp0.topic -> topicId, tp1.topic -> topicId).asJava
       val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch,
         Seq(
           new LeaderAndIsrPartitionState()
@@ -954,12 +960,12 @@ class ReplicaManagerTest {
         }
       }
 
-      def fetchCallback(responseStatus: Seq[(TopicPartition, FetchPartitionData)]) = {
+      def fetchCallback(responseStatus: Seq[(TopicIdPartition, FetchPartitionData)]) = {
         val responseStatusMap = responseStatus.toMap
         assertEquals(2, responseStatus.size)
-        assertEquals(Set(tp0, tp1), responseStatusMap.keySet)
+        assertEquals(Set(tidp0, tidp1), responseStatusMap.keySet)
 
-        val tp0Status = responseStatusMap.get(tp0)
+        val tp0Status = responseStatusMap.get(tidp0)
         assertTrue(tp0Status.isDefined)
         // the response contains high watermark on the leader before it is updated based
         // on this fetch request
@@ -968,7 +974,7 @@ class ReplicaManagerTest {
         assertEquals(Errors.NONE, tp0Status.get.error)
         assertTrue(tp0Status.get.records.batches.iterator.hasNext)
 
-        val tp1Status = responseStatusMap.get(tp1)
+        val tp1Status = responseStatusMap.get(tidp1)
         assertTrue(tp1Status.isDefined)
         assertEquals(0, tp1Status.get.highWatermark)
         assertEquals(Some(0), tp0Status.get.lastStableOffset)
@@ -983,9 +989,8 @@ class ReplicaManagerTest {
         fetchMaxBytes = Int.MaxValue,
         hardMaxBytesLimit = false,
         fetchInfos = Seq(
-          tp0 -> new PartitionData(1, 0, 100000, Optional.empty()),
-          tp1 -> new PartitionData(1, 0, 100000, Optional.empty())),
-        topicIds = topicIds,
+          tidp0 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()),
+          tidp1 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty())),
         quota = UnboundedQuota,
         responseCallback = fetchCallback,
         isolationLevel = IsolationLevel.READ_UNCOMMITTED,
@@ -1126,6 +1131,7 @@ class ReplicaManagerTest {
       val brokerList = Seq[Integer](0, 1).asJava
 
       val tp0 = new TopicPartition(topic, 0)
+      val tidp0 = new TopicIdPartition(topicId, tp0)
 
       initializeLogAndTopicId(replicaManager, tp0, topicId)
 
@@ -1148,8 +1154,8 @@ class ReplicaManagerTest {
       val metadata: ClientMetadata = new DefaultClientMetadata("rack-a", "client-id",
         InetAddress.getByName("localhost"), KafkaPrincipal.ANONYMOUS, "default")
 
-      val consumerResult = fetchAsConsumer(replicaManager, tp0,
-        new PartitionData(0, 0, 100000, Optional.empty()),
+      val consumerResult = fetchAsConsumer(replicaManager, tidp0,
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         clientMetadata = Some(metadata))
 
       // Fetch from follower succeeds
@@ -1183,6 +1189,7 @@ class ReplicaManagerTest {
       val brokerList = Seq[Integer](0, 1).asJava
 
       val tp0 = new TopicPartition(topic, 0)
+      val tidp0 = new TopicIdPartition(topicId, tp0)
 
       initializeLogAndTopicId(replicaManager, tp0, topicId)
 
@@ -1205,8 +1212,8 @@ class ReplicaManagerTest {
       val metadata: ClientMetadata = new DefaultClientMetadata("rack-a", "client-id",
         InetAddress.getByName("localhost"), KafkaPrincipal.ANONYMOUS, "default")
 
-      val consumerResult = fetchAsConsumer(replicaManager, tp0,
-        new PartitionData(0, 0, 100000, Optional.empty()),
+      val consumerResult = fetchAsConsumer(replicaManager, tidp0,
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()),
         clientMetadata = Some(metadata))
 
       // Fetch from follower succeeds
@@ -1239,6 +1246,7 @@ class ReplicaManagerTest {
     val brokerList = Seq[Integer](0, 1).asJava
 
     val tp0 = new TopicPartition(topic, 0)
+    val tidp0 = new TopicIdPartition(topicId, tp0)
 
     initializeLogAndTopicId(replicaManager, tp0, topicId)
 
@@ -1264,8 +1272,8 @@ class ReplicaManagerTest {
 
     // Increment the hw in the leader by fetching from the last offset
     val fetchOffset = simpleRecords.size
-    var followerResult = fetchAsFollower(replicaManager, tp0,
-      new PartitionData(fetchOffset, 0, 100000, Optional.empty()),
+    var followerResult = fetchAsFollower(replicaManager, tidp0,
+      new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, 100000, Optional.empty()),
       clientMetadata = None)
     assertTrue(followerResult.isFired)
     assertEquals(0, followerResult.assertFired.highWatermark)
@@ -1274,8 +1282,8 @@ class ReplicaManagerTest {
 
     // Fetch from the same offset, no new data is expected and hence the fetch request should
     // go to the purgatory
-    followerResult = fetchAsFollower(replicaManager, tp0,
-      new PartitionData(fetchOffset, 0, 100000, Optional.empty()),
+    followerResult = fetchAsFollower(replicaManager, tidp0,
+      new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, 100000, Optional.empty()),
       clientMetadata = None, minBytes = 1000)
     assertFalse(followerResult.isFired, "Request completed immediately unexpectedly")
 
@@ -1332,6 +1340,7 @@ class ReplicaManagerTest {
 
     try {
       val tp0 = new TopicPartition(topic, 0)
+      val tidp0 = new TopicIdPartition(topicId, tp0)
       val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints)
       replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
       val partition0Replicas = Seq[Integer](0, 1).asJava
@@ -1352,16 +1361,16 @@ class ReplicaManagerTest {
 
       // Fetch from follower, with non-empty ClientMetadata (FetchRequest v11+)
       val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "")
-      var partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+      var partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.of(0))
-      var fetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, Some(clientMetadata))
+      var fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata))
       assertNotNull(fetchResult.get)
       assertEquals(Errors.NONE, fetchResult.get.error)
 
       // Fetch from follower, with empty ClientMetadata (which implies an older version)
-      partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+      partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.of(0))
-      fetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, None)
+      fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None)
       assertNotNull(fetchResult.get)
       assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.get.error)
     } finally {
@@ -1377,6 +1386,7 @@ class ReplicaManagerTest {
     val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1))
 
     val tp0 = new TopicPartition(topic, 0)
+    val tidp0 = new TopicIdPartition(topicId, tp0)
     val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints)
     replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
     val partition0Replicas = Seq[Integer](0, 1).asJava
@@ -1401,15 +1411,15 @@ class ReplicaManagerTest {
       assertEquals(expected, replicaManager.brokerTopicStats.topicStats(topic).totalFetchRequestRate.count)
     }
 
-    val partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+    val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.empty())
 
-    val nonPurgatoryFetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, None, timeout = 0)
+    val nonPurgatoryFetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 0)
     assertNotNull(nonPurgatoryFetchResult.get)
     assertEquals(Errors.NONE, nonPurgatoryFetchResult.get.error)
     assertMetricCount(1)
 
-    val purgatoryFetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, None, timeout = 10)
+    val purgatoryFetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10)
     assertNull(purgatoryFetchResult.get)
     mockTimer.advanceClock(11)
     assertNotNull(purgatoryFetchResult.get)
@@ -1424,6 +1434,7 @@ class ReplicaManagerTest {
 
     try {
       val tp0 = new TopicPartition(topic, 0)
+      val tidp0 = new TopicIdPartition(topicId, tp0)
       val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints)
       replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
       val partition0Replicas = Seq[Integer](0, 1).asJava
@@ -1443,9 +1454,9 @@ class ReplicaManagerTest {
         Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
       replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
 
-      val partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+      val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.empty())
-      val fetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, None, timeout = 10)
+      val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10)
       assertNull(fetchResult.get)
 
       // Become a follower and ensure that the delayed fetch returns immediately
@@ -1480,6 +1491,7 @@ class ReplicaManagerTest {
 
     try {
       val tp0 = new TopicPartition(topic, 0)
+      val tidp0 = new TopicIdPartition(topicId, tp0)
       val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints)
       replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
       val partition0Replicas = Seq[Integer](0, 1).asJava
@@ -1500,9 +1512,9 @@ class ReplicaManagerTest {
       replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
 
       val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "")
-      val partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+      val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
         Optional.of(1))
-      val fetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, Some(clientMetadata), timeout = 10)
+      val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata), timeout = 10)
       assertNull(fetchResult.get)
 
       // Become a follower and ensure that the delayed fetch returns immediately
@@ -1535,6 +1547,7 @@ class ReplicaManagerTest {
     val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), aliveBrokerIds = Seq(0, 1))
 
     val tp0 = new TopicPartition(topic, 0)
+    val tidp0 = new TopicIdPartition(topicId, tp0)
     val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints)
     replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
     val partition0Replicas = Seq[Integer](0, 1).asJava
@@ -1555,15 +1568,15 @@ class ReplicaManagerTest {
     replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
 
     val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "")
-    var partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+    var partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.of(1))
-    var fetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, Some(clientMetadata))
+    var fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata))
     assertNotNull(fetchResult.get)
     assertEquals(Errors.NONE, fetchResult.get.error)
 
-    partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+    partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.empty())
-    fetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, Some(clientMetadata))
+    fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata))
     assertNotNull(fetchResult.get)
     assertEquals(Errors.NONE, fetchResult.get.error)
   }
@@ -1578,6 +1591,7 @@ class ReplicaManagerTest {
     val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1))
 
     val tp0 = new TopicPartition(topic, 0)
+    val tidp0 = new TopicIdPartition(topicId, tp0)
     val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints)
     replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None)
     val partition0Replicas = Seq[Integer](0, 1).asJava
@@ -1597,9 +1611,9 @@ class ReplicaManagerTest {
       Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build()
     replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ())
 
-    val partitionData = new FetchRequest.PartitionData(0L, 0L, 100,
+    val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100,
       Optional.of(1))
-    val fetchResult = sendConsumerFetch(replicaManager, tp0, partitionData, None, timeout = 10)
+    val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10)
     assertNull(fetchResult.get)
     Mockito.when(replicaManager.metadataCache.contains(ArgumentMatchers.eq(tp0))).thenReturn(true)
 
@@ -1685,13 +1699,13 @@ class ReplicaManagerTest {
   }
 
   private def sendConsumerFetch(replicaManager: ReplicaManager,
-                                topicPartition: TopicPartition,
+                                topicIdPartition: TopicIdPartition,
                                 partitionData: FetchRequest.PartitionData,
                                 clientMetadataOpt: Option[ClientMetadata],
                                 timeout: Long = 0L): AtomicReference[FetchPartitionData] = {
     val fetchResult = new AtomicReference[FetchPartitionData]()
-    def callback(response: Seq[(TopicPartition, FetchPartitionData)]): Unit = {
-      fetchResult.set(response.toMap.apply(topicPartition))
+    def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = {
+      fetchResult.set(response.toMap.apply(topicIdPartition))
     }
     replicaManager.fetchMessages(
       timeout = timeout,
@@ -1699,8 +1713,7 @@ class ReplicaManagerTest {
       fetchMinBytes = 1,
       fetchMaxBytes = 100,
       hardMaxBytesLimit = false,
-      fetchInfos = Seq(topicPartition -> partitionData),
-      topicIds = topicIds.asJava,
+      fetchInfos = Seq(topicIdPartition -> partitionData),
       quota = UnboundedQuota,
       isolationLevel = IsolationLevel.READ_UNCOMMITTED,
       responseCallback = callback,
@@ -1943,7 +1956,7 @@ class ReplicaManagerTest {
   }
 
   private def fetchAsConsumer(replicaManager: ReplicaManager,
-                              partition: TopicPartition,
+                              partition: TopicIdPartition,
                               partitionData: PartitionData,
                               minBytes: Int = 0,
                               isolationLevel: IsolationLevel = IsolationLevel.READ_UNCOMMITTED,
@@ -1952,7 +1965,7 @@ class ReplicaManagerTest {
   }
 
   private def fetchAsFollower(replicaManager: ReplicaManager,
-                              partition: TopicPartition,
+                              partition: TopicIdPartition,
                               partitionData: PartitionData,
                               minBytes: Int = 0,
                               isolationLevel: IsolationLevel = IsolationLevel.READ_UNCOMMITTED,
@@ -1962,13 +1975,13 @@ class ReplicaManagerTest {
 
   private def fetchMessages(replicaManager: ReplicaManager,
                             replicaId: Int,
-                            partition: TopicPartition,
+                            partition: TopicIdPartition,
                             partitionData: PartitionData,
                             minBytes: Int,
                             isolationLevel: IsolationLevel,
                             clientMetadata: Option[ClientMetadata]): CallbackResult[FetchPartitionData] = {
     val result = new CallbackResult[FetchPartitionData]()
-    def fetchCallback(responseStatus: Seq[(TopicPartition, FetchPartitionData)]) = {
+    def fetchCallback(responseStatus: Seq[(TopicIdPartition, FetchPartitionData)]) = {
       assertEquals(1, responseStatus.size)
       val (topicPartition, fetchData) = responseStatus.head
       assertEquals(partition, topicPartition)
@@ -1982,7 +1995,6 @@ class ReplicaManagerTest {
       fetchMaxBytes = Int.MaxValue,
       hardMaxBytesLimit = false,
       fetchInfos = Seq(partition -> partitionData),
-      topicIds = topicIds.asJava,
       quota = UnboundedQuota,
       responseCallback = fetchCallback,
       isolationLevel = isolationLevel,
@@ -2934,6 +2946,8 @@ class ReplicaManagerTest {
       // Make the local replica the leader
       val leaderTopicsDelta = topicsCreateDelta(localId, true)
       val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply())
+      val topicId = leaderMetadataImage.topics().topicsByName.get("foo").id
+      val topicIdPartition = new TopicIdPartition(topicId, topicPartition)
       replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage)
 
       // Check the state of that partition and fetcher
@@ -2949,8 +2963,8 @@ class ReplicaManagerTest {
       fetchMessages(
         replicaManager,
         otherId,
-        topicPartition,
-        new PartitionData(numOfRecords, 0, Int.MaxValue, Optional.empty()),
+        topicIdPartition,
+        new PartitionData(Uuid.ZERO_UUID, numOfRecords, 0, Int.MaxValue, Optional.empty()),
         Int.MaxValue,
         IsolationLevel.READ_UNCOMMITTED,
         None
@@ -3009,6 +3023,8 @@ class ReplicaManagerTest {
       // Change the local replica to leader
       val leaderTopicsDelta = topicsChangeDelta(followerMetadataImage.topics(), localId, true)
       val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply())
+      val topicId = leaderMetadataImage.topics().topicsByName.get("foo").id
+      val topicIdPartition = new TopicIdPartition(topicId, topicPartition)
       replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage)
 
       // Send a produce request and advance the highwatermark
@@ -3016,8 +3032,8 @@ class ReplicaManagerTest {
       fetchMessages(
         replicaManager,
         otherId,
-        topicPartition,
-        new PartitionData(numOfRecords, 0, Int.MaxValue, Optional.empty()),
+        topicIdPartition,
+        new PartitionData(Uuid.ZERO_UUID, numOfRecords, 0, Int.MaxValue, Optional.empty()),
         Int.MaxValue,
         IsolationLevel.READ_UNCOMMITTED,
         None
@@ -3273,6 +3289,8 @@ class ReplicaManagerTest {
       // Make the local replica the leader
       val leaderTopicsDelta = topicsCreateDelta(localId, true)
       val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply())
+      val topicId = leaderMetadataImage.topics().topicsByName.get("foo").id
+      val topicIdPartition = new TopicIdPartition(topicId, topicPartition)
       replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage)
 
       // Check the state of that partition and fetcher
@@ -3287,8 +3305,8 @@ class ReplicaManagerTest {
       val fetchCallback = fetchMessages(
         replicaManager,
         otherId,
-        topicPartition,
-        new PartitionData(0, 0, Int.MaxValue, Optional.empty()),
+        topicIdPartition,
+        new PartitionData(Uuid.ZERO_UUID, 0, 0, Int.MaxValue, Optional.empty()),
         Int.MaxValue,
         IsolationLevel.READ_UNCOMMITTED,
         None
diff --git a/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala b/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala
index 56359df..ddbc987 100644
--- a/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala
+++ b/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala
@@ -226,8 +226,8 @@ class RequestQuotaTest extends BaseRequestTest {
 
         case ApiKeys.FETCH =>
           val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
-          partitionMap.put(tp, new FetchRequest.PartitionData(0, 0, 100, Optional.of(15)))
-          FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 0, partitionMap, getTopicIds().asJava)
+          partitionMap.put(tp, new FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID), 0, 0, 100, Optional.of(15)))
+          FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 0, partitionMap)
 
         case ApiKeys.METADATA =>
           new MetadataRequest.Builder(List(topic).asJava, true)
diff --git a/core/src/test/scala/unit/kafka/server/TopicIdWithOldInterBrokerProtocolTest.scala b/core/src/test/scala/unit/kafka/server/TopicIdWithOldInterBrokerProtocolTest.scala
index 224474e..a49fcb0 100644
--- a/core/src/test/scala/unit/kafka/server/TopicIdWithOldInterBrokerProtocolTest.scala
+++ b/core/src/test/scala/unit/kafka/server/TopicIdWithOldInterBrokerProtocolTest.scala
@@ -22,7 +22,7 @@ import java.util.{Arrays, LinkedHashMap, Optional, Properties}
 import kafka.api.KAFKA_2_7_IV0
 import kafka.network.SocketServer
 import kafka.utils.TestUtils
-import org.apache.kafka.common.{TopicPartition, Uuid}
+import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid}
 import org.apache.kafka.common.message.DeleteTopicsRequestData
 import org.apache.kafka.common.message.DeleteTopicsRequestData.DeleteTopicState
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
@@ -71,14 +71,16 @@ class TopicIdWithOldInterBrokerProtocolTest extends BaseRequestTest {
     val maxPartitionBytes = 190
     val topicIds = Map("topic1" -> Uuid.randomUuid())
     val topicNames = topicIds.map(_.swap)
+    val tidp0 = new TopicIdPartition(topicIds(topic1), tp0)
 
     val leadersMap = createTopic(topic1, replicaAssignment)
-    val req = createFetchRequest(maxResponseBytes, maxPartitionBytes, Seq(tp0),  Map.empty, topicIds, ApiKeys.FETCH.latestVersion())
+    val req = createFetchRequest(maxResponseBytes, maxPartitionBytes, Seq(tidp0), Map.empty, ApiKeys.FETCH.latestVersion())
     val resp = sendFetchRequest(leadersMap(0), req)
 
     val responseData = resp.responseData(topicNames.asJava, ApiKeys.FETCH.latestVersion())
-    assertEquals(Errors.UNKNOWN_TOPIC_ID.code, resp.error().code())
-    assertEquals(0, responseData.size());
+    assertEquals(Errors.NONE.code, resp.error().code())
+    assertEquals(1, responseData.size())
+    assertEquals(Errors.UNKNOWN_TOPIC_ID.code, responseData.get(tp0).errorCode)
   }
 
   @Test
@@ -90,9 +92,10 @@ class TopicIdWithOldInterBrokerProtocolTest extends BaseRequestTest {
     val maxPartitionBytes = 190
     val topicIds = Map("topic1" -> Uuid.randomUuid())
     val topicNames = topicIds.map(_.swap)
+    val tidp0 = new TopicIdPartition(topicIds(topic1), tp0)
 
     val leadersMap = createTopic(topic1, replicaAssignment)
-    val req = createFetchRequest(maxResponseBytes, maxPartitionBytes, Seq(tp0), Map.empty, topicIds, 12)
+    val req = createFetchRequest(maxResponseBytes, maxPartitionBytes, Seq(tidp0), Map.empty, 12)
     val resp = sendFetchRequest(leadersMap(0), req)
 
     assertEquals(Errors.NONE, resp.error())
@@ -141,19 +144,18 @@ class TopicIdWithOldInterBrokerProtocolTest extends BaseRequestTest {
     connectAndReceive[MetadataResponse](request, destination = destination.getOrElse(anySocketServer))
   }
 
-  private def createFetchRequest(maxResponseBytes: Int, maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition],
+  private def createFetchRequest(maxResponseBytes: Int, maxPartitionBytes: Int, topicPartitions: Seq[TopicIdPartition],
                                  offsetMap: Map[TopicPartition, Long],
-                                 topicIds: Map[String, Uuid],
                                  version: Short): FetchRequest = {
-    FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, topicPartitions, offsetMap), topicIds.asJava)
+    FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, topicPartitions, offsetMap))
       .setMaxBytes(maxResponseBytes).build()
   }
 
-  private def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition],
+  private def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicIdPartition],
                                  offsetMap: Map[TopicPartition, Long]): LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = {
     val partitionMap = new LinkedHashMap[TopicPartition, FetchRequest.PartitionData]
     topicPartitions.foreach { tp =>
-      partitionMap.put(tp, new FetchRequest.PartitionData(offsetMap.getOrElse(tp, 0), 0L, maxPartitionBytes,
+      partitionMap.put(tp.topicPartition, new FetchRequest.PartitionData(tp.topicId, offsetMap.getOrElse(tp.topicPartition, 0), 0L, maxPartitionBytes,
         Optional.empty()))
     }
     partitionMap
diff --git a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
index 4abbb9d..8f3fcff 100644
--- a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
+++ b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala
@@ -27,7 +27,7 @@ import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.AbstractRequest.Builder
 import org.apache.kafka.common.requests.{AbstractRequest, FetchResponse, OffsetsForLeaderEpochResponse, FetchMetadata => JFetchMetadata}
 import org.apache.kafka.common.utils.{SystemTime, Time}
-import org.apache.kafka.common.{Node, TopicPartition, Uuid}
+import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid}
 
 import scala.collection.Map
 
@@ -100,15 +100,13 @@ class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, Epoc
 
       case ApiKeys.FETCH =>
         fetchCount += 1
-        val partitionData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData]
-        val topicIdsForRequest = new util.HashMap[String, Uuid]()
-        fetchPartitionData.foreach { case (tp, data) => partitionData.put(tp, data) }
-        topicIds.foreach { case (name, id) => topicIdsForRequest.put(name, id)}
+        val partitionData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]
+        fetchPartitionData.foreach { case (tp, data) => partitionData.put(new TopicIdPartition(topicIds.getOrElse(tp.topic(), Uuid.ZERO_UUID), tp), data) }
         fetchPartitionData = Map.empty
         topicIds = Map.empty
         FetchResponse.of(Errors.NONE, 0,
           if (partitionData.isEmpty) JFetchMetadata.INVALID_SESSION_ID else 1,
-          partitionData, topicIdsForRequest)
+          partitionData)
 
       case _ =>
         throw new UnsupportedOperationException
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchRequestBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchRequestBenchmark.java
index 86fd7c3..a428f91 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchRequestBenchmark.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchRequestBenchmark.java
@@ -62,8 +62,6 @@ public class FetchRequestBenchmark {
 
     Map<TopicPartition, FetchRequest.PartitionData> fetchData;
 
-    Map<String, Uuid> topicIds;
-
     Map<Uuid, String> topicNames;
 
     RequestHeader header;
@@ -77,24 +75,22 @@ public class FetchRequestBenchmark {
     @Setup(Level.Trial)
     public void setup() {
         this.fetchData = new HashMap<>();
-        this.topicIds = new HashMap<>();
         this.topicNames = new HashMap<>();
         for (int topicIdx = 0; topicIdx < topicCount; topicIdx++) {
             String topic = Uuid.randomUuid().toString();
             Uuid id = Uuid.randomUuid();
-            topicIds.put(topic, id);
             topicNames.put(id, topic);
             for (int partitionId = 0; partitionId < partitionCount; partitionId++) {
                 FetchRequest.PartitionData partitionData = new FetchRequest.PartitionData(
-                    0, 0, 4096, Optional.empty());
+                    id, 0, 0, 4096, Optional.empty());
                 fetchData.put(new TopicPartition(topic, partitionId), partitionData);
             }
         }
 
         this.header = new RequestHeader(ApiKeys.FETCH, ApiKeys.FETCH.latestVersion(), "jmh-benchmark", 100);
-        this.consumerRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 0, 0, fetchData, topicIds)
+        this.consumerRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 0, 0, fetchData)
             .build(ApiKeys.FETCH.latestVersion());
-        this.replicaRequest = FetchRequest.Builder.forReplica(ApiKeys.FETCH.latestVersion(), 1, 0, 0, fetchData, topicIds)
+        this.replicaRequest = FetchRequest.Builder.forReplica(ApiKeys.FETCH.latestVersion(), 1, 0, 0, fetchData)
             .build(ApiKeys.FETCH.latestVersion());
         this.requestBuffer = this.consumerRequest.serialize();
 
@@ -107,7 +103,7 @@ public class FetchRequestBenchmark {
 
     @Benchmark
     public int testFetchRequestForConsumer() {
-        FetchRequest fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 0, 0, fetchData, topicIds)
+        FetchRequest fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 0, 0, fetchData)
             .build(ApiKeys.FETCH.latestVersion());
         return fetchRequest.fetchData(topicNames).size();
     }
@@ -115,7 +111,7 @@ public class FetchRequestBenchmark {
     @Benchmark
     public int testFetchRequestForReplica() {
         FetchRequest fetchRequest = FetchRequest.Builder.forReplica(
-            ApiKeys.FETCH.latestVersion(), 1, 0, 0, fetchData, topicIds)
+            ApiKeys.FETCH.latestVersion(), 1, 0, 0, fetchData)
                 .build(ApiKeys.FETCH.latestVersion());
         return fetchRequest.fetchData(topicNames).size();
     }
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchResponseBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchResponseBenchmark.java
index 7795218..d8512bd 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchResponseBenchmark.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchResponseBenchmark.java
@@ -18,6 +18,7 @@
 package org.apache.kafka.jmh.common;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.message.FetchResponseData;
 import org.apache.kafka.common.network.Send;
@@ -63,7 +64,7 @@ public class FetchResponseBenchmark {
     @Param({"3", "10", "20"})
     private int partitionCount;
 
-    LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> responseData;
+    LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> responseData;
 
     Map<String, Uuid> topicIds;
 
@@ -96,18 +97,18 @@ public class FetchResponseBenchmark {
                                 .setLastStableOffset(0)
                                 .setLogStartOffset(0)
                                 .setRecords(records);
-                responseData.put(new TopicPartition(topic, partitionId), partitionData);
+                responseData.put(new TopicIdPartition(id, new TopicPartition(topic, partitionId)), partitionData);
             }
         }
 
         this.header = new ResponseHeader(100, ApiKeys.FETCH.responseHeaderVersion(ApiKeys.FETCH.latestVersion()));
-        this.fetchResponse = FetchResponse.of(Errors.NONE, 0, 0, responseData, topicIds);
+        this.fetchResponse = FetchResponse.of(Errors.NONE, 0, 0, responseData);
         this.fetchResponseData = this.fetchResponse.data();
     }
 
     @Benchmark
     public int testConstructFetchResponse() {
-        FetchResponse fetchResponse = FetchResponse.of(Errors.NONE, 0, 0, responseData, topicIds);
+        FetchResponse fetchResponse = FetchResponse.of(Errors.NONE, 0, 0, responseData);
         return fetchResponse.data().responses().size();
     }
 
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
index def0ad0..7f03788 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java
@@ -52,6 +52,7 @@ import kafka.utils.Pool;
 import kafka.utils.TestUtils;
 import kafka.zk.KafkaZkClient;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.message.FetchResponseData;
 import org.apache.kafka.common.message.LeaderAndIsrRequestData;
@@ -152,7 +153,7 @@ public class ReplicaFetcherThreadBenchmark {
             setKeepPartitionMetadataFile(true).
             build();
 
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> initialFetched = new LinkedHashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> initialFetched = new LinkedHashMap<>();
         HashMap<String, Uuid> topicIds = new HashMap<>();
         scala.collection.mutable.Map<TopicPartition, InitialFetchState> initialFetchStates = new scala.collection.mutable.HashMap<>();
         List<UpdateMetadataRequestData.UpdateMetadataPartitionState> updatePartitionState = new ArrayList<>();
@@ -191,7 +192,7 @@ public class ReplicaFetcherThreadBenchmark {
                     return null;
                 }
             };
-            initialFetched.put(tp, new FetchResponseData.PartitionData()
+            initialFetched.put(new TopicIdPartition(topicId.get(), tp), new FetchResponseData.PartitionData()
                     .setPartitionIndex(tp.partition())
                     .setLastStableOffset(0)
                     .setLogStartOffset(0)
@@ -234,7 +235,7 @@ public class ReplicaFetcherThreadBenchmark {
         // so that we do not measure this time as part of the steady state work
         fetcher.doWork();
         // handle response to engage the incremental fetch session handler
-        fetcher.fetchSessionHandler().handleResponse(FetchResponse.of(Errors.NONE, 0, 999, initialFetched, topicIds), ApiKeys.FETCH.latestVersion());
+        fetcher.fetchSessionHandler().handleResponse(FetchResponse.of(Errors.NONE, 0, 999, initialFetched), ApiKeys.FETCH.latestVersion());
     }
 
     @TearDown(Level.Trial)
diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetchsession/FetchSessionBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetchsession/FetchSessionBenchmark.java
index 8d2fcb9..26216b9 100644
--- a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetchsession/FetchSessionBenchmark.java
+++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetchsession/FetchSessionBenchmark.java
@@ -19,6 +19,7 @@ package org.apache.kafka.jmh.fetchsession;
 
 import org.apache.kafka.clients.FetchSessionHandler;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.TopicIdPartition;
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.message.FetchResponseData;
 import org.apache.kafka.common.protocol.ApiKeys;
@@ -78,28 +79,27 @@ public class FetchSessionBenchmark {
         Uuid id = Uuid.randomUuid();
         topicIds.put("foo", id);
 
-        LinkedHashMap<TopicPartition, FetchResponseData.PartitionData> respMap = new LinkedHashMap<>();
+        LinkedHashMap<TopicIdPartition, FetchResponseData.PartitionData> respMap = new LinkedHashMap<>();
         for (int i = 0; i < partitionCount; i++) {
             TopicPartition tp = new TopicPartition("foo", i);
-            FetchRequest.PartitionData partitionData = new FetchRequest.PartitionData(0, 0, 200,
-                    Optional.empty());
+            FetchRequest.PartitionData partitionData = new FetchRequest.PartitionData(id, 0, 0, 200, Optional.empty());
             fetches.put(tp, partitionData);
-            builder.add(tp, topicIds.get(tp.topic()), partitionData);
-            respMap.put(tp, new FetchResponseData.PartitionData()
+            builder.add(tp, partitionData);
+            respMap.put(new TopicIdPartition(id, tp), new FetchResponseData.PartitionData()
                             .setPartitionIndex(tp.partition())
                             .setLastStableOffset(0)
                             .setLogStartOffset(0));
         }
         builder.build();
         // build and handle an initial response so that the next fetch will be incremental
-        handler.handleResponse(FetchResponse.of(Errors.NONE, 0, 1, respMap, topicIds), ApiKeys.FETCH.latestVersion());
+        handler.handleResponse(FetchResponse.of(Errors.NONE, 0, 1, respMap), ApiKeys.FETCH.latestVersion());
 
         int counter = 0;
         for (TopicPartition topicPartition: new ArrayList<>(fetches.keySet())) {
             if (updatedPercentage != 0 && counter % (100 / updatedPercentage) == 0) {
                 // reorder in fetch session, and update log start offset
                 fetches.remove(topicPartition);
-                fetches.put(topicPartition, new FetchRequest.PartitionData(50, 40, 200,
+                fetches.put(topicPartition, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 50, 40, 200,
                         Optional.empty()));
             }
             counter++;
@@ -115,9 +115,10 @@ public class FetchSessionBenchmark {
         else
             builder = handler.newBuilder();
 
+        // Should we keep lookup to mimic how adding really works?
         for (Map.Entry<TopicPartition, FetchRequest.PartitionData> entry: fetches.entrySet()) {
             TopicPartition topicPartition = entry.getKey();
-            builder.add(topicPartition, topicIds.get(topicPartition.topic()), entry.getValue());
+            builder.add(topicPartition, entry.getValue());
         }
 
         builder.build();