You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2022/08/24 20:20:08 UTC

[kafka] branch 3.3 updated: MINOR: A few cleanups for DescribeQuorum APIs (#12548)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 3.3
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/3.3 by this push:
     new 12759d656c5 MINOR: A few cleanups for DescribeQuorum APIs (#12548)
12759d656c5 is described below

commit 12759d656c5d077bffa50f2db0be78f49e390ac8
Author: Jason Gustafson <ja...@confluent.io>
AuthorDate: Wed Aug 24 13:12:14 2022 -0700

    MINOR: A few cleanups for DescribeQuorum APIs (#12548)
    
    A few small cleanups in the `DescribeQuorum` API and handling logic:
    
    - Change field types in `QuorumInfo`:
      - `leaderId`: `Integer` -> `int`
      - `leaderEpoch`: `Integer` -> `long` (to allow for type expansion in the future)
      - `highWatermark`: `Long` -> `long`
    - Use field names `lastFetchTimestamp` and `lastCaughtUpTimestamp` consistently
    - Move construction of `DescribeQuorumResponseData.PartitionData` into `LeaderState`
    - Consolidate fetch time/offset update logic into `LeaderState.ReplicaState.updateFollowerState`
    
    Reviewers: Luke Chen <sh...@gmail.com>, José Armando García Sancio <js...@users.noreply.github.com>
---
 .../kafka/clients/admin/KafkaAdminClient.java      |  19 +-
 .../org/apache/kafka/clients/admin/QuorumInfo.java |  68 +++--
 .../common/requests/DescribeQuorumResponse.java    |  22 +-
 .../kafka/clients/admin/KafkaAdminClientTest.java  |   6 +-
 .../scala/kafka/admin/MetadataQuorumCommand.scala  |  10 +-
 .../kafka/server/KRaftClusterTest.scala            |   8 +-
 .../org/apache/kafka/raft/KafkaRaftClient.java     |  13 +-
 .../java/org/apache/kafka/raft/LeaderState.java    | 233 +++++++-------
 .../org/apache/kafka/raft/LeaderStateTest.java     | 337 +++++++++++++++------
 .../apache/kafka/raft/RaftClientTestContext.java   |  15 +-
 .../kafka/raft/internals/KafkaRaftMetricsTest.java |   4 +-
 11 files changed, 456 insertions(+), 279 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java b/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java
index e5df779b616..1b837fda223 100644
--- a/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java
+++ b/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java
@@ -4357,12 +4357,21 @@ public class KafkaAdminClient extends AdminClient {
             }
 
             private QuorumInfo createQuorumResult(final DescribeQuorumResponseData.PartitionData partition) {
+                List<QuorumInfo.ReplicaState> voters = partition.currentVoters().stream()
+                    .map(this::translateReplicaState)
+                    .collect(Collectors.toList());
+
+                List<QuorumInfo.ReplicaState> observers = partition.observers().stream()
+                    .map(this::translateReplicaState)
+                    .collect(Collectors.toList());
+
                 return new QuorumInfo(
-                        partition.leaderId(),
-                        partition.leaderEpoch(),
-                        partition.highWatermark(),
-                        partition.currentVoters().stream().map(v -> translateReplicaState(v)).collect(Collectors.toList()),
-                        partition.observers().stream().map(o -> translateReplicaState(o)).collect(Collectors.toList()));
+                    partition.leaderId(),
+                    partition.leaderEpoch(),
+                    partition.highWatermark(),
+                    voters,
+                    observers
+                );
             }
 
             @Override
diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/QuorumInfo.java b/clients/src/main/java/org/apache/kafka/clients/admin/QuorumInfo.java
index 3a0b6cf6f74..f9e4f8c11c9 100644
--- a/clients/src/main/java/org/apache/kafka/clients/admin/QuorumInfo.java
+++ b/clients/src/main/java/org/apache/kafka/clients/admin/QuorumInfo.java
@@ -24,13 +24,19 @@ import java.util.OptionalLong;
  * This class is used to describe the state of the quorum received in DescribeQuorumResponse.
  */
 public class QuorumInfo {
-    private final Integer leaderId;
-    private final Integer leaderEpoch;
-    private final Long highWatermark;
+    private final int leaderId;
+    private final long leaderEpoch;
+    private final long highWatermark;
     private final List<ReplicaState> voters;
     private final List<ReplicaState> observers;
 
-    QuorumInfo(Integer leaderId, Integer leaderEpoch, Long highWatermark, List<ReplicaState> voters, List<ReplicaState> observers) {
+    QuorumInfo(
+        int leaderId,
+        long leaderEpoch,
+        long highWatermark,
+        List<ReplicaState> voters,
+        List<ReplicaState> observers
+    ) {
         this.leaderId = leaderId;
         this.leaderEpoch = leaderEpoch;
         this.highWatermark = highWatermark;
@@ -38,15 +44,15 @@ public class QuorumInfo {
         this.observers = observers;
     }
 
-    public Integer leaderId() {
+    public int leaderId() {
         return leaderId;
     }
 
-    public Integer leaderEpoch() {
+    public long leaderEpoch() {
         return leaderEpoch;
     }
 
-    public Long highWatermark() {
+    public long highWatermark() {
         return highWatermark;
     }
 
@@ -63,20 +69,24 @@ public class QuorumInfo {
         if (this == o) return true;
         if (o == null || getClass() != o.getClass()) return false;
         QuorumInfo that = (QuorumInfo) o;
-        return leaderId.equals(that.leaderId)
-            && voters.equals(that.voters)
-            && observers.equals(that.observers);
+        return leaderId == that.leaderId
+            && leaderEpoch == that.leaderEpoch
+            && highWatermark == that.highWatermark
+            && Objects.equals(voters, that.voters)
+            && Objects.equals(observers, that.observers);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(leaderId, voters, observers);
+        return Objects.hash(leaderId, leaderEpoch, highWatermark, voters, observers);
     }
 
     @Override
     public String toString() {
         return "QuorumInfo(" +
             "leaderId=" + leaderId +
+            ", leaderEpoch=" + leaderEpoch +
+            ", highWatermark=" + highWatermark +
             ", voters=" + voters +
             ", observers=" + observers +
             ')';
@@ -85,8 +95,8 @@ public class QuorumInfo {
     public static class ReplicaState {
         private final int replicaId;
         private final long logEndOffset;
-        private final OptionalLong lastFetchTimeMs;
-        private final OptionalLong lastCaughtUpTimeMs;
+        private final OptionalLong lastFetchTimestamp;
+        private final OptionalLong lastCaughtUpTimestamp;
 
         ReplicaState() {
             this(0, 0, OptionalLong.empty(), OptionalLong.empty());
@@ -95,13 +105,13 @@ public class QuorumInfo {
         ReplicaState(
             int replicaId,
             long logEndOffset,
-            OptionalLong lastFetchTimeMs,
-            OptionalLong lastCaughtUpTimeMs
+            OptionalLong lastFetchTimestamp,
+            OptionalLong lastCaughtUpTimestamp
         ) {
             this.replicaId = replicaId;
             this.logEndOffset = logEndOffset;
-            this.lastFetchTimeMs = lastFetchTimeMs;
-            this.lastCaughtUpTimeMs = lastCaughtUpTimeMs;
+            this.lastFetchTimestamp = lastFetchTimestamp;
+            this.lastCaughtUpTimestamp = lastCaughtUpTimestamp;
         }
 
         /**
@@ -121,19 +131,21 @@ public class QuorumInfo {
         }
 
         /**
-         * Return the lastFetchTime in milliseconds for this replica.
+         * Return the last millisecond timestamp that the leader received a
+         * fetch from this replica.
          * @return The value of the lastFetchTime if known, empty otherwise
          */
-        public OptionalLong lastFetchTimeMs() {
-            return lastFetchTimeMs;
+        public OptionalLong lastFetchTimestamp() {
+            return lastFetchTimestamp;
         }
 
         /**
-         * Return the lastCaughtUpTime in milliseconds for this replica.
+         * Return the last millisecond timestamp at which this replica was known to be
+         * caught up with the leader.
          * @return The value of the lastCaughtUpTime if known, empty otherwise
          */
-        public OptionalLong lastCaughtUpTimeMs() {
-            return lastCaughtUpTimeMs;
+        public OptionalLong lastCaughtUpTimestamp() {
+            return lastCaughtUpTimestamp;
         }
 
         @Override
@@ -143,13 +155,13 @@ public class QuorumInfo {
             ReplicaState that = (ReplicaState) o;
             return replicaId == that.replicaId
                 && logEndOffset == that.logEndOffset
-                && lastFetchTimeMs.equals(that.lastFetchTimeMs)
-                && lastCaughtUpTimeMs.equals(that.lastCaughtUpTimeMs);
+                && lastFetchTimestamp.equals(that.lastFetchTimestamp)
+                && lastCaughtUpTimestamp.equals(that.lastCaughtUpTimestamp);
         }
 
         @Override
         public int hashCode() {
-            return Objects.hash(replicaId, logEndOffset, lastFetchTimeMs, lastCaughtUpTimeMs);
+            return Objects.hash(replicaId, logEndOffset, lastFetchTimestamp, lastCaughtUpTimestamp);
         }
 
         @Override
@@ -157,8 +169,8 @@ public class QuorumInfo {
             return "ReplicaState(" +
                 "replicaId=" + replicaId +
                 ", logEndOffset=" + logEndOffset +
-                ", lastFetchTimeMs=" + lastFetchTimeMs +
-                ", lastCaughtUpTimeMs=" + lastCaughtUpTimeMs +
+                ", lastFetchTimestamp=" + lastFetchTimestamp +
+                ", lastCaughtUpTimestamp=" + lastCaughtUpTimestamp +
                 ')';
         }
     }
diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumResponse.java
index 06ae681bc5c..9f58e52970c 100644
--- a/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumResponse.java
+++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumResponse.java
@@ -18,7 +18,6 @@ package org.apache.kafka.common.requests;
 
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.message.DescribeQuorumResponseData;
-import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.ByteBufferAccessor;
 import org.apache.kafka.common.protocol.Errors;
@@ -26,7 +25,6 @@ import org.apache.kafka.common.protocol.Errors;
 import java.nio.ByteBuffer;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 
 /**
@@ -85,23 +83,15 @@ public class DescribeQuorumResponse extends AbstractResponse {
     }
 
 
-    public static DescribeQuorumResponseData singletonResponse(TopicPartition topicPartition,
-                                                               int leaderId,
-                                                               int leaderEpoch,
-                                                               long highWatermark,
-                                                               List<ReplicaState> voterStates,
-                                                               List<ReplicaState> observerStates) {
+    public static DescribeQuorumResponseData singletonResponse(
+        TopicPartition topicPartition,
+        DescribeQuorumResponseData.PartitionData partitionData
+    ) {
         return new DescribeQuorumResponseData()
             .setTopics(Collections.singletonList(new DescribeQuorumResponseData.TopicData()
                 .setTopicName(topicPartition.topic())
-                .setPartitions(Collections.singletonList(new DescribeQuorumResponseData.PartitionData()
-                    .setPartitionIndex(topicPartition.partition())
-                    .setErrorCode(Errors.NONE.code())
-                    .setLeaderId(leaderId)
-                    .setLeaderEpoch(leaderEpoch)
-                    .setHighWatermark(highWatermark)
-                    .setCurrentVoters(voterStates)
-                    .setObservers(observerStates)))));
+                .setPartitions(Collections.singletonList(partitionData
+                    .setPartitionIndex(topicPartition.partition())))));
     }
 
     public static DescribeQuorumResponse parse(ByteBuffer buffer, short version) {
diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java
index 193457655a5..471551255e2 100644
--- a/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java
@@ -607,7 +607,7 @@ public class KafkaAdminClientTest {
                 .setErrorCode(error.code()));
     }
 
-    private static QuorumInfo defaultQuorumInfo(Boolean emptyOptionals) {
+    private static QuorumInfo defaultQuorumInfo(boolean emptyOptionals) {
         return new QuorumInfo(1, 1, 1L,
                 singletonList(new QuorumInfo.ReplicaState(1, 100,
                         emptyOptionals ? OptionalLong.empty() : OptionalLong.of(1000),
@@ -637,8 +637,8 @@ public class KafkaAdminClientTest {
             replica.setLastCaughtUpTimestamp(emptyOptionals ? -1 : 1000);
             partitions.add(new DescribeQuorumResponseData.PartitionData().setPartitionIndex(partitionIndex)
                     .setLeaderId(1)
-                    .setLeaderEpoch(0)
-                    .setHighWatermark(0)
+                    .setLeaderEpoch(1)
+                    .setHighWatermark(1)
                     .setCurrentVoters(singletonList(replica))
                     .setObservers(singletonList(replica))
                     .setErrorCode(partitionLevelError.code()));
diff --git a/core/src/main/scala/kafka/admin/MetadataQuorumCommand.scala b/core/src/main/scala/kafka/admin/MetadataQuorumCommand.scala
index b6e4e1597b5..c92988d97fa 100644
--- a/core/src/main/scala/kafka/admin/MetadataQuorumCommand.scala
+++ b/core/src/main/scala/kafka/admin/MetadataQuorumCommand.scala
@@ -127,13 +127,13 @@ object MetadataQuorumCommand {
         Array(info.replicaId,
               info.logEndOffset,
               leader.logEndOffset - info.logEndOffset,
-              info.lastFetchTimeMs.orElse(-1),
-              info.lastCaughtUpTimeMs.orElse(-1),
+              info.lastFetchTimestamp.orElse(-1),
+              info.lastCaughtUpTimestamp.orElse(-1),
               status
         ).map(_.toString)
       }
     prettyPrintTable(
-      Array("NodeId", "LogEndOffset", "Lag", "LastFetchTimeMs", "LastCaughtUpTimeMs", "Status"),
+      Array("NodeId", "LogEndOffset", "Lag", "LastFetchTimestamp", "LastCaughtUpTimestamp", "Status"),
       (convertQuorumInfo(Seq(leader), "Leader")
         ++ convertQuorumInfo(quorumInfo.voters.asScala.filter(_.replicaId != leaderId).toSeq, "Follower")
         ++ convertQuorumInfo(quorumInfo.observers.asScala.toSeq, "Observer")).asJava,
@@ -152,8 +152,8 @@ object MetadataQuorumCommand {
     val maxFollowerLagTimeMs =
       if (leader == maxLagFollower) {
         0
-      } else if (leader.lastCaughtUpTimeMs.isPresent && maxLagFollower.lastCaughtUpTimeMs.isPresent) {
-        leader.lastCaughtUpTimeMs.getAsLong - maxLagFollower.lastCaughtUpTimeMs.getAsLong
+      } else if (leader.lastCaughtUpTimestamp.isPresent && maxLagFollower.lastCaughtUpTimestamp.isPresent) {
+        leader.lastCaughtUpTimestamp.getAsLong - maxLagFollower.lastCaughtUpTimestamp.getAsLong
       } else {
         -1
       }
diff --git a/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala b/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala
index a16cf821d4d..c550553917b 100644
--- a/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala
+++ b/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala
@@ -810,16 +810,16 @@ class KRaftClusterTest {
         quorumInfo.voters.forEach { voter =>
           assertTrue(0 < voter.logEndOffset,
             s"logEndOffset for voter with ID ${voter.replicaId} was ${voter.logEndOffset}")
-          assertNotEquals(OptionalLong.empty(), voter.lastFetchTimeMs)
-          assertNotEquals(OptionalLong.empty(), voter.lastCaughtUpTimeMs)
+          assertNotEquals(OptionalLong.empty(), voter.lastFetchTimestamp)
+          assertNotEquals(OptionalLong.empty(), voter.lastCaughtUpTimestamp)
         }
 
         assertEquals(cluster.brokers.asScala.keySet, quorumInfo.observers.asScala.map(_.replicaId).toSet)
         quorumInfo.observers.forEach { observer =>
           assertTrue(0 < observer.logEndOffset,
             s"logEndOffset for observer with ID ${observer.replicaId} was ${observer.logEndOffset}")
-          assertNotEquals(OptionalLong.empty(), observer.lastFetchTimeMs)
-          assertNotEquals(OptionalLong.empty(), observer.lastCaughtUpTimeMs)
+          assertNotEquals(OptionalLong.empty(), observer.lastFetchTimestamp)
+          assertNotEquals(OptionalLong.empty(), observer.lastCaughtUpTimestamp)
         }
       } finally {
         admin.close()
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 69d2025b6bb..dab0bb33926 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -275,7 +275,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
     ) {
         final LogOffsetMetadata endOffsetMetadata = log.endOffset();
 
-        if (state.updateLocalState(currentTimeMs, endOffsetMetadata)) {
+        if (state.updateLocalState(endOffsetMetadata)) {
             onUpdateLeaderHighWatermark(state, currentTimeMs);
         }
 
@@ -1014,7 +1014,7 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
             if (validOffsetAndEpoch.kind() == ValidOffsetAndEpoch.Kind.VALID) {
                 LogFetchInfo info = log.read(fetchOffset, Isolation.UNCOMMITTED);
 
-                if (state.updateReplicaState(replicaId, currentTimeMs, info.startOffsetMetadata, log.endOffset().offset)) {
+                if (state.updateReplicaState(replicaId, currentTimeMs, info.startOffsetMetadata)) {
                     onUpdateLeaderHighWatermark(state, currentTimeMs);
                 }
 
@@ -1176,12 +1176,9 @@ public class KafkaRaftClient<T> implements RaftClient<T> {
         }
 
         LeaderState<T> leaderState = quorum.leaderStateOrThrow();
-        return DescribeQuorumResponse.singletonResponse(log.topicPartition(),
-            leaderState.localId(),
-            leaderState.epoch(),
-            leaderState.highWatermark().isPresent() ? leaderState.highWatermark().get().offset : -1,
-            leaderState.quorumResponseVoterStates(currentTimeMs),
-            leaderState.quorumResponseObserverStates(currentTimeMs)
+        return DescribeQuorumResponse.singletonResponse(
+            log.topicPartition(),
+            leaderState.describeQuorum(currentTimeMs)
         );
     }
 
diff --git a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
index 0b8ebad8bda..3b5b6c11e6a 100644
--- a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
@@ -17,23 +17,21 @@
 package org.apache.kafka.raft;
 
 import org.apache.kafka.common.message.DescribeQuorumResponseData;
-import org.apache.kafka.common.utils.LogContext;
-import org.apache.kafka.raft.internals.BatchAccumulator;
-import org.slf4j.Logger;
-
 import org.apache.kafka.common.message.LeaderChangeMessage;
 import org.apache.kafka.common.message.LeaderChangeMessage.Voter;
+import org.apache.kafka.common.protocol.Errors;
 import org.apache.kafka.common.record.ControlRecordUtils;
+import org.apache.kafka.common.utils.LogContext;
+import org.apache.kafka.raft.internals.BatchAccumulator;
+import org.slf4j.Logger;
 
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
-import java.util.OptionalLong;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -147,7 +145,7 @@ public class LeaderState<T> implements EpochState {
         return nonAcknowledging;
     }
 
-    private boolean updateHighWatermark() {
+    private boolean maybeUpdateHighWatermark() {
         // Find the largest offset which is replicated to a majority of replicas (the leader counts)
         List<ReplicaState> followersByDescendingFetchOffset = followersByDescendingFetchOffset();
 
@@ -173,9 +171,8 @@ public class LeaderState<T> implements EpochState {
                         || (highWatermarkUpdateOffset == currentHighWatermarkMetadata.offset &&
                             !highWatermarkUpdateMetadata.metadata.equals(currentHighWatermarkMetadata.metadata))) {
                         highWatermark = highWatermarkUpdateOpt;
-                        log.trace(
-                            "High watermark updated to {} based on indexOfHw {} and voters {}",
-                            highWatermark,
+                        logHighWatermarkUpdate(
+                            highWatermarkUpdateMetadata,
                             indexOfHw,
                             followersByDescendingFetchOffset
                         );
@@ -191,9 +188,8 @@ public class LeaderState<T> implements EpochState {
                     }
                 } else {
                     highWatermark = highWatermarkUpdateOpt;
-                    log.trace(
-                        "High watermark set to {} based on indexOfHw {} and voters {}",
-                        highWatermark,
+                    logHighWatermarkUpdate(
+                        highWatermarkUpdateMetadata,
                         indexOfHw,
                         followersByDescendingFetchOffset
                     );
@@ -204,50 +200,79 @@ public class LeaderState<T> implements EpochState {
         return false;
     }
 
+    private void logHighWatermarkUpdate(
+        LogOffsetMetadata newHighWatermark,
+        int indexOfHw,
+        List<ReplicaState> followersByDescendingFetchOffset
+    ) {
+        log.trace(
+            "High watermark set to {} based on indexOfHw {} and voters {}",
+            newHighWatermark,
+            indexOfHw,
+            followersByDescendingFetchOffset
+        );
+    }
+
     /**
      * Update the local replica state.
      *
-     * See {@link #updateReplicaState(int, long, LogOffsetMetadata, long)}
+     * @param endOffsetMetadata updated log end offset of local replica
+     * @return true if the high watermark is updated as a result of this call
      */
-    public boolean updateLocalState(long fetchTimestamp, LogOffsetMetadata logOffsetMetadata) {
-        return updateReplicaState(localId, fetchTimestamp, logOffsetMetadata, logOffsetMetadata.offset);
+    public boolean updateLocalState(
+        LogOffsetMetadata endOffsetMetadata
+    ) {
+        ReplicaState state = getOrCreateReplicaState(localId);
+        state.endOffset.ifPresent(currentEndOffset -> {
+            if (currentEndOffset.offset > endOffsetMetadata.offset) {
+                throw new IllegalStateException("Detected non-monotonic update of local " +
+                    "end offset: " + currentEndOffset.offset + " -> " + endOffsetMetadata.offset);
+            }
+        });
+        state.updateLeaderState(endOffsetMetadata);
+        return maybeUpdateHighWatermark();
     }
 
     /**
      * Update the replica state in terms of fetch time and log end offsets.
      *
      * @param replicaId replica id
-     * @param fetchTimestamp fetch timestamp
-     * @param logOffsetMetadata new log offset and metadata
-     * @param leaderLogEndOffset current log end offset of the leader
-     * @return true if the high watermark is updated too
+     * @param currentTimeMs current time in milliseconds
+     * @param fetchOffsetMetadata new log offset and metadata
+     * @return true if the high watermark is updated as a result of this call
      */
     public boolean updateReplicaState(
         int replicaId,
-        long fetchTimestamp,
-        LogOffsetMetadata logOffsetMetadata,
-        long leaderLogEndOffset
+        long currentTimeMs,
+        LogOffsetMetadata fetchOffsetMetadata
     ) {
         // Ignore fetches from negative replica id, as it indicates
         // the fetch is from non-replica. For example, a consumer.
         if (replicaId < 0) {
             return false;
+        } else if (replicaId == localId) {
+            throw new IllegalStateException("Remote replica ID " + replicaId + " matches the local leader ID");
         }
 
-        ReplicaState state = getReplicaState(replicaId);
+        ReplicaState state = getOrCreateReplicaState(replicaId);
+
+        state.endOffset.ifPresent(currentEndOffset -> {
+            if (currentEndOffset.offset > fetchOffsetMetadata.offset) {
+                log.warn("Detected non-monotonic update of fetch offset from nodeId {}: {} -> {}",
+                    state.nodeId, currentEndOffset.offset, fetchOffsetMetadata.offset);
+            }
+        });
 
-        // Only proceed with updating the states if the offset update is valid
-        verifyEndOffsetUpdate(state, logOffsetMetadata);
+        Optional<LogOffsetMetadata> leaderEndOffsetOpt =
+            voterStates.get(localId).endOffset;
 
-        // Update the Last CaughtUp Time
-        if (logOffsetMetadata.offset >= leaderLogEndOffset) {
-            state.updateLastCaughtUpTimestamp(fetchTimestamp);
-        } else if (logOffsetMetadata.offset >= state.lastFetchLeaderLogEndOffset.orElse(-1L)) {
-            state.updateLastCaughtUpTimestamp(state.lastFetchTimestamp.orElse(-1L));
-        }
+        state.updateFollowerState(
+            currentTimeMs,
+            fetchOffsetMetadata,
+            leaderEndOffsetOpt
+        );
 
-        state.updateFetchTimestamp(fetchTimestamp, leaderLogEndOffset);
-        return updateEndOffset(state, logOffsetMetadata);
+        return isVoter(state.nodeId) && maybeUpdateHighWatermark();
     }
 
     public List<Integer> nonLeaderVotersByDescendingFetchOffset() {
@@ -263,31 +288,6 @@ public class LeaderState<T> implements EpochState {
             .collect(Collectors.toList());
     }
 
-    private void verifyEndOffsetUpdate(
-        ReplicaState state,
-        LogOffsetMetadata endOffsetMetadata
-    ) {
-        state.endOffset.ifPresent(currentEndOffset -> {
-            if (currentEndOffset.offset > endOffsetMetadata.offset) {
-                if (state.nodeId == localId) {
-                    throw new IllegalStateException("Detected non-monotonic update of local " +
-                        "end offset: " + currentEndOffset.offset + " -> " + endOffsetMetadata.offset);
-                } else {
-                    log.warn("Detected non-monotonic update of fetch offset from nodeId {}: {} -> {}",
-                        state.nodeId, currentEndOffset.offset, endOffsetMetadata.offset);
-                }
-            }
-        });
-    }
-    private boolean updateEndOffset(
-        ReplicaState state,
-        LogOffsetMetadata endOffsetMetadata
-    ) {
-        state.endOffset = Optional.of(endOffsetMetadata);
-        state.hasAcknowledgedLeader = true;
-        return isVoter(state.nodeId) && updateHighWatermark();
-    }
-
     public void addAcknowledgementFrom(int remoteNodeId) {
         ReplicaState voterState = ensureValidVoter(remoteNodeId);
         voterState.hasAcknowledgedLeader = true;
@@ -304,7 +304,7 @@ public class LeaderState<T> implements EpochState {
         return epochStartOffset;
     }
 
-    private ReplicaState getReplicaState(int remoteNodeId) {
+    private ReplicaState getOrCreateReplicaState(int remoteNodeId) {
         ReplicaState state = voterStates.get(remoteNodeId);
         if (state == null) {
             observerStates.putIfAbsent(remoteNodeId, new ReplicaState(remoteNodeId, false));
@@ -313,43 +313,52 @@ public class LeaderState<T> implements EpochState {
         return state;
     }
 
-    List<DescribeQuorumResponseData.ReplicaState> quorumResponseVoterStates(long currentTimeMs) {
-        return quorumResponseReplicaStates(voterStates.values(), localId, currentTimeMs);
+    public DescribeQuorumResponseData.PartitionData describeQuorum(long currentTimeMs) {
+        clearInactiveObservers(currentTimeMs);
+
+        return new DescribeQuorumResponseData.PartitionData()
+            .setErrorCode(Errors.NONE.code())
+            .setLeaderId(localId)
+            .setLeaderEpoch(epoch)
+            .setHighWatermark(highWatermark().map(offsetMetadata -> offsetMetadata.offset).orElse(-1L))
+            .setCurrentVoters(describeReplicaStates(voterStates, currentTimeMs))
+            .setObservers(describeReplicaStates(observerStates, currentTimeMs));
     }
 
-    List<DescribeQuorumResponseData.ReplicaState> quorumResponseObserverStates(long currentTimeMs) {
-        clearInactiveObservers(currentTimeMs);
-        return quorumResponseReplicaStates(observerStates.values(), localId, currentTimeMs);
+    private List<DescribeQuorumResponseData.ReplicaState> describeReplicaStates(
+        Map<Integer, ReplicaState> state,
+        long currentTimeMs
+    ) {
+        return state.values().stream()
+            .map(replicaState -> describeReplicaState(replicaState, currentTimeMs))
+            .collect(Collectors.toList());
     }
 
-    private static  List<DescribeQuorumResponseData.ReplicaState> quorumResponseReplicaStates(
-        Collection<ReplicaState> state,
-        int leaderId,
+    private DescribeQuorumResponseData.ReplicaState describeReplicaState(
+        ReplicaState replicaState,
         long currentTimeMs
     ) {
-        return state.stream().map(s -> {
-            final long lastCaughtUpTimestamp;
-            final long lastFetchTimestamp;
-            if (s.nodeId == leaderId) {
-                lastCaughtUpTimestamp = currentTimeMs;
-                lastFetchTimestamp = currentTimeMs;
-            } else {
-                lastCaughtUpTimestamp = s.lastCaughtUpTimestamp.orElse(-1);
-                lastFetchTimestamp = s.lastFetchTimestamp.orElse(-1);
-            }
-            return new DescribeQuorumResponseData.ReplicaState()
-                    .setReplicaId(s.nodeId)
-                    .setLogEndOffset(s.endOffset.map(md -> md.offset).orElse(-1L))
-                    .setLastCaughtUpTimestamp(lastCaughtUpTimestamp)
-                    .setLastFetchTimestamp(lastFetchTimestamp);
-        }).collect(Collectors.toList());
+        final long lastCaughtUpTimestamp;
+        final long lastFetchTimestamp;
+        if (replicaState.nodeId == localId) {
+            lastCaughtUpTimestamp = currentTimeMs;
+            lastFetchTimestamp = currentTimeMs;
+        } else {
+            lastCaughtUpTimestamp = replicaState.lastCaughtUpTimestamp;
+            lastFetchTimestamp = replicaState.lastFetchTimestamp;
+        }
+        return new DescribeQuorumResponseData.ReplicaState()
+            .setReplicaId(replicaState.nodeId)
+            .setLogEndOffset(replicaState.endOffset.map(md -> md.offset).orElse(-1L))
+            .setLastCaughtUpTimestamp(lastCaughtUpTimestamp)
+            .setLastFetchTimestamp(lastFetchTimestamp);
+
     }
 
     private void clearInactiveObservers(final long currentTimeMs) {
-        observerStates.entrySet().removeIf(
-            integerReplicaStateEntry ->
-                currentTimeMs - integerReplicaStateEntry.getValue().lastFetchTimestamp.orElse(-1)
-                    >= OBSERVER_SESSION_TIMEOUT_MS);
+        observerStates.entrySet().removeIf(integerReplicaStateEntry ->
+            currentTimeMs - integerReplicaStateEntry.getValue().lastFetchTimestamp >= OBSERVER_SESSION_TIMEOUT_MS
+        );
     }
 
     private boolean isVoter(int remoteNodeId) {
@@ -359,31 +368,49 @@ public class LeaderState<T> implements EpochState {
     private static class ReplicaState implements Comparable<ReplicaState> {
         final int nodeId;
         Optional<LogOffsetMetadata> endOffset;
-        OptionalLong lastFetchTimestamp;
-        OptionalLong lastFetchLeaderLogEndOffset;
-        OptionalLong lastCaughtUpTimestamp;
+        long lastFetchTimestamp;
+        long lastFetchLeaderLogEndOffset;
+        long lastCaughtUpTimestamp;
         boolean hasAcknowledgedLeader;
 
         public ReplicaState(int nodeId, boolean hasAcknowledgedLeader) {
             this.nodeId = nodeId;
             this.endOffset = Optional.empty();
-            this.lastFetchTimestamp = OptionalLong.empty();
-            this.lastFetchLeaderLogEndOffset = OptionalLong.empty();
-            this.lastCaughtUpTimestamp = OptionalLong.empty();
+            this.lastFetchTimestamp = -1;
+            this.lastFetchLeaderLogEndOffset = -1;
+            this.lastCaughtUpTimestamp = -1;
             this.hasAcknowledgedLeader = hasAcknowledgedLeader;
         }
 
-        void updateFetchTimestamp(long currentFetchTimeMs, long leaderLogEndOffset) {
-            // To be resilient to system time shifts we do not strictly
-            // require the timestamp be monotonically increasing.
-            lastFetchTimestamp = OptionalLong.of(Math.max(lastFetchTimestamp.orElse(-1L), currentFetchTimeMs));
-            lastFetchLeaderLogEndOffset = OptionalLong.of(leaderLogEndOffset);
+        void updateLeaderState(
+            LogOffsetMetadata endOffsetMetadata
+        ) {
+            // For the leader, we only update the end offset. The remaining fields
+            // (such as the caught up time) are determined implicitly.
+            this.endOffset = Optional.of(endOffsetMetadata);
         }
 
-        void updateLastCaughtUpTimestamp(long lastCaughtUpTime) {
-            // This value relies on the fetch timestamp which does not
-            // require monotonicity
-            lastCaughtUpTimestamp = OptionalLong.of(Math.max(lastCaughtUpTimestamp.orElse(-1L), lastCaughtUpTime));
+        void updateFollowerState(
+            long currentTimeMs,
+            LogOffsetMetadata fetchOffsetMetadata,
+            Optional<LogOffsetMetadata> leaderEndOffsetOpt
+        ) {
+            // Update the `lastCaughtUpTimestamp` before we update the `lastFetchTimestamp`.
+            // This allows us to use the previous value for `lastFetchTimestamp` if the
+            // follower was able to catch up to `lastFetchLeaderLogEndOffset` on this fetch.
+            leaderEndOffsetOpt.ifPresent(leaderEndOffset -> {
+                if (fetchOffsetMetadata.offset >= leaderEndOffset.offset) {
+                    lastCaughtUpTimestamp = Math.max(lastCaughtUpTimestamp, currentTimeMs);
+                } else if (lastFetchLeaderLogEndOffset > 0
+                    && fetchOffsetMetadata.offset >= lastFetchLeaderLogEndOffset) {
+                    lastCaughtUpTimestamp = Math.max(lastCaughtUpTimestamp, lastFetchTimestamp);
+                }
+                lastFetchLeaderLogEndOffset = leaderEndOffset.offset;
+            });
+
+            lastFetchTimestamp = Math.max(lastFetchTimestamp, currentTimeMs);
+            endOffset = Optional.of(fetchOffsetMetadata);
+            hasAcknowledgedLeader = true;
         }
 
         @Override
diff --git a/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java b/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java
index fa54d5cbc6b..bb44fea2ac0 100644
--- a/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java
@@ -21,7 +21,6 @@ import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.raft.internals.BatchAccumulator;
-
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
@@ -29,16 +28,13 @@ import org.mockito.Mockito;
 
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
-import java.util.stream.Collectors;
 
-import static java.util.Collections.emptyList;
 import static java.util.Collections.emptySet;
 import static java.util.Collections.singleton;
-import static org.apache.kafka.common.utils.Utils.mkEntry;
-import static org.apache.kafka.common.utils.Utils.mkMap;
 import static org.apache.kafka.common.utils.Utils.mkSet;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -103,12 +99,12 @@ public class LeaderStateTest {
     public void testUpdateHighWatermarkQuorumSizeOne() {
         LeaderState<?> state = newLeaderState(singleton(localId), 15L);
         assertEquals(Optional.empty(), state.highWatermark());
-        assertFalse(state.updateLocalState(0, new LogOffsetMetadata(15L)));
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(15L)));
         assertEquals(emptySet(), state.nonAcknowledgingVoters());
         assertEquals(Optional.empty(), state.highWatermark());
-        assertTrue(state.updateLocalState(0, new LogOffsetMetadata(16L)));
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(16L)));
         assertEquals(Optional.of(new LogOffsetMetadata(16L)), state.highWatermark());
-        assertTrue(state.updateLocalState(0, new LogOffsetMetadata(20)));
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(20)));
         assertEquals(Optional.of(new LogOffsetMetadata(20L)), state.highWatermark());
     }
 
@@ -116,10 +112,10 @@ public class LeaderStateTest {
     public void testNonMonotonicLocalEndOffsetUpdate() {
         LeaderState<?> state = newLeaderState(singleton(localId), 15L);
         assertEquals(Optional.empty(), state.highWatermark());
-        assertTrue(state.updateLocalState(0, new LogOffsetMetadata(16L)));
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(16L)));
         assertEquals(Optional.of(new LogOffsetMetadata(16L)), state.highWatermark());
         assertThrows(IllegalStateException.class,
-            () -> state.updateLocalState(0, new LogOffsetMetadata(15L)));
+            () -> state.updateLocalState(new LogOffsetMetadata(15L)));
     }
 
     @Test
@@ -128,49 +124,51 @@ public class LeaderStateTest {
         int node2 = 2;
         int currentTime = 1000;
         int fetchTime = 0;
-        int caughtupTime = -1;
+        int caughtUpTime = -1;
         LeaderState<?> state = newLeaderState(mkSet(localId, node1, node2), 10L);
         assertEquals(Optional.empty(), state.highWatermark());
-        assertFalse(state.updateLocalState(++fetchTime, new LogOffsetMetadata(10L)));
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(10L)));
         assertEquals(mkSet(node1, node2), state.nonAcknowledgingVoters());
         assertEquals(Optional.empty(), state.highWatermark());
 
         // Node 1 falls behind
-        assertFalse(state.updateLocalState(++fetchTime, new LogOffsetMetadata(10L)));
-        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(10L), 11L));
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtupTime, state.quorumResponseVoterStates(currentTime).get(node1).lastCaughtUpTimestamp());
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(11L)));
+        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(10L)));
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeVoterState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node 1 catches up to leader
-        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(11L), 11L));
-        caughtupTime = fetchTime;
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtupTime, state.quorumResponseVoterStates(currentTime).get(node1).lastCaughtUpTimestamp());
+        assertTrue(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(11L)));
+        caughtUpTime = fetchTime;
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeVoterState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node 1 falls behind
-        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(50L), 100L));
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtupTime, state.quorumResponseVoterStates(currentTime).get(node1).lastCaughtUpTimestamp());
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(100L)));
+        assertTrue(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(50L)));
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeVoterState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node 1 catches up to the last fetch offset
         int prevFetchTime = fetchTime;
-        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(102L), 200L));
-        caughtupTime = prevFetchTime;
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtupTime, state.quorumResponseVoterStates(currentTime).get(node1).lastCaughtUpTimestamp());
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(200L)));
+        assertTrue(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(100L)));
+        caughtUpTime = prevFetchTime;
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeVoterState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node2 has never caught up to leader
-        assertEquals(-1L, state.quorumResponseVoterStates(currentTime).get(node2).lastCaughtUpTimestamp());
-        assertTrue(state.updateReplicaState(node2, ++fetchTime, new LogOffsetMetadata(202L), 300L));
-        assertEquals(-1L, state.quorumResponseVoterStates(currentTime).get(node2).lastCaughtUpTimestamp());
-        assertFalse(state.updateReplicaState(node2, ++fetchTime, new LogOffsetMetadata(250L), 300L));
-        assertEquals(-1L, state.quorumResponseVoterStates(currentTime).get(node2).lastCaughtUpTimestamp());
+        assertEquals(-1L, describeVoterState(state, node2, currentTime).lastCaughtUpTimestamp());
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(300L)));
+        assertTrue(state.updateReplicaState(node2, ++fetchTime, new LogOffsetMetadata(200L)));
+        assertEquals(-1L, describeVoterState(state, node2, currentTime).lastCaughtUpTimestamp());
+        assertTrue(state.updateReplicaState(node2, ++fetchTime, new LogOffsetMetadata(250L)));
+        assertEquals(-1L, describeVoterState(state, node2, currentTime).lastCaughtUpTimestamp());
     }
 
     @Test
     public void testLastCaughtUpTimeObserver() {
-        int node1Index = 0;
-        int node1Id = 1;
+        int node1 = 1;
         int currentTime = 1000;
         int fetchTime = 0;
         int caughtUpTime = -1;
@@ -179,42 +177,44 @@ public class LeaderStateTest {
         assertEquals(emptySet(), state.nonAcknowledgingVoters());
 
         // Node 1 falls behind
-        assertTrue(state.updateLocalState(++fetchTime, new LogOffsetMetadata(10L)));
-        assertFalse(state.updateReplicaState(node1Id, ++fetchTime, new LogOffsetMetadata(10L), 11L));
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtUpTime, state.quorumResponseObserverStates(currentTime).get(node1Index).lastCaughtUpTimestamp());
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(11L)));
+        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(10L)));
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeObserverState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node 1 catches up to leader
-        assertFalse(state.updateReplicaState(node1Id, ++fetchTime, new LogOffsetMetadata(11L), 11L));
+        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(11L)));
         caughtUpTime = fetchTime;
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtUpTime, state.quorumResponseObserverStates(currentTime).get(node1Index).lastCaughtUpTimestamp());
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeObserverState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node 1 falls behind
-        assertFalse(state.updateReplicaState(node1Id, ++fetchTime, new LogOffsetMetadata(50L), 100L));
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtUpTime, state.quorumResponseObserverStates(currentTime).get(node1Index).lastCaughtUpTimestamp());
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(100L)));
+        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(50L)));
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeObserverState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node 1 catches up to the last fetch offset
         int prevFetchTime = fetchTime;
