You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2022/12/13 18:45:15 UTC

[kafka] branch trunk updated: MINOR: Pass snapshot ID directly in `RaftClient.createSnapshot` (#12981)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 26a4d420726 MINOR: Pass snapshot ID directly in `RaftClient.createSnapshot` (#12981)
26a4d420726 is described below

commit 26a4d420726f4c63da6ea5a045694a6c6ac87207
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Tue Dec 13 10:44:56 2022 -0800

    MINOR: Pass snapshot ID directly in `RaftClient.createSnapshot` (#12981)
    
    Let `RaftClient.createSnapshot` take the snapshotId directly instead of the committed offset/epoch (which may not exist).
    
    Reviewers: José Armando García Sancio <js...@apache.org>
---
 .../src/main/scala/kafka/server/BrokerServer.scala |  4 +-
 .../apache/kafka/controller/QuorumController.java  |  9 +++-
 .../org/apache/kafka/metalog/LocalLogManager.java  |  8 ++--
 .../org/apache/kafka/raft/KafkaRaftClient.java     |  5 +--
 .../java/org/apache/kafka/raft/RaftClient.java     | 23 ++++++-----
 .../org/apache/kafka/raft/ReplicatedCounter.java   |  3 +-
 .../kafka/snapshot/RecordsSnapshotWriter.java      | 46 +++++++++++++++------
 .../org/apache/kafka/snapshot/SnapshotWriter.java  |  2 +-
 .../kafka/raft/KafkaRaftClientSnapshotTest.java    | 48 +++++++++++-----------
 .../kafka/snapshot/SnapshotWriterReaderTest.java   |  8 ++--
 10 files changed, 90 insertions(+), 66 deletions(-)

diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala b/core/src/main/scala/kafka/server/BrokerServer.scala
index a83da6e0258..623338fd2f1 100644
--- a/core/src/main/scala/kafka/server/BrokerServer.scala
+++ b/core/src/main/scala/kafka/server/BrokerServer.scala
@@ -43,6 +43,7 @@ import org.apache.kafka.common.utils.{LogContext, Time, Utils}
 import org.apache.kafka.common.{ClusterResource, Endpoint}
 import org.apache.kafka.metadata.authorizer.ClusterMetadataAuthorizer
 import org.apache.kafka.metadata.{BrokerState, VersionRange}
+import org.apache.kafka.raft
 import org.apache.kafka.raft.{RaftClient, RaftConfig}
 import org.apache.kafka.server.authorizer.Authorizer
 import org.apache.kafka.server.common.ApiMessageAndVersion
@@ -59,7 +60,8 @@ class BrokerSnapshotWriterBuilder(raftClient: RaftClient[ApiMessageAndVersion])
   override def build(committedOffset: Long,
                      committedEpoch: Int,
                      lastContainedLogTime: Long): Option[SnapshotWriter[ApiMessageAndVersion]] = {
-    raftClient.createSnapshot(committedOffset, committedEpoch, lastContainedLogTime).asScala
+    val snapshotId = new raft.OffsetAndEpoch(committedOffset + 1, committedEpoch)
+    raftClient.createSnapshot(snapshotId, lastContainedLogTime).asScala
   }
 }
 
diff --git a/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java b/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java
index 37e96b294d7..3dba5b401d6 100644
--- a/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java
+++ b/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java
@@ -533,9 +533,14 @@ public final class QuorumController implements Controller {
                     )
                 );
             }
