You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@kafka.apache.org by ew...@apache.org on 2016/02/03 20:29:14 UTC

kafka git commit: KAFKA-3092: Replace SinkTask onPartitionsAssigned/onPartitionsRevoked with open/close

Repository: kafka
Updated Branches:
  refs/heads/trunk e8343e67e -> 1d80f563b


KAFKA-3092: Replace SinkTask onPartitionsAssigned/onPartitionsRevoked with open/close

Author: Jason Gustafson <ja...@confluent.io>

Reviewers: Liquan Pei <li...@gmail.com>, Ewen Cheslack-Postava <ew...@confluent.io>

Closes #815 from hachikuji/KAFKA-3092


Project: http://git-wip-us.apache.org/repos/asf/kafka/repo
Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/1d80f563
Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/1d80f563
Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/1d80f563

Branch: refs/heads/trunk
Commit: 1d80f563bcd043cd464003782802906b60a0ade8
Parents: e8343e6
Author: Jason Gustafson <ja...@confluent.io>
Authored: Wed Feb 3 11:28:58 2016 -0800
Committer: Ewen Cheslack-Postava <me...@ewencp.org>
Committed: Wed Feb 3 11:28:58 2016 -0800

----------------------------------------------------------------------
 .../org/apache/kafka/connect/sink/SinkTask.java |  64 ++++-
 .../apache/kafka/connect/runtime/Worker.java    |  22 +-
 .../kafka/connect/runtime/WorkerSinkTask.java   | 257 +++++++++++--------
 .../connect/runtime/WorkerSinkTaskThread.java   | 112 --------
 .../kafka/connect/runtime/WorkerSourceTask.java | 152 +++++------
 .../kafka/connect/runtime/WorkerTask.java       |  81 +++++-
 .../connect/runtime/WorkerSinkTaskTest.java     |  66 ++---
 .../runtime/WorkerSinkTaskThreadedTest.java     | 196 ++++++++------
 .../connect/runtime/WorkerSourceTaskTest.java   |  15 +-
 .../kafka/connect/runtime/WorkerTest.java       |   6 +-
 10 files changed, 512 insertions(+), 459 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java
