You are viewing a plain text version of this content. The canonical link for it is here.
Posted to jira@kafka.apache.org by "dajac (via GitHub)" <gi...@apache.org> on 2023/02/17 15:48:34 UTC

[GitHub] [kafka] dajac commented on a diff in pull request #12990: KAFKA-14451: Rack-aware consumer partition assignment for RangeAssignor (KIP-881)

dajac commented on code in PR #12990:
URL: https://github.com/apache/kafka/pull/12990#discussion_r1109783046


##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -63,9 +76,19 @@
  * <li><code>I0: [t0p0, t0p1, t1p0, t1p1]</code>
  * <li><code>I1: [t0p2, t1p2]</code>
  * </ul>
+ * <p>
+ * Rack-aware assignment is used if both consumer and partition replica racks are available and
+ * some partitions have replicas only on a subset of racks. We attempt to match consumer racks with
+ * partition replica racks on a best-effort basis, prioritizing balanced assignment over rack alignment.
+ * Topics with equal partition count and same set of subscribers prioritize co-partitioning guarantee
+ * over rack alignment. In this case, aligning partition replicas of these topics on the same racks
+ * will improve locality for consumers. For example, if partitions 0 of all topics have a replica on
+ * rack 'a', partition 1 on rack 'b' etc., partition 0 of all topics can be assigned to a consumer
+ * on rack 'a', partition 1 to a consumer on rack 'b' and so on.

