You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by ca...@apache.org on 2020/06/01 17:35:27 UTC

[samza] branch master updated: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop (#1366)

This is an automated email from the ASF dual-hosted git repository.

cameronlee pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new f134f0f  SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop (#1366)
f134f0f is described below

commit f134f0f467814c09e24d00c6b0fc47be7e47622c
Author: bkonold <bk...@users.noreply.github.com>
AuthorDate: Mon Jun 1 10:35:14 2020 -0700

    SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop (#1366)
    
    API Changes: None. RunLoop, TaskInstance, RunLoopTask are internal to Samza and are not user facing.
    Upgrade Instructions: None.
    Usage Instructions: None.
---
 .../java/org/apache/samza/container/RunLoop.java   |  30 +-
 .../org/apache/samza/container/RunLoopFactory.java |  16 +-
 .../org/apache/samza/container/RunLoopTask.java    | 146 ++++
 .../apache/samza/container/SamzaContainer.scala    |   3 +-
 .../org/apache/samza/container/TaskInstance.scala  |  38 +-
 .../org/apache/samza/container/TestRunLoop.java    | 859 +++++++--------------
 .../apache/samza/container/TestTaskInstance.scala  |  41 +-
 7 files changed, 489 insertions(+), 644 deletions(-)

diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
index 6974f35..2917c83 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
@@ -23,7 +23,6 @@ import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -51,7 +50,6 @@ import org.apache.samza.util.Throttleable;
 import org.apache.samza.util.ThrottlingScheduler;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import scala.collection.JavaConverters;
 
 
 /**
@@ -88,7 +86,7 @@ public class RunLoop implements Runnable, Throttleable {
   private final boolean isAsyncCommitEnabled;
   private volatile boolean runLoopResumedSinceLastChecked;
 
-  public RunLoop(Map<TaskName, TaskInstance> taskInstances,
+  public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
       ExecutorService threadPool,
       SystemConsumers consumerMultiplexer,
       int maxConcurrency,
@@ -111,16 +109,16 @@ public class RunLoop implements Runnable, Throttleable {
     this.maxIdleMs = maxIdleMs;
     this.callbackTimer = (callbackTimeoutMs > 0) ? Executors.newSingleThreadScheduledExecutor() : null;
     this.callbackExecutor = new ThrottlingScheduler(maxThrottlingDelayMs);
-    this.coordinatorRequests = new CoordinatorRequests(taskInstances.keySet());
+    this.coordinatorRequests = new CoordinatorRequests(runLoopTasks.keySet());
     this.latch = new Object();
     this.workerTimer = Executors.newSingleThreadScheduledExecutor();
     this.clock = clock;
     Map<TaskName, AsyncTaskWorker> workers = new HashMap<>();
-    for (TaskInstance task : taskInstances.values()) {
+    for (RunLoopTask task : runLoopTasks.values()) {
       workers.put(task.taskName(), new AsyncTaskWorker(task));
     }
     // Partions and tasks assigned to the container will not change during the run loop life time
-    this.sspToTaskWorkerMapping = Collections.unmodifiableMap(getSspToAsyncTaskWorkerMap(taskInstances, workers));
+    this.sspToTaskWorkerMapping = Collections.unmodifiableMap(getSspToAsyncTaskWorkerMap(runLoopTasks, workers));
     this.taskWorkers = Collections.unmodifiableList(new ArrayList<>(workers.values()));
     this.isAsyncCommitEnabled = isAsyncCommitEnabled;
   }
@@ -129,10 +127,10 @@ public class RunLoop implements Runnable, Throttleable {
    * Returns mapping of the SystemStreamPartition to the AsyncTaskWorkers to efficiently route the envelopes
    */
   private static Map<SystemStreamPartition, List<AsyncTaskWorker>> getSspToAsyncTaskWorkerMap(
-      Map<TaskName, TaskInstance> taskInstances, Map<TaskName, AsyncTaskWorker> taskWorkers) {
+      Map<TaskName, RunLoopTask> runLoopTasks, Map<TaskName, AsyncTaskWorker> taskWorkers) {
     Map<SystemStreamPartition, List<AsyncTaskWorker>> sspToWorkerMap = new HashMap<>();
-    for (TaskInstance task : taskInstances.values()) {
-      Set<SystemStreamPartition> ssps = JavaConverters.setAsJavaSetConverter(task.systemStreamPartitions()).asJava();
+    for (RunLoopTask task : runLoopTasks.values()) {
+      Set<SystemStreamPartition> ssps = task.systemStreamPartitions();
       for (SystemStreamPartition ssp : ssps) {
         sspToWorkerMap.putIfAbsent(ssp, new ArrayList<>());
         sspToWorkerMap.get(ssp).add(taskWorkers.get(task.taskName()));
@@ -361,15 +359,15 @@ public class RunLoop implements Runnable, Throttleable {
    * will run the task asynchronously. It runs window and commit in the provided thread pool.
    */
   private class AsyncTaskWorker implements TaskCallbackListener {
-    private final TaskInstance task;
+    private final RunLoopTask task;
     private final TaskCallbackManager callbackManager;
     private volatile AsyncTaskState state;
 
-    AsyncTaskWorker(TaskInstance task) {
+    AsyncTaskWorker(RunLoopTask task) {
       this.task = task;
       this.callbackManager = new TaskCallbackManager(this, callbackTimer, callbackTimeoutMs, maxConcurrency, clock);
       Set<SystemStreamPartition> sspSet = getWorkingSSPSet(task);
-      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, task.intermediateStreams().nonEmpty());
+      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, !task.intermediateStreams().isEmpty());
     }
 
     private void init() {
@@ -409,9 +407,9 @@ public class RunLoop implements Runnable, Throttleable {
      * @param task
      * @return a Set of SSPs such that all SSPs are not at end of stream.
      */
-    private Set<SystemStreamPartition> getWorkingSSPSet(TaskInstance task) {
+    private Set<SystemStreamPartition> getWorkingSSPSet(RunLoopTask task) {
 
-      Set<SystemStreamPartition> allPartitions = new HashSet<>(JavaConverters.setAsJavaSetConverter(task.systemStreamPartitions()).asJava());
+      Set<SystemStreamPartition> allPartitions = task.systemStreamPartitions();
 
       // filter only those SSPs that are not at end of stream.
       Set<SystemStreamPartition> workingSSPSet = allPartitions.stream()
@@ -631,7 +629,9 @@ public class RunLoop implements Runnable, Throttleable {
               log.trace("Update offset for ssp {}, offset {}", envelope.getSystemStreamPartition(), envelope.getOffset());
 
               // update offset
-              task.offsetManager().update(task.taskName(), envelope.getSystemStreamPartition(), envelope.getOffset());
+              if (task.offsetManager() != null) {
+                task.offsetManager().update(task.taskName(), envelope.getSystemStreamPartition(), envelope.getOffset());
+              }
 
               // update coordinator
               coordinatorRequests.update(callbackToUpdate.getCoordinator());
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
index 0e0a01c..e2069f4 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
@@ -19,14 +19,12 @@
 
 package org.apache.samza.container;
 
-import org.apache.samza.SamzaException;
 import org.apache.samza.config.TaskConfig;
 import org.apache.samza.system.SystemConsumers;
 import org.apache.samza.util.HighResolutionClock;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.collection.JavaConverters;
-import scala.runtime.AbstractFunction1;
 import java.util.concurrent.ExecutorService;
 
 /**
@@ -36,7 +34,7 @@ import java.util.concurrent.ExecutorService;
 public class RunLoopFactory {
   private static final Logger log = LoggerFactory.getLogger(RunLoopFactory.class);
 
-  public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, TaskInstance> taskInstances,
+  public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, RunLoopTask> taskInstances,
       SystemConsumers consumerMultiplexer,
       ExecutorService threadPool,
       long maxThrottlingDelayMs,
@@ -52,18 +50,6 @@ public class RunLoopFactory {
 
     log.info("Got commit milliseconds: {}.", taskCommitMs);
 
-    int asyncTaskCount = taskInstances.values().count(new AbstractFunction1<TaskInstance, Object>() {
-      @Override
-      public Boolean apply(TaskInstance t) {
-        return t.isAsyncTask();
-      }
-    });
-
-    // asyncTaskCount should be either 0 or the number of all taskInstances
-    if (asyncTaskCount > 0 && asyncTaskCount < taskInstances.size()) {
-      throw new SamzaException("Mixing StreamTask and AsyncStreamTask is not supported");
-    }
-
     int taskMaxConcurrency = taskConfig.getMaxConcurrency();
     log.info("Got taskMaxConcurrency: {}.", taskMaxConcurrency);
 
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
new file mode 100644
index 0000000..551da88
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Set;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+
+
+/**
+ * The interface required for a task's execution to be managed within {@link RunLoop}.
+ *
+ * Some notes on thread safety and exclusivity between methods:
+ *
+ * TODO SAMZA-2531: isAsyncCommitEnabled is either an incomplete feature or misnamed
+ * RunLoop will ensure exclusivity between {@link #window}, {@link #commit}, {@link #scheduler}, and
+ * {@link #endOfStream}.
+ *
+ * There is an exception for {@link #process}, which may execute concurrently with {@link #commit} IF the encapsulating
+ * {@link RunLoop} has its isAsyncCommitEnabled set to true. In this case, the implementer of this interface should
+ * take care to ensure that any objects shared between commit and process are thread safe.
+ *
+ * Be aware that {@link #commit}, {@link #window} and {@link #scheduler} can be run in their own thread pool outside
+ * the main RunLoop thread (which executes {@link #process}) so may run concurrently between tasks. For example, one
+ * task may be executing a commit while another is executing window. For this reason, implementers of this class must
+ * ensure that objects shared between instances of RunLoopTask are thread safe.
+ */
+public interface RunLoopTask {
+
+  /**
+   * The {@link TaskName} associated with this RunLoopTask.
+   *
+   * @return taskName
+   */
+  TaskName taskName();
+
+  /**
+   * Process an incoming message envelope.
+   *
+   * @param envelope The envelope to be processed
+   * @param coordinator Manages execution of tasks
+   * @param callbackFactory Creates a callback to be used to indicate completion of or failure to process the
+   *                        envelope. {@link TaskCallbackFactory#createCallback()} should be called before processing
+   *                        begins.
+   */
+  void process(IncomingMessageEnvelope envelope, ReadableCoordinator coordinator, TaskCallbackFactory callbackFactory);
+
+  /**
+   * Performs a window for this task. If {@link #isWindowableTask()} is true, this method will be invoked periodically
+   * by {@link RunLoop} according to its windowMs.
+   *
+   * This method can be used to perform aggregations within a task.
+   *
+   * @param coordinator Manages execution of tasks
+   */
+  void window(ReadableCoordinator coordinator);
+
+  /**
+   * Used in conjunction with {@link #epochTimeScheduler()} to execute scheduled callbacks. See documentation of
+   * {@link EpochTimeScheduler} for more information.
+   *
+   * @param coordinator Manages execution of tasks.
+   */
+  void scheduler(ReadableCoordinator coordinator);
+
+  /**
+   * Performs a commit for this task. Operations for persisting checkpoint-related information for this task should
+   * be done here.
+   */
+  void commit();
+
+  /**
+   * Called when all {@link SystemStreamPartition} processed by a task have reached end of stream. This is called only
+   * once per task. {@link RunLoop} will issue a shutdown request to the coordinator immediately following the
+   * invocation of this method.
+   *
+   * @param coordinator manages execution of tasks.
+   */
+  void endOfStream(ReadableCoordinator coordinator);
+
+  /**
+   * Indicates whether {@link #window} should be invoked on this task. If true, {@link RunLoop}
+   * will schedule window to execute periodically according to its windowMs.
+   *
+   * @return whether the task should perform window
+   */
+  boolean isWindowableTask();
+
+  /**
+   * Whether this task has intermediate streams. Intermediate streams may be used to facilitate task processing
+   * before terminal output is produced. {@link RunLoop} uses this information to determine when the task has reached
+   * end of stream.
+   *
+   * @return whether the task uses intermediate streams
+   */
+  Set<String> intermediateStreams();
+
+  /**
+   * The set of {@link SystemStreamPartition} this task consumes from.
+   *
+   * @return the set of SSPs
+   */
+  Set<SystemStreamPartition> systemStreamPartitions();
+
+  /**
+   * An {@link OffsetManager}, if any, to use to track offsets for each input SSP. Offsets will be updated after successful
+   * completion of an envelope from an SSP.
+   *
+   * @return the offset manager, or null otherwise
+   */
+  OffsetManager offsetManager();
+
+  /**
+   * The metrics instance {@link RunLoop} will use to emit metrics related to execution of this task.
+   *
+   * @return metrics instance for this task
+   */
+  TaskInstanceMetrics metrics();
+
+  /**
+   * An {@link EpochTimeScheduler}, if any, used by the task to handle timer based callbacks.
+   *
+   * @return the scheduler, or null otherwise
+   */
+  EpochTimeScheduler epochTimeScheduler();
+}
\ No newline at end of file
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index 6fab351..83ce3a1 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -52,6 +52,7 @@ import org.apache.samza.util.{Util, _}
 import org.apache.samza.SamzaException
 import org.apache.samza.clustermanager.StandbyTaskUtil
 
+import scala.collection.JavaConversions
 import scala.collection.JavaConverters._
 
 object SamzaContainer extends Logging {
@@ -587,7 +588,7 @@ object SamzaContainer extends Logging {
           offsetManager = offsetManager,
           storageManager = storageManager,
           tableManager = tableManager,
-          systemStreamPartitions = taskSSPs -- taskSideInputSSPs,
+          systemStreamPartitions = JavaConversions.setAsJavaSet(taskSSPs -- taskSideInputSSPs),
           exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics.get(taskName).get, taskConfig),
           jobModel = jobModel,
           streamMetadataCache = streamMetadataCache,
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 37aaeff..2ebe465 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -20,7 +20,7 @@
 package org.apache.samza.container
 
 
-import java.util.{Objects, Optional}
+import java.util.{Collections, Objects, Optional}
 import java.util.concurrent.ScheduledExecutorService
 
 import org.apache.samza.SamzaException
@@ -38,7 +38,7 @@ import org.apache.samza.util.{Logging, ScalaJavaUtil}
 
 import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
-import scala.collection.Map
+import scala.collection.{JavaConverters, Map}
 
 class TaskInstance(
   val task: Any,
@@ -47,10 +47,10 @@ class TaskInstance(
   systemAdmins: SystemAdmins,
   consumerMultiplexer: SystemConsumers,
   collector: TaskInstanceCollector,
-  val offsetManager: OffsetManager = new OffsetManager,
+  override val offsetManager: OffsetManager = new OffsetManager,
   storageManager: TaskStorageManager = null,
   tableManager: TableManager = null,
-  val systemStreamPartitions: Set[SystemStreamPartition] = Set(),
+  val systemStreamPartitions: java.util.Set[SystemStreamPartition] = Collections.emptySet(),
   val exceptionHandler: TaskInstanceExceptionHandler = new TaskInstanceExceptionHandler,
   jobModel: JobModel = null,
   streamMetadataCache: StreamMetadataCache = null,
@@ -60,16 +60,16 @@ class TaskInstance(
   containerContext: ContainerContext,
   applicationContainerContextOption: Option[ApplicationContainerContext],
   applicationTaskContextFactoryOption: Option[ApplicationTaskContextFactory[ApplicationTaskContext]],
-  externalContextOption: Option[ExternalContext]) extends Logging {
+  externalContextOption: Option[ExternalContext]) extends Logging with RunLoopTask {
 
   val taskName: TaskName = taskModel.getTaskName
   val isInitableTask = task.isInstanceOf[InitableTask]
-  val isWindowableTask = task.isInstanceOf[WindowableTask]
   val isEndOfStreamListenerTask = task.isInstanceOf[EndOfStreamListenerTask]
   val isClosableTask = task.isInstanceOf[ClosableTask]
-  val isAsyncTask = task.isInstanceOf[AsyncStreamTask]
 
-  val epochTimeScheduler: EpochTimeScheduler = EpochTimeScheduler.create(timerExecutor)
+  override val isWindowableTask = task.isInstanceOf[WindowableTask]
+
+  override val epochTimeScheduler: EpochTimeScheduler = EpochTimeScheduler.create(timerExecutor)
 
   private val kvStoreSupplier = ScalaJavaUtil.toJavaFunction(
     (storeName: String) => {
@@ -99,7 +99,7 @@ class TaskInstance(
   private val config: Config = jobContext.getConfig
 
   val streamConfig: StreamConfig = new StreamConfig(config)
-  val intermediateStreams: Set[String] = streamConfig.getStreamIds.filter(streamConfig.getIsIntermediateStream).toSet
+  override val intermediateStreams: java.util.Set[String] = JavaConverters.setAsJavaSetConverter(streamConfig.getStreamIds.filter(streamConfig.getIsIntermediateStream)).asJava
 
   val streamsToDeleteCommittedMessages: Set[String] = streamConfig.getStreamIds.filter(streamConfig.getDeleteCommittedMessages).map(streamConfig.getPhysicalName).toSet
 
@@ -165,7 +165,7 @@ class TaskInstance(
   }
 
   def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator,
-    callbackFactory: TaskCallbackFactory = null) {
+    callbackFactory: TaskCallbackFactory) {
     metrics.processes.inc
 
     val incomingMessageSsp = envelope.getSystemStreamPartition
@@ -181,22 +181,10 @@ class TaskInstance(
       trace("Processing incoming message envelope for taskName and SSP: %s, %s"
         format (taskName, incomingMessageSsp))
 
-      if (isAsyncTask) {
-        exceptionHandler.maybeHandle {
-          val callback = callbackFactory.createCallback()
-          task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback)
-        }
-      } else {
-        exceptionHandler.maybeHandle {
-          task.asInstanceOf[StreamTask].process(envelope, collector, coordinator)
-        }
-
-        trace("Updating offset map for taskName, SSP and offset: %s, %s, %s"
-          format(taskName, incomingMessageSsp, envelope.getOffset))
-
-        offsetManager.update(taskName, incomingMessageSsp, envelope.getOffset)
+      exceptionHandler.maybeHandle {
+        val callback = callbackFactory.createCallback()
+        task.asInstanceOf[AsyncStreamTask].processAsync(envelope, collector, coordinator, callback)
       }
-
     }
   }
 
diff --git a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
index 9d65a57..1ec718e 100644
--- a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
@@ -20,10 +20,8 @@
 package org.apache.samza.container;
 
 import com.google.common.collect.ImmutableMap;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
@@ -31,44 +29,23 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.samza.Partition;
 import org.apache.samza.SamzaException;
-import org.apache.samza.checkpoint.Checkpoint;
 import org.apache.samza.checkpoint.OffsetManager;
-import org.apache.samza.context.ContainerContext;
-import org.apache.samza.context.JobContext;
-import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.system.IncomingMessageEnvelope;
-import org.apache.samza.system.SystemAdmin;
-import org.apache.samza.system.SystemAdmins;
-import org.apache.samza.system.SystemConsumer;
 import org.apache.samza.system.SystemConsumers;
 import org.apache.samza.system.SystemStreamPartition;
-import org.apache.samza.system.TestSystemConsumers;
-import org.apache.samza.task.AsyncStreamTask;
-import org.apache.samza.task.EndOfStreamListenerTask;
-import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.ReadableCoordinator;
 import org.apache.samza.task.TaskCallback;
-import org.apache.samza.task.TaskCallbackImpl;
+import org.apache.samza.task.TaskCallbackFactory;
 import org.apache.samza.task.TaskCoordinator;
-import org.apache.samza.task.TaskInstanceCollector;
-import org.apache.samza.task.WindowableTask;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.Timeout;
-import org.mockito.Mockito;
-import scala.Option;
-import scala.collection.JavaConverters;
+import org.mockito.InOrder;
 
 import static org.junit.Assert.assertEquals;
-import static org.mockito.Mockito.any;
-import static org.mockito.Mockito.anyLong;
-import static org.mockito.Mockito.anyObject;
-import static org.mockito.Mockito.atLeastOnce;
-import static org.mockito.Mockito.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
+
 
 public class TestRunLoop {
   // Immutable objects shared by all test methods.
@@ -85,725 +62,473 @@ public class TestRunLoop {
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope11 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
-  public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope11), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
+        assertEquals(0, task0.metrics().asyncCallbackCompleted().getCount());
+
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = getMockRunLoopTask(taskName0, ssp0);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope11), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope11)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope11), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    task0.callbackHandler = buildOutofOrderCallback(task0);
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
-    runLoop.run();
-
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-  }
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-  @Test
-  public void testEndOfStreamCommitBehavior() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask task0 = new TestTask(true, false, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        assertEquals(2, task0.metrics().asyncCallbackCompleted().getCount());
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).endOfStream(any());
   }
 
   @Test
-  public void testEndOfStreamOffsetManagement() throws Exception {
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask mockStreamTask1 = new TestTask(true, false, false, null);
-    TestTask mockStreamTask2 = new TestTask(true, false, false, null);
-
-    Partition p1 = new Partition(1);
-    Partition p2 = new Partition(2);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system1", "stream1", p1);
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system1", "stream2", p2);
-    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", "key1", "message1");
-    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", "key1", "message1");
-    IncomingMessageEnvelope envelope3 = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
-
-    Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new HashMap<>();
-    List<IncomingMessageEnvelope> messageList = new ArrayList<>();
-    messageList.add(envelope1);
-    messageList.add(envelope2);
-    messageList.add(envelope3);
-    sspMap.put(ssp2, messageList);
-
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
-
-    SystemAdmins systemAdmins = Mockito.mock(SystemAdmins.class);
-    Mockito.when(systemAdmins.getSystemAdmin("system1")).thenReturn(Mockito.mock(SystemAdmin.class));
-    Mockito.when(systemAdmins.getSystemAdmin("testSystem")).thenReturn(Mockito.mock(SystemAdmin.class));
-
-    HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
-    systemConsumerMap.put("system1", mockConsumer);
-
-    SystemConsumers consumers = TestSystemConsumers.getSystemConsumers(systemConsumerMap, systemAdmins);
-
-    TaskName taskName1 = new TaskName("task1");
-    TaskName taskName2 = new TaskName("task2");
+  public void testEndOfStreamCommitBehavior() {
+    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
 
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
 
-    when(offsetManager.getLastProcessedOffset(taskName1, ssp1)).thenReturn(Option.apply("3"));
-    when(offsetManager.getLastProcessedOffset(taskName2, ssp2)).thenReturn(Option.apply("0"));
-    when(offsetManager.getStartingOffset(taskName1, ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
-    when(offsetManager.getStartingOffset(taskName2, ssp2)).thenReturn(Option.apply("1"));
-    when(offsetManager.getStartpoint(anyObject(), anyObject())).thenReturn(Option.empty());
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers);
-    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName1, taskInstance1);
-    tasks.put(taskName2, taskInstance2);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    taskInstance1.registerConsumers();
-    taskInstance2.registerConsumers();
-    consumers.start();
+    tasks.put(taskName0, task0);
 
     int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumers, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-  }
-
-  //@Test
-  public void testCommitBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    int maxMessagesInFlight = 3;
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, false, null, maxMessagesInFlight);
-
-    IncomingMessageEnvelope firstMsg = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-    IncomingMessageEnvelope secondMsg = new IncomingMessageEnvelope(ssp0, "1", "key1", "value1");
-    IncomingMessageEnvelope thirdMsg = new IncomingMessageEnvelope(ssp0, "2", "key0", "value0");
-
-    final CountDownLatch firstMsgCompletionLatch = new CountDownLatch(1);
-    final CountDownLatch secondMsgCompletionLatch = new CountDownLatch(1);
-    task0.callbackHandler = callback -> {
-      IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-      try {
-        if (envelope.equals(firstMsg)) {
-          firstMsgCompletionLatch.await();
-        } else if (envelope.equals(secondMsg)) {
-          firstMsgCompletionLatch.countDown();
-          secondMsgCompletionLatch.await();
-        } else if (envelope.equals(thirdMsg)) {
-          secondMsgCompletionLatch.countDown();
-          // OffsetManager.update with firstMsg offset, task.commit has happened when second message callback has not completed.
-          verify(offsetManager).update(eq(taskName0), eq(firstMsg.getSystemStreamPartition()), eq(firstMsg.getOffset()));
-        }
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(firstMsg).thenReturn(secondMsg).thenReturn(thirdMsg).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
-    firstMsgCompletionLatch.await();
-    secondMsgCompletionLatch.await();
+    InOrder inOrder = inOrder(task0);
 
-    verify(offsetManager, atLeastOnce()).buildCheckpoint(eq(taskName0));
-    verify(offsetManager, atLeastOnce()).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    assertEquals(3, task0.processed);
-    assertEquals(3, task0.committed);
-    assertEquals(1, task1.processed);
-    assertEquals(0, task1.committed);
+    inOrder.verify(task0).endOfStream(any());
+    inOrder.verify(task0).commit();
   }
 
   @Test
-  public void testProcessBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
+  public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
     int maxMessagesInFlight = 2;
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    CountDownLatch commitLatch = new CountDownLatch(1);
-    task0.commitHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope3)) {
-        try {
-          commitLatch.await();
-        } catch (InterruptedException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-
-    task0.callbackHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope0)) {
-        // Both the process call has gone through when the first commit is in progress.
-        assertEquals(2, containerMetrics.processes().getCount());
-        assertEquals(0, containerMetrics.commits().getCount());
-        commitLatch.countDown();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope3).thenReturn(envelope0).thenReturn(ssp0EndOfStream).thenReturn(null);
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
-                                            () -> 0L, true);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    CountDownLatch secondMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            // let the first message proceed to ask for a commit
+            firstMessageBarrier.countDown();
+            // block this message until commit is executed
+            secondMessageBarrier.await();
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
+        secondMessageBarrier.countDown();
+        return null;
+      }).when(task0).commit();
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+
+    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    commitLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).commit();
   }
 
   @Test(expected = SamzaException.class)
   public void testExceptionIsPropagated() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(false, false, false, null);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        callbackFactory.createCallback().failure(new Exception("Intentional failure"));
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = ImmutableMap.of(taskName0, t0);
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
 
-    int maxMessagesInFlight = 2;
+    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
         () -> 0L, false);
 
     when(consumerMultiplexer.choose(false))
-        .thenReturn(envelope0)
+        .thenReturn(envelope00)
         .thenReturn(ssp0EndOfStream)
         .thenReturn(null);
 
     runLoop.run();
   }
+
+  private RunLoopTask getMockRunLoopTask(TaskName taskName, SystemStreamPartition ssp0) {
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(new TaskInstanceMetrics("test", new MetricsRegistryMap()));
+    when(task0.taskName()).thenReturn(taskName);
+    return task0;
+  }
 }
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 90f1b58..4cab185 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -21,6 +21,7 @@ package org.apache.samza.container
 
 import java.util.Collections
 
+import com.google.common.collect.ImmutableSet
 import org.apache.samza.{Partition, SamzaException}
 import org.apache.samza.checkpoint.{Checkpoint, CheckpointedChangelogOffset, OffsetManager}
 import org.apache.samza.config.MapConfig
@@ -48,7 +49,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
   private val TASK_NAME = new TaskName("taskName")
   private val SYSTEM_STREAM_PARTITION =
     new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-stream"), new Partition(0))
-  private val SYSTEM_STREAM_PARTITIONS = Set(SYSTEM_STREAM_PARTITION)
+  private val SYSTEM_STREAM_PARTITIONS = ImmutableSet.of(SYSTEM_STREAM_PARTITION)
 
   @Mock
   private var task: AllTask = null
@@ -110,9 +111,12 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
     when(this.offsetManager.getStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION)).thenReturn(Some("0"))
     val envelope = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "0", null, null)
     val coordinator = mock[ReadableCoordinator]
-    this.taskInstance.process(envelope, coordinator)
+    val callbackFactory = mock[TaskCallbackFactory]
+    val callback = mock[TaskCallback]
+    when(callbackFactory.createCallback()).thenReturn(callback)
+    this.taskInstance.process(envelope, coordinator, callbackFactory)
     assertEquals(1, this.taskInstanceExceptionHandler.numTimesCalled)
-    verify(this.task).process(envelope, this.collector, coordinator)
+    verify(this.task).processAsync(envelope, this.collector, coordinator, callback)
     verify(processesCounter).inc()
     verify(messagesActuallyProcessedCounter).inc()
   }
@@ -152,16 +156,6 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
     verify(this.task).close()
   }
 
-  @Test
-  def testOffsetsAreUpdatedOnProcess() {
-    when(this.metrics.processes).thenReturn(mock[Counter])
-    when(this.metrics.messagesActuallyProcessed).thenReturn(mock[Counter])
-    when(this.offsetManager.getStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION)).thenReturn(Some("2"))
-    this.taskInstance.process(new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "4", null, null),
-      mock[ReadableCoordinator])
-    verify(this.offsetManager).update(TASK_NAME, SYSTEM_STREAM_PARTITION, "4")
-  }
-
   /**
    * Tests that the init() method of task can override the existing offset assignment.
    * This helps verify wiring for the task context (i.e. offset manager).
@@ -199,12 +193,17 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
     val newEnvelope0 = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "5", null, null)
     val newEnvelope1 = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "7", null, null)
 
-    this.taskInstance.process(oldEnvelope, mock[ReadableCoordinator])
-    this.taskInstance.process(newEnvelope0, mock[ReadableCoordinator])
-    this.taskInstance.process(newEnvelope1, mock[ReadableCoordinator])
-    verify(this.task).process(Matchers.eq(newEnvelope0), Matchers.eq(this.collector), any())
-    verify(this.task).process(Matchers.eq(newEnvelope1), Matchers.eq(this.collector), any())
-    verify(this.task, never()).process(Matchers.eq(oldEnvelope), any(), any())
+    val mockCoordinator = mock[ReadableCoordinator]
+    val mockCallback = mock[TaskCallback]
+    val mockCallbackFactory = mock[TaskCallbackFactory]
+    when(mockCallbackFactory.createCallback()).thenReturn(mockCallback)
+
+    this.taskInstance.process(oldEnvelope, mockCoordinator, mockCallbackFactory)
+    this.taskInstance.process(newEnvelope0, mockCoordinator, mockCallbackFactory)
+    this.taskInstance.process(newEnvelope1, mockCoordinator, mockCallbackFactory)
+    verify(this.task).processAsync(Matchers.eq(newEnvelope0), Matchers.eq(this.collector), Matchers.eq(mockCoordinator), Matchers.eq(mockCallback))
+    verify(this.task).processAsync(Matchers.eq(newEnvelope1), Matchers.eq(this.collector), Matchers.eq(mockCoordinator), Matchers.eq(mockCallback))
+    verify(this.task, never()).processAsync(Matchers.eq(oldEnvelope), any(), any(), any())
     verify(processesCounter, times(3)).inc()
     verify(messagesActuallyProcessedCounter, times(2)).inc()
   }
@@ -403,7 +402,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
       offsetManager = offsetManagerMock,
       storageManager = this.taskStorageManager,
       tableManager = this.taskTableManager,
-      systemStreamPartitions = Set(ssp),
+      systemStreamPartitions = ImmutableSet.of(ssp),
       exceptionHandler = this.taskInstanceExceptionHandler,
       streamMetadataCache = cacheMock,
       inputStreamMetadata = Map.empty ++ inputStreamMetadata,
@@ -441,7 +440,7 @@ class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
   /**
     * Task type which has all task traits, which can be mocked.
     */
-  trait AllTask extends StreamTask with InitableTask with ClosableTask with WindowableTask {}
+  trait AllTask extends AsyncStreamTask with InitableTask with ClosableTask with WindowableTask {}
 
   /**
     * Mock version of [TaskInstanceExceptionHandler] which just does a passthrough execution and keeps track of the