You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/04/21 23:24:42 UTC

[kafka] branch trunk updated: KAFKA-7747; Check for truncation after leader changes [KIP-320] (#6371)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 409fabc  KAFKA-7747; Check for truncation after leader changes [KIP-320] (#6371)
409fabc is described below

commit 409fabc5610443f36574bdea2e2994b6c20e2829
Author: David Arthur <mu...@gmail.com>
AuthorDate: Sun Apr 21 19:24:18 2019 -0400

    KAFKA-7747; Check for truncation after leader changes [KIP-320] (#6371)
    
    After the client detects a leader change we need to check the offset of the current leader for truncation. These changes were part of KIP-320: https://cwiki.apache.org/confluence/display/KAFKA/KIP-320%3A+Allow+fetchers+to+detect+and+handle+log+truncation.
    
    Reviewers: Jason Gustafson <ja...@confluent.io>
---
 .../java/org/apache/kafka/clients/Metadata.java    |  58 +++-
 .../kafka/clients/consumer/KafkaConsumer.java      |  33 +-
 .../clients/consumer/LogTruncationException.java   |  50 +++
 .../kafka/clients/consumer/MockConsumer.java       |  22 +-
 .../clients/consumer/internals/AsyncClient.java    |  75 +++++
 .../consumer/internals/ConsumerCoordinator.java    |  12 +-
 .../kafka/clients/consumer/internals/Fetcher.java  | 254 +++++++++++----
 .../internals/OffsetsForLeaderEpochClient.java     | 128 ++++++++
 .../consumer/internals/SubscriptionState.java      | 348 ++++++++++++++++++---
 .../java/org/apache/kafka/common/utils/Utils.java  |  12 +
 .../org/apache/kafka/clients/MetadataTest.java     |  32 +-
 .../kafka/clients/consumer/KafkaConsumerTest.java  |  20 ++
 .../internals/ConsumerCoordinatorTest.java         |  40 ++-
 .../clients/consumer/internals/FetcherTest.java    | 291 +++++++++++++----
 .../internals/OffsetForLeaderEpochClientTest.java  | 166 ++++++++++
 .../consumer/internals/SubscriptionStateTest.java  |  17 +-
 .../test/java/org/apache/kafka/test/TestUtils.java |  27 +-
 .../SmokeTestDriverIntegrationTest.java            |   3 +
 tests/kafkatest/services/console_consumer.py       |   7 +-
 tests/kafkatest/services/kafka/kafka.py            |  12 +
 tests/kafkatest/services/verifiable_consumer.py    |  46 ++-
 tests/kafkatest/tests/client/truncation_test.py    | 150 +++++++++
 tests/kafkatest/tests/verifiable_consumer_test.py  |  11 +-
 23 files changed, 1574 insertions(+), 240 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/Metadata.java b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
index dfd461a..ef01b4b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/Metadata.java
+++ b/clients/src/main/java/org/apache/kafka/clients/Metadata.java
@@ -18,6 +18,7 @@ package org.apache.kafka.clients;
 
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.KafkaException;
+import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.AuthenticationException;
@@ -314,6 +315,8 @@ public class Metadata implements Closeable {
                 if (metadata.isInternal())
                     internalTopics.add(metadata.topic());
                 for (MetadataResponse.PartitionMetadata partitionMetadata : metadata.partitionMetadata()) {
+
+                    // Even if the partition's metadata includes an error, we need to handle the update to catch new epochs
                     updatePartitionInfo(metadata.topic(), partitionMetadata, partitionInfo -> {
                         int epoch = partitionMetadata.leaderEpoch().orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH);
                         partitions.add(new MetadataCache.PartitionInfoAndEpoch(partitionInfo, epoch));
@@ -358,8 +361,8 @@ public class Metadata implements Closeable {
                 }
             }
         } else {
-            // Old cluster format (no epochs)
-            lastSeenLeaderEpochs.clear();
+            // Handle old cluster formats as well as error responses where leader and epoch are missing
+            lastSeenLeaderEpochs.remove(tp);
             partitionInfoConsumer.accept(MetadataResponse.partitionMetaToInfo(topic, partitionMetadata));
         }
     }
@@ -444,4 +447,55 @@ public class Metadata implements Closeable {
         }
     }
 
+    public synchronized LeaderAndEpoch leaderAndEpoch(TopicPartition tp) {
+        return partitionInfoIfCurrent(tp)
+                .map(infoAndEpoch -> {
+                    Node leader = infoAndEpoch.partitionInfo().leader();
+                    return new LeaderAndEpoch(leader == null ? Node.noNode() : leader, Optional.of(infoAndEpoch.epoch()));
+                })
+                .orElse(new LeaderAndEpoch(Node.noNode(), lastSeenLeaderEpoch(tp)));
+    }
+
+    public static class LeaderAndEpoch {
+
+        public static final LeaderAndEpoch NO_LEADER_OR_EPOCH = new LeaderAndEpoch(Node.noNode(), Optional.empty());
+
+        public final Node leader;
+        public final Optional<Integer> epoch;
+
+        public LeaderAndEpoch(Node leader, Optional<Integer> epoch) {
+            this.leader = Objects.requireNonNull(leader);
+            this.epoch = Objects.requireNonNull(epoch);
+        }
+
+        public static LeaderAndEpoch noLeaderOrEpoch() {
+            return NO_LEADER_OR_EPOCH;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            LeaderAndEpoch that = (LeaderAndEpoch) o;
+
+            if (!leader.equals(that.leader)) return false;
+            return epoch.equals(that.epoch);
+        }
+
+        @Override
+        public int hashCode() {
+            int result = leader.hashCode();
+            result = 31 * result + epoch.hashCode();
+            return result;
+        }
+
+        @Override
+        public String toString() {
+            return "LeaderAndEpoch{" +
+                    "leader=" + leader +
+                    ", epoch=" + epoch.map(Number::toString).orElse("absent") +
+                    '}';
+        }
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index 4cee56a..839057c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -19,6 +19,7 @@ package org.apache.kafka.clients.consumer;
 import org.apache.kafka.clients.ApiVersions;
 import org.apache.kafka.clients.ClientDnsLookup;
 import org.apache.kafka.clients.ClientUtils;
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.consumer.internals.ConsumerCoordinator;
 import org.apache.kafka.clients.consumer.internals.ConsumerInterceptors;
@@ -69,6 +70,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.Properties;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
@@ -1508,7 +1510,20 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      */
     @Override
     public void seek(TopicPartition partition, long offset) {
-        seek(partition, new OffsetAndMetadata(offset, null));
+        if (offset < 0)
+            throw new IllegalArgumentException("seek offset must not be a negative number");
+
+        acquireAndEnsureOpen();
+        try {
+            log.info("Seeking to offset {} for partition {}", offset, partition);
+            SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
+                    offset,
+                    Optional.empty(), // This will ensure we skip validation
+                    this.metadata.leaderAndEpoch(partition));
+            this.subscriptions.seek(partition, newPosition);
+        } finally {
+            release();
+        }
     }
 
     /**
@@ -1535,8 +1550,13 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
             } else {
                 log.info("Seeking to offset {} for partition {}", offset, partition);
             }
+            Metadata.LeaderAndEpoch currentLeaderAndEpoch = this.metadata.leaderAndEpoch(partition);
+            SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
+                    offsetAndMetadata.offset(),
+                    offsetAndMetadata.leaderEpoch(),
+                    currentLeaderAndEpoch);
             this.updateLastSeenEpochIfNewer(partition, offsetAndMetadata);
-            this.subscriptions.seek(partition, offset);
+            this.subscriptions.seekAndValidate(partition, newPosition);
         } finally {
             release();
         }
@@ -1658,9 +1678,9 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
 
             Timer timer = time.timer(timeout);
             do {
-                Long offset = this.subscriptions.position(partition);
-                if (offset != null)
-                    return offset;
+                SubscriptionState.FetchPosition position = this.subscriptions.validPosition(partition);
+                if (position != null)
+                    return position.offset;
 
                 updateFetchPositions(timer);
                 client.poll(timer);
@@ -2196,6 +2216,9 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
      * @return true iff the operation completed without timing out
      */
     private boolean updateFetchPositions(final Timer timer) {
+        // If any partitions have been truncated due to a leader change, we need to validate the offsets
+        fetcher.validateOffsetsIfNeeded();
+
         cachedSubscriptionHashAllFetchPositions = subscriptions.hasAllFetchPositions();
         if (cachedSubscriptionHashAllFetchPositions) return true;
 
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/LogTruncationException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/LogTruncationException.java
new file mode 100644
index 0000000..f8af50d
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/LogTruncationException.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer;
+
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.utils.Utils;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.function.Function;
+
+/**
+ * In the even of unclean leader election, the log will be truncated,
+ * previously committed data will be lost, and new data will be written
+ * over these offsets. When this happens, the consumer will detect the
+ * truncation and raise this exception (if no automatic reset policy
+ * has been defined) with the first offset to diverge from what the
+ * consumer read.
+ */
+public class LogTruncationException extends OffsetOutOfRangeException {
+
+    private final Map<TopicPartition, OffsetAndMetadata> divergentOffsets;
+
+    public LogTruncationException(Map<TopicPartition, OffsetAndMetadata> divergentOffsets) {
+        super(Utils.transformMap(divergentOffsets, Function.identity(), OffsetAndMetadata::offset));
+        this.divergentOffsets = Collections.unmodifiableMap(divergentOffsets);
+    }
+
+    /**
+     * Get the offsets for the partitions which were truncated. This is the first offset which is known to diverge
+     * from what the consumer read.
+     */
+    public Map<TopicPartition, OffsetAndMetadata> divergentOffsets() {
+        return divergentOffsets;
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
index 614ec9b..c8c2e72 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
@@ -16,11 +16,13 @@
  */
 package org.apache.kafka.clients.consumer;
 
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.internals.SubscriptionState;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.Metric;
 import org.apache.kafka.common.MetricName;
+import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.WakeupException;
@@ -189,13 +191,17 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
             if (!subscriptions.isPaused(entry.getKey())) {
                 final List<ConsumerRecord<K, V>> recs = entry.getValue();
                 for (final ConsumerRecord<K, V> rec : recs) {
-                    if (beginningOffsets.get(entry.getKey()) != null && beginningOffsets.get(entry.getKey()) > subscriptions.position(entry.getKey())) {
-                        throw new OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), subscriptions.position(entry.getKey())));
+                    long position = subscriptions.position(entry.getKey()).offset;
+
+                    if (beginningOffsets.get(entry.getKey()) != null && beginningOffsets.get(entry.getKey()) > position) {
+                        throw new OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), position));
                     }
 
-                    if (assignment().contains(entry.getKey()) && rec.offset() >= subscriptions.position(entry.getKey())) {
+                    if (assignment().contains(entry.getKey()) && rec.offset() >= position) {
                         results.computeIfAbsent(entry.getKey(), partition -> new ArrayList<>()).add(rec);
-                        subscriptions.position(entry.getKey(), rec.offset() + 1);
+                        SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
+                                rec.offset() + 1, rec.leaderEpoch(), new Metadata.LeaderAndEpoch(Node.noNode(), rec.leaderEpoch()));
+                        subscriptions.position(entry.getKey(), newPosition);
                     }
                 }
             }
@@ -290,12 +296,12 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
         ensureNotClosed();
         if (!this.subscriptions.isAssigned(partition))
             throw new IllegalArgumentException("You can only check the position for partitions assigned to this consumer.");
-        Long offset = this.subscriptions.position(partition);
-        if (offset == null) {
+        SubscriptionState.FetchPosition position = this.subscriptions.position(partition);
+        if (position == null) {
             updateFetchPosition(partition);
-            offset = this.subscriptions.position(partition);
+            position = this.subscriptions.position(partition);
         }
-        return offset;
+        return position.offset;
     }
 
     @Override
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncClient.java
new file mode 100644
index 0000000..8b35499
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncClient.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.clients.ClientResponse;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.requests.AbstractRequest;
+import org.apache.kafka.common.requests.AbstractResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.slf4j.Logger;
+
+public abstract class AsyncClient<T1, Req extends AbstractRequest, Resp extends AbstractResponse, T2> {
+
+    private final Logger log;
+    private final ConsumerNetworkClient client;
+
+    AsyncClient(ConsumerNetworkClient client, LogContext logContext) {
+        this.client = client;
+        this.log = logContext.logger(getClass());
+    }
+
+    public RequestFuture<T2> sendAsyncRequest(Node node, T1 requestData) {
+        AbstractRequest.Builder<Req> requestBuilder = prepareRequest(node, requestData);
+
+        return client.send(node, requestBuilder).compose(new RequestFutureAdapter<ClientResponse, T2>() {
+            @Override
+            @SuppressWarnings("unchecked")
+            public void onSuccess(ClientResponse value, RequestFuture<T2> future) {
+                Resp resp;
+                try {
+                    resp = (Resp) value.responseBody();
+                } catch (ClassCastException cce) {
+                    log.error("Could not cast response body", cce);
+                    future.raise(cce);
+                    return;
+                }
+                log.trace("Received {} {} from broker {}", resp.getClass().getSimpleName(), resp, node);
+                try {
+                    future.complete(handleResponse(node, requestData, resp));
+                } catch (RuntimeException e) {
+                    if (!future.isDone()) {
+                        future.raise(e);
+                    }
+                }
+            }
+
+            @Override
+            public void onFailure(RuntimeException e, RequestFuture<T2> future1) {
+                future1.raise(e);
+            }
+        });
+    }
+
+    protected Logger logger() {
+        return log;
+    }
+
+    protected abstract AbstractRequest.Builder<Req> prepareRequest(Node node, T1 requestData);
+
+    protected abstract T2 handleResponse(Node node, T1 requestData, Resp response);
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index b31bf44..64a97ab 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -60,6 +60,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -507,10 +508,15 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
         for (final Map.Entry<TopicPartition, OffsetAndMetadata> entry : offsets.entrySet()) {
             final TopicPartition tp = entry.getKey();
-            final long offset = entry.getValue().offset();
-            log.info("Setting offset for partition {} to the committed offset {}", tp, offset);
+            final OffsetAndMetadata offsetAndMetadata = entry.getValue();
+            final ConsumerMetadata.LeaderAndEpoch leaderAndEpoch = metadata.leaderAndEpoch(tp);
+            final SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition(
+                    offsetAndMetadata.offset(), offsetAndMetadata.leaderEpoch(),
+                    new ConsumerMetadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.empty()));
+
+            log.info("Setting offset for partition {} to the committed offset {}", tp, position);
             entry.getValue().leaderEpoch().ifPresent(epoch -> this.metadata.updateLastSeenEpochIfNewer(entry.getKey(), epoch));
-            this.subscriptions.seek(tp, offset);
+            this.subscriptions.seekAndValidate(tp, position);
         }
         return true;
     }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 93d476f..870a8b7 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -18,10 +18,13 @@ package org.apache.kafka.clients.consumer.internals;
 
 import org.apache.kafka.clients.ClientResponse;
 import org.apache.kafka.clients.FetchSessionHandler;
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MetadataCache;
 import org.apache.kafka.clients.StaleMetadataException;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.apache.kafka.clients.consumer.LogTruncationException;
+import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetAndTimestamp;
 import org.apache.kafka.clients.consumer.OffsetOutOfRangeException;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