-        assertFalse(state.updateReplicaState(node1Id, ++fetchTime, new LogOffsetMetadata(102L), 200L));
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(200L)));
+        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(102L)));
         caughtUpTime = prevFetchTime;
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtUpTime, state.quorumResponseObserverStates(currentTime).get(node1Index).lastCaughtUpTimestamp());
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeObserverState(state, node1, currentTime).lastCaughtUpTimestamp());
 
         // Node 1 catches up to leader
-        assertFalse(state.updateReplicaState(node1Id, ++fetchTime, new LogOffsetMetadata(202L), 200L));
+        assertFalse(state.updateReplicaState(node1, ++fetchTime, new LogOffsetMetadata(200L)));
         caughtUpTime = fetchTime;
-        assertEquals(currentTime, state.quorumResponseVoterStates(currentTime).get(localId).lastCaughtUpTimestamp());
-        assertEquals(caughtUpTime, state.quorumResponseObserverStates(currentTime).get(node1Index).lastCaughtUpTimestamp());
+        assertEquals(currentTime, describeVoterState(state, localId, currentTime).lastCaughtUpTimestamp());
+        assertEquals(caughtUpTime, describeObserverState(state, node1, currentTime).lastCaughtUpTimestamp());
     }
 
     @Test
     public void testIdempotentEndOffsetUpdate() {
         LeaderState<?> state = newLeaderState(singleton(localId), 15L);
         assertEquals(Optional.empty(), state.highWatermark());
-        assertTrue(state.updateLocalState(0, new LogOffsetMetadata(16L)));
-        assertFalse(state.updateLocalState(0, new LogOffsetMetadata(16L)));
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(16L)));
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(16L)));
         assertEquals(Optional.of(new LogOffsetMetadata(16L)), state.highWatermark());
     }
 
