You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/05/30 16:10:42 UTC

[kafka] branch 2.3 updated: KAFKA-8429; Handle offset change when OffsetForLeaderEpoch inflight (#6811)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.3
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.3 by this push:
     new ba1f856  KAFKA-8429; Handle offset change when OffsetForLeaderEpoch inflight (#6811)
ba1f856 is described below

commit ba1f8561bfffa678a49731550c3b32a341de4572
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Thu May 30 08:50:45 2019 -0700

    KAFKA-8429; Handle offset change when OffsetForLeaderEpoch inflight (#6811)
    
    It is possible for the offset of a partition to be changed while we are in the middle of validation. If the OffsetForLeaderEpoch request is in-flight and the offset changes, we need to redo the validation after it returns. We had a check for this situation previously, but it was only checking if the current leader epoch had changed. This patch fixes this and moves the validation in `SubscriptionState` where it can be protected with a lock.
    
    Additionally, this patch adds test cases for the SubscriptionState validation API. We fix a small bug handling broker downgrades. Basically we should skip validation if the latest metadata does not include leader epoch information.
    
    Reviewers: David Arthur <mu...@gmail.com>
---
 .../kafka/clients/consumer/KafkaConsumer.java      |   4 +-
 .../consumer/internals/ConsumerCoordinator.java    |   4 +-
 .../kafka/clients/consumer/internals/Fetcher.java  |  55 ++---
 .../consumer/internals/SubscriptionState.java      |  98 ++++++---
 .../clients/consumer/internals/FetcherTest.java    |  56 ++++-
 .../consumer/internals/SubscriptionStateTest.java  | 240 ++++++++++++++++++++-
 6 files changed, 368 insertions(+), 89 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index 9c1a45c..01b8989 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -1547,7 +1547,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     offset,
                     Optional.empty(), // This will ensure we skip validation
                     this.metadata.leaderAndEpoch(partition));
-            this.subscriptions.seek(partition, newPosition);
+            this.subscriptions.seekUnvalidated(partition, newPosition);
         } finally {
             release();
         }
@@ -1583,7 +1583,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
                     offsetAndMetadata.leaderEpoch(),
                     currentLeaderAndEpoch);
             this.updateLastSeenEpochIfNewer(partition, offsetAndMetadata);
-            this.subscriptions.seekAndValidate(partition, newPosition);
+            this.subscriptions.seekUnvalidated(partition, newPosition);
         } finally {
             release();
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index d88c47b..e71711f 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -520,11 +520,11 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             final ConsumerMetadata.LeaderAndEpoch leaderAndEpoch = metadata.leaderAndEpoch(tp);
             final SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition(
                     offsetAndMetadata.offset(), offsetAndMetadata.leaderEpoch(),
-                    new ConsumerMetadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.empty()));
+                    leaderAndEpoch);
 
             log.info("Setting offset for partition {} to the committed offset {}", tp, position);
             entry.getValue().leaderEpoch().ifPresent(epoch -> this.metadata.updateLastSeenEpochIfNewer(entry.getKey(), epoch));
-            this.subscriptions.seekAndValidate(tp, position);
+            this.subscriptions.seekUnvalidated(tp, position);
         }
         return true;
     }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index d839791..e638963 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -19,7 +19,6 @@ package org.apache.kafka.clients.consumer.internals;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientResponse;
 import org.apache.kafka.clients.FetchSessionHandler;
-import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MetadataCache;
 import org.apache.kafka.clients.NodeApiVersions;
 import org.apache.kafka.clients.StaleMetadataException;