Review Comment:
   > Topics with equal partition count and same set of subscribers prioritize co-partitioning guarantee over rack alignment.
   
   I would like to ensure that we are on the same point on this point. My understanding is that the current implementation guarantees the co-partitioning iff the topics have the same number of partitions. Am I correct?



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {

Review Comment:
   nit: Would it make sense to use `forEach` here? I always find the `getKey` and `getValue` annoying.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {
             String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
+            Subscription subscription = subscriptionEntry.getValue();
+            MemberInfo memberInfo = new MemberInfo(consumerId, subscription.groupInstanceId(), subscription.rackId());
+            for (String topic : subscription.topics()) {
                 put(topicToConsumers, topic, memberInfo);
             }
         }
         return topicToConsumers;
     }
 
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
-                                                    Map<String, Subscription> subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = consumersPerTopic(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), consumersPerTopic.get(e.getKey())))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
         for (String memberId : subscriptions.keySet())
             assignment.put(memberId, new ArrayList<>());
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
+
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, assignment));
+
+        if (useRackAware)
+            assignment.values().forEach(list -> list.sort(PARTITION_COMPARATOR));
+        return assignment;
+    }
+
+    // This method is not used, but retained for compatibility with any custom assignors that extend this class.
+    @Override
+    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
+                                                    Map<String, Subscription> subscriptions) {
+        return assignPartitions(partitionInfosWithoutRacks(partitionsPerTopic), subscriptions);
+    }
+
+    private void assignRanges(TopicAssignmentState assignmentState,
+                              BiFunction<String, TopicPartition, Boolean> mayAssign,
+                              Map<String, List<TopicPartition>> assignment) {
+        for (String consumer : assignmentState.consumers) {
+            if (assignmentState.unassignedPartitions.isEmpty())
+                break;
+            List<TopicPartition> assignablePartitions = assignmentState.unassignedPartitions.stream()
+                    .filter(tp -> mayAssign.apply(consumer, tp))
+                    .collect(Collectors.toList());
 
-            Integer numPartitionsForTopic = partitionsPerTopic.get(topic);
-            if (numPartitionsForTopic == null)
+            int maxAssignable = Math.min(assignmentState.maxAssignable(consumer), assignablePartitions.size());
+            if (maxAssignable <= 0)
                 continue;
 
-            Collections.sort(consumersForTopic);
+            assign(consumer, assignablePartitions.subList(0, maxAssignable), assignmentState, assignment);
+        }
+    }
+
+    private void assignWithRackMatching(Collection<TopicAssignmentState> assignmentStates,
+                                        Map<String, List<TopicPartition>> assignment) {
 
-            int numPartitionsPerConsumer = numPartitionsForTopic / consumersForTopic.size();
-            int consumersWithExtraPartition = numPartitionsForTopic % consumersForTopic.size();
+        assignmentStates.stream().collect(Collectors.groupingBy(t -> t.consumers)).forEach((consumers, states) -> {
+            states.stream().collect(Collectors.groupingBy(t -> t.partitionRacks.size())).forEach((numPartitions, coPartitionedStates) -> {
+                if (coPartitionedStates.size() > 1)
+                    assignCoPartitionedWithRackMatching(consumers, numPartitions, states, assignment);
+                else {
+                    TopicAssignmentState state = coPartitionedStates.get(0);
+                    if (state.needsRackAwareAssignment)
+                        assignRanges(state, state::racksMatch, assignment);
+                }
+            });
+        });
+    }
+
+    private void assignCoPartitionedWithRackMatching(List<String> consumers,
+                                                     int numPartitions,
+                                                     Collection<TopicAssignmentState> assignmentStates,
+                                                     Map<String, List<TopicPartition>> assignment) {
+
+        List<String> remainingConsumers = new LinkedList<>(consumers);
+        for (int i = 0; i < numPartitions; i++) {
+            int p = i;
 
-            List<TopicPartition> partitions = AbstractPartitionAssignor.partitions(topic, numPartitionsForTopic);
-            for (int i = 0, n = consumersForTopic.size(); i < n; i++) {
-                int start = numPartitionsPerConsumer * i + Math.min(i, consumersWithExtraPartition);
-                int length = numPartitionsPerConsumer + (i + 1 > consumersWithExtraPartition ? 0 : 1);
-                assignment.get(consumersForTopic.get(i).memberId).addAll(partitions.subList(start, start + length));
+            Optional<String> matchingConsumer = remainingConsumers.stream()
+                    .filter(c -> assignmentStates.stream().allMatch(t -> t.racksMatch(c, new TopicPartition(t.topic, p)) && t.maxAssignable(c) > 0))
+                    .findFirst();
+            if (matchingConsumer.isPresent()) {

Review Comment:
   What do we do with the partition if we don't find a consumer?



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {
             String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
+            Subscription subscription = subscriptionEntry.getValue();
+            MemberInfo memberInfo = new MemberInfo(consumerId, subscription.groupInstanceId(), subscription.rackId());
+            for (String topic : subscription.topics()) {
                 put(topicToConsumers, topic, memberInfo);
             }
         }
         return topicToConsumers;
     }
 
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
-                                                    Map<String, Subscription> subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = consumersPerTopic(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), consumersPerTopic.get(e.getKey())))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
         for (String memberId : subscriptions.keySet())
             assignment.put(memberId, new ArrayList<>());
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
+
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, assignment));
+
+        if (useRackAware)
+            assignment.values().forEach(list -> list.sort(PARTITION_COMPARATOR));
+        return assignment;
+    }
+
+    // This method is not used, but retained for compatibility with any custom assignors that extend this class.
+    @Override
+    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
+                                                    Map<String, Subscription> subscriptions) {
+        return assignPartitions(partitionInfosWithoutRacks(partitionsPerTopic), subscriptions);
+    }
+
+    private void assignRanges(TopicAssignmentState assignmentState,
+                              BiFunction<String, TopicPartition, Boolean> mayAssign,
+                              Map<String, List<TopicPartition>> assignment) {
+        for (String consumer : assignmentState.consumers) {
+            if (assignmentState.unassignedPartitions.isEmpty())
+                break;
+            List<TopicPartition> assignablePartitions = assignmentState.unassignedPartitions.stream()
+                    .filter(tp -> mayAssign.apply(consumer, tp))
+                    .collect(Collectors.toList());
 
-            Integer numPartitionsForTopic = partitionsPerTopic.get(topic);
-            if (numPartitionsForTopic == null)
+            int maxAssignable = Math.min(assignmentState.maxAssignable(consumer), assignablePartitions.size());
+            if (maxAssignable <= 0)
                 continue;
 
-            Collections.sort(consumersForTopic);
+            assign(consumer, assignablePartitions.subList(0, maxAssignable), assignmentState, assignment);
+        }
+    }
+
+    private void assignWithRackMatching(Collection<TopicAssignmentState> assignmentStates,
+                                        Map<String, List<TopicPartition>> assignment) {
 
-            int numPartitionsPerConsumer = numPartitionsForTopic / consumersForTopic.size();
-            int consumersWithExtraPartition = numPartitionsForTopic % consumersForTopic.size();
+        assignmentStates.stream().collect(Collectors.groupingBy(t -> t.consumers)).forEach((consumers, states) -> {
+            states.stream().collect(Collectors.groupingBy(t -> t.partitionRacks.size())).forEach((numPartitions, coPartitionedStates) -> {
+                if (coPartitionedStates.size() > 1)
+                    assignCoPartitionedWithRackMatching(consumers, numPartitions, states, assignment);
+                else {
+                    TopicAssignmentState state = coPartitionedStates.get(0);
+                    if (state.needsRackAwareAssignment)
+                        assignRanges(state, state::racksMatch, assignment);
+                }
+            });
+        });
+    }
+
+    private void assignCoPartitionedWithRackMatching(List<String> consumers,
+                                                     int numPartitions,
+                                                     Collection<TopicAssignmentState> assignmentStates,
+                                                     Map<String, List<TopicPartition>> assignment) {
+
+        List<String> remainingConsumers = new LinkedList<>(consumers);
+        for (int i = 0; i < numPartitions; i++) {
+            int p = i;
 
-            List<TopicPartition> partitions = AbstractPartitionAssignor.partitions(topic, numPartitionsForTopic);
-            for (int i = 0, n = consumersForTopic.size(); i < n; i++) {
-                int start = numPartitionsPerConsumer * i + Math.min(i, consumersWithExtraPartition);
-                int length = numPartitionsPerConsumer + (i + 1 > consumersWithExtraPartition ? 0 : 1);
-                assignment.get(consumersForTopic.get(i).memberId).addAll(partitions.subList(start, start + length));
+            Optional<String> matchingConsumer = remainingConsumers.stream()
+                    .filter(c -> assignmentStates.stream().allMatch(t -> t.racksMatch(c, new TopicPartition(t.topic, p)) && t.maxAssignable(c) > 0))
+                    .findFirst();
+            if (matchingConsumer.isPresent()) {
+                String consumer = matchingConsumer.get();
+                assignmentStates.forEach(t -> assign(consumer, Collections.singletonList(new TopicPartition(t.topic, p)), t, assignment));
+
+                if (assignmentStates.stream().noneMatch(t -> t.maxAssignable(consumer) > 0)) {
+                    remainingConsumers.remove(consumer);
+                    if (remainingConsumers.isEmpty())
+                        break;
+                }
             }
         }
-        return assignment;
+    }
+
+    private void assign(String consumer, List<TopicPartition> partitions, TopicAssignmentState assignmentState, Map<String, List<TopicPartition>> assignment) {
+        assignment.get(consumer).addAll(partitions);
+        assignmentState.onAssigned(consumer, partitions);
+    }
+
+    private class TopicAssignmentState {
+        private final String topic;
+        private final List<String> consumers;
+        private final boolean needsRackAwareAssignment;
+        private final Map<TopicPartition, Set<String>> partitionRacks;
+        private final Map<String, String> consumerRacks;

Review Comment:
   It feels a bit weird to have this mapping here because it will be same in all `TopicAssignmentState` if the subscriptions are the same for all consumers. Is there a reason why you decided to have it here?



##########
core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala:
##########
@@ -1951,20 +1954,47 @@ class PlaintextConsumerTest extends BaseConsumerTest {
   }
 
   @Test
-  def testConsumerRackIdPropagatedToPartitionAssignor(): Unit = {
-    consumerConfig.setProperty(ConsumerConfig.CLIENT_RACK_CONFIG, "rack-a")
-    consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, classOf[RackAwareAssignor].getName)
-    val consumer = createConsumer()
-    consumer.subscribe(Set(topic).asJava)
-    awaitAssignment(consumer, Set(tp, tp2))
-  }
-}
+  def testRackAwareRangeAssignor(): Unit = {

Review Comment:
   I was wondering if there is any value in having an integration test with FFF enabled as well. I am really not sure. What do you think?



##########
core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala:
##########
@@ -1951,20 +1954,47 @@ class PlaintextConsumerTest extends BaseConsumerTest {
   }
 
   @Test
-  def testConsumerRackIdPropagatedToPartitionAssignor(): Unit = {
-    consumerConfig.setProperty(ConsumerConfig.CLIENT_RACK_CONFIG, "rack-a")
-    consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, classOf[RackAwareAssignor].getName)
-    val consumer = createConsumer()
-    consumer.subscribe(Set(topic).asJava)
-    awaitAssignment(consumer, Set(tp, tp2))
-  }
-}
+  def testRackAwareRangeAssignor(): Unit = {
+    val partitionList = servers.indices.toList
+
+    val topicWithAllPartitionsOnAllRacks = "topicWithAllPartitionsOnAllRacks"
+    createTopic(topicWithAllPartitionsOnAllRacks, servers.size, servers.size)
+
+    // Racks are in order of broker ids, assign leaders in reverse order
+    val topicWithSingleRackPartitions = "topicWithSingleRackPartitions"
+    createTopicWithAssignment(topicWithSingleRackPartitions, partitionList.map(i => (i, Seq(servers.size - i - 1))).toMap)
+
+    // Create consumers with instance ids in ascending order, with racks in the same order.
+    consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, classOf[RangeAssignor].getName)
+    val consumers = servers.map { server =>
+      consumerConfig.setProperty(ConsumerConfig.CLIENT_RACK_CONFIG, server.config.rack.orNull)
+      consumerConfig.setProperty(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, s"instance-${server.config.brokerId}")
+      createConsumer()
+    }
+
+    val executor = Executors.newFixedThreadPool(consumers.size)
+    def waitForAssignments(assignments: List[Set[TopicPartition]]): Unit = {
+      val futures = consumers.zipWithIndex.map { case (consumer, i) =>
+        executor.submit(() => awaitAssignment(consumer, assignments(i)), 0)
+      }
+      futures.foreach(future => assertEquals(0, future.get(20, TimeUnit.SECONDS)))
+    }
 