----------------------------------------------------------------------
diff --git a/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java
index 85ce88a..3d0becc 100644
--- a/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java
+++ b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java
@@ -25,9 +25,32 @@ import java.util.Collection;
 import java.util.Map;
 
 /**
- * SinkTask is a Task takes records loaded from Kafka and sends them to another system. In
- * addition to the basic {@link #put} interface, SinkTasks must also implement {@link #flush}
- * to support offset commits.
+ * SinkTask is a Task that takes records loaded from Kafka and sends them to another system. Each task
+ * instance is assigned a set of partitions by the Connect framework and will handle all records received
+ * from those partitions. As records are fetched from Kafka, they will be passed to the sink task using the
+ * {@link #put(Collection)} API, which should either write them to the downstream system or batch them for
+ * later writing. Periodically, Connect will call {@link #flush(Map)} to ensure that batched records are
+ * actually pushed to the downstream system..
+ *
+ * Below we describe the lifecycle of a SinkTask.
+ *
+ * <ol>
+ *     <li><b>Initialization:</b> SinkTasks are first initialized using {@link #initialize(SinkTaskContext)}
+ *     to prepare the task's context and {@link #start(Map)} to accept configuration and start any services
+ *     needed for processing.</li>
+ *     <li><b>Partition Assignment:</b> After initialization, Connect will assign the task a set of partitions
+ *     using {@link #open(Collection)}. These partitions are owned exclusively by this task until they
+ *     have been closed with {@link #close(Collection)}.</li>
+ *     <li><b>Record Processing:</b> Once partitions have been opened for writing, Connect will begin forwarding
+ *     records from Kafka using the {@link #put(Collection)} API. Periodically, Connect will ask the task
+ *     to flush records using {@link #flush(Map)} as described above.</li>
+ *     <li><b>Partition Rebalancing:</b> Occasionally, Connect will need to change the assignment of this task.
+ *     When this happens, the currently assigned partitions will be closed with {@link #close(Collection)} and
+ *     the new assignment will be opened using {@link #open(Collection)}.</li>
+ *     <li><b>Shutdown:</b> When the task needs to be shutdown, Connect will close active partitions (if there
+ *     are any) and stop the task using {@link #stop()}</li>
+  * </ol>
+ *
  */
 @InterfaceStability.Unstable
 public abstract class SinkTask implements Task {
@@ -42,6 +65,11 @@ public abstract class SinkTask implements Task {
 
     protected SinkTaskContext context;
 
+    /**
+     * Initialize the context of this task. Note that the partition assignment will be empty until
+     * Connect has opened the partitions for writing with {@link #open(Collection)}.
+     * @param context The sink task's context
+     */
     public void initialize(SinkTaskContext context) {
         this.context = context;
     }
@@ -77,24 +105,38 @@ public abstract class SinkTask implements Task {
 
     /**
      * The SinkTask use this method to create writers for newly assigned partitions in case of partition
-     * re-assignment. In partition re-assignment, some new partitions may be assigned to the SinkTask.
-     * The SinkTask needs to create writers and perform necessary recovery for the newly assigned partitions.
-     * This method will be called after partition re-assignment completes and before the SinkTask starts
+     * rebalance. This method will be called after partition re-assignment completes and before the SinkTask starts
      * fetching data. Note that any errors raised from this method will cause the task to stop.
      * @param partitions The list of partitions that are now assigned to the task (may include
      *                   partitions previously assigned to the task)
      */
+    public void open(Collection<TopicPartition> partitions) {
+        this.onPartitionsAssigned(partitions);
+    }
+
+    /**
+     * @deprecated Use {@link #open(Collection)} for partition initialization.
+     */
+    @Deprecated
     public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
     }
 
     /**
-     * The SinkTask use this method to close writers and commit offsets for partitions that are no
+     * The SinkTask use this method to close writers for partitions that are no
      * longer assigned to the SinkTask. This method will be called before a rebalance operation starts
-     * and after the SinkTask stops fetching data. Note that any errors raised from this method will cause
-     * the task to stop.
-     * @param partitions The list of partitions that were assigned to the consumer on the last
-     *                   rebalance
+     * and after the SinkTask stops fetching data. After being closed, Connect will not write
+     * any records to the task until a new set of partitions has been opened. Note that any errors raised
+     * from this method will cause the task to stop.
+     * @param partitions The list of partitions that should be closed
+     */
+    public void close(Collection<TopicPartition> partitions) {
+        this.onPartitionsRevoked(partitions);
+    }
+
+    /**
+     * @deprecated Use {@link #close(Collection)} instead for partition cleanup.
      */
+    @Deprecated
     public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
     }
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java
index 88b4c10..0a4bb7f 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java
@@ -17,11 +17,11 @@
 
 package org.apache.kafka.connect.runtime;
 
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.KafkaException;
 import org.apache.kafka.common.utils.SystemTime;
 import org.apache.kafka.common.utils.Time;
-import org.apache.kafka.clients.producer.KafkaProducer;
-import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.utils.Utils;
 import org.apache.kafka.connect.connector.Connector;
 import org.apache.kafka.connect.connector.ConnectorContext;
@@ -33,8 +33,8 @@ import org.apache.kafka.connect.source.SourceTask;
 import org.apache.kafka.connect.storage.Converter;
 import org.apache.kafka.connect.storage.OffsetBackingStore;
 import org.apache.kafka.connect.storage.OffsetStorageReader;
-import org.apache.kafka.connect.storage.OffsetStorageWriter;
 import org.apache.kafka.connect.storage.OffsetStorageReaderImpl;
+import org.apache.kafka.connect.storage.OffsetStorageWriter;
 import org.apache.kafka.connect.util.ConnectorTaskId;
 import org.reflections.Reflections;
 import org.reflections.util.ClasspathHelper;
@@ -48,6 +48,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 
 
 /**
@@ -62,8 +64,10 @@ import java.util.Set;
 public class Worker {
     private static final Logger log = LoggerFactory.getLogger(Worker.class);
 
-    private Time time;
-    private WorkerConfig config;
+    private final ExecutorService executor;
+    private final Time time;
+    private final WorkerConfig config;
+
     private Converter keyConverter;
     private Converter valueConverter;
     private Converter internalKeyConverter;
@@ -80,6 +84,7 @@ public class Worker {
 
     @SuppressWarnings("unchecked")
     public Worker(Time time, WorkerConfig config, OffsetBackingStore offsetBackingStore) {
+        this.executor = Executors.newCachedThreadPool();
         this.time = time;
         this.config = config;
         this.keyConverter = config.getConfiguredInstance(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, Converter.class);
@@ -154,7 +159,6 @@ public class Worker {
             log.debug("Waiting for task {} to finish shutting down", task);
             if (!task.awaitStop(Math.max(limit - time.milliseconds(), 0)))
                 log.error("Graceful shutdown of task {} failed.", task);
-            task.close();
         }
 
         long timeoutMs = limit - time.milliseconds();
@@ -342,7 +346,9 @@ public class Worker {
 
         // Start the task before adding modifying any state, any exceptions are caught higher up the
         // call chain and there's no cleanup to do here
-        workerTask.start(taskConfig.originalsStrings());
+        workerTask.initialize(taskConfig.originalsStrings());
+        executor.submit(workerTask);
+
         if (task instanceof SourceTask) {
             WorkerSourceTask workerSourceTask = (WorkerSourceTask) workerTask;
             sourceTaskOffsetCommitter.schedule(id, workerSourceTask);
@@ -367,7 +373,6 @@ public class Worker {
         task.stop();
         if (!task.awaitStop(config.getLong(WorkerConfig.TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_CONFIG)))
             log.error("Graceful stop of task {} failed.", task);
-        task.close();
         tasks.remove(id);
     }
 
@@ -394,4 +399,5 @@ public class Worker {
     public Converter getInternalValueConverter() {
         return internalValueConverter;
     }
+
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
index f48a734..8c5bd9f 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java
@@ -48,119 +48,155 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.TimeUnit;
 
 /**
  * WorkerTask that uses a SinkTask to export data from Kafka.
  */
-class WorkerSinkTask implements WorkerTask {
+class WorkerSinkTask extends WorkerTask {
     private static final Logger log = LoggerFactory.getLogger(WorkerSinkTask.class);
 
-    private final ConnectorTaskId id;
-    private final SinkTask task;
     private final WorkerConfig workerConfig;
+    private final SinkTask task;
+    private Map<String, String> taskConfig;
     private final Time time;
     private final Converter keyConverter;
     private final Converter valueConverter;
-    private WorkerSinkTaskThread workThread;
-    private Map<String, String> taskProps;
     private KafkaConsumer<byte[], byte[]> consumer;
     private WorkerSinkTaskContext context;
-    private boolean started;
     private final List<SinkRecord> messageBatch;
     private Map<TopicPartition, OffsetAndMetadata> lastCommittedOffsets;
     private Map<TopicPartition, OffsetAndMetadata> currentOffsets;
-    private boolean pausedForRedelivery;
     private RuntimeException rebalanceException;
+    private long nextCommit;
+    private int commitSeqno;
+    private long commitStarted;
+    private int commitFailures;
+    private boolean pausedForRedelivery;
+    private boolean committing;
+
+    public WorkerSinkTask(ConnectorTaskId id,
+                          SinkTask task,
+                          WorkerConfig workerConfig,
+                          Converter keyConverter,
+                          Converter valueConverter,
+                          Time time) {
+        super(id);
 
-    public WorkerSinkTask(ConnectorTaskId id, SinkTask task, WorkerConfig workerConfig,
-                          Converter keyConverter, Converter valueConverter, Time time) {
-        this.id = id;
-        this.task = task;
         this.workerConfig = workerConfig;
+        this.task = task;
         this.keyConverter = keyConverter;
         this.valueConverter = valueConverter;
         this.time = time;
-        this.started = false;
         this.messageBatch = new ArrayList<>();
         this.currentOffsets = new HashMap<>();
         this.pausedForRedelivery = false;
         this.rebalanceException = null;
+        this.nextCommit = time.milliseconds() +
+                workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG);
+        this.committing = false;
+        this.commitSeqno = 0;
+        this.commitStarted = -1;
+        this.commitFailures = 0;
     }
 
     @Override
-    public void start(Map<String, String> props) {
-        taskProps = props;
-        consumer = createConsumer();
-        context = new WorkerSinkTaskContext(consumer);
-
-        workThread = createWorkerThread();
-        workThread.start();
+    public void initialize(Map<String, String> taskConfig) {
+        this.taskConfig = taskConfig;
+        this.consumer = createConsumer();
+        this.context = new WorkerSinkTaskContext(consumer);
     }
 
     @Override
     public void stop() {
         // Offset commit is handled upon exit in work thread
-        if (workThread != null)
-            workThread.startGracefulShutdown();
+        super.stop();
         consumer.wakeup();
     }
 
     @Override
-    public boolean awaitStop(long timeoutMs) {
-        boolean success = true;
-        if (workThread != null) {
-            try {
-                success = workThread.awaitShutdown(timeoutMs, TimeUnit.MILLISECONDS);
-                if (!success)
-                    workThread.forceShutdown();
-            } catch (InterruptedException e) {
-                success = false;
-            }
-        }
-        task.stop();
-        return success;
-    }
-
-    @Override
-    public void close() {
+    protected void close() {
         // FIXME Kafka needs to add a timeout parameter here for us to properly obey the timeout
         // passed in
+        task.stop();
         if (consumer != null)
             consumer.close();
     }
 
+    @Override
+    public void execute() {
+        initializeAndStart();
+        try {
+            while (!isStopping())
+                iteration();
+        } finally {
+            // Make sure any uncommitted data has been committed and the task has
+            // a chance to clean up its state
+            closePartitions();
+        }
+    }
+
+    protected void iteration() {
+        long now = time.milliseconds();
+
+        // Maybe commit
+        if (!committing && now >= nextCommit) {
+            commitOffsets(now, false);
+            nextCommit += workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG);
+        }
+
+        // Check for timed out commits
+        long commitTimeout = commitStarted + workerConfig.getLong(
+                WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_CONFIG);
+        if (committing && now >= commitTimeout) {
+            log.warn("Commit of {} offsets timed out", this);
+            commitFailures++;
+            committing = false;
+        }
+
+        // And process messages
+        long timeoutMs = Math.max(nextCommit - now, 0);
+        poll(timeoutMs);
+    }
+
+    private void onCommitCompleted(Throwable error, long seqno) {
+        if (commitSeqno != seqno) {
+            log.debug("Got callback for timed out commit {}: {}, but most recent commit is {}",
+                    this,
+                    seqno, commitSeqno);
+        } else {
+            if (error != null) {
+                log.error("Commit of {} offsets threw an unexpected exception: ", this, error);
+                commitFailures++;
+            } else {
+                log.debug("Finished {} offset commit successfully in {} ms",
+                        this, time.milliseconds() - commitStarted);
+                commitFailures = 0;
+            }
+            committing = false;
+        }
+    }
+
+    public int commitFailures() {
+        return commitFailures;
+    }
+
     /**
-     * Performs initial join process for consumer group, ensures we have an assignment, and initializes + starts the
-     * SinkTask.
-     *
-     * @returns true if successful, false if joining the consumer group was interrupted
+     * Initializes and starts the SinkTask.
      */
-    public boolean joinConsumerGroupAndStart() {
-        String topicsStr = taskProps.get(SinkTask.TOPICS_CONFIG);
+    protected void initializeAndStart() {
+        String topicsStr = taskConfig.get(SinkTask.TOPICS_CONFIG);
         if (topicsStr == null || topicsStr.isEmpty())
             throw new ConnectException("Sink tasks require a list of topics.");
         String[] topics = topicsStr.split(",");
         log.debug("Task {} subscribing to topics {}", id, topics);
         consumer.subscribe(Arrays.asList(topics), new HandleRebalance());
-
-        // Ensure we're in the group so that if start() wants to rewind offsets, it will have an assignment of partitions
-        // to work with. Any rewinding will be handled immediately when polling starts.
-        try {
-            pollConsumer(0);
-        } catch (WakeupException e) {
-            log.error("Sink task {} was stopped before completing join group. Task initialization and start is being skipped", this);
-            return false;
-        }
         task.initialize(context);
-        task.start(taskProps);
+        task.start(taskConfig);
         log.info("Sink task {} finished initialization and start", this);
-        started = true;
-        return true;
     }
 
     /** Poll for new messages with the given timeout. Should only be invoked by the worker thread. */
-    public void poll(long timeoutMs) {
+    protected void poll(long timeoutMs) {
         try {
             rewind();
             long retryTimeout = context.timeout();
@@ -183,55 +219,62 @@ class WorkerSinkTask implements WorkerTask {
 
     /**
      * Starts an offset commit by flushing outstanding messages from the task and then starting
-     * the write commit. This should only be invoked by the WorkerSinkTaskThread.
+     * the write commit.
      **/
-    public void commitOffsets(boolean sync, final int seqno) {
+    private void doCommit(Map<TopicPartition, OffsetAndMetadata> offsets, boolean closing, final int seqno) {
         log.info("{} Committing offsets", this);
-
-        final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>(currentOffsets);
-
-        try {
-            task.flush(offsets);
-        } catch (Throwable t) {
-            log.error("Commit of {} offsets failed due to exception while flushing:", this, t);
-            log.error("Rewinding offsets to last committed offsets");
-            for (Map.Entry<TopicPartition, OffsetAndMetadata> entry : lastCommittedOffsets.entrySet()) {
-                log.debug("{} Rewinding topic partition {} to offset {}", id, entry.getKey(), entry.getValue().offset());
-                consumer.seek(entry.getKey(), entry.getValue().offset());
-            }
-            currentOffsets = new HashMap<>(lastCommittedOffsets);
-            workThread.onCommitCompleted(t, seqno);
-            return;
-        }
-
-        if (sync) {
+        if (closing) {
             try {
                 consumer.commitSync(offsets);
                 lastCommittedOffsets = offsets;
-                workThread.onCommitCompleted(null, seqno);
+                onCommitCompleted(null, seqno);
             } catch (KafkaException e) {
-                workThread.onCommitCompleted(e, seqno);
+                onCommitCompleted(e, seqno);
             }
         } else {
             OffsetCommitCallback cb = new OffsetCommitCallback() {
                 @Override
                 public void onComplete(Map<TopicPartition, OffsetAndMetadata> offsets, Exception error) {
                     lastCommittedOffsets = offsets;
-                    workThread.onCommitCompleted(error, seqno);
+                    onCommitCompleted(error, seqno);
                 }
             };
             consumer.commitAsync(offsets, cb);
         }
     }
 
-    public Time time() {
-        return time;
-    }
+    private void commitOffsets(long now, boolean closing) {
+        if (currentOffsets.isEmpty())
+            return;
 
-    public WorkerConfig workerConfig() {
-        return workerConfig;
+        committing = true;
+        commitSeqno += 1;
+        commitStarted = now;
+
+        Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>(currentOffsets);
+        try {
+            task.flush(offsets);
+        } catch (Throwable t) {
+            log.error("Commit of {} offsets failed due to exception while flushing:", this, t);
+            log.error("Rewinding offsets to last committed offsets");
+            for (Map.Entry<TopicPartition, OffsetAndMetadata> entry : lastCommittedOffsets.entrySet()) {
+                log.debug("{} Rewinding topic partition {} to offset {}", id, entry.getKey(), entry.getValue().offset());
+                consumer.seek(entry.getKey(), entry.getValue().offset());
+            }
+            currentOffsets = new HashMap<>(lastCommittedOffsets);
+            onCommitCompleted(t, commitSeqno);
+            return;
+        } finally {
+            // Close the task if needed before committing the offsets. This is basically the last chance for
+            // the connector to actually flush data that has been written to it.
+            if (closing)
+                task.close(currentOffsets.keySet());
+        }
+
+        doCommit(offsets, closing, commitSeqno);
     }
 
+
     @Override
     public String toString() {
         return "WorkerSinkTask{" +
@@ -277,10 +320,6 @@ class WorkerSinkTask implements WorkerTask {
         return newConsumer;
     }
 
-    private WorkerSinkTaskThread createWorkerThread() {
-        return new WorkerSinkTaskThread(this, "WorkerSinkTask-" + id, time, workerConfig);
-    }
-
     private void convertMessages(ConsumerRecords<byte[], byte[]> msgs) {
         for (ConsumerRecord<byte[], byte[]> msg : msgs) {
             log.trace("Consuming message with key {}, value {}", msg.key(), msg.value());
@@ -321,8 +360,8 @@ class WorkerSinkTask implements WorkerTask {
                 consumer.pause(tp);
             // Let this exit normally, the batch will be reprocessed on the next loop.
         } catch (Throwable t) {
-            log.error("Task {} threw an uncaught and unrecoverable exception", id);
-            log.error("Task is being killed and will not recover until manually restarted:", t);
+            log.error("Task {} threw an uncaught and unrecoverable exception", id, t);
+            log.error("Task is being killed and will not recover until manually restarted");
             throw new ConnectException("Exiting WorkerSinkTask due to unrecoverable exception.");
         }
     }
@@ -344,12 +383,20 @@ class WorkerSinkTask implements WorkerTask {
         context.clearOffsets();
     }
 
+    private void openPartitions(Collection<TopicPartition> partitions) {
+        if (partitions.isEmpty())
+            return;
+
+        task.open(partitions);
+    }
+
+    private void closePartitions() {
+        commitOffsets(time.milliseconds(), true);
+    }
+
     private class HandleRebalance implements ConsumerRebalanceListener {
         @Override
         public void onPartitionsAssigned(Collection<TopicPartition> partitions) {
-            if (rebalanceException != null)
-                return;
-
             lastCommittedOffsets = new HashMap<>();
             currentOffsets = new HashMap<>();
             for (TopicPartition tp : partitions) {
@@ -364,6 +411,7 @@ class WorkerSinkTask implements WorkerTask {
             // Also make sure our tracking of paused partitions is updated to remove any partitions we no longer own.
             if (pausedForRedelivery) {
                 pausedForRedelivery = false;
+
                 Set<TopicPartition> assigned = new HashSet<>(partitions);
                 Set<TopicPartition> taskPaused = context.pausedPartitions();
 
@@ -383,9 +431,9 @@ class WorkerSinkTask implements WorkerTask {
             // Instead of invoking the assignment callback on initialization, we guarantee the consumer is ready upon
             // task start. Since this callback gets invoked during that initial setup before we've started the task, we
             // need to guard against invoking the user's callback method during that period.
-            if (started) {
+            if (rebalanceException == null) {
                 try {
-                    task.onPartitionsAssigned(partitions);
+                    openPartitions(partitions);
                 } catch (RuntimeException e) {
                     // The consumer swallows exceptions raised in the rebalance listener, so we need to store
                     // exceptions and rethrow when poll() returns.
@@ -396,15 +444,12 @@ class WorkerSinkTask implements WorkerTask {
 
         @Override
         public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
-            if (started) {
-                try {
-                    task.onPartitionsRevoked(partitions);
-                    commitOffsets(true, -1);
-                } catch (RuntimeException e) {
-                    // The consumer swallows exceptions raised in the rebalance listener, so we need to store
-                    // exceptions and rethrow when poll() returns.
-                    rebalanceException = e;
-                }
+            try {
+                closePartitions();
+            } catch (RuntimeException e) {
+                // The consumer swallows exceptions raised in the rebalance listener, so we need to store
+                // exceptions and rethrow when poll() returns.
+                rebalanceException = e;
             }
 
             // Make sure we don't have any leftover data since offsets will be reset to committed positions

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThread.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThread.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThread.java
deleted file mode 100644
index 93e210a..0000000
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThread.java
+++ /dev/null
@@ -1,112 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- * <p/>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p/>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- **/
-
-package org.apache.kafka.connect.runtime;
-
-import org.apache.kafka.common.utils.Time;
-import org.apache.kafka.connect.util.ShutdownableThread;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * Worker thread for a WorkerSinkTask. These classes are very tightly coupled, but separated to
- * simplify testing.
- */
-class WorkerSinkTaskThread extends ShutdownableThread {
-    private static final Logger log = LoggerFactory.getLogger(WorkerSinkTask.class);
-
-    private final WorkerSinkTask task;
-    private long nextCommit;
-    private boolean committing;
-    private int commitSeqno;
-    private long commitStarted;
-    private int commitFailures;
-
-    public WorkerSinkTaskThread(WorkerSinkTask task, String name, Time time,
-                                WorkerConfig workerConfig) {
-        super(name);
-        this.task = task;
-        this.nextCommit = time.milliseconds() +
-                workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG);
-        this.committing = false;
-        this.commitSeqno = 0;
-        this.commitStarted = -1;
-        this.commitFailures = 0;
-    }
-
-    @Override
-    public void execute() {
-        // Try to join and start. If we're interrupted before this completes, bail.
-        if (!task.joinConsumerGroupAndStart())
-            return;
-
-        while (getRunning()) {
-            iteration();
-        }
-
-        // Make sure any uncommitted data has committed
-        task.commitOffsets(true, -1);
-    }
-
-    public void iteration() {
-        long now = task.time().milliseconds();
-
-        // Maybe commit
-        if (!committing && now >= nextCommit) {
-            committing = true;
-            commitSeqno += 1;
-            commitStarted = now;
-            task.commitOffsets(false, commitSeqno);
-            nextCommit += task.workerConfig().getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG);
-        }
-
-        // Check for timed out commits
-        long commitTimeout = commitStarted + task.workerConfig().getLong(
-                WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_CONFIG);
-        if (committing && now >= commitTimeout) {
-            log.warn("Commit of {} offsets timed out", task);
-            commitFailures++;
-            committing = false;
-        }
-
-        // And process messages
-        long timeoutMs = Math.max(nextCommit - now, 0);
-        task.poll(timeoutMs);
-    }
-
-    public void onCommitCompleted(Throwable error, long seqno) {
-        if (commitSeqno != seqno) {
-            log.debug("Got callback for timed out commit {}: {}, but most recent commit is {}",
-                    this,
-                    seqno, commitSeqno);
-        } else {
-            if (error != null) {
-                log.error("Commit of {} offsets threw an unexpected exception: ", task, error);
-                commitFailures++;
-            } else {
-                log.debug("Finished {} offset commit successfully in {} ms",
-                        task, task.time().milliseconds() - commitStarted);
-                commitFailures = 0;
-            }
-            committing = false;
-        }
-    }
-
-    public int commitFailures() {
-        return commitFailures;
-    }
-}

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java
index 6c61d79..30c2262 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java
@@ -31,7 +31,6 @@ import org.apache.kafka.connect.storage.Converter;
 import org.apache.kafka.connect.storage.OffsetStorageReader;
 import org.apache.kafka.connect.storage.OffsetStorageWriter;
 import org.apache.kafka.connect.util.ConnectorTaskId;
-import org.apache.kafka.connect.util.ShutdownableThread;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -47,20 +46,18 @@ import java.util.concurrent.TimeoutException;
 /**
  * WorkerTask that uses a SourceTask to ingest data into Kafka.
  */
-class WorkerSourceTask implements WorkerTask {
+class WorkerSourceTask extends WorkerTask {
     private static final Logger log = LoggerFactory.getLogger(WorkerSourceTask.class);
 
     private static final long SEND_FAILED_BACKOFF_MS = 100;
 
-    private final ConnectorTaskId id;
+    private final WorkerConfig workerConfig;
     private final SourceTask task;
     private final Converter keyConverter;
     private final Converter valueConverter;
     private KafkaProducer<byte[], byte[]> producer;
-    private WorkerSourceTaskThread workThread;
     private final OffsetStorageReader offsetReader;
     private final OffsetStorageWriter offsetWriter;
-    private final WorkerConfig workerConfig;
     private final Time time;
 
     private List<SourceRecord> toSend;
@@ -73,19 +70,28 @@ class WorkerSourceTask implements WorkerTask {
     private boolean flushing;
     private CountDownLatch stopRequestedLatch;
 
-    public WorkerSourceTask(ConnectorTaskId id, SourceTask task,
-                            Converter keyConverter, Converter valueConverter,
+    private Map<String, String> taskConfig;
+    private boolean finishedStart = false;
+    private boolean startedShutdownBeforeStartCompleted = false;
+
+    public WorkerSourceTask(ConnectorTaskId id,
+                            SourceTask task,
+                            Converter keyConverter,
+                            Converter valueConverter,
                             KafkaProducer<byte[], byte[]> producer,
-                            OffsetStorageReader offsetReader, OffsetStorageWriter offsetWriter,
-                            WorkerConfig workerConfig, Time time) {
-        this.id = id;
+                            OffsetStorageReader offsetReader,
+                            OffsetStorageWriter offsetWriter,
+                            WorkerConfig workerConfig,
+                            Time time) {
+        super(id);
+
+        this.workerConfig = workerConfig;
         this.task = task;
         this.keyConverter = keyConverter;
         this.valueConverter = valueConverter;
         this.producer = producer;
         this.offsetReader = offsetReader;
         this.offsetWriter = offsetWriter;
-        this.workerConfig = workerConfig;
         this.time = time;
 
         this.toSend = null;
@@ -97,37 +103,60 @@ class WorkerSourceTask implements WorkerTask {
     }
 
     @Override
-    public void start(Map<String, String> props) {
-        workThread = new WorkerSourceTaskThread("WorkerSourceTask-" + id, props);
-        workThread.start();
+    public void initialize(Map<String, String> config) {
+        this.taskConfig = config;
+    }
+
+    protected void close() {
+        // nothing to do
     }
 
     @Override
     public void stop() {
-        if (workThread != null) {
-            workThread.startGracefulShutdown();
-            stopRequestedLatch.countDown();
+        super.stop();
+        stopRequestedLatch.countDown();
+        synchronized (this) {
+            if (finishedStart)
+                task.stop();
+            else
+                startedShutdownBeforeStartCompleted = true;
         }
     }
 
     @Override
-    public boolean awaitStop(long timeoutMs) {
-        boolean success = true;
-        if (workThread != null) {
-            try {
-                success = workThread.awaitShutdown(timeoutMs, TimeUnit.MILLISECONDS);
-                if (!success)
-                    workThread.forceShutdown();
-            } catch (InterruptedException e) {
-                success = false;
+    public void execute() {
+        try {
+            task.initialize(new WorkerSourceTaskContext(offsetReader));
+            task.start(taskConfig);
+            log.info("Source task {} finished initialization and start", this);
+            synchronized (this) {
+                if (startedShutdownBeforeStartCompleted) {
+                    task.stop();
+                    return;
+                }
+                finishedStart = true;
             }
+
+            while (!isStopping()) {
+                if (toSend == null)
+                    toSend = task.poll();
+                if (toSend == null)
+                    continue;
+                if (!sendRecords())
+                    stopRequestedLatch.await(SEND_FAILED_BACKOFF_MS, TimeUnit.MILLISECONDS);
+            }
+        } catch (InterruptedException e) {
+            // Ignore and allow to exit.
+        } catch (Throwable t) {
+            log.error("Task {} threw an uncaught and unrecoverable exception", id);
+            log.error("Task is being killed and will not recover until manually restarted:", t);
+            // It should still be safe to let this fall through and commit offsets since this exception would have
+            // simply resulted in not getting more records but all the existing records should be ok to flush
+            // and commit offsets. Worst case, task.flush() will also throw an exception causing the offset commit
+            // to fail.
         }
-        return success;
-    }
 
-    @Override
-    public void close() {
-        // Nothing to do
+        commitOffsets();
     }
 
     /**
@@ -323,67 +352,6 @@ class WorkerSourceTask implements WorkerTask {
         flushing = false;
     }
 
-
-    private class WorkerSourceTaskThread extends ShutdownableThread {
-        private Map<String, String> workerProps;
-        private boolean finishedStart;
-        private boolean startedShutdownBeforeStartCompleted;
-
-        public WorkerSourceTaskThread(String name, Map<String, String> workerProps) {
-            super(name);
-            this.workerProps = workerProps;
-            this.finishedStart = false;
-            this.startedShutdownBeforeStartCompleted = false;
-        }
-
-        @Override
-        public void execute() {
-            try {
-                task.initialize(new WorkerSourceTaskContext(offsetReader));
-                task.start(workerProps);
-                log.info("Source task {} finished initialization and start", this);
-                synchronized (this) {
-                    if (startedShutdownBeforeStartCompleted) {
-                        task.stop();
-                        return;
-                    }
-                    finishedStart = true;
-                }
-
-                while (getRunning()) {
-                    if (toSend == null)
-                        toSend = task.poll();
-                    if (toSend == null)
-                        continue;
-                    if (!sendRecords())
-                        stopRequestedLatch.await(SEND_FAILED_BACKOFF_MS, TimeUnit.MILLISECONDS);
-                }
-            } catch (InterruptedException e) {
-                // Ignore and allow to exit.
-            } catch (Throwable t) {
-                log.error("Task {} threw an uncaught and unrecoverable exception", id);
-                log.error("Task is being killed and will not recover until manually restarted:", t);
-                // It should still be safe to let this fall through and commit offsets since this exception would have
-                // simply resulted in not getting more records but all the existing records should be ok to flush
-                // and commit offsets. Worst case, task.flush() will also throw an exception causing the offset commit
-                // to fail.
-            }
-
-            commitOffsets();
-        }
-
-        @Override
-        public void startGracefulShutdown() {
-            super.startGracefulShutdown();
-            synchronized (this) {
-                if (finishedStart)
-                    task.stop();
-                else
-                    startedShutdownBeforeStartCompleted = true;
-            }
-        }
-    }
-
     @Override
     public String toString() {
         return "WorkerSourceTask{" +

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java
index 66fc45b..b4d427a 100644
--- a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java
+++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java
@@ -17,25 +17,48 @@
 
 package org.apache.kafka.connect.runtime;
 
+import org.apache.kafka.connect.util.ConnectorTaskId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 /**
  * Handles processing for an individual task. This interface only provides the basic methods
  * used by {@link Worker} to manage the tasks. Implementations combine a user-specified Task with
  * Kafka to create a data flow.
  */
-interface WorkerTask {
+abstract class WorkerTask implements Runnable {
+    private static final Logger log = LoggerFactory.getLogger(WorkerTask.class);
+
+    protected final ConnectorTaskId id;
+    private final AtomicBoolean stopping;
+    private final AtomicBoolean running;
+    private final CountDownLatch shutdownLatch;
+
+    public WorkerTask(ConnectorTaskId id) {
+        this.id = id;
+        this.stopping = new AtomicBoolean(false);
+        this.running = new AtomicBoolean(false);
+        this.shutdownLatch = new CountDownLatch(1);
+    }
+
     /**
-     * Start the Task
+     * Initialize the task for execution.
      * @param props initial configuration
      */
-    void start(Map<String, String> props);
+    public abstract void initialize(Map<String, String> props);
 
     /**
      * Stop this task from processing messages. This method does not block, it only triggers
      * shutdown. Use #{@link #awaitStop} to block until completion.
      */
-    void stop();
+    public void stop() {
+        this.stopping.set(true);
+    }
 
     /**
      * Wait for this task to finish stopping.
@@ -43,12 +66,48 @@ interface WorkerTask {
      * @param timeoutMs
      * @return true if successful, false if the timeout was reached
      */
-    boolean awaitStop(long timeoutMs);
+    public boolean awaitStop(long timeoutMs) {
+        if (!running.get())
+            return true;
+
+        try {
+            return shutdownLatch.await(timeoutMs, TimeUnit.MILLISECONDS);
+        } catch (InterruptedException e) {
+            return false;
+        }
+    }
+
+    protected abstract void execute();
+
+    protected abstract void close();
+
+    protected boolean isStopping() {
+        return stopping.get();
+    }
+
+    private void doClose() {
+        try {
+            close();
+        } catch (Throwable t) {
+            log.error("Unhandled exception in task shutdown {}", id, t);
+        } finally {
+            running.set(false);
+            shutdownLatch.countDown();
+        }
+    }
+
+    @Override
+    public void run() {
+        if (!this.running.compareAndSet(false, true))
+            throw new IllegalStateException("The task cannot be started while still running");
+
+        try {
+            execute();
+        } catch (Throwable t) {
+            log.error("Unhandled exception in task {}", id, t);
+        } finally {
+            doClose();
+        }
+    }
 
-    /**
-     * Close this task. This is different from #{@link #stop} and #{@link #awaitStop} in that the
-     * stop methods ensure processing has stopped but may leave resources allocated. This method
-     * should clean up all resources.
-     */
-    void close();
 }

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
index 305a61e..04b08b3 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java
@@ -97,8 +97,6 @@ public class WorkerSinkTaskTest {
     @Mock
     private Converter valueConverter;
     @Mock
-    private WorkerSinkTaskThread workerThread;
-    @Mock
     private KafkaConsumer<byte[], byte[]> consumer;
     private Capture<ConsumerRebalanceListener> rebalanceListener = EasyMock.newCapture();
 
@@ -116,7 +114,7 @@ public class WorkerSinkTaskTest {
         workerProps.put("internal.value.converter.schemas.enable", "false");
         workerConfig = new StandaloneConfig(workerProps);
         workerTask = PowerMock.createPartialMock(
-                WorkerSinkTask.class, new String[]{"createConsumer", "createWorkerThread"},
+                WorkerSinkTask.class, new String[]{"createConsumer"},
                 taskId, sinkTask, workerConfig, keyConverter, valueConverter, time);
 
         recordsReturned = 0;
@@ -125,6 +123,7 @@ public class WorkerSinkTaskTest {
     @Test
     public void testPollRedelivery() throws Exception {
         expectInitializeTask();
+        expectPollInitialAssignment();
 
         // If a retriable exception is thrown, we should redeliver the same batch, pausing the consumer in the meantime
         expectConsumerPoll(1);
@@ -152,8 +151,9 @@ public class WorkerSinkTaskTest {
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+        workerTask.poll(Long.MAX_VALUE);
         workerTask.poll(Long.MAX_VALUE);
         workerTask.poll(Long.MAX_VALUE);
 
@@ -165,12 +165,14 @@ public class WorkerSinkTaskTest {
         RuntimeException exception = new RuntimeException("Revocation error");
 
         expectInitializeTask();
+        expectPollInitialAssignment();
         expectRebalanceRevocationError(exception);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+        workerTask.poll(Long.MAX_VALUE);
         try {
             workerTask.poll(Long.MAX_VALUE);
             fail("Poll should have raised the rebalance exception");
@@ -186,12 +188,14 @@ public class WorkerSinkTaskTest {
         RuntimeException exception = new RuntimeException("Assignment error");
 
         expectInitializeTask();
+        expectPollInitialAssignment();
         expectRebalanceAssignmentError(exception);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+        workerTask.poll(Long.MAX_VALUE);
         try {
             workerTask.poll(Long.MAX_VALUE);
             fail("Poll should have raised the rebalance exception");
@@ -205,24 +209,9 @@ public class WorkerSinkTaskTest {
 
     private void expectInitializeTask() throws Exception {
         PowerMock.expectPrivate(workerTask, "createConsumer").andReturn(consumer);
-        PowerMock.expectPrivate(workerTask, "createWorkerThread")
-                .andReturn(workerThread);
-        workerThread.start();
-        PowerMock.expectLastCall();
-
         consumer.subscribe(EasyMock.eq(Arrays.asList(TOPIC)), EasyMock.capture(rebalanceListener));
         PowerMock.expectLastCall();
 
-        EasyMock.expect(consumer.poll(EasyMock.anyLong())).andAnswer(new IAnswer<ConsumerRecords<byte[], byte[]>>() {
-            @Override
-            public ConsumerRecords<byte[], byte[]> answer() throws Throwable {
-                rebalanceListener.getValue().onPartitionsAssigned(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2));
-                return ConsumerRecords.empty();
-            }
-        });
-        EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
-        EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
-
         sinkTask.initialize(EasyMock.capture(sinkTaskContext));
         PowerMock.expectLastCall();
         sinkTask.start(TASK_PROPS);
@@ -232,7 +221,7 @@ public class WorkerSinkTaskTest {
     private void expectRebalanceRevocationError(RuntimeException e) {
         final List<TopicPartition> partitions = Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2);
 
-        sinkTask.onPartitionsRevoked(partitions);
+        sinkTask.close(new HashSet<>(partitions));
         EasyMock.expectLastCall().andThrow(e);
 
         EasyMock.expect(consumer.poll(EasyMock.anyLong())).andAnswer(
@@ -248,7 +237,7 @@ public class WorkerSinkTaskTest {
     private void expectRebalanceAssignmentError(RuntimeException e) {
         final List<TopicPartition> partitions = Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2);
 
-        sinkTask.onPartitionsRevoked(partitions);
+        sinkTask.close(new HashSet<>(partitions));
         EasyMock.expectLastCall();
 
         sinkTask.flush(EasyMock.<Map<TopicPartition, OffsetAndMetadata>>anyObject());
@@ -257,13 +246,10 @@ public class WorkerSinkTaskTest {
         consumer.commitSync(EasyMock.<Map<TopicPartition, OffsetAndMetadata>>anyObject());
         EasyMock.expectLastCall();
 
-        workerThread.onCommitCompleted(EasyMock.<Throwable>isNull(), EasyMock.anyLong());
-        EasyMock.expectLastCall();
-
         EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
         EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
 
-        sinkTask.onPartitionsAssigned(partitions);
+        sinkTask.open(partitions);
         EasyMock.expectLastCall().andThrow(e);
 
         EasyMock.expect(consumer.poll(EasyMock.anyLong())).andAnswer(
@@ -277,6 +263,26 @@ public class WorkerSinkTaskTest {
                 });
     }
 
+    private void expectPollInitialAssignment() {
+        final List<TopicPartition> partitions = Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2);
+
+        sinkTask.open(partitions);
+        EasyMock.expectLastCall();
+
+        EasyMock.expect(consumer.poll(EasyMock.anyLong())).andAnswer(new IAnswer<ConsumerRecords<byte[], byte[]>>() {
+            @Override
+            public ConsumerRecords<byte[], byte[]> answer() throws Throwable {
+                rebalanceListener.getValue().onPartitionsAssigned(partitions);
+                return ConsumerRecords.empty();
+            }
+        });
+        EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET);
+        EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
+
+        sinkTask.put(Collections.<SinkRecord>emptyList());
+        EasyMock.expectLastCall();
+    }
+
     private void expectConsumerPoll(final int numMessages) {
         EasyMock.expect(consumer.poll(EasyMock.anyLong())).andAnswer(
                 new IAnswer<ConsumerRecords<byte[], byte[]>>() {

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
index 6915631..3bf653e 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java
@@ -55,8 +55,8 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
-import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -100,7 +100,6 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     private Converter valueConverter;
     private WorkerSinkTask workerTask;
     @Mock private KafkaConsumer<byte[], byte[]> consumer;
-    private WorkerSinkTaskThread workerThread;
     private Capture<ConsumerRebalanceListener> rebalanceListener = EasyMock.newCapture();
 
     private long recordsReturned;
@@ -119,7 +118,7 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
         workerProps.put("internal.value.converter.schemas.enable", "false");
         workerConfig = new StandaloneConfig(workerProps);
         workerTask = PowerMock.createPartialMock(
-                WorkerSinkTask.class, new String[]{"createConsumer", "createWorkerThread"},
+                WorkerSinkTask.class, new String[]{"createConsumer"},
                 taskId, sinkTask, workerConfig, keyConverter, valueConverter, time);
 
         recordsReturned = 0;
@@ -128,16 +127,22 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     @Test
     public void testPollsInBackground() throws Exception {
         expectInitializeTask();
+        expectPollInitialAssignment();
+
         Capture<Collection<SinkRecord>> capturedRecords = expectPolls(1L);
         expectStopTask(10L);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+
+        // First iteration initializes partition assignment
+        workerTask.iteration();
+
+        // Then we iterate to fetch data
         for (int i = 0; i < 10; i++) {
-            workerThread.iteration();
+            workerTask.iteration();
         }
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
@@ -163,23 +168,28 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     @Test
     public void testCommit() throws Exception {
         expectInitializeTask();
+        expectPollInitialAssignment();
+
         // Make each poll() take the offset commit interval
         Capture<Collection<SinkRecord>> capturedRecords
                 = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetFlush(1L, null, null, 0, true);
         expectStopTask(2);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
-        // First iteration gets one record
-        workerThread.iteration();
-        // Second triggers commit, gets a second offset
-        workerThread.iteration();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+
+        // Initialize partition assignment
+        workerTask.iteration();
+        // Fetch one record
+        workerTask.iteration();
+        // Trigger the commit
+        workerTask.iteration();
+
         // Commit finishes synchronously for testing so we can check this immediately
-        assertEquals(0, workerThread.commitFailures());
+        assertEquals(0, workerTask.commitFailures());
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
@@ -192,6 +202,8 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     @Test
     public void testCommitTaskFlushFailure() throws Exception {
         expectInitializeTask();
+        expectPollInitialAssignment();
+
         Capture<Collection<SinkRecord>> capturedRecords = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetFlush(1L, new RuntimeException(), null, 0, true);
         // Should rewind to last known good positions, which in this case will be the offsets loaded during initialization
@@ -203,17 +215,21 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
         consumer.seek(TOPIC_PARTITION3, FIRST_OFFSET);
         PowerMock.expectLastCall();
         expectStopTask(2);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
-        // Second iteration triggers commit
-        workerThread.iteration();
-        workerThread.iteration();
-        assertEquals(1, workerThread.commitFailures());
-        assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+
+        // Initialize partition assignment
+        workerTask.iteration();
+        // Fetch some data
+        workerTask.iteration();
+        // Trigger the commit
+        workerTask.iteration();
+
+        assertEquals(1, workerTask.commitFailures());
+        assertEquals(false, Whitebox.getInternalState(workerTask, "committing"));
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
@@ -226,6 +242,7 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
         // Validate that we rewind to the correct offsets if a task's flush method throws an exception
 
         expectInitializeTask();
+        expectPollInitialAssignment();
         Capture<Collection<SinkRecord>> capturedRecords = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetFlush(1L, null, null, 0, true);
         expectOffsetFlush(2L, new RuntimeException(), null, 0, true);
@@ -237,18 +254,23 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
         consumer.seek(TOPIC_PARTITION3, FIRST_OFFSET);
         PowerMock.expectLastCall();
         expectStopTask(2);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
-        // Second iteration triggers first commit, third iteration triggers second (failing) commit
-        workerThread.iteration();
-        workerThread.iteration();
-        workerThread.iteration();
-        assertEquals(1, workerThread.commitFailures());
-        assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+
+        // Initialize partition assignment
+        workerTask.iteration();
+        // Fetch some data
+        workerTask.iteration();
+        // Trigger first commit,
+        workerTask.iteration();
+        // Trigger second (failing) commit
+        workerTask.iteration();
+
+        assertEquals(1, workerTask.commitFailures());
+        assertEquals(false, Whitebox.getInternalState(workerTask, "committing"));
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
@@ -259,22 +281,28 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     @Test
     public void testCommitConsumerFailure() throws Exception {
         expectInitializeTask();
+        expectPollInitialAssignment();
+
         Capture<Collection<SinkRecord>> capturedRecords
                 = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT);
         expectOffsetFlush(1L, null, new Exception(), 0, true);
         expectStopTask(2);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
-        // Second iteration triggers commit
-        workerThread.iteration();
-        workerThread.iteration();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+
+        // Initialize partition assignment
+        workerTask.iteration();
+        // Fetch some data
+        workerTask.iteration();
+        // Trigger commit
+        workerTask.iteration();
+
         // TODO Response to consistent failures?
-        assertEquals(1, workerThread.commitFailures());
-        assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
+        assertEquals(1, workerTask.commitFailures());
+        assertEquals(false, Whitebox.getInternalState(workerTask, "committing"));
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
@@ -285,26 +313,32 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     @Test
     public void testCommitTimeout() throws Exception {
         expectInitializeTask();
+        expectPollInitialAssignment();
+
         // Cut down amount of time to pass in each poll so we trigger exactly 1 offset commit
         Capture<Collection<SinkRecord>> capturedRecords
                 = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT / 2);
         expectOffsetFlush(2L, null, null, WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_DEFAULT, false);
         expectStopTask(4);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
-        // Third iteration triggers commit, fourth gives a chance to trigger the timeout but doesn't
-        // trigger another commit
-        workerThread.iteration();
-        workerThread.iteration();
-        workerThread.iteration();
-        workerThread.iteration();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+
+        // Initialize partition assignment
+        workerTask.iteration();
+        // Fetch some data
+        workerTask.iteration();
+        workerTask.iteration();
+        // Trigger the commit
+        workerTask.iteration();
+        // Trigger the timeout without another commit
+        workerTask.iteration();
+
         // TODO Response to consistent failures?
-        assertEquals(1, workerThread.commitFailures());
-        assertEquals(false, Whitebox.getInternalState(workerThread, "committing"));
+        assertEquals(1, workerTask.commitFailures());
+        assertEquals(false, Whitebox.getInternalState(workerTask, "committing"));
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
@@ -366,15 +400,14 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
         PowerMock.expectLastCall();
 
         expectStopTask(0);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
 
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
-        workerThread.iteration();
-        workerThread.iteration();
-        workerThread.iteration();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+        workerTask.iteration();
+        workerTask.iteration();
+        workerTask.iteration();
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
@@ -385,6 +418,8 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     @Test
     public void testRewind() throws Exception {
         expectInitializeTask();
+        expectPollInitialAssignment();
+
         final long startOffset = 40L;
         final Map<TopicPartition, Long> offsets = new HashMap<>();
 
@@ -410,14 +445,13 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
         });
 
         expectStopTask(3);
-        EasyMock.expect(workerThread.awaitShutdown(EasyMock.anyLong(), EasyMock.<TimeUnit>anyObject())).andReturn(true);
-
         PowerMock.replayAll();
 
-        workerTask.start(TASK_PROPS);
-        workerTask.joinConsumerGroupAndStart();
-        workerThread.iteration();
-        workerThread.iteration();
+        workerTask.initialize(TASK_PROPS);
+        workerTask.initializeAndStart();
+        workerTask.iteration();
+        workerTask.iteration();
+        workerTask.iteration();
         workerTask.stop();
         workerTask.awaitStop(Long.MAX_VALUE);
         workerTask.close();
@@ -428,21 +462,25 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     private void expectInitializeTask() throws Exception {
         PowerMock.expectPrivate(workerTask, "createConsumer").andReturn(consumer);
 
-        workerThread = PowerMock.createPartialMock(WorkerSinkTaskThread.class, new String[]{"start", "awaitShutdown"},
-                workerTask, "mock-worker-thread", time,
-                workerConfig);
-        PowerMock.expectPrivate(workerTask, "createWorkerThread")
-                .andReturn(workerThread);
-        workerThread.start();
+        consumer.subscribe(EasyMock.eq(Arrays.asList(TOPIC)), EasyMock.capture(rebalanceListener));
         PowerMock.expectLastCall();
 
-        consumer.subscribe(EasyMock.eq(Arrays.asList(TOPIC)), EasyMock.capture(rebalanceListener));
+        sinkTask.initialize(EasyMock.capture(sinkTaskContext));
         PowerMock.expectLastCall();
+        sinkTask.start(TASK_PROPS);
+        PowerMock.expectLastCall();
+    }
+
+    private void expectPollInitialAssignment() throws Exception {
+        final List<TopicPartition> partitions = Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3);
+
+        sinkTask.open(partitions);
+        EasyMock.expectLastCall();
 
         EasyMock.expect(consumer.poll(EasyMock.anyLong())).andAnswer(new IAnswer<ConsumerRecords<byte[], byte[]>>() {
             @Override
             public ConsumerRecords<byte[], byte[]> answer() throws Throwable {
-                rebalanceListener.getValue().onPartitionsAssigned(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3));
+                rebalanceListener.getValue().onPartitionsAssigned(partitions);
                 return ConsumerRecords.empty();
             }
         });
@@ -450,15 +488,11 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
         EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET);
         EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET);
 
-        sinkTask.initialize(EasyMock.capture(sinkTaskContext));
-        PowerMock.expectLastCall();
-        sinkTask.start(TASK_PROPS);
-        PowerMock.expectLastCall();
+        sinkTask.put(Collections.<SinkRecord>emptyList());
+        EasyMock.expectLastCall();
     }
 
     private void expectStopTask(final long expectedMessages) throws Exception {
-        final long finalOffset = FIRST_OFFSET + expectedMessages - 1;
-
         sinkTask.stop();
         PowerMock.expectLastCall();
 
@@ -526,10 +560,10 @@ public class WorkerSinkTaskThreadedTest extends ThreadedTest {
     }
 
     private Capture<OffsetCommitCallback> expectOffsetFlush(final long expectedMessages,
-                                                              final RuntimeException flushError,
-                                                              final Exception consumerCommitError,
-                                                              final long consumerCommitDelayMs,
-                                                              final boolean invokeCallback)
+                                                            final RuntimeException flushError,
+                                                            final Exception consumerCommitError,
+                                                            final long consumerCommitDelayMs,
+                                                            final boolean invokeCallback)
             throws Exception {
         final long finalOffset = FIRST_OFFSET + expectedMessages;
 

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java
index f16cbeb..3888534 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java
@@ -52,6 +52,8 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
@@ -77,6 +79,7 @@ public class WorkerSourceTaskTest extends ThreadedTest {
     private static final byte[] SERIALIZED_KEY = "converted-key".getBytes();
     private static final byte[] SERIALIZED_RECORD = "converted-record".getBytes();
 
+    private ExecutorService executor = Executors.newSingleThreadExecutor();
     private ConnectorTaskId taskId = new ConnectorTaskId("job", 0);
     private WorkerConfig config;
     @Mock private SourceTask sourceTask;
@@ -132,7 +135,8 @@ public class WorkerSourceTaskTest extends ThreadedTest {
 
         PowerMock.replayAll();
 
-        workerTask.start(EMPTY_TASK_PROPS);
+        workerTask.initialize(EMPTY_TASK_PROPS);
+        executor.submit(workerTask);
         awaitPolls(pollLatch);
         workerTask.stop();
         assertEquals(true, workerTask.awaitStop(1000));
@@ -160,7 +164,8 @@ public class WorkerSourceTaskTest extends ThreadedTest {
 
         PowerMock.replayAll();
 
-        workerTask.start(EMPTY_TASK_PROPS);
+        workerTask.initialize(EMPTY_TASK_PROPS);
+        executor.submit(workerTask);
         awaitPolls(pollLatch);
         assertTrue(workerTask.commitOffsets());
         workerTask.stop();
@@ -189,7 +194,8 @@ public class WorkerSourceTaskTest extends ThreadedTest {
 
         PowerMock.replayAll();
 
-        workerTask.start(EMPTY_TASK_PROPS);
+        workerTask.initialize(EMPTY_TASK_PROPS);
+        executor.submit(workerTask);
         awaitPolls(pollLatch);
         assertFalse(workerTask.commitOffsets());
         workerTask.stop();
@@ -271,7 +277,8 @@ public class WorkerSourceTaskTest extends ThreadedTest {
 
         PowerMock.replayAll();
 
-        workerTask.start(EMPTY_TASK_PROPS);
+        workerTask.initialize(EMPTY_TASK_PROPS);
+        executor.submit(workerTask);
         // Stopping immediately while the other thread has work to do should result in no polling, no offset commits,
         // exiting the work thread immediately, and the stop() method will be invoked in the background thread since it
         // cannot be invoked immediately in the thread trying to stop the task.

http://git-wip-us.apache.org/repos/asf/kafka/blob/1d80f563/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java
----------------------------------------------------------------------
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java
index 335e0ce..f33347a 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java
@@ -354,14 +354,13 @@ public class WorkerTest extends ThreadedTest {
                 .andReturn(workerTask);
         Map<String, String> origProps = new HashMap<>();
         origProps.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName());
-        workerTask.start(origProps);
+        workerTask.initialize(origProps);
         EasyMock.expectLastCall();
 
         // Remove
         workerTask.stop();
         EasyMock.expectLastCall();
         EasyMock.expect(workerTask.awaitStop(EasyMock.anyLong())).andStubReturn(true);
-        workerTask.close();
         EasyMock.expectLastCall();
 
         offsetBackingStore.stop();
@@ -424,7 +423,7 @@ public class WorkerTest extends ThreadedTest {
                 .andReturn(workerTask);
         Map<String, String> origProps = new HashMap<>();
         origProps.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName());
-        workerTask.start(origProps);
+        workerTask.initialize(origProps);
         EasyMock.expectLastCall();
 
         // Remove on Worker.stop()
@@ -432,7 +431,6 @@ public class WorkerTest extends ThreadedTest {
         EasyMock.expectLastCall();
         EasyMock.expect(workerTask.awaitStop(EasyMock.anyLong())).andReturn(true);
         // Note that in this case we *do not* commit offsets since it's an unclean shutdown
-        workerTask.close();
         EasyMock.expectLastCall();
 
         offsetBackingStore.stop();