+
+            OffsetAndEpoch snapshotId = new OffsetAndEpoch(
+                committedOffset + 1,
+                committedEpoch
+            );
+
             Optional<SnapshotWriter<ApiMessageAndVersion>> writer = raftClient.createSnapshot(
-                committedOffset,
-                committedEpoch,
+                snapshotId,
                 committedTimestamp
             );
             if (writer.isPresent()) {
diff --git a/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManager.java b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManager.java
index 20c5c324dd2..3bc07c06af0 100644
--- a/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManager.java
+++ b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManager.java
@@ -40,8 +40,8 @@ import org.apache.kafka.snapshot.RawSnapshotReader;
 import org.apache.kafka.snapshot.RawSnapshotWriter;
 import org.apache.kafka.snapshot.RecordsSnapshotReader;
 import org.apache.kafka.snapshot.RecordsSnapshotWriter;
-import org.apache.kafka.snapshot.SnapshotWriter;
 import org.apache.kafka.snapshot.SnapshotReader;
+import org.apache.kafka.snapshot.SnapshotWriter;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -51,8 +51,8 @@ import java.util.HashMap;
 import java.util.IdentityHashMap;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Map.Entry;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.NavigableMap;
 import java.util.Objects;
 import java.util.Optional;
@@ -774,11 +774,9 @@ public final class LocalLogManager implements RaftClient<ApiMessageAndVersion>,
 
     @Override
     public Optional<SnapshotWriter<ApiMessageAndVersion>> createSnapshot(
-        long committedOffset,
-        int committedEpoch,
+        OffsetAndEpoch snapshotId,
         long lastContainedLogTimestamp
     ) {
-        OffsetAndEpoch snapshotId = new OffsetAndEpoch(committedOffset + 1, committedEpoch);
         return RecordsSnapshotWriter.createWithHeader(
             () -> createNewSnapshot(snapshotId),
             1024,
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 1cd6058eb0c..f21c31eeab9 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -2342,12 +2342,11 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
 
     @Override
     public Optional<SnapshotWriter<T>> createSnapshot(
-        long committedOffset,
-        int committedEpoch,
+        OffsetAndEpoch snapshotId,
         long lastContainedLogTime
     ) {
         return RecordsSnapshotWriter.createWithHeader(
-                () -> log.createNewSnapshot(new OffsetAndEpoch(committedOffset + 1, committedEpoch)),
+                () -> log.createNewSnapshot(snapshotId),
                 MAX_BATCH_SIZE_BYTES,
                 memoryPool,
                 time,
diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftClient.java b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java
index 727bd1bbf80..95805672311 100644
--- a/raft/src/main/java/org/apache/kafka/raft/RaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java
@@ -216,28 +216,29 @@ public interface RaftClient<T> extends AutoCloseable {
     /**
      * Create a writable snapshot file for a committed offset and epoch.
      *
-     * The RaftClient assumes that the snapshot returned will contain the records up to and
-     * including the committed offset and epoch. See {@link SnapshotWriter} for details on
-     * how to use this object. If a snapshot already exists then returns an
-     * {@link Optional#empty()}.
+     * The RaftClient assumes that the snapshot returned will contain the records up to, but not
+     * including the committed offset and epoch. If no records have been committed, it is possible
+     * to generate an empty snapshot using 0 for both the offset and epoch.
+     *
+     * See {@link SnapshotWriter} for details on how to use this object. If a snapshot already
+     * exists then returns an {@link Optional#empty()}.
      *
-     * @param committedEpoch the epoch of the committed offset
-     * @param committedOffset the last committed offset that will be included in the snapshot
+     * @param snapshotId The ID of the new snapshot, which includes the (exclusive) last committed offset
+     *                   and the last committed epoch.
      * @param lastContainedLogTime The append time of the highest record contained in this snapshot
      * @return a writable snapshot if it doesn't already exists
      * @throws IllegalArgumentException if the committed offset is greater than the high-watermark
      *         or less than the log start offset.
      */
-    Optional<SnapshotWriter<T>> createSnapshot(long committedOffset, int committedEpoch, long lastContainedLogTime);
-
+    Optional<SnapshotWriter<T>> createSnapshot(OffsetAndEpoch snapshotId, long lastContainedLogTime);
 
     /**
-     * The snapshot id for the lastest snapshot.
+     * The snapshot id for the latest snapshot.
      *
-     * Returns the snapshot id of the latest snapshot, if it exists. If a snapshot doesn't exists, returns an
+     * Returns the snapshot id of the latest snapshot, if it exists. If a snapshot doesn't exist, returns an
      * {@link Optional#empty()}.
      *
-     * @return the id of the latest snaphost, if it exists
+     * @return the id of the latest snapshot, if it exists
      */
     Optional<OffsetAndEpoch> latestSnapshotId();
 }
diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java
index 27b3be163e0..65a3b6526e9 100644
--- a/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java
+++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java
@@ -115,8 +115,7 @@ public class ReplicatedCounter implements RaftClient.Listener<Integer> {
                     lastOffsetSnapshotted
                 );
                 Optional<SnapshotWriter<Integer>> snapshot = client.createSnapshot(
-                    lastCommittedOffset,
-                    lastCommittedEpoch,
+                    new OffsetAndEpoch(lastCommittedOffset + 1, lastCommittedEpoch),
                     lastCommittedTimestamp);
                 if (snapshot.isPresent()) {
                     try {
diff --git a/raft/src/main/java/org/apache/kafka/snapshot/RecordsSnapshotWriter.java b/raft/src/main/java/org/apache/kafka/snapshot/RecordsSnapshotWriter.java
index 859a0259445..05d3fde09d0 100644
--- a/raft/src/main/java/org/apache/kafka/snapshot/RecordsSnapshotWriter.java
+++ b/raft/src/main/java/org/apache/kafka/snapshot/RecordsSnapshotWriter.java
@@ -118,19 +118,39 @@ final public class RecordsSnapshotWriter<T> implements SnapshotWriter<T> {
         CompressionType compressionType,
         RecordSerde<T> serde
     ) {
-        return supplier.get().map(snapshot -> {
-            RecordsSnapshotWriter<T> writer = new RecordsSnapshotWriter<>(
-                    snapshot,
-                    maxBatchSize,
-                    memoryPool,
-                    snapshotTime,
-                    lastContainedLogTimestamp,
-                    compressionType,
-                    serde);
-            writer.initializeSnapshotWithHeader();
-
-            return writer;
-        });
+        return supplier.get().map(writer ->
+            createWithHeader(
+                writer,
+                maxBatchSize,
+                memoryPool,
+                snapshotTime,
+                lastContainedLogTimestamp,
+                compressionType,
+                serde
+            )
+        );
+    }
+
+    public static <T> SnapshotWriter<T> createWithHeader(
+        RawSnapshotWriter rawSnapshotWriter,
+        int maxBatchSize,
+        MemoryPool memoryPool,
+        Time snapshotTime,
+        long lastContainedLogTimestamp,
+        CompressionType compressionType,
+        RecordSerde<T> serde
+    ) {
+        RecordsSnapshotWriter<T> writer = new RecordsSnapshotWriter<>(
+            rawSnapshotWriter,
+            maxBatchSize,
+            memoryPool,
+            snapshotTime,
+            lastContainedLogTimestamp,
+            compressionType,
+            serde
+        );
+        writer.initializeSnapshotWithHeader();
+        return writer;
     }
 
     @Override
diff --git a/raft/src/main/java/org/apache/kafka/snapshot/SnapshotWriter.java b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotWriter.java
index 77b29d94498..537335c058a 100644
--- a/raft/src/main/java/org/apache/kafka/snapshot/SnapshotWriter.java
+++ b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotWriter.java
@@ -32,7 +32,7 @@ import java.util.List;
  * topic partition from offset 0 up to but not including the end offset in the snapshot
  * id.
  *
- * @see org.apache.kafka.raft.KafkaRaftClient#createSnapshot(long, int, long)
+ * @see org.apache.kafka.raft.RaftClient#createSnapshot(OffsetAndEpoch, long)
  */
 public interface SnapshotWriter<T> extends AutoCloseable {
     /**
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
index ecef1ffa933..23fe3fd0694 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java
@@ -232,7 +232,7 @@ final public class KafkaRaftClientSnapshotTest {
 
         // Generate a new snapshot
         OffsetAndEpoch secondSnapshotId = new OffsetAndEpoch(localLogEndOffset, epoch);
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(secondSnapshotId.offset() - 1, secondSnapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(secondSnapshotId, 0).get()) {
             assertEquals(secondSnapshotId, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -278,7 +278,7 @@ final public class KafkaRaftClientSnapshotTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
         OffsetAndEpoch snapshotId = new OffsetAndEpoch(localLogEndOffset, epoch);
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId.offset() - 1, snapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId, 0).get()) {
             assertEquals(snapshotId, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -318,7 +318,7 @@ final public class KafkaRaftClientSnapshotTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
         // Create a snapshot at the high watermark
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId.offset() - 1, oldestSnapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId, 0).get()) {
             assertEquals(oldestSnapshotId, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -357,7 +357,7 @@ final public class KafkaRaftClientSnapshotTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
         // Create a snapshot at the high watermark
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId.offset() - 1, oldestSnapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId, 0).get()) {
             assertEquals(oldestSnapshotId, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -400,7 +400,7 @@ final public class KafkaRaftClientSnapshotTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
         // Create a snapshot at the high watermark
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId.offset() - 1, oldestSnapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId, 0).get()) {
             assertEquals(oldestSnapshotId, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -438,7 +438,7 @@ final public class KafkaRaftClientSnapshotTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
         // Create a snapshot at the high watermark
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId.offset() - 1, oldestSnapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId, 0).get()) {
             assertEquals(oldestSnapshotId, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -482,7 +482,7 @@ final public class KafkaRaftClientSnapshotTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
         // Create a snapshot at the high watermark
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId.offset() - 1, oldestSnapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(oldestSnapshotId, 0).get()) {
             assertEquals(oldestSnapshotId, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -572,7 +572,7 @@ final public class KafkaRaftClientSnapshotTest {
 
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId.offset() - 1, snapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId, 0).get()) {
             assertEquals(snapshotId, snapshot.snapshotId());
             snapshot.append(records);
             snapshot.freeze();
@@ -621,7 +621,7 @@ final public class KafkaRaftClientSnapshotTest {
 
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId.offset() - 1, snapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId, 0).get()) {
             assertEquals(snapshotId, snapshot.snapshotId());
             snapshot.append(records);
             snapshot.freeze();
@@ -730,7 +730,7 @@ final public class KafkaRaftClientSnapshotTest {
 
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId.offset() - 1, snapshotId.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(snapshotId, 0).get()) {
             assertEquals(snapshotId, snapshot.snapshotId());
             snapshot.append(records);
             snapshot.freeze();
@@ -1554,7 +1554,7 @@ final public class KafkaRaftClientSnapshotTest {
         int epoch = 2;
 
         List<String> appendRecords = Arrays.asList("a", "b", "c");
-        OffsetAndEpoch invalidSnapshotId1 = new OffsetAndEpoch(3, epoch);
+        OffsetAndEpoch invalidSnapshotId1 = new OffsetAndEpoch(4, epoch);
 
         RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters)
                 .appendToLog(epoch, appendRecords)
@@ -1567,7 +1567,7 @@ final public class KafkaRaftClientSnapshotTest {
         // When leader creating snapshot:
         // 1.1 high watermark cannot be empty
         assertEquals(OptionalLong.empty(), context.client.highWatermark());
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId1.offset(), invalidSnapshotId1.epoch(), 0));
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId1, 0));
 
         // 1.2 high watermark must larger than or equal to the snapshotId's endOffset
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
@@ -1578,18 +1578,18 @@ final public class KafkaRaftClientSnapshotTest {
         context.client.poll();
         assertEquals(context.log.endOffset().offset, context.client.highWatermark().getAsLong() + newRecords.size());
 
-        OffsetAndEpoch invalidSnapshotId2 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() + 1, currentEpoch);
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId2.offset(), invalidSnapshotId2.epoch(), 0));
+        OffsetAndEpoch invalidSnapshotId2 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() + 2, currentEpoch);
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId2, 0));
 
         // 2 the quorum epoch must larger than or equal to the snapshotId's epoch
-        OffsetAndEpoch invalidSnapshotId3 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() - 2, currentEpoch + 1);
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId3.offset(), invalidSnapshotId3.epoch(), 0));
+        OffsetAndEpoch invalidSnapshotId3 = new OffsetAndEpoch(context.client.highWatermark().getAsLong(), currentEpoch + 1);
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId3, 0));
 
         // 3 the snapshotId should be validated against endOffsetForEpoch
         OffsetAndEpoch endOffsetForEpoch = context.log.endOffsetForEpoch(epoch);
         assertEquals(epoch, endOffsetForEpoch.epoch());
-        OffsetAndEpoch invalidSnapshotId4 = new OffsetAndEpoch(endOffsetForEpoch.offset() + 1, epoch);
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId4.offset(), invalidSnapshotId4.epoch(), 0));
+        OffsetAndEpoch invalidSnapshotId4 = new OffsetAndEpoch(endOffsetForEpoch.offset() + 2, epoch);
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId4, 0));
     }
 
     @Test
@@ -1608,8 +1608,8 @@ final public class KafkaRaftClientSnapshotTest {
         // When follower creating snapshot:
         // 1) The high watermark cannot be empty
         assertEquals(OptionalLong.empty(), context.client.highWatermark());
-        OffsetAndEpoch invalidSnapshotId1 = new OffsetAndEpoch(0, 0);
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId1.offset(), invalidSnapshotId1.epoch(), 0));
+        OffsetAndEpoch invalidSnapshotId1 = new OffsetAndEpoch(1, 0);
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId1, 0));
 
         // Poll for our first fetch request
         context.pollUntilRequest();
@@ -1627,11 +1627,11 @@ final public class KafkaRaftClientSnapshotTest {
         // 2) The high watermark must be larger than or equal to the snapshotId's endOffset
         int currentEpoch = context.currentEpoch();
         OffsetAndEpoch invalidSnapshotId2 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() + 1, currentEpoch);
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId2.offset(), invalidSnapshotId2.epoch(), 0));
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId2, 0));
 
         // 3) The quorum epoch must be larger than or equal to the snapshotId's epoch
-        OffsetAndEpoch invalidSnapshotId3 = new OffsetAndEpoch(context.client.highWatermark().getAsLong(), currentEpoch + 1);
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId3.offset(), invalidSnapshotId3.epoch(), 0));
+        OffsetAndEpoch invalidSnapshotId3 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() + 1, currentEpoch + 1);
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId3, 0));
 
         // The high watermark advances to be larger than log.endOffsetForEpoch(3), to test the case 3
         context.pollUntilRequest();
@@ -1650,7 +1650,7 @@ final public class KafkaRaftClientSnapshotTest {
         OffsetAndEpoch endOffsetForEpoch = context.log.endOffsetForEpoch(3);
         assertEquals(3, endOffsetForEpoch.epoch());
         OffsetAndEpoch invalidSnapshotId4 = new OffsetAndEpoch(endOffsetForEpoch.offset() + 1, epoch);
-        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId4.offset(), invalidSnapshotId4.epoch(), 0));
+        assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId4, 0));
     }
 
     private static FetchSnapshotRequestData fetchSnapshotRequest(
diff --git a/raft/src/test/java/org/apache/kafka/snapshot/SnapshotWriterReaderTest.java b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotWriterReaderTest.java
index c0fd2dc3f97..bd8258cb708 100644
--- a/raft/src/test/java/org/apache/kafka/snapshot/SnapshotWriterReaderTest.java
+++ b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotWriterReaderTest.java
@@ -64,7 +64,7 @@ final public class SnapshotWriterReaderTest {
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
         // Create an empty snapshot and freeze it immediately
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id.offset() - 1, id.epoch(), magicTimestamp).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id, magicTimestamp).get()) {
             assertEquals(id, snapshot.snapshotId());
             snapshot.freeze();
         }
@@ -97,7 +97,7 @@ final public class SnapshotWriterReaderTest {
 
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id.offset() - 1, id.epoch(), magicTimestamp).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id, magicTimestamp).get()) {
             assertEquals(id, snapshot.snapshotId());
             expected.forEach(batch -> assertDoesNotThrow(() -> snapshot.append(batch)));
             snapshot.freeze();
@@ -129,7 +129,7 @@ final public class SnapshotWriterReaderTest {
 
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id.offset() - 1, id.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id, 0).get()) {
             assertEquals(id, snapshot.snapshotId());
             expected.forEach(batch -> {
                 assertDoesNotThrow(() -> snapshot.append(batch));
@@ -157,7 +157,7 @@ final public class SnapshotWriterReaderTest {
 
         context.advanceLocalLeaderHighWatermarkToLogEndOffset();
 
-        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id.offset() - 1, id.epoch(), 0).get()) {
+        try (SnapshotWriter<String> snapshot = context.client.createSnapshot(id, 0).get()) {
             assertEquals(id, snapshot.snapshotId());
             expected.forEach(batch -> {
                 assertDoesNotThrow(() -> snapshot.append(batch));