-class RackAwareAssignor extends RoundRobinAssignor {
-  override def assign(partitionsPerTopic: util.Map[String, Integer], subscriptions: util.Map[String, ConsumerPartitionAssignor.Subscription]): util.Map[String, util.List[TopicPartition]] = {
-    assertEquals(1, subscriptions.size())
-    assertEquals(Optional.of("rack-a"), subscriptions.values.asScala.head.rackId)
-    super.assign(partitionsPerTopic, subscriptions)
+    try {
+      // Rack-based assignment results in partitions assigned in reverse order since partition racks are in the reverse order.
+      consumers.foreach(_.subscribe(Collections.singleton(topicWithSingleRackPartitions)))
+      waitForAssignments(partitionList.reverse.map(p => Set(new TopicPartition(topicWithSingleRackPartitions, p))))
+
+      // Non-rack-aware assignment results in ordered partitions.
+      consumers.foreach(_.subscribe(Collections.singleton(topicWithAllPartitionsOnAllRacks)))
+      waitForAssignments(partitionList.map(p => Set(new TopicPartition(topicWithAllPartitionsOnAllRacks, p))))
+
+      // Rack-aware assignment with co-partitioning results in reverse assignment for both topics.
+      consumers.foreach(_.subscribe(Set(topicWithSingleRackPartitions, topicWithAllPartitionsOnAllRacks).asJava))
+      waitForAssignments(partitionList.reverse.map(p => Set(new TopicPartition(topicWithAllPartitionsOnAllRacks, p), new TopicPartition(topicWithSingleRackPartitions, p))))

Review Comment:
   When we have a consumer in the same zone as a leader, does the assignment algorithm guarantees that the leader in the same zone is selected all the time?



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {
             String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
+            Subscription subscription = subscriptionEntry.getValue();
+            MemberInfo memberInfo = new MemberInfo(consumerId, subscription.groupInstanceId(), subscription.rackId());
+            for (String topic : subscription.topics()) {
                 put(topicToConsumers, topic, memberInfo);
             }
         }
         return topicToConsumers;
     }
 
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
-                                                    Map<String, Subscription> subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = consumersPerTopic(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), consumersPerTopic.get(e.getKey())))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
         for (String memberId : subscriptions.keySet())
             assignment.put(memberId, new ArrayList<>());
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
+
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, assignment));

