You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/05/29 23:09:56 UTC

[kafka] branch 2.3 updated: KAFKA-7703; position() may return a wrong offset after seekToEnd (#6407)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.3
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.3 by this push:
     new b7cf712  KAFKA-7703; position() may return a wrong offset after seekToEnd (#6407)
b7cf712 is described below

commit b7cf712f90d7a9af8718ef361db0299cdc0b0e7e
Author: Viktor Somogyi <vi...@gmail.com>
AuthorDate: Thu May 30 00:59:08 2019 +0200

    KAFKA-7703; position() may return a wrong offset after seekToEnd (#6407)
    
    When poll is called which resets the offsets to the beginning, followed by a seekToEnd and a position, it could happen that the "reset to earliest" call in poll overrides the "reset to latest" initiated by seekToEnd in a very delicate way:
    
    1. both request has been issued and returned to the client side (listOffsetResponse has happened)
    2. in Fetcher.resetOffsetIfNeeded(TopicPartition, Long, OffsetData) the thread scheduler could prefer the heartbeat thread with the "reset to earliest" call, overriding the offset to the earliest and setting the SubscriptionState with that position.
    3. The thread scheduler continues execution of the thread (application thread) with the "reset to latest" call and discards it as the "reset to earliest" already set the position - the wrong one.
    4. The blocking position call returns with the earliest offset instead of the latest, despite it wasn't expected.
    
    The fix makes SubscriptionState synchronized so that we can verify that the reset is expected while holding the lock.
    
    Reviewers: Jason Gustafson <ja...@confluent.io>
---
 .../kafka/clients/consumer/KafkaConsumer.java      |  12 +-
 .../kafka/clients/consumer/MockConsumer.java       |   8 +-
 .../consumer/internals/ConsumerCoordinator.java    |   5 +-
 .../kafka/clients/consumer/internals/Fetcher.java  |  34 ++-
 .../consumer/internals/SubscriptionState.java      | 266 ++++++++++++---------
 .../clients/consumer/internals/FetcherTest.java    |  81 ++++++-
 6 files changed, 247 insertions(+), 159 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
index 7f63e8b..9c1a45c 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java
@@ -888,7 +888,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
     public Set<TopicPartition> assignment() {
         acquireAndEnsureOpen();
         try {
-            return Collections.unmodifiableSet(new HashSet<>(this.subscriptions.assignedPartitions()));
+            return Collections.unmodifiableSet(this.subscriptions.assignedPartitions());
         } finally {
             release();
         }
@@ -1605,10 +1605,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         acquireAndEnsureOpen();
         try {
             Collection<TopicPartition> parts = partitions.size() == 0 ? this.subscriptions.assignedPartitions() : partitions;
-            for (TopicPartition tp : parts) {
-                log.info("Seeking to beginning of partition {}", tp);
-                subscriptions.requestOffsetReset(tp, OffsetResetStrategy.EARLIEST);
-            }
+            subscriptions.requestOffsetReset(parts, OffsetResetStrategy.EARLIEST);
         } finally {
             release();
         }
@@ -1633,10 +1630,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
         acquireAndEnsureOpen();
         try {
             Collection<TopicPartition> parts = partitions.size() == 0 ? this.subscriptions.assignedPartitions() : partitions;
-            for (TopicPartition tp : parts) {
-                log.info("Seeking to end of partition {}", tp);
-                subscriptions.requestOffsetReset(tp, OffsetResetStrategy.LATEST);
-            }
+            subscriptions.requestOffsetReset(parts, OffsetResetStrategy.LATEST);
         } finally {
             release();
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
index c8c2e72..660a112 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java
@@ -213,7 +213,7 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
     public synchronized void addRecord(ConsumerRecord<K, V> record) {
         ensureNotClosed();
         TopicPartition tp = new TopicPartition(record.topic(), record.partition());
-        Set<TopicPartition> currentAssigned = new HashSet<>(this.subscriptions.assignedPartitions());
+        Set<TopicPartition> currentAssigned = this.subscriptions.assignedPartitions();
         if (!currentAssigned.contains(tp))
             throw new IllegalStateException("Cannot add records for a partition that is not assigned to the consumer");
         List<ConsumerRecord<K, V>> recs = this.records.computeIfAbsent(tp, k -> new ArrayList<>());
@@ -312,8 +312,7 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
     @Override
     public synchronized void seekToBeginning(Collection<TopicPartition> partitions) {
         ensureNotClosed();
-        for (TopicPartition tp : partitions)
-            subscriptions.requestOffsetReset(tp, OffsetResetStrategy.EARLIEST);
+        subscriptions.requestOffsetReset(partitions, OffsetResetStrategy.EARLIEST);
     }
 
     public synchronized void updateBeginningOffsets(Map<TopicPartition, Long> newOffsets) {
@@ -323,8 +322,7 @@ public class MockConsumer<K, V> implements Consumer<K, V> {
     @Override
     public synchronized void seekToEnd(Collection<TopicPartition> partitions) {
         ensureNotClosed();
-        for (TopicPartition tp : partitions)
-            subscriptions.requestOffsetReset(tp, OffsetResetStrategy.LATEST);
+        subscriptions.requestOffsetReset(partitions, OffsetResetStrategy.LATEST);
     }
 
     // needed for cases where you make a second call to endOffsets
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index b03af74..d88c47b 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -265,7 +265,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
             return;
         }
 
-        Set<TopicPartition> assignedPartitions = new HashSet<>(subscriptions.assignedPartitions());
+        Set<TopicPartition> assignedPartitions = subscriptions.assignedPartitions();
 
         // The leader may have assigned partitions which match our subscription pattern, but which
         // were not explicitly requested, so we update the joined subscription here.
@@ -472,8 +472,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
 
         // execute the user's callback before rebalance
         ConsumerRebalanceListener listener = subscriptions.rebalanceListener();
-        // copy since about to be handed to user code
-        Set<TopicPartition> revoked = new HashSet<>(subscriptions.assignedPartitions());
+        Set<TopicPartition> revoked = subscriptions.assignedPartitions();
         log.info("Revoking previously assigned partitions {}", revoked);
         try {
             listener.onPartitionsRevoked(revoked);
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 59bc14c..d839791 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -420,6 +420,15 @@ public class Fetcher<K, V> implements Closeable {
             return null;
     }
 
+    private OffsetResetStrategy timestampToOffsetResetStrategy(long timestamp) {
+        if (timestamp == ListOffsetRequest.EARLIEST_TIMESTAMP)
+            return OffsetResetStrategy.EARLIEST;
+        else if (timestamp == ListOffsetRequest.LATEST_TIMESTAMP)
+            return OffsetResetStrategy.LATEST;
+        else
+            return null;
+    }
+
     /**
      * Reset offsets for all assigned partitions that require it.
      *
@@ -664,22 +673,11 @@ public class Fetcher<K, V> implements Closeable {
         return emptyList();
     }
 
-    private void resetOffsetIfNeeded(TopicPartition partition, Long requestedResetTimestamp, ListOffsetData offsetData) {
-        // we might lose the assignment while fetching the offset, or the user might seek to a different offset,
-        // so verify it is still assigned and still in need of the requested reset
-        if (!subscriptions.isAssigned(partition)) {
-            log.debug("Skipping reset of partition {} since it is no longer assigned", partition);
-        } else if (!subscriptions.isOffsetResetNeeded(partition)) {
-            log.debug("Skipping reset of partition {} since reset is no longer needed", partition);
-        } else if (!requestedResetTimestamp.equals(offsetResetStrategyTimestamp(partition))) {
-            log.debug("Skipping reset of partition {} since an alternative reset has been requested", partition);
-        } else {
-            SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition(
-                    offsetData.offset, offsetData.leaderEpoch, metadata.leaderAndEpoch(partition));
-            log.info("Resetting offset for partition {} to offset {}.", partition, position);
-            offsetData.leaderEpoch.ifPresent(epoch -> metadata.updateLastSeenEpochIfNewer(partition, epoch));
-            subscriptions.seek(partition, position);
-        }
+    private void resetOffsetIfNeeded(TopicPartition partition, OffsetResetStrategy requestedResetStrategy, ListOffsetData offsetData) {
+        SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition(
+                offsetData.offset, offsetData.leaderEpoch, metadata.leaderAndEpoch(partition));
+        offsetData.leaderEpoch.ifPresent(epoch -> metadata.updateLastSeenEpochIfNewer(partition, epoch));
+        subscriptions.maybeSeek(partition, position.offset, requestedResetStrategy);
     }
 
     private void resetOffsetsAsync(Map<TopicPartition, Long> partitionResetTimestamps) {
@@ -703,7 +701,7 @@ public class Fetcher<K, V> implements Closeable {
                         TopicPartition partition = fetchedOffset.getKey();
                         ListOffsetData offsetData = fetchedOffset.getValue();
                         ListOffsetRequest.PartitionData requestedReset = resetTimestamps.get(partition);
-                        resetOffsetIfNeeded(partition, requestedReset.timestamp, offsetData);
+                        resetOffsetIfNeeded(partition, timestampToOffsetResetStrategy(requestedReset.timestamp), offsetData);
                     }
                 }
 
@@ -1729,7 +1727,7 @@ public class Fetcher<K, V> implements Closeable {
         private void maybeUpdateAssignment(SubscriptionState subscription) {
             int newAssignmentId = subscription.assignmentId();
             if (this.assignmentId != newAssignmentId) {
-                Set<TopicPartition> newAssignedPartitions = new HashSet<>(subscription.assignedPartitions());
+                Set<TopicPartition> newAssignedPartitions = subscription.assignedPartitions();
                 for (TopicPartition tp : this.assignedPartitions) {
                     if (!newAssignedPartitions.contains(tp)) {
                         metrics.removeSensor(partitionLagMetricName(tp));
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
index 69bc6e4..8cba206 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java
@@ -38,7 +38,6 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
 import java.util.regex.Pattern;
@@ -61,10 +60,7 @@ import java.util.stream.Collectors;
  * Note that pause state as well as fetch/consumed positions are not preserved when partition
  * assignment is changed whether directly by the user or through a group rebalance.
  *
- * Thread Safety: this class is generally not thread-safe. It should only be accessed in the
- * consumer's calling thread. The only exception is {@link ConsumerMetadata} which accesses
- * the subscription state needed to build and handle Metadata requests. The thread-safe methods
- * are documented below.
+ * Thread Safety: this class is thread-safe.
  */
 public class SubscriptionState {
     private static final String SUBSCRIPTION_EXCEPTION_MESSAGE =
@@ -77,10 +73,10 @@ public class SubscriptionState {
     }
 
     /* the type of subscription */
-    private volatile SubscriptionType subscriptionType;
+    private SubscriptionType subscriptionType;
 
     /* the pattern user has requested */
-    private volatile Pattern subscribedPattern;
+    private Pattern subscribedPattern;
 
     /* the list of topics the user has requested */
     private Set<String> subscription;
@@ -88,7 +84,7 @@ public class SubscriptionState {
     /* The list of topics the group has subscribed to. This may include some topics which are not part
      * of `subscription` for the leader of a group since it is responsible for detecting metadata changes
      * which require a group rebalance. */
-    private final Set<String> groupSubscription;
+    private Set<String> groupSubscription;
 
     /* the partitions that are currently assigned, note that the order of partition matters (see FetchBuilder for more details) */
     private final PartitionStates<TopicPartitionState> assignment;
@@ -104,9 +100,9 @@ public class SubscriptionState {
     public SubscriptionState(LogContext logContext, OffsetResetStrategy defaultResetStrategy) {
         this.log = logContext.logger(this.getClass());
         this.defaultResetStrategy = defaultResetStrategy;
-        this.subscription = Collections.emptySet();
+        this.subscription = new HashSet<>();
         this.assignment = new PartitionStates<>();
-        this.groupSubscription = ConcurrentHashMap.newKeySet();
+        this.groupSubscription = new HashSet<>();
         this.subscribedPattern = null;
         this.subscriptionType = SubscriptionType.NONE;
     }
@@ -117,7 +113,7 @@ public class SubscriptionState {
      *
      * @return The current assignment Id
      */
-    public int assignmentId() {
+    synchronized int assignmentId() {
         return assignmentId;
     }
 
@@ -134,18 +130,19 @@ public class SubscriptionState {
             throw new IllegalStateException(SUBSCRIPTION_EXCEPTION_MESSAGE);
     }
 
-    public boolean subscribe(Set<String> topics, ConsumerRebalanceListener listener) {
-        if (listener == null)
-            throw new IllegalArgumentException("RebalanceListener cannot be null");
-
+    public synchronized boolean subscribe(Set<String> topics, ConsumerRebalanceListener listener) {
+        registerRebalanceListener(listener);
         setSubscriptionType(SubscriptionType.AUTO_TOPICS);
-
-        this.rebalanceListener = listener;
-
         return changeSubscription(topics);
     }
 
-    public boolean subscribeFromPattern(Set<String> topics) {
+    public synchronized void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
+        registerRebalanceListener(listener);
+        setSubscriptionType(SubscriptionType.AUTO_PATTERN);
+        this.subscribedPattern = pattern;
+    }
+
+    public synchronized boolean subscribeFromPattern(Set<String> topics) {
         if (subscriptionType != SubscriptionType.AUTO_PATTERN)
             throw new IllegalArgumentException("Attempt to subscribe from pattern while subscription type set to " +
                     subscriptionType);
@@ -157,8 +154,9 @@ public class SubscriptionState {
         if (subscription.equals(topicsToSubscribe))
             return false;
 
-        this.subscription = topicsToSubscribe;
-        this.groupSubscription.addAll(topicsToSubscribe);
+        subscription = topicsToSubscribe;
+        groupSubscription = new HashSet<>(groupSubscription);
+        groupSubscription.addAll(topicsToSubscribe);
         return true;
     }
 
@@ -167,17 +165,18 @@ public class SubscriptionState {
      * that it receives metadata updates for all topics that the group is interested in.
      * @param topics The topics to add to the group subscription
      */
-    public boolean groupSubscribe(Collection<String> topics) {
+    synchronized boolean groupSubscribe(Collection<String> topics) {
         if (!partitionsAutoAssigned())
             throw new IllegalStateException(SUBSCRIPTION_EXCEPTION_MESSAGE);
-        return this.groupSubscription.addAll(topics);
+        groupSubscription = new HashSet<>(groupSubscription);
+        return groupSubscription.addAll(topics);
     }
 
     /**
      * Reset the group's subscription to only contain topics subscribed by this consumer.
      */
-    public void resetGroupSubscription() {
-        this.groupSubscription.retainAll(subscription);
+    synchronized void resetGroupSubscription() {
+        groupSubscription = subscription;
     }
 
     /**
@@ -185,7 +184,7 @@ public class SubscriptionState {
      * note this is different from {@link #assignFromSubscribed(Collection)}
      * whose input partitions are provided from the subscribed topics.
      */
-    public boolean assignFromUser(Set<TopicPartition> partitions) {
+    public synchronized boolean assignFromUser(Set<TopicPartition> partitions) {
         setSubscriptionType(SubscriptionType.USER_ASSIGNED);
 
         if (this.assignment.partitionSet().equals(partitions))
@@ -212,29 +211,28 @@ public class SubscriptionState {
      *
      * @return true if assignments matches subscription, otherwise false
      */
-    public boolean assignFromSubscribed(Collection<TopicPartition> assignments) {
+    public synchronized boolean assignFromSubscribed(Collection<TopicPartition> assignments) {
         if (!this.partitionsAutoAssigned())
             throw new IllegalArgumentException("Attempt to dynamically assign partitions while manual assignment in use");
 
-        Predicate<TopicPartition> predicate = topicPartition -> {
+        boolean assignmentMatchedSubscription = true;
+        for (TopicPartition topicPartition : assignments) {
             if (this.subscribedPattern != null) {
-                boolean match = this.subscribedPattern.matcher(topicPartition.topic()).matches();
-                if (!match) {
+                assignmentMatchedSubscription = this.subscribedPattern.matcher(topicPartition.topic()).matches();
+                if (!assignmentMatchedSubscription) {
                     log.info("Assigned partition {} for non-subscribed topic regex pattern; subscription pattern is {}",
                             topicPartition,
                             this.subscribedPattern);
+                    break;
                 }
-                return match;
             } else {
-                boolean match = this.subscription.contains(topicPartition.topic());
-                if (!match) {
+                assignmentMatchedSubscription = this.subscription.contains(topicPartition.topic());
+                if (!assignmentMatchedSubscription) {
                     log.info("Assigned partition {} for non-subscribed topic; subscription is {}", topicPartition, this.subscription);
+                    break;
                 }
-                return match;
             }
-        };
-
-        boolean assignmentMatchedSubscription = assignments.stream().allMatch(predicate);
+        }
 
         if (assignmentMatchedSubscription) {
             Map<TopicPartition, TopicPartitionState> assignedPartitionStates = partitionToStateMap(
@@ -246,31 +244,27 @@ public class SubscriptionState {
         return assignmentMatchedSubscription;
     }
 
-    public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) {
+    private void registerRebalanceListener(ConsumerRebalanceListener listener) {
         if (listener == null)
             throw new IllegalArgumentException("RebalanceListener cannot be null");
-
-        setSubscriptionType(SubscriptionType.AUTO_PATTERN);
-
         this.rebalanceListener = listener;
-        this.subscribedPattern = pattern;
     }
 
     /**
-     * Check whether pattern subscription is in use. This is thread-safe.
+     * Check whether pattern subscription is in use.
      *
      */
-    public boolean hasPatternSubscription() {
+    synchronized boolean hasPatternSubscription() {
         return this.subscriptionType == SubscriptionType.AUTO_PATTERN;
     }
 
-    public boolean hasNoSubscriptionOrUserAssignment() {
+    public synchronized boolean hasNoSubscriptionOrUserAssignment() {
         return this.subscriptionType == SubscriptionType.NONE;
     }
 
-    public void unsubscribe() {
+    public synchronized void unsubscribe() {
         this.subscription = Collections.emptySet();
-        this.groupSubscription.clear();
+        this.groupSubscription = Collections.emptySet();
         this.assignment.clear();
         this.subscribedPattern = null;
         this.subscriptionType = SubscriptionType.NONE;
@@ -280,18 +274,16 @@ public class SubscriptionState {
     /**
      * Check whether a topic matches a subscribed pattern.
      *
-     * This is thread-safe, but it may not always reflect the most recent subscription pattern.
-     *
      * @return true if pattern subscription is in use and the topic matches the subscribed pattern, false otherwise
      */
-    public boolean matchesSubscribedPattern(String topic) {
+    synchronized boolean matchesSubscribedPattern(String topic) {
         Pattern pattern = this.subscribedPattern;
         if (hasPatternSubscription() && pattern != null)
             return pattern.matcher(topic).matches();
         return false;
     }
 
-    public Set<String> subscription() {
+    public synchronized Set<String> subscription() {
         if (partitionsAutoAssigned())
             return this.subscription;
         return Collections.emptySet();
@@ -309,19 +301,14 @@ public class SubscriptionState {
      * can do the partition assignment (which requires at least partition counts for all topics
      * to be assigned).
      *
-     * Note this is thread-safe since the Set is backed by a ConcurrentMap.
-     *
      * @return The union of all subscribed topics in the group if this member is the leader
      *   of the current generation; otherwise it returns the same set as {@link #subscription()}
      */
-    public Set<String> groupSubscription() {
+    synchronized Set<String> groupSubscription() {
         return this.groupSubscription;
     }
 
-    /**
-     * Note this is thread-safe since the Set is backed by a ConcurrentMap.
-     */
-    public boolean isGroupSubscribed(String topic) {
+    synchronized boolean isGroupSubscribed(String topic) {
         return groupSubscription.contains(topic);
     }
 
@@ -332,69 +319,87 @@ public class SubscriptionState {
         return state;
     }
 
-    public void seek(TopicPartition tp, FetchPosition position) {
+    private TopicPartitionState assignedStateOrNull(TopicPartition tp) {
+        return this.assignment.stateValue(tp);
+    }
+
+    public synchronized void seek(TopicPartition tp, FetchPosition position) {
         assignedState(tp).seek(position);
     }
 
-    public void seekAndValidate(TopicPartition tp, FetchPosition position) {
+    public synchronized void seekAndValidate(TopicPartition tp, FetchPosition position) {
         assignedState(tp).seekAndValidate(position);
     }
 
     public void seek(TopicPartition tp, long offset) {
-        seek(tp, new FetchPosition(offset, Optional.empty(), new Metadata.LeaderAndEpoch(Node.noNode(), Optional.empty())));
+        seek(tp, new FetchPosition(offset));
+    }
+
+    synchronized void maybeSeek(TopicPartition tp, long offset, OffsetResetStrategy requestedResetStrategy) {
+        TopicPartitionState state = assignedStateOrNull(tp);
+        if (state == null) {
+            log.debug("Skipping reset of partition {} since it is no longer assigned", tp);
+        } else if (!state.awaitingReset()) {
+            log.debug("Skipping reset of partition {} since reset is no longer needed", tp);
+        } else if (requestedResetStrategy != state.resetStrategy) {
+            log.debug("Skipping reset of partition {} since an alternative reset has been requested", tp);
+        } else {
+            log.info("Resetting offset for partition {} to offset {}.", tp, offset);
+            state.seek(new FetchPosition(offset));
+        }
     }
 
     /**
-     * @return an unmodifiable view of the currently assigned partitions
+     * @return a modifiable copy of the currently assigned partitions
      */
-    public Set<TopicPartition> assignedPartitions() {
-        return this.assignment.partitionSet();
+    public synchronized Set<TopicPartition> assignedPartitions() {
+        return new HashSet<>(this.assignment.partitionSet());
     }
 
     /**
      * Provides the number of assigned partitions in a thread safe manner.
      * @return the number of assigned partitions.
      */
-    public int numAssignedPartitions() {
+    synchronized int numAssignedPartitions() {
         return this.assignment.size();
     }
 
-    public List<TopicPartition> fetchablePartitions(Predicate<TopicPartition> isAvailable) {
+    synchronized List<TopicPartition> fetchablePartitions(Predicate<TopicPartition> isAvailable) {
         return assignment.stream()
                 .filter(tpState -> isAvailable.test(tpState.topicPartition()) && tpState.value().isFetchable())
                 .map(PartitionStates.PartitionState::topicPartition)
                 .collect(Collectors.toList());
     }
 
-    public boolean partitionsAutoAssigned() {
+    synchronized boolean partitionsAutoAssigned() {
         return this.subscriptionType == SubscriptionType.AUTO_TOPICS || this.subscriptionType == SubscriptionType.AUTO_PATTERN;
     }
 
-    public void position(TopicPartition tp, FetchPosition position) {
+    public synchronized void position(TopicPartition tp, FetchPosition position) {
         assignedState(tp).position(position);
     }
 
-    public boolean maybeValidatePosition(TopicPartition tp, Metadata.LeaderAndEpoch leaderAndEpoch) {
+    synchronized boolean maybeValidatePosition(TopicPartition tp, Metadata.LeaderAndEpoch leaderAndEpoch) {
         return assignedState(tp).maybeValidatePosition(leaderAndEpoch);
     }
 
-    public boolean awaitingValidation(TopicPartition tp) {
+    synchronized boolean awaitingValidation(TopicPartition tp) {
         return assignedState(tp).awaitingValidation();
     }
 
-    public void completeValidation(TopicPartition tp) {
+    public synchronized void completeValidation(TopicPartition tp) {
         assignedState(tp).validate();
     }
 
-    public FetchPosition validPosition(TopicPartition tp) {
+    public synchronized FetchPosition validPosition(TopicPartition tp) {
         return assignedState(tp).validPosition();
     }
 
-    public FetchPosition position(TopicPartition tp) {
-        return assignedState(tp).position();
+    synchronized public FetchPosition position(TopicPartition tp) {
+        return assignedState(tp).position;
     }
 
-    public Long partitionLag(TopicPartition tp, IsolationLevel isolationLevel) {
+    synchronized Long partitionLag(TopicPartition tp, IsolationLevel isolationLevel) {
         TopicPartitionState topicPartitionState = assignedState(tp);
         if (isolationLevel == IsolationLevel.READ_COMMITTED)
             return topicPartitionState.lastStableOffset == null ? null : topicPartitionState.lastStableOffset - topicPartitionState.position.offset;
@@ -402,21 +407,21 @@ public class SubscriptionState {
             return topicPartitionState.highWatermark == null ? null : topicPartitionState.highWatermark - topicPartitionState.position.offset;
     }
 
-    public Long partitionLead(TopicPartition tp) {
+    synchronized Long partitionLead(TopicPartition tp) {
         TopicPartitionState topicPartitionState = assignedState(tp);
         return topicPartitionState.logStartOffset == null ? null : topicPartitionState.position.offset - topicPartitionState.logStartOffset;
     }
 
-    public void updateHighWatermark(TopicPartition tp, long highWatermark) {
-        assignedState(tp).highWatermark = highWatermark;
+    synchronized void updateHighWatermark(TopicPartition tp, long highWatermark) {
+        assignedState(tp).highWatermark(highWatermark);
     }
 
-    public void updateLogStartOffset(TopicPartition tp, long logStartOffset) {
-        assignedState(tp).logStartOffset = logStartOffset;
+    synchronized void updateLogStartOffset(TopicPartition tp, long logStartOffset) {
+        assignedState(tp).logStartOffset(logStartOffset);
     }
 
-    public void updateLastStableOffset(TopicPartition tp, long lastStableOffset) {
-        assignedState(tp).lastStableOffset = lastStableOffset;
+    synchronized void updateLastStableOffset(TopicPartition tp, long lastStableOffset) {
+        assignedState(tp).lastStableOffset(lastStableOffset);
     }
 
     /**
@@ -427,7 +432,7 @@ public class SubscriptionState {
      * @param preferredReadReplicaId The preferred read replica
      * @param timeMs The time at which this preferred replica is no longer valid
      */
-    public void updatePreferredReadReplica(TopicPartition tp, int preferredReadReplicaId, Supplier<Long> timeMs) {
+    public synchronized void updatePreferredReadReplica(TopicPartition tp, int preferredReadReplicaId, Supplier<Long> timeMs) {
         assignedState(tp).updatePreferredReadReplica(preferredReadReplicaId, timeMs);
     }
 
@@ -438,7 +443,7 @@ public class SubscriptionState {
      * @param timeMs The current time
      * @return Returns the current preferred read replica, if it has been set and if it has not expired.
      */
-    public Optional<Integer> preferredReadReplica(TopicPartition tp, long timeMs) {
+    public synchronized Optional<Integer> preferredReadReplica(TopicPartition tp, long timeMs) {
         return assignedState(tp).preferredReadReplica(timeMs);
     }
 
@@ -448,11 +453,11 @@ public class SubscriptionState {
      * @param tp The topic partition
      * @return true if the preferred read replica was set, false otherwise.
      */
-    public Optional<Integer> clearPreferredReadReplica(TopicPartition tp) {
+    public synchronized Optional<Integer> clearPreferredReadReplica(TopicPartition tp) {
         return assignedState(tp).clearPreferredReadReplica();
     }
 
-    public Map<TopicPartition, OffsetAndMetadata> allConsumed() {
+    public synchronized Map<TopicPartition, OffsetAndMetadata> allConsumed() {
         Map<TopicPartition, OffsetAndMetadata> allConsumed = new HashMap<>();
         assignment.stream().forEach(state -> {
             TopicPartitionState partitionState = state.value();
@@ -463,41 +468,48 @@ public class SubscriptionState {
         return allConsumed;
     }
 
-    public void requestOffsetReset(TopicPartition partition, OffsetResetStrategy offsetResetStrategy) {
+    public synchronized void requestOffsetReset(TopicPartition partition, OffsetResetStrategy offsetResetStrategy) {
         assignedState(partition).reset(offsetResetStrategy);
     }
 
+    public synchronized void requestOffsetReset(Collection<TopicPartition> partitions, OffsetResetStrategy offsetResetStrategy) {
+        partitions.forEach(tp -> {
+            log.info("Seeking to {} offset of partition {}", offsetResetStrategy, tp);
+            assignedState(tp).reset(offsetResetStrategy);
+        });
+    }
+
     public void requestOffsetReset(TopicPartition partition) {
         requestOffsetReset(partition, defaultResetStrategy);
     }
 
-    public void setNextAllowedRetry(Set<TopicPartition> partitions, long nextAllowResetTimeMs) {
+    synchronized void setNextAllowedRetry(Set<TopicPartition> partitions, long nextAllowResetTimeMs) {
         for (TopicPartition partition : partitions) {
             assignedState(partition).setNextAllowedRetry(nextAllowResetTimeMs);
         }
     }
 
-    public boolean hasDefaultOffsetResetPolicy() {
+    boolean hasDefaultOffsetResetPolicy() {
         return defaultResetStrategy != OffsetResetStrategy.NONE;
     }
 
-    public boolean isOffsetResetNeeded(TopicPartition partition) {
+    public synchronized boolean isOffsetResetNeeded(TopicPartition partition) {
         return assignedState(partition).awaitingReset();
     }
 
-    public OffsetResetStrategy resetStrategy(TopicPartition partition) {
-        return assignedState(partition).resetStrategy;
+    public synchronized OffsetResetStrategy resetStrategy(TopicPartition partition) {
+        return assignedState(partition).resetStrategy();
     }
 
-    public boolean hasAllFetchPositions() {
+    public synchronized boolean hasAllFetchPositions() {
         return assignment.stream().allMatch(state -> state.value().hasValidPosition());
     }
 
-    public Set<TopicPartition> missingFetchPositions() {
+    Set<TopicPartition> missingFetchPositions() {
         return collectPartitions(state -> !state.hasPosition(), Collectors.toSet());
     }
 
-    private <T extends Collection<TopicPartition>> T collectPartitions(Predicate<TopicPartitionState> filter, Collector<TopicPartition, ?, T> collector) {
+    private synchronized <T extends Collection<TopicPartition>> T collectPartitions(Predicate<TopicPartitionState> filter, Collector<TopicPartition, ?, T> collector) {
         return assignment.stream()
                 .filter(state -> filter.test(state.value()))
                 .map(PartitionStates.PartitionState::topicPartition)
@@ -505,7 +517,7 @@ public class SubscriptionState {
     }
 
 
-    public void resetMissingPositions() {
+    public synchronized void resetMissingPositions() {
         final Set<TopicPartition> partitionsWithNoOffsets = new HashSet<>();
         assignment.stream().forEach(state -> {
             TopicPartition tp = state.topicPartition();
@@ -514,7 +526,7 @@ public class SubscriptionState {
                 if (defaultResetStrategy == OffsetResetStrategy.NONE)
                     partitionsWithNoOffsets.add(tp);
                 else
-                    partitionState.reset(defaultResetStrategy);
+                    requestOffsetReset(tp);
             }
         });
 
@@ -522,50 +534,53 @@ public class SubscriptionState {
             throw new NoOffsetForPartitionException(partitionsWithNoOffsets);
     }
 
-    public Set<TopicPartition> partitionsNeedingReset(long nowMs) {
+    Set<TopicPartition> partitionsNeedingReset(long nowMs) {
         return collectPartitions(state -> state.awaitingReset() && !state.awaitingRetryBackoff(nowMs),
                 Collectors.toSet());
     }
 
-    public Set<TopicPartition> partitionsNeedingValidation(long nowMs) {
+    Set<TopicPartition> partitionsNeedingValidation(long nowMs) {
         return collectPartitions(state -> state.awaitingValidation() && !state.awaitingRetryBackoff(nowMs),
                 Collectors.toSet());
     }
 
-    public boolean isAssigned(TopicPartition tp) {
+    public synchronized boolean isAssigned(TopicPartition tp) {
         return assignment.contains(tp);
     }
 
-    public boolean isPaused(TopicPartition tp) {
-        return isAssigned(tp) && assignedState(tp).paused;
+    public synchronized boolean isPaused(TopicPartition tp) {
+        TopicPartitionState assignedOrNull = assignedStateOrNull(tp);
+        return assignedOrNull != null && assignedOrNull.isPaused();
     }
 
-    public boolean isFetchable(TopicPartition tp) {
-        return isAssigned(tp) && assignedState(tp).isFetchable();
+    synchronized boolean isFetchable(TopicPartition tp) {
+        TopicPartitionState assignedOrNull = assignedStateOrNull(tp);
+        return assignedOrNull != null && assignedOrNull.isFetchable();
     }
 
-    public boolean hasValidPosition(TopicPartition tp) {
-        return isAssigned(tp) && assignedState(tp).hasValidPosition();
+    public synchronized boolean hasValidPosition(TopicPartition tp) {
+        TopicPartitionState assignedOrNull = assignedStateOrNull(tp);
+        return assignedOrNull != null && assignedOrNull.hasValidPosition();
     }
 
-    public void pause(TopicPartition tp) {
+    public synchronized void pause(TopicPartition tp) {
         assignedState(tp).pause();
     }
 
-    public void resume(TopicPartition tp) {
+    public synchronized void resume(TopicPartition tp) {
         assignedState(tp).resume();
     }
 
-    public void requestFailed(Set<TopicPartition> partitions, long nextRetryTimeMs) {
+    synchronized void requestFailed(Set<TopicPartition> partitions, long nextRetryTimeMs) {
         for (TopicPartition partition : partitions)
             assignedState(partition).requestFailed(nextRetryTimeMs);
     }
 
-    public void movePartitionToEnd(TopicPartition tp) {
+    synchronized void movePartitionToEnd(TopicPartition tp) {
         assignment.moveToEnd(tp);
     }
 
-    public ConsumerRebalanceListener rebalanceListener() {
+    public synchronized ConsumerRebalanceListener rebalanceListener() {
         return rebalanceListener;
     }
 
@@ -747,10 +762,6 @@ public class SubscriptionState {
             }
         }
 
-        private FetchPosition position() {
-            return position;
-        }
-
         private void pause() {
             this.paused = true;
         }
@@ -763,6 +774,21 @@ public class SubscriptionState {
             return !paused && hasValidPosition();
         }
 
+        private void highWatermark(Long highWatermark) {
+            this.highWatermark = highWatermark;
+        }
+
+        private void logStartOffset(Long logStartOffset) {
+            this.logStartOffset = logStartOffset;
+        }
+
+        private void lastStableOffset(Long lastStableOffset) {
+            this.lastStableOffset = lastStableOffset;
+        }
+
+        private OffsetResetStrategy resetStrategy() {
+            return resetStrategy;
+        }
     }
 
     /**
@@ -869,8 +895,12 @@ public class SubscriptionState {
      */
     public static class FetchPosition {
         public final long offset;
-        public final Optional<Integer> offsetEpoch;
-        public final Metadata.LeaderAndEpoch currentLeader;
+        final Optional<Integer> offsetEpoch;
+        final Metadata.LeaderAndEpoch currentLeader;
+
+        FetchPosition(long offset) {
+            this(offset, Optional.empty(), new Metadata.LeaderAndEpoch(Node.noNode(), Optional.empty()));
+        }
 
         public FetchPosition(long offset, Optional<Integer> offsetEpoch, Metadata.LeaderAndEpoch currentLeader) {
             this.offset = offset;
@@ -882,7 +912,7 @@ public class SubscriptionState {
          * Test if it is "safe" to fetch from a given leader and epoch. This effectively is testing if
          * {@link Metadata.LeaderAndEpoch} known to the subscription is equal to the one supplied by the caller.
          */
-        public boolean safeToFetchFrom(Metadata.LeaderAndEpoch leaderAndEpoch) {
+        boolean safeToFetchFrom(Metadata.LeaderAndEpoch leaderAndEpoch) {
             return !currentLeader.leader.isEmpty() && currentLeader.equals(leaderAndEpoch);
         }
 
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 754b46e..1e64e3b 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -109,6 +109,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
@@ -130,6 +131,8 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.spy;
 
 @SuppressWarnings("deprecation")
 public class FetcherTest {
@@ -1483,6 +1486,51 @@ public class FetcherTest {
         assertEquals(237L, subscriptions.position(tp0).offset);
     }
 
+    @Test(timeout = 10000)
+    public void testEarlierOffsetResetArrivesLate() throws InterruptedException {
+        LogContext lc = new LogContext();
+        buildFetcher(spy(new SubscriptionState(lc, OffsetResetStrategy.EARLIEST)), lc);
+        assignFromUser(singleton(tp0));
+
+        ExecutorService es = Executors.newSingleThreadExecutor();
+        CountDownLatch latchLatestStart = new CountDownLatch(1);
+        CountDownLatch latchEarliestStart = new CountDownLatch(1);
+        CountDownLatch latchEarliestDone = new CountDownLatch(1);
+        CountDownLatch latchEarliestFinish = new CountDownLatch(1);
+        try {
+            doAnswer(invocation -> {
+                latchLatestStart.countDown();
+                latchEarliestStart.await();
+                Object result = invocation.callRealMethod();
+                latchEarliestDone.countDown();
+                return result;
+            }).when(subscriptions).maybeSeek(tp0, 0L, OffsetResetStrategy.EARLIEST);
+
+            es.submit(() -> {
+                subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST);
+                fetcher.resetOffsetsIfNeeded();
+                consumerClient.pollNoWakeup();
+                client.respond(listOffsetResponse(Errors.NONE, 1L, 0L));
+                consumerClient.pollNoWakeup();
+                latchEarliestFinish.countDown();
+            }, Void.class);
+
+            latchLatestStart.await();
+            subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST);
+            fetcher.resetOffsetsIfNeeded();
+            consumerClient.pollNoWakeup();
+            client.respond(listOffsetResponse(Errors.NONE, 1L, 10L));
+            latchEarliestStart.countDown();
+            latchEarliestDone.await();
+            consumerClient.pollNoWakeup();
+            latchEarliestFinish.await();
+            assertEquals(10, subscriptions.position(tp0).offset);
+        } finally {
+            es.shutdown();
+            es.awaitTermination(10000, TimeUnit.MILLISECONDS);
+        }
+    }
+
     @Test
     public void testChangeResetWithInFlightReset() {
         buildFetcher();
@@ -2816,7 +2864,8 @@ public class FetcherTest {
         for (int i = 0; i < numPartitions; i++)
             topicPartitions.add(new TopicPartition(topicName, i));
 
-        buildDependencies(new MetricConfig(), OffsetResetStrategy.EARLIEST, Long.MAX_VALUE);
+        LogContext logContext = new LogContext();
+        buildDependencies(new MetricConfig(), Long.MAX_VALUE, new SubscriptionState(logContext, OffsetResetStrategy.EARLIEST), logContext);
 
         fetcher = new Fetcher<byte[], byte[]>(
                 new LogContext(),
@@ -3616,7 +3665,26 @@ public class FetcherTest {
                                      int maxPollRecords,
                                      IsolationLevel isolationLevel,
                                      long metadataExpireMs) {
-        buildDependencies(metricConfig, offsetResetStrategy, metadataExpireMs);
+        LogContext logContext = new LogContext();
+        SubscriptionState subscriptionState = new SubscriptionState(logContext, offsetResetStrategy);
+        buildFetcher(metricConfig, keyDeserializer, valueDeserializer, maxPollRecords, isolationLevel, metadataExpireMs,
+                subscriptionState, logContext);
+    }
+
+    private void buildFetcher(SubscriptionState subscriptionState, LogContext logContext) {
+        buildFetcher(new MetricConfig(), new ByteArrayDeserializer(), new ByteArrayDeserializer(), Integer.MAX_VALUE,
+                IsolationLevel.READ_UNCOMMITTED, Long.MAX_VALUE, subscriptionState, logContext);
+    }
+
+    private <K, V> void buildFetcher(MetricConfig metricConfig,
+                                     Deserializer<K> keyDeserializer,
+                                     Deserializer<V> valueDeserializer,
+                                     int maxPollRecords,
+                                     IsolationLevel isolationLevel,
+                                     long metadataExpireMs,
+                                     SubscriptionState subscriptionState,
+                                     LogContext logContext) {
+        buildDependencies(metricConfig, metadataExpireMs, subscriptionState, logContext);
         fetcher = new Fetcher<>(
                 new LogContext(),
                 consumerClient,
@@ -3640,11 +3708,12 @@ public class FetcherTest {
                 apiVersions);
     }
 
-
-    private void buildDependencies(MetricConfig metricConfig, OffsetResetStrategy offsetResetStrategy, long metadataExpireMs) {
-        LogContext logContext = new LogContext();
+    private void buildDependencies(MetricConfig metricConfig,
+                                   long metadataExpireMs,
+                                   SubscriptionState subscriptionState,
+                                   LogContext logContext) {
         time = new MockTime(1);
-        subscriptions = new SubscriptionState(logContext, offsetResetStrategy);
+        subscriptions = subscriptionState;
         metadata = new ConsumerMetadata(0, metadataExpireMs, false, false,
                 subscriptions, logContext, new ClusterResourceListeners());
         client = new MockClient(time, metadata);