You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by da...@apache.org on 2022/12/12 09:00:38 UTC

[kafka] branch 3.4 updated: KAFKA-14379: Consumer should refresh preferred read replica on update metadata (#12956)

This is an automated email from the ASF dual-hosted git repository.

dajac pushed a commit to branch 3.4
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.4 by this push:
     new b296e167cab KAFKA-14379: Consumer should refresh preferred read replica on update metadata (#12956)
b296e167cab is described below

commit b296e167cabd7484232b9e3b50fbd41eacbd9c22
Author: Artem Livshits <84...@users.noreply.github.com>
AuthorDate: Mon Dec 12 00:55:22 2022 -0800

    KAFKA-14379: Consumer should refresh preferred read replica on update metadata (#12956)
    
    The consumer (fetcher) used to refresh the preferred read replica on
    three conditions:
    
    1. the consumer receives an OFFSET_OUT_OF_RANGE error
    2. the follower does not exist in the client's metadata (i.e., offline)
    3. after metadata.max.age.ms (5 min default)
    
    For other errors, it will continue to reach to the possibly unavailable
    follower and only after 5 minutes will it refresh the preferred read
    replica and go back to the leader.
    
    Another problem is that the client might have stale metadata and not
    send fetches to preferred replica, even after the leader redirects to
    the preferred replica.
    
    A specific example is when a partition is reassigned. the consumer will
    get NOT_LEADER_OR_FOLLOWER which triggers a metadata update but the
    preferred read replica will not be refreshed as the follower is still
    online. it will continue to reach out to the old follower until the
    preferred read replica expires.
    
    The consumer can instead refresh its preferred read replica whenever it
    makes a metadata update request, so when the consumer receives i.e.
    NOT_LEADER_OR_FOLLOWER it can find the new preferred read replica without
    waiting for the expiration.
    
    Generally, we will rely on the leader to choose the correct preferred
    read replica and have the consumer fail fast (clear preferred read replica
    cache) on errors and reach out to the leader.
    
    Co-authored-by: Jeff Kim <je...@confluent.io>
    
    Reviewers: David Jacot <dj...@confluent.io>, Jason Gustafson <ja...@confluent.io>
---
 .../apache/kafka/clients/FetchSessionHandler.java  |  7 ++
 .../kafka/clients/consumer/internals/Fetcher.java  | 18 ++--
 .../main/java/org/apache/kafka/common/Cluster.java |  6 +-
 .../org/apache/kafka/clients/MetadataTest.java     | 20 +++++
 .../clients/consumer/internals/FetcherTest.java    | 85 ++++++++++++++++---
 .../kafka/common/requests/RequestTestUtils.java    | 29 ++++++-
 .../server/FetchFromFollowerIntegrationTest.scala  | 97 ++++++++++++++++------
 7 files changed, 218 insertions(+), 44 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
index 739a47631e0..44354947539 100644
--- a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
+++ b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java
@@ -602,4 +602,11 @@ public class FetchSessionHandler {
         log.info("Error sending fetch request {} to node {}:", nextMetadata, node, t);
         nextMetadata = nextMetadata.nextCloseExisting();
     }
+
+    /**
+     * Get the fetch request session's partitions.
+     */
+    public Set<TopicPartition> sessionTopicPartitions() {
+        return sessionPartitions.keySet();
+    }
 }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
index 73ffd217efe..c93b675f755 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java
@@ -346,6 +346,7 @@ public class Fetcher<K, V> implements Closeable {
                             FetchSessionHandler handler = sessionHandler(fetchTarget.id());
                             if (handler != null) {
                                 handler.handleError(e);
+                                handler.sessionTopicPartitions().forEach(subscriptions::clearPreferredReadReplica);
                             }
                         } finally {
                             nodesWithPendingFetchRequests.remove(fetchTarget.id());
@@ -1154,7 +1155,9 @@ public class Fetcher<K, V> implements Closeable {
             } else {
                 log.trace("Not fetching from {} for partition {} since it is marked offline or is missing from our metadata," +
                           " using the leader instead.", nodeId, partition);
-                subscriptions.clearPreferredReadReplica(partition);
+                // Note that this condition may happen due to stale metadata, so we clear preferred replica and
+                // refresh metadata.
+                requestMetadataUpdate(partition);
                 return leaderReplica;
             }
         } else {
@@ -1335,16 +1338,16 @@ public class Fetcher<K, V> implements Closeable {
                        error == Errors.FENCED_LEADER_EPOCH ||
                        error == Errors.OFFSET_NOT_AVAILABLE) {
                 log.debug("Error in fetch for partition {}: {}", tp, error.exceptionName());
-                this.metadata.requestUpdate();
+                requestMetadataUpdate(tp);
             } else if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) {
                 log.warn("Received unknown topic or partition error in fetch for partition {}", tp);
-                this.metadata.requestUpdate();
+                requestMetadataUpdate(tp);
             } else if (error == Errors.UNKNOWN_TOPIC_ID) {
                 log.warn("Received unknown topic ID error in fetch for partition {}", tp);
-                this.metadata.requestUpdate();
+                requestMetadataUpdate(tp);
             } else if (error == Errors.INCONSISTENT_TOPIC_ID) {
                 log.warn("Received inconsistent topic ID error in fetch for partition {}", tp);
-                this.metadata.requestUpdate();
+                requestMetadataUpdate(tp);
             } else if (error == Errors.OFFSET_OUT_OF_RANGE) {
                 Optional<Integer> clearedReplicaId = subscriptions.clearPreferredReadReplica(tp);
                 if (!clearedReplicaId.isPresent()) {
@@ -1944,4 +1947,9 @@ public class Fetcher<K, V> implements Closeable {
         return partitions.stream().map(TopicPartition::topic).collect(Collectors.toSet());
     }
 
+    private void requestMetadataUpdate(TopicPartition topicPartition) {
+        this.metadata.requestUpdate();
+        this.subscriptions.clearPreferredReadReplica(topicPartition);
+    }
+
 }
diff --git a/clients/src/main/java/org/apache/kafka/common/Cluster.java b/clients/src/main/java/org/apache/kafka/common/Cluster.java
index 96e310df9a5..84b77ef5f40 100644
--- a/clients/src/main/java/org/apache/kafka/common/Cluster.java
+++ b/clients/src/main/java/org/apache/kafka/common/Cluster.java
@@ -253,7 +253,11 @@ public final class Cluster {
     public Optional<Node> nodeIfOnline(TopicPartition partition, int id) {
         Node node = nodeById(id);
         PartitionInfo partitionInfo = partition(partition);
-        if (node != null && partitionInfo != null && !Arrays.asList(partitionInfo.offlineReplicas()).contains(node)) {
+
+        if (node != null && partitionInfo != null &&
+            !Arrays.asList(partitionInfo.offlineReplicas()).contains(node) &&
+            Arrays.asList(partitionInfo.replicas()).contains(node)) {
+
             return Optional.of(node);
         } else {
             return Optional.empty();
diff --git a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
index e900aa44e78..30e7c3ab186 100644
--- a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java
@@ -791,6 +791,26 @@ public class MetadataTest {
         assertEquals(metadata.fetch().nodeById(1).id(), 1);
     }
 
+    @Test
+    public void testNodeIfOnlineWhenNotInReplicaSet() {
+        Map<String, Integer> partitionCounts = new HashMap<>();
+        partitionCounts.put("topic-1", 1);
+        Node node0 = new Node(0, "localhost", 9092);
+
+        MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 2, Collections.emptyMap(), partitionCounts, _tp -> 99,
+            (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) ->
+                new MetadataResponse.PartitionMetadata(error, partition, Optional.of(node0.id()), leaderEpoch,
+                    Collections.singletonList(node0.id()), Collections.emptyList(),
+                        Collections.emptyList()), ApiKeys.METADATA.latestVersion(), Collections.emptyMap());
+        metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, 0L);
+        metadata.updateWithCurrentRequestVersion(metadataResponse, false, 10L);
+
+        TopicPartition tp = new TopicPartition("topic-1", 0);
+
+        assertEquals(1, metadata.fetch().nodeById(1).id());
+        assertFalse(metadata.fetch().nodeIfOnline(tp, 1).isPresent());
+    }
+
     @Test
     public void testNodeIfOnlineNonExistentTopicPartition() {
         MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith(2, Collections.emptyMap());
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
index 8c45839283b..0e14355ecc8 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java
@@ -4692,12 +4692,12 @@ public class FetcherTest {
                 Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED, Duration.ofMinutes(5).toMillis());
 
         subscriptions.assignFromUser(singleton(tp0));
-        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds, false));
         subscriptions.seek(tp0, 0);
 
         // Node preferred replica before first fetch response
         Node selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds());
-        assertEquals(selected.id(), -1);
+        assertEquals(-1, selected.id());
 
         assertEquals(1, fetcher.sendFetches());
         assertFalse(fetcher.hasCompletedFetches());
@@ -4711,9 +4711,9 @@ public class FetcherTest {
         Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> partitionRecords = fetchedRecords();
         assertTrue(partitionRecords.containsKey(tp0));
 
-        // verify
+        // Verify
         selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds());
-        assertEquals(selected.id(), 1);
+        assertEquals(1, selected.id());
 
 
         assertEquals(1, fetcher.sendFetches());
@@ -4726,7 +4726,75 @@ public class FetcherTest {
         assertTrue(fetcher.hasCompletedFetches());
         fetchedRecords();
         selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds());
-        assertEquals(selected.id(), -1);
+        assertEquals(-1, selected.id());
+    }
+
+    @Test
+    public void testFetchDisconnectedShouldClearPreferredReadReplica() {
+        buildFetcher(new MetricConfig(), OffsetResetStrategy.EARLIEST, new BytesDeserializer(), new BytesDeserializer(),
+                Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED, Duration.ofMinutes(5).toMillis());
+
+        subscriptions.assignFromUser(singleton(tp0));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds, false));
+        subscriptions.seek(tp0, 0);
+        assertEquals(1, fetcher.sendFetches());
+
+        // Set preferred read replica to node=1
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L,
+                FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(1)));
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+
+        // Verify
+        Node selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds());
+        assertEquals(1, selected.id());
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        // Disconnect - preferred read replica should be cleared.
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0), true);
+
+        consumerClient.poll(time.timer(0));
+        assertFalse(fetcher.hasCompletedFetches());
+        fetchedRecords();
+        selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds());
+        assertEquals(-1, selected.id());
+    }
+
+    @Test
+    public void testFetchErrorShouldClearPreferredReadReplica() {
+        buildFetcher(new MetricConfig(), OffsetResetStrategy.EARLIEST, new BytesDeserializer(), new BytesDeserializer(),
+                Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED, Duration.ofMinutes(5).toMillis());
+
+        subscriptions.assignFromUser(singleton(tp0));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds, false));
+        subscriptions.seek(tp0, 0);
+        assertEquals(1, fetcher.sendFetches());
+
+        // Set preferred read replica to node=1
+        client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L,
+                FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(1)));
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+
+        // Verify
+        Node selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds());
+        assertEquals(1, selected.id());
+        assertEquals(1, fetcher.sendFetches());
+        assertFalse(fetcher.hasCompletedFetches());
+
+        // Error - preferred read replica should be cleared. An actual error response will contain -1 as the
+        // preferred read replica. In the test we want to ensure that we are handling the error.
+        client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.EMPTY, Errors.NOT_LEADER_OR_FOLLOWER, -1L,
+                FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(1)));
+
+        consumerClient.poll(time.timer(0));
+        assertTrue(fetcher.hasCompletedFetches());
+        fetchedRecords();
+        selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds());
+        assertEquals(-1, selected.id());
     }
 
     @Test