@@ -133,6 +136,8 @@ public class Fetcher<K, V> implements Closeable {
     private final IsolationLevel isolationLevel;
     private final Map<Integer, FetchSessionHandler> sessionHandlers;
     private final AtomicReference<RuntimeException> cachedListOffsetsException = new AtomicReference<>();
+    private final AtomicReference<RuntimeException> cachedOffsetForLeaderException = new AtomicReference<>();
+    private final OffsetsForLeaderEpochClient offsetsForLeaderEpochClient;
 
     private PartitionRecords nextInLineRecords = null;
 
@@ -174,17 +179,18 @@ public class Fetcher<K, V> implements Closeable {
         this.requestTimeoutMs = requestTimeoutMs;
         this.isolationLevel = isolationLevel;
         this.sessionHandlers = new HashMap<>();
+        this.offsetsForLeaderEpochClient = new OffsetsForLeaderEpochClient(client, logContext);
     }
 
     /**
      * Represents data about an offset returned by a broker.
      */
-    private static class OffsetData {
+    private static class ListOffsetData {
         final long offset;
         final Long timestamp; //  null if the broker does not support returning timestamps
         final Optional<Integer> leaderEpoch; // empty if the leader epoch is not known
 
-        OffsetData(long offset, Long timestamp, Optional<Integer> leaderEpoch) {
+        ListOffsetData(long offset, Long timestamp, Optional<Integer> leaderEpoch) {
             this.offset = offset;
             this.timestamp = timestamp;
             this.leaderEpoch = leaderEpoch;
@@ -411,22 +417,45 @@ public class Fetcher<K, V> implements Closeable {
         resetOffsetsAsync(offsetResetTimestamps);
     }
 
+    /**
+     *  Validate offsets for all assigned partitions for which a leader change has been detected.
+     */
+    public void validateOffsetsIfNeeded() {
+        RuntimeException exception = cachedOffsetForLeaderException.getAndSet(null);
+        if (exception != null)
+            throw exception;
+
+        // Validate each partition against the current leader and epoch
+        subscriptions.assignedPartitions().forEach(topicPartition -> {
+            ConsumerMetadata.LeaderAndEpoch leaderAndEpoch = metadata.leaderAndEpoch(topicPartition);
+            subscriptions.maybeValidatePosition(topicPartition, leaderAndEpoch);
+        });
+
+        // Collect positions needing validation, with backoff
+        Map<TopicPartition, SubscriptionState.FetchPosition> partitionsToValidate = subscriptions
+                .partitionsNeedingValidation(time.milliseconds())
+                .stream()
+                .collect(Collectors.toMap(Function.identity(), subscriptions::position));
+
+        validateOffsetsAsync(partitionsToValidate);
+    }
+
     public Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes(Map<TopicPartition, Long> timestampsToSearch,
                                                                    Timer timer) {
         metadata.addTransientTopics(topicsForPartitions(timestampsToSearch.keySet()));
 
         try {
-            Map<TopicPartition, OffsetData> fetchedOffsets = fetchOffsetsByTimes(timestampsToSearch,
+            Map<TopicPartition, ListOffsetData> fetchedOffsets = fetchOffsetsByTimes(timestampsToSearch,
                     timer, true).fetchedOffsets;
 
             HashMap<TopicPartition, OffsetAndTimestamp> offsetsByTimes = new HashMap<>(timestampsToSearch.size());
             for (Map.Entry<TopicPartition, Long> entry : timestampsToSearch.entrySet())
                 offsetsByTimes.put(entry.getKey(), null);
 
-            for (Map.Entry<TopicPartition, OffsetData> entry : fetchedOffsets.entrySet()) {
+            for (Map.Entry<TopicPartition, ListOffsetData> entry : fetchedOffsets.entrySet()) {
                 // 'entry.getValue().timestamp' will not be null since we are guaranteed
                 // to work with a v1 (or later) ListOffset request
-                OffsetData offsetData = entry.getValue();
+                ListOffsetData offsetData = entry.getValue();
                 offsetsByTimes.put(entry.getKey(), new OffsetAndTimestamp(offsetData.offset, offsetData.timestamp,
                         offsetData.leaderEpoch));
             }
@@ -570,14 +599,19 @@ public class Fetcher<K, V> implements Closeable {
             log.debug("Not returning fetched records for assigned partition {} since it is no longer fetchable",
                     partitionRecords.partition);
         } else {
-            long position = subscriptions.position(partitionRecords.partition);
-            if (partitionRecords.nextFetchOffset == position) {
+            SubscriptionState.FetchPosition position = subscriptions.position(partitionRecords.partition);
+            if (partitionRecords.nextFetchOffset == position.offset) {
                 List<ConsumerRecord<K, V>> partRecords = partitionRecords.fetchRecords(maxRecords);
 
-                long nextOffset = partitionRecords.nextFetchOffset;
-                log.trace("Returning fetched records at offset {} for assigned partition {} and update " +
-                        "position to {}", position, partitionRecords.partition, nextOffset);
-                subscriptions.position(partitionRecords.partition, nextOffset);
+                if (partitionRecords.nextFetchOffset > position.offset) {
+                    SubscriptionState.FetchPosition nextPosition = new SubscriptionState.FetchPosition(
+                            partitionRecords.nextFetchOffset,
+                            partitionRecords.lastEpoch,
+                            position.currentLeader);
+                    log.trace("Returning fetched records at offset {} for assigned partition {} and update " +
+                            "position to {}", position, partitionRecords.partition, nextPosition);
+                    subscriptions.position(partitionRecords.partition, nextPosition);
+                }
 
                 Long partitionLag = subscriptions.partitionLag(partitionRecords.partition, isolationLevel);
                 if (partitionLag != null)
@@ -601,7 +635,7 @@ public class Fetcher<K, V> implements Closeable {
         return emptyList();
     }
 
-    private void resetOffsetIfNeeded(TopicPartition partition, Long requestedResetTimestamp, OffsetData offsetData) {
+    private void resetOffsetIfNeeded(TopicPartition partition, Long requestedResetTimestamp, ListOffsetData offsetData) {
         // we might lose the assignment while fetching the offset, or the user might seek to a different offset,
         // so verify it is still assigned and still in need of the requested reset
         if (!subscriptions.isAssigned(partition)) {
@@ -611,9 +645,11 @@ public class Fetcher<K, V> implements Closeable {
         } else if (!requestedResetTimestamp.equals(offsetResetStrategyTimestamp(partition))) {
             log.debug("Skipping reset of partition {} since an alternative reset has been requested", partition);
         } else {
-            log.info("Resetting offset for partition {} to offset {}.", partition, offsetData.offset);
+            SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition(
+                    offsetData.offset, offsetData.leaderEpoch, metadata.leaderAndEpoch(partition));
+            log.info("Resetting offset for partition {} to offset {}.", partition, position);
             offsetData.leaderEpoch.ifPresent(epoch -> metadata.updateLastSeenEpochIfNewer(partition, epoch));
-            subscriptions.seek(partition, offsetData.offset);
+            subscriptions.seek(partition, position);
         }
     }
 
@@ -623,20 +659,20 @@ public class Fetcher<K, V> implements Closeable {
         for (Map.Entry<Node, Map<TopicPartition, ListOffsetRequest.PartitionData>> entry : timestampsToSearchByNode.entrySet()) {
             Node node = entry.getKey();
             final Map<TopicPartition, ListOffsetRequest.PartitionData> resetTimestamps = entry.getValue();
-            subscriptions.setResetPending(resetTimestamps.keySet(), time.milliseconds() + requestTimeoutMs);
+            subscriptions.setNextAllowedRetry(resetTimestamps.keySet(), time.milliseconds() + requestTimeoutMs);
 
             RequestFuture<ListOffsetResult> future = sendListOffsetRequest(node, resetTimestamps, false);
             future.addListener(new RequestFutureListener<ListOffsetResult>() {
                 @Override
                 public void onSuccess(ListOffsetResult result) {
                     if (!result.partitionsToRetry.isEmpty()) {
-                        subscriptions.resetFailed(result.partitionsToRetry, time.milliseconds() + retryBackoffMs);
+                        subscriptions.requestFailed(result.partitionsToRetry, time.milliseconds() + retryBackoffMs);
                         metadata.requestUpdate();
                     }
 
-                    for (Map.Entry<TopicPartition, OffsetData> fetchedOffset : result.fetchedOffsets.entrySet()) {
+                    for (Map.Entry<TopicPartition, ListOffsetData> fetchedOffset : result.fetchedOffsets.entrySet()) {
                         TopicPartition partition = fetchedOffset.getKey();
-                        OffsetData offsetData = fetchedOffset.getValue();
+                        ListOffsetData offsetData = fetchedOffset.getValue();
                         ListOffsetRequest.PartitionData requestedReset = resetTimestamps.get(partition);
                         resetOffsetIfNeeded(partition, requestedReset.timestamp, offsetData);
                     }
@@ -644,7 +680,7 @@ public class Fetcher<K, V> implements Closeable {
 
                 @Override
                 public void onFailure(RuntimeException e) {
-                    subscriptions.resetFailed(resetTimestamps.keySet(), time.milliseconds() + retryBackoffMs);
+                    subscriptions.requestFailed(resetTimestamps.keySet(), time.milliseconds() + retryBackoffMs);
                     metadata.requestUpdate();
 
                     if (!(e instanceof RetriableException) && !cachedListOffsetsException.compareAndSet(null, e))
@@ -655,6 +691,90 @@ public class Fetcher<K, V> implements Closeable {
     }
 
     /**
+     * For each partition which needs validation, make an asynchronous request to get the end-offsets for the partition
+     * with the epoch less than or equal to the epoch the partition last saw.
+     *
+     * Requests are grouped by Node for efficiency.
+     */
+    private void validateOffsetsAsync(Map<TopicPartition, SubscriptionState.FetchPosition> partitionsToValidate) {
+        final Map<Node, Map<TopicPartition, SubscriptionState.FetchPosition>> regrouped =
+                regroupFetchPositionsByLeader(partitionsToValidate);
+
+        regrouped.forEach((node, dataMap) -> {
+            if (node.isEmpty()) {
+                metadata.requestUpdate();
+                return;
+            }
+
+            subscriptions.setNextAllowedRetry(dataMap.keySet(), time.milliseconds() + requestTimeoutMs);
+
+            final Map<TopicPartition, Metadata.LeaderAndEpoch> cachedLeaderAndEpochs = partitionsToValidate.entrySet()
+                    .stream()
+                    .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().currentLeader));
+
+            RequestFuture<OffsetsForLeaderEpochClient.OffsetForEpochResult> future = offsetsForLeaderEpochClient.sendAsyncRequest(node, partitionsToValidate);
+            future.addListener(new RequestFutureListener<OffsetsForLeaderEpochClient.OffsetForEpochResult>() {
+                @Override
+                public void onSuccess(OffsetsForLeaderEpochClient.OffsetForEpochResult offsetsResult) {
+                    Map<TopicPartition, OffsetAndMetadata> truncationWithoutResetPolicy = new HashMap<>();
+                    if (!offsetsResult.partitionsToRetry().isEmpty()) {
+                        subscriptions.setNextAllowedRetry(offsetsResult.partitionsToRetry(), time.milliseconds() + retryBackoffMs);
+                        metadata.requestUpdate();
+                    }
+
+                    // For each OffsetsForLeader response, check if the end-offset is lower than our current offset
+                    // for the partition. If so, it means we have experienced log truncation and need to reposition
+                    // that partition's offset.
+                    offsetsResult.endOffsets().forEach((respTopicPartition, respEndOffset) -> {
+                        if (!subscriptions.isAssigned(respTopicPartition)) {
+                            log.debug("Ignoring OffsetsForLeader response for partition {} which is not currently assigned.", respTopicPartition);
+                            return;
+                        }
+
+                        if (subscriptions.awaitingValidation(respTopicPartition)) {
+                            SubscriptionState.FetchPosition currentPosition = subscriptions.position(respTopicPartition);
+                            Metadata.LeaderAndEpoch currentLeader = currentPosition.currentLeader;
+                            if (!currentLeader.equals(cachedLeaderAndEpochs.get(respTopicPartition))) {
+                                return;
+                            }
+
+                            if (respEndOffset.endOffset() < currentPosition.offset) {
+                                if (subscriptions.hasDefaultOffsetResetPolicy()) {
+                                    SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition(
+                                            respEndOffset.endOffset(), Optional.of(respEndOffset.leaderEpoch()), currentLeader);
+                                    log.info("Truncation detected for partition {}, resetting offset to {}", respTopicPartition, newPosition);
+                                    subscriptions.seek(respTopicPartition, newPosition);
+                                } else {
+                                    log.warn("Truncation detected for partition {}, but no reset policy is set", respTopicPartition);
+                                    truncationWithoutResetPolicy.put(respTopicPartition, new OffsetAndMetadata(
+                                            respEndOffset.endOffset(), Optional.of(respEndOffset.leaderEpoch()), null));
+                                }
+                            } else {
+                                // Offset is fine, clear the validation state
+                                subscriptions.validate(respTopicPartition);
+                            }
+                        }
+                    });
+
+                    if (!truncationWithoutResetPolicy.isEmpty()) {
+                        throw new LogTruncationException(truncationWithoutResetPolicy);
+                    }
+                }
+
+                @Override
+                public void onFailure(RuntimeException e) {
+                    subscriptions.requestFailed(dataMap.keySet(), time.milliseconds() + retryBackoffMs);
+                    metadata.requestUpdate();
+
+                    if (!(e instanceof RetriableException) && !cachedOffsetForLeaderException.compareAndSet(null, e)) {
+                        log.error("Discarding error in OffsetsForLeaderEpoch because another error is pending", e);
+                    }
+                }
+            });
+        });
+    }
+
+    /**
      * Search the offsets by target times for the specified partitions.
      *
      * @param timestampsToSearch the mapping between partitions and target time
@@ -671,7 +791,7 @@ public class Fetcher<K, V> implements Closeable {
             return RequestFuture.failure(new StaleMetadataException());
 
         final RequestFuture<ListOffsetResult> listOffsetRequestsFuture = new RequestFuture<>();
-        final Map<TopicPartition, OffsetData> fetchedTimestampOffsets = new HashMap<>();
+        final Map<TopicPartition, ListOffsetData> fetchedTimestampOffsets = new HashMap<>();
         final AtomicInteger remainingResponses = new AtomicInteger(timestampsToSearchByNode.size());
 
         for (Map.Entry<Node, Map<TopicPartition, ListOffsetRequest.PartitionData>> entry : timestampsToSearchByNode.entrySet()) {
@@ -712,8 +832,9 @@ public class Fetcher<K, V> implements Closeable {
      *                          that need metadata update or re-connect to the leader.
      */
     private Map<Node, Map<TopicPartition, ListOffsetRequest.PartitionData>> groupListOffsetRequests(
-            Map<TopicPartition, Long> timestampsToSearch, Set<TopicPartition> partitionsToRetry) {
-        final Map<Node, Map<TopicPartition, ListOffsetRequest.PartitionData>> timestampsToSearchByNode = new HashMap<>();
+            Map<TopicPartition, Long> timestampsToSearch,
+            Set<TopicPartition> partitionsToRetry) {
+        final Map<TopicPartition, ListOffsetRequest.PartitionData> partitionDataMap = new HashMap<>();
         for (Map.Entry<TopicPartition, Long> entry: timestampsToSearch.entrySet()) {
             TopicPartition tp  = entry.getKey();
             Optional<MetadataCache.PartitionInfoAndEpoch> currentInfo = metadata.partitionInfoIfCurrent(tp);
@@ -735,15 +856,11 @@ public class Fetcher<K, V> implements Closeable {
                         currentInfo.get().partitionInfo().leader(), tp);
                 partitionsToRetry.add(tp);
             } else {
-                Node node = currentInfo.get().partitionInfo().leader();
-                Map<TopicPartition, ListOffsetRequest.PartitionData> topicData =
-                        timestampsToSearchByNode.computeIfAbsent(node, n -> new HashMap<>());
-                ListOffsetRequest.PartitionData partitionData = new ListOffsetRequest.PartitionData(
-                        entry.getValue(), Optional.of(currentInfo.get().epoch()));
-                topicData.put(entry.getKey(), partitionData);
+                partitionDataMap.put(tp,
+                        new ListOffsetRequest.PartitionData(entry.getValue(), Optional.of(currentInfo.get().epoch())));
             }
         }
-        return timestampsToSearchByNode;
+        return regroupPartitionMapByNode(partitionDataMap);
     }
 
     /**
@@ -788,7 +905,7 @@ public class Fetcher<K, V> implements Closeable {
     private void handleListOffsetResponse(Map<TopicPartition, ListOffsetRequest.PartitionData> timestampsToSearch,
                                           ListOffsetResponse listOffsetResponse,
                                           RequestFuture<ListOffsetResult> future) {
-        Map<TopicPartition, OffsetData> fetchedOffsets = new HashMap<>();
+        Map<TopicPartition, ListOffsetData> fetchedOffsets = new HashMap<>();
         Set<TopicPartition> partitionsToRetry = new HashSet<>();
         Set<String> unauthorizedTopics = new HashSet<>();
 
@@ -812,7 +929,7 @@ public class Fetcher<K, V> implements Closeable {
                     log.debug("Handling v0 ListOffsetResponse response for {}. Fetched offset {}",
                             topicPartition, offset);
                     if (offset != ListOffsetResponse.UNKNOWN_OFFSET) {
-                        OffsetData offsetData = new OffsetData(offset, null, Optional.empty());
+                        ListOffsetData offsetData = new ListOffsetData(offset, null, Optional.empty());
                         fetchedOffsets.put(topicPartition, offsetData);
                     }
                 } else {
@@ -820,7 +937,7 @@ public class Fetcher<K, V> implements Closeable {
                     log.debug("Handling ListOffsetResponse response for {}. Fetched offset {}, timestamp {}",
                             topicPartition, partitionData.offset, partitionData.timestamp);
                     if (partitionData.offset != ListOffsetResponse.UNKNOWN_OFFSET) {
-                        OffsetData offsetData = new OffsetData(partitionData.offset, partitionData.timestamp,
+                        ListOffsetData offsetData = new ListOffsetData(partitionData.offset, partitionData.timestamp,
                                 partitionData.leaderEpoch);
                         fetchedOffsets.put(topicPartition, offsetData);
                     }
@@ -861,11 +978,11 @@ public class Fetcher<K, V> implements Closeable {
             future.complete(new ListOffsetResult(fetchedOffsets, partitionsToRetry));
     }
 
-    private static class ListOffsetResult {
-        private final Map<TopicPartition, OffsetData> fetchedOffsets;
+    static class ListOffsetResult {
+        private final Map<TopicPartition, ListOffsetData> fetchedOffsets;
         private final Set<TopicPartition> partitionsToRetry;
 
-        public ListOffsetResult(Map<TopicPartition, OffsetData> fetchedOffsets, Set<TopicPartition> partitionsNeedingRetry) {
+        public ListOffsetResult(Map<TopicPartition, ListOffsetData> fetchedOffsets, Set<TopicPartition> partitionsNeedingRetry) {
             this.fetchedOffsets = fetchedOffsets;
             this.partitionsToRetry = partitionsNeedingRetry;
         }
@@ -878,15 +995,13 @@ public class Fetcher<K, V> implements Closeable {
 
     private List<TopicPartition> fetchablePartitions() {
         Set<TopicPartition> exclude = new HashSet<>();
-        List<TopicPartition> fetchable = subscriptions.fetchablePartitions();
         if (nextInLineRecords != null && !nextInLineRecords.isFetched) {
             exclude.add(nextInLineRecords.partition);
         }
         for (CompletedFetch completedFetch : completedFetches) {
             exclude.add(completedFetch.partition);
         }
-        fetchable.removeAll(exclude);
-        return fetchable;
+        return subscriptions.fetchablePartitions(tp -> !exclude.contains(tp));
     }
 
     /**
@@ -895,12 +1010,17 @@ public class Fetcher<K, V> implements Closeable {
      */
     private Map<Node, FetchSessionHandler.FetchRequestData> prepareFetchRequests() {
         Map<Node, FetchSessionHandler.Builder> fetchable = new LinkedHashMap<>();
+
+        // Ensure the position has an up-to-date leader
+        subscriptions.assignedPartitions().forEach(
+            tp -> subscriptions.maybeValidatePosition(tp, metadata.leaderAndEpoch(tp)));
+
         for (TopicPartition partition : fetchablePartitions()) {
-            Node node = metadata.partitionInfoIfCurrent(partition)
-                    .map(MetadataCache.PartitionInfoAndEpoch::partitionInfo)
-                    .map(PartitionInfo::leader)
-                    .orElse(null);
-            if (node == null) {
+            SubscriptionState.FetchPosition position = this.subscriptions.position(partition);
+            Metadata.LeaderAndEpoch leaderAndEpoch = position.currentLeader;
+            Node node = leaderAndEpoch.leader;
+
+            if (node == null || node.isEmpty()) {
                 metadata.requestUpdate();
             } else if (client.isUnavailable(node)) {
                 client.maybeThrowAuthFailure(node);
@@ -912,26 +1032,26 @@ public class Fetcher<K, V> implements Closeable {
                 log.trace("Skipping fetch for partition {} because there is an in-flight request to {}", partition, node);
             } else {
                 // if there is a leader and no in-flight requests, issue a new fetch
-                FetchSessionHandler.Builder builder = fetchable.get(node);
+                FetchSessionHandler.Builder builder = fetchable.get(leaderAndEpoch.leader);
                 if (builder == null) {
-                    FetchSessionHandler handler = sessionHandler(node.id());
+                    int id = leaderAndEpoch.leader.id();
+                    FetchSessionHandler handler = sessionHandler(id);
                     if (handler == null) {
-                        handler = new FetchSessionHandler(logContext, node.id());
-                        sessionHandlers.put(node.id(), handler);
+                        handler = new FetchSessionHandler(logContext, id);
+                        sessionHandlers.put(id, handler);
                     }
                     builder = handler.newBuilder();
-                    fetchable.put(node, builder);
+                    fetchable.put(leaderAndEpoch.leader, builder);
                 }
 
-                long position = this.subscriptions.position(partition);
-                Optional<Integer> leaderEpoch = this.metadata.lastSeenLeaderEpoch(partition);
-                builder.add(partition, new FetchRequest.PartitionData(position, FetchRequest.INVALID_LOG_START_OFFSET,
-                    this.fetchSize, leaderEpoch));
+                builder.add(partition, new FetchRequest.PartitionData(position.offset,
+                        FetchRequest.INVALID_LOG_START_OFFSET, this.fetchSize, leaderAndEpoch.epoch));
 
-                log.debug("Added {} fetch request for partition {} at offset {} to node {}", isolationLevel,
-                    partition, position, node);
+                log.debug("Added {} fetch request for partition {} at position {} to node {}", isolationLevel,
+                    partition, position, leaderAndEpoch.leader);
             }
         }
+
         Map<Node, FetchSessionHandler.FetchRequestData> reqs = new LinkedHashMap<>();
         for (Map.Entry<Node, FetchSessionHandler.Builder> entry : fetchable.entrySet()) {
             reqs.put(entry.getKey(), entry.getValue().build());
@@ -939,6 +1059,21 @@ public class Fetcher<K, V> implements Closeable {
         return reqs;
     }
 
+    private Map<Node, Map<TopicPartition, SubscriptionState.FetchPosition>> regroupFetchPositionsByLeader(
+            Map<TopicPartition, SubscriptionState.FetchPosition> partitionMap) {
+        return partitionMap.entrySet()
+                .stream()
+                .collect(Collectors.groupingBy(entry -> entry.getValue().currentLeader.leader,
+                        Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
+    }
+
+    private <T> Map<Node, Map<TopicPartition, T>> regroupPartitionMapByNode(Map<TopicPartition, T> partitionMap) {
+        return partitionMap.entrySet()
+                .stream()
+                .collect(Collectors.groupingBy(entry -> metadata.fetch().leaderFor(entry.getKey()),
+                        Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)));
+    }
+
     /**
      * The callback for fetch completion
      */
@@ -957,8 +1092,8 @@ public class Fetcher<K, V> implements Closeable {
             } else if (error == Errors.NONE) {
                 // we are interested in this fetch only if the beginning offset matches the
                 // current consumed position
-                Long position = subscriptions.position(tp);
-                if (position == null || position != fetchOffset) {
+                SubscriptionState.FetchPosition position = subscriptions.position(tp);
+                if (position == null || position.offset != fetchOffset) {
                     log.debug("Discarding stale fetch response for partition {} since its offset {} does not match " +
                             "the expected offset {}", tp, fetchOffset, position);
                     return null;
@@ -1011,7 +1146,7 @@ public class Fetcher<K, V> implements Closeable {
                 log.warn("Received unknown topic or partition error in fetch for partition {}", tp);
                 this.metadata.requestUpdate();
             } else if (error == Errors.OFFSET_OUT_OF_RANGE) {
-                if (fetchOffset != subscriptions.position(tp)) {
+                if (fetchOffset != subscriptions.position(tp).offset) {
                     log.debug("Discarding stale fetch response for partition {} since the fetched offset {} " +
                             "does not match the current offset {}", tp, fetchOffset, subscriptions.position(tp));
                 } else if (subscriptions.hasDefaultOffsetResetPolicy()) {
@@ -1138,6 +1273,7 @@ public class Fetcher<K, V> implements Closeable {
         private Record lastRecord;
         private CloseableIterator<Record> records;
         private long nextFetchOffset;
+        private Optional<Integer> lastEpoch;
         private boolean isFetched = false;
         private Exception cachedRecordException = null;
         private boolean corruptLastRecord = false;
@@ -1149,6 +1285,7 @@ public class Fetcher<K, V> implements Closeable {
             this.completedFetch = completedFetch;
             this.batches = batches;
             this.nextFetchOffset = completedFetch.fetchedOffset;
+            this.lastEpoch = Optional.empty();
             this.abortedProducerIds = new HashSet<>();
             this.abortedTransactions = abortedTransactions(completedFetch.partitionData);
         }
@@ -1214,6 +1351,9 @@ public class Fetcher<K, V> implements Closeable {
                     }
 
                     currentBatch = batches.next();
+                    lastEpoch = currentBatch.partitionLeaderEpoch() == RecordBatch.NO_PARTITION_LEADER_EPOCH ?
+                            Optional.empty() : Optional.of(currentBatch.partitionLeaderEpoch());
+
                     maybeEnsureValid(currentBatch);
 
                     if (isolationLevel == IsolationLevel.READ_COMMITTED && currentBatch.hasProducerId()) {
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/OffsetsForLeaderEpochClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/OffsetsForLeaderEpochClient.java
new file mode 100644
index 0000000..9ffedd1
--- /dev/null
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/OffsetsForLeaderEpochClient.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.TopicAuthorizationException;
+import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.AbstractRequest;
+import org.apache.kafka.common.requests.EpochEndOffset;
+import org.apache.kafka.common.requests.OffsetsForLeaderEpochRequest;
+import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse;
+import org.apache.kafka.common.utils.LogContext;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Convenience class for making asynchronous requests to the OffsetsForLeaderEpoch API
+ */
+public class OffsetsForLeaderEpochClient extends AsyncClient<
+        Map<TopicPartition, SubscriptionState.FetchPosition>,
+        OffsetsForLeaderEpochRequest,
+        OffsetsForLeaderEpochResponse,
+        OffsetsForLeaderEpochClient.OffsetForEpochResult> {
+
+    OffsetsForLeaderEpochClient(ConsumerNetworkClient client, LogContext logContext) {
+        super(client, logContext);
+    }
+
+    @Override
+    protected AbstractRequest.Builder<OffsetsForLeaderEpochRequest> prepareRequest(
+            Node node, Map<TopicPartition, SubscriptionState.FetchPosition> requestData) {
+        Map<TopicPartition, OffsetsForLeaderEpochRequest.PartitionData> partitionData = new HashMap<>(requestData.size());
+        requestData.forEach((topicPartition, fetchPosition) -> fetchPosition.offsetEpoch.ifPresent(
+            fetchEpoch -> partitionData.put(topicPartition,
+                new OffsetsForLeaderEpochRequest.PartitionData(fetchPosition.currentLeader.epoch, fetchEpoch))));
+
+        return new OffsetsForLeaderEpochRequest.Builder(ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion(), partitionData);
+    }
+
+    @Override
+    protected OffsetForEpochResult handleResponse(
+            Node node,
+            Map<TopicPartition, SubscriptionState.FetchPosition> requestData,
+            OffsetsForLeaderEpochResponse response) {
+
+        Set<TopicPartition> partitionsToRetry = new HashSet<>();
+        Set<String> unauthorizedTopics = new HashSet<>();
+        Map<TopicPartition, EpochEndOffset> endOffsets = new HashMap<>();
+
+        for (TopicPartition topicPartition : requestData.keySet()) {
+            EpochEndOffset epochEndOffset = response.responses().get(topicPartition);
+            if (epochEndOffset == null) {
+                logger().warn("Missing partition {} from response, ignoring", topicPartition);
+                partitionsToRetry.add(topicPartition);
+                continue;
+            }
+            Errors error = epochEndOffset.error();
+            if (error == Errors.NONE) {
+                logger().debug("Handling OffsetsForLeaderEpoch response for {}. Got offset {} for epoch {}",
+                        topicPartition, epochEndOffset.endOffset(), epochEndOffset.leaderEpoch());
+                endOffsets.put(topicPartition, epochEndOffset);
+            } else if (error == Errors.NOT_LEADER_FOR_PARTITION ||
+                    error == Errors.REPLICA_NOT_AVAILABLE ||
+                    error == Errors.KAFKA_STORAGE_ERROR ||
+                    error == Errors.OFFSET_NOT_AVAILABLE ||
+                    error == Errors.LEADER_NOT_AVAILABLE) {
+                logger().debug("Attempt to fetch offsets for partition {} failed due to {}, retrying.",
+                        topicPartition, error);
+                partitionsToRetry.add(topicPartition);
+            } else if (error == Errors.FENCED_LEADER_EPOCH ||
+                    error == Errors.UNKNOWN_LEADER_EPOCH) {
+                logger().debug("Attempt to fetch offsets for partition {} failed due to {}, retrying.",
+                        topicPartition, error);
+                partitionsToRetry.add(topicPartition);
+            } else if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) {
+                logger().warn("Received unknown topic or partition error in ListOffset request for partition {}", topicPartition);
+                partitionsToRetry.add(topicPartition);
+            } else if (error == Errors.TOPIC_AUTHORIZATION_FAILED) {
+                unauthorizedTopics.add(topicPartition.topic());
+            } else {
+                logger().warn("Attempt to fetch offsets for partition {} failed due to: {}, retrying.", topicPartition, error.message());
+                partitionsToRetry.add(topicPartition);
+            }
+        }
+
+        if (!unauthorizedTopics.isEmpty())
+            throw new TopicAuthorizationException(unauthorizedTopics);
+        else
+            return new OffsetForEpochResult(endOffsets, partitionsToRetry);
+    }
+
+    public static class OffsetForEpochResult {
+        private final Map<TopicPartition, EpochEndOffset> endOffsets;
+        private final Set<TopicPartition> partitionsToRetry;
+
+        OffsetForEpochResult(Map<TopicPartition, EpochEndOffset> endOffsets, Set<TopicPartition> partitionsNeedingRetry) {
+            this.endOffsets = endOffsets;
+            this.partitionsToRetry = partitionsNeedingRetry;
+        }
+
+        public Map<TopicPartition, EpochEndOffset> endOffsets() {
+            return endOffsets;
+        }
+
+        public Set<TopicPartition> partitionsToRetry() {
+            return partitionsToRetry;
+        }
+    }
+}
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index c28ac87..3909421 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -16,22 +16,27 @@
  */
 package org.apache.kafka.clients.consumer.internals;
 
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.NoOffsetForPartitionException;
 import org.apache.kafka.clients.consumer.OffsetAndMetadata;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.internals.PartitionStates;
 import org.apache.kafka.common.requests.IsolationLevel;
 import org.apache.kafka.common.utils.LogContext;
 import org.slf4j.Logger;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Predicate;
@@ -45,7 +50,7 @@ import java.util.stream.Collectors;
  * or with {@link #assignFromSubscribed(Collection)} (automatic assignment from subscription).
  *
  * Once assigned, the partition is not considered "fetchable" until its initial position has
- * been set with {@link #seek(TopicPartition, long)}. Fetchable partitions track a fetch
+ * been set with {@link #seek(TopicPartition, FetchPosition)}. Fetchable partitions track a fetch
  * position which is used to set the offset of the next fetch, and a consumed position
  * which is the last offset that has been returned to the user. You can suspend fetching
  * from a partition through {@link #pause(TopicPartition)} without affecting the fetched/consumed
@@ -326,8 +331,16 @@ public class SubscriptionState {
         return state;
     }
 
+    public void seek(TopicPartition tp, FetchPosition position) {
+        assignedState(tp).seek(position);
+    }
+
+    public void seekAndValidate(TopicPartition tp, FetchPosition position) {
+        assignedState(tp).seekAndValidate(position);
+    }
+
     public void seek(TopicPartition tp, long offset) {
-        assignedState(tp).seek(offset);
+        seek(tp, new FetchPosition(offset, Optional.empty(), new Metadata.LeaderAndEpoch(Node.noNode(), Optional.empty())));
     }
 
     /**
@@ -345,33 +358,52 @@ public class SubscriptionState {
         return this.assignment.size();
     }
 
-    public List<TopicPartition> fetchablePartitions() {
-        return collectPartitions(TopicPartitionState::isFetchable, Collectors.toList());
+    public List<TopicPartition> fetchablePartitions(Predicate<TopicPartition> isAvailable) {
+        return assignment.stream()
+                .filter(tpState -> isAvailable.test(tpState.topicPartition()) && tpState.value().isFetchable())
+                .map(PartitionStates.PartitionState::topicPartition)
+                .collect(Collectors.toList());
     }
 
     public boolean partitionsAutoAssigned() {
         return this.subscriptionType == SubscriptionType.AUTO_TOPICS || this.subscriptionType == SubscriptionType.AUTO_PATTERN;
     }
 
-    public void position(TopicPartition tp, long offset) {
-        assignedState(tp).position(offset);
+    public void position(TopicPartition tp, FetchPosition position) {
+        assignedState(tp).position(position);
+    }
+
+    public boolean maybeValidatePosition(TopicPartition tp, Metadata.LeaderAndEpoch leaderAndEpoch) {
+        return assignedState(tp).maybeValidatePosition(leaderAndEpoch);
+    }
+
+    public boolean awaitingValidation(TopicPartition tp) {
+        return assignedState(tp).awaitingValidation();
+    }
+
+    public void validate(TopicPartition tp) {
+        assignedState(tp).validate();
+    }
+
+    public FetchPosition validPosition(TopicPartition tp) {
+        return assignedState(tp).validPosition();
     }
 
-    public Long position(TopicPartition tp) {
-        return assignedState(tp).position;
+    public FetchPosition position(TopicPartition tp) {
+        return assignedState(tp).position();
     }
 
     public Long partitionLag(TopicPartition tp, IsolationLevel isolationLevel) {
         TopicPartitionState topicPartitionState = assignedState(tp);
         if (isolationLevel == IsolationLevel.READ_COMMITTED)
-            return topicPartitionState.lastStableOffset == null ? null : topicPartitionState.lastStableOffset - topicPartitionState.position;
+            return topicPartitionState.lastStableOffset == null ? null : topicPartitionState.lastStableOffset - topicPartitionState.position.offset;
         else
-            return topicPartitionState.highWatermark == null ? null : topicPartitionState.highWatermark - topicPartitionState.position;
+            return topicPartitionState.highWatermark == null ? null : topicPartitionState.highWatermark - topicPartitionState.position.offset;
     }
 
     public Long partitionLead(TopicPartition tp) {
         TopicPartitionState topicPartitionState = assignedState(tp);
-        return topicPartitionState.logStartOffset == null ? null : topicPartitionState.position - topicPartitionState.logStartOffset;
+        return topicPartitionState.logStartOffset == null ? null : topicPartitionState.position.offset - topicPartitionState.logStartOffset;
     }
 
     public void updateHighWatermark(TopicPartition tp, long highWatermark) {
@@ -389,8 +421,10 @@ public class SubscriptionState {
     public Map<TopicPartition, OffsetAndMetadata> allConsumed() {
         Map<TopicPartition, OffsetAndMetadata> allConsumed = new HashMap<>();
         assignment.stream().forEach(state -> {
-            if (state.value().hasValidPosition())
-                allConsumed.put(state.topicPartition(), new OffsetAndMetadata(state.value().position));
+            TopicPartitionState partitionState = state.value();
+            if (partitionState.hasValidPosition())
+                allConsumed.put(state.topicPartition(), new OffsetAndMetadata(partitionState.position.offset,
+                        partitionState.position.offsetEpoch, ""));
         });
         return allConsumed;
     }
@@ -403,9 +437,9 @@ public class SubscriptionState {
         requestOffsetReset(partition, defaultResetStrategy);
     }
 
-    public void setResetPending(Set<TopicPartition> partitions, long nextAllowResetTimeMs) {
+    public void setNextAllowedRetry(Set<TopicPartition> partitions, long nextAllowResetTimeMs) {
         for (TopicPartition partition : partitions) {
-            assignedState(partition).setResetPending(nextAllowResetTimeMs);
+            assignedState(partition).setNextAllowedRetry(nextAllowResetTimeMs);
         }
     }
 
@@ -426,7 +460,7 @@ public class SubscriptionState {
     }
 
     public Set<TopicPartition> missingFetchPositions() {
-        return collectPartitions(TopicPartitionState::isMissingPosition, Collectors.toSet());
+        return collectPartitions(state -> !state.hasPosition(), Collectors.toSet());
     }
 
     private <T extends Collection<TopicPartition>> T collectPartitions(Predicate<TopicPartitionState> filter, Collector<TopicPartition, ?, T> collector) {
@@ -436,12 +470,13 @@ public class SubscriptionState {
                 .collect(collector);
     }
 
+
     public void resetMissingPositions() {
         final Set<TopicPartition> partitionsWithNoOffsets = new HashSet<>();
         assignment.stream().forEach(state -> {
             TopicPartition tp = state.topicPartition();
             TopicPartitionState partitionState = state.value();
-            if (partitionState.isMissingPosition()) {
+            if (!partitionState.hasPosition()) {
                 if (defaultResetStrategy == OffsetResetStrategy.NONE)
                     partitionsWithNoOffsets.add(tp);
                 else
@@ -454,7 +489,12 @@ public class SubscriptionState {
     }
 
     public Set<TopicPartition> partitionsNeedingReset(long nowMs) {
-        return collectPartitions(state -> state.awaitingReset() && state.isResetAllowed(nowMs),
+        return collectPartitions(state -> state.awaitingReset() && !state.awaitingRetryBackoff(nowMs),
+                Collectors.toSet());
+    }
+
+    public Set<TopicPartition> partitionsNeedingValidation(long nowMs) {
+        return collectPartitions(state -> state.awaitingValidation() && !state.awaitingRetryBackoff(nowMs),
                 Collectors.toSet());
     }
 
@@ -482,9 +522,9 @@ public class SubscriptionState {
         assignedState(tp).resume();
     }
 
-    public void resetFailed(Set<TopicPartition> partitions, long nextRetryTimeMs) {
+    public void requestFailed(Set<TopicPartition> partitions, long nextRetryTimeMs) {
         for (TopicPartition partition : partitions)
-            assignedState(partition).resetFailed(nextRetryTimeMs);
+            assignedState(partition).requestFailed(nextRetryTimeMs);
     }
 
     public void movePartitionToEnd(TopicPartition tp) {
@@ -503,68 +543,148 @@ public class SubscriptionState {
     }
 
     private static class TopicPartitionState {
-        private Long position; // last consumed position
+
+        private FetchState fetchState;
+        private FetchPosition position; // last consumed position
+
         private Long highWatermark; // the high watermark from last fetch
         private Long logStartOffset; // the log start offset
         private Long lastStableOffset;
         private boolean paused;  // whether this partition has been paused by the user
         private OffsetResetStrategy resetStrategy;  // the strategy to use if the offset needs resetting
-        private Long nextAllowedRetryTimeMs;
+        private Long nextRetryTimeMs;
+
 
         TopicPartitionState() {
             this.paused = false;
+            this.fetchState = FetchStates.INITIALIZING;
             this.position = null;
             this.highWatermark = null;
             this.logStartOffset = null;
             this.lastStableOffset = null;
             this.resetStrategy = null;
-            this.nextAllowedRetryTimeMs = null;
+            this.nextRetryTimeMs = null;
+        }
+
+        private void transitionState(FetchState newState, Runnable runIfTransitioned) {
+            FetchState nextState = this.fetchState.transitionTo(newState);
+            if (nextState.equals(newState)) {
+                this.fetchState = nextState;
+                runIfTransitioned.run();
+            }
         }
 
         private void reset(OffsetResetStrategy strategy) {
-            this.resetStrategy = strategy;
-            this.position = null;
-            this.nextAllowedRetryTimeMs = null;
+            transitionState(FetchStates.AWAIT_RESET, () -> {
+                this.resetStrategy = strategy;
+                this.nextRetryTimeMs = null;
+            });
         }
 
-        private boolean isResetAllowed(long nowMs) {
-            return nextAllowedRetryTimeMs == null || nowMs >= nextAllowedRetryTimeMs;
+        private boolean maybeValidatePosition(Metadata.LeaderAndEpoch currentLeaderAndEpoch) {
+            if (this.fetchState.equals(FetchStates.AWAIT_RESET)) {
+                return false;
+            }
+
+            if (currentLeaderAndEpoch.equals(Metadata.LeaderAndEpoch.noLeaderOrEpoch())) {
+                // Ignore empty LeaderAndEpochs
+                return false;
+            }
+
+            if (position != null && !position.safeToFetchFrom(currentLeaderAndEpoch)) {
+                FetchPosition newPosition = new FetchPosition(position.offset, position.offsetEpoch, currentLeaderAndEpoch);
+                validatePosition(newPosition);
+            }
+            return this.fetchState.equals(FetchStates.AWAIT_VALIDATION);
+        }
+
+        private void validatePosition(FetchPosition position) {
+            if (position.offsetEpoch.isPresent()) {
+                transitionState(FetchStates.AWAIT_VALIDATION, () -> {
+                    this.position = position;
+                    this.nextRetryTimeMs = null;
+                });
+            } else {
+                // If we have no epoch information for the current position, then we can skip validation
+                transitionState(FetchStates.FETCHING, () -> {
+                    this.position = position;
+                    this.nextRetryTimeMs = null;
+                });
+            }
+        }
+
+        /**
+         * Clear the awaiting validation state and enter fetching.
+         */
+        private void validate() {
+            if (hasPosition()) {
+                transitionState(FetchStates.FETCHING, () -> {
+                    this.nextRetryTimeMs = null;
+                });
+            }
+        }
+
+        private boolean awaitingValidation() {
+            return fetchState.equals(FetchStates.AWAIT_VALIDATION);
+        }
+
+        private boolean awaitingRetryBackoff(long nowMs) {
+            return nextRetryTimeMs != null && nowMs < nextRetryTimeMs;
         }
 
         private boolean awaitingReset() {
-            return resetStrategy != null;
+            return fetchState.equals(FetchStates.AWAIT_RESET);
         }
 
-        private void setResetPending(long nextAllowedRetryTimeMs) {
-            this.nextAllowedRetryTimeMs = nextAllowedRetryTimeMs;
+        private void setNextAllowedRetry(long nextAllowedRetryTimeMs) {
+            this.nextRetryTimeMs = nextAllowedRetryTimeMs;
         }
 
-        private void resetFailed(long nextAllowedRetryTimeMs) {
-            this.nextAllowedRetryTimeMs = nextAllowedRetryTimeMs;
+        private void requestFailed(long nextAllowedRetryTimeMs) {
+            this.nextRetryTimeMs = nextAllowedRetryTimeMs;
         }
 
         private boolean hasValidPosition() {
-            return position != null;
+            return fetchState.hasValidPosition();
         }
 
-        private boolean isMissingPosition() {
-            return !hasValidPosition() && !awaitingReset();
+        private boolean hasPosition() {
+            return fetchState.hasPosition();
         }
 
         private boolean isPaused() {
             return paused;
         }
 
-        private void seek(long offset) {
-            this.position = offset;
-            this.resetStrategy = null;
-            this.nextAllowedRetryTimeMs = null;
+        private void seek(FetchPosition position) {
+            transitionState(FetchStates.FETCHING, () -> {
+                this.position = position;
+                this.resetStrategy = null;
+                this.nextRetryTimeMs = null;
+            });
+        }
+
+        private void seekAndValidate(FetchPosition fetchPosition) {
+            seek(fetchPosition);
+            validatePosition(fetchPosition);
         }
 
-        private void position(long offset) {
+        private void position(FetchPosition position) {
             if (!hasValidPosition())
                 throw new IllegalStateException("Cannot set a new position without a valid current position");
-            this.position = offset;
+            this.position = position;
+        }
+
+        private FetchPosition validPosition() {
+            if (hasValidPosition()) {
+                return position;
+            } else {
+                return null;
+            }
+        }
+
+        private FetchPosition position() {
+            return position;
         }
 
         private void pause() {
@@ -581,5 +701,149 @@ public class SubscriptionState {
 
     }
 
+    /**
+     * The fetch state of a partition. This class is used to determine valid state transitions and expose the some of
+     * the behavior of the current fetch state. Actual state variables are stored in the {@link TopicPartitionState}.
+     */
+    interface FetchState {
+        default FetchState transitionTo(FetchState newState) {
+            if (validTransitions().contains(newState)) {
+                return newState;
+            } else {
+                return this;
+            }
+        }
+
+        Collection<FetchState> validTransitions();
+
+        boolean hasPosition();
+
+        boolean hasValidPosition();
+    }
+
+    /**
+     * An enumeration of all the possible fetch states. The state transitions are encoded in the values returned by
+     * {@link FetchState#validTransitions}.
+     */
+    enum FetchStates implements FetchState {
+        INITIALIZING() {
+            @Override
+            public Collection<FetchState> validTransitions() {
+                return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET, FetchStates.AWAIT_VALIDATION);
+            }
 
+            @Override
+            public boolean hasPosition() {
+                return false;
+            }
+
+            @Override
+            public boolean hasValidPosition() {
+                return false;
+            }
+        },
+
+        FETCHING() {
+            @Override
+            public Collection<FetchState> validTransitions() {
+                return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET, FetchStates.AWAIT_VALIDATION);
+            }
+
+            @Override
+            public boolean hasPosition() {
+                return true;
+            }
+
+            @Override
+            public boolean hasValidPosition() {
+                return true;
+            }
+        },
+
+        AWAIT_RESET() {
+            @Override
+            public Collection<FetchState> validTransitions() {
+                return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET);
+            }
+
+            @Override
+            public boolean hasPosition() {
+                return true;
+            }
+
+            @Override
+            public boolean hasValidPosition() {
+                return false;
+            }
+        },
+
+        AWAIT_VALIDATION() {
+            @Override
+            public Collection<FetchState> validTransitions() {
+                return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET, FetchStates.AWAIT_VALIDATION);
+            }
+
+            @Override
+            public boolean hasPosition() {
+                return true;
+            }
+
+            @Override
+            public boolean hasValidPosition() {
+                return false;
+            }
+        }
+    }
+
+    /**
+     * Represents the position of a partition subscription.
+     *
+     * This includes the offset and epoch from the last record in
+     * the batch from a FetchResponse. It also includes the leader epoch at the time the batch was consumed.
+     *
+     * The last fetch epoch is used to
+     */
+    public static class FetchPosition {
+        public final long offset;
+        public final Optional<Integer> offsetEpoch;
+        public final Metadata.LeaderAndEpoch currentLeader;
+
+        public FetchPosition(long offset, Optional<Integer> offsetEpoch, Metadata.LeaderAndEpoch currentLeader) {
+            this.offset = offset;
+            this.offsetEpoch = Objects.requireNonNull(offsetEpoch);
+            this.currentLeader = Objects.requireNonNull(currentLeader);
+        }
+
+        /**
+         * Test if it is "safe" to fetch from a given leader and epoch. This effectively is testing if
+         * {@link Metadata.LeaderAndEpoch} known to the subscription is equal to the one supplied by the caller.
+         */
+        public boolean safeToFetchFrom(Metadata.LeaderAndEpoch leaderAndEpoch) {
+            return !currentLeader.leader.isEmpty() && currentLeader.equals(leaderAndEpoch);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            FetchPosition that = (FetchPosition) o;
+            return offset == that.offset &&
+                    offsetEpoch.equals(that.offsetEpoch) &&
+                    currentLeader.equals(that.currentLeader);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(offset, offsetEpoch, currentLeader);
+        }
+
+        @Override
+        public String toString() {
+            return "FetchPosition{" +
+                    "offset=" + offset +
+                    ", offsetEpoch=" + offsetEpoch +
+                    ", currentLeader=" + currentLeader +
+                    '}';
+        }
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index 5d2a5cf..b5f7ab2 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -1038,4 +1038,16 @@ public final class Utils {
         }
         return result;
     }
+
+    public static <K1, V1, K2, V2> Map<K2, V2> transformMap(
+            Map<? extends K1, ? extends V1> map,
+            Function<K1, K2> keyMapper,
+            Function<V1, V2> valueMapper) {
+        return map.entrySet().stream().collect(
+            Collectors.toMap(
+                entry -> keyMapper.apply(entry.getKey()),
+                entry -> valueMapper.apply(entry.getValue())
+            )
+        );
+    }
 }
diff --git a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
index 39e8c3d..0e3d191 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
@@ -35,7 +35,6 @@ import java.net.InetSocketAddress;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.Optional;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -205,9 +204,7 @@ public class MetadataTest {
 
         // First epoch seen, accept it
         {
-            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts,
-                (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                        new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(100), replicas, isr, offlineReplicas));
+            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 100);
             metadata.update(metadataResponse, 10L);
             assertNotNull(metadata.fetch().partition(tp));
             assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100);
@@ -215,9 +212,9 @@ public class MetadataTest {
 
         // Fake an empty ISR, but with an older epoch, should reject it
         {
-            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts,
+            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 99,
                 (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                        new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(99), replicas, Collections.emptyList(), offlineReplicas));
+                        new MetadataResponse.PartitionMetadata(error, partition, leader, leaderEpoch, replicas, Collections.emptyList(), offlineReplicas));
             metadata.update(metadataResponse, 20L);
             assertEquals(metadata.fetch().partition(tp).inSyncReplicas().length, 1);
             assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100);
@@ -225,9 +222,9 @@ public class MetadataTest {
 
         // Fake an empty ISR, with same epoch, accept it
         {
-            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts,
+            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 100,
                 (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                        new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(100), replicas, Collections.emptyList(), offlineReplicas));
+                        new MetadataResponse.PartitionMetadata(error, partition, leader, leaderEpoch, replicas, Collections.emptyList(), offlineReplicas));
             metadata.update(metadataResponse, 20L);
             assertEquals(metadata.fetch().partition(tp).inSyncReplicas().length, 0);
             assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100);
@@ -235,8 +232,7 @@ public class MetadataTest {
 
         // Empty metadata response, should not keep old partition but should keep the last-seen epoch
         {
-            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith(
-                    "dummy", 1, Collections.emptyMap(), Collections.emptyMap(), MetadataResponse.PartitionMetadata::new);
+            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.emptyMap());
             metadata.update(metadataResponse, 20L);
             assertNull(metadata.fetch().partition(tp));
             assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100);
@@ -244,9 +240,7 @@ public class MetadataTest {
 
         // Back in the metadata, with old epoch, should not get added
         {
-            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts,
-                (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                        new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(99), replicas, isr, offlineReplicas));
+            MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 99);
             metadata.update(metadataResponse, 10L);
             assertNull(metadata.fetch().partition(tp));
             assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100);
@@ -284,9 +278,7 @@ public class MetadataTest {
         assertTrue(metadata.updateLastSeenEpochIfNewer(tp, 99));
 
         // Update epoch to 100
-        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts,
-            (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                    new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(100), replicas, isr, offlineReplicas));
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 100);
         metadata.update(metadataResponse, 10L);
         assertNotNull(metadata.fetch().partition(tp));
         assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100);
@@ -308,9 +300,7 @@ public class MetadataTest {
         assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 101);
 
         // Metadata with equal or newer epoch is accepted
-        metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts,
-            (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                    new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(101), replicas, isr, offlineReplicas));
+        metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 101);
         metadata.update(metadataResponse, 30L);
         assertNotNull(metadata.fetch().partition(tp));
         assertEquals(metadata.fetch().partitionCountForTopic("topic-1").longValue(), 5);
@@ -321,9 +311,7 @@ public class MetadataTest {
     @Test
     public void testNoEpoch() {
         metadata.update(emptyMetadataResponse(), 0L);
-        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1),
-            (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.empty(), replicas, isr, offlineReplicas));
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1));
         metadata.update(metadataResponse, 10L);
 
         TopicPartition tp = new TopicPartition("topic-1", 0);
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
index e17a6ef..f31ca52 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java
@@ -639,6 +639,26 @@ public class KafkaConsumerTest {
     }
 
     @Test
+    public void testOffsetIsValidAfterSeek() {
+        Time time = new MockTime();
+        SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.LATEST);
+        ConsumerMetadata metadata = createMetadata(subscription);
+        MockClient client = new MockClient(time, metadata);
+
+        initMetadata(client, Collections.singletonMap(topic, 1));
+        Node node = metadata.fetch().nodes().get(0);
+
+        PartitionAssignor assignor = new RoundRobinAssignor();
+
+        KafkaConsumer<String, String> consumer = newConsumer(time, client, subscription, metadata, assignor,
+                true, groupId);
+        consumer.assign(singletonList(tp0));
+        consumer.seek(tp0, 20L);
+        consumer.poll(Duration.ZERO);
+        assertEquals(subscription.validPosition(tp0).offset, 20L);
+    }
+
+    @Test
     public void testCommitsFetchedDuringAssign() {
         long offset1 = 10000;
         long offset2 = 20000;
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
index 49bbcec..3cdbbfa 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java
@@ -1722,7 +1722,31 @@ public class ConsumerCoordinatorTest {
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
-        assertEquals(100L, subscriptions.position(t1p).longValue());
+        assertEquals(100L, subscriptions.position(t1p).offset);
+    }
+
+    @Test
+    public void testRefreshOffsetWithValidation() {
+        client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE));
+        coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE));
+
+        subscriptions.assignFromUser(singleton(t1p));
+
+        // Initial leader epoch of 4
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("kafka-cluster", 1,
+                Collections.emptyMap(), singletonMap(topic1, 1), tp -> 4);
+        client.updateMetadata(metadataResponse);
+
+        // Load offsets from previous epoch
+        client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L, 3));
+        coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE));
+
+        // Offset gets loaded, but requires validation
+        assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
+        assertFalse(subscriptions.hasAllFetchPositions());
+        assertTrue(subscriptions.awaitingValidation(t1p));
+        assertEquals(subscriptions.position(t1p).offset, 100L);
+        assertNull(subscriptions.validPosition(t1p));
     }
 
     @Test
@@ -1756,7 +1780,7 @@ public class ConsumerCoordinatorTest {
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
-        assertEquals(100L, subscriptions.position(t1p).longValue());
+        assertEquals(100L, subscriptions.position(t1p).offset);
     }
 
     @Test
@@ -1797,7 +1821,7 @@ public class ConsumerCoordinatorTest {
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
-        assertEquals(100L, subscriptions.position(t1p).longValue());
+        assertEquals(100L, subscriptions.position(t1p).offset);
     }
 
     @Test
@@ -1825,7 +1849,7 @@ public class ConsumerCoordinatorTest {
 
         assertEquals(Collections.emptySet(), subscriptions.missingFetchPositions());
         assertTrue(subscriptions.hasAllFetchPositions());
-        assertEquals(500L, subscriptions.position(t1p).longValue());
+        assertEquals(500L, subscriptions.position(t1p).offset);
         assertTrue(coordinator.coordinatorUnknown());
     }
 
@@ -2023,7 +2047,7 @@ public class ConsumerCoordinatorTest {
         time.sleep(autoCommitIntervalMs); // sleep for a while to ensure auto commit does happen
         coordinator.maybeAutoCommitOffsetsAsync(time.milliseconds());
         assertFalse(coordinator.coordinatorUnknown());
-        assertEquals(100L, subscriptions.position(t1p).longValue());
+        assertEquals(100L, subscriptions.position(t1p).offset);
     }
 
     private ConsumerCoordinator prepareCoordinatorForCloseTest(final boolean useGroupManagement,
@@ -2208,6 +2232,12 @@ public class ConsumerCoordinatorTest {
         return new OffsetFetchResponse(Errors.NONE, singletonMap(tp, data));
     }
 
+    private OffsetFetchResponse offsetFetchResponse(TopicPartition tp, Errors partitionLevelError, String metadata, long offset, int epoch) {
+        OffsetFetchResponse.PartitionData data = new OffsetFetchResponse.PartitionData(offset,
+                Optional.of(epoch), metadata, partitionLevelError);
+        return new OffsetFetchResponse(Errors.NONE, singletonMap(tp, data));
+    }
+
     private OffsetCommitCallback callback(final AtomicBoolean success) {
         return new OffsetCommitCallback() {
             @Override
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 68a467f..6a0a4f3 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -21,6 +21,7 @@ import org.apache.kafka.clients.ClientDnsLookup;
 import org.apache.kafka.clients.ClientRequest;
 import org.apache.kafka.clients.ClientUtils;
 import org.apache.kafka.clients.FetchSessionHandler;
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.MockClient;
 import org.apache.kafka.clients.NetworkClient;
 import org.apache.kafka.clients.NodeApiVersions;
@@ -66,6 +67,7 @@ import org.apache.kafka.common.record.SimpleRecord;
 import org.apache.kafka.common.record.TimestampType;
 import org.apache.kafka.common.requests.AbstractRequest;
 import org.apache.kafka.common.requests.ApiVersionsResponse;
+import org.apache.kafka.common.requests.EpochEndOffset;
 import org.apache.kafka.common.requests.FetchRequest;
 import org.apache.kafka.common.requests.FetchResponse;
 import org.apache.kafka.common.requests.IsolationLevel;
@@ -73,6 +75,7 @@ import org.apache.kafka.common.requests.ListOffsetRequest;
 import org.apache.kafka.common.requests.ListOffsetResponse;
 import org.apache.kafka.common.requests.MetadataRequest;
 import org.apache.kafka.common.requests.MetadataResponse;
+import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse;
 import org.apache.kafka.common.requests.ResponseHeader;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
 import org.apache.kafka.common.serialization.Deserializer;
@@ -115,6 +118,7 @@ import java.util.stream.Collectors;
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonMap;
 import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID;
+import static org.apache.kafka.test.TestUtils.assertOptional;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -206,7 +210,7 @@ public class FetcherTest {
 
         List<ConsumerRecord<byte[], byte[]>> records = partitionRecords.get(tp0);
         assertEquals(3, records.size());
-        assertEquals(4L, subscriptions.position(tp0).longValue()); // this is the next fetching position
+        assertEquals(4L, subscriptions.position(tp0).offset); // this is the next fetching position
         long offset = 1;
         for (ConsumerRecord<byte[], byte[]> record : records) {
             assertEquals(offset, record.offset());
@@ -370,7 +374,7 @@ public class FetcherTest {
 
         List<ConsumerRecord<byte[], byte[]>> records = partitionRecords.get(tp0);
         assertEquals(1, records.size());
-        assertEquals(2L, subscriptions.position(tp0).longValue());
+        assertEquals(2L, subscriptions.position(tp0).offset);
 
         ConsumerRecord<byte[], byte[]> record = records.get(0);
         assertArrayEquals("key".getBytes(), record.key());
@@ -438,7 +442,7 @@ public class FetcherTest {
                 fail("fetchedRecords should have raised");
             } catch (SerializationException e) {
                 // the position should not advance since no data has been returned
-                assertEquals(1, subscriptions.position(tp0).longValue());
+                assertEquals(1, subscriptions.position(tp0).offset);
             }
         }
     }
@@ -496,7 +500,7 @@ public class FetcherTest {
 
         // the first fetchedRecords() should return the first valid message
         assertEquals(1, fetcher.fetchedRecords().get(tp0).size());
-        assertEquals(1, subscriptions.position(tp0).longValue());
+        assertEquals(1, subscriptions.position(tp0).offset);
 
         ensureBlockOnRecord(1L);
         seekAndConsumeRecord(buffer, 2L);
@@ -518,7 +522,7 @@ public class FetcherTest {
                 fetcher.fetchedRecords();
                 fail("fetchedRecords should have raised KafkaException");
             } catch (KafkaException e) {
-                assertEquals(blockedOffset, subscriptions.position(tp0).longValue());
+                assertEquals(blockedOffset, subscriptions.position(tp0).offset);
             }
         }
     }
@@ -536,7 +540,7 @@ public class FetcherTest {
         List<ConsumerRecord<byte[], byte[]>> records = recordsByPartition.get(tp0);
         assertEquals(1, records.size());
         assertEquals(toOffset, records.get(0).offset());
-        assertEquals(toOffset + 1, subscriptions.position(tp0).longValue());
+        assertEquals(toOffset + 1, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -574,7 +578,7 @@ public class FetcherTest {
                 fetcher.fetchedRecords();
                 fail("fetchedRecords should have raised KafkaException");
             } catch (KafkaException e) {
-                assertEquals(0, subscriptions.position(tp0).longValue());
+                assertEquals(0, subscriptions.position(tp0).offset);
             }
         }
     }
@@ -604,7 +608,7 @@ public class FetcherTest {
             fail("fetchedRecords should have raised");
         } catch (KafkaException e) {
             // the position should not advance since no data has been returned
-            assertEquals(0, subscriptions.position(tp0).longValue());
+            assertEquals(0, subscriptions.position(tp0).offset);
         }
     }
 
@@ -669,7 +673,7 @@ public class FetcherTest {
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
         records = recordsByPartition.get(tp0);
         assertEquals(2, records.size());
-        assertEquals(3L, subscriptions.position(tp0).longValue());
+        assertEquals(3L, subscriptions.position(tp0).offset);
         assertEquals(1, records.get(0).offset());
         assertEquals(2, records.get(1).offset());
 
@@ -678,7 +682,7 @@ public class FetcherTest {
         recordsByPartition = fetchedRecords();
         records = recordsByPartition.get(tp0);
         assertEquals(1, records.size());
-        assertEquals(4L, subscriptions.position(tp0).longValue());
+        assertEquals(4L, subscriptions.position(tp0).offset);
         assertEquals(3, records.get(0).offset());
 
         assertTrue(fetcher.sendFetches() > 0);
@@ -686,7 +690,7 @@ public class FetcherTest {
         recordsByPartition = fetchedRecords();
         records = recordsByPartition.get(tp0);
         assertEquals(2, records.size());
-        assertEquals(6L, subscriptions.position(tp0).longValue());
+        assertEquals(6L, subscriptions.position(tp0).offset);
         assertEquals(4, records.get(0).offset());
         assertEquals(5, records.get(1).offset());
     }
@@ -712,7 +716,7 @@ public class FetcherTest {
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
         records = recordsByPartition.get(tp0);
         assertEquals(2, records.size());
-        assertEquals(3L, subscriptions.position(tp0).longValue());
+        assertEquals(3L, subscriptions.position(tp0).offset);
         assertEquals(1, records.get(0).offset());
         assertEquals(2, records.get(1).offset());
 
@@ -726,7 +730,7 @@ public class FetcherTest {
         assertNull(fetchedRecords.get(tp0));
         records = fetchedRecords.get(tp1);
         assertEquals(2, records.size());
-        assertEquals(6L, subscriptions.position(tp1).longValue());
+        assertEquals(6L, subscriptions.position(tp1).offset);
         assertEquals(4, records.get(0).offset());
         assertEquals(5, records.get(1).offset());
     }
@@ -755,7 +759,7 @@ public class FetcherTest {
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> recordsByPartition = fetchedRecords();
         consumerRecords = recordsByPartition.get(tp0);
         assertEquals(3, consumerRecords.size());
-        assertEquals(31L, subscriptions.position(tp0).longValue()); // this is the next fetching position
+        assertEquals(31L, subscriptions.position(tp0).offset); // this is the next fetching position
 
         assertEquals(15L, consumerRecords.get(0).offset());
         assertEquals(20L, consumerRecords.get(1).offset());
@@ -780,7 +784,7 @@ public class FetcherTest {
             } catch (RecordTooLargeException e) {
                 assertTrue(e.getMessage().startsWith("There are some messages at [Partition=Offset]: "));
                 // the position should not advance since no data has been returned
-                assertEquals(0, subscriptions.position(tp0).longValue());
+                assertEquals(0, subscriptions.position(tp0).offset);
             }
         } finally {
             client.setNodeApiVersions(NodeApiVersions.create());
@@ -803,7 +807,7 @@ public class FetcherTest {
         } catch (KafkaException e) {
             assertTrue(e.getMessage().startsWith("Failed to make progress reading messages"));
             // the position should not advance since no data has been returned
-            assertEquals(0, subscriptions.position(tp0).longValue());
+            assertEquals(0, subscriptions.position(tp0).offset);
         }
     }
 
@@ -944,15 +948,11 @@ public class FetcherTest {
     public void testEpochSetInFetchRequest() {
         buildFetcher();
         subscriptions.assignFromUser(singleton(tp0));
-        client.updateMetadata(initialUpdateResponse);
-
-        // Metadata update with leader epochs
-        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap(topicName, 4),
-            (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                    new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(99), replicas, Collections.emptyList(), offlineReplicas));
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1,
+                Collections.emptyMap(), Collections.singletonMap(topicName, 4), tp -> 99);
         client.updateMetadata(metadataResponse);
 
-        subscriptions.seek(tp0, 0);
+        subscriptions.seek(tp0, 10);
         assertEquals(1, fetcher.sendFetches());
 
         // Check for epoch in outgoing request
@@ -961,7 +961,7 @@ public class FetcherTest {
                 FetchRequest fetchRequest = (FetchRequest) body;
                 fetchRequest.fetchData().values().forEach(partitionData -> {
                     assertTrue("Expected Fetcher to set leader epoch in request", partitionData.currentLeaderEpoch.isPresent());
-                    assertEquals("Expected leader epoch to match epoch from metadata update", partitionData.currentLeaderEpoch.get().longValue(), 99);
+                    assertEquals("Expected leader epoch to match epoch from metadata update", 99, partitionData.currentLeaderEpoch.get().longValue());
                 });
                 return true;
             } else {
@@ -984,7 +984,8 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertTrue(subscriptions.isOffsetResetNeeded(tp0));
-        assertEquals(null, subscriptions.position(tp0));
+        assertNull(subscriptions.validPosition(tp0));
+        assertNotNull(subscriptions.position(tp0));
     }
 
     @Test
@@ -1001,7 +1002,7 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         assertEquals(0, fetcher.fetchedRecords().size());
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
-        assertEquals(1, subscriptions.position(tp0).longValue());
+        assertEquals(1, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1065,8 +1066,8 @@ public class FetcherTest {
         List<ConsumerRecord<byte[], byte[]>> allFetchedRecords = new ArrayList<>();
         fetchRecordsInto(allFetchedRecords);
 
-        assertEquals(1, subscriptions.position(tp0).longValue());
-        assertEquals(4, subscriptions.position(tp1).longValue());
+        assertEquals(1, subscriptions.position(tp0).offset);
+        assertEquals(4, subscriptions.position(tp1).offset);
         assertEquals(3, allFetchedRecords.size());
 
         OffsetOutOfRangeException e = assertThrows(OffsetOutOfRangeException.class, () ->
@@ -1075,8 +1076,8 @@ public class FetcherTest {
         assertEquals(singleton(tp0), e.offsetOutOfRangePartitions().keySet());
         assertEquals(1L, e.offsetOutOfRangePartitions().get(tp0).longValue());
 
-        assertEquals(1, subscriptions.position(tp0).longValue());
-        assertEquals(4, subscriptions.position(tp1).longValue());
+        assertEquals(1, subscriptions.position(tp0).offset);
+        assertEquals(4, subscriptions.position(tp1).offset);
         assertEquals(3, allFetchedRecords.size());
     }
 
@@ -1118,8 +1119,8 @@ public class FetcherTest {
         for (List<ConsumerRecord<byte[], byte[]>> records : recordsByPartition.values())
             fetchedRecords.addAll(records);
 
-        assertEquals(fetchedRecords.size(), subscriptions.position(tp1) - 1);
-        assertEquals(4, subscriptions.position(tp1).longValue());
+        assertEquals(fetchedRecords.size(), subscriptions.position(tp1).offset - 1);
+        assertEquals(4, subscriptions.position(tp1).offset);
         assertEquals(3, fetchedRecords.size());
 
         List<OffsetOutOfRangeException> oorExceptions = new ArrayList<>();
@@ -1142,7 +1143,7 @@ public class FetcherTest {
             fetchedRecords.addAll(records);
 
         // Should not have received an Exception for tp2.
-        assertEquals(6, subscriptions.position(tp2).longValue());
+        assertEquals(6, subscriptions.position(tp2).offset);
         assertEquals(5, fetchedRecords.size());
 
         int numExceptionsExpected = 3;
@@ -1206,7 +1207,7 @@ public class FetcherTest {
         // disconnects should have no affect on subscription state
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(0, subscriptions.position(tp0).longValue());
+        assertEquals(0, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1218,7 +1219,7 @@ public class FetcherTest {
         fetcher.resetOffsetsIfNeeded();
         assertFalse(client.hasInFlightRequests());
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1233,7 +1234,7 @@ public class FetcherTest {
         consumerClient.pollNoWakeup();
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1250,7 +1251,7 @@ public class FetcherTest {
         consumerClient.pollNoWakeup();
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     /**
@@ -1290,7 +1291,7 @@ public class FetcherTest {
         assertTrue(subscriptions.hasValidPosition(tp0));
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(subscriptions.position(tp0).longValue(), 5L);
+        assertEquals(subscriptions.position(tp0).offset, 5L);
     }
 
     @Test
@@ -1319,7 +1320,7 @@ public class FetcherTest {
 
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1346,7 +1347,7 @@ public class FetcherTest {
 
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1362,7 +1363,7 @@ public class FetcherTest {
 
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1392,7 +1393,7 @@ public class FetcherTest {
 
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1428,7 +1429,7 @@ public class FetcherTest {
 
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1476,7 +1477,7 @@ public class FetcherTest {
 
         assertFalse(client.hasPendingResponses());
         assertFalse(client.hasInFlightRequests());
-        assertEquals(237L, subscriptions.position(tp0).longValue());
+        assertEquals(237L, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1524,7 +1525,7 @@ public class FetcherTest {
 
         assertFalse(client.hasInFlightRequests());
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
-        assertEquals(5L, subscriptions.position(tp0).longValue());
+        assertEquals(5L, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1562,7 +1563,7 @@ public class FetcherTest {
 
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertTrue(subscriptions.isFetchable(tp0));
-        assertEquals(5, subscriptions.position(tp0).longValue());
+        assertEquals(5, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1580,7 +1581,7 @@ public class FetcherTest {
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertFalse(subscriptions.isFetchable(tp0)); // because tp is paused
         assertTrue(subscriptions.hasValidPosition(tp0));
-        assertEquals(10, subscriptions.position(tp0).longValue());
+        assertEquals(10, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -1610,7 +1611,7 @@ public class FetcherTest {
         assertFalse(subscriptions.isOffsetResetNeeded(tp0));
         assertFalse(subscriptions.isFetchable(tp0)); // because tp is paused
         assertTrue(subscriptions.hasValidPosition(tp0));
-        assertEquals(10, subscriptions.position(tp0).longValue());
+        assertEquals(10, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -2197,9 +2198,8 @@ public class FetcherTest {
         client.updateMetadata(initialUpdateResponse);
 
         // Metadata update with leader epochs
-        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap(topicName, 4),
-            (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
-                new MetadataResponse.PartitionMetadata(error, partition, leader, Optional.of(99), replicas, Collections.emptyList(), offlineReplicas));
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1,
+                Collections.emptyMap(), Collections.singletonMap(topicName, 4), tp -> 99);
         client.updateMetadata(metadataResponse);
 
         // Request latest offset
@@ -2571,7 +2571,7 @@ public class FetcherTest {
         }
 
         // The next offset should point to the next batch
-        assertEquals(4L, subscriptions.position(tp0).longValue());
+        assertEquals(4L, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -2602,7 +2602,7 @@ public class FetcherTest {
         assertTrue(allFetchedRecords.isEmpty());
 
         // The next offset should point to the next batch
-        assertEquals(lastOffset + 1, subscriptions.position(tp0).longValue());
+        assertEquals(lastOffset + 1, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -2734,7 +2734,7 @@ public class FetcherTest {
 
         // Ensure that we don't return any of the aborted records, but yet advance the consumer position.
         assertFalse(fetchedRecords.containsKey(tp0));
-        assertEquals(currentOffset, (long) subscriptions.position(tp0));
+        assertEquals(currentOffset, subscriptions.position(tp0).offset);
     }
 
     @Test
@@ -2743,8 +2743,8 @@ public class FetcherTest {
 
         List<ConsumerRecord<byte[], byte[]>> records;
         assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1)));
-        subscriptions.seek(tp0, 0);
-        subscriptions.seek(tp1, 1);
+        subscriptions.seek(tp0, new SubscriptionState.FetchPosition(0, Optional.empty(), metadata.leaderAndEpoch(tp0)));
+        subscriptions.seek(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.leaderAndEpoch(tp1)));
 
         // Fetch some records and establish an incremental fetch session.
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions1 = new LinkedHashMap<>();
@@ -2762,8 +2762,8 @@ public class FetcherTest {
         assertFalse(fetchedRecords.containsKey(tp1));
         records = fetchedRecords.get(tp0);
         assertEquals(2, records.size());
-        assertEquals(3L, subscriptions.position(tp0).longValue());
-        assertEquals(1L, subscriptions.position(tp1).longValue());
+        assertEquals(3L, subscriptions.position(tp0).offset);
+        assertEquals(1L, subscriptions.position(tp1).offset);
         assertEquals(1, records.get(0).offset());
         assertEquals(2, records.get(1).offset());
 
@@ -2774,7 +2774,7 @@ public class FetcherTest {
         records = fetchedRecords.get(tp0);
         assertEquals(1, records.size());
         assertEquals(3, records.get(0).offset());
-        assertEquals(4L, subscriptions.position(tp0).longValue());
+        assertEquals(4L, subscriptions.position(tp0).offset);
 
         // The second response contains no new records.
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions2 = new LinkedHashMap<>();
@@ -2784,8 +2784,8 @@ public class FetcherTest {
         consumerClient.poll(time.timer(0));
         fetchedRecords = fetchedRecords();
         assertTrue(fetchedRecords.isEmpty());
-        assertEquals(4L, subscriptions.position(tp0).longValue());
-        assertEquals(1L, subscriptions.position(tp1).longValue());
+        assertEquals(4L, subscriptions.position(tp0).offset);
+        assertEquals(1L, subscriptions.position(tp1).offset);
 
         // The third response contains some new records for tp0.
         LinkedHashMap<TopicPartition, FetchResponse.PartitionData<MemoryRecords>> partitions3 = new LinkedHashMap<>();
@@ -2799,8 +2799,8 @@ public class FetcherTest {
         assertFalse(fetchedRecords.containsKey(tp1));
         records = fetchedRecords.get(tp0);
         assertEquals(2, records.size());
-        assertEquals(6L, subscriptions.position(tp0).longValue());
-        assertEquals(1L, subscriptions.position(tp1).longValue());
+        assertEquals(6L, subscriptions.position(tp0).offset);
+        assertEquals(1L, subscriptions.position(tp1).offset);
         assertEquals(4, records.get(0).offset());
         assertEquals(5, records.get(1).offset());
     }
@@ -3098,6 +3098,171 @@ public class FetcherTest {
         assertNull(offsetAndTimestampMap.get(tp0));
     }
 
+    @Test
+    public void testSubscriptionPositionUpdatedWithEpoch() {
+        // Create some records that include a leader epoch (1)
+        MemoryRecordsBuilder builder = MemoryRecords.builder(
+                ByteBuffer.allocate(1024),
+                RecordBatch.CURRENT_MAGIC_VALUE,
+                CompressionType.NONE,
+                TimestampType.CREATE_TIME,
+                0L,
+                RecordBatch.NO_TIMESTAMP,
+                RecordBatch.NO_PRODUCER_ID,
+                RecordBatch.NO_PRODUCER_EPOCH,
+                RecordBatch.NO_SEQUENCE,
+                false,
+                1
+        );
+        builder.appendWithOffset(0L, 0L, "key".getBytes(), "value-1".getBytes());
+        builder.appendWithOffset(1L, 0L, "key".getBytes(), "value-2".getBytes());
+        builder.appendWithOffset(2L, 0L, "key".getBytes(), "value-3".getBytes());
+        MemoryRecords records = builder.build();
+
+        buildFetcher();
+        assignFromUser(singleton(tp0));
+
+        // Initialize the epoch=1
+        Map<String, Integer> partitionCounts = new HashMap<>();
+        partitionCounts.put(tp0.topic(), 4);
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, tp -> 1);
+        metadata.update(metadataResponse, 0L);
+
+        // Seek
+        subscriptions.seek(tp0, 0);
+
+        // Do a normal fetch
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        consumerClient.pollNoWakeup();
+        assertTrue(fetcher.hasCompletedFetches());
+
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
+        assertTrue(partitionRecords.containsKey(tp0));
+
+        assertEquals(subscriptions.position(tp0).offset, 3L);
+        assertOptional(subscriptions.position(tp0).offsetEpoch, value -> assertEquals(value.intValue(), 1));
+    }
+
+    @Test
+    public void testOffsetValidationFencing() {
+        buildFetcher();
+        assignFromUser(singleton(tp0));
+
+        Map<String, Integer> partitionCounts = new HashMap<>();
+        partitionCounts.put(tp0.topic(), 4);
+
+        final int epochOne = 1;
+        final int epochTwo = 2;
+        final int epochThree = 3;
+
+        // Start with metadata, epoch=1
+        metadata.update(TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, tp -> epochOne), 0L);
+
+        // Seek with a position and leader+epoch
+        Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.leaderAndEpoch(tp0).leader, Optional.of(epochOne));
+        subscriptions.seek(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch));
+
+        // Update metadata to epoch=2, enter validation
+        metadata.update(TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, tp -> epochTwo), 0L);
+        fetcher.validateOffsetsIfNeeded();
+        assertTrue(subscriptions.awaitingValidation(tp0));
+
+        // Update the position to epoch=3, as we would from a fetch
+        subscriptions.validate(tp0);
+        SubscriptionState.FetchPosition nextPosition = new SubscriptionState.FetchPosition(
+                10,
+                Optional.of(epochTwo),
+                new Metadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.of(epochTwo)));
+        subscriptions.position(tp0, nextPosition);
+        subscriptions.maybeValidatePosition(tp0, new Metadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.of(epochThree)));
+
+        // Prepare offset list response from async validation with epoch=2
+        Map<TopicPartition, EpochEndOffset> endOffsetMap = new HashMap<>();
+        endOffsetMap.put(tp0, new EpochEndOffset(Errors.NONE, epochTwo, 10L));
+        OffsetsForLeaderEpochResponse resp = new OffsetsForLeaderEpochResponse(endOffsetMap);
+        client.prepareResponse(resp);
+        consumerClient.pollNoWakeup();
+        assertTrue("Expected validation to fail since leader epoch changed", subscriptions.awaitingValidation(tp0));
+
+        // Next round of validation, should succeed in validating the position
+        fetcher.validateOffsetsIfNeeded();
+        endOffsetMap.clear();
+        endOffsetMap.put(tp0, new EpochEndOffset(Errors.NONE, epochThree, 10L));
+        resp = new OffsetsForLeaderEpochResponse(endOffsetMap);
+        client.prepareResponse(resp);
+        consumerClient.pollNoWakeup();
+        assertFalse("Expected validation to succeed with latest epoch", subscriptions.awaitingValidation(tp0));
+    }
+
+    @Test
+    public void testTruncationDetected() {
+        // Create some records that include a leader epoch (1)
+        MemoryRecordsBuilder builder = MemoryRecords.builder(
+                ByteBuffer.allocate(1024),
+                RecordBatch.CURRENT_MAGIC_VALUE,
+                CompressionType.NONE,
+                TimestampType.CREATE_TIME,
+                0L,
+                RecordBatch.NO_TIMESTAMP,
+                RecordBatch.NO_PRODUCER_ID,
+                RecordBatch.NO_PRODUCER_EPOCH,
+                RecordBatch.NO_SEQUENCE,
+                false,
+                1 // record epoch is earlier than the leader epoch on the client
+        );
+        builder.appendWithOffset(0L, 0L, "key".getBytes(), "value-1".getBytes());
+        builder.appendWithOffset(1L, 0L, "key".getBytes(), "value-2".getBytes());
+        builder.appendWithOffset(2L, 0L, "key".getBytes(), "value-3".getBytes());
+        MemoryRecords records = builder.build();
+
+        buildFetcher();
+        assignFromUser(singleton(tp0));
+
+        // Initialize the epoch=2
+        Map<String, Integer> partitionCounts = new HashMap<>();
+        partitionCounts.put(tp0.topic(), 4);
+        MetadataResponse metadataResponse = TestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, tp -> 2);
+        metadata.update(metadataResponse, 0L);
+
+        // Seek
+        Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.leaderAndEpoch(tp0).leader, Optional.of(1));
+        subscriptions.seek(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), leaderAndEpoch));
+
+        // Check for truncation, this should cause tp0 to go into validation
+        fetcher.validateOffsetsIfNeeded();
+
+        // No fetches sent since we entered validation
+        assertEquals(0, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+        assertTrue(subscriptions.awaitingValidation(tp0));
+
+        // Prepare OffsetForEpoch response then check that we update the subscription position correctly.
+        Map<TopicPartition, EpochEndOffset> endOffsetMap = new HashMap<>();
+        endOffsetMap.put(tp0, new EpochEndOffset(Errors.NONE, 1, 10L));
+        OffsetsForLeaderEpochResponse resp = new OffsetsForLeaderEpochResponse(endOffsetMap);
+        client.prepareResponse(resp);
+        consumerClient.pollNoWakeup();
+
+        assertFalse(subscriptions.awaitingValidation(tp0));
+
+        // Fetch again, now it works
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0));
+        consumerClient.pollNoWakeup();
+        assertTrue(fetcher.hasCompletedFetches());
+
+        Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
+        assertTrue(partitionRecords.containsKey(tp0));
+
+        assertEquals(subscriptions.position(tp0).offset, 3L);
+        assertOptional(subscriptions.position(tp0).offsetEpoch, value -> assertEquals(value.intValue(), 1));
+    }
+
     private MockClient.RequestMatcher listOffsetRequestMatcher(final long timestamp) {
         // matches any list offset request with the provided timestamp
         return new MockClient.RequestMatcher() {
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/OffsetForLeaderEpochClientTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/OffsetForLeaderEpochClientTest.java
new file mode 100644
index 0000000..ee00e48
--- /dev/null
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/OffsetForLeaderEpochClientTest.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.clients.consumer.internals;
+
+import org.apache.kafka.clients.Metadata;
+import org.apache.kafka.clients.MockClient;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.errors.TopicAuthorizationException;
+import org.apache.kafka.common.internals.ClusterResourceListeners;
+import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.requests.EpochEndOffset;
+import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.common.utils.Time;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+public class OffsetForLeaderEpochClientTest {
+
+    private ConsumerNetworkClient consumerClient;
+    private SubscriptionState subscriptions;
+    private Metadata metadata;
+    private MockClient client;
+    private Time time;
+
+    private TopicPartition tp0 = new TopicPartition("topic", 0);
+
+    @Test
+    public void testEmptyResponse() {
+        OffsetsForLeaderEpochClient offsetClient = newOffsetClient();
+        RequestFuture<OffsetsForLeaderEpochClient.OffsetForEpochResult> future =
+                offsetClient.sendAsyncRequest(Node.noNode(), Collections.emptyMap());
+
+        OffsetsForLeaderEpochResponse resp = new OffsetsForLeaderEpochResponse(Collections.emptyMap());
+        client.prepareResponse(resp);
+        consumerClient.pollNoWakeup();
+
+        OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value();
+        assertTrue(result.partitionsToRetry().isEmpty());
+        assertTrue(result.endOffsets().isEmpty());
+    }
+
+    @Test
+    public void testUnexpectedEmptyResponse() {
+        Map<TopicPartition, SubscriptionState.FetchPosition> positionMap = new HashMap<>();
+        positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1),
+                new Metadata.LeaderAndEpoch(Node.noNode(), Optional.of(1))));
+
+        OffsetsForLeaderEpochClient offsetClient = newOffsetClient();
+        RequestFuture<OffsetsForLeaderEpochClient.OffsetForEpochResult> future =
+                offsetClient.sendAsyncRequest(Node.noNode(), positionMap);
+
+        OffsetsForLeaderEpochResponse resp = new OffsetsForLeaderEpochResponse(Collections.emptyMap());
+        client.prepareResponse(resp);
+        consumerClient.pollNoWakeup();
+
+        OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value();
+        assertFalse(result.partitionsToRetry().isEmpty());
+        assertTrue(result.endOffsets().isEmpty());
+    }
+
+    @Test
+    public void testOkResponse() {
+        Map<TopicPartition, SubscriptionState.FetchPosition> positionMap = new HashMap<>();
+        positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1),
+                new Metadata.LeaderAndEpoch(Node.noNode(), Optional.of(1))));
+
+        OffsetsForLeaderEpochClient offsetClient = newOffsetClient();
+        RequestFuture<OffsetsForLeaderEpochClient.OffsetForEpochResult> future =
+                offsetClient.sendAsyncRequest(Node.noNode(), positionMap);
+
+        Map<TopicPartition, EpochEndOffset> endOffsetMap = new HashMap<>();
+        endOffsetMap.put(tp0, new EpochEndOffset(Errors.NONE, 1, 10L));
+        client.prepareResponse(new OffsetsForLeaderEpochResponse(endOffsetMap));
+        consumerClient.pollNoWakeup();
+
+        OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value();
+        assertTrue(result.partitionsToRetry().isEmpty());
+        assertTrue(result.endOffsets().containsKey(tp0));
+        assertEquals(result.endOffsets().get(tp0).error(), Errors.NONE);
+        assertEquals(result.endOffsets().get(tp0).leaderEpoch(), 1);
+        assertEquals(result.endOffsets().get(tp0).endOffset(), 10L);
+    }
+
+    @Test
+    public void testUnauthorizedTopic() {
+        Map<TopicPartition, SubscriptionState.FetchPosition> positionMap = new HashMap<>();
+        positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1),
+                new Metadata.LeaderAndEpoch(Node.noNode(), Optional.of(1))));
+
+        OffsetsForLeaderEpochClient offsetClient = newOffsetClient();
+        RequestFuture<OffsetsForLeaderEpochClient.OffsetForEpochResult> future =
+                offsetClient.sendAsyncRequest(Node.noNode(), positionMap);
+
+        Map<TopicPartition, EpochEndOffset> endOffsetMap = new HashMap<>();
+        endOffsetMap.put(tp0, new EpochEndOffset(Errors.TOPIC_AUTHORIZATION_FAILED, -1, -1));
+        client.prepareResponse(new OffsetsForLeaderEpochResponse(endOffsetMap));
+        consumerClient.pollNoWakeup();
+
+        assertTrue(future.failed());
+        assertEquals(future.exception().getClass(), TopicAuthorizationException.class);
+        assertTrue(((TopicAuthorizationException) future.exception()).unauthorizedTopics().contains(tp0.topic()));
+    }
+
+    @Test
+    public void testRetriableError() {
+        Map<TopicPartition, SubscriptionState.FetchPosition> positionMap = new HashMap<>();
+        positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1),
+                new Metadata.LeaderAndEpoch(Node.noNode(), Optional.of(1))));
+
+        OffsetsForLeaderEpochClient offsetClient = newOffsetClient();
+        RequestFuture<OffsetsForLeaderEpochClient.OffsetForEpochResult> future =
+                offsetClient.sendAsyncRequest(Node.noNode(), positionMap);
+
+        Map<TopicPartition, EpochEndOffset> endOffsetMap = new HashMap<>();
+        endOffsetMap.put(tp0, new EpochEndOffset(Errors.LEADER_NOT_AVAILABLE, -1, -1));
+        client.prepareResponse(new OffsetsForLeaderEpochResponse(endOffsetMap));
+        consumerClient.pollNoWakeup();
+
+        assertFalse(future.failed());
+        OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value();
+        assertTrue(result.partitionsToRetry().contains(tp0));
+        assertFalse(result.endOffsets().containsKey(tp0));
+    }
+
+    private OffsetsForLeaderEpochClient newOffsetClient() {
+        buildDependencies(OffsetResetStrategy.EARLIEST);
+        return new OffsetsForLeaderEpochClient(consumerClient, new LogContext());
+    }
+
+    private void buildDependencies(OffsetResetStrategy offsetResetStrategy) {
+        LogContext logContext = new LogContext();
+        time = new MockTime(1);
+        subscriptions = new SubscriptionState(logContext, offsetResetStrategy);
+        metadata = new ConsumerMetadata(0, Long.MAX_VALUE, false,
+                subscriptions, logContext, new ClusterResourceListeners());
+        client = new MockClient(time, metadata);
+        consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time,
+                100, 1000, Integer.MAX_VALUE);
+    }
+}
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
index 5d4d113..701866d 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java
@@ -16,8 +16,10 @@
  */
 package org.apache.kafka.clients.consumer.internals;
 
+import org.apache.kafka.clients.Metadata;
 import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.common.Node;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.Utils;
@@ -27,12 +29,14 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.Optional;
 import java.util.Set;
 import java.util.regex.Pattern;
 
 import static java.util.Collections.singleton;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 
 public class SubscriptionStateTest {
@@ -46,6 +50,7 @@ public class SubscriptionStateTest {
     private final TopicPartition tp1 = new TopicPartition(topic, 1);
     private final TopicPartition t1p0 = new TopicPartition(topic1, 0);
     private final MockRebalanceListener rebalanceListener = new MockRebalanceListener();
+    private final Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(Node.noNode(), Optional.empty());
 
     @Test
     public void partitionAssignment() {
@@ -55,7 +60,7 @@ public class SubscriptionStateTest {
         assertFalse(state.hasAllFetchPositions());
         state.seek(tp0, 1);
         assertTrue(state.isFetchable(tp0));
-        assertEquals(1L, state.position(tp0).longValue());
+        assertEquals(1L, state.position(tp0).offset);
         state.assignFromUser(Collections.<TopicPartition>emptySet());
         assertTrue(state.assignedPartitions().isEmpty());
         assertEquals(0, state.numAssignedPartitions());
@@ -167,11 +172,11 @@ public class SubscriptionStateTest {
     public void partitionReset() {
         state.assignFromUser(singleton(tp0));
         state.seek(tp0, 5);
-        assertEquals(5L, (long) state.position(tp0));
+        assertEquals(5L, state.position(tp0).offset);
         state.requestOffsetReset(tp0);
         assertFalse(state.isFetchable(tp0));
         assertTrue(state.isOffsetResetNeeded(tp0));
-        assertEquals(null, state.position(tp0));
+        assertNotNull(state.position(tp0));
 
         // seek should clear the reset and make the partition fetchable
         state.seek(tp0, 0);
@@ -188,7 +193,7 @@ public class SubscriptionStateTest {
         assertTrue(state.partitionsAutoAssigned());
         assertTrue(state.assignFromSubscribed(singleton(tp0)));
         state.seek(tp0, 1);
-        assertEquals(1L, state.position(tp0).longValue());
+        assertEquals(1L, state.position(tp0).offset);
         assertTrue(state.assignFromSubscribed(singleton(tp1)));
         assertTrue(state.isAssigned(tp1));
         assertFalse(state.isAssigned(tp0));
@@ -212,7 +217,7 @@ public class SubscriptionStateTest {
     public void invalidPositionUpdate() {
         state.subscribe(singleton(topic), rebalanceListener);
         assertTrue(state.assignFromSubscribed(singleton(tp0)));
-        state.position(tp0, 0);
+        state.position(tp0, new SubscriptionState.FetchPosition(0, Optional.empty(), leaderAndEpoch));
     }
 
     @Test
@@ -230,7 +235,7 @@ public class SubscriptionStateTest {
 
     @Test(expected = IllegalStateException.class)
     public void cantChangePositionForNonAssignedPartition() {
-        state.position(tp0, 1);
+        state.position(tp0, new SubscriptionState.FetchPosition(1, Optional.empty(), leaderAndEpoch));
     }
 
     @Test(expected = IllegalStateException.class)
diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
index f7a37ba..58d900d 100644
--- a/clients/src/test/java/org/apache/kafka/test/TestUtils.java
+++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java
@@ -21,6 +21,7 @@ import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.Cluster;
 import org.apache.kafka.common.Node;
 import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.internals.Topic;
 import org.apache.kafka.common.network.NetworkReceive;
 import org.apache.kafka.common.protocol.ApiKeys;
@@ -52,6 +53,8 @@ import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
+import java.util.function.Consumer;
+import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
@@ -113,20 +116,29 @@ public class TestUtils {
     public static MetadataResponse metadataUpdateWith(final String clusterId,
                                                       final int numNodes,
                                                       final Map<String, Integer> topicPartitionCounts) {
-        return metadataUpdateWith(clusterId, numNodes, Collections.emptyMap(), topicPartitionCounts, MetadataResponse.PartitionMetadata::new);
+        return metadataUpdateWith(clusterId, numNodes, Collections.emptyMap(), topicPartitionCounts, tp -> null, MetadataResponse.PartitionMetadata::new);
     }
 
     public static MetadataResponse metadataUpdateWith(final String clusterId,
                                                       final int numNodes,
                                                       final Map<String, Errors> topicErrors,
                                                       final Map<String, Integer> topicPartitionCounts) {
-        return metadataUpdateWith(clusterId, numNodes, topicErrors, topicPartitionCounts, MetadataResponse.PartitionMetadata::new);
+        return metadataUpdateWith(clusterId, numNodes, topicErrors, topicPartitionCounts, tp -> null, MetadataResponse.PartitionMetadata::new);
     }
 
     public static MetadataResponse metadataUpdateWith(final String clusterId,
                                                       final int numNodes,
                                                       final Map<String, Errors> topicErrors,
                                                       final Map<String, Integer> topicPartitionCounts,
+                                                      final Function<TopicPartition, Integer> epochSupplier) {
+        return metadataUpdateWith(clusterId, numNodes, topicErrors, topicPartitionCounts, epochSupplier, MetadataResponse.PartitionMetadata::new);
+    }
+
+    public static MetadataResponse metadataUpdateWith(final String clusterId,
+                                                      final int numNodes,
+                                                      final Map<String, Errors> topicErrors,
+                                                      final Map<String, Integer> topicPartitionCounts,
+                                                      final Function<TopicPartition, Integer> epochSupplier,
                                                       final PartitionMetadataSupplier partitionSupplier) {
         final List<Node> nodes = new ArrayList<>(numNodes);
         for (int i = 0; i < numNodes; i++)
@@ -139,10 +151,11 @@ public class TestUtils {
 
             List<MetadataResponse.PartitionMetadata> partitionMetadata = new ArrayList<>(numPartitions);
             for (int i = 0; i < numPartitions; i++) {
+                TopicPartition tp = new TopicPartition(topic, i);
                 Node leader = nodes.get(i % nodes.size());
                 List<Node> replicas = Collections.singletonList(leader);
                 partitionMetadata.add(partitionSupplier.supply(
-                        Errors.NONE, i, leader, Optional.empty(), replicas, replicas, replicas));
+                        Errors.NONE, i, leader, Optional.ofNullable(epochSupplier.apply(tp)), replicas, replicas, replicas));
             }
 
             topicMetadata.add(new MetadataResponse.TopicMetadata(Errors.NONE, topic,
@@ -443,4 +456,12 @@ public class TestUtils {
     public static ApiKeys apiKeyFrom(NetworkReceive networkReceive) {
         return RequestHeader.parse(networkReceive.payload().duplicate()).apiKey();
     }
+
+    public static <T> void assertOptional(Optional<T> optional, Consumer<T> assertion) {
+        if (optional.isPresent()) {
+            assertion.accept(optional.get());
+        } else {
+            fail("Missing value from Optional");
+        }
+    }
 }
diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
index 7b896ec..6426168 100644
--- a/streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java
@@ -21,9 +21,11 @@ import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
 import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.tests.SmokeTestClient;
 import org.apache.kafka.streams.tests.SmokeTestDriver;
+import org.apache.kafka.test.IntegrationTest;
 import org.junit.Assert;
 import org.junit.ClassRule;
 import org.junit.Test;
+import org.junit.experimental.categories.Category;
 
 import java.time.Duration;
 import java.util.ArrayList;
@@ -34,6 +36,7 @@ import java.util.Set;
 import static org.apache.kafka.streams.tests.SmokeTestDriver.generate;
 import static org.apache.kafka.streams.tests.SmokeTestDriver.verify;
 
+@Category(IntegrationTest.class)
 public class SmokeTestDriverIntegrationTest {
     @ClassRule
     public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(3);
diff --git a/tests/kafkatest/services/console_consumer.py b/tests/kafkatest/services/console_consumer.py
index dfbec9f..d85c9dc 100644
--- a/tests/kafkatest/services/console_consumer.py
+++ b/tests/kafkatest/services/console_consumer.py
@@ -62,7 +62,7 @@ class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService)
                  client_id="console-consumer", print_key=False, jmx_object_names=None, jmx_attributes=None,
                  enable_systest_events=False, stop_timeout_sec=35, print_timestamp=False,
                  isolation_level="read_uncommitted", jaas_override_variables=None,
-                 kafka_opts_override="", client_prop_file_override=""):
+                 kafka_opts_override="", client_prop_file_override="", consumer_properties={}):
         """
         Args:
             context:                    standard context
@@ -120,6 +120,7 @@ class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService)
         self.jaas_override_variables = jaas_override_variables or {}
         self.kafka_opts_override = kafka_opts_override
         self.client_prop_file_override = client_prop_file_override
+        self.consumer_properties = consumer_properties
 
 
     def prop_file(self, node):
@@ -205,6 +206,10 @@ class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService)
             assert node.version >= V_0_10_0_0
             cmd += " --enable-systest-events"
 
+        if self.consumer_properties is not None:
+            for k, v in self.consumer_properties.items():
+                cmd += "--consumer_properties %s=%s" % (k, v)
+
         cmd += " 2>> %(stderr)s | tee -a %(stdout)s &" % args
         return cmd
 
diff --git a/tests/kafkatest/services/kafka/kafka.py b/tests/kafkatest/services/kafka/kafka.py
index 72d84a0..6c0920e 100644
--- a/tests/kafkatest/services/kafka/kafka.py
+++ b/tests/kafkatest/services/kafka/kafka.py
@@ -405,6 +405,18 @@ class KafkaService(KafkaPathResolverMixin, JmxMixin, Service):
         self.logger.info("Running alter message format command...\n%s" % cmd)
         node.account.ssh(cmd)
 
+    def set_unclean_leader_election(self, topic, value=True, node=None):
+        if node is None:
+            node = self.nodes[0]
+        if value is True:
+            self.logger.info("Enabling unclean leader election for topic %s", topic)
+        else:
+            self.logger.info("Disabling unclean leader election for topic %s", topic)
+        cmd = "%s --zookeeper %s --entity-name %s --entity-type topics --alter --add-config unclean.leader.election.enable=%s" % \
+              (self.path.script("kafka-configs.sh", node), self.zk_connect_setting(), topic, str(value).lower())
+        self.logger.info("Running alter unclean leader command...\n%s" % cmd)
+        node.account.ssh(cmd)
+
     def parse_describe_topic(self, topic_description):
         """Parse output of kafka-topics.sh --describe (or describe_topic() method above), which is a string of form
         PartitionCount:2\tReplicationFactor:2\tConfigs:
diff --git a/tests/kafkatest/services/verifiable_consumer.py b/tests/kafkatest/services/verifiable_consumer.py
index 6e5d50e..b9aa6a7 100644
--- a/tests/kafkatest/services/verifiable_consumer.py
+++ b/tests/kafkatest/services/verifiable_consumer.py
@@ -15,7 +15,6 @@
 
 import json
 import os
-import signal
 
 from ducktape.services.background_thread import BackgroundThreadService
 
@@ -34,7 +33,7 @@ class ConsumerState:
 
 class ConsumerEventHandler(object):
 
-    def __init__(self, node):
+    def __init__(self, node, verify_offsets):
         self.node = node
         self.state = ConsumerState.Dead
         self.revoked_count = 0
@@ -43,6 +42,7 @@ class ConsumerEventHandler(object):
         self.position = {}
         self.committed = {}
         self.total_consumed = 0
+        self.verify_offsets = verify_offsets
 
     def handle_shutdown_complete(self):
         self.state = ConsumerState.Dead
@@ -72,7 +72,7 @@ class ConsumerEventHandler(object):
                     (offset, self.position[tp], str(tp))
                 self.committed[tp] = offset
 
-    def handle_records_consumed(self, event):
+    def handle_records_consumed(self, event, logger):
         assert self.state == ConsumerState.Joined, \
             "Consumed records should only be received when joined (current state: %s)" % str(self.state)
 
@@ -85,12 +85,18 @@ class ConsumerEventHandler(object):
             assert tp in self.assignment, \
                 "Consumed records for partition %s which is not assigned (current assignment: %s)" % \
                 (str(tp), str(self.assignment))
-            assert tp not in self.position or self.position[tp] == min_offset, \
-                "Consumed from an unexpected offset (%d, %d) for partition %s" % \
-                (self.position[tp], min_offset, str(tp))
-            self.position[tp] = max_offset + 1 
-
-        self.total_consumed += event["count"]
+            if tp not in self.position or self.position[tp] == min_offset:
+                self.position[tp] = max_offset + 1
+            else:
+                msg = "Consumed from an unexpected offset (%d, %d) for partition %s" % \
+                      (self.position.get(tp), min_offset, str(tp))
+                if self.verify_offsets:
+                    raise AssertionError(msg)
+                else:
+                    if tp in self.position:
+                        self.position[tp] = max_offset + 1
+                    logger.warn(msg)
+            self.total_consumed += event["count"]
 
     def handle_partitions_revoked(self, event):
         self.revoked_count += 1
@@ -162,7 +168,7 @@ class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
                  max_messages=-1, session_timeout_sec=30, enable_autocommit=False,
                  assignment_strategy="org.apache.kafka.clients.consumer.RangeAssignor",
                  version=DEV_BRANCH, stop_timeout_sec=30, log_level="INFO", jaas_override_variables=None,
-                 on_record_consumed=None):
+                 on_record_consumed=None, reset_policy="latest", verify_offsets=True):
         """
         :param jaas_override_variables: A dict of variables to be used in the jaas.conf template file
         """
@@ -172,6 +178,7 @@ class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
         self.kafka = kafka
         self.topic = topic
         self.group_id = group_id
+        self.reset_policy = reset_policy
         self.max_messages = max_messages
         self.session_timeout_sec = session_timeout_sec
         self.enable_autocommit = enable_autocommit
@@ -179,6 +186,7 @@ class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
         self.prop_file = ""
         self.stop_timeout_sec = stop_timeout_sec
         self.on_record_consumed = on_record_consumed
+        self.verify_offsets = verify_offsets
 
         self.event_handlers = {}
         self.global_position = {}
@@ -194,7 +202,7 @@ class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
     def _worker(self, idx, node):
         with self.lock:
             if node not in self.event_handlers:
-                self.event_handlers[node] = ConsumerEventHandler(node)
+                self.event_handlers[node] = ConsumerEventHandler(node, self.verify_offsets)
             handler = self.event_handlers[node]
 
         node.account.ssh("mkdir -p %s" % VerifiableConsumer.PERSISTENT_ROOT, allow_fail=False)
@@ -228,7 +236,7 @@ class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
                         handler.handle_offsets_committed(event, node, self.logger)
                         self._update_global_committed(event)
                     elif name == "records_consumed":
-                        handler.handle_records_consumed(event)
+                        handler.handle_records_consumed(event, self.logger)
                         self._update_global_position(event, node)
                     elif name == "record_data" and self.on_record_consumed:
                         self.on_record_consumed(event, node)
@@ -244,9 +252,13 @@ class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
             tp = TopicPartition(consumed_partition["topic"], consumed_partition["partition"])
             if tp in self.global_committed:
                 # verify that the position never gets behind the current commit.
-                assert self.global_committed[tp] <= consumed_partition["minOffset"], \
-                    "Consumed position %d is behind the current committed offset %d for partition %s" % \
-                    (consumed_partition["minOffset"], self.global_committed[tp], str(tp))
+                if self.global_committed[tp] > consumed_partition["minOffset"]:
+                    msg = "Consumed position %d is behind the current committed offset %d for partition %s" % \
+                          (consumed_partition["minOffset"], self.global_committed[tp], str(tp))
+                    if self.verify_offsets:
+                        raise AssertionError(msg)
+                    else:
+                        self.logger.warn(msg)
 
             # the consumer cannot generally guarantee that the position increases monotonically
             # without gaps in the face of hard failures, so we only log a warning when this happens
@@ -274,8 +286,8 @@ class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, Backgrou
         cmd += self.impl.exec_cmd(node)
         if self.on_record_consumed:
             cmd += " --verbose"
-        cmd += " --group-id %s --topic %s --broker-list %s --session-timeout %s --assignment-strategy %s %s" % \
-               (self.group_id, self.topic, self.kafka.bootstrap_servers(self.security_config.security_protocol),
+        cmd += " --reset-policy %s --group-id %s --topic %s --broker-list %s --session-timeout %s --assignment-strategy %s %s" % \
+               (self.reset_policy, self.group_id, self.topic, self.kafka.bootstrap_servers(self.security_config.security_protocol),
                self.session_timeout_sec*1000, self.assignment_strategy, "--enable-autocommit" if self.enable_autocommit else "")
                
         if self.max_messages > 0:
diff --git a/tests/kafkatest/tests/client/truncation_test.py b/tests/kafkatest/tests/client/truncation_test.py
new file mode 100644
index 0000000..8269de7
--- /dev/null
+++ b/tests/kafkatest/tests/client/truncation_test.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ducktape.mark.resource import cluster
+from ducktape.utils.util import wait_until
+
+from kafkatest.tests.verifiable_consumer_test import VerifiableConsumerTest
+from kafkatest.services.kafka import TopicPartition
+from kafkatest.services.verifiable_consumer import VerifiableConsumer
+
+
+
+class TruncationTest(VerifiableConsumerTest):
+    TOPIC = "test_topic"
+    NUM_PARTITIONS = 1
+    TOPICS = {
+        TOPIC: {
+            'partitions': NUM_PARTITIONS,
+            'replication-factor': 2
+        }
+    }
+    GROUP_ID = "truncation-test"
+
+    def __init__(self, test_context):
+        super(TruncationTest, self).__init__(test_context, num_consumers=1, num_producers=1,
+                                             num_zk=1, num_brokers=3, topics=self.TOPICS)
+        self.last_total = 0
+        self.all_offsets_consumed = []
+        self.all_values_consumed = []
+
+    def setup_consumer(self, topic, **kwargs):
+        consumer = super(TruncationTest, self).setup_consumer(topic, **kwargs)
+        self.mark_for_collect(consumer, 'verifiable_consumer_stdout')
+
+        def print_record(event, node):
+            self.all_offsets_consumed.append(event['offset'])
+            self.all_values_consumed.append(event['value'])
+        consumer.on_record_consumed = print_record
+
+        return consumer
+
+    @cluster(num_nodes=7)
+    def test_offset_truncate(self):
+        """
+        Verify correct consumer behavior when the brokers are consecutively restarted.
+
+        Setup: single Kafka cluster with one producer writing messages to a single topic with one
+        partition, an a set of consumers in the same group reading from the same topic.
+
+        - Start a producer which continues producing new messages throughout the test.
+        - Start up the consumers and wait until they've joined the group.
+        - In a loop, restart each broker consecutively, waiting for the group to stabilize between
+          each broker restart.
+        - Verify delivery semantics according to the failure type and that the broker bounces
+          did not cause unexpected group rebalances.
+        """
+        tp = TopicPartition(self.TOPIC, 0)
+
+        producer = self.setup_producer(self.TOPIC, throughput=10)
+        producer.start()
+        self.await_produced_messages(producer, min_messages=10)
+
+        consumer = self.setup_consumer(self.TOPIC, reset_policy="earliest", verify_offsets=False)
+        consumer.start()
+        self.await_all_members(consumer)
+
+        # Reduce ISR to one node
+        isr = self.kafka.isr_idx_list(self.TOPIC, 0)
+        node1 = self.kafka.get_node(isr[0])
+        self.kafka.stop_node(node1)
+        self.logger.info("Reduced ISR to one node, consumer is at %s", consumer.current_position(tp))
+
+        # Ensure remaining ISR member has a little bit of data
+        current_total = consumer.total_consumed()
+        wait_until(lambda: consumer.total_consumed() > current_total + 10,
+                   timeout_sec=30,
+                   err_msg="Timed out waiting for consumer to move ahead by 10 messages")
+
+        # Kill last ISR member
+        node2 = self.kafka.get_node(isr[1])
+        self.kafka.stop_node(node2)
+        self.logger.info("No members in ISR, consumer is at %s", consumer.current_position(tp))
+
+        # Keep consuming until we've caught up to HW
+        def none_consumed(this, consumer):
+            new_total = consumer.total_consumed()
+            if new_total == this.last_total:
+                return True
+            else:
+                this.last_total = new_total
+                return False
+
+        self.last_total = consumer.total_consumed()
+        wait_until(lambda: none_consumed(self, consumer),
+                   timeout_sec=30,
+                   err_msg="Timed out waiting for the consumer to catch up")
+
+        self.kafka.start_node(node1)
+        self.logger.info("Out of sync replica is online, but not electable. Consumer is at  %s", consumer.current_position(tp))
+
+        pre_truncation_pos = consumer.current_position(tp)
+
+        self.kafka.set_unclean_leader_election(self.TOPIC)
+        self.logger.info("New unclean leader, consumer is at %s", consumer.current_position(tp))
+
+        # Wait for truncation to be detected
+        self.kafka.start_node(node2)
+        wait_until(lambda: consumer.current_position(tp) >= pre_truncation_pos,
+                   timeout_sec=30,
+                   err_msg="Timed out waiting for truncation")
+
+        # Make sure we didn't reset to beginning of log
+        total_records_consumed = len(self.all_values_consumed)
+        assert total_records_consumed == len(set(self.all_values_consumed)), "Received duplicate records"
+
+        consumer.stop()
+        producer.stop()
+
+        # Re-consume all the records
+        consumer2 = VerifiableConsumer(self.test_context, 1, self.kafka, self.TOPIC, group_id="group2",
+                                       reset_policy="earliest", verify_offsets=True)
+
+        consumer2.start()
+        self.await_all_members(consumer2)
+
+        wait_until(lambda: consumer2.total_consumed() > 0,
+           timeout_sec=30,
+           err_msg="Timed out waiting for consumer to consume at least 10 messages")
+
+        self.last_total = consumer2.total_consumed()
+        wait_until(lambda: none_consumed(self, consumer2),
+               timeout_sec=30,
+               err_msg="Timed out waiting for the consumer to fully consume data")
+
+        second_total_consumed = consumer2.total_consumed()
+        assert second_total_consumed < total_records_consumed, "Expected fewer records with new consumer since we truncated"
+        self.logger.info("Second consumer saw only %s, meaning %s were truncated",
+                         second_total_consumed, total_records_consumed - second_total_consumed)
\ No newline at end of file
diff --git a/tests/kafkatest/tests/verifiable_consumer_test.py b/tests/kafkatest/tests/verifiable_consumer_test.py
index 2ba2a61..539a0f3 100644
--- a/tests/kafkatest/tests/verifiable_consumer_test.py
+++ b/tests/kafkatest/tests/verifiable_consumer_test.py
@@ -16,8 +16,6 @@
 from ducktape.utils.util import wait_until
 
 from kafkatest.tests.kafka_test import KafkaTest
-from kafkatest.services.zookeeper import ZookeeperService
-from kafkatest.services.kafka import KafkaService
 from kafkatest.services.verifiable_producer import VerifiableProducer
 from kafkatest.services.verifiable_consumer import VerifiableConsumer
 from kafkatest.services.kafka import TopicPartition
@@ -55,15 +53,16 @@ class VerifiableConsumerTest(KafkaTest):
         """Override this since we're adding services outside of the constructor"""
         return super(VerifiableConsumerTest, self).min_cluster_size() + self.num_consumers + self.num_producers
 
-    def setup_consumer(self, topic, enable_autocommit=False, assignment_strategy="org.apache.kafka.clients.consumer.RangeAssignor"):
+    def setup_consumer(self, topic, enable_autocommit=False,
+                       assignment_strategy="org.apache.kafka.clients.consumer.RangeAssignor", **kwargs):
         return VerifiableConsumer(self.test_context, self.num_consumers, self.kafka,
                                   topic, self.group_id, session_timeout_sec=self.session_timeout_sec,
                                   assignment_strategy=assignment_strategy, enable_autocommit=enable_autocommit,
-                                  log_level="TRACE")
+                                  log_level="TRACE", **kwargs)
 
-    def setup_producer(self, topic, max_messages=-1):
+    def setup_producer(self, topic, max_messages=-1, throughput=500):
         return VerifiableProducer(self.test_context, self.num_producers, self.kafka, topic,
-                                  max_messages=max_messages, throughput=500,
+                                  max_messages=max_messages, throughput=throughput,
                                   request_timeout_sec=self.PRODUCER_REQUEST_TIMEOUT_SEC,
                                   log_level="DEBUG")