You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/06/12 15:50:36 UTC

[kafka] branch 2.3 updated: KAFKA-8500; Static member rejoin should always update member.id (#6899)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch 2.3
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/2.3 by this push:
     new c796518  KAFKA-8500; Static member rejoin should always update member.id (#6899)
c796518 is described below

commit c7965187a6953435b13727b764d8c9d6dbe9b210
Author: Boyang Chen <bo...@confluent.io>
AuthorDate: Wed Jun 12 08:41:58 2019 -0700

    KAFKA-8500; Static member rejoin should always update member.id (#6899)
    
    This PR fixes a bug in static group membership. Previously we limit the `member.id` replacement in JoinGroup to only cases when the group is in Stable. This is error-prone and could potentially allow duplicate consumers reading from the same topic. For example, imagine a case where two unknown members join in the `PrepareRebalance` stage at the same time.
    
    The PR fixes the following things:
    
    1. Replace `member.id` at any time we see a known static member rejoins group with unknown member.id
    2. Immediately fence any ongoing join/sync group callback to early terminate the duplicate member.
    3. Clearly handle Dead/Empty cases as exceptional.
    4. Return old leader id upon static member leader rejoin to avoid trivial member assignment being triggered.
    
    Reviewers: Guozhang Wang <wa...@gmail.com>, Jason Gustafson <ja...@confluent.io>
---
 .../kafka/coordinator/group/GroupCoordinator.scala | 110 ++++++-------
 .../kafka/coordinator/group/GroupMetadata.scala    |  34 +++-
 .../kafka/coordinator/group/MemberMetadata.scala   |   3 +-
 core/src/main/scala/kafka/server/KafkaApis.scala   |  12 +-
 .../group/GroupCoordinatorConcurrencyTest.scala    |   6 +-
 .../coordinator/group/GroupCoordinatorTest.scala   | 173 +++++++++++++++++++--
 .../coordinator/group/GroupMetadataTest.scala      |  95 ++++++++++-
 tests/kafkatest/tests/client/consumer_test.py      |   6 +
 8 files changed, 352 insertions(+), 87 deletions(-)

diff --git a/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
index 52f1a98..07d7435 100644
--- a/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
+++ b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala
@@ -58,7 +58,7 @@ class GroupCoordinator(val brokerId: Int,
   import GroupCoordinator._
 
   type JoinCallback = JoinGroupResult => Unit
-  type SyncCallback = (Array[Byte], Errors) => Unit
+  type SyncCallback = SyncGroupResult => Unit
 
   this.logIdent = "[GroupCoordinator " + brokerId + "]: "
 
@@ -179,31 +179,38 @@ class GroupCoordinator(val brokerId: Int,
 
         if (group.hasStaticMember(groupInstanceId)) {
           val oldMemberId = group.getStaticMemberId(groupInstanceId)
+          info(s"Static member $groupInstanceId with unknown member id rejoins, assigning new member id $newMemberId, while " +
+            s"old member $oldMemberId will be removed.")
 
-          if (group.is(Stable)) {
-            info(s"Static member $groupInstanceId with unknown member id rejoins, assigning new member id $newMemberId, while " +
-              s"old member $oldMemberId will be removed. No rebalance will be triggered.")
+          val currentLeader = group.leaderOrNull
+          val member = group.replaceGroupInstance(oldMemberId, newMemberId, groupInstanceId)
+          // Heartbeat of old member id will expire without effect since the group no longer contains that member id.
+          // New heartbeat shall be scheduled with new member id.
+          completeAndScheduleNextHeartbeatExpiration(group, member)
 
-            val oldMember = group.replaceGroupInstance(oldMemberId, newMemberId, groupInstanceId)
+          val knownStaticMember = group.get(newMemberId)
+          group.updateMember(knownStaticMember, protocols, responseCallback)
 
-            // Heartbeat of old member id will expire without affection since the group no longer contains that member id.
-            // New heartbeat shall be scheduled with new member id.
-            completeAndScheduleNextHeartbeatExpiration(group, oldMember)
-
-            responseCallback(JoinGroupResult(
-              members = if (group.isLeader(newMemberId)) {
-                group.currentMemberMetadata
-              } else {
-                List.empty
-              },
-              memberId = newMemberId,
-              generationId = group.generationId,
-              subProtocol = group.protocolOrNull,
-              leaderId = group.leaderOrNull,
-              error = Errors.NONE))
-          } else {
-            val knownStaticMember = group.get(oldMemberId)
-            updateMemberAndRebalance(group, knownStaticMember, protocols, responseCallback)
+          group.currentState match {
+            case Stable | CompletingRebalance =>
+              info(s"Static member joins during ${group.currentState} stage will not trigger rebalance.")
+              group.maybeInvokeJoinCallback(member, JoinGroupResult(
+                members = List.empty,
+                memberId = newMemberId,
+                generationId = group.generationId,
+                subProtocol = group.protocolOrNull,
+                // We want to avoid current leader performing trivial assignment while the group
+                // is in stable/awaiting sync stage, because the new assignment in leader's next sync call
+                // won't be broadcast by a stable/awaiting sync group. This could be guaranteed by
+                // always returning the old leader id so that the current leader won't assume itself
+                // as a leader based on the returned message, since the new member.id won't match
+                // returned leader id, therefore no assignment will be performed.
+                leaderId = currentLeader,
+                error = Errors.NONE))
+            case Empty | Dead =>
+              throw new IllegalStateException(s"Group ${group.groupId} was not supposed to be " +
+                s"in the state ${group.currentState} when the unknown static member $groupInstanceId rejoins.")
+            case PreparingRebalance =>
           }
         } else if (requireKnownMemberId) {
             // If member id required (dynamic membership), register the member in the pending member list
@@ -214,7 +221,6 @@ class GroupCoordinator(val brokerId: Int,
         } else {
           addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, newMemberId, groupInstanceId,
             clientId, clientHost, protocolType, protocols, group, responseCallback)
-
         }
       }
     }
@@ -328,13 +334,13 @@ class GroupCoordinator(val brokerId: Int,
         // group will need to start over at JoinGroup. By returning rebalance in progress, the consumer
         // will attempt to rejoin without needing to rediscover the coordinator. Note that we cannot
         // return COORDINATOR_LOAD_IN_PROGRESS since older clients do not expect the error.
-        responseCallback(Array.empty, Errors.REBALANCE_IN_PROGRESS)
+        responseCallback(SyncGroupResult(Array.empty, Errors.REBALANCE_IN_PROGRESS))
 
-      case Some(error) => responseCallback(Array.empty, error)
+      case Some(error) => responseCallback(SyncGroupResult(Array.empty, error))
 
       case None =>
         groupManager.getGroup(groupId) match {
-          case None => responseCallback(Array.empty, Errors.UNKNOWN_MEMBER_ID)
+          case None => responseCallback(SyncGroupResult(Array.empty, Errors.UNKNOWN_MEMBER_ID))
           case Some(group) => doSyncGroup(group, generation, memberId, groupInstanceId, groupAssignment, responseCallback)
         }
     }
@@ -352,20 +358,20 @@ class GroupCoordinator(val brokerId: Int,
         // from the coordinator metadata; this is likely that the group has migrated to some other
         // coordinator OR the group is in a transient unstable phase. Let the member retry
         // finding the correct coordinator and rejoin.
-        responseCallback(Array.empty, Errors.COORDINATOR_NOT_AVAILABLE)
+        responseCallback(SyncGroupResult(Array.empty, Errors.COORDINATOR_NOT_AVAILABLE))
       } else if (group.isStaticMemberFenced(memberId, groupInstanceId)) {
-        responseCallback(Array.empty, Errors.FENCED_INSTANCE_ID)
+        responseCallback(SyncGroupResult(Array.empty, Errors.FENCED_INSTANCE_ID))
       } else if (!group.has(memberId)) {
-        responseCallback(Array.empty, Errors.UNKNOWN_MEMBER_ID)
+        responseCallback(SyncGroupResult(Array.empty, Errors.UNKNOWN_MEMBER_ID))
       } else if (generationId != group.generationId) {
-        responseCallback(Array.empty, Errors.ILLEGAL_GENERATION)
+        responseCallback(SyncGroupResult(Array.empty, Errors.ILLEGAL_GENERATION))
       } else {
         group.currentState match {
           case Empty =>
-            responseCallback(Array.empty, Errors.UNKNOWN_MEMBER_ID)
+            responseCallback(SyncGroupResult(Array.empty, Errors.UNKNOWN_MEMBER_ID))
 
           case PreparingRebalance =>
-            responseCallback(Array.empty, Errors.REBALANCE_IN_PROGRESS)
+            responseCallback(SyncGroupResult(Array.empty, Errors.REBALANCE_IN_PROGRESS))
 
           case CompletingRebalance =>
             group.get(memberId).awaitingSyncCallback = responseCallback
@@ -399,7 +405,7 @@ class GroupCoordinator(val brokerId: Int,
           case Stable =>
             // if the group is stable, we just return the current assignment
             val memberMetadata = group.get(memberId)
-            responseCallback(memberMetadata.assignment, Errors.NONE)
+            responseCallback(SyncGroupResult(memberMetadata.assignment, Errors.NONE))
             completeAndScheduleNextHeartbeatExpiration(group, group.get(memberId))
         }
       }
@@ -463,7 +469,7 @@ class GroupCoordinator(val brokerId: Int,
                   case Empty =>
                     group.transitionTo(Dead)
                     groupsEligibleForDeletion :+= group
-                  case _ =>
+                  case Stable | PreparingRebalance | CompletingRebalance =>
                     groupErrors += groupId -> Errors.NON_EMPTY_GROUP
                 }
               }
@@ -708,10 +714,7 @@ class GroupCoordinator(val brokerId: Int,
 
         case Stable | CompletingRebalance =>
           for (member <- group.allMemberMetadata) {
-            if (member.awaitingSyncCallback != null) {
-              member.awaitingSyncCallback(Array.empty[Byte], Errors.NOT_COORDINATOR)
-              member.awaitingSyncCallback = null
-            }
+            group.maybeInvokeSyncCallback(member, SyncGroupResult(Array.empty, Errors.NOT_COORDINATOR))
             heartbeatPurgatory.checkAndComplete(MemberKey(member.groupId, member.memberId))
           }
       }
@@ -746,16 +749,13 @@ class GroupCoordinator(val brokerId: Int,
 
   private def resetAndPropagateAssignmentError(group: GroupMetadata, error: Errors) {
     assert(group.is(CompletingRebalance))
-    group.allMemberMetadata.foreach(_.assignment = Array.empty[Byte])
+    group.allMemberMetadata.foreach(_.assignment = Array.empty)
     propagateAssignment(group, error)
   }
 
   private def propagateAssignment(group: GroupMetadata, error: Errors) {
     for (member <- group.allMemberMetadata) {
-      if (member.awaitingSyncCallback != null) {
-        member.awaitingSyncCallback(member.assignment, error)
-        member.awaitingSyncCallback = null
-
+      if (group.maybeInvokeSyncCallback(member, SyncGroupResult(member.assignment, error))) {
         // reset the session timeout for members after propagating the member's assignment.
         // This is because if any member's session expired while we were still awaiting either
         // the leader sync group or the storage callback, its expiration will be ignored and no
@@ -765,16 +765,6 @@ class GroupCoordinator(val brokerId: Int,
     }
   }
 
-  private def joinError(memberId: String, error: Errors): JoinGroupResult = {
-    JoinGroupResult(
-      members = List.empty,
-      memberId = memberId,
-      generationId = GroupCoordinator.NoGeneration,
-      subProtocol = GroupCoordinator.NoProtocol,
-      leaderId = GroupCoordinator.NoLeader,
-      error = error)
-  }
-
   /**
    * Complete existing DelayedHeartbeats for the given member and schedule the next one
    */
@@ -1084,6 +1074,15 @@ object GroupCoordinator {
     new GroupCoordinator(config.brokerId, groupConfig, offsetConfig, groupMetadataManager, heartbeatPurgatory, joinPurgatory, time)
   }
 
+  def joinError(memberId: String, error: Errors): JoinGroupResult = {
+    JoinGroupResult(
+      members = List.empty,
+      memberId = memberId,
+      generationId = GroupCoordinator.NoGeneration,
+      subProtocol = GroupCoordinator.NoProtocol,
+      leaderId = GroupCoordinator.NoLeader,
+      error = error)
+  }
 }
 
 case class GroupConfig(groupMinSessionTimeoutMs: Int,
@@ -1097,3 +1096,6 @@ case class JoinGroupResult(members: List[JoinGroupResponseMember],
                            subProtocol: String,
                            leaderId: String,
                            error: Errors)
+
+case class SyncGroupResult(memberAssignment: Array[Byte],
+                           error: Errors)
diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadata.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadata.scala
index 21aae42..58a68a2 100644
--- a/core/src/main/scala/kafka/coordinator/group/GroupMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadata.scala
@@ -23,6 +23,7 @@ import kafka.common.OffsetAndMetadata
 import kafka.utils.{CoreUtils, Logging, nonthreadsafe}
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember
+import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.utils.Time
 
 import scala.collection.{Seq, immutable, mutable}
@@ -162,7 +163,7 @@ case class GroupSummary(state: String,
   * being materialized.
   */
 case class CommitRecordMetadataAndOffset(appendedBatchOffset: Option[Long], offsetAndMetadata: OffsetAndMetadata) {
-  def olderThan(that: CommitRecordMetadataAndOffset) : Boolean = appendedBatchOffset.get < that.appendedBatchOffset.get
+  def olderThan(that: CommitRecordMetadataAndOffset): Boolean = appendedBatchOffset.get < that.appendedBatchOffset.get
 }
 
 /**
@@ -290,6 +291,19 @@ private[group] class GroupMetadata(val groupId: String, initialState: GroupState
     val oldMember = members.remove(oldMemberId)
       .getOrElse(throw new IllegalArgumentException(s"Cannot replace non-existing member id $oldMemberId"))
 
+    // Fence potential duplicate member immediately if someone awaits join/sync callback.
+    maybeInvokeJoinCallback(oldMember, JoinGroupResult(
+      members = List.empty,
+      memberId = oldMemberId,
+      generationId = GroupCoordinator.NoGeneration,
+      subProtocol = GroupCoordinator.NoProtocol,
+      leaderId = GroupCoordinator.NoLeader,
+      error = Errors.FENCED_INSTANCE_ID))
+
+    maybeInvokeSyncCallback(oldMember, SyncGroupResult(
+      Array.empty, Errors.FENCED_INSTANCE_ID
+    ))
+
     oldMember.memberId = newMemberId
     members.put(newMemberId, oldMember)
 
@@ -425,7 +439,7 @@ private[group] class GroupMetadata(val groupId: String, initialState: GroupState
   }
 
   def maybeInvokeJoinCallback(member: MemberMetadata,
-                              joinGroupResult: JoinGroupResult) : Unit = {
+                              joinGroupResult: JoinGroupResult): Unit = {
     if (member.isAwaitingJoin) {
       member.awaitingJoinCallback(joinGroupResult)
       member.awaitingJoinCallback = null
@@ -433,6 +447,20 @@ private[group] class GroupMetadata(val groupId: String, initialState: GroupState
     }
   }
 
+  /**
+    * @return true if a sync callback actually performs.
+    */
+  def maybeInvokeSyncCallback(member: MemberMetadata,
+                              syncGroupResult: SyncGroupResult): Boolean = {
+    if (member.isAwaitingSync) {
+      member.awaitingSyncCallback(syncGroupResult)
+      member.awaitingSyncCallback = null
+      true
+    } else {
+      false
+    }
+  }
+
   def initNextGeneration() = {
     if (members.nonEmpty) {
       generationId += 1
@@ -600,7 +628,7 @@ private[group] class GroupMetadata(val groupId: String, initialState: GroupState
     }.toMap
   }
 
-  def removeExpiredOffsets(currentTimestamp: Long, offsetRetentionMs: Long) : Map[TopicPartition, OffsetAndMetadata] = {
+  def removeExpiredOffsets(currentTimestamp: Long, offsetRetentionMs: Long): Map[TopicPartition, OffsetAndMetadata] = {
 
     def getExpiredOffsets(baseTimestamp: CommitRecordMetadataAndOffset => Long): Map[TopicPartition, OffsetAndMetadata] = {
       offsets.filter {
diff --git a/core/src/main/scala/kafka/coordinator/group/MemberMetadata.scala b/core/src/main/scala/kafka/coordinator/group/MemberMetadata.scala
index a090d97..83ff709 100644
--- a/core/src/main/scala/kafka/coordinator/group/MemberMetadata.scala
+++ b/core/src/main/scala/kafka/coordinator/group/MemberMetadata.scala
@@ -66,13 +66,14 @@ private[group] class MemberMetadata(var memberId: String,
 
   var assignment: Array[Byte] = Array.empty[Byte]
   var awaitingJoinCallback: JoinGroupResult => Unit = null
-  var awaitingSyncCallback: (Array[Byte], Errors) => Unit = null
+  var awaitingSyncCallback: SyncGroupResult => Unit = null
   var latestHeartbeat: Long = -1
   var isLeaving: Boolean = false
   var isNew: Boolean = false
   val isStaticMember: Boolean = groupInstanceId.isDefined
 
   def isAwaitingJoin = awaitingJoinCallback != null
+  def isAwaitingSync = awaitingSyncCallback != null
 
   /**
    * Get metadata corresponding to the provided protocol.
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala
index d362d64..c883345 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -30,7 +30,7 @@ import kafka.api.{ApiVersion, KAFKA_0_11_0_IV0, KAFKA_2_3_IV0}
 import kafka.cluster.Partition
 import kafka.common.OffsetAndMetadata
 import kafka.controller.KafkaController
-import kafka.coordinator.group.{GroupCoordinator, JoinGroupResult}
+import kafka.coordinator.group.{GroupCoordinator, JoinGroupResult, SyncGroupResult}
 import kafka.coordinator.transaction.{InitProducerIdResult, TransactionCoordinator}
 import kafka.message.ZStdCompressionCodec
 import kafka.network.RequestChannel
@@ -1393,12 +1393,12 @@ class KafkaApis(val requestChannel: RequestChannel,
   def handleSyncGroupRequest(request: RequestChannel.Request) {
     val syncGroupRequest = request.body[SyncGroupRequest]
 
-    def sendResponseCallback(memberState: Array[Byte], error: Errors) {
+    def sendResponseCallback(syncGroupResult: SyncGroupResult) {
       sendResponseMaybeThrottle(request, requestThrottleMs =>
         new SyncGroupResponse(
           new SyncGroupResponseData()
-            .setErrorCode(error.code)
-            .setAssignment(memberState)
+            .setErrorCode(syncGroupResult.error.code)
+            .setAssignment(syncGroupResult.memberAssignment)
             .setThrottleTimeMs(requestThrottleMs)
         ))
     }
@@ -1407,9 +1407,9 @@ class KafkaApis(val requestChannel: RequestChannel,
       // Only enable static membership when IBP >= 2.3, because it is not safe for the broker to use the static member logic
       // until we are sure that all brokers support it. If static group being loaded by an older coordinator, it will discard
       // the group.instance.id field, so static members could accidentally become "dynamic", which leads to wrong states.
-      sendResponseCallback(Array[Byte](), Errors.UNSUPPORTED_VERSION)
+      sendResponseCallback(SyncGroupResult(Array[Byte](), Errors.UNSUPPORTED_VERSION))
     } else if (!authorize(request.session, Read, Resource(Group, syncGroupRequest.data.groupId, LITERAL))) {
-      sendResponseCallback(Array[Byte](), Errors.GROUP_AUTHORIZATION_FAILED)
+      sendResponseCallback(SyncGroupResult(Array[Byte](), Errors.GROUP_AUTHORIZATION_FAILED))
     } else {
       val assignmentMap = immutable.Map.newBuilder[String, Array[Byte]]
       syncGroupRequest.data.assignments.asScala.foreach { assignment =>
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
index 3da5a0c..1cee665 100644
--- a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala
@@ -173,8 +173,8 @@ class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest
 
   class SyncGroupOperation extends GroupOperation[SyncGroupCallbackParams, SyncGroupCallback] {
     override def responseCallback(responsePromise: Promise[SyncGroupCallbackParams]): SyncGroupCallback = {
-      val callback: SyncGroupCallback = (assignment, error) =>
-        responsePromise.success((assignment, error))
+      val callback: SyncGroupCallback = syncGroupResult =>
+        responsePromise.success(syncGroupResult.memberAssignment, syncGroupResult.error)
       callback
     }
     override def runWithCallback(member: GroupMember, responseCallback: SyncGroupCallback): Unit = {
@@ -280,7 +280,7 @@ object GroupCoordinatorConcurrencyTest {
 
   type JoinGroupCallback = JoinGroupResult => Unit
   type SyncGroupCallbackParams = (Array[Byte], Errors)
-  type SyncGroupCallback = (Array[Byte], Errors) => Unit
+  type SyncGroupCallback = SyncGroupResult => Unit
   type HeartbeatCallbackParams = Errors
   type HeartbeatCallback = Errors => Unit
   type CommitOffsetCallbackParams = Map[TopicPartition, Errors]
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala
index 5bbaf5d..3e753c6 100644
--- a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala
@@ -47,7 +47,7 @@ import scala.concurrent.{Await, Future, Promise, TimeoutException}
 class GroupCoordinatorTest {
   type JoinGroupCallback = JoinGroupResult => Unit
   type SyncGroupCallbackParams = (Array[Byte], Errors)
-  type SyncGroupCallback = (Array[Byte], Errors) => Unit
+  type SyncGroupCallback = SyncGroupResult => Unit
   type HeartbeatCallbackParams = Errors
   type HeartbeatCallback = Errors => Unit
   type CommitOffsetCallbackParams = Map[TopicPartition, Errors]
@@ -59,7 +59,7 @@ class GroupCoordinatorTest {
   val ClientHost = "localhost"
   val GroupMinSessionTimeout = 10
   val GroupMaxSessionTimeout = 10 * 60 * 1000
-  val GroupMaxSize = 3
+  val GroupMaxSize = 4
   val DefaultRebalanceTimeout = 500
   val DefaultSessionTimeout = 500
   val GroupInitialRebalanceDelay = 50
@@ -146,7 +146,7 @@ class GroupCoordinatorTest {
     // SyncGroup
     var syncGroupResponse: Option[Errors] = None
     groupCoordinator.handleSyncGroup(otherGroupId, 1, memberId, None, Map.empty[String, Array[Byte]],
-      (_, error)=> syncGroupResponse = Some(error))
+      syncGroupResult => syncGroupResponse = Some(syncGroupResult.error))
     assertEquals(Some(Errors.REBALANCE_IN_PROGRESS), syncGroupResponse)
 
     // OffsetCommit
@@ -439,6 +439,151 @@ class GroupCoordinatorTest {
   }
 
   @Test
+  def staticMemberFenceDuplicateRejoinedFollower() {
+    val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId)
+
+    EasyMock.reset(replicaManager)
+    // A third member joins will trigger rebalance.
+    sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols)
+    timer.advanceClock(1)
+    assertTrue(getGroup(groupId).is(PreparingRebalance))
+
+    EasyMock.reset(replicaManager)
+    timer.advanceClock(1)
+    // Old follower rejoins group will be matching current member.id.
+    val oldFollowerJoinGroupFuture =
+      sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, protocols, groupInstanceId = followerInstanceId)
+
+    EasyMock.reset(replicaManager)
+    timer.advanceClock(1)
+    // Duplicate follower joins group with unknown member id will trigger member.id replacement.
+    val duplicateFollowerJoinFuture =
+      sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, groupInstanceId = followerInstanceId)
+
+    timer.advanceClock(1)
+    // Old member shall be fenced immediately upon duplicate follower joins.
+    val oldFollowerJoinGroupResult = Await.result(oldFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS))
+    checkJoinGroupResult(oldFollowerJoinGroupResult,
+      Errors.FENCED_INSTANCE_ID,
+      -1,
+      Set.empty,
+      groupId,
+      PreparingRebalance)
+    verifyDelayedTaskNotCompleted(duplicateFollowerJoinFuture)
+  }
+
+  @Test
+  def staticMemberFenceDuplicateSyncingFollowerAfterMemberIdChanged() {
+    val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId)
+
+    EasyMock.reset(replicaManager)
+    // Known leader rejoins will trigger rebalance.
+    val leaderJoinGroupFuture =
+      sendJoinGroup(groupId, rebalanceResult.leaderId, protocolType, protocols, groupInstanceId = leaderInstanceId)
+    timer.advanceClock(1)
+    assertTrue(getGroup(groupId).is(PreparingRebalance))
+
+    EasyMock.reset(replicaManager)
+    timer.advanceClock(1)
+    // Old follower rejoins group will match current member.id.
+    val oldFollowerJoinGroupFuture =
+      sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, protocols, groupInstanceId = followerInstanceId)
+
+    timer.advanceClock(1)
+    val leaderJoinGroupResult = Await.result(leaderJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS))
+    checkJoinGroupResult(leaderJoinGroupResult,
+      Errors.NONE,
+      rebalanceResult.generation + 1,
+      Set(leaderInstanceId, followerInstanceId),
+      groupId,
+      CompletingRebalance)
+    assertEquals(leaderJoinGroupResult.leaderId, leaderJoinGroupResult.memberId)
+    assertEquals(rebalanceResult.leaderId, leaderJoinGroupResult.leaderId)
+
+    // Old member shall be getting a successful join group response.
+    val oldFollowerJoinGroupResult = Await.result(oldFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS))
+    checkJoinGroupResult(oldFollowerJoinGroupResult,
+      Errors.NONE,
+      rebalanceResult.generation + 1,
+      Set.empty,
+      groupId,
+      CompletingRebalance,
+      expectedLeaderId = leaderJoinGroupResult.memberId)
+
+    EasyMock.reset(replicaManager)
+    val oldFollowerSyncGroupFuture = sendSyncGroupFollower(groupId, oldFollowerJoinGroupResult.generationId,
+      oldFollowerJoinGroupResult.memberId, followerInstanceId)
+
+    // Duplicate follower joins group with unknown member id will trigger member.id replacement.
+    EasyMock.reset(replicaManager)
+    val duplicateFollowerJoinFuture =
+      sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, groupInstanceId = followerInstanceId)
+    timer.advanceClock(1)
+
+    // Old follower sync callback will return fenced exception while broker replaces the member identity.
+    val oldFollowerSyncGroupResult = Await.result(oldFollowerSyncGroupFuture, Duration(1, TimeUnit.MILLISECONDS))
+    assertEquals(oldFollowerSyncGroupResult._2, Errors.FENCED_INSTANCE_ID)
+
+    // Duplicate follower will get the same response as old follower.
+    val duplicateFollowerJoinGroupResult = Await.result(duplicateFollowerJoinFuture, Duration(1, TimeUnit.MILLISECONDS))
+    checkJoinGroupResult(duplicateFollowerJoinGroupResult,
+      Errors.NONE,
+      rebalanceResult.generation + 1,
+      Set.empty,
+      groupId,
+      CompletingRebalance,
+      expectedLeaderId = leaderJoinGroupResult.memberId)
+  }
+
+  @Test
+  def staticMemberFenceDuplicateRejoiningFollowerAfterMemberIdChanged() {
+    val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId)
+
+    EasyMock.reset(replicaManager)
+    // Known leader rejoins will trigger rebalance.
+    val leaderJoinGroupFuture =
+      sendJoinGroup(groupId, rebalanceResult.leaderId, protocolType, protocols, groupInstanceId = leaderInstanceId)
+    timer.advanceClock(1)
+    assertTrue(getGroup(groupId).is(PreparingRebalance))
+
+    EasyMock.reset(replicaManager)
+    // Duplicate follower joins group will trigger member.id replacement.
+    val duplicateFollowerJoinGroupFuture =
+      sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, groupInstanceId = followerInstanceId)
+
+    EasyMock.reset(replicaManager)
+    timer.advanceClock(1)
+    // Old follower rejoins group will fail because member.id already updated.
+    val oldFollowerJoinGroupFuture =
+      sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, protocols, groupInstanceId = followerInstanceId)
+
+    val leaderRejoinGroupResult = Await.result(leaderJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS))
+    checkJoinGroupResult(leaderRejoinGroupResult,
+      Errors.NONE,
+      rebalanceResult.generation + 1,
+      Set(leaderInstanceId, followerInstanceId),
+      groupId,
+      CompletingRebalance)
+
+    val duplicateFollowerJoinGroupResult = Await.result(duplicateFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS))
+    checkJoinGroupResult(duplicateFollowerJoinGroupResult,
+      Errors.NONE,
+      rebalanceResult.generation + 1,
+      Set.empty,
+      groupId,
+      CompletingRebalance)
+    assertNotEquals(rebalanceResult.followerId, duplicateFollowerJoinGroupResult.memberId)
+
+    val oldFollowerJoinGroupResult = Await.result(oldFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS))
+    checkJoinGroupResult(oldFollowerJoinGroupResult,
+      Errors.FENCED_INSTANCE_ID,
+      -1,
+      Set.empty,
+      groupId,
+      CompletingRebalance)
+  }
+
+  @Test
   def staticMemberRejoinWithKnownMemberId() {
     var joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, groupInstanceId, protocolType, protocols)
     assertEquals(Errors.NONE, joinGroupResult.error)
@@ -464,31 +609,31 @@ class GroupCoordinatorTest {
   def staticMemberRejoinWithLeaderIdAndUnknownMemberId() {
     val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId)
 
-    // A static leader rejoin with unknown id will not trigger rebalance.
+    // A static leader rejoin with unknown id will not trigger rebalance, and no assignment will be returned.
     EasyMock.reset(replicaManager)
     val joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, leaderInstanceId, protocolType, protocolSuperset, clockAdvance = 1)
 
     checkJoinGroupResult(joinGroupResult,
       Errors.NONE,
       rebalanceResult.generation, // The group should be at the same generation
-      Set(leaderInstanceId, followerInstanceId),
+      Set.empty,
       groupId,
-      Stable)
+      Stable,
+      rebalanceResult.leaderId)
 
     EasyMock.reset(replicaManager)
     val oldLeaderJoinGroupResult = staticJoinGroup(groupId, rebalanceResult.leaderId, leaderInstanceId, protocolType, protocolSuperset, clockAdvance = 1)
     assertEquals(Errors.FENCED_INSTANCE_ID, oldLeaderJoinGroupResult.error)
 
     EasyMock.reset(replicaManager)
-    assertNotEquals(rebalanceResult.leaderId, joinGroupResult.leaderId)
     // Old leader will get fenced.
     val oldLeaderSyncGroupResult = syncGroupLeader(groupId, rebalanceResult.generation, rebalanceResult.leaderId, Map.empty, leaderInstanceId)
     assertEquals(Errors.FENCED_INSTANCE_ID, oldLeaderSyncGroupResult._2)
 
+    // Calling sync on old leader.id will fail because that leader.id is no longer valid and replaced.
     EasyMock.reset(replicaManager)
     val newLeaderSyncGroupResult = syncGroupLeader(groupId, rebalanceResult.generation, joinGroupResult.leaderId, Map.empty)
-    assertEquals(Errors.NONE, newLeaderSyncGroupResult._2)
-    assertEquals(rebalanceResult.leaderAssignment, newLeaderSyncGroupResult._1)
+    assertEquals(Errors.UNKNOWN_MEMBER_ID, newLeaderSyncGroupResult._2)
   }
 
   @Test
@@ -599,7 +744,7 @@ class GroupCoordinatorTest {
     val leaderRejoinGroupFuture = sendJoinGroup(groupId, rebalanceResult.leaderId, protocolType, protocolSuperset, leaderInstanceId)
     // Rebalance complete immediately after follower rejoin.
     EasyMock.reset(replicaManager)
-    val followerRejoinWithFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocolSuperset, followerInstanceId)
+    val followerRejoinWithFuture = sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, protocolSuperset, followerInstanceId)
 
     timer.advanceClock(1)
 
@@ -656,6 +801,8 @@ class GroupCoordinatorTest {
       groupId,
       Stable)
 
+    assertNotEquals(rebalanceResult.followerId, joinGroupResult.memberId)
+
     EasyMock.reset(replicaManager)
     val syncGroupResult = syncGroupFollower(groupId, rebalanceResult.generation, joinGroupResult.memberId)
     assertEquals(Errors.NONE, syncGroupResult._2)
@@ -1000,7 +1147,7 @@ class GroupCoordinatorTest {
                                    expectedGroupState: GroupState,
                                    expectedLeaderId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID,
                                    expectedMemberId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID) {
-    assertEquals(Errors.NONE, joinGroupResult.error)
+    assertEquals(expectedError, joinGroupResult.error)
     assertEquals(expectedGeneration, joinGroupResult.generationId)
     assertEquals(expectedGroupInstanceIds.size, joinGroupResult.members.size)
     val resultedGroupInstanceIds = joinGroupResult.members.map(member => Some(member.groupInstanceId())).toSet
@@ -2512,8 +2659,8 @@ class GroupCoordinatorTest {
   private def setupSyncGroupCallback: (Future[SyncGroupCallbackParams], SyncGroupCallback) = {
     val responsePromise = Promise[SyncGroupCallbackParams]
     val responseFuture = responsePromise.future
-    val responseCallback: SyncGroupCallback = (assignment, error) =>
-      responsePromise.success((assignment, error))
+    val responseCallback: SyncGroupCallback = syncGroupResult =>
+      responsePromise.success(syncGroupResult.memberAssignment, syncGroupResult.error)
     (responseFuture, responseCallback)
   }
 
diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataTest.scala
index a3a9008..177bef7 100644
--- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataTest.scala
@@ -19,6 +19,7 @@ package kafka.coordinator.group
 
 import kafka.common.OffsetAndMetadata
 import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.utils.Time
 import org.junit.Assert._
 import org.junit.{Before, Test}
@@ -30,16 +31,20 @@ class GroupMetadataTest {
   private val protocolType = "consumer"
   private val groupId = "groupId"
   private val groupInstanceId = Some("groupInstanceId")
+  private val memberId = "memberId"
   private val clientId = "clientId"
   private val clientHost = "clientHost"
   private val rebalanceTimeoutMs = 60000
   private val sessionTimeoutMs = 10000
 
   private var group: GroupMetadata = null
+  private var member: MemberMetadata = null
 
   @Before
   def setUp() {
     group = new GroupMetadata("groupId", Empty, Time.SYSTEM)
+    member = new MemberMetadata(memberId, groupId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
+      protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
   }
 
   @Test
@@ -240,10 +245,6 @@ class GroupMetadataTest {
     // by default, the group supports everything
     assertTrue(group.supportsProtocols(protocolType, Set("roundrobin", "range")))
 
-    val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs,
-      sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))
-
     group.add(member)
     group.transitionTo(PreparingRebalance)
     assertTrue(group.supportsProtocols(protocolType, Set("roundrobin", "foo")))
@@ -263,9 +264,7 @@ class GroupMetadataTest {
 
   @Test
   def testInitNextGeneration() {
-    val memberId = "memberId"
-    val member = new MemberMetadata(memberId, groupId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs,
-      protocolType, List(("roundrobin", Array.empty[Byte])))
+    member.supportedProtocols = List(("roundrobin", Array.empty[Byte]))
 
     group.transitionTo(PreparingRebalance)
     group.add(member, _ => ())
@@ -465,6 +464,88 @@ class GroupMetadataTest {
     assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId))
   }
 
+  @Test(expected = classOf[IllegalArgumentException])
+  def testReplaceGroupInstanceWithEmptyGroupInstanceId(): Unit = {
+    group.add(member)
+    group.addStaticMember(groupInstanceId, memberId)
+    assertTrue(group.isLeader(memberId))
+    assertEquals(memberId, group.getStaticMemberId(groupInstanceId))
+
+    val newMemberId = "newMemberId"
+    group.replaceGroupInstance(memberId, newMemberId, Option.empty)
+  }
+
+  @Test(expected = classOf[IllegalArgumentException])
+  def testReplaceGroupInstanceWithNonExistingMember(): Unit = {
+    val newMemberId = "newMemberId"
+    group.replaceGroupInstance(memberId, newMemberId, groupInstanceId)
+  }
+
+  @Test
+  def testReplaceGroupInstance(): Unit = {
+    var joinAwaitingMemberFenced = false
+    group.add(member, joinGroupResult => {
+      joinAwaitingMemberFenced = joinGroupResult.error == Errors.FENCED_INSTANCE_ID
+    })
+    var syncAwaitingMemberFenced = false
+    member.awaitingSyncCallback = syncGroupResult => {
+      syncAwaitingMemberFenced = syncGroupResult.error == Errors.FENCED_INSTANCE_ID
+    }
+    group.addStaticMember(groupInstanceId, memberId)
+    assertTrue(group.isLeader(memberId))
+    assertEquals(memberId, group.getStaticMemberId(groupInstanceId))
+
+    val newMemberId = "newMemberId"
+    group.replaceGroupInstance(memberId, newMemberId, groupInstanceId)
+    assertTrue(group.isLeader(newMemberId))
+    assertEquals(newMemberId, group.getStaticMemberId(groupInstanceId))
+    assertTrue(joinAwaitingMemberFenced)
+    assertTrue(syncAwaitingMemberFenced)
+    assertFalse(member.isAwaitingJoin)
+    assertFalse(member.isAwaitingSync)
+  }
+
+  @Test
+  def testInvokeJoinCallback(): Unit = {
+    var invoked = false
+    group.add(member, _ => {
+      invoked = true
+    })
+
+    assertTrue(group.hasAllMembersJoined)
+    group.maybeInvokeJoinCallback(member, GroupCoordinator.joinError(member.memberId, Errors.NONE))
+    assertTrue(invoked)
+    assertFalse(member.isAwaitingJoin)
+  }
+
+  @Test
+  def testNotInvokeJoinCallback(): Unit = {
+    group.add(member)
+
+    assertFalse(member.isAwaitingJoin)
+    group.maybeInvokeJoinCallback(member, GroupCoordinator.joinError(member.memberId, Errors.NONE))
+    assertFalse(member.isAwaitingJoin)
+  }
+
+  @Test
+  def testInvokeSyncCallback(): Unit = {
+    group.add(member)
+    member.awaitingSyncCallback = _ => {}
+
+    val invoked = group.maybeInvokeSyncCallback(member, SyncGroupResult(Array.empty, Errors.NONE))
+    assertTrue(invoked)
+    assertFalse(member.isAwaitingSync)
+  }
+
+  @Test
+  def testNotInvokeSyncCallback(): Unit = {
+    group.add(member)
+
+    val invoked = group.maybeInvokeSyncCallback(member, SyncGroupResult(Array.empty, Errors.NONE))
+    assertFalse(invoked)
+    assertFalse(member.isAwaitingSync)
+  }
+
   private def assertState(group: GroupMetadata, targetState: GroupState) {
     val states: Set[GroupState] = Set(Stable, PreparingRebalance, CompletingRebalance, Dead)
     val otherStates = states - targetState
diff --git a/tests/kafkatest/tests/client/consumer_test.py b/tests/kafkatest/tests/client/consumer_test.py
index be15e6a..131123f 100644
--- a/tests/kafkatest/tests/client/consumer_test.py
+++ b/tests/kafkatest/tests/client/consumer_test.py
@@ -265,6 +265,12 @@ class OffsetValidationTest(VerifiableConsumerTest):
                                "normal consumer group and %d from conflict consumer group" % \
                                (len(consumer.nodes), len(consumer.joined_nodes()), len(conflict_consumer.joined_nodes()))
                        )
+            wait_until(lambda: len(consumer.dead_nodes()) + len(conflict_consumer.dead_nodes()) == len(conflict_consumer.nodes),
+                       timeout_sec=self.session_timeout_sec,
+                       err_msg="Timed out waiting for fenced consumers to die, expected total %d dead, but only see %d dead in"
+                               "normal consumer group and %d dead in conflict consumer group" % \
+                               (len(conflict_consumer.nodes), len(consumer.dead_nodes()), len(conflict_consumer.dead_nodes()))
+                       )
 
     @cluster(num_nodes=7)
     @matrix(clean_shutdown=[True], enable_autocommit=[True, False])