@@ -466,7 +465,7 @@ public class Fetcher<K, V> implements Closeable {
         // Validate each partition against the current leader and epoch
         subscriptions.assignedPartitions().forEach(topicPartition -> {
             ConsumerMetadata.LeaderAndEpoch leaderAndEpoch = metadata.leaderAndEpoch(topicPartition);
-            subscriptions.maybeValidatePosition(topicPartition, leaderAndEpoch);
+            subscriptions.maybeValidatePositionForCurrentLeader(topicPartition, leaderAndEpoch);
         });
 
         // Collect positions needing validation, with backoff
@@ -677,7 +676,7 @@ public class Fetcher<K, V> implements Closeable {
         SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition(
                 offsetData.offset, offsetData.leaderEpoch, metadata.leaderAndEpoch(partition));
         offsetData.leaderEpoch.ifPresent(epoch -> metadata.updateLastSeenEpochIfNewer(partition, epoch));
-        subscriptions.maybeSeek(partition, position.offset, requestedResetStrategy);
+        subscriptions.maybeSeekUnvalidated(partition, position.offset, requestedResetStrategy);
     }
 
     private void resetOffsetsAsync(Map<TopicPartition, Long> partitionResetTimestamps) {
@@ -735,16 +734,12 @@ public class Fetcher<K, V> implements Closeable {
         final Map<Node, Map<TopicPartition, SubscriptionState.FetchPosition>> regrouped =
                 regroupFetchPositionsByLeader(partitionsToValidate);
 
-        regrouped.forEach((node, dataMap) -> {
+        regrouped.forEach((node, fetchPostitions) -> {
             if (node.isEmpty()) {
                 metadata.requestUpdate();
                 return;
             }
 
-            final Map<TopicPartition, Metadata.LeaderAndEpoch> cachedLeaderAndEpochs = partitionsToValidate.entrySet()
-                    .stream()
-                    .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().currentLeader));
-
             NodeApiVersions nodeApiVersions = apiVersions.get(node.idString());
             if (nodeApiVersions == null) {
                 client.tryConnect(node);
@@ -754,14 +749,14 @@ public class Fetcher<K, V> implements Closeable {
             if (!hasUsableOffsetForLeaderEpochVersion(nodeApiVersions)) {
                 log.debug("Skipping validation of fetch offsets for partitions {} since the broker does not " +
                                 "support the required protocol version (introduced in Kafka 2.3)",
-                        cachedLeaderAndEpochs.keySet());
-                for (TopicPartition partition : cachedLeaderAndEpochs.keySet()) {
+                        fetchPostitions.keySet());
+                for (TopicPartition partition : fetchPostitions.keySet()) {
                     subscriptions.completeValidation(partition);
                 }
                 return;
             }
 
-            subscriptions.setNextAllowedRetry(dataMap.keySet(), time.milliseconds() + requestTimeoutMs);
+            subscriptions.setNextAllowedRetry(fetchPostitions.keySet(), time.milliseconds() + requestTimeoutMs);
 
             RequestFuture<OffsetsForLeaderEpochClient.OffsetForEpochResult> future = offsetsForLeaderEpochClient.sendAsyncRequest(node, partitionsToValidate);
             future.addListener(new RequestFutureListener<OffsetsForLeaderEpochClient.OffsetForEpochResult>() {
@@ -777,34 +772,12 @@ public class Fetcher<K, V> implements Closeable {
                     // for the partition. If so, it means we have experienced log truncation and need to reposition
                     // that partition's offset.
                     offsetsResult.endOffsets().forEach((respTopicPartition, respEndOffset) -> {
-                        if (!subscriptions.isAssigned(respTopicPartition)) {
-                            log.debug("Ignoring OffsetsForLeader response for partition {} which is not currently assigned.", respTopicPartition);
-                            return;
-                        }
-
-                        if (subscriptions.awaitingValidation(respTopicPartition)) {
-                            SubscriptionState.FetchPosition currentPosition = subscriptions.position(respTopicPartition);
-                            Metadata.LeaderAndEpoch currentLeader = currentPosition.currentLeader;
-                            if (!currentLeader.equals(cachedLeaderAndEpochs.get(respTopicPartition))) {
-                                return;
-                            }
-
-                            if (respEndOffset.endOffset() < currentPosition.offset) {
-                                if (subscriptions.hasDefaultOffsetResetPolicy()) {
-                                    SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
-                                            respEndOffset.endOffset(), Optional.of(respEndOffset.leaderEpoch()), currentLeader);
-                                    log.info("Truncation detected for partition {}, resetting offset to {}", respTopicPartition, newPosition);
-                                    subscriptions.seek(respTopicPartition, newPosition);
-                                } else {
-                                    log.warn("Truncation detected for partition {}, but no reset policy is set", respTopicPartition);
-                                    truncationWithoutResetPolicy.put(respTopicPartition, new OffsetAndMetadata(
-                                            respEndOffset.endOffset(), Optional.of(respEndOffset.leaderEpoch()), null));
-                                }
-                            } else {
-                                // Offset is fine, clear the validation state
-                                subscriptions.completeValidation(respTopicPartition);
-                            }
-                        }
+                        SubscriptionState.FetchPosition requestPosition = fetchPostitions.get(respTopicPartition);
+                        Optional<OffsetAndMetadata> divergentOffsetOpt = subscriptions.maybeCompleteValidation(
+                                respTopicPartition, requestPosition, respEndOffset);
+                        divergentOffsetOpt.ifPresent(divergentOffset -> {
+                            truncationWithoutResetPolicy.put(respTopicPartition, divergentOffset);
+                        });
                     });
 
                     if (!truncationWithoutResetPolicy.isEmpty()) {
@@ -814,7 +787,7 @@ public class Fetcher<K, V> implements Closeable {
 
                 @Override
                 public void onFailure(RuntimeException e) {
-                    subscriptions.requestFailed(dataMap.keySet(), time.milliseconds() + retryBackoffMs);
+                    subscriptions.requestFailed(fetchPostitions.keySet(), time.milliseconds() + retryBackoffMs);
                     metadata.requestUpdate();
 
                     if (!(e instanceof RetriableException) && !cachedOffsetForLeaderException.compareAndSet(null, e)) {
@@ -1084,7 +1057,7 @@ public class Fetcher<K, V> implements Closeable {
 
         // Ensure the position has an up-to-date leader
         subscriptions.assignedPartitions().forEach(
-            tp -> subscriptions.maybeValidatePosition(tp, metadata.leaderAndEpoch(tp)));
+            tp -> subscriptions.maybeValidatePositionForCurrentLeader(tp, metadata.leaderAndEpoch(tp)));
 
         long currentTimeMs = time.milliseconds();
 
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index 8cba206..281578b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -24,6 +24,7 @@ import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.internals.PartitionStates;
+import org.apache.kafka.common.requests.EpochEndOffset;
 import org.apache.kafka.common.requests.IsolationLevel;
 import org.apache.kafka.common.utils.LogContext;
 import org.slf4j.Logger;
@@ -50,7 +51,7 @@ import java.util.stream.Collectors;
  * or with {@link #assignFromSubscribed(Collection)} (automatic assignment from subscription).
  *
  * Once assigned, the partition is not considered "fetchable" until its initial position has
- * been set with {@link #seek(TopicPartition, FetchPosition)}. Fetchable partitions track a fetch
+ * been set with {@link #seekValidated(TopicPartition, FetchPosition)}. Fetchable partitions track a fetch
  * position which is used to set the offset of the next fetch, and a consumed position
  * which is the last offset that has been returned to the user. You can suspend fetching
  * from a partition through {@link #pause(TopicPartition)} without affecting the fetched/consumed
@@ -289,7 +290,7 @@ public class SubscriptionState {
         return Collections.emptySet();
     }
 
-    public Set<TopicPartition> pausedPartitions() {
+    public synchronized Set<TopicPartition> pausedPartitions() {
         return collectPartitions(TopicPartitionState::isPaused, Collectors.toSet());
     }
 
@@ -323,19 +324,19 @@ public class SubscriptionState {
         return this.assignment.stateValue(tp);
     }
 
-    public synchronized void seek(TopicPartition tp, FetchPosition position) {
-        assignedState(tp).seek(position);
+    public synchronized void seekValidated(TopicPartition tp, FetchPosition position) {
+        assignedState(tp).seekValidated(position);
     }
 
-    public synchronized void seekAndValidate(TopicPartition tp, FetchPosition position) {
-        assignedState(tp).seekAndValidate(position);
+    public void seek(TopicPartition tp, long offset) {
+        seekValidated(tp, new FetchPosition(offset));
     }
 
-    public void seek(TopicPartition tp, long offset) {
-        seek(tp, new FetchPosition(offset));
+    public void seekUnvalidated(TopicPartition tp, FetchPosition position) {
+        assignedState(tp).seekUnvalidated(position);
     }
 
-    synchronized void maybeSeek(TopicPartition tp, long offset, OffsetResetStrategy requestedResetStrategy) {
+    synchronized void maybeSeekUnvalidated(TopicPartition tp, long offset, OffsetResetStrategy requestedResetStrategy) {
         TopicPartitionState state = assignedStateOrNull(tp);
         if (state == null) {
             log.debug("Skipping reset of partition {} since it is no longer assigned", tp);
@@ -345,7 +346,7 @@ public class SubscriptionState {
             log.debug("Skipping reset of partition {} since an alternative reset has been requested", tp);
         } else {
             log.info("Resetting offset for partition {} to offset {}.", tp, offset);
-            state.seek(new FetchPosition(offset));
+            state.seekUnvalidated(new FetchPosition(offset));
         }
     }
 
@@ -379,23 +380,64 @@ public class SubscriptionState {
         assignedState(tp).position(position);
     }
 
-    synchronized boolean maybeValidatePosition(TopicPartition tp, Metadata.LeaderAndEpoch leaderAndEpoch) {
+    public synchronized boolean maybeValidatePositionForCurrentLeader(TopicPartition tp, Metadata.LeaderAndEpoch leaderAndEpoch) {
         return assignedState(tp).maybeValidatePosition(leaderAndEpoch);
     }
 
-    synchronized boolean awaitingValidation(TopicPartition tp) {
+    /**
+     * Attempt to complete validation with the end offset returned from the OffsetForLeaderEpoch request.
+     * @return The diverging offset if truncation was detected and no reset policy is defined.
+     */
+    public synchronized Optional<OffsetAndMetadata> maybeCompleteValidation(TopicPartition tp,
+                                                                            FetchPosition requestPosition,
+                                                                            EpochEndOffset epochEndOffset) {
+        TopicPartitionState state = assignedStateOrNull(tp);
+        if (state == null) {
+            log.debug("Skipping completed validation for partition {} which is not currently assigned.", tp);
+        } else if (!state.awaitingValidation()) {
+            log.debug("Skipping completed validation for partition {} which is no longer expecting validation.", tp);
+        } else {
+            SubscriptionState.FetchPosition currentPosition = state.position;
+            if (!currentPosition.equals(requestPosition)) {
+                log.debug("Skipping completed validation for partition {} since the current position {} " +
+                                "no longer matches the position {} when the request was sent",
+                        tp, currentPosition, requestPosition);
+            } else if (epochEndOffset.endOffset() < currentPosition.offset) {
+                if (hasDefaultOffsetResetPolicy()) {
+                    SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
+                            epochEndOffset.endOffset(), Optional.of(epochEndOffset.leaderEpoch()),
+                            currentPosition.currentLeader);
+                    log.info("Truncation detected for partition {} at offset {}, resetting offset to " +
+                            "the first offset known to diverge {}", tp, currentPosition, newPosition);
+                    state.seekValidated(newPosition);
+                } else {
+                    log.warn("Truncation detected for partition {} at offset {} (the end offset from the " +
+                                    "broker is {}), but no reset policy is set",
+                            tp, currentPosition, epochEndOffset);
+                    return Optional.of(new OffsetAndMetadata(epochEndOffset.endOffset(),
+                            Optional.of(epochEndOffset.leaderEpoch()), null));
+                }
+            } else {
+                state.completeValidation();
+            }
+        }
+
+        return Optional.empty();
+    }
+
+    public synchronized boolean awaitingValidation(TopicPartition tp) {
         return assignedState(tp).awaitingValidation();
     }
 
     public synchronized void completeValidation(TopicPartition tp) {
-        assignedState(tp).validate();
+        assignedState(tp).completeValidation();
     }
 
     public synchronized FetchPosition validPosition(TopicPartition tp) {
         return assignedState(tp).validPosition();
     }
 
-    synchronized public FetchPosition position(TopicPartition tp) {
+    public synchronized FetchPosition position(TopicPartition tp) {
         return assignedState(tp).position;
     }
 
@@ -505,11 +547,11 @@ public class SubscriptionState {
         return assignment.stream().allMatch(state -> state.value().hasValidPosition());
     }
 
-    Set<TopicPartition> missingFetchPositions() {
+    public synchronized Set<TopicPartition> missingFetchPositions() {
         return collectPartitions(state -> !state.hasPosition(), Collectors.toSet());
     }
 
-    private synchronized <T extends Collection<TopicPartition>> T collectPartitions(Predicate<TopicPartitionState> filter, Collector<TopicPartition, ?, T> collector) {
+    private <T extends Collection<TopicPartition>> T collectPartitions(Predicate<TopicPartitionState> filter, Collector<TopicPartition, ?, T> collector) {
         return assignment.stream()
                 .filter(state -> filter.test(state.value()))
                 .map(PartitionStates.PartitionState::topicPartition)
@@ -534,12 +576,12 @@ public class SubscriptionState {
             throw new NoOffsetForPartitionException(partitionsWithNoOffsets);
     }
 
-    Set<TopicPartition> partitionsNeedingReset(long nowMs) {
+    public synchronized Set<TopicPartition> partitionsNeedingReset(long nowMs) {
         return collectPartitions(state -> state.awaitingReset() && !state.awaitingRetryBackoff(nowMs),
                 Collectors.toSet());
     }
 
-    Set<TopicPartition> partitionsNeedingValidation(long nowMs) {
+    public synchronized Set<TopicPartition> partitionsNeedingValidation(long nowMs) {
         return collectPartitions(state -> state.awaitingValidation() && !state.awaitingRetryBackoff(nowMs),
                 Collectors.toSet());
     }
@@ -669,7 +711,7 @@ public class SubscriptionState {
                 return false;
             }
 
-            if (position != null && !position.safeToFetchFrom(currentLeaderAndEpoch)) {
+            if (position != null && !position.currentLeader.equals(currentLeaderAndEpoch)) {
                 FetchPosition newPosition = new FetchPosition(position.offset, position.offsetEpoch, currentLeaderAndEpoch);
                 validatePosition(newPosition);
                 preferredReadReplica = null;
@@ -678,7 +720,7 @@ public class SubscriptionState {
         }
 
         private void validatePosition(FetchPosition position) {
-            if (position.offsetEpoch.isPresent()) {
+            if (position.offsetEpoch.isPresent() && position.currentLeader.epoch.isPresent()) {
                 transitionState(FetchStates.AWAIT_VALIDATION, () -> {
                     this.position = position;
                     this.nextRetryTimeMs = null;
@@ -695,7 +737,7 @@ public class SubscriptionState {
         /**
          * Clear the awaiting validation state and enter fetching.
          */
-        private void validate() {
+        private void completeValidation() {
             if (hasPosition()) {
                 transitionState(FetchStates.FETCHING, () -> {
                     this.nextRetryTimeMs = null;
@@ -735,7 +777,7 @@ public class SubscriptionState {
             return paused;
         }
 
-        private void seek(FetchPosition position) {
+        private void seekValidated(FetchPosition position) {
             transitionState(FetchStates.FETCHING, () -> {
                 this.position = position;
                 this.resetStrategy = null;
@@ -743,8 +785,8 @@ public class SubscriptionState {
             });
         }
 
-        private void seekAndValidate(FetchPosition fetchPosition) {
-            seek(fetchPosition);
+        private void seekUnvalidated(FetchPosition fetchPosition) {
+            seekValidated(fetchPosition);
             validatePosition(fetchPosition);
         }
 
@@ -908,14 +950,6 @@ public class SubscriptionState {
             this.currentLeader = Objects.requireNonNull(currentLeader);
         }
 
-        /**
-         * Test if it is "safe" to fetch from a given leader and epoch. This effectively is testing if
-         * {@link Metadata.LeaderAndEpoch} known to the subscription is equal to the one supplied by the caller.
-         */
-        boolean safeToFetchFrom(Metadata.LeaderAndEpoch leaderAndEpoch) {
-            return !currentLeader.leader.isEmpty() && currentLeader.equals(leaderAndEpoch);
-        }
-
         @Override
         public boolean equals(Object o) {
             if (this == o) return true;
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 1e64e3b..44c00c4 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -75,6 +75,7 @@ import org.apache.kafka.common.requests.ListOffsetRequest;
 import org.apache.kafka.common.requests.ListOffsetResponse;
 import org.apache.kafka.common.requests.MetadataRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.requests.OffsetsForLeaderEpochRequest;
 import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse;
 import org.apache.kafka.common.requests.ResponseHeader;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
@@ -1504,7 +1505,7 @@ public class FetcherTest {
                 Object result = invocation.callRealMethod();
                 latchEarliestDone.countDown();
                 return result;
-            }).when(subscriptions).maybeSeek(tp0, 0L, OffsetResetStrategy.EARLIEST);
+            }).when(subscriptions).maybeSeekUnvalidated(tp0, 0L, OffsetResetStrategy.EARLIEST);
 
             es.submit(() -> {
                 subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST);
@@ -2795,8 +2796,8 @@ public class FetcherTest {
 
         List<ConsumerRecord<byte[], byte[]>> records;
         assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1)));
-        subscriptions.seek(tp0, new SubscriptionState.FetchPosition(0, Optional.empty(), metadata.leaderAndEpoch(tp0)));
-        subscriptions.seek(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.leaderAndEpoch(tp1)));
+        subscriptions.seekValidated(tp0, new SubscriptionState.FetchPosition(0, Optional.empty(), metadata.leaderAndEpoch(tp0)));
+        subscriptions.seekValidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.leaderAndEpoch(tp1)));
 
         // Fetch some records and establish an incremental fetch session.
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions1 = new LinkedHashMap<>();
@@ -3276,7 +3277,7 @@ public class FetcherTest {
         // Seek with a position and leader+epoch
         Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(
                 metadata.leaderAndEpoch(tp0).leader, Optional.of(epochOne));
-        subscriptions.seekAndValidate(tp0, new SubscriptionState.FetchPosition(20L, Optional.of(epochOne), leaderAndEpoch));
+        subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(20L, Optional.of(epochOne), leaderAndEpoch));
         assertFalse(client.isConnected(node.idString()));
         assertTrue(subscriptions.awaitingValidation(tp0));
 
@@ -3325,7 +3326,7 @@ public class FetcherTest {
         // Seek with a position and leader+epoch
         Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(
                 metadata.leaderAndEpoch(tp0).leader, Optional.of(epochOne));
-        subscriptions.seek(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch));
+        subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch));
 
         // Update metadata to epoch=2, enter validation
         metadata.update(TestUtils.metadataUpdateWith("dummy", 1,
@@ -3337,6 +3338,45 @@ public class FetcherTest {
     }
 
     @Test
+    public void testOffsetValidationHandlesSeekWithInflightOffsetForLeaderRequest() {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
+
+        Map<String, Integer> partitionCounts = new HashMap<>();
+        partitionCounts.put(tp0.topic(), 4);
+
+        final int epochOne = 1;
+
+        metadata.update(TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, tp -> epochOne), 0L);
+
+        // Offset validation requires OffsetForLeaderEpoch request v3 or higher
+        Node node = metadata.fetch().nodes().get(0);
+        apiVersions.update(node.idString(), NodeApiVersions.create());
+
+        Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.leaderAndEpoch(tp0).leader, Optional.of(epochOne));
+        subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch));
+
+        fetcher.validateOffsetsIfNeeded();
+        consumerClient.poll(time.timer(Duration.ZERO));
+        assertTrue(subscriptions.awaitingValidation(tp0));
+        assertTrue(client.hasInFlightRequests());
+
+        // While the OffsetForLeaderEpoch request is in-flight, we seek to a different offset.
+        subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(5, Optional.of(epochOne), leaderAndEpoch));
+        assertTrue(subscriptions.awaitingValidation(tp0));
+
+        client.respond(request -> {
+            OffsetsForLeaderEpochRequest epochRequest = (OffsetsForLeaderEpochRequest) request;
+            OffsetsForLeaderEpochRequest.PartitionData partitionData = epochRequest.epochsByTopicPartition().get(tp0);
+            return partitionData.currentLeaderEpoch.equals(Optional.of(epochOne)) && partitionData.leaderEpoch == epochOne;
+        }, new OffsetsForLeaderEpochResponse(singletonMap(tp0, new EpochEndOffset(0, 0L))));
+        consumerClient.poll(time.timer(Duration.ZERO));
+
+        // The response should be ignored since we were validating a different position.
+        assertTrue(subscriptions.awaitingValidation(tp0));
+    }
+
+    @Test
     public void testOffsetValidationFencing() {
         buildFetcher();
         assignFromUser(singleton(tp0));
@@ -3357,7 +3397,7 @@ public class FetcherTest {
 
         // Seek with a position and leader+epoch
         Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.leaderAndEpoch(tp0).leader, Optional.of(epochOne));
-        subscriptions.seek(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch));
+        subscriptions.seekValidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch));
 
         // Update metadata to epoch=2, enter validation
         metadata.update(TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, tp -> epochTwo), 0L);
@@ -3371,7 +3411,7 @@ public class FetcherTest {
                 Optional.of(epochTwo),
                 new Metadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.of(epochTwo)));
         subscriptions.position(tp0, nextPosition);
-        subscriptions.maybeValidatePosition(tp0, new Metadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.of(epochThree)));
+        subscriptions.maybeValidatePositionForCurrentLeader(tp0, new Metadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.of(epochThree)));
 
         // Prepare offset list response from async validation with epoch=2
         Map<TopicPartition, EpochEndOffset> endOffsetMap = new HashMap<>();
@@ -3427,7 +3467,7 @@ public class FetcherTest {
 
         // Seek
         Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.leaderAndEpoch(tp0).leader, Optional.of(1));
-        subscriptions.seek(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), leaderAndEpoch));
+        subscriptions.seekValidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), leaderAndEpoch));
 
         // Check for truncation, this should cause tp0 to go into validation
         fetcher.validateOffsetsIfNeeded();
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
index 528c2b9..484b9de 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
@@ -18,9 +18,11 @@ package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.requests.EpochEndOffset;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.test.TestUtils;
@@ -42,9 +44,7 @@ import static org.junit.Assert.assertTrue;
 
 public class SubscriptionStateTest {
 
-    private final SubscriptionState state = new SubscriptionState(
-            new LogContext(),
-            OffsetResetStrategy.EARLIEST);
+    private SubscriptionState state = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST);
     private final String topic = "test";
     private final String topic1 = "test1";
     private final TopicPartition tp0 = new TopicPartition(topic, 0);
@@ -340,13 +340,245 @@ public class SubscriptionStateTest {
         assertFalse(state.preferredReadReplica(tp0, 31L).isPresent());
     }
 
+    @Test
+    public void testSeekUnvalidatedWithNoOffsetEpoch() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        // Seek with no offset epoch requires no validation no matter what the current leader is
+        state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.empty(),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(5))));
+        assertTrue(state.hasValidPosition(tp0));
+        assertFalse(state.awaitingValidation(tp0));
+
+        assertFalse(state.maybeValidatePositionForCurrentLeader(tp0, new Metadata.LeaderAndEpoch(broker1, Optional.empty())));
+        assertTrue(state.hasValidPosition(tp0));
+        assertFalse(state.awaitingValidation(tp0));
+
+        assertFalse(state.maybeValidatePositionForCurrentLeader(tp0, new Metadata.LeaderAndEpoch(broker1, Optional.of(10))));
+        assertTrue(state.hasValidPosition(tp0));
+        assertFalse(state.awaitingValidation(tp0));
+    }
+
+    @Test
+    public void testSeekUnvalidatedWithNoEpochClearsAwaitingValidation() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        // Seek with no offset epoch requires no validation no matter what the current leader is
+        state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.of(2),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(5))));
+        assertFalse(state.hasValidPosition(tp0));
+        assertTrue(state.awaitingValidation(tp0));
+
+        state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.empty(),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(5))));
+        assertTrue(state.hasValidPosition(tp0));
+        assertFalse(state.awaitingValidation(tp0));
+    }
+
+    @Test
+    public void testSeekUnvalidatedWithOffsetEpoch() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.of(2),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(5))));
+        assertFalse(state.hasValidPosition(tp0));
+        assertTrue(state.awaitingValidation(tp0));
+
+        // Update using the current leader and epoch
+        assertTrue(state.maybeValidatePositionForCurrentLeader(tp0, new Metadata.LeaderAndEpoch(broker1, Optional.of(5))));
+        assertFalse(state.hasValidPosition(tp0));
+        assertTrue(state.awaitingValidation(tp0));
+
+        // Update with a newer leader and epoch
+        assertTrue(state.maybeValidatePositionForCurrentLeader(tp0, new Metadata.LeaderAndEpoch(broker1, Optional.of(15))));
+        assertFalse(state.hasValidPosition(tp0));
+        assertTrue(state.awaitingValidation(tp0));
+
+        // If the updated leader has no epoch information, then skip validation and begin fetching
+        assertFalse(state.maybeValidatePositionForCurrentLeader(tp0, new Metadata.LeaderAndEpoch(broker1, Optional.empty())));
+        assertTrue(state.hasValidPosition(tp0));
+        assertFalse(state.awaitingValidation(tp0));
+    }
+
+    @Test
+    public void testSeekValidatedShouldClearAwaitingValidation() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(10))));
+        assertFalse(state.hasValidPosition(tp0));
+        assertTrue(state.awaitingValidation(tp0));
+        assertEquals(10L, state.position(tp0).offset);
+
+        state.seekValidated(tp0, new SubscriptionState.FetchPosition(8L, Optional.of(4),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(10))));
+        assertTrue(state.hasValidPosition(tp0));
+        assertFalse(state.awaitingValidation(tp0));
+        assertEquals(8L, state.position(tp0).offset);
+    }
+
+    @Test
+    public void testCompleteValidationShouldClearAwaitingValidation() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(10))));
+        assertFalse(state.hasValidPosition(tp0));
+        assertTrue(state.awaitingValidation(tp0));
+        assertEquals(10L, state.position(tp0).offset);
+
+        state.completeValidation(tp0);
+        assertTrue(state.hasValidPosition(tp0));
+        assertFalse(state.awaitingValidation(tp0));
+        assertEquals(10L, state.position(tp0).offset);
+    }
+
+    @Test
+    public void testOffsetResetWhileAwaitingValidation() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5),
+                new Metadata.LeaderAndEpoch(broker1, Optional.of(10))));
+        assertTrue(state.awaitingValidation(tp0));
+
+        state.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST);
+        assertFalse(state.awaitingValidation(tp0));
+        assertTrue(state.isOffsetResetNeeded(tp0));
+    }
+
+    @Test
+    public void testMaybeCompleteValidation() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        int currentEpoch = 10;
+        long initialOffset = 10L;
+        int initialOffsetEpoch = 5;
+
+        SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset,
+                Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(broker1, Optional.of(currentEpoch)));
+        state.seekUnvalidated(tp0, initialPosition);
+        assertTrue(state.awaitingValidation(tp0));
+
+        Optional<OffsetAndMetadata> divergentOffsetMetadataOpt = state.maybeCompleteValidation(tp0, initialPosition,
+                new EpochEndOffset(initialOffsetEpoch, initialOffset + 5));
+        assertEquals(Optional.empty(), divergentOffsetMetadataOpt);
+        assertFalse(state.awaitingValidation(tp0));
+        assertEquals(initialPosition, state.position(tp0));
+    }
+
+    @Test
+    public void testMaybeCompleteValidationAfterPositionChange() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        int currentEpoch = 10;
+        long initialOffset = 10L;
+        int initialOffsetEpoch = 5;
+        long updateOffset = 20L;
+        int updateOffsetEpoch = 8;
+
+        SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset,
+                Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(broker1, Optional.of(currentEpoch)));
+        state.seekUnvalidated(tp0, initialPosition);
+        assertTrue(state.awaitingValidation(tp0));
+
+        SubscriptionState.FetchPosition updatePosition = new SubscriptionState.FetchPosition(updateOffset,
+                Optional.of(updateOffsetEpoch), new Metadata.LeaderAndEpoch(broker1, Optional.of(currentEpoch)));
+        state.seekUnvalidated(tp0, updatePosition);
+
+        Optional<OffsetAndMetadata> divergentOffsetMetadataOpt = state.maybeCompleteValidation(tp0, initialPosition,
+                new EpochEndOffset(initialOffsetEpoch, initialOffset + 5));
+        assertEquals(Optional.empty(), divergentOffsetMetadataOpt);
+        assertTrue(state.awaitingValidation(tp0));
+        assertEquals(updatePosition, state.position(tp0));
+    }
+
+    @Test
+    public void testMaybeCompleteValidationAfterOffsetReset() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        int currentEpoch = 10;
+        long initialOffset = 10L;
+        int initialOffsetEpoch = 5;
+
+        SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset,
+                Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(broker1, Optional.of(currentEpoch)));
+        state.seekUnvalidated(tp0, initialPosition);
+        assertTrue(state.awaitingValidation(tp0));
+
+        state.requestOffsetReset(tp0);
+
+        Optional<OffsetAndMetadata> divergentOffsetMetadataOpt = state.maybeCompleteValidation(tp0, initialPosition,
+                new EpochEndOffset(initialOffsetEpoch, initialOffset + 5));
+        assertEquals(Optional.empty(), divergentOffsetMetadataOpt);
+        assertFalse(state.awaitingValidation(tp0));
+        assertTrue(state.isOffsetResetNeeded(tp0));
+    }
+
+    @Test
+    public void testTruncationDetectionWithResetPolicy() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        int currentEpoch = 10;
+        long initialOffset = 10L;
+        int initialOffsetEpoch = 5;
+        long divergentOffset = 5L;
+        int divergentOffsetEpoch = 7;
+
+        SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset,
+                Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(broker1, Optional.of(currentEpoch)));
+        state.seekUnvalidated(tp0, initialPosition);
+        assertTrue(state.awaitingValidation(tp0));
+
+        Optional<OffsetAndMetadata> divergentOffsetMetadata = state.maybeCompleteValidation(tp0, initialPosition,
+                new EpochEndOffset(divergentOffsetEpoch, divergentOffset));
+        assertEquals(Optional.empty(), divergentOffsetMetadata);
+        assertFalse(state.awaitingValidation(tp0));
+
+        SubscriptionState.FetchPosition updatedPosition = new SubscriptionState.FetchPosition(divergentOffset,
+                Optional.of(divergentOffsetEpoch), new Metadata.LeaderAndEpoch(broker1, Optional.of(currentEpoch)));
+        assertEquals(updatedPosition, state.position(tp0));
+    }
+
+    @Test
+    public void testTruncationDetectionWithoutResetPolicy() {
+        Node broker1 = new Node(1, "localhost", 9092);
+        state = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);
+        state.assignFromUser(Collections.singleton(tp0));
+
+        int currentEpoch = 10;
+        long initialOffset = 10L;
+        int initialOffsetEpoch = 5;
+        long divergentOffset = 5L;
+        int divergentOffsetEpoch = 7;
+
+        SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset,
+                Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(broker1, Optional.of(currentEpoch)));
+        state.seekUnvalidated(tp0, initialPosition);
+        assertTrue(state.awaitingValidation(tp0));
+
+        Optional<OffsetAndMetadata> divergentOffsetMetadata = state.maybeCompleteValidation(tp0, initialPosition,
+                new EpochEndOffset(divergentOffsetEpoch, divergentOffset));
+        assertEquals(Optional.of(new OffsetAndMetadata(divergentOffset, Optional.of(divergentOffsetEpoch), "")),
+                divergentOffsetMetadata);
+        assertTrue(state.awaitingValidation(tp0));
+    }
+
     private static class MockRebalanceListener implements ConsumerRebalanceListener {
         public Collection<TopicPartition> revoked;
         public Collection<TopicPartition> assigned;
         public int revokedCount = 0;
         public int assignedCount = 0;
 
-
         @Override
         public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
             this.assigned = partitions;