@@ -4735,7 +4803,7 @@ public class FetcherTest {
                 Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED, Duration.ofMinutes(5).toMillis());
 
         subscriptions.assignFromUser(singleton(tp0));
-        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds));
+        client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds, false));
 
         subscriptions.seek(tp0, 0);
 
@@ -5157,11 +5225,6 @@ public class FetcherTest {
                 subscriptionState, logContext);
     }
 
-    private void buildFetcher(SubscriptionState subscriptionState, LogContext logContext) {
-        buildFetcher(new MetricConfig(), new ByteArrayDeserializer(), new ByteArrayDeserializer(), Integer.MAX_VALUE,
-                IsolationLevel.READ_UNCOMMITTED, Long.MAX_VALUE, subscriptionState, logContext);
-    }
-
     private <K, V> void buildFetcher(MetricConfig metricConfig,
                                      Deserializer<K> keyDeserializer,
                                      Deserializer<V> valueDeserializer,
diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java
index d50e1b90913..ebc74c807c4 100644
--- a/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java
+++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java
@@ -37,6 +37,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.function.Function;
+import java.util.stream.Collectors;
 
 public class RequestTestUtils {
 
@@ -177,6 +178,16 @@ public class RequestTestUtils {
                 topicIds);
     }
 
+    public static MetadataResponse metadataUpdateWithIds(final int numNodes,
+                                                         final Map<String, Integer> topicPartitionCounts,
+                                                         final Function<TopicPartition, Integer> epochSupplier,
+                                                         final Map<String, Uuid> topicIds,
+                                                         final Boolean leaderOnly) {
+        return metadataUpdateWith("kafka-cluster", numNodes, Collections.emptyMap(),
+                topicPartitionCounts, epochSupplier, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(),
+                topicIds, leaderOnly);
+    }
+
     public static MetadataResponse metadataUpdateWithIds(final String clusterId,
                                                          final int numNodes,
                                                          final Map<String, Errors> topicErrors,
@@ -195,6 +206,20 @@ public class RequestTestUtils {
                                                       final PartitionMetadataSupplier partitionSupplier,
                                                       final short responseVersion,
                                                       final Map<String, Uuid> topicIds) {
+        return metadataUpdateWith(clusterId, numNodes, topicErrors,
+                topicPartitionCounts, epochSupplier, partitionSupplier,
+                responseVersion, topicIds, true);
+    }
+
+    public static MetadataResponse metadataUpdateWith(final String clusterId,
+                                                      final int numNodes,
+                                                      final Map<String, Errors> topicErrors,
+                                                      final Map<String, Integer> topicPartitionCounts,
+                                                      final Function<TopicPartition, Integer> epochSupplier,
+                                                      final PartitionMetadataSupplier partitionSupplier,
+                                                      final short responseVersion,
+                                                      final Map<String, Uuid> topicIds,
+                                                      final Boolean leaderOnly) {
         final List<Node> nodes = new ArrayList<>(numNodes);
         for (int i = 0; i < numNodes; i++)
             nodes.add(new Node(i, "localhost", 1969 + i));
@@ -208,10 +233,10 @@ public class RequestTestUtils {
             for (int i = 0; i < numPartitions; i++) {
                 TopicPartition tp = new TopicPartition(topic, i);
                 Node leader = nodes.get(i % nodes.size());
-                List<Integer> replicaIds = Collections.singletonList(leader.id());
+                List<Integer> replicaIds = leaderOnly ? Collections.singletonList(leader.id()) : nodes.stream().map(Node::id).collect(Collectors.toList());
                 partitionMetadata.add(partitionSupplier.supply(
                         Errors.NONE, tp, Optional.of(leader.id()), Optional.ofNullable(epochSupplier.apply(tp)),
-                        replicaIds, replicaIds, replicaIds));
+                        replicaIds, replicaIds, Collections.emptyList()));
             }
 
             topicMetadata.add(new MetadataResponse.TopicMetadata(Errors.NONE, topic, topicIds.getOrDefault(topic, Uuid.ZERO_UUID),
diff --git a/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala b/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala
index a1e0d20b4e8..e0d7b358786 100644
--- a/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/server/FetchFromFollowerIntegrationTest.scala
@@ -18,9 +18,11 @@ package integration.kafka.server
 
 import kafka.server.{BaseFetchRequestTest, KafkaConfig}
 import kafka.utils.{TestInfoUtils, TestUtils}
+import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer}
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.protocol.{ApiKeys, Errors}
 import org.apache.kafka.common.requests.FetchResponse
+import org.apache.kafka.common.serialization.ByteArrayDeserializer
 import org.junit.jupiter.api.Assertions.assertEquals
 import org.junit.jupiter.api.{Test, Timeout}
 import org.junit.jupiter.params.ParameterizedTest
@@ -103,8 +105,71 @@ class FetchFromFollowerIntegrationTest extends BaseFetchRequestTest {
 
     TestUtils.generateAndProduceMessages(brokers, topic, numMessages = 10)
 
+    assertEquals(1, getPreferredReplica)
+
+    // Shutdown follower broker.
+    brokers(followerBrokerId).shutdown()
+    val topicPartition = new TopicPartition(topic, 0)
+    TestUtils.waitUntilTrue(() => {
+      val endpoints = brokers(leaderBrokerId).metadataCache.getPartitionReplicaEndpoints(topicPartition, listenerName)
+      !endpoints.contains(followerBrokerId)
+    }, "follower is still reachable.")
+
+    assertEquals(-1, getPreferredReplica)
+  }
+
+  @Test
+  def testFetchFromFollowerWithRoll(): Unit = {
+    // Create a topic with 2 replicas where broker 0 is the leader and 1 is the follower.
+    val admin = createAdminClient()
+    TestUtils.createTopicWithAdmin(
+      admin,
+      topic,
+      brokers,
+      replicaAssignment = Map(0 -> Seq(leaderBrokerId, followerBrokerId))
+    )
+
+    // Create consumer with client.rack = follower id.
+    val consumerProps = new Properties
+    consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers())
+    consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, "test-group")
+    consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
+    consumerProps.put(ConsumerConfig.CLIENT_RACK_CONFIG, followerBrokerId.toString)
+    val consumer = new KafkaConsumer(consumerProps, new ByteArrayDeserializer, new ByteArrayDeserializer)
+    try {
+      consumer.subscribe(List(topic).asJava)
+
+      // Wait until preferred replica is set to follower.
+      TestUtils.waitUntilTrue(() => {
+        getPreferredReplica == 1
+      }, "Preferred replica is not set")
+
+      // Produce and consume.
+      TestUtils.generateAndProduceMessages(brokers, topic, numMessages = 1)
+      TestUtils.pollUntilAtLeastNumRecords(consumer, 1)
+
+      // Shutdown follower, produce and consume should work.
+      brokers(followerBrokerId).shutdown()
+      TestUtils.generateAndProduceMessages(brokers, topic, numMessages = 1)
+      TestUtils.pollUntilAtLeastNumRecords(consumer, 1)
+
+      // Start the follower and wait until preferred replica is set to follower.
+      brokers(followerBrokerId).startup()
+      TestUtils.waitUntilTrue(() => {
+        getPreferredReplica == 1
+      }, "Preferred replica is not set")
+
+      // Produce and consume should still work.
+      TestUtils.generateAndProduceMessages(brokers, topic, numMessages = 1)
+      TestUtils.pollUntilAtLeastNumRecords(consumer, 1)
+    } finally {
+      consumer.close()
+    }
+  }
+
+  private def getPreferredReplica: Int = {
     val topicPartition = new TopicPartition(topic, 0)
-    val offsetMap = Map(topicPartition -> 10L)
+    val offsetMap = Map(topicPartition -> 0L)
 
     val request = createConsumerFetchRequest(
       maxResponseBytes = 1000,
@@ -112,35 +177,17 @@ class FetchFromFollowerIntegrationTest extends BaseFetchRequestTest {
       Seq(topicPartition),
       offsetMap,
       ApiKeys.FETCH.latestVersion,
-      maxWaitMs = 20000,
+      maxWaitMs = 500,
       minBytes = 1,
       rackId = followerBrokerId.toString
     )
-    var response = connectAndReceive[FetchResponse](request, brokers(leaderBrokerId).socketServer)
-    assertEquals(Errors.NONE, response.error)
-    assertEquals(Map(Errors.NONE -> 2).asJava, response.errorCounts)
-    validatePreferredReadReplica(response, preferredReadReplica = 1)
-
-    // Shutdown follower broker. Consumer will reach out to leader after metadata.max.age.ms
-    brokers(followerBrokerId).shutdown()
-    TestUtils.waitUntilTrue(() => {
-      val endpoints = brokers(leaderBrokerId).metadataCache.getPartitionReplicaEndpoints(topicPartition, listenerName)
-      !endpoints.contains(followerBrokerId)
-    }, "follower is still reachable.")
-
-    response = connectAndReceive[FetchResponse](request, brokers(leaderBrokerId).socketServer)
+    val response = connectAndReceive[FetchResponse](request, brokers(leaderBrokerId).socketServer)
     assertEquals(Errors.NONE, response.error)
     assertEquals(Map(Errors.NONE -> 2).asJava, response.errorCounts)
-    validatePreferredReadReplica(response, preferredReadReplica = -1)
-  }
-
-  private def validatePreferredReadReplica(response: FetchResponse, preferredReadReplica: Int): Unit = {
     assertEquals(1, response.data.responses.size)
-    response.data.responses.forEach { topicResponse =>
-      assertEquals(1, topicResponse.partitions.size)
-      topicResponse.partitions.forEach { partitionResponse =>
-        assertEquals(preferredReadReplica, partitionResponse.preferredReadReplica)
-      }
-    }
+    val topicResponse = response.data.responses.get(0)
+    assertEquals(1, topicResponse.partitions.size)
+
+    topicResponse.partitions.get(0).preferredReadReplica
   }
 }