You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by gu...@apache.org on 2015/11/05 19:07:41 UTC

kafka git commit: KAFKA-2748: Ensure sink tasks commit offsets upon rebalance and rewind if the SinkTask flush fails.

Repository: kafka
Updated Branches:
  refs/heads/trunk d23785ff2 -> a4551773c


KAFKA-2748: Ensure sink tasks commit offsets upon rebalance and rewind if the SinkTask flush fails.

Also fix the incorrect consumer group ID setting which was giving each task its
own group instead of one for the entire sink connector.

Author: Ewen Cheslack-Postava <me...@ewencp.org>

Reviewers: Guozhang

Closes #431 from ewencp/kafka-2748-sink-task-rebalance-commit


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/a4551773
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/a4551773
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/a4551773

Branch: refs/heads/trunk
Commit: a4551773ca8cb3c1b75cea873715f30053d88f47
Parents: d23785f
Author: Ewen Cheslack-Postava <me...@ewencp.org>
Authored: Thu Nov 5 10:13:23 2015 -0800
Committer: Guozhang Wang <wa...@gmail.com>
Committed: Thu Nov 5 10:13:23 2015 -0800

----------------------------------------------------------------------
 .../kafka/copycat/runtime/WorkerSinkTask.java   |  74 +++++++----
 .../copycat/runtime/WorkerSinkTaskThread.java   |   4 +-
 .../copycat/runtime/WorkerSinkTaskTest.java     | 131 +++++++++++++------
 3 files changed, 140 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/a4551773/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTask.java