@@ -224,11 +224,11 @@ public class LeaderStateTest {
         assertEquals(Optional.empty(), state.highWatermark());
 
         LogOffsetMetadata initialHw = new LogOffsetMetadata(16L, Optional.of(new MockOffsetMetadata("bar")));
-        assertTrue(state.updateLocalState(0, initialHw));
+        assertTrue(state.updateLocalState(initialHw));
         assertEquals(Optional.of(initialHw), state.highWatermark());
 
         LogOffsetMetadata updateHw = new LogOffsetMetadata(16L, Optional.of(new MockOffsetMetadata("baz")));
-        assertTrue(state.updateLocalState(0, updateHw));
+        assertTrue(state.updateLocalState(updateHw));
         assertEquals(Optional.of(updateHw), state.highWatermark());
     }
 
@@ -236,15 +236,15 @@ public class LeaderStateTest {
     public void testUpdateHighWatermarkQuorumSizeTwo() {
         int otherNodeId = 1;
         LeaderState<?> state = newLeaderState(mkSet(localId, otherNodeId), 10L);
-        assertFalse(state.updateLocalState(0, new LogOffsetMetadata(13L)));
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(13L)));
         assertEquals(singleton(otherNodeId), state.nonAcknowledgingVoters());
         assertEquals(Optional.empty(), state.highWatermark());