Review Comment:
   At first, it is not so clear why both `assignWithRackMatching` and `assignRanges` must be called. It may be worth adding a comment to explain the assignment logic here or some javadoc for `assignPartitions`.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {
             String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
+            Subscription subscription = subscriptionEntry.getValue();
+            MemberInfo memberInfo = new MemberInfo(consumerId, subscription.groupInstanceId(), subscription.rackId());
+            for (String topic : subscription.topics()) {
                 put(topicToConsumers, topic, memberInfo);
             }
         }
         return topicToConsumers;
     }
 
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
-                                                    Map<String, Subscription> subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = consumersPerTopic(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), consumersPerTopic.get(e.getKey())))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
         for (String memberId : subscriptions.keySet())
             assignment.put(memberId, new ArrayList<>());
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
+
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, assignment));
+
+        if (useRackAware)
+            assignment.values().forEach(list -> list.sort(PARTITION_COMPARATOR));
+        return assignment;
+    }
+
+    // This method is not used, but retained for compatibility with any custom assignors that extend this class.
+    @Override
+    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
+                                                    Map<String, Subscription> subscriptions) {
+        return assignPartitions(partitionInfosWithoutRacks(partitionsPerTopic), subscriptions);
+    }

Review Comment:
   This makes sense. Unfortunately, we did not make the class final.