----------------------------------------------------------------------
diff --git a/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTask.java b/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTask.java
index 439a1f5..55a67c0 100644
--- a/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTask.java
+++ b/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTask.java
@@ -62,6 +62,7 @@ class WorkerSinkTask implements WorkerTask {
     private WorkerSinkTaskThread workThread;
     private KafkaConsumer<byte[], byte[]> consumer;
     private WorkerSinkTaskContext context;
+    private Map<TopicPartition, OffsetAndMetadata> lastCommittedOffsets;
 
     public WorkerSinkTask(ConnectorTaskId id, SinkTask task, WorkerConfig workerConfig,
                           Converter keyConverter, Converter valueConverter, Time time) {
@@ -75,11 +76,17 @@ class WorkerSinkTask implements WorkerTask {
 
     @Override
     public void start(Properties props) {
-        consumer = createConsumer(props);
+        consumer = createConsumer();
         context = new WorkerSinkTaskContext(consumer);
 
         // Ensure we're in the group so that if start() wants to rewind offsets, it will have an assignment of partitions
         // to work with. Any rewinding will be handled immediately when polling starts.
+        String topicsStr = props.getProperty(SinkTask.TOPICS_CONFIG);
+        if (topicsStr == null || topicsStr.isEmpty())
+            throw new CopycatException("Sink tasks require a list of topics.");
+        String[] topics = topicsStr.split(",");
+        log.debug("Task {} subscribing to topics {}", id, topics);
+        consumer.subscribe(Arrays.asList(topics), new HandleRebalance());
         consumer.poll(0);
 
         task.initialize(context);
@@ -92,7 +99,6 @@ class WorkerSinkTask implements WorkerTask {
     @Override
     public void stop() {
         // Offset commit is handled upon exit in work thread
-        task.stop();
         if (workThread != null)
             workThread.startGracefulShutdown();
         consumer.wakeup();
@@ -100,17 +106,18 @@ class WorkerSinkTask implements WorkerTask {
 
     @Override
     public boolean awaitStop(long timeoutMs) {
+        boolean success = true;
         if (workThread != null) {
             try {
-                boolean success = workThread.awaitShutdown(timeoutMs, TimeUnit.MILLISECONDS);
+                success = workThread.awaitShutdown(timeoutMs, TimeUnit.MILLISECONDS);
                 if (!success)
                     workThread.forceShutdown();
-                return success;
             } catch (InterruptedException e) {
-                return false;
+                success = false;
             }
         }
-        return true;
+        task.stop();
+        return success;
     }
 
     @Override
@@ -143,27 +150,32 @@ class WorkerSinkTask implements WorkerTask {
      * Starts an offset commit by flushing outstanding messages from the task and then starting
      * the write commit. This should only be invoked by the WorkerSinkTaskThread.
      **/
-    public void commitOffsets(long now, boolean sync, final int seqno, boolean flush) {
+    public void commitOffsets(boolean sync, final int seqno) {
         log.info("{} Committing offsets", this);
-        HashMap<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        final HashMap<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
         for (TopicPartition tp : consumer.assignment()) {
-            offsets.put(tp, new OffsetAndMetadata(consumer.position(tp)));
+            long pos = consumer.position(tp);
+            offsets.put(tp, new OffsetAndMetadata(pos));
+            log.trace("{} committing {} offset {}", id, tp, pos);
         }
-        // We only don't flush the task in one case: when shutting down, the task has already been
-        // stopped and all data should have already been flushed
-        if (flush) {
-            try {
-                task.flush(offsets);
-            } catch (Throwable t) {
-                log.error("Commit of {} offsets failed due to exception while flushing: {}", this, t);
-                workThread.onCommitCompleted(t, seqno);
-                return;
+
+        try {
+            task.flush(offsets);
+        } catch (Throwable t) {
+            log.error("Commit of {} offsets failed due to exception while flushing: {}", this, t);
+            log.error("Rewinding offsets to last committed offsets");
+            for (Map.Entry<TopicPartition, OffsetAndMetadata> entry : lastCommittedOffsets.entrySet()) {
+                log.debug("{} Rewinding topic partition {} to offset {}", id, entry.getKey(), entry.getValue().offset());
+                consumer.seek(entry.getKey(), entry.getValue().offset());
             }
+            workThread.onCommitCompleted(t, seqno);
+            return;
         }
 
         if (sync) {
             try {
                 consumer.commitSync(offsets);
+                lastCommittedOffsets = offsets;
             } catch (KafkaException e) {
                 workThread.onCommitCompleted(e, seqno);
             }
@@ -171,6 +183,7 @@ class WorkerSinkTask implements WorkerTask {
             OffsetCommitCallback cb = new OffsetCommitCallback() {
                 @Override
                 public void onComplete(Map<TopicPartition, OffsetAndMetadata> offsets, Exception error) {
+                    lastCommittedOffsets = offsets;
                     workThread.onCommitCompleted(error, seqno);
                 }
             };
@@ -186,16 +199,11 @@ class WorkerSinkTask implements WorkerTask {
         return workerConfig;
     }
 
-    private KafkaConsumer<byte[], byte[]> createConsumer(Properties taskProps) {
-        String topicsStr = taskProps.getProperty(SinkTask.TOPICS_CONFIG);
-        if (topicsStr == null || topicsStr.isEmpty())
-            throw new CopycatException("Sink tasks require a list of topics.");
-        String[] topics = topicsStr.split(",");
-
+    private KafkaConsumer<byte[], byte[]> createConsumer() {
         // Include any unknown worker configs so consumer configs can be set globally on the worker
         // and through to the task
         Properties props = workerConfig.unusedProperties();
-        props.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "copycat-" + id.toString());
+        props.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "copycat-" + id.connector());
         props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG,
                 Utils.join(workerConfig.getList(WorkerConfig.BOOTSTRAP_SERVERS_CONFIG), ","));
         props.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false");
@@ -210,9 +218,6 @@ class WorkerSinkTask implements WorkerTask {
             throw new CopycatException("Failed to create consumer", t);
         }
 
-        log.debug("Task {} subscribing to topics {}", id, topics);
-        newConsumer.subscribe(Arrays.asList(topics), new HandleRebalance());
-
         return newConsumer;
     }
 
@@ -264,12 +269,23 @@ class WorkerSinkTask implements WorkerTask {
     private class HandleRebalance implements ConsumerRebalanceListener {
         @Override
         public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
-            task.onPartitionsAssigned(partitions);
+            lastCommittedOffsets = new HashMap<>();
+            for (TopicPartition tp : partitions) {
+                long pos = consumer.position(tp);
+                lastCommittedOffsets.put(tp, new OffsetAndMetadata(pos));
+                log.trace("{} assigned topic partition {} with offset {}", id, tp, pos);
+            }
+            // Instead of invoking the assignment callback on initialization, we guarantee the consumer is ready upon
+            // task start. Since this callback gets invoked during that initial setup before we've started the task, we
+            // need to guard against invoking the user's callback method during that period.
+            if (workThread != null)
+                task.onPartitionsAssigned(partitions);
         }
 
         @Override
         public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
             task.onPartitionsRevoked(partitions);
+            commitOffsets(true, -1);
         }
     }
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/a4551773/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskThread.java
----------------------------------------------------------------------
diff --git a/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskThread.java b/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskThread.java
index 0e28c97..486407d 100644
--- a/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskThread.java
+++ b/copycat/runtime/src/main/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskThread.java
@@ -54,7 +54,7 @@ class WorkerSinkTaskThread extends ShutdownableThread {
             iteration();
         }
         // Make sure any uncommitted data has committed
-        task.commitOffsets(task.time().milliseconds(), true, -1, false);
+        task.commitOffsets(true, -1);
     }
 
     public void iteration() {
@@ -67,7 +67,7 @@ class WorkerSinkTaskThread extends ShutdownableThread {
                 commitSeqno += 1;
                 commitStarted = now;
             }
-            task.commitOffsets(now, false, commitSeqno, true);
+            task.commitOffsets(false, commitSeqno);
             nextCommit += task.workerConfig().getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG);
         }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/a4551773/copycat/runtime/src/test/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskTest.java
----------------------------------------------------------------------
diff --git a/copycat/runtime/src/test/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskTest.java b/copycat/runtime/src/test/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskTest.java
index acc1179..28e9e2e 100644
--- a/copycat/runtime/src/test/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskTest.java
+++ b/copycat/runtime/src/test/java/org/apache/kafka/copycat/runtime/WorkerSinkTaskTest.java
@@ -17,6 +17,7 @@
 
 package org.apache.kafka.copycat.runtime;
 
+import org.apache.kafka.clients.consumer.ConsumerRebalanceListener;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
@@ -28,6 +29,7 @@ import org.apache.kafka.copycat.data.Schema;
 import org.apache.kafka.copycat.data.SchemaAndValue;
 import org.apache.kafka.copycat.errors.CopycatException;
 import org.apache.kafka.copycat.runtime.standalone.StandaloneConfig;
+import org.apache.kafka.copycat.sink.SinkConnector;
 import org.apache.kafka.copycat.sink.SinkRecord;
 import org.apache.kafka.copycat.sink.SinkTask;
 import org.apache.kafka.copycat.storage.Converter;
@@ -55,6 +57,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Properties;
+import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -83,6 +86,11 @@ public class WorkerSinkTaskTest extends ThreadedTest {
     private static final TopicPartition TOPIC_PARTITION3 = new TopicPartition(TOPIC, PARTITION3);
     private static final TopicPartition UNASSIGNED_TOPIC_PARTITION = new TopicPartition(TOPIC, 200);
 
+    private static final Properties TASK_PROPS = new Properties();
+    static {
+        TASK_PROPS.put(SinkConnector.TOPICS_CONFIG, TOPIC);
+    }
+
     private ConnectorTaskId taskId = new ConnectorTaskId("job", 0);
     private Time time;
     @Mock private SinkTask sinkTask;
@@ -94,6 +102,7 @@ public class WorkerSinkTaskTest extends ThreadedTest {
     private WorkerSinkTask workerTask;
     @Mock private KafkaConsumer<byte[], byte[]> consumer;
     private WorkerSinkTaskThread workerThread;
+    private Capture<ConsumerRebalanceListener> rebalanceListener = EasyMock.newCapture();
 
     private long recordsReturned;
 
@@ -119,20 +128,19 @@ public class WorkerSinkTaskTest extends ThreadedTest {
 
     @Test
     public void testPollsInBackground() throws Exception {
-        Properties taskProps = new Properties();
-
-        expectInitializeTask(taskProps);
+        expectInitializeTask();
         Capture<Collection<SinkRecord>> capturedRecords = expectPolls(1L);
         expectStopTask(10L);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(taskProps);
+        workerTask.start(TASK_PROPS);
         for (int i = 0; i < 10; i++) {
             workerThread.iteration();
         }
         workerTask.stop();
-        // No need for awaitStop since the thread is mocked
+        workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
 
         // Verify contents match expected values, i.e. that they were translated properly. With max
@@ -183,18 +191,17 @@ public class WorkerSinkTaskTest extends ThreadedTest {
 
     @Test
     public void testCommit() throws Exception {
-        Properties taskProps = new Properties();
-
-        expectInitializeTask(taskProps);
+        expectInitializeTask();
         // Make each poll() take the offset commit interval
         Capture<Collection<SinkRecord>> capturedRecords
                 = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetFlush(1L, null, null, 0, true);
         expectStopTask(2);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(taskProps);
+        workerTask.start(TASK_PROPS);
         // First iteration gets one record
         workerThread.iteration();
         // Second triggers commit, gets a second offset
@@ -202,6 +209,7 @@ public class WorkerSinkTaskTest extends ThreadedTest {
         // Commit finishes synchronously for testing so we can check this immediately
         assertEquals(0, workerThread.commitFailures());
         workerTask.stop();
+        workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
 
         assertEquals(2, capturedRecords.getValues().size());
@@ -211,41 +219,79 @@ public class WorkerSinkTaskTest extends ThreadedTest {
 
     @Test
     public void testCommitTaskFlushFailure() throws Exception {
-        Properties taskProps = new Properties();
-
-        expectInitializeTask(taskProps);
-        Capture<Collection<SinkRecord>> capturedRecords
-                = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
+        expectInitializeTask();
+        Capture<Collection<SinkRecord>> capturedRecords = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetFlush(1L, new RuntimeException(), null, 0, true);
+        // Should rewind to last known good positions, which in this case will be the offsets loaded during initialization
+        // for all topic partitions
+        consumer.seek(TOPIC_PARTITION, FIRST_OFFSET);
+        PowerMock.expectLastCall();
+        consumer.seek(TOPIC_PARTITION2, FIRST_OFFSET);
+        PowerMock.expectLastCall();
+        consumer.seek(TOPIC_PARTITION3, FIRST_OFFSET);
+        PowerMock.expectLastCall();
         expectStopTask(2);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(taskProps);
+        workerTask.start(TASK_PROPS);
         // Second iteration triggers commit
         workerThread.iteration();
         workerThread.iteration();
         assertEquals(1, workerThread.commitFailures());
         assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
         workerTask.stop();
+        workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
 
         PowerMock.verifyAll();
     }
 
     @Test
-    public void testCommitConsumerFailure() throws Exception {
-        Properties taskProps = new Properties();
+    public void testCommitTaskSuccessAndFlushFailure() throws Exception {
+        // Validate that we rewind to the correct
+
+        expectInitializeTask();
+        Capture<Collection<SinkRecord>> capturedRecords = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
+        expectOffsetFlush(1L, null, null, 0, true);
+        expectOffsetFlush(2L, new RuntimeException(), null, 0, true);
+        // Should rewind to last known good positions, which in this case will be the offsets last committed. This test
+        // isn't quite accurate since we started with assigning 3 topic partitions and then only committed one, but what
+        // is important here is that we roll back to the last committed values.
+        consumer.seek(TOPIC_PARTITION, FIRST_OFFSET);
+        PowerMock.expectLastCall();
+        expectStopTask(2);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
+
+        PowerMock.replayAll();
+
+        workerTask.start(TASK_PROPS);
+        // Second iteration triggers first commit, third iteration triggers second (failing) commit
+        workerThread.iteration();
+        workerThread.iteration();
+        workerThread.iteration();
+        assertEquals(1, workerThread.commitFailures());
+        assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
+        workerTask.stop();
+        workerTask.awaitStop(Long.MAX_VALUE);
+        workerTask.close();
+
+        PowerMock.verifyAll();
+    }
 
-        expectInitializeTask(taskProps);
+    @Test
+    public void testCommitConsumerFailure() throws Exception {
+        expectInitializeTask();
         Capture<Collection<SinkRecord>> capturedRecords
                 = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetFlush(1L, null, new Exception(), 0, true);
         expectStopTask(2);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(taskProps);
+        workerTask.start(TASK_PROPS);
         // Second iteration triggers commit
         workerThread.iteration();
         workerThread.iteration();
@@ -253,6 +299,7 @@ public class WorkerSinkTaskTest extends ThreadedTest {
         assertEquals(1, workerThread.commitFailures());
         assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
         workerTask.stop();
+        workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
 
         PowerMock.verifyAll();
@@ -260,18 +307,17 @@ public class WorkerSinkTaskTest extends ThreadedTest {
 
     @Test
     public void testCommitTimeout() throws Exception {
-        Properties taskProps = new Properties();
-
-        expectInitializeTask(taskProps);
+        expectInitializeTask();
         // Cut down amount of time to pass in each poll so we trigger exactly 1 offset commit
         Capture<Collection<SinkRecord>> capturedRecords
                 = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT / 2);
         expectOffsetFlush(2L, null, null, WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_DEFAULT, false);
         expectStopTask(4);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(taskProps);
+        workerTask.start(TASK_PROPS);
         // Third iteration triggers commit, fourth gives a chance to trigger the timeout but doesn't
         // trigger another commit
         workerThread.iteration();
@@ -282,6 +328,7 @@ public class WorkerSinkTaskTest extends ThreadedTest {
         assertEquals(1, workerThread.commitFailures());
         assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
         workerTask.stop();
+        workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
 
         PowerMock.verifyAll();
@@ -291,10 +338,7 @@ public class WorkerSinkTaskTest extends ThreadedTest {
     public void testAssignmentPauseResume() throws Exception {
         // Just validate that the calls are passed through to the consumer, and that where appropriate errors are
         // converted
-
-        Properties taskProps = new Properties();
-
-        expectInitializeTask(taskProps);
+        expectInitializeTask();
 
         expectOnePoll().andAnswer(new IAnswer<Object>() {
             @Override
@@ -344,14 +388,16 @@ public class WorkerSinkTaskTest extends ThreadedTest {
         PowerMock.expectLastCall();
 
         expectStopTask(0);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(taskProps);
+        workerTask.start(TASK_PROPS);
         workerThread.iteration();
         workerThread.iteration();
         workerThread.iteration();
         workerTask.stop();
+        workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
 
         PowerMock.verifyAll();
@@ -359,8 +405,7 @@ public class WorkerSinkTaskTest extends ThreadedTest {
 
     @Test
     public void testRewind() throws Exception {
-        Properties taskProps = new Properties();
-        expectInitializeTask(taskProps);
+        expectInitializeTask();
         final long startOffset = 40L;
         final Map<TopicPartition, Long> offsets = new HashMap<>();
 
@@ -386,31 +431,41 @@ public class WorkerSinkTaskTest extends ThreadedTest {
         });
 
         expectStopTask(3);
+        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(taskProps);
+        workerTask.start(TASK_PROPS);
         workerThread.iteration();
         workerThread.iteration();
         workerTask.stop();
-        // No need for awaitStop since the thread is mocked
+        workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
 
         PowerMock.verifyAll();
     }
 
-    private void expectInitializeTask(Properties taskProps) throws Exception {
-        PowerMock.expectPrivate(workerTask, "createConsumer", taskProps)
-                .andReturn(consumer);
+    private void expectInitializeTask() throws Exception {
+        PowerMock.expectPrivate(workerTask, "createConsumer").andReturn(consumer);
 
-        EasyMock.expect(consumer.poll(EasyMock.anyLong())).andReturn(ConsumerRecords.<byte[], byte[]>empty());
+        consumer.subscribe(EasyMock.eq(Arrays.asList(TOPIC)), EasyMock.capture(rebalanceListener));
+        EasyMock.expect(consumer.poll(EasyMock.anyLong())).andAnswer(new IAnswer<ConsumerRecords<byte[], byte[]>>() {
+            @Override
+            public ConsumerRecords<byte[], byte[]> answer() throws Throwable {
+                rebalanceListener.getValue().onPartitionsAssigned(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3));
+                return ConsumerRecords.empty();
+            }
+        });
+        EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
+        EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
+        EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET);
 
         sinkTask.initialize(EasyMock.capture(sinkTaskContext));
         PowerMock.expectLastCall();
-        sinkTask.start(taskProps);
+        sinkTask.start(TASK_PROPS);
         PowerMock.expectLastCall();
 
-        workerThread = PowerMock.createPartialMock(WorkerSinkTaskThread.class, new String[]{"start"},
+        workerThread = PowerMock.createPartialMock(WorkerSinkTaskThread.class, new String[]{"start", "awaitShutdown"},
                 workerTask, "mock-worker-thread", time,
                 workerConfig);
         PowerMock.expectPrivate(workerTask, "createWorkerThread")