-        assertFalse(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(10L), 11L));
+        assertFalse(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(10L)));
         assertEquals(emptySet(), state.nonAcknowledgingVoters());
         assertEquals(Optional.empty(), state.highWatermark());
-        assertTrue(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(11L), 12L));
+        assertTrue(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(11L)));
         assertEquals(Optional.of(new LogOffsetMetadata(11L)), state.highWatermark());
-        assertTrue(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(13L), 14L));
+        assertTrue(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(13L)));
         assertEquals(Optional.of(new LogOffsetMetadata(13L)), state.highWatermark());
     }
 
@@ -253,22 +253,22 @@ public class LeaderStateTest {
         int node1 = 1;
         int node2 = 2;
         LeaderState<?> state = newLeaderState(mkSet(localId, node1, node2), 10L);
-        assertFalse(state.updateLocalState(0, new LogOffsetMetadata(15L)));
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(15L)));
         assertEquals(mkSet(node1, node2), state.nonAcknowledgingVoters());
         assertEquals(Optional.empty(), state.highWatermark());
-        assertFalse(state.updateReplicaState(node1, 0, new LogOffsetMetadata(10L), 11L));
+        assertFalse(state.updateReplicaState(node1, 0, new LogOffsetMetadata(10L)));
         assertEquals(singleton(node2), state.nonAcknowledgingVoters());
         assertEquals(Optional.empty(), state.highWatermark());
