You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2021/02/19 22:57:51 UTC

[kafka] branch 2.8 updated: KAFKA-10817; Add clusterId validation to raft Fetch handling (#10129)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.8
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.8 by this push:
     new 5c921fc  KAFKA-10817; Add clusterId validation to raft Fetch handling (#10129)
5c921fc is described below

commit 5c921fcbf450ab07c6d3adb5c21b81e8775e5640
Author: David Jacot <dj...@confluent.io>
AuthorDate: Fri Feb 19 23:43:14 2021 +0100

    KAFKA-10817; Add clusterId validation to raft Fetch handling (#10129)
    
    This patch adds clusterId validation in the `Fetch` API as documented in KIP-595. A new error code `INCONSISTENT_CLUSTER_ID` is returned if the request clusterId does not match the value on the server. If no clusterId is provided, the request is treated as valid.
    
    Reviewers: Jason Gustafson <ja...@confluent.io>
---
 .../errors/InconsistentClusterIdException.java     | 28 +++++++++++
 .../org/apache/kafka/common/protocol/Errors.java   |  4 +-
 .../resources/common/message/FetchRequest.json     |  3 +-
 core/src/main/scala/kafka/raft/RaftManager.scala   |  1 +
 .../org/apache/kafka/raft/KafkaRaftClient.java     | 18 +++++++
 .../kafka/raft/KafkaRaftClientSnapshotTest.java    | 24 ++++-----
 .../org/apache/kafka/raft/KafkaRaftClientTest.java | 58 ++++++++++++++++------
 .../apache/kafka/raft/RaftClientTestContext.java   | 55 +++++++++++++++++---
 .../apache/kafka/raft/RaftEventSimulationTest.java |  3 ++
 9 files changed, 158 insertions(+), 36 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InconsistentClusterIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentClusterIdException.java
new file mode 100644
index 0000000..62fed41
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentClusterIdException.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.errors;
+
+public class InconsistentClusterIdException extends ApiException {
+
+    public InconsistentClusterIdException(String message) {
+        super(message);
+    }
+
+    public InconsistentClusterIdException(String message, Throwable throwable) {
+        super(message, throwable);
+    }
+}
\ No newline at end of file
diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java
index 34c4206..5c2ca7d 100644
--- a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java
+++ b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java
@@ -49,6 +49,7 @@ import org.apache.kafka.common.errors.IllegalSaslStateException;
 import org.apache.kafka.common.errors.InconsistentGroupProtocolException;
 import org.apache.kafka.common.errors.InconsistentTopicIdException;
 import org.apache.kafka.common.errors.InconsistentVoterSetException;
+import org.apache.kafka.common.errors.InconsistentClusterIdException;
 import org.apache.kafka.common.errors.InvalidCommitOffsetSizeException;
 import org.apache.kafka.common.errors.InvalidConfigurationException;
 import org.apache.kafka.common.errors.InvalidFetchSessionEpochException;
@@ -356,7 +357,8 @@ public enum Errors {
         PositionOutOfRangeException::new),
     UNKNOWN_TOPIC_ID(100, "This server does not host this topic ID.", UnknownTopicIdException::new),
     DUPLICATE_BROKER_REGISTRATION(101, "This broker ID is already in use.", DuplicateBrokerRegistrationException::new),
-    INCONSISTENT_TOPIC_ID(102, "The log's topic ID did not match the topic ID in the request", InconsistentTopicIdException::new);
+    INCONSISTENT_TOPIC_ID(102, "The log's topic ID did not match the topic ID in the request", InconsistentTopicIdException::new),
+    INCONSISTENT_CLUSTER_ID(103, "The clusterId in the request does not match that found on the server", InconsistentClusterIdException::new);
 
     private static final Logger log = LoggerFactory.getLogger(Errors.class);
 
diff --git a/clients/src/main/resources/common/message/FetchRequest.json b/clients/src/main/resources/common/message/FetchRequest.json
index ab4c95f..6594773 100644
--- a/clients/src/main/resources/common/message/FetchRequest.json
+++ b/clients/src/main/resources/common/message/FetchRequest.json
@@ -50,7 +50,8 @@
   "validVersions": "0-12",
   "flexibleVersions": "12+",
   "fields": [
-    { "name": "ClusterId", "type": "string", "versions": "12+", "nullableVersions": "12+", "default": "null", "taggedVersions": "12+", "tag": 0,
+    { "name": "ClusterId", "type": "string", "versions": "12+", "nullableVersions": "12+", "default": "null",
+      "taggedVersions": "12+", "tag": 0, "ignorable": true,
       "about": "The clusterId if known. This is used to validate metadata fetches prior to broker registration." },
     { "name": "ReplicaId", "type": "int32", "versions": "0+",
       "about": "The broker ID of the follower, of -1 if this request is from a consumer." },
diff --git a/core/src/main/scala/kafka/raft/RaftManager.scala b/core/src/main/scala/kafka/raft/RaftManager.scala
index 6a74c27..7bf34b0 100644
--- a/core/src/main/scala/kafka/raft/RaftManager.scala
+++ b/core/src/main/scala/kafka/raft/RaftManager.scala
@@ -200,6 +200,7 @@ class KafkaRaftManager[T](
       metrics,
       expirationService,
       logContext,
+      metaProperties.clusterId.toString,
       OptionalInt.of(config.nodeId),
       raftConfig
     )
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 49aabd3..946fec1 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -141,6 +141,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
     private final LogContext logContext;
     private final Time time;
     private final int fetchMaxWaitMs;
+    private final String clusterId;
     private final OptionalInt nodeId;
     private final NetworkChannel channel;
     private final ReplicatedLog log;
@@ -177,6 +178,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         Metrics metrics,
         ExpirationService expirationService,
         LogContext logContext,
+        String clusterId,
         OptionalInt nodeId,
         RaftConfig raftConfig
     ) {
@@ -190,6 +192,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
             metrics,
             expirationService,
             FETCH_MAX_WAIT_MS,
+            clusterId,
             nodeId,
             logContext,
             new Random(),
@@ -207,6 +210,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         Metrics metrics,
         ExpirationService expirationService,
         int fetchMaxWaitMs,
+        String clusterId,
         OptionalInt nodeId,
         LogContext logContext,
         Random random,
@@ -221,6 +225,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         this.fetchPurgatory = new ThresholdPurgatory<>(expirationService);
         this.appendPurgatory = new ThresholdPurgatory<>(expirationService);
         this.time = time;
+        this.clusterId = clusterId;
         this.nodeId = nodeId;
         this.metrics = metrics;
         this.fetchMaxWaitMs = fetchMaxWaitMs;
@@ -933,6 +938,14 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         );
     }
 
+    private boolean hasValidClusterId(FetchRequestData request) {
+        // We don't enforce the cluster id if it is not provided.
+        if (request.clusterId() == null) {
+            return true;
+        }
+        return clusterId.equals(request.clusterId());
+    }
+
     /**
      * Handle a Fetch request. The fetch offset and last fetched epoch are always
      * validated against the current log. In the case that they do not match, the response will
@@ -952,6 +965,10 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
     ) {
         FetchRequestData request = (FetchRequestData) requestMetadata.data;
 
+        if (!hasValidClusterId(request)) {
+            return completedFuture(new FetchResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()));
+        }
+
         if (!hasValidTopicPartition(request, log.topicPartition())) {
             // Until we support multi-raft, we treat topic partition mismatches as invalid requests
             return completedFuture(new FetchResponseData().setErrorCode(Errors.INVALID_REQUEST.code()));
@@ -1759,6 +1776,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         });
         return request
             .setMaxWaitMs(fetchMaxWaitMs)
+            .setClusterId(clusterId.toString())
             .setReplicaId(quorum.localIdOrSentinel());
     }
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
index 9ebb776..2b7cea5 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
@@ -74,7 +74,7 @@ final public class KafkaRaftClientSnapshotTest {
         // Advance the highWatermark
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, localLogEndOffset, epoch, 0));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(localLogEndOffset, context.client.highWatermark().getAsLong());
 
         OffsetAndEpoch snapshotId = new OffsetAndEpoch(localLogEndOffset, epoch);
@@ -89,7 +89,7 @@ final public class KafkaRaftClientSnapshotTest {
         // Send Fetch request less than start offset
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 0, epoch, 0));
         context.pollUntilResponse();
-        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchResponse();
+        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchPartitionResponse();
         assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
         assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
         assertEquals(localId, partitionResponse.currentLeader().leaderId());
@@ -119,7 +119,7 @@ final public class KafkaRaftClientSnapshotTest {
         long localLogEndOffset = context.log.endOffset().offset;
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, localLogEndOffset, epoch, 0));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(localLogEndOffset, context.client.highWatermark().getAsLong());
 
         // Create a snapshot at the high watermark
@@ -135,7 +135,7 @@ final public class KafkaRaftClientSnapshotTest {
         // It is an invalid request to send an last fetched epoch greater than the current epoch
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset + 1, epoch + 1, 0));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
     }
 
     @Test
@@ -162,7 +162,7 @@ final public class KafkaRaftClientSnapshotTest {
             context.fetchRequest(epoch, syncNodeId, context.log.endOffset().offset, epoch, 0)
         );
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(context.log.endOffset().offset, context.client.highWatermark().getAsLong());
 
         // Create a snapshot at the high watermark
@@ -176,7 +176,7 @@ final public class KafkaRaftClientSnapshotTest {
             context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset + 1, oldestSnapshotId.epoch + 1, 0)
         );
         context.pollUntilResponse();
-        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchResponse();
+        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchPartitionResponse();
         assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
         assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
         assertEquals(localId, partitionResponse.currentLeader().leaderId());
@@ -209,7 +209,7 @@ final public class KafkaRaftClientSnapshotTest {
             context.fetchRequest(epoch, syncNodeId, context.log.endOffset().offset, epoch, 0)
         );
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(context.log.endOffset().offset, context.client.highWatermark().getAsLong());
 
         // Create a snapshot at the high watermark
@@ -223,7 +223,7 @@ final public class KafkaRaftClientSnapshotTest {
             context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset, oldestSnapshotId.epoch, 0)
         );
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
     }
 
     @Test
@@ -251,7 +251,7 @@ final public class KafkaRaftClientSnapshotTest {
             context.fetchRequest(epoch, syncNodeId, context.log.endOffset().offset, epoch, 0)
         );
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(context.log.endOffset().offset, context.client.highWatermark().getAsLong());
 
         // Create a snapshot at the high watermark
@@ -265,7 +265,7 @@ final public class KafkaRaftClientSnapshotTest {
             context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset, oldestSnapshotId.epoch + 1, 0)
         );
         context.pollUntilResponse();
-        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchResponse();
+        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchPartitionResponse();
         assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
         assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
         assertEquals(localId, partitionResponse.currentLeader().leaderId());
@@ -298,7 +298,7 @@ final public class KafkaRaftClientSnapshotTest {
             context.fetchRequest(epoch, syncNodeId, context.log.endOffset().offset, epoch, 0)
         );
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(context.log.endOffset().offset, context.client.highWatermark().getAsLong());
 
         // Create a snapshot at the high watermark
@@ -318,7 +318,7 @@ final public class KafkaRaftClientSnapshotTest {
             )
         );
         context.pollUntilResponse();
-        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchResponse();
+        FetchResponseData.FetchablePartitionResponse partitionResponse = context.assertSentFetchPartitionResponse();
         assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
         assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
         assertEquals(localId, partitionResponse.currentLeader().leaderId());
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
index fb188f1..8093c1f 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
@@ -233,7 +233,7 @@ public class KafkaRaftClientTest {
         // when transition to resign, all request in fetchPurgatory will fail
         context.client.shutdown(1000);
         context.client.poll();
-        context.assertSentFetchResponse(Errors.BROKER_NOT_AVAILABLE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.BROKER_NOT_AVAILABLE, epoch, OptionalInt.of(localId));
         context.assertResignedLeader(epoch, localId);
 
         // shutting down finished
@@ -836,7 +836,7 @@ public class KafkaRaftClientTest {
         // note the offset 0 would be a control message for becoming the leader
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 0L, epoch, 500));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(OptionalLong.of(0L), context.client.highWatermark());
 
         List<String> records = Arrays.asList("a", "b", "c");
@@ -847,7 +847,7 @@ public class KafkaRaftClientTest {
         // Let the follower send a fetch, it should advance the high watermark
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, 500));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(OptionalLong.of(1L), context.client.highWatermark());
         assertEquals(OptionalLong.empty(), context.listener.lastCommitOffset());
 
@@ -1075,27 +1075,55 @@ public class KafkaRaftClientTest {
         context.deliverRequest(context.fetchRequest(
             epoch, otherNodeId, -5L, 0, 0));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
 
         context.deliverRequest(context.fetchRequest(
             epoch, otherNodeId, 0L, -1, 0));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
 
         context.deliverRequest(context.fetchRequest(
             epoch, otherNodeId, 0L, epoch + 1, 0));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
 
         context.deliverRequest(context.fetchRequest(
             epoch + 1, otherNodeId, 0L, 0, 0));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.UNKNOWN_LEADER_EPOCH, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.UNKNOWN_LEADER_EPOCH, epoch, OptionalInt.of(localId));
 
         context.deliverRequest(context.fetchRequest(
             epoch, otherNodeId, 0L, 0, -1));
         context.pollUntilResponse();
-        context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
+    }
+
+    @Test
+    public void testFetchRequestClusterIdValidation() throws Exception {
+        int localId = 0;
+        int otherNodeId = 1;
+        int epoch = 5;
+        Set<Integer> voters = Utils.mkSet(localId, otherNodeId);
+
+        RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
+
+        // null cluster id is accepted
+        context.deliverRequest(context.fetchRequest(
+            epoch, null, otherNodeId, -5L, 0, 0));
+        context.pollUntilResponse();
+        context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId));
+
+        // empty cluster id is rejected
+        context.deliverRequest(context.fetchRequest(
+            epoch, "", otherNodeId, -5L, 0, 0));
+        context.pollUntilResponse();
+        context.assertSentFetchPartitionResponse(Errors.INCONSISTENT_CLUSTER_ID);
+
+        // invalid cluster id is rejected
+        context.deliverRequest(context.fetchRequest(
+            epoch, "invalid-uuid", otherNodeId, -5L, 0, 0));
+        context.pollUntilResponse();
+        context.assertSentFetchPartitionResponse(Errors.INCONSISTENT_CLUSTER_ID);
     }
 
     @Test
@@ -1169,7 +1197,7 @@ public class KafkaRaftClientTest {
         // After expiration of the max wait time, the fetch returns an empty record set
         context.time.sleep(maxWaitTimeMs);
         context.client.poll();
-        MemoryRecords fetchedRecords = context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         assertEquals(0, fetchedRecords.sizeInBytes());
     }
 
@@ -1192,7 +1220,7 @@ public class KafkaRaftClientTest {
         context.client.scheduleAppend(epoch, Arrays.asList(appendRecords));
         context.client.poll();
 
-        MemoryRecords fetchedRecords = context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         RaftClientTestContext.assertMatchingRecords(appendRecords, fetchedRecords);
     }
 
@@ -1222,7 +1250,7 @@ public class KafkaRaftClientTest {
         context.assertSentBeginQuorumEpochResponse(Errors.NONE, epoch + 1, OptionalInt.of(voter3));
 
         // The fetch should be satisfied immediately and return an error
-        MemoryRecords fetchedRecords = context.assertSentFetchResponse(
+        MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse(
             Errors.NOT_LEADER_OR_FOLLOWER, epoch + 1, OptionalInt.of(voter3));
         assertEquals(0, fetchedRecords.sizeInBytes());
     }
@@ -1492,7 +1520,7 @@ public class KafkaRaftClientTest {
         context.pollUntilResponse();
 
         long highWatermark = 1L;
-        context.assertSentFetchResponse(highWatermark, epoch);
+        context.assertSentFetchPartitionResponse(highWatermark, epoch);
 
         context.deliverRequest(DescribeQuorumRequest.singletonRequest(context.metadataPartition));
 
@@ -1710,7 +1738,7 @@ public class KafkaRaftClientTest {
         context.client.poll();
 
         // The BeginEpoch request eventually times out. We should not send another one.
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
         context.time.sleep(context.requestTimeoutMs());
 
         context.client.poll();
@@ -1749,7 +1777,7 @@ public class KafkaRaftClientTest {
         context.deliverRequest(context.fetchRequest(1, otherNodeId, 0L, 0, 500));
         context.pollUntilResponse();
 
-        MemoryRecords fetchedRecords = context.assertSentFetchResponse(Errors.NONE, 1, OptionalInt.of(localId));
+        MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse(Errors.NONE, 1, OptionalInt.of(localId));
         List<MutableRecordBatch> batches = Utils.toList(fetchedRecords.batchIterator());
         assertEquals(2, batches.size());
 
@@ -2159,7 +2187,7 @@ public class KafkaRaftClientTest {
         context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 9L, epoch, 500));
         context.pollUntilResponse();
         assertEquals(OptionalLong.of(9L), context.client.highWatermark());
-        context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
+        context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId));
 
         // Now we receive a vote request which transitions us to the 'unattached' state
         context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 9L));
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
index 12fda05..e57995c 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.raft;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.memory.MemoryPool;
 import org.apache.kafka.common.message.BeginQuorumEpochRequestData;
 import org.apache.kafka.common.message.BeginQuorumEpochResponseData;
@@ -95,6 +96,7 @@ public final class RaftClientTestContext {
     private int appendLingerMs;
 
     private final QuorumStateStore quorumStateStore;
+    private final Uuid clusterId;
     private final OptionalInt localId;
     public final KafkaRaftClient<String> client;
     final Metrics metrics;
@@ -127,6 +129,7 @@ public final class RaftClientTestContext {
         private final Set<Integer> voters;
         private final OptionalInt localId;
 
+        private Uuid clusterId = Uuid.randomUuid();
         private int requestTimeoutMs = DEFAULT_REQUEST_TIMEOUT_MS;
         private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS;
         private int appendLingerMs = DEFAULT_APPEND_LINGER_MS;
@@ -192,6 +195,11 @@ public final class RaftClientTestContext {
             return this;
         }
 
+        Builder withClusterId(Uuid clusterId) {
+            this.clusterId = clusterId;
+            return this;
+        }
+
         public RaftClientTestContext build() throws IOException {
             Metrics metrics = new Metrics(time);
             MockNetworkChannel channel = new MockNetworkChannel(voters);
@@ -213,6 +221,7 @@ public final class RaftClientTestContext {
                 metrics,
                 new MockExpirationService(time),
                 FETCH_MAX_WAIT_MS,
+                clusterId.toString(),
                 localId,
                 logContext,
                 random,
@@ -223,6 +232,7 @@ public final class RaftClientTestContext {
             client.initialize();
 
             RaftClientTestContext context = new RaftClientTestContext(
+                clusterId,
                 localId,
                 client,
                 log,
@@ -244,6 +254,7 @@ public final class RaftClientTestContext {
     }
 
     private RaftClientTestContext(
+        Uuid clusterId,
         OptionalInt localId,
         KafkaRaftClient<String> client,
         MockLog log,
@@ -255,6 +266,7 @@ public final class RaftClientTestContext {
         Metrics metrics,
         MockListener listener
     ) {
+        this.clusterId = clusterId;
         this.localId = localId;
         this.client = client;
         this.log = log;
@@ -594,7 +606,7 @@ public final class RaftClientTestContext {
         return raftMessage.correlationId();
     }
 
-    FetchResponseData.FetchablePartitionResponse assertSentFetchResponse() {
+    FetchResponseData.FetchablePartitionResponse assertSentFetchPartitionResponse() {
         List<RaftResponse.Outbound> sentMessages = drainSentResponses(ApiKeys.FETCH);
         assertEquals(
             1, sentMessages.size(), "Found unexpected sent messages " + sentMessages);
@@ -609,13 +621,23 @@ public final class RaftClientTestContext {
         return response.responses().get(0).partitionResponses().get(0);
     }
 
+    void assertSentFetchPartitionResponse(Errors error) {
+        List<RaftResponse.Outbound> sentMessages = drainSentResponses(ApiKeys.FETCH);
+        assertEquals(
+            1, sentMessages.size(), "Found unexpected sent messages " + sentMessages);
+        RaftResponse.Outbound raftMessage = sentMessages.get(0);
+        assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey());
+        FetchResponseData response = (FetchResponseData) raftMessage.data();
+        assertEquals(error, Errors.forCode(response.errorCode()));
+    }
+
 
-    MemoryRecords assertSentFetchResponse(
+    MemoryRecords assertSentFetchPartitionResponse(
         Errors error,
         int epoch,
         OptionalInt leaderId
     ) {
-        FetchResponseData.FetchablePartitionResponse partitionResponse = assertSentFetchResponse();
+        FetchResponseData.FetchablePartitionResponse partitionResponse = assertSentFetchPartitionResponse();
         assertEquals(error, Errors.forCode(partitionResponse.errorCode()));
         assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch());
         assertEquals(leaderId.orElse(-1), partitionResponse.currentLeader().leaderId());
@@ -626,11 +648,11 @@ public final class RaftClientTestContext {
         return (MemoryRecords) partitionResponse.recordSet();
     }
 
-    MemoryRecords assertSentFetchResponse(
+    MemoryRecords assertSentFetchPartitionResponse(
         long highWatermark,
         int leaderEpoch
     ) {
-        FetchResponseData.FetchablePartitionResponse partitionResponse = assertSentFetchResponse();
+        FetchResponseData.FetchablePartitionResponse partitionResponse = assertSentFetchPartitionResponse();
         assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode()));
         assertEquals(leaderEpoch, partitionResponse.currentLeader().leaderEpoch());
         assertEquals(highWatermark, partitionResponse.highWatermark());
@@ -670,7 +692,7 @@ public final class RaftClientTestContext {
         deliverRequest(fetchRequest(1, laggingFollower, 0L, 0, 0));
 
         pollUntilResponse();
-        assertSentFetchResponse(0L, epoch);
+        assertSentFetchPartitionResponse(0L, epoch);
 
         // Append some records, so that the close follower will be able to advance further.
         client.scheduleAppend(epoch, Arrays.asList("foo", "bar"));
@@ -679,7 +701,7 @@ public final class RaftClientTestContext {
         deliverRequest(fetchRequest(epoch, closeFollower, 1L, epoch, 0));
 
         pollUntilResponse();
-        assertSentFetchResponse(1L, epoch);
+        assertSentFetchPartitionResponse(1L, epoch);
     }
 
     List<RaftRequest.Outbound> collectEndQuorumRequests(int epoch, Set<Integer> destinationIdSet) {
@@ -867,6 +889,24 @@ public final class RaftClientTestContext {
         int lastFetchedEpoch,
         int maxWaitTimeMs
     ) {
+        return fetchRequest(
+            epoch,
+            clusterId.toString(),
+            replicaId,
+            fetchOffset,
+            lastFetchedEpoch,
+            maxWaitTimeMs
+        );
+    }
+
+    FetchRequestData fetchRequest(
+        int epoch,
+        String clusterId,
+        int replicaId,
+        long fetchOffset,
+        int lastFetchedEpoch,
+        int maxWaitTimeMs
+    ) {
         FetchRequestData request = RaftUtil.singletonFetchRequest(metadataPartition, fetchPartition -> {
             fetchPartition
                 .setCurrentLeaderEpoch(epoch)
@@ -875,6 +915,7 @@ public final class RaftClientTestContext {
         });
         return request
             .setMaxWaitMs(maxWaitTimeMs)
+            .setClusterId(clusterId)
             .setReplicaId(replicaId);
     }
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
index f94b85a..a1af912 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.raft;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.memory.MemoryPool;
 import org.apache.kafka.common.metrics.Metrics;
 import org.apache.kafka.common.protocol.ObjectSerializationCache;
@@ -562,6 +563,7 @@ public class RaftEventSimulationTest {
         final Random random;
         final AtomicInteger correlationIdCounter = new AtomicInteger();
         final MockTime time = new MockTime();
+        final Uuid clusterId = Uuid.randomUuid();
         final Set<Integer> voters = new HashSet<>();
         final Map<Integer, PersistentState> nodes = new HashMap<>();
         final Map<Integer, RaftNode> running = new HashMap<>();
@@ -758,6 +760,7 @@ public class RaftEventSimulationTest {
                 metrics,
                 new MockExpirationService(time),
                 FETCH_MAX_WAIT_MS,
+                clusterId.toString(),
                 OptionalInt.of(nodeId),
                 logContext,
                 random,