You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2015/11/12 01:08:24 UTC
kafka git commit: KAFKA-2763: better stream task assignment
Repository: kafka
Updated Branches:
refs/heads/trunk c6b8de4e6 -> 124f73b17
KAFKA-2763: better stream task assignment
guozhangwang
When the rebalance happens each consumer reports the following information to the coordinator.
* Client UUID (a unique id assigned to an instance of KafkaStreaming)
* Task ids of previously running tasks
* Task ids of valid local states on the client's state directory
TaskAssignor does the following
* Assign a task to a client which was running it previously. If there is no such client, assign a task to a client which has its valid local state.
* Try to balance the load among stream threads.
* A client may have more than one stream threads. The assignor tries to assign tasks to a client proportionally to the number of threads.
Author: Yasuhiro Matsuda <ya...@confluent.io>
Reviewers: Guozhang Wang
Closes #497 from ymatsuda/task_assignment
Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/124f73b1
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/124f73b1
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/124f73b1
Branch: refs/heads/trunk
Commit: 124f73b1747a574982e9ca491712e6758ddbacea
Parents: c6b8de4
Author: Yasuhiro Matsuda <ya...@confluent.io>
Authored: Wed Nov 11 16:14:27 2015 -0800
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Wed Nov 11 16:14:27 2015 -0800
----------------------------------------------------------------------
.../apache/kafka/streams/KafkaStreaming.java | 5 +-
.../apache/kafka/streams/StreamingConfig.java | 8 +-
.../streams/processor/PartitionGrouper.java | 4 +
.../apache/kafka/streams/processor/TaskId.java | 23 +-
.../KafkaStreamingPartitionAssignor.java | 187 ++++++++----
.../processor/internals/StreamThread.java | 59 +++-
.../internals/assignment/AssignmentInfo.java | 125 ++++++++
.../internals/assignment/ClientState.java | 72 +++++
.../internals/assignment/SubscriptionInfo.java | 128 ++++++++
.../assignment/TaskAssignmentException.java | 32 ++
.../internals/assignment/TaskAssignor.java | 195 +++++++++++++
.../KafkaStreamingPartitionAssignorTest.java | 283 ++++++++++++++++++
.../processor/internals/StreamThreadTest.java | 33 ++-
.../assignment/AssginmentInfoTest.java | 45 +++
.../assignment/SubscriptionInfoTest.java | 46 +++
.../internals/assignment/TaskAssignorTest.java | 289 +++++++++++++++++++
16 files changed, 1464 insertions(+), 70 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
index d274fb9..fc1fdae 100644
--- a/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
+++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreaming.java
@@ -29,6 +29,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
+import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@@ -85,11 +86,13 @@ public class KafkaStreaming {
private final StreamThread[] threads;
private String clientId;
+ private final UUID uuid;
private final Metrics metrics;
public KafkaStreaming(TopologyBuilder builder, StreamingConfig config) throws Exception {
// create the metrics
this.time = new SystemTime();
+ this.uuid = UUID.randomUUID();
MetricConfig metricConfig = new MetricConfig().samples(config.getInt(StreamingConfig.METRICS_NUM_SAMPLES_CONFIG))
.timeWindow(config.getLong(StreamingConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG),
@@ -104,7 +107,7 @@ public class KafkaStreaming {
this.threads = new StreamThread[config.getInt(StreamingConfig.NUM_STREAM_THREADS_CONFIG)];
for (int i = 0; i < this.threads.length; i++) {
- this.threads[i] = new StreamThread(builder, config, this.clientId, this.metrics, this.time);
+ this.threads[i] = new StreamThread(builder, config, this.clientId, this.uuid, this.metrics, this.time);
}
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
index 88bd844..693cb0c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
+++ b/streams/src/main/java/org/apache/kafka/streams/StreamingConfig.java
@@ -27,8 +27,8 @@ import org.apache.kafka.common.config.ConfigDef.Type;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.Serializer;
import org.apache.kafka.streams.processor.DefaultPartitionGrouper;
-import org.apache.kafka.streams.processor.PartitionGrouper;
import org.apache.kafka.streams.processor.internals.KafkaStreamingPartitionAssignor;
+import org.apache.kafka.streams.processor.internals.StreamThread;
import java.util.Map;
@@ -205,16 +205,16 @@ public class StreamingConfig extends AbstractConfig {
}
public static class InternalConfig {
- public static final String PARTITION_GROUPER_INSTANCE = "__partition.grouper.instance__";
+ public static final String STREAM_THREAD_INSTANCE = "__stream.thread.instance__";
}
public StreamingConfig(Map<?, ?> props) {
super(CONFIG, props);
}
- public Map<String, Object> getConsumerConfigs(PartitionGrouper partitionGrouper) {
+ public Map<String, Object> getConsumerConfigs(StreamThread streamThread) {
Map<String, Object> props = getConsumerConfigs();
- props.put(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper);
+ props.put(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, streamThread);
props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, KafkaStreamingPartitionAssignor.class.getName());
return props;
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
index 026ec89..00b56b3 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/PartitionGrouper.java
@@ -50,4 +50,8 @@ public abstract class PartitionGrouper {
return partitionAssignor.taskIds(partition);
}
+ public Set<TaskId> standbyTasks() {
+ return partitionAssignor.standbyTasks();
+ }
+
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
index 3d474fe..5344f6c 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java
@@ -17,7 +17,9 @@
package org.apache.kafka.streams.processor;
-public class TaskId {
+import java.nio.ByteBuffer;
+
+public class TaskId implements Comparable<TaskId> {
public final int topicGroupId;
public final int partition;
@@ -45,6 +47,15 @@ public class TaskId {
}
}
+ public void writeTo(ByteBuffer buf) {
+ buf.putInt(topicGroupId);
+ buf.putInt(partition);
+ }
+
+ public static TaskId readFrom(ByteBuffer buf) {
+ return new TaskId(buf.getInt(), buf.getInt());
+ }
+
@Override
public boolean equals(Object o) {
if (o instanceof TaskId) {
@@ -61,6 +72,16 @@ public class TaskId {
return (int) (n % 0xFFFFFFFFL);
}
+ @Override
+ public int compareTo(TaskId other) {
+ return
+ this.topicGroupId < other.topicGroupId ? -1 :
+ (this.topicGroupId > other.topicGroupId ? 1 :
+ (this.partition < other.partition ? -1 :
+ (this.partition > other.partition ? 1 :
+ 0)));
+ }
+
public static class TaskIdFormatException extends RuntimeException {
}
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
index f7b14ad..35ba0ec 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignor.java
@@ -23,37 +23,49 @@ import org.apache.kafka.common.Configurable;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.streams.StreamingConfig;
-import org.apache.kafka.streams.processor.PartitionGrouper;
import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import org.apache.kafka.streams.processor.internals.assignment.ClientState;
+import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
+import org.apache.kafka.streams.processor.internals.assignment.TaskAssignmentException;
+import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.UUID;
public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Configurable {
private static final Logger log = LoggerFactory.getLogger(KafkaStreamingPartitionAssignor.class);
- private PartitionGrouper partitionGrouper;
+ private StreamThread streamThread;
private Map<TopicPartition, Set<TaskId>> partitionToTaskIds;
+ private Set<TaskId> standbyTasks;
@Override
public void configure(Map<String, ?> configs) {
- Object o = configs.get(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE);
- if (o == null)
- throw new KafkaException("PartitionGrouper is not specified");
+ Object o = configs.get(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE);
+ if (o == null) {
+ KafkaException ex = new KafkaException("StreamThread is not specified");
+ log.error(ex.getMessage(), ex);
+ throw ex;
+ }
- if (!PartitionGrouper.class.isInstance(o))
- throw new KafkaException(o.getClass().getName() + " is not an instance of " + PartitionGrouper.class.getName());
+ if (!(o instanceof StreamThread)) {
+ KafkaException ex = new KafkaException(o.getClass().getName() + " is not an instance of " + StreamThread.class.getName());
+ log.error(ex.getMessage(), ex);
+ throw ex;
+ }
- partitionGrouper = (PartitionGrouper) o;
- partitionGrouper.partitionAssignor(this);
+ streamThread = (StreamThread) o;
+ streamThread.partitionGrouper.partitionAssignor(this);
}
@Override
@@ -63,38 +75,110 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi
@Override
public Subscription subscription(Set<String> topics) {
- return new Subscription(new ArrayList<>(topics));
+ // Adds the following information to subscription
+ // 1. Client UUID (a unique id assigned to an instance of KafkaStreaming)
+ // 2. Task ids of previously running tasks
+ // 3. Task ids of valid local states on the client's state directory.
+
+ Set<TaskId> prevTasks = streamThread.prevTasks();
+ Set<TaskId> standbyTasks = streamThread.cachedTasks();
+ standbyTasks.removeAll(prevTasks);
+ SubscriptionInfo data = new SubscriptionInfo(streamThread.clientUUID, prevTasks, standbyTasks);
+
+ return new Subscription(new ArrayList<>(topics), data.encode());
}
@Override
public Map<String, Assignment> assign(Cluster metadata, Map<String, Subscription> subscriptions) {
- Map<TaskId, Set<TopicPartition>> partitionGroups = partitionGrouper.partitionGroups(metadata);
+ // This assigns tasks to consumer clients in two steps.
+ // 1. using TaskAssignor tasks are assigned to streaming clients.
+ // - Assign a task to a client which was running it previously.
+ // If there is no such client, assign a task to a client which has its valid local state.
+ // - A client may have more than one stream threads.
+ // The assignor tries to assign tasks to a client proportionally to the number of threads.
+ // - We try not to assign the same set of tasks to two different clients
+ // We do the assignment in one-pass. The result may not satisfy above all.
+ // 2. within each client, tasks are assigned to consumer clients in round-robin manner.
+
+ Map<UUID, Set<String>> consumersByClient = new HashMap<>();
+ Map<UUID, ClientState<TaskId>> states = new HashMap<>();
+
+ // Decode subscription info
+ for (Map.Entry<String, Subscription> entry : subscriptions.entrySet()) {
+ String consumerId = entry.getKey();
+ Subscription subscription = entry.getValue();
+
+ SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData());
+
+ Set<String> consumers = consumersByClient.get(info.clientUUID);
+ if (consumers == null) {
+ consumers = new HashSet<>();
+ consumersByClient.put(info.clientUUID, consumers);
+ }
+ consumers.add(consumerId);
+
+ ClientState<TaskId> state = states.get(info.clientUUID);
+ if (state == null) {
+ state = new ClientState<>();
+ states.put(info.clientUUID, state);
+ }
+
+ state.prevActiveTasks.addAll(info.prevTasks);
+ state.prevAssignedTasks.addAll(info.prevTasks);
+ state.prevAssignedTasks.addAll(info.standbyTasks);
+ state.capacity = state.capacity + 1d;
+ }
- String[] clientIds = subscriptions.keySet().toArray(new String[subscriptions.size()]);
- TaskId[] taskIds = partitionGroups.keySet().toArray(new TaskId[partitionGroups.size()]);
+ // Get partition groups from the partition grouper
+ Map<TaskId, Set<TopicPartition>> partitionGroups = streamThread.partitionGrouper.partitionGroups(metadata);
+ states = TaskAssignor.assign(states, partitionGroups.keySet(), 0); // TODO: enable standby tasks
Map<String, Assignment> assignment = new HashMap<>();
- for (int i = 0; i < clientIds.length; i++) {
- List<TopicPartition> partitions = new ArrayList<>();
- List<TaskId> ids = new ArrayList<>();
- for (int j = i; j < taskIds.length; j += clientIds.length) {
- TaskId taskId = taskIds[j];
- for (TopicPartition partition : partitionGroups.get(taskId)) {
- partitions.add(partition);
- ids.add(taskId);
- }
+ for (Map.Entry<UUID, Set<String>> entry : consumersByClient.entrySet()) {
+ UUID uuid = entry.getKey();
+ Set<String> consumers = entry.getValue();
+ ClientState<TaskId> state = states.get(uuid);
+
+ ArrayList<TaskId> taskIds = new ArrayList<>(state.assignedTasks.size());
+ final int numActiveTasks = state.activeTasks.size();
+ for (TaskId id : state.activeTasks) {
+ taskIds.add(id);
}
- ByteBuffer buf = ByteBuffer.allocate(4 + ids.size() * 8);
- //version
- buf.putInt(1);
- // encode task ids
- for (TaskId id : ids) {
- buf.putInt(id.topicGroupId);
- buf.putInt(id.partition);
+ for (TaskId id : state.assignedTasks) {
+ if (!state.activeTasks.contains(id))
+ taskIds.add(id);
+ }
+
+ final int numConsumers = consumers.size();
+ List<TaskId> active = new ArrayList<>();
+ Set<TaskId> standby = new HashSet<>();
+
+ int i = 0;
+ for (String consumer : consumers) {
+ List<TopicPartition> partitions = new ArrayList<>();
+
+ final int numTaskIds = taskIds.size();
+ for (int j = i; j < numTaskIds; j += numConsumers) {
+ TaskId taskId = taskIds.get(j);
+ if (j < numActiveTasks) {
+ for (TopicPartition partition : partitionGroups.get(taskId)) {
+ partitions.add(partition);
+ active.add(taskId);
+ }
+ } else {
+ // no partition to a standby task
+ standby.add(taskId);
+ }
+ }
+
+ AssignmentInfo data = new AssignmentInfo(active, standby);
+ assignment.put(consumer, new Assignment(partitions, data.encode()));
+ i++;
+
+ active.clear();
+ standby.clear();
}
- buf.rewind();
- assignment.put(clientIds[i], new Assignment(partitions, buf));
}
return assignment;
@@ -103,27 +187,29 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi
@Override
public void onAssignment(Assignment assignment) {
List<TopicPartition> partitions = assignment.partitions();
- ByteBuffer data = assignment.userData();
- data.rewind();
+
+ AssignmentInfo info = AssignmentInfo.decode(assignment.userData());
+ this.standbyTasks = info.standbyTasks;
Map<TopicPartition, Set<TaskId>> partitionToTaskIds = new HashMap<>();
+ Iterator<TaskId> iter = info.activeTasks.iterator();
+ for (TopicPartition partition : partitions) {
+ Set<TaskId> taskIds = partitionToTaskIds.get(partition);
+ if (taskIds == null) {
+ taskIds = new HashSet<>();
+ partitionToTaskIds.put(partition, taskIds);
+ }
- // check version
- int version = data.getInt();
- if (version == 1) {
- for (TopicPartition partition : partitions) {
- Set<TaskId> taskIds = partitionToTaskIds.get(partition);
- if (taskIds == null) {
- taskIds = new HashSet<>();
- partitionToTaskIds.put(partition, taskIds);
- }
- // decode a task id
- taskIds.add(new TaskId(data.getInt(), data.getInt()));
+ if (iter.hasNext()) {
+ taskIds.add(iter.next());
+ } else {
+ TaskAssignmentException ex = new TaskAssignmentException(
+ "failed to find a task id for the partition=" + partition.toString() +
+ ", partitions=" + partitions.size() + ", assignmentInfo=" + info.toString()
+ );
+ log.error(ex.getMessage(), ex);
+ throw ex;
}
- } else {
- KafkaException ex = new KafkaException("unknown assignment data version: " + version);
- log.error(ex.getMessage(), ex);
- throw ex;
}
this.partitionToTaskIds = partitionToTaskIds;
}
@@ -132,4 +218,7 @@ public class KafkaStreamingPartitionAssignor implements PartitionAssignor, Confi
return partitionToTaskIds.get(partition);
}
+ public Set<TaskId> standbyTasks() {
+ return standbyTasks;
+ }
}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index ba81421..06e5951 100644
--- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -59,6 +59,7 @@ import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
@@ -67,16 +68,18 @@ public class StreamThread extends Thread {
private static final Logger log = LoggerFactory.getLogger(StreamThread.class);
private static final AtomicInteger STREAMING_THREAD_ID_SEQUENCE = new AtomicInteger(1);
- private final AtomicBoolean running;
+ public final PartitionGrouper partitionGrouper;
+ public final UUID clientUUID;
protected final StreamingConfig config;
protected final TopologyBuilder builder;
- protected final PartitionGrouper partitionGrouper;
protected final Producer<byte[], byte[]> producer;
protected final Consumer<byte[], byte[]> consumer;
protected final Consumer<byte[], byte[]> restoreConsumer;
+ private final AtomicBoolean running;
private final Map<TaskId, StreamTask> tasks;
+ private final Set<TaskId> prevTasks;
private final String clientId;
private final Time time;
private final File stateDir;
@@ -108,9 +111,10 @@ public class StreamThread extends Thread {
public StreamThread(TopologyBuilder builder,
StreamingConfig config,
String clientId,
+ UUID clientUUID,
Metrics metrics,
Time time) throws Exception {
- this(builder, config, null , null, null, clientId, metrics, time);
+ this(builder, config, null , null, null, clientId, clientUUID, metrics, time);
}
StreamThread(TopologyBuilder builder,
@@ -119,6 +123,7 @@ public class StreamThread extends Thread {
Consumer<byte[], byte[]> consumer,
Consumer<byte[], byte[]> restoreConsumer,
String clientId,
+ UUID clientUUID,
Metrics metrics,
Time time) throws Exception {
super("StreamThread-" + STREAMING_THREAD_ID_SEQUENCE.getAndIncrement());
@@ -126,6 +131,7 @@ public class StreamThread extends Thread {
this.config = config;
this.builder = builder;
this.clientId = clientId;
+ this.clientUUID = clientUUID;
this.partitionGrouper = config.getConfiguredInstance(StreamingConfig.PARTITION_GROUPER_CLASS_CONFIG, PartitionGrouper.class);
this.partitionGrouper.topicGroups(builder.topicGroups());
@@ -136,6 +142,7 @@ public class StreamThread extends Thread {
// initialize the task list
this.tasks = new HashMap<>();
+ this.prevTasks = new HashSet<>();
// read in task specific config values
this.stateDir = new File(this.config.getString(StreamingConfig.STATE_DIR_CONFIG));
@@ -164,7 +171,7 @@ public class StreamThread extends Thread {
private Consumer<byte[], byte[]> createConsumer() {
log.info("Creating consumer client for stream thread [" + this.getName() + "]");
- return new KafkaConsumer<>(config.getConsumerConfigs(partitionGrouper),
+ return new KafkaConsumer<>(config.getConsumerConfigs(this),
new ByteArrayDeserializer(),
new ByteArrayDeserializer());
}
@@ -415,6 +422,43 @@ public class StreamThread extends Thread {
}
}
+ /**
+ * Returns ids of tasks that were being executed before the rebalance.
+ */
+ public Set<TaskId> prevTasks() {
+ return prevTasks;
+ }
+
+ /**
+ * Returns ids of tasks whose states are kept on the local storage.
+ */
+ public Set<TaskId> cachedTasks() {
+ // A client could contain some inactive tasks whose states are still kept on the local storage in the following scenarios:
+ // 1) the client is actively maintaining standby tasks by maintaining their states from the change log.
+ // 2) the client has just got some tasks migrated out of itself to other clients while these task states
+ // have not been cleaned up yet (this can happen in a rolling bounce upgrade, for example).
+
+ HashSet<TaskId> tasks = new HashSet<>();
+
+ File[] stateDirs = stateDir.listFiles();
+ if (stateDirs != null) {
+ for (File dir : stateDirs) {
+ try {
+ TaskId id = TaskId.parse(dir.getName());
+ // if the checkpoint file exists, the state is valid.
+ if (new File(dir, ProcessorStateManager.CHECKPOINT_FILE_NAME).exists())
+ tasks.add(id);
+
+ } catch (TaskId.TaskIdFormatException e) {
+ // there may be some unknown files that sits in the same directory,
+ // we should ignore these files instead trying to delete them as well
+ }
+ }
+ }
+
+ return tasks;
+ }
+
protected StreamTask createStreamTask(TaskId id, Collection<TopicPartition> partitionsForTask) {
sensors.taskCreationSensor.record();
@@ -465,11 +509,10 @@ public class StreamThread extends Thread {
}
sensors.taskDestructionSensor.record();
}
- tasks.clear();
- }
+ prevTasks.clear();
+ prevTasks.addAll(tasks.keySet());
- public PartitionGrouper partitionGrouper() {
- return partitionGrouper;
+ tasks.clear();
}
private void ensureCopartitioning(Collection<Set<String>> copartitionGroups) {
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
new file mode 100644
index 0000000..d82dd7d
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java
@@ -0,0 +1,125 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class AssignmentInfo {
+
+ private static final Logger log = LoggerFactory.getLogger(AssignmentInfo.class);
+
+ public final int version;
+ public final List<TaskId> activeTasks; // each element corresponds to a partition
+ public final Set<TaskId> standbyTasks;
+
+ public AssignmentInfo(List<TaskId> activeTasks, Set<TaskId> standbyTasks) {
+ this(1, activeTasks, standbyTasks);
+ }
+
+ protected AssignmentInfo(int version, List<TaskId> activeTasks, Set<TaskId> standbyTasks) {
+ this.version = version;
+ this.activeTasks = activeTasks;
+ this.standbyTasks = standbyTasks;
+ }
+
+ public ByteBuffer encode() {
+ if (version == 1) {
+ ByteBuffer buf = ByteBuffer.allocate(4 + 4 + activeTasks.size() * 8 + 4 + standbyTasks.size() * 8);
+ // Encode version
+ buf.putInt(1);
+ // Encode active tasks
+ buf.putInt(activeTasks.size());
+ for (TaskId id : activeTasks) {
+ id.writeTo(buf);
+ }
+ // Encode standby tasks
+ buf.putInt(standbyTasks.size());
+ for (TaskId id : standbyTasks) {
+ id.writeTo(buf);
+ }
+ buf.rewind();
+
+ return buf;
+
+ } else {
+ TaskAssignmentException ex = new TaskAssignmentException("unable to encode assignment data: version=" + version);
+ log.error(ex.getMessage(), ex);
+ throw ex;
+ }
+ }
+
+ public static AssignmentInfo decode(ByteBuffer data) {
+ // ensure we are at the beginning of the ByteBuffer
+ data.rewind();
+
+ // Decode version
+ int version = data.getInt();
+ if (version == 1) {
+ // Decode active tasks
+ int count = data.getInt();
+ List<TaskId> activeTasks = new ArrayList<>(count);
+ for (int i = 0; i < count; i++) {
+ activeTasks.add(TaskId.readFrom(data));
+ }
+ // Decode standby tasks
+ count = data.getInt();
+ Set<TaskId> standbyTasks = new HashSet<>(count);
+ for (int i = 0; i < count; i++) {
+ standbyTasks.add(TaskId.readFrom(data));
+ }
+
+ return new AssignmentInfo(activeTasks, standbyTasks);
+
+ } else {
+ TaskAssignmentException ex = new TaskAssignmentException("unknown assignment data version: " + version);
+ log.error(ex.getMessage(), ex);
+ throw ex;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return version ^ activeTasks.hashCode() ^ standbyTasks.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof AssignmentInfo) {
+ AssignmentInfo other = (AssignmentInfo) o;
+ return this.version == other.version &&
+ this.activeTasks.equals(other.activeTasks) &&
+ this.standbyTasks.equals(other.standbyTasks);
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "[version=" + version + ", active tasks=" + activeTasks.size() + ", standby tasks=" + standbyTasks.size() + "]";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
new file mode 100644
index 0000000..a0f6179
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashSet;
+import java.util.Set;
+
+public class ClientState<T> {
+
+ public final static double COST_ACTIVE = 0.1;
+ public final static double COST_STANDBY = 0.2;
+ public final static double COST_LOAD = 0.5;
+
+ public final Set<T> activeTasks;
+ public final Set<T> assignedTasks;
+ public final Set<T> prevActiveTasks;
+ public final Set<T> prevAssignedTasks;
+
+ public double capacity;
+ public double cost;
+
+ public ClientState() {
+ this(0d);
+ }
+
+ public ClientState(double capacity) {
+ this(new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), new HashSet<T>(), capacity);
+ }
+
+ private ClientState(Set<T> activeTasks, Set<T> assignedTasks, Set<T> prevActiveTasks, Set<T> prevAssignedTasks, double capacity) {
+ this.activeTasks = activeTasks;
+ this.assignedTasks = assignedTasks;
+ this.prevActiveTasks = prevActiveTasks;
+ this.prevAssignedTasks = prevAssignedTasks;
+ this.capacity = capacity;
+ this.cost = 0d;
+ }
+
+ public ClientState<T> copy() {
+ return new ClientState<>(new HashSet<>(activeTasks), new HashSet<>(assignedTasks),
+ new HashSet<>(prevActiveTasks), new HashSet<>(prevAssignedTasks), capacity);
+ }
+
+ public void assign(T taskId, boolean active) {
+ if (active)
+ activeTasks.add(taskId);
+
+ assignedTasks.add(taskId);
+
+ double cost = COST_LOAD;
+ cost = prevAssignedTasks.remove(taskId) ? COST_STANDBY : cost;
+ cost = prevActiveTasks.remove(taskId) ? COST_ACTIVE : cost;
+
+ this.cost += cost;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
new file mode 100644
index 0000000..54042b9
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.nio.ByteBuffer;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.UUID;
+
+public class SubscriptionInfo {
+
+ private static final Logger log = LoggerFactory.getLogger(SubscriptionInfo.class);
+
+ public final int version;
+ public final UUID clientUUID;
+ public final Set<TaskId> prevTasks;
+ public final Set<TaskId> standbyTasks;
+
+ public SubscriptionInfo(UUID clientUUID, Set<TaskId> prevTasks, Set<TaskId> standbyTasks) {
+ this(1, clientUUID, prevTasks, standbyTasks);
+ }
+
+ private SubscriptionInfo(int version, UUID clientUUID, Set<TaskId> prevTasks, Set<TaskId> standbyTasks) {
+ this.version = version;
+ this.clientUUID = clientUUID;
+ this.prevTasks = prevTasks;
+ this.standbyTasks = standbyTasks;
+ }
+
+ public ByteBuffer encode() {
+ if (version == 1) {
+ ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + prevTasks.size() * 8 + 4 + standbyTasks.size() * 8);
+ // version
+ buf.putInt(1);
+ // encode client UUID
+ buf.putLong(clientUUID.getMostSignificantBits());
+ buf.putLong(clientUUID.getLeastSignificantBits());
+ // encode ids of previously running tasks
+ buf.putInt(prevTasks.size());
+ for (TaskId id : prevTasks) {
+ id.writeTo(buf);
+ }
+ // encode ids of cached tasks
+ buf.putInt(standbyTasks.size());
+ for (TaskId id : standbyTasks) {
+ id.writeTo(buf);
+ }
+ buf.rewind();
+
+ return buf;
+
+ } else {
+ TaskAssignmentException ex = new TaskAssignmentException("unable to encode subscription data: version=" + version);
+ log.error(ex.getMessage(), ex);
+ throw ex;
+ }
+ }
+
+ public static SubscriptionInfo decode(ByteBuffer data) {
+ // ensure we are at the beginning of the ByteBuffer
+ data.rewind();
+
+ // Decode version
+ int version = data.getInt();
+ if (version == 1) {
+ // Decode client UUID
+ UUID clientUUID = new UUID(data.getLong(), data.getLong());
+ // Decode previously active tasks
+ Set<TaskId> prevTasks = new HashSet<>();
+ int numPrevs = data.getInt();
+ for (int i = 0; i < numPrevs; i++) {
+ TaskId id = TaskId.readFrom(data);
+ prevTasks.add(id);
+ }
+ // Decode previously cached tasks
+ Set<TaskId> standbyTasks = new HashSet<>();
+ int numCached = data.getInt();
+ for (int i = 0; i < numCached; i++) {
+ standbyTasks.add(TaskId.readFrom(data));
+ }
+
+ return new SubscriptionInfo(version, clientUUID, prevTasks, standbyTasks);
+
+ } else {
+ TaskAssignmentException ex = new TaskAssignmentException("unable to decode subscription data: version=" + version);
+ log.error(ex.getMessage(), ex);
+ throw ex;
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ return version ^ clientUUID.hashCode() ^ prevTasks.hashCode() ^ standbyTasks.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof SubscriptionInfo) {
+ SubscriptionInfo other = (SubscriptionInfo) o;
+ return this.version == other.version &&
+ this.clientUUID.equals(other.clientUUID) &&
+ this.prevTasks.equals(other.prevTasks) &&
+ this.standbyTasks.equals(other.standbyTasks);
+ } else {
+ return false;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java
new file mode 100644
index 0000000..839a6c2
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignmentException.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.common.KafkaException;
+
+/**
+ * The run time exception class for stream task assignments
+ */
+public class TaskAssignmentException extends KafkaException {
+
+ private final static long serialVersionUID = 1L;
+
+ public TaskAssignmentException(String message) {
+ super(message);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
----------------------------------------------------------------------
diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
new file mode 100644
index 0000000..d1e0782
--- /dev/null
+++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
@@ -0,0 +1,195 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+public class TaskAssignor<C, T extends Comparable<T>> {
+
+ private static final Logger log = LoggerFactory.getLogger(TaskAssignor.class);
+
+ public static <C, T extends Comparable<T>> Map<C, ClientState<T>> assign(Map<C, ClientState<T>> states, Set<T> tasks, int numStandbyReplicas) {
+ long seed = 0L;
+ for (C client : states.keySet()) {
+ seed += client.hashCode();
+ }
+
+ TaskAssignor<C, T> assignor = new TaskAssignor<>(states, tasks, seed);
+ assignor.assignTasks();
+ if (numStandbyReplicas > 0)
+ assignor.assignStandbyTasks(numStandbyReplicas);
+
+ return assignor.states;
+ }
+
+ private final Random rand;
+ private final Map<C, ClientState<T>> states;
+ private final Set<TaskPair<T>> taskPairs;
+ private final int maxNumTaskPairs;
+ private final ArrayList<T> tasks;
+
+ private TaskAssignor(Map<C, ClientState<T>> states, Set<T> tasks, long randomSeed) {
+ this.rand = new Random(randomSeed);
+ this.states = new HashMap<>();
+ for (Map.Entry<C, ClientState<T>> entry : states.entrySet()) {
+ this.states.put(entry.getKey(), entry.getValue().copy());
+ }
+ this.tasks = new ArrayList<>(tasks);
+
+ int numTasks = tasks.size();
+ this.maxNumTaskPairs = numTasks * (numTasks - 1) / 2;
+ this.taskPairs = new HashSet<>(this.maxNumTaskPairs);
+ }
+
+ public void assignTasks() {
+ assignTasks(true);
+ }
+
+ public void assignStandbyTasks(int numStandbyReplicas) {
+ int numReplicas = Math.min(numStandbyReplicas, states.size() - 1);
+ for (int i = 0; i < numReplicas; i++) {
+ assignTasks(false);
+ }
+ }
+
+ private void assignTasks(boolean active) {
+ Collections.shuffle(this.tasks, rand);
+
+ for (T task : tasks) {
+ ClientState<T> state = findClientFor(task);
+
+ if (state != null) {
+ state.assign(task, active);
+ } else {
+ TaskAssignmentException ex = new TaskAssignmentException("failed to find an assignable client");
+ log.error(ex.getMessage(), ex);
+ throw ex;
+ }
+ }
+ }
+
+ private ClientState<T> findClientFor(T task) {
+ boolean checkTaskPairs = taskPairs.size() < maxNumTaskPairs;
+
+ ClientState<T> state = findClientByAdditionCost(task, checkTaskPairs);
+
+ if (state == null && checkTaskPairs)
+ state = findClientByAdditionCost(task, false);
+
+ if (state != null)
+ addTaskPairs(task, state);
+
+ return state;
+ }
+
+ private ClientState<T> findClientByAdditionCost(T task, boolean checkTaskPairs) {
+ ClientState<T> candidate = null;
+ double candidateAdditionCost = 0d;
+
+ for (ClientState<T> state : states.values()) {
+ if (!state.assignedTasks.contains(task)) {
+ // if checkTaskPairs flag is on, skip this client if this task doesn't introduce a new task combination
+ if (checkTaskPairs && !state.assignedTasks.isEmpty() && !hasNewTaskPair(task, state))
+ continue;
+
+ double additionCost = computeAdditionCost(task, state);
+ if (candidate == null ||
+ (additionCost < candidateAdditionCost ||
+ (additionCost == candidateAdditionCost && state.cost < candidate.cost))) {
+ candidate = state;
+ candidateAdditionCost = additionCost;
+ }
+ }
+ }
+
+ return candidate;
+ }
+
+ private void addTaskPairs(T task, ClientState<T> state) {
+ for (T other : state.assignedTasks) {
+ taskPairs.add(pair(task, other));
+ }
+ }
+
+ private boolean hasNewTaskPair(T task, ClientState<T> state) {
+ for (T other : state.assignedTasks) {
+ if (!taskPairs.contains(pair(task, other)))
+ return true;
+ }
+ return false;
+ }
+
+ private double computeAdditionCost(T task, ClientState<T> state) {
+ double cost = Math.floor((double) state.assignedTasks.size() / state.capacity);
+
+ if (state.prevAssignedTasks.contains(task)) {
+ if (state.prevActiveTasks.contains(task)) {
+ cost += ClientState.COST_ACTIVE;
+ } else {
+ cost += ClientState.COST_STANDBY;
+ }
+ } else {
+ cost += ClientState.COST_LOAD;
+ }
+
+ return cost;
+ }
+
+ private TaskPair<T> pair(T task1, T task2) {
+ if (task1.compareTo(task2) < 0) {
+ return new TaskPair<>(task1, task2);
+ } else {
+ return new TaskPair<>(task2, task1);
+ }
+ }
+
+ private static class TaskPair<T> {
+ public final T task1;
+ public final T task2;
+
+ public TaskPair(T task1, T task2) {
+ this.task1 = task1;
+ this.task2 = task2;
+ }
+
+ @Override
+ public int hashCode() {
+ return task1.hashCode() ^ task2.hashCode();
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof TaskPair) {
+ TaskPair<T> other = (TaskPair<T>) o;
+ return this.task1.equals(other.task1) && this.task2.equals(other.task2);
+ }
+ return false;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java
new file mode 100644
index 0000000..86434fb
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/KafkaStreamingPartitionAssignorTest.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals;
+
+import org.apache.kafka.clients.consumer.Consumer;
+import org.apache.kafka.clients.consumer.MockConsumer;
+import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.clients.consumer.internals.PartitionAssignor;
+import org.apache.kafka.clients.producer.MockProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.common.Cluster;
+import org.apache.kafka.common.Node;
+import org.apache.kafka.common.PartitionInfo;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.common.metrics.Metrics;
+import org.apache.kafka.common.serialization.ByteArraySerializer;
+import org.apache.kafka.common.utils.SystemTime;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.StreamingConfig;
+import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.TopologyBuilder;
+import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo;
+import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
+import org.apache.kafka.test.MockProcessorSupplier;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+import java.util.UUID;
+
+import static org.junit.Assert.assertEquals;
+
+public class KafkaStreamingPartitionAssignorTest {
+
+ private TopicPartition t1p0 = new TopicPartition("topic1", 0);
+ private TopicPartition t1p1 = new TopicPartition("topic1", 1);
+ private TopicPartition t1p2 = new TopicPartition("topic1", 2);
+ private TopicPartition t2p0 = new TopicPartition("topic2", 0);
+ private TopicPartition t2p1 = new TopicPartition("topic2", 1);
+ private TopicPartition t2p2 = new TopicPartition("topic2", 2);
+ private TopicPartition t2p3 = new TopicPartition("topic2", 3);
+
+ private List<PartitionInfo> infos = Arrays.asList(
+ new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]),
+ new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]),
+ new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]),
+ new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]),
+ new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]),
+ new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0])
+ );
+
+ private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos, Collections.<String>emptySet());
+
+ private ByteBuffer subscriptionUserData() {
+ UUID uuid = UUID.randomUUID();
+ ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + 4);
+ // version
+ buf.putInt(1);
+ // encode client clientUUID
+ buf.putLong(uuid.getMostSignificantBits());
+ buf.putLong(uuid.getLeastSignificantBits());
+ // previously running tasks
+ buf.putInt(0);
+ // cached tasks
+ buf.putInt(0);
+ buf.rewind();
+ return buf;
+ }
+
+ private final TaskId task0 = new TaskId(0, 0);
+ private final TaskId task1 = new TaskId(0, 1);
+ private final TaskId task2 = new TaskId(0, 2);
+ private final TaskId task3 = new TaskId(0, 3);
+
+ private Properties configProps() {
+ return new Properties() {
+ {
+ setProperty(StreamingConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer");
+ setProperty(StreamingConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer");
+ setProperty(StreamingConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer");
+ setProperty(StreamingConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer");
+ setProperty(StreamingConfig.TIMESTAMP_EXTRACTOR_CLASS_CONFIG, "org.apache.kafka.test.MockTimestampExtractor");
+ setProperty(StreamingConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171");
+ setProperty(StreamingConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3");
+ }
+ };
+ }
+
+ private static class TestStreamTask extends StreamTask {
+ public boolean committed = false;
+
+ public TestStreamTask(TaskId id,
+ Consumer<byte[], byte[]> consumer,
+ Producer<byte[], byte[]> producer,
+ Consumer<byte[], byte[]> restoreConsumer,
+ Collection<TopicPartition> partitions,
+ ProcessorTopology topology,
+ StreamingConfig config) {
+ super(id, consumer, producer, restoreConsumer, partitions, topology, config, null);
+ }
+
+ @Override
+ public void commit() {
+ super.commit();
+ committed = true;
+ }
+ }
+
+ private ByteArraySerializer serializer = new ByteArraySerializer();
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testSubscription() throws Exception {
+ StreamingConfig config = new StreamingConfig(configProps());
+
+ MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer);
+ MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
+ MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST);
+
+ TopologyBuilder builder = new TopologyBuilder();
+ builder.addSource("source1", "topic1");
+ builder.addSource("source2", "topic2");
+ builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
+
+ final Set<TaskId> prevTasks = Utils.mkSet(
+ new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1));
+ final Set<TaskId> cachedTasks = Utils.mkSet(
+ new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1),
+ new TaskId(0, 2), new TaskId(1, 2), new TaskId(2, 2));
+
+ UUID uuid = UUID.randomUUID();
+ StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime()) {
+ @Override
+ public Set<TaskId> prevTasks() {
+ return prevTasks;
+ }
+ @Override
+ public Set<TaskId> cachedTasks() {
+ return cachedTasks;
+ }
+ };
+
+ KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
+ partitionAssignor.configure(
+ Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread)
+ );
+
+ PartitionAssignor.Subscription subscription = partitionAssignor.subscription(Utils.mkSet("topic1", "topic2"));
+
+ assertEquals(Utils.mkList("topic1", "topic2"), subscription.topics());
+
+ Set<TaskId> standbyTasks = new HashSet<>(cachedTasks);
+ standbyTasks.removeAll(prevTasks);
+
+ SubscriptionInfo info = new SubscriptionInfo(uuid, prevTasks, standbyTasks);
+ assertEquals(info.encode(), subscription.userData());
+ }
+
+ @Test
+ public void testAssign() throws Exception {
+ StreamingConfig config = new StreamingConfig(configProps());
+
+ MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer);
+ MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
+ MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST);
+
+ TopologyBuilder builder = new TopologyBuilder();
+ builder.addSource("source1", "topic1");
+ builder.addSource("source2", "topic2");
+ builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
+
+ final Set<TaskId> prevTasks10 = Utils.mkSet(task0);
+ final Set<TaskId> prevTasks11 = Utils.mkSet(task1);
+ final Set<TaskId> prevTasks20 = Utils.mkSet(task2);
+ final Set<TaskId> standbyTasks10 = Utils.mkSet(task1);
+ final Set<TaskId> standbyTasks11 = Utils.mkSet(task2);
+ final Set<TaskId> standbyTasks20 = Utils.mkSet(task0);
+
+ UUID uuid1 = UUID.randomUUID();
+ UUID uuid2 = UUID.randomUUID();
+
+ StreamThread thread10 = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid1, new Metrics(), new SystemTime());
+
+ KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
+ partitionAssignor.configure(
+ Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread10)
+ );
+
+ Map<String, PartitionAssignor.Subscription> subscriptions = new HashMap<>();
+ subscriptions.put("consumer10",
+ new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid1, prevTasks10, standbyTasks10).encode()));
+ subscriptions.put("consumer11",
+ new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid1, prevTasks11, standbyTasks11).encode()));
+ subscriptions.put("consumer20",
+ new PartitionAssignor.Subscription(Utils.mkList("topic1", "topic2"), new SubscriptionInfo(uuid2, prevTasks20, standbyTasks20).encode()));
+
+ Map<String, PartitionAssignor.Assignment> assignments = partitionAssignor.assign(metadata, subscriptions);
+
+ // check assigned partitions
+
+ assertEquals(Utils.mkSet(Utils.mkSet(t1p0, t2p0), Utils.mkSet(t1p1, t2p1)),
+ Utils.mkSet(new HashSet<>(assignments.get("consumer10").partitions()), new HashSet<>(assignments.get("consumer11").partitions())));
+ assertEquals(Utils.mkSet(t1p2, t2p2), new HashSet<>(assignments.get("consumer20").partitions()));
+
+ // check assignment info
+
+ List<TaskId> activeTasks = new ArrayList<>();
+ for (TopicPartition partition : assignments.get("consumer10").partitions()) {
+ activeTasks.add(new TaskId(0, partition.partition()));
+ }
+ assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer10").userData()).activeTasks);
+
+ activeTasks.clear();
+ for (TopicPartition partition : assignments.get("consumer11").partitions()) {
+ activeTasks.add(new TaskId(0, partition.partition()));
+ }
+ assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer11").userData()).activeTasks);
+
+ activeTasks.clear();
+ for (TopicPartition partition : assignments.get("consumer20").partitions()) {
+ activeTasks.add(new TaskId(0, partition.partition()));
+ }
+ assertEquals(activeTasks, AssignmentInfo.decode(assignments.get("consumer20").userData()).activeTasks);
+ }
+
+ @Test
+ public void testOnAssignment() throws Exception {
+ StreamingConfig config = new StreamingConfig(configProps());
+
+ MockProducer<byte[], byte[]> producer = new MockProducer<>(true, serializer, serializer);
+ MockConsumer<byte[], byte[]> consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST);
+ MockConsumer<byte[], byte[]> mockRestoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST);
+
+ TopologyBuilder builder = new TopologyBuilder();
+ builder.addSource("source1", "topic1");
+ builder.addSource("source2", "topic2");
+ builder.addProcessor("processor", new MockProcessorSupplier(), "source1", "source2");
+
+ UUID uuid = UUID.randomUUID();
+
+ StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime());
+
+ KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
+ partitionAssignor.configure(
+ Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread)
+ );
+
+ List<TaskId> activeTaskList = Utils.mkList(task0, task3);
+ Set<TaskId> standbyTasks = Utils.mkSet(task1, task2);
+ AssignmentInfo info = new AssignmentInfo(activeTaskList, standbyTasks);
+ PartitionAssignor.Assignment assignment = new PartitionAssignor.Assignment(Utils.mkList(t1p0, t2p3), info.encode());
+ partitionAssignor.onAssignment(assignment);
+
+ assertEquals(Utils.mkSet(task0), partitionAssignor.taskIds(t1p0));
+ assertEquals(Utils.mkSet(task3), partitionAssignor.taskIds(t2p3));
+ assertEquals(standbyTasks, partitionAssignor.standbyTasks());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 909df13..54d0a18 100644
--- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -38,13 +38,13 @@ import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.SystemTime;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.streams.StreamingConfig;
-import org.apache.kafka.streams.processor.PartitionGrouper;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.TopologyBuilder;
import org.apache.kafka.test.MockProcessorSupplier;
import org.junit.Test;
import java.io.File;
+import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.Collection;
@@ -55,9 +55,12 @@ import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
+import java.util.UUID;
public class StreamThreadTest {
+ private UUID uuid = UUID.randomUUID();
+
private TopicPartition t1p1 = new TopicPartition("topic1", 1);
private TopicPartition t1p2 = new TopicPartition("topic1", 2);
private TopicPartition t2p1 = new TopicPartition("topic2", 1);
@@ -79,7 +82,24 @@ public class StreamThreadTest {
private Cluster metadata = new Cluster(Arrays.asList(Node.noNode()), infos, Collections.<String>emptySet());
- PartitionAssignor.Subscription subscription = new PartitionAssignor.Subscription(Arrays.asList("topic1", "topic2", "topic3"));
+ private final PartitionAssignor.Subscription subscription =
+ new PartitionAssignor.Subscription(Arrays.asList("topic1", "topic2", "topic3"), subscriptionUserData());
+
+ private ByteBuffer subscriptionUserData() {
+ UUID uuid = UUID.randomUUID();
+ ByteBuffer buf = ByteBuffer.allocate(4 + 16 + 4 + 4);
+ // version
+ buf.putInt(1);
+ // encode client clientUUID
+ buf.putLong(uuid.getMostSignificantBits());
+ buf.putLong(uuid.getLeastSignificantBits());
+ // previously running tasks
+ buf.putInt(0);
+ // cached tasks
+ buf.putInt(0);
+ buf.rewind();
+ return buf;
+ }
// task0 is unused
private final TaskId task1 = new TaskId(0, 1);
@@ -139,7 +159,7 @@ public class StreamThreadTest {
builder.addSource("source3", "topic3");
builder.addProcessor("processor", new MockProcessorSupplier(), "source2", "source3");
- StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), new SystemTime()) {
+ StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), new SystemTime()) {
@Override
protected StreamTask createStreamTask(TaskId id, Collection<TopicPartition> partitionsForTask) {
ProcessorTopology topology = builder.build(id.topicGroupId);
@@ -259,7 +279,7 @@ public class StreamThreadTest {
TopologyBuilder builder = new TopologyBuilder();
builder.addSource("source1", "topic1");
- StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), mockTime) {
+ StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), mockTime) {
@Override
public void maybeClean() {
super.maybeClean();
@@ -381,7 +401,7 @@ public class StreamThreadTest {
TopologyBuilder builder = new TopologyBuilder();
builder.addSource("source1", "topic1");
- StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", new Metrics(), mockTime) {
+ StreamThread thread = new StreamThread(builder, config, producer, consumer, mockRestoreConsumer, "test", uuid, new Metrics(), mockTime) {
@Override
public void maybeCommit() {
super.maybeCommit();
@@ -448,12 +468,11 @@ public class StreamThreadTest {
}
private void initPartitionGrouper(StreamThread thread) {
- PartitionGrouper partitionGrouper = thread.partitionGrouper();
KafkaStreamingPartitionAssignor partitionAssignor = new KafkaStreamingPartitionAssignor();
partitionAssignor.configure(
- Collections.singletonMap(StreamingConfig.InternalConfig.PARTITION_GROUPER_INSTANCE, partitionGrouper)
+ Collections.singletonMap(StreamingConfig.InternalConfig.STREAM_THREAD_INSTANCE, thread)
);
Map<String, PartitionAssignor.Assignment> assignments =
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java
new file mode 100644
index 0000000..58e0af9
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssginmentInfoTest.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+
+public class AssginmentInfoTest {
+
+ @Test
+ public void testEncodeDecode() {
+ List<TaskId> activeTasks =
+ Arrays.asList(new TaskId(0, 0), new TaskId(0, 0), new TaskId(0, 1), new TaskId(1, 0));
+ Set<TaskId> standbyTasks =
+ new HashSet<>(Arrays.asList(new TaskId(1, 1), new TaskId(2, 0)));
+
+ AssignmentInfo info = new AssignmentInfo(activeTasks, standbyTasks);
+ AssignmentInfo decoded = AssignmentInfo.decode(info.encode());
+
+ assertEquals(info, decoded);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
new file mode 100644
index 0000000..acc9a9d
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import org.apache.kafka.streams.processor.TaskId;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.UUID;
+
+import static org.junit.Assert.assertEquals;
+
+public class SubscriptionInfoTest {
+
+ @Test
+ public void testEncodeDecode() {
+ UUID clientUUID = UUID.randomUUID();
+ Set<TaskId> activeTasks =
+ new HashSet<>(Arrays.asList(new TaskId(0, 0), new TaskId(0, 1), new TaskId(1, 0)));
+ Set<TaskId> standbyTasks =
+ new HashSet<>(Arrays.asList(new TaskId(1, 1), new TaskId(2, 0)));
+
+ SubscriptionInfo info = new SubscriptionInfo(clientUUID, activeTasks, standbyTasks);
+ SubscriptionInfo decoded = SubscriptionInfo.decode(info.encode());
+
+ assertEquals(info, decoded);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/kafka/blob/124f73b1/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
----------------------------------------------------------------------
diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
new file mode 100644
index 0000000..28364ab
--- /dev/null
+++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorTest.java
@@ -0,0 +1,289 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import static org.apache.kafka.common.utils.Utils.mkList;
+import static org.apache.kafka.common.utils.Utils.mkSet;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class TaskAssignorTest {
+
+ @Test
+ public void testAssignWithoutStandby() {
+ HashMap<Integer, ClientState<Integer>> states = new HashMap<>();
+ for (int i = 0; i < 6; i++) {
+ states.put(i, new ClientState<Integer>(1d));
+ }
+ Set<Integer> tasks;
+ Map<Integer, ClientState<Integer>> assignments;
+ int numActiveTasks;
+ int numAssignedTasks;
+
+ // # of clients and # of tasks are equal.
+ tasks = mkSet(0, 1, 2, 3, 4, 5);
+ assignments = TaskAssignor.assign(states, tasks, 0);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertEquals(1, assignment.activeTasks.size());
+ assertEquals(1, assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size(), numAssignedTasks);
+
+ // # of clients < # of tasks
+ tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7);
+ assignments = TaskAssignor.assign(states, tasks, 0);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertTrue(1 <= assignment.activeTasks.size());
+ assertTrue(2 >= assignment.activeTasks.size());
+ assertTrue(1 <= assignment.assignedTasks.size());
+ assertTrue(2 >= assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size(), numAssignedTasks);
+
+ // # of clients > # of tasks
+ tasks = mkSet(0, 1, 2, 3);
+ assignments = TaskAssignor.assign(states, tasks, 0);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertTrue(0 <= assignment.activeTasks.size());
+ assertTrue(1 >= assignment.activeTasks.size());
+ assertTrue(0 <= assignment.assignedTasks.size());
+ assertTrue(1 >= assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size(), numAssignedTasks);
+ }
+
+ @Test
+ public void testAssignWithStandby() {
+ HashMap<Integer, ClientState<Integer>> states = new HashMap<>();
+ for (int i = 0; i < 6; i++) {
+ states.put(i, new ClientState<Integer>(1d));
+ }
+ Set<Integer> tasks;
+ Map<Integer, ClientState<Integer>> assignments;
+ int numActiveTasks;
+ int numAssignedTasks;
+
+ // # of clients and # of tasks are equal.
+ tasks = mkSet(0, 1, 2, 3, 4, 5);
+
+ // 1 standby replicas.
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ assignments = TaskAssignor.assign(states, tasks, 1);
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertEquals(1, assignment.activeTasks.size());
+ assertEquals(2, assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size() * 2, numAssignedTasks);
+
+ // # of clients < # of tasks
+ tasks = mkSet(0, 1, 2, 3, 4, 5, 6, 7);
+
+ // 1 standby replicas.
+ assignments = TaskAssignor.assign(states, tasks, 1);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertTrue(1 <= assignment.activeTasks.size());
+ assertTrue(2 >= assignment.activeTasks.size());
+ assertTrue(2 <= assignment.assignedTasks.size());
+ assertTrue(3 >= assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size() * 2, numAssignedTasks);
+
+ // # of clients > # of tasks
+ tasks = mkSet(0, 1, 2, 3);
+
+ // 1 standby replicas.
+ assignments = TaskAssignor.assign(states, tasks, 1);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertTrue(0 <= assignment.activeTasks.size());
+ assertTrue(1 >= assignment.activeTasks.size());
+ assertTrue(1 <= assignment.assignedTasks.size());
+ assertTrue(2 >= assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size() * 2, numAssignedTasks);
+
+ // # of clients >> # of tasks
+ tasks = mkSet(0, 1);
+
+ // 1 standby replicas.
+ assignments = TaskAssignor.assign(states, tasks, 1);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertTrue(0 <= assignment.activeTasks.size());
+ assertTrue(1 >= assignment.activeTasks.size());
+ assertTrue(0 <= assignment.assignedTasks.size());
+ assertTrue(1 >= assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size() * 2, numAssignedTasks);
+
+ // 2 standby replicas.
+ assignments = TaskAssignor.assign(states, tasks, 2);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertTrue(0 <= assignment.activeTasks.size());
+ assertTrue(1 >= assignment.activeTasks.size());
+ assertTrue(1 == assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size() * 3, numAssignedTasks);
+
+ // 3 standby replicas.
+ assignments = TaskAssignor.assign(states, tasks, 3);
+ numActiveTasks = 0;
+ numAssignedTasks = 0;
+ for (ClientState<Integer> assignment : assignments.values()) {
+ numActiveTasks += assignment.activeTasks.size();
+ numAssignedTasks += assignment.assignedTasks.size();
+ assertTrue(0 <= assignment.activeTasks.size());
+ assertTrue(1 >= assignment.activeTasks.size());
+ assertTrue(1 <= assignment.assignedTasks.size());
+ assertTrue(2 >= assignment.assignedTasks.size());
+ }
+ assertEquals(tasks.size(), numActiveTasks);
+ assertEquals(tasks.size() * 4, numAssignedTasks);
+ }
+
+ @Test
+ public void testStickiness() {
+ List<Integer> tasks;
+ Map<Integer, ClientState<Integer>> states;
+ Map<Integer, ClientState<Integer>> assignments;
+ int i;
+
+ // # of clients and # of tasks are equal.
+ tasks = mkList(0, 1, 2, 3, 4, 5);
+ Collections.shuffle(tasks);
+ states = new HashMap<>();
+ i = 0;
+ for (int task : tasks) {
+ ClientState<Integer> state = new ClientState<>(1d);
+ state.prevActiveTasks.add(task);
+ state.prevAssignedTasks.add(task);
+ states.put(i++, state);
+ }
+ assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5), 0);
+ for (int client : states.keySet()) {
+ Set<Integer> oldActive = states.get(client).prevActiveTasks;
+ Set<Integer> oldAssigned = states.get(client).prevAssignedTasks;
+ Set<Integer> newActive = assignments.get(client).activeTasks;
+ Set<Integer> newAssigned = assignments.get(client).assignedTasks;
+
+ assertEquals(oldActive, newActive);
+ assertEquals(oldAssigned, newAssigned);
+ }
+
+ // # of clients > # of tasks
+ tasks = mkList(0, 1, 2, 3, -1, -1);
+ Collections.shuffle(tasks);
+ states = new HashMap<>();
+ i = 0;
+ for (int task : tasks) {
+ ClientState<Integer> state = new ClientState<>(1d);
+ if (task >= 0) {
+ state.prevActiveTasks.add(task);
+ state.prevAssignedTasks.add(task);
+ }
+ states.put(i++, state);
+ }
+ assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3), 0);
+ for (int client : states.keySet()) {
+ Set<Integer> oldActive = states.get(client).prevActiveTasks;
+ Set<Integer> oldAssigned = states.get(client).prevAssignedTasks;
+ Set<Integer> newActive = assignments.get(client).activeTasks;
+ Set<Integer> newAssigned = assignments.get(client).assignedTasks;
+
+ assertEquals(oldActive, newActive);
+ assertEquals(oldAssigned, newAssigned);
+ }
+
+ // # of clients < # of tasks
+ List<Set<Integer>> taskSets = mkList(mkSet(0, 1), mkSet(2, 3), mkSet(4, 5), mkSet(6, 7), mkSet(8, 9), mkSet(10, 11));
+ Collections.shuffle(taskSets);
+ states = new HashMap<>();
+ i = 0;
+ for (Set<Integer> taskSet : taskSets) {
+ ClientState<Integer> state = new ClientState<>(1d);
+ state.prevActiveTasks.addAll(taskSet);
+ state.prevAssignedTasks.addAll(taskSet);
+ states.put(i++, state);
+ }
+ assignments = TaskAssignor.assign(states, mkSet(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), 0);
+ for (int client : states.keySet()) {
+ Set<Integer> oldActive = states.get(client).prevActiveTasks;
+ Set<Integer> oldAssigned = states.get(client).prevAssignedTasks;
+ Set<Integer> newActive = assignments.get(client).activeTasks;
+ Set<Integer> newAssigned = assignments.get(client).assignedTasks;
+
+ Set<Integer> intersection = new HashSet<>();
+
+ intersection.addAll(oldActive);
+ intersection.retainAll(newActive);
+ assertTrue(intersection.size() > 0);
+
+ intersection.clear();
+ intersection.addAll(oldAssigned);
+ intersection.retainAll(newAssigned);
+ assertTrue(intersection.size() > 0);
+ }
+ }
+
+}