-        assertFalse(state.updateReplicaState(node2, 0, new LogOffsetMetadata(10L), 11L));
+        assertFalse(state.updateReplicaState(node2, 0, new LogOffsetMetadata(10L)));
         assertEquals(emptySet(), state.nonAcknowledgingVoters());
         assertEquals(Optional.empty(), state.highWatermark());
-        assertTrue(state.updateReplicaState(node2, 0, new LogOffsetMetadata(15L), 16L));
+        assertTrue(state.updateReplicaState(node2, 0, new LogOffsetMetadata(15L)));
         assertEquals(Optional.of(new LogOffsetMetadata(15L)), state.highWatermark());
-        assertFalse(state.updateLocalState(0, new LogOffsetMetadata(20L)));
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(20L)));
         assertEquals(Optional.of(new LogOffsetMetadata(15L)), state.highWatermark());
-        assertTrue(state.updateReplicaState(node1, 0, new LogOffsetMetadata(20L), 21L));
+        assertTrue(state.updateReplicaState(node1, 0, new LogOffsetMetadata(20L)));
         assertEquals(Optional.of(new LogOffsetMetadata(20L)), state.highWatermark());
-        assertFalse(state.updateReplicaState(node2, 0, new LogOffsetMetadata(20L), 21L));
+        assertFalse(state.updateReplicaState(node2, 0, new LogOffsetMetadata(20L)));
         assertEquals(Optional.of(new LogOffsetMetadata(20L)), state.highWatermark());
     }
 
