You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by jg...@apache.org on 2019/04/18 15:35:48 UTC

[kafka] branch trunk updated: KAFKA-7026; Sticky Assignor Partition Assignment Improvement (KIP-341) (#5291)

This is an automated email from the ASF dual-hosted git repository.

jgus pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 3e8a10e  KAFKA-7026; Sticky Assignor Partition Assignment Improvement (KIP-341) (#5291)
3e8a10e is described below

commit 3e8a10e7d9a9d5d3d29c8793e30d8401be1588ac
Author: Vahid Hashemian <va...@gmail.com>
AuthorDate: Thu Apr 18 20:05:24 2019 +0430

    KAFKA-7026; Sticky Assignor Partition Assignment Improvement (KIP-341) (#5291)
    
    This patch contains the implementation of KIP-341, which adds protection in the sticky assignor from consumers which are joining with a stale assignment. More details can be found in the proposal: https://cwiki.apache.org/confluence/display/KAFKA/KIP-341%3A+Update+Sticky+Assignor%27s+User+Data+Protocol.
    
    Reviewers: Steven Aerts <st...@gmail.com>, Jason Gustafson <ja...@confluent.io>
---
 .../kafka/clients/consumer/StickyAssignor.java     | 209 +++++++++++---
 .../consumer/internals/AbstractCoordinator.java    |  12 +-
 .../consumer/internals/ConsumerCoordinator.java    |   2 +-
 .../consumer/internals/ConsumerProtocol.java       |   1 +
 .../consumer/internals/PartitionAssignor.java      |  10 +-
 .../kafka/clients/consumer/StickyAssignorTest.java | 317 +++++++++++++++++++--
 6 files changed, 478 insertions(+), 73 deletions(-)

diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/StickyAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/StickyAssignor.java
index ee537eb..9575ba6 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/StickyAssignor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/StickyAssignor.java
@@ -39,7 +39,9 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Optional;
 import java.util.Set;
+import java.util.TreeMap;
 import java.util.TreeSet;
 
 /**
@@ -185,24 +187,49 @@ public class StickyAssignor extends AbstractPartitionAssignor {
 
     // these schemas are used for preserving consumer's previously assigned partitions
     // list and sending it as user data to the leader during a rebalance
-    private static final String TOPIC_PARTITIONS_KEY_NAME = "previous_assignment";
-    private static final String TOPIC_KEY_NAME = "topic";
-    private static final String PARTITIONS_KEY_NAME = "partitions";
-    private static final Schema TOPIC_ASSIGNMENT = new Schema(
+    static final String TOPIC_PARTITIONS_KEY_NAME = "previous_assignment";
+    static final String TOPIC_KEY_NAME = "topic";
+    static final String PARTITIONS_KEY_NAME = "partitions";
+    private static final String GENERATION_KEY_NAME = "generation";
+    private static final int DEFAULT_GENERATION = -1;
+    static final Schema TOPIC_ASSIGNMENT = new Schema(
             new Field(TOPIC_KEY_NAME, Type.STRING),
             new Field(PARTITIONS_KEY_NAME, new ArrayOf(Type.INT32)));
-    private static final Schema STICKY_ASSIGNOR_USER_DATA = new Schema(
+    static final Schema STICKY_ASSIGNOR_USER_DATA_V0 = new Schema(
             new Field(TOPIC_PARTITIONS_KEY_NAME, new ArrayOf(TOPIC_ASSIGNMENT)));
+    private static final Schema STICKY_ASSIGNOR_USER_DATA_V1 = new Schema(
+            new Field(TOPIC_PARTITIONS_KEY_NAME, new ArrayOf(TOPIC_ASSIGNMENT)),
+            new Field(GENERATION_KEY_NAME, Type.INT32));
 
     private List<TopicPartition> memberAssignment = null;
     private PartitionMovements partitionMovements;
+    private int generation = DEFAULT_GENERATION; // consumer group generation
+
+    static final class ConsumerUserData {
+        final List<TopicPartition> partitions;
+        final Optional<Integer> generation;
+        ConsumerUserData(List<TopicPartition> partitions, Optional<Integer> generation) {
+            this.partitions = partitions;
+            this.generation = generation;
+        }
+    }
+
+    static final class ConsumerGenerationPair {
+        final String consumer;
+        final int generation;
+        ConsumerGenerationPair(String consumer, int generation) {
+            this.consumer = consumer;
+            this.generation = generation;
+        }
+    }
 
     public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
                                                     Map<String, Subscription> subscriptions) {
         Map<String, List<TopicPartition>> currentAssignment = new HashMap<>();
+        Map<TopicPartition, ConsumerGenerationPair> prevAssignment = new HashMap<>();
         partitionMovements = new PartitionMovements();
 
-        prepopulateCurrentAssignments(subscriptions, currentAssignment);
+        prepopulateCurrentAssignments(subscriptions, currentAssignment, prevAssignment);
         boolean isFreshAssignment = currentAssignment.isEmpty();
 
         // a mapping of all topic partitions to all consumers that can be assigned to them
@@ -213,12 +240,12 @@ public class StickyAssignor extends AbstractPartitionAssignor {
         // initialize partition2AllPotentialConsumers and consumer2AllPotentialPartitions in the following two for loops
         for (Entry<String, Integer> entry: partitionsPerTopic.entrySet()) {
             for (int i = 0; i < entry.getValue(); ++i)
-                partition2AllPotentialConsumers.put(new TopicPartition(entry.getKey(), i), new ArrayList<String>());
+                partition2AllPotentialConsumers.put(new TopicPartition(entry.getKey(), i), new ArrayList<>());
         }
 
         for (Entry<String, Subscription> entry: subscriptions.entrySet()) {
             String consumer = entry.getKey();
-            consumer2AllPotentialPartitions.put(consumer, new ArrayList<TopicPartition>());
+            consumer2AllPotentialPartitions.put(consumer, new ArrayList<>());
             entry.getValue().topics().stream().filter(topic -> partitionsPerTopic.get(topic) != null).forEach(topic -> {
                 for (int i = 0; i < partitionsPerTopic.get(topic); ++i) {
                     TopicPartition topicPartition = new TopicPartition(topic, i);
@@ -229,7 +256,7 @@ public class StickyAssignor extends AbstractPartitionAssignor {
 
             // add this consumer to currentAssignment (with an empty topic partition assignment) if it does not already exist
             if (!currentAssignment.containsKey(consumer))
-                currentAssignment.put(consumer, new ArrayList<TopicPartition>());
+                currentAssignment.put(consumer, new ArrayList<>());
         }
 
         // a mapping of partition to current consumer
@@ -239,7 +266,7 @@ public class StickyAssignor extends AbstractPartitionAssignor {
                 currentPartitionConsumer.put(topicPartition, entry.getKey());
 
         List<TopicPartition> sortedPartitions = sortPartitions(
-                currentAssignment, isFreshAssignment, partition2AllPotentialConsumers, consumer2AllPotentialPartitions);
+                currentAssignment, prevAssignment.keySet(), isFreshAssignment, partition2AllPotentialConsumers, consumer2AllPotentialPartitions);
 
         // all partitions that need to be assigned (initially set to all partitions but adjusted in the following loop)
         List<TopicPartition> unassignedPartitions = new ArrayList<>(sortedPartitions);
@@ -278,23 +305,68 @@ public class StickyAssignor extends AbstractPartitionAssignor {
         TreeSet<String> sortedCurrentSubscriptions = new TreeSet<>(new SubscriptionComparator(currentAssignment));
         sortedCurrentSubscriptions.addAll(currentAssignment.keySet());
 
-        balance(currentAssignment, sortedPartitions, unassignedPartitions, sortedCurrentSubscriptions,
+        balance(currentAssignment, prevAssignment, sortedPartitions, unassignedPartitions, sortedCurrentSubscriptions,
                 consumer2AllPotentialPartitions, partition2AllPotentialConsumers, currentPartitionConsumer);
         return currentAssignment;
     }
 
     private void prepopulateCurrentAssignments(Map<String, Subscription> subscriptions,
-                                               Map<String, List<TopicPartition>> currentAssignment) {
-        for (Map.Entry<String, Subscription> subscriptionEntry : subscriptions.entrySet()) {
+                                               Map<String, List<TopicPartition>> currentAssignment,
+                                               Map<TopicPartition, ConsumerGenerationPair> prevAssignment) {
+        // we need to process subscriptions' user data with each consumer's reported generation in mind
+        // higher generations overwrite lower generations in case of a conflict
+        // note that a conflict could exists only if user data is for different generations
+
+        // for each partition we create a sorted map of its consumers by generation
+        Map<TopicPartition, TreeMap<Integer, String>> sortedPartitionConsumersByGeneration = new HashMap<>();
+        for (Map.Entry<String, Subscription> subscriptionEntry: subscriptions.entrySet()) {
+            String consumer = subscriptionEntry.getKey();
             ByteBuffer userData = subscriptionEntry.getValue().userData();
-            if (userData != null && userData.hasRemaining())
-                currentAssignment.put(subscriptionEntry.getKey(), deserializeTopicPartitionAssignment(userData));
+            if (userData == null || !userData.hasRemaining()) continue;
+            ConsumerUserData consumerUserData = deserializeTopicPartitionAssignment(userData);
+
+            for (TopicPartition partition: consumerUserData.partitions) {
+                if (sortedPartitionConsumersByGeneration.containsKey(partition)) {
+                    Map<Integer, String> consumers = sortedPartitionConsumersByGeneration.get(partition);
+                    if (consumerUserData.generation.isPresent() && consumers.containsKey(consumerUserData.generation.get())) {
+                        // same partition is assigned to two consumers during the same rebalance.
+                        // log a warning and skip this record
+                        log.warn("Partition '{}' is assigned to multiple consumers following sticky assignment generation {}.",
+                                partition, consumerUserData.generation);
+                    } else
+                        consumers.put(consumerUserData.generation.orElse(DEFAULT_GENERATION), consumer);
+                } else {
+                    TreeMap<Integer, String> sortedConsumers = new TreeMap<>();
+                    sortedConsumers.put(consumerUserData.generation.orElse(DEFAULT_GENERATION), consumer);
+                    sortedPartitionConsumersByGeneration.put(partition, sortedConsumers);
+                }
+            }
+        }
+
+        // prevAssignment holds the prior ConsumerGenerationPair (before current) of each partition
+        // current and previous consumers are the last two consumers of each partition in the above sorted map
+        for (Map.Entry<TopicPartition, TreeMap<Integer, String>> partitionConsumersEntry: sortedPartitionConsumersByGeneration.entrySet()) {
+            TopicPartition partition = partitionConsumersEntry.getKey();
+            TreeMap<Integer, String> consumers = partitionConsumersEntry.getValue();
+            Iterator<Integer> it = consumers.descendingKeySet().iterator();
+
+            // let's process the current (most recent) consumer first
+            String consumer = consumers.get(it.next());
+            currentAssignment.computeIfAbsent(consumer, k -> new ArrayList<>());
+            currentAssignment.get(consumer).add(partition);
+
+            // now update previous assignment if any
+            if (it.hasNext()) {
+                int generation = it.next();
+                prevAssignment.put(partition, new ConsumerGenerationPair(consumers.get(generation), generation));
+            }
         }
     }
 
     @Override
-    public void onAssignment(Assignment assignment) {
+    public void onAssignment(Assignment assignment, int generation) {
         memberAssignment = assignment.partitions();
+        this.generation = generation;
     }
 
     @Override
@@ -302,7 +374,8 @@ public class StickyAssignor extends AbstractPartitionAssignor {
         if (memberAssignment == null)
             return new Subscription(new ArrayList<>(topics));
 
-        return new Subscription(new ArrayList<>(topics), serializeTopicPartitionAssignment(memberAssignment));
+        return new Subscription(new ArrayList<>(topics),
+                serializeTopicPartitionAssignment(new ConsumerUserData(memberAssignment, Optional.of(generation))));
     }
 
     @Override
@@ -310,6 +383,10 @@ public class StickyAssignor extends AbstractPartitionAssignor {
         return "sticky";
     }
 
+    int generation() {
+        return generation;
+    }
+
     /**
      * determine if the current assignment is a balanced one
      *
@@ -395,12 +472,16 @@ public class StickyAssignor extends AbstractPartitionAssignor {
      * that causes minimal partition movement among consumers (hence honoring maximal stickiness)
      *
      * @param currentAssignment the calculated assignment so far
+     * @param partitionsWithADifferentPreviousAssignment partitions that had a different consumer before (for every
+     *                                                   such partition there should also be a mapping in
+     *                                                   @currentAssignment to a different consumer)
      * @param isFreshAssignment whether this is a new assignment, or a reassignment of an existing one
      * @param partition2AllPotentialConsumers a mapping of partitions to their potential consumers
      * @param consumer2AllPotentialPartitions a mapping of consumers to potential partitions they can consumer from
      * @return sorted list of valid partitions
      */
     private List<TopicPartition> sortPartitions(Map<String, List<TopicPartition>> currentAssignment,
+                                                Set<TopicPartition> partitionsWithADifferentPreviousAssignment,
                                                 boolean isFreshAssignment,
                                                 Map<TopicPartition, List<String>> partition2AllPotentialConsumers,
                                                 Map<String, List<TopicPartition>> consumer2AllPotentialPartitions) {
@@ -421,11 +502,28 @@ public class StickyAssignor extends AbstractPartitionAssignor {
             }
             TreeSet<String> sortedConsumers = new TreeSet<>(new SubscriptionComparator(assignments));
             sortedConsumers.addAll(assignments.keySet());
+            // at this point, sortedConsumers contains an ascending-sorted list of consumers based on
+            // how many valid partitions are currently assigned to them
 
             while (!sortedConsumers.isEmpty()) {
+                // take the consumer with the most partitions
                 String consumer = sortedConsumers.pollLast();
+                // currently assigned partitions to this consumer
                 List<TopicPartition> remainingPartitions = assignments.get(consumer);
-                if (!remainingPartitions.isEmpty()) {
+                // partitions that were assigned to a different consumer last time
+                List<TopicPartition> prevPartitions = new ArrayList<>(partitionsWithADifferentPreviousAssignment);
+                // from partitions that had a different consumer before, keep only those that are
+                // assigned to this consumer now
+                prevPartitions.retainAll(remainingPartitions);
+                if (!prevPartitions.isEmpty()) {
+                    // if there is a partition of this consumer that was assigned to another consumer before
+                    // mark it as good options for reassignment
+                    TopicPartition partition = prevPartitions.remove(0);
+                    remainingPartitions.remove(partition);
+                    sortedPartitions.add(partition);
+                    sortedConsumers.add(consumer);
+                } else if (!remainingPartitions.isEmpty()) {
+                    // otherwise, mark any other one of the current partitions as a reassignment candidate
                     sortedPartitions.add(remainingPartitions.remove(0));
                     sortedConsumers.add(consumer);
                 }
@@ -459,17 +557,13 @@ public class StickyAssignor extends AbstractPartitionAssignor {
         if (!hasIdenticalListElements(partition2AllPotentialConsumers.values()))
             return false;
 
-        if (!hasIdenticalListElements(consumer2AllPotentialPartitions.values()))
-            return false;
-
-        return true;
+        return hasIdenticalListElements(consumer2AllPotentialPartitions.values());
     }
 
     /**
-     * @return the consumer to which the given partition is assigned. The assignment should improve the overall balance
-     * of the partition assignments to consumers.
+     * The assignment should improve the overall balance of the partition assignments to consumers.
      */
-    private String assignPartition(TopicPartition partition,
+    private void assignPartition(TopicPartition partition,
                                    TreeSet<String> sortedCurrentSubscriptions,
                                    Map<String, List<TopicPartition>> currentAssignment,
                                    Map<String, List<TopicPartition>> consumer2AllPotentialPartitions,
@@ -480,10 +574,9 @@ public class StickyAssignor extends AbstractPartitionAssignor {
                 currentAssignment.get(consumer).add(partition);
                 currentPartitionConsumer.put(partition, consumer);
                 sortedCurrentSubscriptions.add(consumer);
-                return consumer;
+                break;
             }
         }
-        return null;
     }
 
     private boolean canParticipateInReassignment(TopicPartition partition,
@@ -519,6 +612,7 @@ public class StickyAssignor extends AbstractPartitionAssignor {
      * Balance the current assignment using the data structures created in the assign(...) method above.
      */
     private void balance(Map<String, List<TopicPartition>> currentAssignment,
+                         Map<TopicPartition, ConsumerGenerationPair> prevAssignment,
                          List<TopicPartition> sortedPartitions,
                          List<TopicPartition> unassignedPartitions,
                          TreeSet<String> sortedCurrentSubscriptions,
@@ -558,7 +652,7 @@ public class StickyAssignor extends AbstractPartitionAssignor {
         Map<String, List<TopicPartition>> preBalanceAssignment = deepCopy(currentAssignment);
         Map<TopicPartition, String> preBalancePartitionConsumers = new HashMap<>(currentPartitionConsumer);
 
-        reassignmentPerformed = performReassignments(sortedPartitions, currentAssignment, sortedCurrentSubscriptions,
+        reassignmentPerformed = performReassignments(sortedPartitions, currentAssignment, prevAssignment, sortedCurrentSubscriptions,
                 consumer2AllPotentialPartitions, partition2AllPotentialConsumers, currentPartitionConsumer);
 
         // if we are not preserving existing assignments and we have made changes to the current assignment
@@ -581,6 +675,7 @@ public class StickyAssignor extends AbstractPartitionAssignor {
 
     private boolean performReassignments(List<TopicPartition> reassignablePartitions,
                                          Map<String, List<TopicPartition>> currentAssignment,
+                                         Map<TopicPartition, ConsumerGenerationPair> prevAssignment,
                                          TreeSet<String> sortedCurrentSubscriptions,
                                          Map<String, List<TopicPartition>> consumer2AllPotentialPartitions,
                                          Map<TopicPartition, List<String>> partition2AllPotentialConsumers,
@@ -606,6 +701,14 @@ public class StickyAssignor extends AbstractPartitionAssignor {
                 if (consumer == null)
                     log.error("Expected partition '{}' to be assigned to a consumer", partition);
 
+                if (prevAssignment.containsKey(partition) &&
+                        currentAssignment.get(consumer).size() > currentAssignment.get(prevAssignment.get(partition).consumer).size() + 1) {
+                    reassignPartition(partition, currentAssignment, sortedCurrentSubscriptions, currentPartitionConsumer, prevAssignment.get(partition).consumer);
+                    reassignmentPerformed = true;
+                    modified = true;
+                    continue;
+                }
+
                 // check if a better-suited consumer exist for the partition; if so, reassign it
                 for (String otherConsumer: partition2AllPotentialConsumers.get(partition)) {
                     if (currentAssignment.get(consumer).size() > currentAssignment.get(otherConsumer).size() + 1) {
@@ -626,8 +729,6 @@ public class StickyAssignor extends AbstractPartitionAssignor {
                                    TreeSet<String> sortedCurrentSubscriptions,
                                    Map<TopicPartition, String> currentPartitionConsumer,
                                    Map<String, List<TopicPartition>> consumer2AllPotentialPartitions) {
-        String consumer = currentPartitionConsumer.get(partition);
-
         // find the new consumer
         String newConsumer = null;
         for (String anotherConsumer: sortedCurrentSubscriptions) {
@@ -639,11 +740,18 @@ public class StickyAssignor extends AbstractPartitionAssignor {
 
         assert newConsumer != null;
 
+        reassignPartition(partition, currentAssignment, sortedCurrentSubscriptions, currentPartitionConsumer, newConsumer);
+    }
+
+    private void reassignPartition(TopicPartition partition,
+                                   Map<String, List<TopicPartition>> currentAssignment,
+                                   TreeSet<String> sortedCurrentSubscriptions,
+                                   Map<TopicPartition, String> currentPartitionConsumer,
+                                   String newConsumer) {
+        String consumer = currentPartitionConsumer.get(partition);
         // find the correct partition movement considering the stickiness requirement
         TopicPartition partitionToBeMoved = partitionMovements.getTheActualPartitionToBeMoved(partition, consumer, newConsumer);
         processPartitionMovement(partitionToBeMoved, newConsumer, currentAssignment, sortedCurrentSubscriptions, currentPartitionConsumer);
-
-        return;
     }
 
     private void processPartitionMovement(TopicPartition partition,
@@ -669,24 +777,39 @@ public class StickyAssignor extends AbstractPartitionAssignor {
         return partitionMovements.isSticky();
     }
 
-    static ByteBuffer serializeTopicPartitionAssignment(List<TopicPartition> partitions) {
-        Struct struct = new Struct(STICKY_ASSIGNOR_USER_DATA);
+    static ByteBuffer serializeTopicPartitionAssignment(ConsumerUserData consumerUserData) {
+        Struct struct = new Struct(STICKY_ASSIGNOR_USER_DATA_V1);
         List<Struct> topicAssignments = new ArrayList<>();
-        for (Map.Entry<String, List<Integer>> topicEntry : CollectionUtils.groupPartitionsByTopic(partitions).entrySet()) {
+        for (Map.Entry<String, List<Integer>> topicEntry : CollectionUtils.groupPartitionsByTopic(consumerUserData.partitions).entrySet()) {
             Struct topicAssignment = new Struct(TOPIC_ASSIGNMENT);
             topicAssignment.set(TOPIC_KEY_NAME, topicEntry.getKey());
             topicAssignment.set(PARTITIONS_KEY_NAME, topicEntry.getValue().toArray());
             topicAssignments.add(topicAssignment);
         }
         struct.set(TOPIC_PARTITIONS_KEY_NAME, topicAssignments.toArray());
-        ByteBuffer buffer = ByteBuffer.allocate(STICKY_ASSIGNOR_USER_DATA.sizeOf(struct));
-        STICKY_ASSIGNOR_USER_DATA.write(buffer, struct);
+        if (consumerUserData.generation.isPresent())
+            struct.set(GENERATION_KEY_NAME, consumerUserData.generation.get());
+        ByteBuffer buffer = ByteBuffer.allocate(STICKY_ASSIGNOR_USER_DATA_V1.sizeOf(struct));
+        STICKY_ASSIGNOR_USER_DATA_V1.write(buffer, struct);
         buffer.flip();
         return buffer;
     }
 
-    private static List<TopicPartition> deserializeTopicPartitionAssignment(ByteBuffer buffer) {
-        Struct struct = STICKY_ASSIGNOR_USER_DATA.read(buffer);
+    private static ConsumerUserData deserializeTopicPartitionAssignment(ByteBuffer buffer) {
+        Struct struct;
+        ByteBuffer copy = buffer.duplicate();
+        try {
+            struct = STICKY_ASSIGNOR_USER_DATA_V1.read(buffer);
+        } catch (Exception e1) {
+            try {
+                // fall back to older schema
+                struct = STICKY_ASSIGNOR_USER_DATA_V0.read(copy);
+            } catch (Exception e2) {
+                // ignore the consumer's previous assignment if it cannot be parsed
+                return new ConsumerUserData(Collections.emptyList(), Optional.of(DEFAULT_GENERATION));
+            }
+        }
+
         List<TopicPartition> partitions = new ArrayList<>();
         for (Object structObj : struct.getArray(TOPIC_PARTITIONS_KEY_NAME)) {
             Struct assignment = (Struct) structObj;
@@ -696,7 +819,9 @@ public class StickyAssignor extends AbstractPartitionAssignor {
                 partitions.add(new TopicPartition(topic, partition));
             }
         }
-        return partitions;
+        // make sure this is backward compatible
+        Optional<Integer> generation = struct.hasField(GENERATION_KEY_NAME) ? Optional.of(struct.getInt(GENERATION_KEY_NAME)) : Optional.empty();
+        return new ConsumerUserData(partitions, generation);
     }
 
     /**
@@ -794,11 +919,11 @@ public class StickyAssignor extends AbstractPartitionAssignor {
 
             String topic = partition.topic();
             if (!partitionMovementsByTopic.containsKey(topic))
-                partitionMovementsByTopic.put(topic, new HashMap<ConsumerPair, Set<TopicPartition>>());
+                partitionMovementsByTopic.put(topic, new HashMap<>());
 
             Map<ConsumerPair, Set<TopicPartition>> partitionMovementsForThisTopic = partitionMovementsByTopic.get(topic);
             if (!partitionMovementsForThisTopic.containsKey(pair))
-                partitionMovementsForThisTopic.put(pair, new HashSet<TopicPartition>());
+                partitionMovementsForThisTopic.put(pair, new HashSet<>());
 
             partitionMovementsForThisTopic.get(pair).add(partition);
         }
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
index 4ff4e19..9261966 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java
@@ -566,7 +566,7 @@ public abstract class AbstractCoordinator implements Closeable {
                 // and send another join group request in next cycle.
                 synchronized (AbstractCoordinator.this) {
                     AbstractCoordinator.this.generation = new Generation(OffsetCommitRequest.DEFAULT_GENERATION_ID,
-                        joinResponse.data().memberId(), null);
+                            joinResponse.data().memberId(), null);
                     AbstractCoordinator.this.rejoinNeeded = true;
                     AbstractCoordinator.this.state = MemberState.UNJOINED;
                 }
@@ -654,7 +654,7 @@ public abstract class AbstractCoordinator implements Closeable {
         FindCoordinatorRequest.Builder requestBuilder =
                 new FindCoordinatorRequest.Builder(FindCoordinatorRequest.CoordinatorType.GROUP, this.groupId);
         return client.send(node, requestBuilder)
-                     .compose(new FindCoordinatorResponseHandler());
+                .compose(new FindCoordinatorResponseHandler());
     }
 
     private class FindCoordinatorResponseHandler extends RequestFutureAdapter<ClientResponse, Void> {
@@ -940,8 +940,8 @@ public abstract class AbstractCoordinator implements Closeable {
 
             this.heartbeatLatency = metrics.sensor("heartbeat-latency");
             this.heartbeatLatency.add(metrics.metricName("heartbeat-response-time-max",
-                this.metricGrpName,
-                "The max time taken to receive a response to a heartbeat request"), new Max());
+                    this.metricGrpName,
+                    "The max time taken to receive a response to a heartbeat request"), new Max());
             this.heartbeatLatency.add(createMeter(metrics, metricGrpName, "heartbeat", "heartbeats"));
 
             this.joinLatency = metrics.sensor("join-latency");
@@ -1148,8 +1148,8 @@ public abstract class AbstractCoordinator implements Closeable {
             if (o == null || getClass() != o.getClass()) return false;
             final Generation that = (Generation) o;
             return generationId == that.generationId &&
-                Objects.equals(memberId, that.memberId) &&
-                Objects.equals(protocol, that.protocol);
+                    Objects.equals(memberId, that.memberId) &&
+                    Objects.equals(protocol, that.protocol);
         }
 
         @Override
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
index 2b949a3..b31bf44 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java
@@ -262,7 +262,7 @@ public final class ConsumerCoordinator extends AbstractCoordinator {
         maybeUpdateJoinedSubscription(assignedPartitions);
 
         // give the assignor a chance to update internal state based on the received assignment
-        assignor.onAssignment(assignment);
+        assignor.onAssignment(assignment, generation);
 
         // reschedule the auto commit starting from now
         if (autoCommitEnabled)
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java
index 8a4aef8..7bef8f7 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java
@@ -45,6 +45,7 @@ import java.util.Map;
  *   TopicPartitions => [Topic Partitions]
  *     Topic         => String
  *     Partitions    => [int32]
+ *   UserData        => Bytes
  * </pre>
  *
  * The current implementation assumes that future versions will not break compatibility. When
diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PartitionAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PartitionAssignor.java
index 4a7c7a8..43fdaf3 100644
--- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PartitionAssignor.java
+++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/PartitionAssignor.java
@@ -58,13 +58,21 @@ public interface PartitionAssignor {
      */
     Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription> subscriptions);
 
-
     /**
      * Callback which is invoked when a group member receives its assignment from the leader.
      * @param assignment The local member's assignment as provided by the leader in {@link #assign(Cluster, Map)}
      */
     void onAssignment(Assignment assignment);
 
+    /**
+     * Callback which is invoked when a group member receives its assignment from the leader.
+     * @param assignment The local member's assignment as provided by the leader in {@link #assign(Cluster, Map)}
+     * @param generation The consumer group generation associated with this partition assignment (optional)
+     */
+    default void onAssignment(Assignment assignment, int generation) {
+        onAssignment(assignment);
+    }
+
 
     /**
      * Unique name for this assignor (e.g. "range" or "roundrobin" or "sticky")
diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java
index 32ba16a..a1fe0cd 100644
--- a/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java
@@ -19,6 +19,7 @@ package org.apache.kafka.clients.consumer;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -27,11 +28,14 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Optional;
 import java.util.Random;
 import java.util.Set;
 
+import org.apache.kafka.clients.consumer.StickyAssignor.ConsumerUserData;
 import org.apache.kafka.clients.consumer.internals.PartitionAssignor.Subscription;
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.protocol.types.Struct;
 import org.apache.kafka.common.utils.CollectionUtils;
 import org.apache.kafka.common.utils.Utils;
 import org.junit.Test;
@@ -46,7 +50,7 @@ public class StickyAssignorTest {
 
         Map<String, Integer> partitionsPerTopic = new HashMap<>();
         Map<String, Subscription> subscriptions =
-                Collections.singletonMap(consumerId, new Subscription(Collections.<String>emptyList()));
+                Collections.singletonMap(consumerId, new Subscription(Collections.emptyList()));
 
         Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
         assertEquals(Collections.singleton(consumerId), assignment.keySet());
@@ -235,10 +239,11 @@ public class StickyAssignorTest {
 
         String consumer2 = "consumer2";
         subscriptions.put(consumer1,
-                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer1))));
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(assignment.get(consumer1), Optional.of(assignor.generation())))));
         subscriptions.put(consumer2, new Subscription(topics(topic)));
         assignment = assignor.assign(partitionsPerTopic, subscriptions);
-        assertEquals(partitions(tp(topic, 1), tp(topic, 2)), assignment.get(consumer1));
+        assertEquals(partitions(tp(topic, 2), tp(topic, 1)), assignment.get(consumer1));
         assertEquals(partitions(tp(topic, 0)), assignment.get(consumer2));
 
         verifyValidityAndBalance(subscriptions, assignment);
@@ -247,7 +252,8 @@ public class StickyAssignorTest {
 
         subscriptions.remove(consumer1);
         subscriptions.put(consumer2,
-                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer2))));
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(assignment.get(consumer2), Optional.of(assignor.generation())))));
         assignment = assignor.assign(partitionsPerTopic, subscriptions);
         assertTrue(assignment.get(consumer2).contains(tp(topic, 0)));
         assertTrue(assignment.get(consumer2).contains(tp(topic, 1)));
@@ -318,9 +324,11 @@ public class StickyAssignorTest {
         String topic2 = "topic2";
         partitionsPerTopic.put(topic2, 3);
         subscriptions.put(consumer1,
-                new Subscription(topics(topic, topic2), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer1))));
+                new Subscription(topics(topic, topic2), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(assignment.get(consumer1), Optional.of(assignor.generation())))));
         subscriptions.put(consumer2,
-                new Subscription(topics(topic, topic2), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer2))));
+                new Subscription(topics(topic, topic2), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(assignment.get(consumer2), Optional.of(assignor.generation())))));
         assignment = assignor.assign(partitionsPerTopic, subscriptions);
         // verify balance
         verifyValidityAndBalance(subscriptions, assignment);
@@ -335,9 +343,11 @@ public class StickyAssignorTest {
 
         partitionsPerTopic.remove(topic);
         subscriptions.put(consumer1,
-                new Subscription(topics(topic2), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer1))));
+                new Subscription(topics(topic2), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(assignment.get(consumer1), Optional.of(assignor.generation())))));
         subscriptions.put(consumer2,
-                new Subscription(topics(topic2), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer2))));
+                new Subscription(topics(topic2), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(assignment.get(consumer2), Optional.of(assignor.generation())))));
         assignment = assignor.assign(partitionsPerTopic, subscriptions);
         // verify balance
         verifyValidityAndBalance(subscriptions, assignment);
@@ -360,7 +370,7 @@ public class StickyAssignorTest {
 
         Map<String, Subscription> subscriptions = new HashMap<>();
         for (int i = 1; i < 20; i++) {
-            List<String> topics = new ArrayList<String>();
+            List<String> topics = new ArrayList<>();
             for (int j = 1; j <= i; j++)
                 topics.add(getTopicName(j, 20));
             subscriptions.put(getConsumerName(i, 20), new Subscription(topics));
@@ -372,7 +382,8 @@ public class StickyAssignorTest {
         for (int i = 1; i < 20; i++) {
             String consumer = getConsumerName(i, 20);
             subscriptions.put(consumer,
-                    new Subscription(subscriptions.get(consumer).topics(), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer))));
+                    new Subscription(subscriptions.get(consumer).topics(),
+                            StickyAssignor.serializeTopicPartitionAssignment(new ConsumerUserData(assignment.get(consumer), Optional.of(assignor.generation())))));
         }
         subscriptions.remove("consumer10");
 
@@ -409,7 +420,7 @@ public class StickyAssignorTest {
 
         Map<String, Subscription> subscriptions = new HashMap<>();
         for (int i = 1; i < 9; i++) {
-            List<String> topics = new ArrayList<String>();
+            List<String> topics = new ArrayList<>();
             for (int j = 1; j <= partitionsPerTopic.size(); j++)
                 topics.add(getTopicName(j, 15));
             subscriptions.put(getConsumerName(i, 9), new Subscription(topics));
@@ -421,7 +432,8 @@ public class StickyAssignorTest {
         for (int i = 1; i < 9; i++) {
             String consumer = getConsumerName(i, 9);
             subscriptions.put(consumer,
-                    new Subscription(subscriptions.get(consumer).topics(), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer))));
+                    new Subscription(subscriptions.get(consumer).topics(),
+                            StickyAssignor.serializeTopicPartitionAssignment(new ConsumerUserData(assignment.get(consumer), Optional.of(assignor.generation())))));
         }
         subscriptions.remove(getConsumerName(5, 9));
 
@@ -442,7 +454,7 @@ public class StickyAssignorTest {
 
         Map<String, Subscription> subscriptions = new HashMap<>();
         for (int i = 0; i < consumerCount; i++) {
-            List<String> topics = new ArrayList<String>();
+            List<String> topics = new ArrayList<>();
             for (int j = 0; j < rand.nextInt(20); j++)
                 topics.add(getTopicName(rand.nextInt(topicCount), topicCount));
             subscriptions.put(getConsumerName(i, consumerCount), new Subscription(topics));
@@ -454,7 +466,8 @@ public class StickyAssignorTest {
         for (int i = 1; i < consumerCount; i++) {
             String consumer = getConsumerName(i, consumerCount);
             subscriptions.put(consumer,
-                    new Subscription(subscriptions.get(consumer).topics(), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer))));
+                    new Subscription(subscriptions.get(consumer).topics(),
+                            StickyAssignor.serializeTopicPartitionAssignment(new ConsumerUserData(assignment.get(consumer), Optional.of(assignor.generation())))));
         }
         for (int i = 0; i < 50; ++i) {
             String c = getConsumerName(rand.nextInt(consumerCount), consumerCount);
@@ -474,7 +487,7 @@ public class StickyAssignorTest {
 
         Map<String, Subscription> subscriptions = new HashMap<>();
         for (int i = 0; i < 3; i++) {
-            List<String> topics = new ArrayList<String>();
+            List<String> topics = new ArrayList<>();
             for (int j = i; j <= 3 * i - 2; j++)
                 topics.add(getTopicName(j, 5));
             subscriptions.put(getConsumerName(i, 3), new Subscription(topics));
@@ -526,7 +539,7 @@ public class StickyAssignorTest {
                 List<String> sub = Utils.sorted(getRandomSublist(topics));
                 String consumer = getConsumerName(i, maxNumConsumers);
                 subscriptions.put(consumer,
-                        new Subscription(sub, StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer))));
+                        new Subscription(sub, StickyAssignor.serializeTopicPartitionAssignment(new ConsumerUserData(assignment.get(consumer), Optional.of(assignor.generation())))));
             }
 
             assignment = assignor.assign(partitionsPerTopic, subscriptions);
@@ -544,13 +557,16 @@ public class StickyAssignorTest {
         Map<String, Subscription> subscriptions = new HashMap<>();
         subscriptions.put("consumer01",
                 new Subscription(topics("topic01", "topic02"),
-                        StickyAssignor.serializeTopicPartitionAssignment(partitions(tp("topic01", 0)))));
+                        StickyAssignor.serializeTopicPartitionAssignment(
+                                new ConsumerUserData(partitions(tp("topic01", 0)), Optional.of(assignor.generation())))));
         subscriptions.put("consumer02",
                 new Subscription(topics("topic01", "topic02", "topic03", "topic04"),
-                        StickyAssignor.serializeTopicPartitionAssignment(partitions(tp("topic02", 0), tp("topic03", 0)))));
+                        StickyAssignor.serializeTopicPartitionAssignment(
+                                new ConsumerUserData(partitions(tp("topic02", 0), tp("topic03", 0)), Optional.of(assignor.generation())))));
         subscriptions.put("consumer03",
                 new Subscription(topics("topic02", "topic03", "topic04", "topic05", "topic06"),
-                        StickyAssignor.serializeTopicPartitionAssignment(partitions(tp("topic04", 0), tp("topic05", 0), tp("topic06", 0)))));
+                        StickyAssignor.serializeTopicPartitionAssignment(
+                                new ConsumerUserData(partitions(tp("topic04", 0), tp("topic05", 0), tp("topic06", 0)), Optional.of(assignor.generation())))));
 
         Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
         verifyValidityAndBalance(subscriptions, assignment);
@@ -584,13 +600,16 @@ public class StickyAssignorTest {
         subscriptions.remove("consumer01");
         subscriptions.put("consumer02",
                 new Subscription(topics("topic01"),
-                        StickyAssignor.serializeTopicPartitionAssignment(assignment.get("consumer02"))));
+                        StickyAssignor.serializeTopicPartitionAssignment(
+                                new ConsumerUserData(assignment.get("consumer02"), Optional.of(assignor.generation())))));
         subscriptions.put("consumer03",
                 new Subscription(topics("topic01"),
-                        StickyAssignor.serializeTopicPartitionAssignment(assignment.get("consumer03"))));
+                        StickyAssignor.serializeTopicPartitionAssignment(
+                                new ConsumerUserData(assignment.get("consumer03"), Optional.of(assignor.generation())))));
         subscriptions.put("consumer04",
                 new Subscription(topics("topic01"),
-                        StickyAssignor.serializeTopicPartitionAssignment(assignment.get("consumer04"))));
+                        StickyAssignor.serializeTopicPartitionAssignment(
+                                new ConsumerUserData(assignment.get("consumer04"), Optional.of(assignor.generation())))));
 
         assignment = assignor.assign(partitionsPerTopic, subscriptions);
         verifyValidityAndBalance(subscriptions, assignment);
@@ -631,13 +650,265 @@ public class StickyAssignorTest {
         Map<String, Subscription> subscriptions = new HashMap<>();
         subscriptions.put(consumer, new Subscription(topics(topic)));
         Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
-        subscriptions.put(consumer, new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(assignment.get(consumer))));
+        subscriptions.put(consumer,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(assignment.get(consumer), Optional.of(1)))));
 
         assignment = assignor.assign(Collections.emptyMap(), subscriptions);
         assertEquals(assignment.size(), 1);
         assertTrue(assignment.get(consumer).isEmpty());
     }
 
+    @Test
+    public void testAssignmentWithMultipleGenerations1() {
+        String topic = "topic";
+        String consumer1 = "consumer1";
+        String consumer2 = "consumer2";
+        String consumer3 = "consumer3";
+
+        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic, 6);
+        Map<String, Subscription> subscriptions = new HashMap<>();
+        subscriptions.put(consumer1, new Subscription(topics(topic)));
+        subscriptions.put(consumer2, new Subscription(topics(topic)));
+        subscriptions.put(consumer3, new Subscription(topics(topic)));
+
+        Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> r1partitions1 = assignment.get(consumer1);
+        List<TopicPartition> r1partitions2 = assignment.get(consumer2);
+        List<TopicPartition> r1partitions3 = assignment.get(consumer3);
+        assertTrue(r1partitions1.size() == 2 && r1partitions2.size() == 2 && r1partitions3.size() == 2);
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+
+        subscriptions.put(consumer1,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r1partitions1, Optional.of(1)))));
+        subscriptions.put(consumer2,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r1partitions2, Optional.of(1)))));
+        subscriptions.remove(consumer3);
+        assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> r2partitions1 = assignment.get(consumer1);
+        List<TopicPartition> r2partitions2 = assignment.get(consumer2);
+        assertTrue(r2partitions1.size() == 3 && r2partitions2.size() == 3);
+        assertTrue(r2partitions1.containsAll(r1partitions1));
+        assertTrue(r2partitions2.containsAll(r1partitions2));
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+        assertTrue(assignor.isSticky());
+
+        assertTrue(!Collections.disjoint(r2partitions2, r1partitions3));
+        subscriptions.remove(consumer1);
+        subscriptions.put(consumer2,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r2partitions2, Optional.of(2)))));
+        subscriptions.put(consumer3,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r1partitions3, Optional.of(1)))));
+        assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> r3partitions2 = assignment.get(consumer2);
+        List<TopicPartition> r3partitions3 = assignment.get(consumer3);
+        assertTrue(r3partitions2.size() == 3 && r3partitions3.size() == 3);
+        assertTrue(Collections.disjoint(r3partitions2, r3partitions3));
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+        assertTrue(assignor.isSticky());
+    }
+
+    @Test
+    public void testAssignmentWithMultipleGenerations2() {
+        String topic = "topic";
+        String consumer1 = "consumer1";
+        String consumer2 = "consumer2";
+        String consumer3 = "consumer3";
+
+        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic, 6);
+        Map<String, Subscription> subscriptions = new HashMap<>();
+        subscriptions.put(consumer1, new Subscription(topics(topic)));
+        subscriptions.put(consumer2, new Subscription(topics(topic)));
+        subscriptions.put(consumer3, new Subscription(topics(topic)));
+
+        Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> r1partitions1 = assignment.get(consumer1);
+        List<TopicPartition> r1partitions2 = assignment.get(consumer2);
+        List<TopicPartition> r1partitions3 = assignment.get(consumer3);
+        assertTrue(r1partitions1.size() == 2 && r1partitions2.size() == 2 && r1partitions3.size() == 2);
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+
+        subscriptions.remove(consumer1);
+        subscriptions.put(consumer2,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r1partitions2, Optional.of(1)))));
+        subscriptions.remove(consumer3);
+        assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> r2partitions2 = assignment.get(consumer2);
+        assertEquals(6, r2partitions2.size());
+        assertTrue(r2partitions2.containsAll(r1partitions2));
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+        assertTrue(assignor.isSticky());
+
+        subscriptions.put(consumer1,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r1partitions1, Optional.of(1)))));
+        subscriptions.put(consumer2,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r2partitions2, Optional.of(2)))));
+        subscriptions.put(consumer3,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(r1partitions3, Optional.of(1)))));
+        assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> r3partitions1 = assignment.get(consumer1);
+        List<TopicPartition> r3partitions2 = assignment.get(consumer2);
+        List<TopicPartition> r3partitions3 = assignment.get(consumer3);
+        assertTrue(r3partitions1.size() == 2 && r3partitions2.size() == 2 && r3partitions3.size() == 2);
+        assertEquals(r1partitions1, r3partitions1);
+        assertEquals(r1partitions2, r3partitions2);
+        assertEquals(r1partitions3, r3partitions3);
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+        assertTrue(assignor.isSticky());
+    }
+
+    @Test
+    public void testAssignmentWithConflictingPreviousGenerations() {
+        String topic = "topic";
+        String consumer1 = "consumer1";
+        String consumer2 = "consumer2";
+        String consumer3 = "consumer3";
+
+        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic, 6);
+        Map<String, Subscription> subscriptions = new HashMap<>();
+        subscriptions.put(consumer1, new Subscription(topics(topic)));
+        subscriptions.put(consumer2, new Subscription(topics(topic)));
+        subscriptions.put(consumer3, new Subscription(topics(topic)));
+
+        TopicPartition tp0 = new TopicPartition(topic, 0);
+        TopicPartition tp1 = new TopicPartition(topic, 1);
+        TopicPartition tp2 = new TopicPartition(topic, 2);
+        TopicPartition tp3 = new TopicPartition(topic, 3);
+        TopicPartition tp4 = new TopicPartition(topic, 4);
+        TopicPartition tp5 = new TopicPartition(topic, 5);
+
+        List<TopicPartition> c1partitions0 = partitions(tp0, tp1, tp4);
+        List<TopicPartition> c2partitions0 = partitions(tp0, tp2, tp3);
+        List<TopicPartition> c3partitions0 = partitions(tp3, tp4, tp5);
+        subscriptions.put(consumer1,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(c1partitions0, Optional.of(1)))));
+        subscriptions.put(consumer2,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(c2partitions0, Optional.of(1)))));
+        subscriptions.put(consumer3,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(c3partitions0, Optional.of(2)))));
+        Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> c1partitions = assignment.get(consumer1);
+        List<TopicPartition> c2partitions = assignment.get(consumer2);
+        List<TopicPartition> c3partitions = assignment.get(consumer3);
+
+        assertTrue(c1partitions.size() == 2 && c2partitions.size() == 2 && c3partitions.size() == 2);
+        assertTrue(c1partitions0.containsAll(c1partitions));
+        assertTrue(c2partitions0.containsAll(c2partitions));
+        assertTrue(c3partitions0.containsAll(c3partitions));
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+        assertTrue(assignor.isSticky());
+    }
+
+    @Test
+    public void testSchemaBackwardCompatibility() {
+        String topic = "topic";
+        String consumer1 = "consumer1";
+        String consumer2 = "consumer2";
+        String consumer3 = "consumer3";
+
+        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic, 3);
+        Map<String, Subscription> subscriptions = new HashMap<>();
+        subscriptions.put(consumer1, new Subscription(topics(topic)));
+        subscriptions.put(consumer2, new Subscription(topics(topic)));
+        subscriptions.put(consumer3, new Subscription(topics(topic)));
+
+        TopicPartition tp0 = new TopicPartition(topic, 0);
+        TopicPartition tp1 = new TopicPartition(topic, 1);
+        TopicPartition tp2 = new TopicPartition(topic, 2);
+
+        List<TopicPartition> c1partitions0 = partitions(tp0, tp2);
+        List<TopicPartition> c2partitions0 = partitions(tp1);
+        subscriptions.put(consumer1,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(c1partitions0, Optional.of(1)))));
+        subscriptions.put(consumer2,
+                new Subscription(topics(topic), serializeTopicPartitionAssignmentToOldSchema(c2partitions0)));
+        Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> c1partitions = assignment.get(consumer1);
+        List<TopicPartition> c2partitions = assignment.get(consumer2);
+        List<TopicPartition> c3partitions = assignment.get(consumer3);
+
+        assertTrue(c1partitions.size() == 1 && c2partitions.size() == 1 && c3partitions.size() == 1);
+        assertTrue(c1partitions0.containsAll(c1partitions));
+        assertTrue(c2partitions0.containsAll(c2partitions));
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+        assertTrue(assignor.isSticky());
+    }
+
+    @Test
+    public void testConflictingPreviousAssignments() {
+        String topic = "topic";
+        String consumer1 = "consumer1";
+        String consumer2 = "consumer2";
+
+        Map<String, Integer> partitionsPerTopic = new HashMap<>();
+        partitionsPerTopic.put(topic, 2);
+        Map<String, Subscription> subscriptions = new HashMap<>();
+        subscriptions.put(consumer1, new Subscription(topics(topic)));
+        subscriptions.put(consumer2, new Subscription(topics(topic)));
+
+        TopicPartition tp0 = new TopicPartition(topic, 0);
+        TopicPartition tp1 = new TopicPartition(topic, 1);
+
+        // both c1 and c2 have partition 1 assigned to them in generation 1
+        List<TopicPartition> c1partitions0 = partitions(tp0, tp1);
+        List<TopicPartition> c2partitions0 = partitions(tp0, tp1);
+        subscriptions.put(consumer1,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(c1partitions0, Optional.of(1)))));
+        subscriptions.put(consumer2,
+                new Subscription(topics(topic), StickyAssignor.serializeTopicPartitionAssignment(
+                        new ConsumerUserData(c2partitions0, Optional.of(1)))));
+
+        Map<String, List<TopicPartition>> assignment = assignor.assign(partitionsPerTopic, subscriptions);
+        List<TopicPartition> c1partitions = assignment.get(consumer1);
+        List<TopicPartition> c2partitions = assignment.get(consumer2);
+
+        assertTrue(c1partitions.size() == 1 && c2partitions.size() == 1);
+        verifyValidityAndBalance(subscriptions, assignment);
+        assertTrue(isFullyBalanced(assignment));
+        assertTrue(assignor.isSticky());
+    }
+
+    private static ByteBuffer serializeTopicPartitionAssignmentToOldSchema(List<TopicPartition> partitions) {
+        Struct struct = new Struct(StickyAssignor.STICKY_ASSIGNOR_USER_DATA_V0);
+        List<Struct> topicAssignments = new ArrayList<>();
+        for (Map.Entry<String, List<Integer>> topicEntry : CollectionUtils.groupPartitionsByTopic(partitions).entrySet()) {
+            Struct topicAssignment = new Struct(StickyAssignor.TOPIC_ASSIGNMENT);
+            topicAssignment.set(StickyAssignor.TOPIC_KEY_NAME, topicEntry.getKey());
+            topicAssignment.set(StickyAssignor.PARTITIONS_KEY_NAME, topicEntry.getValue().toArray());
+            topicAssignments.add(topicAssignment);
+        }
+        struct.set(StickyAssignor.TOPIC_PARTITIONS_KEY_NAME, topicAssignments.toArray());
+        ByteBuffer buffer = ByteBuffer.allocate(StickyAssignor.STICKY_ASSIGNOR_USER_DATA_V0.sizeOf(struct));
+        StickyAssignor.STICKY_ASSIGNOR_USER_DATA_V0.write(buffer, struct);
+        buffer.flip();
+        return buffer;
+    }
+
     private String getTopicName(int i, int maxNum) {
         return getCanonicalName("t", i, maxNum);
     }