##########
core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala:
##########
@@ -29,24 +35,21 @@ import org.apache.kafka.common.utils.Utils
 import org.apache.kafka.test.{MockConsumerInterceptor, MockProducerInterceptor}
 import org.junit.jupiter.api.Assertions._
 import org.junit.jupiter.api.Test
-
-import scala.jdk.CollectionConverters._
-import scala.collection.mutable.Buffer
-import kafka.server.QuotaType
-import kafka.server.KafkaServer
-import org.apache.kafka.clients.admin.NewPartitions
-import org.apache.kafka.clients.admin.NewTopic
-import org.apache.kafka.common.config.TopicConfig
 import org.junit.jupiter.params.ParameterizedTest
 import org.junit.jupiter.params.provider.ValueSource
 
-import java.util.concurrent.TimeUnit
-import java.util.concurrent.locks.ReentrantLock
 import scala.collection.mutable
+import scala.collection.mutable.Buffer
+import scala.jdk.CollectionConverters._
 
 /* We have some tests in this class instead of `BaseConsumerTest` in order to keep the build time under control. */
 class PlaintextConsumerTest extends BaseConsumerTest {
 
+  override def modifyConfigs(props: collection.Seq[Properties]): Unit = {
+    super.modifyConfigs(props)
+    props.zipWithIndex.foreach{ case (p, i) => p.setProperty(KafkaConfig.RackProp, i.toString) }

Review Comment:
   nit: I think that we usually put a space before `{`.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {
             String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
+            Subscription subscription = subscriptionEntry.getValue();
+            MemberInfo memberInfo = new MemberInfo(consumerId, subscription.groupInstanceId(), subscription.rackId());
+            for (String topic : subscription.topics()) {
                 put(topicToConsumers, topic, memberInfo);
             }
         }
         return topicToConsumers;
     }
 
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
-                                                    Map<String, Subscription> subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = consumersPerTopic(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), consumersPerTopic.get(e.getKey())))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
         for (String memberId : subscriptions.keySet())
             assignment.put(memberId, new ArrayList<>());
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
+
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, assignment));
+
+        if (useRackAware)
+            assignment.values().forEach(list -> list.sort(PARTITION_COMPARATOR));
+        return assignment;
+    }
+
+    // This method is not used, but retained for compatibility with any custom assignors that extend this class.
+    @Override
+    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
+                                                    Map<String, Subscription> subscriptions) {
+        return assignPartitions(partitionInfosWithoutRacks(partitionsPerTopic), subscriptions);
+    }
+
+    private void assignRanges(TopicAssignmentState assignmentState,
+                              BiFunction<String, TopicPartition, Boolean> mayAssign,
+                              Map<String, List<TopicPartition>> assignment) {
+        for (String consumer : assignmentState.consumers) {
+            if (assignmentState.unassignedPartitions.isEmpty())
+                break;
+            List<TopicPartition> assignablePartitions = assignmentState.unassignedPartitions.stream()
+                    .filter(tp -> mayAssign.apply(consumer, tp))
+                    .collect(Collectors.toList());
 
-            Integer numPartitionsForTopic = partitionsPerTopic.get(topic);
-            if (numPartitionsForTopic == null)
+            int maxAssignable = Math.min(assignmentState.maxAssignable(consumer), assignablePartitions.size());
+            if (maxAssignable <= 0)
                 continue;
 
-            Collections.sort(consumersForTopic);
+            assign(consumer, assignablePartitions.subList(0, maxAssignable), assignmentState, assignment);
+        }
+    }
+
+    private void assignWithRackMatching(Collection<TopicAssignmentState> assignmentStates,
+                                        Map<String, List<TopicPartition>> assignment) {
 
-            int numPartitionsPerConsumer = numPartitionsForTopic / consumersForTopic.size();
-            int consumersWithExtraPartition = numPartitionsForTopic % consumersForTopic.size();
+        assignmentStates.stream().collect(Collectors.groupingBy(t -> t.consumers)).forEach((consumers, states) -> {
+            states.stream().collect(Collectors.groupingBy(t -> t.partitionRacks.size())).forEach((numPartitions, coPartitionedStates) -> {
+                if (coPartitionedStates.size() > 1)
+                    assignCoPartitionedWithRackMatching(consumers, numPartitions, states, assignment);
+                else {
+                    TopicAssignmentState state = coPartitionedStates.get(0);
+                    if (state.needsRackAwareAssignment)
+                        assignRanges(state, state::racksMatch, assignment);
+                }
+            });
+        });
+    }
+
+    private void assignCoPartitionedWithRackMatching(List<String> consumers,
+                                                     int numPartitions,
+                                                     Collection<TopicAssignmentState> assignmentStates,
+                                                     Map<String, List<TopicPartition>> assignment) {
+
+        List<String> remainingConsumers = new LinkedList<>(consumers);
+        for (int i = 0; i < numPartitions; i++) {

Review Comment:
   I wonder if going over the partition is the best way. Imagine a case where the last partition ends up having only one rack available. This must be rare, I agree. However, you could end up in a situation where that partition could not be assigned to a consumer with the correct rack because all the consumers with that rack may have reached their maximum capacity.
   
   When I was thinking about it, I was wondering if it would be better to start with the most constraint partitions: the partitions with only one rack available, then the partitions with two racks, etc. Did you consider something like this? There is perhaps a downside that I haven't though about. Once concern is that it may be less deterministic and the determinism is important.



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {
             String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
+            Subscription subscription = subscriptionEntry.getValue();
+            MemberInfo memberInfo = new MemberInfo(consumerId, subscription.groupInstanceId(), subscription.rackId());
+            for (String topic : subscription.topics()) {
                 put(topicToConsumers, topic, memberInfo);
             }
         }
         return topicToConsumers;
     }
 
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
-                                                    Map<String, Subscription> subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = consumersPerTopic(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), consumersPerTopic.get(e.getKey())))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
         for (String memberId : subscriptions.keySet())
             assignment.put(memberId, new ArrayList<>());
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
+
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, assignment));
+
+        if (useRackAware)
+            assignment.values().forEach(list -> list.sort(PARTITION_COMPARATOR));
+        return assignment;
+    }
+
+    // This method is not used, but retained for compatibility with any custom assignors that extend this class.
+    @Override
+    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
+                                                    Map<String, Subscription> subscriptions) {
+        return assignPartitions(partitionInfosWithoutRacks(partitionsPerTopic), subscriptions);
+    }
+
+    private void assignRanges(TopicAssignmentState assignmentState,
+                              BiFunction<String, TopicPartition, Boolean> mayAssign,
+                              Map<String, List<TopicPartition>> assignment) {
+        for (String consumer : assignmentState.consumers) {
+            if (assignmentState.unassignedPartitions.isEmpty())
+                break;
+            List<TopicPartition> assignablePartitions = assignmentState.unassignedPartitions.stream()
+                    .filter(tp -> mayAssign.apply(consumer, tp))
+                    .collect(Collectors.toList());
 
-            Integer numPartitionsForTopic = partitionsPerTopic.get(topic);
-            if (numPartitionsForTopic == null)
+            int maxAssignable = Math.min(assignmentState.maxAssignable(consumer), assignablePartitions.size());
+            if (maxAssignable <= 0)
                 continue;
 
-            Collections.sort(consumersForTopic);
+            assign(consumer, assignablePartitions.subList(0, maxAssignable), assignmentState, assignment);
+        }
+    }
+
+    private void assignWithRackMatching(Collection<TopicAssignmentState> assignmentStates,
+                                        Map<String, List<TopicPartition>> assignment) {
 
-            int numPartitionsPerConsumer = numPartitionsForTopic / consumersForTopic.size();
-            int consumersWithExtraPartition = numPartitionsForTopic % consumersForTopic.size();
+        assignmentStates.stream().collect(Collectors.groupingBy(t -> t.consumers)).forEach((consumers, states) -> {
+            states.stream().collect(Collectors.groupingBy(t -> t.partitionRacks.size())).forEach((numPartitions, coPartitionedStates) -> {
+                if (coPartitionedStates.size() > 1)
+                    assignCoPartitionedWithRackMatching(consumers, numPartitions, states, assignment);
+                else {
+                    TopicAssignmentState state = coPartitionedStates.get(0);
+                    if (state.needsRackAwareAssignment)
+                        assignRanges(state, state::racksMatch, assignment);
+                }
+            });
+        });
+    }
+
+    private void assignCoPartitionedWithRackMatching(List<String> consumers,
+                                                     int numPartitions,
+                                                     Collection<TopicAssignmentState> assignmentStates,
+                                                     Map<String, List<TopicPartition>> assignment) {
+
+        List<String> remainingConsumers = new LinkedList<>(consumers);
+        for (int i = 0; i < numPartitions; i++) {
+            int p = i;

Review Comment:
   nit: I suppose that we could rename `i` to `p` and remove this one, isn't it?



##########
clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java:
##########
@@ -302,10 +339,151 @@ public void testStaticMemberRangeAssignmentPersistentAfterMemberIdChanges() {
         assertEquals(staticAssignment, newStaticAssignment);
     }
 
-    static Map<String, List<TopicPartition>> checkStaticAssignment(AbstractPartitionAssignor assignor,
-                                                                   Map<String, Integer> partitionsPerTopic,
-                                                                   Map<String, Subscription> consumers) {
-        Map<String, List<TopicPartition>> assignmentByMemberId = assignor.assign(partitionsPerTopic, consumers);
+    @Test
+    public void testRackAwareAssignmentWithUniformSubscription() {
+        Map<String, Integer> topics = mkMap(mkEntry("t1", 6), mkEntry("t2", 7), mkEntry("t3", 2));
+        List<String> allTopics = asList("t1", "t2", "t3");
+        List<List<String>> consumerTopics = asList(allTopics, allTopics, allTopics);
+        List<String> nonRackAwareAssignment = asList(
+                "t1-0, t1-1, t2-0, t2-1, t2-2, t3-0",
+                "t1-2, t1-3, t2-3, t2-4, t3-1",
+                "t1-4, t1-5, t2-5, t2-6"
+        );
+
+        // Verify combinations where rack-aware logic is not used.
+        verifyNonRackAwareAssignment(topics, consumerTopics, nonRackAwareAssignment);

Review Comment:
   nit: Should we inline `nonRackAwareAssignment` like for the other ones?



##########
clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java:
##########
@@ -76,43 +99,185 @@ private Map<String, List<MemberInfo>> consumersPerTopic(Map<String, Subscription
         Map<String, List<MemberInfo>> topicToConsumers = new HashMap<>();
         for (Map.Entry<String, Subscription> subscriptionEntry : consumerMetadata.entrySet()) {
             String consumerId = subscriptionEntry.getKey();
-            MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId());
-            for (String topic : subscriptionEntry.getValue().topics()) {
+            Subscription subscription = subscriptionEntry.getValue();
+            MemberInfo memberInfo = new MemberInfo(consumerId, subscription.groupInstanceId(), subscription.rackId());
+            for (String topic : subscription.topics()) {
                 put(topicToConsumers, topic, memberInfo);
             }
         }
         return topicToConsumers;
     }
 
     @Override
-    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
-                                                    Map<String, Subscription> subscriptions) {
+    public Map<String, List<TopicPartition>> assignPartitions(Map<String, List<PartitionInfo>> partitionsPerTopic,
+                                                              Map<String, Subscription> subscriptions) {
         Map<String, List<MemberInfo>> consumersPerTopic = consumersPerTopic(subscriptions);
+        List<TopicAssignmentState> topicAssignmentStates = partitionsPerTopic.entrySet().stream()
+                .filter(e -> !e.getValue().isEmpty())
+                .map(e -> new TopicAssignmentState(e.getKey(), e.getValue(), consumersPerTopic.get(e.getKey())))
+                .collect(Collectors.toList());
 
         Map<String, List<TopicPartition>> assignment = new HashMap<>();
         for (String memberId : subscriptions.keySet())
             assignment.put(memberId, new ArrayList<>());
 
-        for (Map.Entry<String, List<MemberInfo>> topicEntry : consumersPerTopic.entrySet()) {
-            String topic = topicEntry.getKey();
-            List<MemberInfo> consumersForTopic = topicEntry.getValue();
+        boolean useRackAware = topicAssignmentStates.stream().anyMatch(t -> t.needsRackAwareAssignment);
+        if (useRackAware)
+            assignWithRackMatching(topicAssignmentStates, assignment);
+
+        topicAssignmentStates.forEach(t -> assignRanges(t, (c, tp) -> true, assignment));
+
+        if (useRackAware)
+            assignment.values().forEach(list -> list.sort(PARTITION_COMPARATOR));
+        return assignment;
+    }
+
+    // This method is not used, but retained for compatibility with any custom assignors that extend this class.
+    @Override
+    public Map<String, List<TopicPartition>> assign(Map<String, Integer> partitionsPerTopic,
+                                                    Map<String, Subscription> subscriptions) {
+        return assignPartitions(partitionInfosWithoutRacks(partitionsPerTopic), subscriptions);
+    }
+
+    private void assignRanges(TopicAssignmentState assignmentState,
+                              BiFunction<String, TopicPartition, Boolean> mayAssign,
+                              Map<String, List<TopicPartition>> assignment) {
+        for (String consumer : assignmentState.consumers) {
+            if (assignmentState.unassignedPartitions.isEmpty())
+                break;
+            List<TopicPartition> assignablePartitions = assignmentState.unassignedPartitions.stream()
+                    .filter(tp -> mayAssign.apply(consumer, tp))
+                    .collect(Collectors.toList());
 
-            Integer numPartitionsForTopic = partitionsPerTopic.get(topic);
-            if (numPartitionsForTopic == null)
+            int maxAssignable = Math.min(assignmentState.maxAssignable(consumer), assignablePartitions.size());
+            if (maxAssignable <= 0)
                 continue;
 
-            Collections.sort(consumersForTopic);
+            assign(consumer, assignablePartitions.subList(0, maxAssignable), assignmentState, assignment);
+        }
+    }
+
+    private void assignWithRackMatching(Collection<TopicAssignmentState> assignmentStates,
+                                        Map<String, List<TopicPartition>> assignment) {
 
-            int numPartitionsPerConsumer = numPartitionsForTopic / consumersForTopic.size();
-            int consumersWithExtraPartition = numPartitionsForTopic % consumersForTopic.size();
+        assignmentStates.stream().collect(Collectors.groupingBy(t -> t.consumers)).forEach((consumers, states) -> {
+            states.stream().collect(Collectors.groupingBy(t -> t.partitionRacks.size())).forEach((numPartitions, coPartitionedStates) -> {
+                if (coPartitionedStates.size() > 1)
+                    assignCoPartitionedWithRackMatching(consumers, numPartitions, states, assignment);
+                else {
+                    TopicAssignmentState state = coPartitionedStates.get(0);
+                    if (state.needsRackAwareAssignment)
+                        assignRanges(state, state::racksMatch, assignment);
+                }
+            });
+        });
+    }
+
+    private void assignCoPartitionedWithRackMatching(List<String> consumers,
+                                                     int numPartitions,
+                                                     Collection<TopicAssignmentState> assignmentStates,
+                                                     Map<String, List<TopicPartition>> assignment) {
+
+        List<String> remainingConsumers = new LinkedList<>(consumers);
+        for (int i = 0; i < numPartitions; i++) {
+            int p = i;
 
-            List<TopicPartition> partitions = AbstractPartitionAssignor.partitions(topic, numPartitionsForTopic);
-            for (int i = 0, n = consumersForTopic.size(); i < n; i++) {
-                int start = numPartitionsPerConsumer * i + Math.min(i, consumersWithExtraPartition);
-                int length = numPartitionsPerConsumer + (i + 1 > consumersWithExtraPartition ? 0 : 1);
-                assignment.get(consumersForTopic.get(i).memberId).addAll(partitions.subList(start, start + length));
+            Optional<String> matchingConsumer = remainingConsumers.stream()
+                    .filter(c -> assignmentStates.stream().allMatch(t -> t.racksMatch(c, new TopicPartition(t.topic, p)) && t.maxAssignable(c) > 0))
+                    .findFirst();
+            if (matchingConsumer.isPresent()) {
+                String consumer = matchingConsumer.get();
+                assignmentStates.forEach(t -> assign(consumer, Collections.singletonList(new TopicPartition(t.topic, p)), t, assignment));
+
+                if (assignmentStates.stream().noneMatch(t -> t.maxAssignable(consumer) > 0)) {
+                    remainingConsumers.remove(consumer);
+                    if (remainingConsumers.isEmpty())
+                        break;

Review Comment:
   What do we do with the partition if we end up here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscribe@kafka.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org