@@ -277,14 +277,14 @@ public class LeaderStateTest {
         MockTime time = new MockTime();
         int node1 = 1;
         LeaderState<?> state = newLeaderState(mkSet(localId, node1), 0L);
-        state.updateLocalState(time.milliseconds(), new LogOffsetMetadata(10L));
-        state.updateReplicaState(node1, time.milliseconds(), new LogOffsetMetadata(10L), 11L);
+        state.updateLocalState(new LogOffsetMetadata(10L));
+        state.updateReplicaState(node1, time.milliseconds(), new LogOffsetMetadata(10L));
         assertEquals(Optional.of(new LogOffsetMetadata(10L)), state.highWatermark());
 
         // Follower crashes and disk is lost. It fetches an earlier offset to rebuild state.
         // The leader will report an error in the logs, but will not let the high watermark rewind
-        assertFalse(state.updateReplicaState(node1, time.milliseconds(), new LogOffsetMetadata(5L), 11L));
-        assertEquals(5L, state.quorumResponseVoterStates(time.milliseconds()).get(node1).logEndOffset());
+        assertFalse(state.updateReplicaState(node1, time.milliseconds(), new LogOffsetMetadata(5L)));
+        assertEquals(5L, describeVoterState(state, node1, time.milliseconds()).logEndOffset());
         assertEquals(Optional.of(new LogOffsetMetadata(10L)), state.highWatermark());
     }
 
@@ -302,21 +302,102 @@ public class LeaderStateTest {
     }
 
     @Test
-    public void testGetVoterStates() {
-        int node1 = 1;
-        int node2 = 2;
+    public void testDescribeQuorumWithSingleVoter() {
+        MockTime time = new MockTime();
         long leaderStartOffset = 10L;
         long leaderEndOffset = 15L;
 
-        LeaderState<?> state = setUpLeaderAndFollowers(node1, node2, leaderStartOffset, leaderEndOffset);
+        LeaderState<?> state = newLeaderState(mkSet(localId), leaderStartOffset);
+
+        // Until we have updated local state, high watermark should be uninitialized
+        assertEquals(Optional.empty(), state.highWatermark());
+        DescribeQuorumResponseData.PartitionData partitionData = state.describeQuorum(time.milliseconds());
+        assertEquals(-1, partitionData.highWatermark());
+        assertEquals(localId, partitionData.leaderId());
+        assertEquals(epoch, partitionData.leaderEpoch());
+        assertEquals(Collections.emptyList(), partitionData.observers());
+        assertEquals(1, partitionData.currentVoters().size());
+        assertEquals(new DescribeQuorumResponseData.ReplicaState()
+                .setReplicaId(localId)
+                .setLogEndOffset(-1)
+                .setLastFetchTimestamp(time.milliseconds())
+                .setLastCaughtUpTimestamp(time.milliseconds()),
+            partitionData.currentVoters().get(0));
+
+
+        // Now update the high watermark and verify the describe output
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(leaderEndOffset)));
+        assertEquals(Optional.of(new LogOffsetMetadata(leaderEndOffset)), state.highWatermark());
+
+        time.sleep(500);
+
+        partitionData = state.describeQuorum(time.milliseconds());
+        assertEquals(leaderEndOffset, partitionData.highWatermark());
+        assertEquals(localId, partitionData.leaderId());
+        assertEquals(epoch, partitionData.leaderEpoch());
+        assertEquals(Collections.emptyList(), partitionData.observers());
+        assertEquals(1, partitionData.currentVoters().size());
+        assertEquals(new DescribeQuorumResponseData.ReplicaState()
+                .setReplicaId(localId)
+                .setLogEndOffset(leaderEndOffset)
+                .setLastFetchTimestamp(time.milliseconds())
+                .setLastCaughtUpTimestamp(time.milliseconds()),
+            partitionData.currentVoters().get(0));
+    }
+
+    @Test
+    public void testDescribeQuorumWithMultipleVoters() {
+        MockTime time = new MockTime();
+        int activeFollowerId = 1;
+        int inactiveFollowerId = 2;
+        long leaderStartOffset = 10L;
+        long leaderEndOffset = 15L;
+
+        LeaderState<?> state = newLeaderState(mkSet(localId, activeFollowerId, inactiveFollowerId), leaderStartOffset);
+        assertFalse(state.updateLocalState(new LogOffsetMetadata(leaderEndOffset)));
+        assertEquals(Optional.empty(), state.highWatermark());
 
-        assertEquals(mkMap(
-            mkEntry(localId, leaderEndOffset),
-            mkEntry(node1, leaderStartOffset),
-            mkEntry(node2, leaderEndOffset)
-        ), state.quorumResponseVoterStates(0)
-            .stream()
-            .collect(Collectors.toMap(DescribeQuorumResponseData.ReplicaState::replicaId, DescribeQuorumResponseData.ReplicaState::logEndOffset)));
+        long activeFollowerFetchTimeMs = time.milliseconds();
+        assertTrue(state.updateReplicaState(activeFollowerId, activeFollowerFetchTimeMs, new LogOffsetMetadata(leaderEndOffset)));
+        assertEquals(Optional.of(new LogOffsetMetadata(leaderEndOffset)), state.highWatermark());
+
+        time.sleep(500);
+
+        DescribeQuorumResponseData.PartitionData partitionData = state.describeQuorum(time.milliseconds());
+        assertEquals(leaderEndOffset, partitionData.highWatermark());
+        assertEquals(localId, partitionData.leaderId());
+        assertEquals(epoch, partitionData.leaderEpoch());
+        assertEquals(Collections.emptyList(), partitionData.observers());
+
+        List<DescribeQuorumResponseData.ReplicaState> voterStates = partitionData.currentVoters();
+        assertEquals(3, voterStates.size());
+
+        DescribeQuorumResponseData.ReplicaState leaderState =
+            findReplicaOrFail(localId, partitionData.currentVoters());
+        assertEquals(new DescribeQuorumResponseData.ReplicaState()
+                .setReplicaId(localId)
+                .setLogEndOffset(leaderEndOffset)
+                .setLastFetchTimestamp(time.milliseconds())
+                .setLastCaughtUpTimestamp(time.milliseconds()),
+            leaderState);
+
+        DescribeQuorumResponseData.ReplicaState activeFollowerState =
+            findReplicaOrFail(activeFollowerId, partitionData.currentVoters());
+        assertEquals(new DescribeQuorumResponseData.ReplicaState()
+                .setReplicaId(activeFollowerId)
+                .setLogEndOffset(leaderEndOffset)
+                .setLastFetchTimestamp(activeFollowerFetchTimeMs)
+                .setLastCaughtUpTimestamp(activeFollowerFetchTimeMs),
+            activeFollowerState);
+
+        DescribeQuorumResponseData.ReplicaState inactiveFollowerState =
+            findReplicaOrFail(inactiveFollowerId, partitionData.currentVoters());
+        assertEquals(new DescribeQuorumResponseData.ReplicaState()
+                .setReplicaId(inactiveFollowerId)
+                .setLogEndOffset(-1)
+                .setLastFetchTimestamp(-1)
+                .setLastCaughtUpTimestamp(-1),
+            inactiveFollowerState);
     }
 
     private LeaderState<?> setUpLeaderAndFollowers(int follower1,
@@ -324,37 +405,60 @@ public class LeaderStateTest {
                                                    long leaderStartOffset,
                                                    long leaderEndOffset) {
         LeaderState<?> state = newLeaderState(mkSet(localId, follower1, follower2), leaderStartOffset);
-        state.updateLocalState(0, new LogOffsetMetadata(leaderEndOffset));
+        state.updateLocalState(new LogOffsetMetadata(leaderEndOffset));
         assertEquals(Optional.empty(), state.highWatermark());
-        state.updateReplicaState(follower1, 0, new LogOffsetMetadata(leaderStartOffset), leaderEndOffset);
-        state.updateReplicaState(follower2, 0, new LogOffsetMetadata(leaderEndOffset), leaderEndOffset);
+        state.updateReplicaState(follower1, 0, new LogOffsetMetadata(leaderStartOffset));
+        state.updateReplicaState(follower2, 0, new LogOffsetMetadata(leaderEndOffset));
         return state;
     }
 
     @Test
-    public void testGetObserverStatesWithObserver() {
+    public void testDescribeQuorumWithObservers() {
+        MockTime time = new MockTime();
         int observerId = 10;
         long epochStartOffset = 10L;
 
         LeaderState<?> state = newLeaderState(mkSet(localId), epochStartOffset);
-        long timestamp = 20L;
-        assertFalse(state.updateReplicaState(observerId, timestamp, new LogOffsetMetadata(epochStartOffset), epochStartOffset + 10));
-
-        assertEquals(Collections.singletonMap(observerId, epochStartOffset),
-                state.quorumResponseObserverStates(timestamp)
-                    .stream()
-                    .collect(Collectors.toMap(DescribeQuorumResponseData.ReplicaState::replicaId, DescribeQuorumResponseData.ReplicaState::logEndOffset)));
+        assertTrue(state.updateLocalState(new LogOffsetMetadata(epochStartOffset + 1)));
+        assertEquals(Optional.of(new LogOffsetMetadata(epochStartOffset + 1)), state.highWatermark());
+
+        time.sleep(500);
+        long observerFetchTimeMs = time.milliseconds();
+        assertFalse(state.updateReplicaState(observerId, observerFetchTimeMs, new LogOffsetMetadata(epochStartOffset + 1)));
+
+        time.sleep(500);
+        DescribeQuorumResponseData.PartitionData partitionData = state.describeQuorum(time.milliseconds());
+        assertEquals(epochStartOffset + 1, partitionData.highWatermark());
+        assertEquals(localId, partitionData.leaderId());
+        assertEquals(epoch, partitionData.leaderEpoch());
+
+        assertEquals(1, partitionData.currentVoters().size());
+        assertEquals(localId, partitionData.currentVoters().get(0).replicaId());
+
+        List<DescribeQuorumResponseData.ReplicaState> observerStates = partitionData.observers();
+        assertEquals(1, observerStates.size());
+
+        DescribeQuorumResponseData.ReplicaState observerState = observerStates.get(0);
+        assertEquals(new DescribeQuorumResponseData.ReplicaState()
+                .setReplicaId(observerId)
+                .setLogEndOffset(epochStartOffset + 1)
+                .setLastFetchTimestamp(observerFetchTimeMs)
+                .setLastCaughtUpTimestamp(observerFetchTimeMs),
+            observerState);
     }
 
     @Test
     public void testNoOpForNegativeRemoteNodeId() {
-        int observerId = -1;
+        MockTime time = new MockTime();
+        int replicaId = -1;
         long epochStartOffset = 10L;
 
         LeaderState<?> state = newLeaderState(mkSet(localId), epochStartOffset);
-        assertFalse(state.updateReplicaState(observerId, 0, new LogOffsetMetadata(epochStartOffset), epochStartOffset + 10));
+        assertFalse(state.updateReplicaState(replicaId, 0, new LogOffsetMetadata(epochStartOffset)));
 
-        assertEquals(emptyList(), state.quorumResponseObserverStates(10));
+        DescribeQuorumResponseData.PartitionData partitionData = state.describeQuorum(time.milliseconds());
+        List<DescribeQuorumResponseData.ReplicaState> observerStates = partitionData.observers();
+        assertEquals(Collections.emptyList(), observerStates);
     }
 
     @Test
@@ -364,14 +468,17 @@ public class LeaderStateTest {
         long epochStartOffset = 10L;
         LeaderState<?> state = newLeaderState(mkSet(localId), epochStartOffset);
 
-        state.updateReplicaState(observerId, time.milliseconds(), new LogOffsetMetadata(epochStartOffset), epochStartOffset + 10);
-        assertEquals(singleton(observerId),
-                state.quorumResponseObserverStates(time.milliseconds())
-                    .stream().map(o -> o.replicaId())
-                    .collect(Collectors.toSet()));
+        state.updateReplicaState(observerId, time.milliseconds(), new LogOffsetMetadata(epochStartOffset));
+        DescribeQuorumResponseData.PartitionData partitionData = state.describeQuorum(time.milliseconds());
+        List<DescribeQuorumResponseData.ReplicaState> observerStates = partitionData.observers();
+        assertEquals(1, observerStates.size());
+
+        DescribeQuorumResponseData.ReplicaState observerState = observerStates.get(0);
+        assertEquals(observerId, observerState.replicaId());
 
         time.sleep(LeaderState.OBSERVER_SESSION_TIMEOUT_MS);
-        assertEquals(emptyList(), state.quorumResponseObserverStates(time.milliseconds()));
+        partitionData = state.describeQuorum(time.milliseconds());
+        assertEquals(Collections.emptyList(), partitionData.observers());
     }
 
     @ParameterizedTest
@@ -405,4 +512,34 @@ public class LeaderStateTest {
         }
     }
 
+    private DescribeQuorumResponseData.ReplicaState describeVoterState(
+        LeaderState state,
+        int voterId,
+        long currentTimeMs
+    ) {
+        DescribeQuorumResponseData.PartitionData partitionData = state.describeQuorum(currentTimeMs);
+        return findReplicaOrFail(voterId, partitionData.currentVoters());
+    }
+
+    private DescribeQuorumResponseData.ReplicaState describeObserverState(
+        LeaderState state,
+        int observerId,
+        long currentTimeMs
+    ) {
+        DescribeQuorumResponseData.PartitionData partitionData = state.describeQuorum(currentTimeMs);
+        return findReplicaOrFail(observerId, partitionData.observers());
+    }
+
+    private DescribeQuorumResponseData.ReplicaState findReplicaOrFail(
+        int replicaId,
+        List<DescribeQuorumResponseData.ReplicaState> replicas
+    ) {
+        return replicas.stream()
+            .filter(observer -> observer.replicaId() == replicaId)
+            .findFirst()
+            .orElseThrow(() -> new AssertionError(
+                "Failed to find expected replica state for replica " + replicaId
+            ));
+    }
+
 }
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
index b825fc8867a..3af4ba75dfd 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
@@ -458,13 +458,18 @@ public final class RaftClientTestContext {
         List<ReplicaState> observerStates
     ) {
         DescribeQuorumResponseData response = collectDescribeQuorumResponse();
+
+        DescribeQuorumResponseData.PartitionData partitionData = new DescribeQuorumResponseData.PartitionData()
+            .setErrorCode(Errors.NONE.code())
+            .setLeaderId(leaderId)
+            .setLeaderEpoch(leaderEpoch)
+            .setHighWatermark(highWatermark)
+            .setCurrentVoters(voterStates)
+            .setObservers(observerStates);
         DescribeQuorumResponseData expectedResponse = DescribeQuorumResponse.singletonResponse(
             metadataPartition,
-            leaderId,
-            leaderEpoch,
-            highWatermark,
-            voterStates,
-            observerStates);
+            partitionData
+        );
         assertEquals(expectedResponse, response);
     }
 
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java
index cc2700bb17a..d362afc574f 100644
--- a/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java
@@ -102,8 +102,8 @@ public class KafkaRaftMetricsTest {
         assertEquals((double) 1, getMetric(metrics, "current-epoch").metricValue());
         assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue());
 
-        state.leaderStateOrThrow().updateLocalState(0, new LogOffsetMetadata(5L));
-        state.leaderStateOrThrow().updateReplicaState(1, 0, new LogOffsetMetadata(5L), 6L);
+        state.leaderStateOrThrow().updateLocalState(new LogOffsetMetadata(5L));
+        state.leaderStateOrThrow().updateReplicaState(1, 0, new LogOffsetMetadata(5L));
         assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue());
 
         state.transitionToFollower(2, 1);