You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by GitBox <gi...@apache.org> on 2020/05/22 17:52:00 UTC

[GitHub] [samza] bkonold opened a new pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

bkonold opened a new pull request #1366:
URL: https://github.com/apache/samza/pull/1366


   **Issues**: Hard to reuse RunLoop for non-TaskInstance use cases because of coupling
    
   **Changes**: This is an initial pass at extracting an interface from TaskInstance. I've introduced RunLoopTask which represents the set of methods RunLoop uses from TaskInstance.
   
   There is no new functionality added in this patch.
    
   **Tests**: Existing unit tests have been adapted to work on RunLoopTask rather than TaskInstance.
   
   **API Changes**: None. RunLoop & TaskInstance are not public.
    
   **Upgrade Instructions**: None.
    
   **Usage Instructions**: None.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] cameronlee314 commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
cameronlee314 commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432176313



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());

Review comment:
       Whichever you prefer




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432165430



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");

Review comment:
       Yeah, I should probably change this. I was using the second digit more to indicate that it is used in tests as the "0th" message for ssp1. But it not matching the offset is confusing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432176157



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    verify(task0).endOfStream(any());
   }
 
   @Test
-  public void testEndOfStreamCommitBehavior() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamCommitBehavior() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask task0 = new TestTask(true, false, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    inOrder.verify(task0).endOfStream(any());
+    inOrder.verify(task0).commit();
   }
 
   @Test
-  public void testEndOfStreamOffsetManagement() throws Exception {
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask mockStreamTask1 = new TestTask(true, false, false, null);
-    TestTask mockStreamTask2 = new TestTask(true, false, false, null);
-
-    Partition p1 = new Partition(1);
-    Partition p2 = new Partition(2);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system1", "stream1", p1);
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system1", "stream2", p2);
-    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", "key1", "message1");
-    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", "key1", "message1");
-    IncomingMessageEnvelope envelope3 = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
-
-    Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new HashMap<>();
-    List<IncomingMessageEnvelope> messageList = new ArrayList<>();
-    messageList.add(envelope1);
-    messageList.add(envelope2);
-    messageList.add(envelope3);
-    sspMap.put(ssp2, messageList);
-
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
-
-    SystemAdmins systemAdmins = Mockito.mock(SystemAdmins.class);
-    Mockito.when(systemAdmins.getSystemAdmin("system1")).thenReturn(Mockito.mock(SystemAdmin.class));
-    Mockito.when(systemAdmins.getSystemAdmin("testSystem")).thenReturn(Mockito.mock(SystemAdmin.class));
-
-    HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
-    systemConsumerMap.put("system1", mockConsumer);
-
-    SystemConsumers consumers = TestSystemConsumers.getSystemConsumers(systemConsumerMap, systemAdmins);
-
-    TaskName taskName1 = new TaskName("task1");
-    TaskName taskName2 = new TaskName("task2");
-
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    when(offsetManager.getLastProcessedOffset(taskName1, ssp1)).thenReturn(Option.apply("3"));
-    when(offsetManager.getLastProcessedOffset(taskName2, ssp2)).thenReturn(Option.apply("0"));
-    when(offsetManager.getStartingOffset(taskName1, ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
-    when(offsetManager.getStartingOffset(taskName2, ssp2)).thenReturn(Option.apply("1"));
-    when(offsetManager.getStartpoint(anyObject(), anyObject())).thenReturn(Option.empty());
-
-    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers);
-    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName1, taskInstance1);
-    tasks.put(taskName2, taskInstance2);
-
-    taskInstance1.registerConsumers();
-    taskInstance2.registerConsumers();
-    consumers.start();
-
-    int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumers, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-  }
-
-  //@Test
-  public void testCommitBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
+  public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
+    int maxMessagesInFlight = 2;
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    int maxMessagesInFlight = 3;
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, false, null, maxMessagesInFlight);
-
-    IncomingMessageEnvelope firstMsg = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-    IncomingMessageEnvelope secondMsg = new IncomingMessageEnvelope(ssp0, "1", "key1", "value1");
-    IncomingMessageEnvelope thirdMsg = new IncomingMessageEnvelope(ssp0, "2", "key0", "value0");
-
-    final CountDownLatch firstMsgCompletionLatch = new CountDownLatch(1);
-    final CountDownLatch secondMsgCompletionLatch = new CountDownLatch(1);
-    task0.callbackHandler = callback -> {
-      IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-      try {
-        if (envelope.equals(firstMsg)) {
-          firstMsgCompletionLatch.await();
-        } else if (envelope.equals(secondMsg)) {
-          firstMsgCompletionLatch.countDown();
-          secondMsgCompletionLatch.await();
-        } else if (envelope.equals(thirdMsg)) {
-          secondMsgCompletionLatch.countDown();
-          // OffsetManager.update with firstMsg offset, task.commit has happened when second message callback has not completed.
-          verify(offsetManager).update(eq(taskName0), eq(firstMsg.getSystemStreamPartition()), eq(firstMsg.getOffset()));
-        }
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(firstMsg).thenReturn(secondMsg).thenReturn(thirdMsg).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-
-    firstMsgCompletionLatch.await();
-    secondMsgCompletionLatch.await();
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    CountDownLatch secondMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            // let the first message proceed to ask for a commit
+            firstMessageBarrier.countDown();
+            // block this message until commit is executed
+            secondMessageBarrier.await();
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    verify(offsetManager, atLeastOnce()).buildCheckpoint(eq(taskName0));
-    verify(offsetManager, atLeastOnce()).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    assertEquals(3, task0.processed);
-    assertEquals(3, task0.committed);
-    assertEquals(1, task1.processed);
-    assertEquals(0, task1.committed);
-  }
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-  @Test
-  public void testProcessBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
-    int maxMessagesInFlight = 2;
+        secondMessageBarrier.countDown();
+        return null;
+      }).when(task0).commit();
 
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    CountDownLatch commitLatch = new CountDownLatch(1);
-    task0.commitHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope3)) {
-        try {
-          commitLatch.await();
-        } catch (InterruptedException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-
-    task0.callbackHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope0)) {
-        // Both the process call has gone through when the first commit is in progress.
-        assertEquals(2, containerMetrics.processes().getCount());
-        assertEquals(0, containerMetrics.commits().getCount());
-        commitLatch.countDown();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope3).thenReturn(envelope0).thenReturn(ssp0EndOfStream).thenReturn(null);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
-                                            () -> 0L, true);
-
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    commitLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).commit();

Review comment:
       `RunLoopTask.process` and `RunLoopTask.commit` are all invoked from the same thread, and should be completed in this order.
   
   Are you referring to callback completion order?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432043551



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -390,23 +401,21 @@ public void testCommitSingleTask() throws Exception {
     when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
+    TestTask task0 = spy(createTestTask(true, true, false, task0ProcessedMessagesLatch, 0, taskName0, ssp0, offsetManager));
     task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    TestTask task1 = spy(createTestTask(true, false, true, task1ProcessedMessagesLatch, 0, taskName1, ssp1, offsetManager));
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
     when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
         .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
+//            task0ProcessedMessagesLatch.await();

Review comment:
       Resolving this as I've rewritten all tests to instead use mocks, and tried to clean up cruft I found along the way while doing that.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r431492372



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -236,7 +236,7 @@ public void testProcessMultipleTasks() throws Exception {
     TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
     TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();

Review comment:
       Removed `TaskInstance` and instead had `TestTask` be an implementer of `RunLoopTask` interface.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold edited a comment on pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold edited a comment on pull request #1366:
URL: https://github.com/apache/samza/pull/1366#issuecomment-632934829


   > What would be an example use case for `RunLoopTask`, other than `TaskInstance`?
   
   @cameronlee314 
   This will be used as an entry point for side input processing to leverage RunLoop.
   
   E.g.
   https://github.com/apache/samza/pull/1343/files#diff-81b48c3d365639da045f19f1f46c138e


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] mynameborat commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
mynameborat commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r429451120



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       I have seen multiple patterns of grouping methods related to functionalities, access modifiers or just plan alphabetical.I wanted to suggest the grouping on access modifiers but pushing all the default methods down addresses that.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] cameronlee314 merged pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
cameronlee314 merged pull request #1366:
URL: https://github.com/apache/samza/pull/1366


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on pull request #1366:
URL: https://github.com/apache/samza/pull/1366#issuecomment-632934829


   > What would be an example use case for `RunLoopTask`, other than `TaskInstance`?
   
   This will be used as an entry point for side input processing to leverage RunLoop.
   
   E.g.
   https://github.com/apache/samza/pull/1343/files#diff-81b48c3d365639da045f19f1f46c138e


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432168140



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());

Review comment:
       Good point. I think this should be added to the `Answer` for the mocked `endOfStream` call. When `endOfStream` is finally called, the task should have seen all messages at that point. If we have the check later, it might be possible for the test to pass with bad behavior, e.g. if `RunLoop` were to somehow touch the process metric after `endOfStream` is called.
   
   What do you think?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432050248



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
##########
@@ -52,18 +50,6 @@ public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, Ta
 
     log.info("Got commit milliseconds: {}.", taskCommitMs);
 
-    int asyncTaskCount = taskInstances.values().count(new AbstractFunction1<TaskInstance, Object>() {
-      @Override
-      public Boolean apply(TaskInstance t) {
-        return t.isAsyncTask();
-      }
-    });
-
-    // asyncTaskCount should be either 0 or the number of all taskInstances
-    if (asyncTaskCount > 0 && asyncTaskCount < taskInstances.size()) {
-      throw new SamzaException("Mixing StreamTask and AsyncStreamTask is not supported");
-    }

Review comment:
       Updated PR description with these details.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432180833



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    verify(task0).endOfStream(any());
   }
 
   @Test
-  public void testEndOfStreamCommitBehavior() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamCommitBehavior() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask task0 = new TestTask(true, false, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    inOrder.verify(task0).endOfStream(any());
+    inOrder.verify(task0).commit();
   }
 
   @Test
-  public void testEndOfStreamOffsetManagement() throws Exception {
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask mockStreamTask1 = new TestTask(true, false, false, null);
-    TestTask mockStreamTask2 = new TestTask(true, false, false, null);
-
-    Partition p1 = new Partition(1);
-    Partition p2 = new Partition(2);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system1", "stream1", p1);
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system1", "stream2", p2);
-    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", "key1", "message1");
-    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", "key1", "message1");
-    IncomingMessageEnvelope envelope3 = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
-
-    Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new HashMap<>();
-    List<IncomingMessageEnvelope> messageList = new ArrayList<>();
-    messageList.add(envelope1);
-    messageList.add(envelope2);
-    messageList.add(envelope3);
-    sspMap.put(ssp2, messageList);
-
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
-
-    SystemAdmins systemAdmins = Mockito.mock(SystemAdmins.class);
-    Mockito.when(systemAdmins.getSystemAdmin("system1")).thenReturn(Mockito.mock(SystemAdmin.class));
-    Mockito.when(systemAdmins.getSystemAdmin("testSystem")).thenReturn(Mockito.mock(SystemAdmin.class));
-
-    HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
-    systemConsumerMap.put("system1", mockConsumer);
-
-    SystemConsumers consumers = TestSystemConsumers.getSystemConsumers(systemConsumerMap, systemAdmins);
-
-    TaskName taskName1 = new TaskName("task1");
-    TaskName taskName2 = new TaskName("task2");
-
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    when(offsetManager.getLastProcessedOffset(taskName1, ssp1)).thenReturn(Option.apply("3"));
-    when(offsetManager.getLastProcessedOffset(taskName2, ssp2)).thenReturn(Option.apply("0"));
-    when(offsetManager.getStartingOffset(taskName1, ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
-    when(offsetManager.getStartingOffset(taskName2, ssp2)).thenReturn(Option.apply("1"));
-    when(offsetManager.getStartpoint(anyObject(), anyObject())).thenReturn(Option.empty());
-
-    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers);
-    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName1, taskInstance1);
-    tasks.put(taskName2, taskInstance2);
-
-    taskInstance1.registerConsumers();
-    taskInstance2.registerConsumers();
-    consumers.start();
-
-    int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumers, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-  }
-
-  //@Test
-  public void testCommitBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
+  public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
+    int maxMessagesInFlight = 2;
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    int maxMessagesInFlight = 3;
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, false, null, maxMessagesInFlight);
-
-    IncomingMessageEnvelope firstMsg = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-    IncomingMessageEnvelope secondMsg = new IncomingMessageEnvelope(ssp0, "1", "key1", "value1");
-    IncomingMessageEnvelope thirdMsg = new IncomingMessageEnvelope(ssp0, "2", "key0", "value0");
-
-    final CountDownLatch firstMsgCompletionLatch = new CountDownLatch(1);
-    final CountDownLatch secondMsgCompletionLatch = new CountDownLatch(1);
-    task0.callbackHandler = callback -> {
-      IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-      try {
-        if (envelope.equals(firstMsg)) {
-          firstMsgCompletionLatch.await();
-        } else if (envelope.equals(secondMsg)) {
-          firstMsgCompletionLatch.countDown();
-          secondMsgCompletionLatch.await();
-        } else if (envelope.equals(thirdMsg)) {
-          secondMsgCompletionLatch.countDown();
-          // OffsetManager.update with firstMsg offset, task.commit has happened when second message callback has not completed.
-          verify(offsetManager).update(eq(taskName0), eq(firstMsg.getSystemStreamPartition()), eq(firstMsg.getOffset()));
-        }
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(firstMsg).thenReturn(secondMsg).thenReturn(thirdMsg).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-
-    firstMsgCompletionLatch.await();
-    secondMsgCompletionLatch.await();
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    CountDownLatch secondMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            // let the first message proceed to ask for a commit
+            firstMessageBarrier.countDown();
+            // block this message until commit is executed
+            secondMessageBarrier.await();
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    verify(offsetManager, atLeastOnce()).buildCheckpoint(eq(taskName0));
-    verify(offsetManager, atLeastOnce()).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    assertEquals(3, task0.processed);
-    assertEquals(3, task0.committed);
-    assertEquals(1, task1.processed);
-    assertEquals(0, task1.committed);
-  }
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-  @Test
-  public void testProcessBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
-    int maxMessagesInFlight = 2;
+        secondMessageBarrier.countDown();
+        return null;
+      }).when(task0).commit();
 
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    CountDownLatch commitLatch = new CountDownLatch(1);
-    task0.commitHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope3)) {
-        try {
-          commitLatch.await();
-        } catch (InterruptedException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-
-    task0.callbackHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope0)) {
-        // Both the process call has gone through when the first commit is in progress.
-        assertEquals(2, containerMetrics.processes().getCount());
-        assertEquals(0, containerMetrics.commits().getCount());
-        commitLatch.countDown();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope3).thenReturn(envelope0).thenReturn(ssp0EndOfStream).thenReturn(null);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
-                                            () -> 0L, true);
-
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    commitLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).commit();

Review comment:
       Cool.
   
   One thing to note is that the current set of tests does not verify what happens when `RunLoop` is passed a non-null executor and can execute window/commit/scheduler in a separate thread from process. This was the case before as well but wanted to call that out.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r431484365



##########
File path: samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
##########
@@ -181,22 +181,10 @@ class TaskInstance(
       trace("Processing incoming message envelope for taskName and SSP: %s, %s"
         format (taskName, incomingMessageSsp))
 
-      if (isAsyncTask) {

Review comment:
       See my other comment - this is check is now redundant because `StreamTask` is no longer used internally but instead always gets wrapped by `AsyncStreamTaskAdapter`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] cameronlee314 commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
cameronlee314 commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432179780



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    verify(task0).endOfStream(any());
   }
 
   @Test
-  public void testEndOfStreamCommitBehavior() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamCommitBehavior() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask task0 = new TestTask(true, false, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    inOrder.verify(task0).endOfStream(any());
+    inOrder.verify(task0).commit();
   }
 
   @Test
-  public void testEndOfStreamOffsetManagement() throws Exception {
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask mockStreamTask1 = new TestTask(true, false, false, null);
-    TestTask mockStreamTask2 = new TestTask(true, false, false, null);
-
-    Partition p1 = new Partition(1);
-    Partition p2 = new Partition(2);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system1", "stream1", p1);
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system1", "stream2", p2);
-    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", "key1", "message1");
-    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", "key1", "message1");
-    IncomingMessageEnvelope envelope3 = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
-
-    Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new HashMap<>();
-    List<IncomingMessageEnvelope> messageList = new ArrayList<>();
-    messageList.add(envelope1);
-    messageList.add(envelope2);
-    messageList.add(envelope3);
-    sspMap.put(ssp2, messageList);
-
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
-
-    SystemAdmins systemAdmins = Mockito.mock(SystemAdmins.class);
-    Mockito.when(systemAdmins.getSystemAdmin("system1")).thenReturn(Mockito.mock(SystemAdmin.class));
-    Mockito.when(systemAdmins.getSystemAdmin("testSystem")).thenReturn(Mockito.mock(SystemAdmin.class));
-
-    HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
-    systemConsumerMap.put("system1", mockConsumer);
-
-    SystemConsumers consumers = TestSystemConsumers.getSystemConsumers(systemConsumerMap, systemAdmins);
-
-    TaskName taskName1 = new TaskName("task1");
-    TaskName taskName2 = new TaskName("task2");
-
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    when(offsetManager.getLastProcessedOffset(taskName1, ssp1)).thenReturn(Option.apply("3"));
-    when(offsetManager.getLastProcessedOffset(taskName2, ssp2)).thenReturn(Option.apply("0"));
-    when(offsetManager.getStartingOffset(taskName1, ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
-    when(offsetManager.getStartingOffset(taskName2, ssp2)).thenReturn(Option.apply("1"));
-    when(offsetManager.getStartpoint(anyObject(), anyObject())).thenReturn(Option.empty());
-
-    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers);
-    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName1, taskInstance1);
-    tasks.put(taskName2, taskInstance2);
-
-    taskInstance1.registerConsumers();
-    taskInstance2.registerConsumers();
-    consumers.start();
-
-    int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumers, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-  }
-
-  //@Test
-  public void testCommitBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
+  public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
+    int maxMessagesInFlight = 2;
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    int maxMessagesInFlight = 3;
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, false, null, maxMessagesInFlight);
-
-    IncomingMessageEnvelope firstMsg = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-    IncomingMessageEnvelope secondMsg = new IncomingMessageEnvelope(ssp0, "1", "key1", "value1");
-    IncomingMessageEnvelope thirdMsg = new IncomingMessageEnvelope(ssp0, "2", "key0", "value0");
-
-    final CountDownLatch firstMsgCompletionLatch = new CountDownLatch(1);
-    final CountDownLatch secondMsgCompletionLatch = new CountDownLatch(1);
-    task0.callbackHandler = callback -> {
-      IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-      try {
-        if (envelope.equals(firstMsg)) {
-          firstMsgCompletionLatch.await();
-        } else if (envelope.equals(secondMsg)) {
-          firstMsgCompletionLatch.countDown();
-          secondMsgCompletionLatch.await();
-        } else if (envelope.equals(thirdMsg)) {
-          secondMsgCompletionLatch.countDown();
-          // OffsetManager.update with firstMsg offset, task.commit has happened when second message callback has not completed.
-          verify(offsetManager).update(eq(taskName0), eq(firstMsg.getSystemStreamPartition()), eq(firstMsg.getOffset()));
-        }
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(firstMsg).thenReturn(secondMsg).thenReturn(thirdMsg).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-
-    firstMsgCompletionLatch.await();
-    secondMsgCompletionLatch.await();
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    CountDownLatch secondMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            // let the first message proceed to ask for a commit
+            firstMessageBarrier.countDown();
+            // block this message until commit is executed
+            secondMessageBarrier.await();
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    verify(offsetManager, atLeastOnce()).buildCheckpoint(eq(taskName0));
-    verify(offsetManager, atLeastOnce()).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    assertEquals(3, task0.processed);
-    assertEquals(3, task0.committed);
-    assertEquals(1, task1.processed);
-    assertEquals(0, task1.committed);
-  }
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-  @Test
-  public void testProcessBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
-    int maxMessagesInFlight = 2;
+        secondMessageBarrier.countDown();
+        return null;
+      }).when(task0).commit();
 
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    CountDownLatch commitLatch = new CountDownLatch(1);
-    task0.commitHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope3)) {
-        try {
-          commitLatch.await();
-        } catch (InterruptedException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-
-    task0.callbackHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope0)) {
-        // Both the process call has gone through when the first commit is in progress.
-        assertEquals(2, containerMetrics.processes().getCount());
-        assertEquals(0, containerMetrics.commits().getCount());
-        commitLatch.countDown();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope3).thenReturn(envelope0).thenReturn(ssp0EndOfStream).thenReturn(null);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
-                                            () -> 0L, true);
-
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    commitLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).commit();

Review comment:
       Oh, I see. I got confused between the `executor` and the `taskExecutor`. I was thinking the thread pool was running the `RunLoopTask.process` and `RunLoopTask.commit`, but that's not the case.
   My comment doesn't apply then.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] prateekm commented on pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
prateekm commented on pull request #1366:
URL: https://github.com/apache/samza/pull/1366#issuecomment-632841552


   Thanks for the cleanup. Please get a review from @cameronlee314 and @bharathkk on this.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] cameronlee314 commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
cameronlee314 commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432127546



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");

Review comment:
       Nit: Should this be `envelope11`?

##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());

Review comment:
       Would it be good to use containerMetrics to verify that you saw all of the messages?

##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);

Review comment:
       Should this be `when(task1.taskName()).thenReturn(taskName1);`? Although if the test passed, then does that mean the test doesn't need the result of `taskName()`?

##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    verify(task0).endOfStream(any());
   }
 
   @Test
-  public void testEndOfStreamCommitBehavior() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamCommitBehavior() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask task0 = new TestTask(true, false, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    inOrder.verify(task0).endOfStream(any());
+    inOrder.verify(task0).commit();
   }
 
   @Test
-  public void testEndOfStreamOffsetManagement() throws Exception {
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask mockStreamTask1 = new TestTask(true, false, false, null);
-    TestTask mockStreamTask2 = new TestTask(true, false, false, null);
-
-    Partition p1 = new Partition(1);
-    Partition p2 = new Partition(2);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system1", "stream1", p1);
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system1", "stream2", p2);
-    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", "key1", "message1");
-    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", "key1", "message1");
-    IncomingMessageEnvelope envelope3 = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
-
-    Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new HashMap<>();
-    List<IncomingMessageEnvelope> messageList = new ArrayList<>();
-    messageList.add(envelope1);
-    messageList.add(envelope2);
-    messageList.add(envelope3);
-    sspMap.put(ssp2, messageList);
-
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
-
-    SystemAdmins systemAdmins = Mockito.mock(SystemAdmins.class);
-    Mockito.when(systemAdmins.getSystemAdmin("system1")).thenReturn(Mockito.mock(SystemAdmin.class));
-    Mockito.when(systemAdmins.getSystemAdmin("testSystem")).thenReturn(Mockito.mock(SystemAdmin.class));
-
-    HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
-    systemConsumerMap.put("system1", mockConsumer);
-
-    SystemConsumers consumers = TestSystemConsumers.getSystemConsumers(systemConsumerMap, systemAdmins);
-
-    TaskName taskName1 = new TaskName("task1");
-    TaskName taskName2 = new TaskName("task2");
-
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    when(offsetManager.getLastProcessedOffset(taskName1, ssp1)).thenReturn(Option.apply("3"));
-    when(offsetManager.getLastProcessedOffset(taskName2, ssp2)).thenReturn(Option.apply("0"));
-    when(offsetManager.getStartingOffset(taskName1, ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
-    when(offsetManager.getStartingOffset(taskName2, ssp2)).thenReturn(Option.apply("1"));
-    when(offsetManager.getStartpoint(anyObject(), anyObject())).thenReturn(Option.empty());
-
-    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers);
-    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName1, taskInstance1);
-    tasks.put(taskName2, taskInstance2);
-
-    taskInstance1.registerConsumers();
-    taskInstance2.registerConsumers();
-    consumers.start();
-
-    int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumers, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-  }
-
-  //@Test
-  public void testCommitBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
+  public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
+    int maxMessagesInFlight = 2;
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    int maxMessagesInFlight = 3;
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, false, null, maxMessagesInFlight);
-
-    IncomingMessageEnvelope firstMsg = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-    IncomingMessageEnvelope secondMsg = new IncomingMessageEnvelope(ssp0, "1", "key1", "value1");
-    IncomingMessageEnvelope thirdMsg = new IncomingMessageEnvelope(ssp0, "2", "key0", "value0");
-
-    final CountDownLatch firstMsgCompletionLatch = new CountDownLatch(1);
-    final CountDownLatch secondMsgCompletionLatch = new CountDownLatch(1);
-    task0.callbackHandler = callback -> {
-      IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-      try {
-        if (envelope.equals(firstMsg)) {
-          firstMsgCompletionLatch.await();
-        } else if (envelope.equals(secondMsg)) {
-          firstMsgCompletionLatch.countDown();
-          secondMsgCompletionLatch.await();
-        } else if (envelope.equals(thirdMsg)) {
-          secondMsgCompletionLatch.countDown();
-          // OffsetManager.update with firstMsg offset, task.commit has happened when second message callback has not completed.
-          verify(offsetManager).update(eq(taskName0), eq(firstMsg.getSystemStreamPartition()), eq(firstMsg.getOffset()));
-        }
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(firstMsg).thenReturn(secondMsg).thenReturn(thirdMsg).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-
-    firstMsgCompletionLatch.await();
-    secondMsgCompletionLatch.await();
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    CountDownLatch secondMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            // let the first message proceed to ask for a commit
+            firstMessageBarrier.countDown();
+            // block this message until commit is executed
+            secondMessageBarrier.await();
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    verify(offsetManager, atLeastOnce()).buildCheckpoint(eq(taskName0));
-    verify(offsetManager, atLeastOnce()).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    assertEquals(3, task0.processed);
-    assertEquals(3, task0.committed);
-    assertEquals(1, task1.processed);
-    assertEquals(0, task1.committed);
-  }
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-  @Test
-  public void testProcessBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
-    int maxMessagesInFlight = 2;
+        secondMessageBarrier.countDown();
+        return null;
+      }).when(task0).commit();
 
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    CountDownLatch commitLatch = new CountDownLatch(1);
-    task0.commitHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope3)) {
-        try {
-          commitLatch.await();
-        } catch (InterruptedException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-
-    task0.callbackHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope0)) {
-        // Both the process call has gone through when the first commit is in progress.
-        assertEquals(2, containerMetrics.processes().getCount());
-        assertEquals(0, containerMetrics.commits().getCount());
-        commitLatch.countDown();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope3).thenReturn(envelope0).thenReturn(ssp0EndOfStream).thenReturn(null);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
-                                            () -> 0L, true);
-
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    commitLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).commit();

Review comment:
       Do you know if `InOrder` keeps track of when a method initially is called or when a method finishes execution? If it keeps track of the initial call, then this seems ok. However, if it tracks when the method finishes, then this test could end up being flaky, since it looks possible for `commit` to finish before the second `process` is done.
   Maybe `InOrder` isn't really necessary for this test, since the latches take care of the ordering.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r429421367



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       yes, i'll add add a description to the class declaration.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432180833



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task1).process(eq(envelope10), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(2L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
+    assertEquals(4L, containerMetrics.envelopes().getCount());
   }
 
   @Test
-  public void testProcessInOrder() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
+  public void testProcessInOrder() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics taskMetrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(taskMetrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
+    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
     runLoop.run();
 
-    // Wait till the tasks completes processing all the messages.
-    task0ProcessedMessages.await();
-    task1ProcessedMessages.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
-    assertEquals(2L, t0.metrics().asyncCallbackCompleted().getCount());
-    assertEquals(1L, t1.metrics().asyncCallbackCompleted().getCount());
-  }
-
-  private TestCode buildOutofOrderCallback(final TestTask task) {
-    final CountDownLatch latch = new CountDownLatch(1);
-    return new TestCode() {
-      @Override
-      public void run(TaskCallback callback) {
-        IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-        if (envelope.equals(envelope0)) {
-          // process first message will wait till the second one is processed
-          try {
-            latch.await();
-          } catch (InterruptedException e) {
-            e.printStackTrace();
-          }
-        } else {
-          // second envelope complete first
-          assertEquals(0, task.completed.get());
-          latch.countDown();
-        }
-      }
-    };
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
   }
 
   @Test
-  public void testProcessOutOfOrder() throws Exception {
+  public void testProcessCallbacksCompletedOutOfOrder() {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+
+    InOrder inOrderOffsetManager = inOrder(offsetManager);
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+    inOrderOffsetManager.verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope01.getOffset()));
 
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    assertEquals(3L, containerMetrics.envelopes().getCount());
+    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testWindow() throws Exception {
-    TestTask task0 = new TestTask(true, true, false, null);
-    TestTask task1 = new TestTask(true, false, true, null);
-
+  public void testWindow() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    int maxMessagesInFlight = 1;
+    long windowMs = 1;
+    RunLoopTask task = mock(RunLoopTask.class);
+    when(task.isWindowableTask()).thenReturn(true);
+
+    final AtomicInteger windowCount = new AtomicInteger(0);
+    doAnswer(x -> {
+        windowCount.incrementAndGet();
+        if (windowCount.get() == 4) {
+          x.getArgumentAt(0, ReadableCoordinator.class).shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+        Thread.sleep(windowMs);
+        return null;
+      }).when(task).window(any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task);
 
-    long windowMs = 1;
-    int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
-    assertEquals(4, task1.windowCount);
+    verify(task, times(4)).window(any());
   }
 
   @Test
-  public void testCommitSingleTask() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitSingleTask() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager, never()).buildCheckpoint(eq(taskName1));
-    verify(offsetManager, never()).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1, never()).commit();
   }
 
   @Test
-  public void testCommitAllTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testCommitAllTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+        coordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+        coordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
+
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(this.taskName0, task0);
+    tasks.put(taskName1, task1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
-        .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
-            return null;
-          }).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
+
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    verify(task0).commit();
+    verify(task1).commit();
   }
 
   @Test
-  public void testShutdownOnConsensus() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testShutdownOnConsensus() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    TestTask task0 = new TestTask(true, true, true, task0ProcessedMessagesLatch);
-    task0.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    task1.setShutdownRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
 
     int maxMessagesInFlight = 1;
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+
+        TaskCallback callback = callbackFactory.createCallback();
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        callback.complete();
+        return null;
+      }).when(task1).process(eq(envelope10), any(), any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     // consensus is reached after envelope1 is processed.
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope10).thenReturn(null);
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(any(), any(), any());
+    verify(task1).process(any(), any(), any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(2L, containerMetrics.envelopes().getCount());
     assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamWithMultipleTasks() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
                                             () -> 0L, false);
     when(consumerMultiplexer.choose(false))
-      .thenReturn(envelope0)
-      .thenReturn(envelope1)
+      .thenReturn(envelope00)
+      .thenReturn(envelope10)
       .thenReturn(ssp0EndOfStream)
       .thenReturn(ssp1EndOfStream)
       .thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    verify(task0).process(eq(envelope00), any(), any());
+    verify(task0).endOfStream(any());
+
+    verify(task1).process(eq(envelope10), any(), any());
+    verify(task1).endOfStream(any());
 
-    assertEquals(1, task0.processed);
-    assertEquals(1, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
     assertEquals(4L, containerMetrics.envelopes().getCount());
-    assertEquals(2L, containerMetrics.processes().getCount());
   }
 
   @Test
-  public void testEndOfStreamWithOutOfOrderProcess() throws Exception {
+  public void testEndOfStreamWaitsForInFlightMessages() throws Exception {
     int maxMessagesInFlight = 2;
-
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(2);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch, maxMessagesInFlight);
-    TestTask task1 = new TestTask(true, true, false, task1ProcessedMessagesLatch, maxMessagesInFlight);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+    doAnswer(invocation -> {
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+        callback.complete();
+        firstMessageBarrier.countDown();
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    doAnswer(invocation -> {
+        assertEquals(0, task0.metrics().messagesInFlight().getValue());
+        return null;
+      }).when(task0).endOfStream(any());
+
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    task0.callbackHandler = buildOutofOrderCallback(task0);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope3).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+        .thenAnswer(invocation -> {
+            // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
+            firstMessageBarrier.countDown();
+            return null;
+          });
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
-
-    assertEquals(2, task0.processed);
-    assertEquals(2, task0.completed.get());
-    assertEquals(1, task1.processed);
-    assertEquals(1, task1.completed.get());
-    assertEquals(5L, containerMetrics.envelopes().getCount());
-    assertEquals(3L, containerMetrics.processes().getCount());
+    verify(task0).endOfStream(any());
   }
 
   @Test
-  public void testEndOfStreamCommitBehavior() throws Exception {
-    CountDownLatch task0ProcessedMessagesLatch = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessagesLatch = new CountDownLatch(1);
-
+  public void testEndOfStreamCommitBehavior() {
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask task0 = new TestTask(true, false, false, task0ProcessedMessagesLatch);
-    TestTask task1 = new TestTask(true, false, false, task1ProcessedMessagesLatch);
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
+
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+        return null;
+      }).when(task0).endOfStream(any());
 
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope0).thenReturn(envelope1).thenReturn(null).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
 
     runLoop.run();
 
-    task0ProcessedMessagesLatch.await();
-    task1ProcessedMessagesLatch.await();
+    InOrder inOrder = inOrder(task0);
 
-    verify(offsetManager).buildCheckpoint(eq(taskName0));
-    verify(offsetManager).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    verify(offsetManager).buildCheckpoint(eq(taskName1));
-    verify(offsetManager).writeCheckpoint(eq(taskName1), any(Checkpoint.class));
+    inOrder.verify(task0).endOfStream(any());
+    inOrder.verify(task0).commit();
   }
 
   @Test
-  public void testEndOfStreamOffsetManagement() throws Exception {
-    //explicitly configure to disable commits inside process or window calls and invoke commit from end of stream
-    TestTask mockStreamTask1 = new TestTask(true, false, false, null);
-    TestTask mockStreamTask2 = new TestTask(true, false, false, null);
-
-    Partition p1 = new Partition(1);
-    Partition p2 = new Partition(2);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system1", "stream1", p1);
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system1", "stream2", p2);
-    IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp2, "1", "key1", "message1");
-    IncomingMessageEnvelope envelope2 = new IncomingMessageEnvelope(ssp2, "2", "key1", "message1");
-    IncomingMessageEnvelope envelope3 = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp2);
-
-    Map<SystemStreamPartition, List<IncomingMessageEnvelope>> sspMap = new HashMap<>();
-    List<IncomingMessageEnvelope> messageList = new ArrayList<>();
-    messageList.add(envelope1);
-    messageList.add(envelope2);
-    messageList.add(envelope3);
-    sspMap.put(ssp2, messageList);
-
-    SystemConsumer mockConsumer = mock(SystemConsumer.class);
-    when(mockConsumer.poll(anyObject(), anyLong())).thenReturn(sspMap);
-
-    SystemAdmins systemAdmins = Mockito.mock(SystemAdmins.class);
-    Mockito.when(systemAdmins.getSystemAdmin("system1")).thenReturn(Mockito.mock(SystemAdmin.class));
-    Mockito.when(systemAdmins.getSystemAdmin("testSystem")).thenReturn(Mockito.mock(SystemAdmin.class));
-
-    HashMap<String, SystemConsumer> systemConsumerMap = new HashMap<>();
-    systemConsumerMap.put("system1", mockConsumer);
-
-    SystemConsumers consumers = TestSystemConsumers.getSystemConsumers(systemConsumerMap, systemAdmins);
-
-    TaskName taskName1 = new TaskName("task1");
-    TaskName taskName2 = new TaskName("task2");
-
-    OffsetManager offsetManager = mock(OffsetManager.class);
-
-    when(offsetManager.getLastProcessedOffset(taskName1, ssp1)).thenReturn(Option.apply("3"));
-    when(offsetManager.getLastProcessedOffset(taskName2, ssp2)).thenReturn(Option.apply("0"));
-    when(offsetManager.getStartingOffset(taskName1, ssp1)).thenReturn(Option.apply(IncomingMessageEnvelope.END_OF_STREAM_OFFSET));
-    when(offsetManager.getStartingOffset(taskName2, ssp2)).thenReturn(Option.apply("1"));
-    when(offsetManager.getStartpoint(anyObject(), anyObject())).thenReturn(Option.empty());
-
-    TaskInstance taskInstance1 = createTaskInstance(mockStreamTask1, taskName1, ssp1, offsetManager, consumers);
-    TaskInstance taskInstance2 = createTaskInstance(mockStreamTask2, taskName2, ssp2, offsetManager, consumers);
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName1, taskInstance1);
-    tasks.put(taskName2, taskInstance2);
-
-    taskInstance1.registerConsumers();
-    taskInstance2.registerConsumers();
-    consumers.start();
-
-    int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumers, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-  }
-
-  //@Test
-  public void testCommitBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
+  public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
+    int maxMessagesInFlight = 2;
+    ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    int maxMessagesInFlight = 3;
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, false, null, maxMessagesInFlight);
-
-    IncomingMessageEnvelope firstMsg = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-    IncomingMessageEnvelope secondMsg = new IncomingMessageEnvelope(ssp0, "1", "key1", "value1");
-    IncomingMessageEnvelope thirdMsg = new IncomingMessageEnvelope(ssp0, "2", "key0", "value0");
-
-    final CountDownLatch firstMsgCompletionLatch = new CountDownLatch(1);
-    final CountDownLatch secondMsgCompletionLatch = new CountDownLatch(1);
-    task0.callbackHandler = callback -> {
-      IncomingMessageEnvelope envelope = ((TaskCallbackImpl) callback).getEnvelope();
-      try {
-        if (envelope.equals(firstMsg)) {
-          firstMsgCompletionLatch.await();
-        } else if (envelope.equals(secondMsg)) {
-          firstMsgCompletionLatch.countDown();
-          secondMsgCompletionLatch.await();
-        } else if (envelope.equals(thirdMsg)) {
-          secondMsgCompletionLatch.countDown();
-          // OffsetManager.update with firstMsg offset, task.commit has happened when second message callback has not completed.
-          verify(offsetManager).update(eq(taskName0), eq(firstMsg.getSystemStreamPartition()), eq(firstMsg.getOffset()));
-        }
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    tasks.put(taskName1, createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(firstMsg).thenReturn(secondMsg).thenReturn(thirdMsg).thenReturn(envelope1).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
-
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
-
-    runLoop.run();
-
-    firstMsgCompletionLatch.await();
-    secondMsgCompletionLatch.await();
+    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+    when(task0.offsetManager()).thenReturn(offsetManager);
+    CountDownLatch firstMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            firstMessageBarrier.await();
+            coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope00), any(), any());
+
+    CountDownLatch secondMessageBarrier = new CountDownLatch(1);
+    doAnswer(invocation -> {
+        ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
+        TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+        TaskCallback callback = callbackFactory.createCallback();
+
+        taskExecutor.submit(() -> {
+            // let the first message proceed to ask for a commit
+            firstMessageBarrier.countDown();
+            // block this message until commit is executed
+            secondMessageBarrier.await();
+            coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+            callback.complete();
+            return null;
+          });
+        return null;
+      }).when(task0).process(eq(envelope01), any(), any());
 
-    verify(offsetManager, atLeastOnce()).buildCheckpoint(eq(taskName0));
-    verify(offsetManager, atLeastOnce()).writeCheckpoint(eq(taskName0), any(Checkpoint.class));
-    assertEquals(3, task0.processed);
-    assertEquals(3, task0.committed);
-    assertEquals(1, task1.processed);
-    assertEquals(0, task1.committed);
-  }
+    doAnswer(invocation -> {
+        assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
+        assertEquals(1, task0.metrics().messagesInFlight().getValue());
 
-  @Test
-  public void testProcessBehaviourWhenAsyncCommitIsEnabled() throws InterruptedException {
-    int maxMessagesInFlight = 2;
+        secondMessageBarrier.countDown();
+        return null;
+      }).when(task0).commit();
 
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
 
-    TestTask task0 = new TestTask(true, true, false, null, maxMessagesInFlight);
-    CountDownLatch commitLatch = new CountDownLatch(1);
-    task0.commitHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope3)) {
-        try {
-          commitLatch.await();
-        } catch (InterruptedException e) {
-          e.printStackTrace();
-        }
-      }
-    };
-
-    task0.callbackHandler = callback -> {
-      TaskCallbackImpl taskCallback = (TaskCallbackImpl) callback;
-      if (taskCallback.getEnvelope().equals(envelope0)) {
-        // Both the process call has gone through when the first commit is in progress.
-        assertEquals(2, containerMetrics.processes().getCount());
-        assertEquals(0, containerMetrics.commits().getCount());
-        commitLatch.countDown();
-      }
-    };
-
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-
-    tasks.put(taskName0, createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer));
-    when(consumerMultiplexer.choose(false)).thenReturn(envelope3).thenReturn(envelope0).thenReturn(ssp0EndOfStream).thenReturn(null);
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics,
-                                            () -> 0L, true);
-
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
+    when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
     runLoop.run();
 
-    commitLatch.await();
+    InOrder inOrder = inOrder(task0);
+    inOrder.verify(task0).process(eq(envelope00), any(), any());
+    inOrder.verify(task0).process(eq(envelope01), any(), any());
+    inOrder.verify(task0).commit();

Review comment:
       Cool.
   
   One thing to note is that the current set of tests does not verify behavior when `RunLoop` is passed a non-null executor and can execute window/commit/scheduler in a separate thread from process. This was the case before as well but wanted to call that out.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r431491471



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -390,23 +401,21 @@ public void testCommitSingleTask() throws Exception {
     when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
+    TestTask task0 = spy(createTestTask(true, true, false, task0ProcessedMessagesLatch, 0, taskName0, ssp0, offsetManager));
     task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    TestTask task1 = spy(createTestTask(true, false, true, task1ProcessedMessagesLatch, 0, taskName1, ssp1, offsetManager));
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
     when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
         .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
+//            task0ProcessedMessagesLatch.await();

Review comment:
       I was fiddling around with this as I was unsure whether the test actually required it (I see lots of copy pasta in this file). Looks like I also left one of the tests commented out as well.
   
   I'll comb back through to correct these.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] cameronlee314 commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
cameronlee314 commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r431474125



##########
File path: samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
##########
@@ -181,22 +181,10 @@ class TaskInstance(
       trace("Processing incoming message envelope for taskName and SSP: %s, %s"
         format (taskName, incomingMessageSsp))
 
-      if (isAsyncTask) {

Review comment:
       Just double checking here too: Is this check no longer necessary?

##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
##########
@@ -52,18 +50,6 @@ public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, Ta
 
     log.info("Got commit milliseconds: {}.", taskCommitMs);
 
-    int asyncTaskCount = taskInstances.values().count(new AbstractFunction1<TaskInstance, Object>() {
-      @Override
-      public Boolean apply(TaskInstance t) {
-        return t.isAsyncTask();
-      }
-    });
-
-    // asyncTaskCount should be either 0 or the number of all taskInstances
-    if (asyncTaskCount > 0 && asyncTaskCount < taskInstances.size()) {
-      throw new SamzaException("Mixing StreamTask and AsyncStreamTask is not supported");
-    }

Review comment:
       Was this validation moved somewhere else? Or is it no longer necessary?

##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -231,14 +251,12 @@ public void testProcessMultipleTasks() throws Exception {
     when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);

Review comment:
       Is this unused now? Could you please check the other tests to see if there are other unused variables too?

##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -390,23 +401,21 @@ public void testCommitSingleTask() throws Exception {
     when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
     OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessagesLatch);
+    TestTask task0 = spy(createTestTask(true, true, false, task0ProcessedMessagesLatch, 0, taskName0, ssp0, offsetManager));
     task0.setCommitRequest(TaskCoordinator.RequestScope.CURRENT_TASK);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessagesLatch);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    TestTask task1 = spy(createTestTask(true, false, true, task1ProcessedMessagesLatch, 0, taskName1, ssp1, offsetManager));
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+    tasks.put(taskName0, task0);
+    tasks.put(taskName1, task1);
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
                                             callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
     when(consumerMultiplexer.choose(false)).thenReturn(envelope0)
         .thenAnswer(x -> {
-            task0ProcessedMessagesLatch.await();
+//            task0ProcessedMessagesLatch.await();

Review comment:
       Was this intended to be removed?

##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -538,6 +533,9 @@ public void testEndOfStreamWithMultipleTasks() throws Exception {
     task0ProcessedMessagesLatch.await();
     task1ProcessedMessagesLatch.await();
 
+    verify(task0, times(1)).endOfStream(any());

Review comment:
       `times(1)` is the default for `verify`, so you don't need to pass as an argument.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] cameronlee314 commented on pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
cameronlee314 commented on pull request #1366:
URL: https://github.com/apache/samza/pull/1366#issuecomment-632907550


   What would be an example use case for `RunLoopTask`, other than `TaskInstance`?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r431487323



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -538,6 +533,9 @@ public void testEndOfStreamWithMultipleTasks() throws Exception {
     task0ProcessedMessagesLatch.await();
     task1ProcessedMessagesLatch.await();
 
+    verify(task0, times(1)).endOfStream(any());

Review comment:
       Thanks, didn't know this.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] mynameborat commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
mynameborat commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r429404171



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       Can we add some java docs for this class?

##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       minor: can we either group the methods and order them by ones that require implementation vs default 

##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {
+
+  TaskName taskName();
+
+  default boolean isWindowableTask() {
+    return false;
+  }
+
+  default boolean isAsyncTask() {
+    return false;
+  }
+
+  default EpochTimeScheduler epochTimeScheduler() {
+    return null;
+  }
+
+  default scala.collection.immutable.Set<String> intermediateStreams() {
+    return JavaConversions.asScalaSet(Collections.emptySet()).toSet();
+  }
+
+  scala.collection.immutable.Set<SystemStreamPartition> systemStreamPartitions();

Review comment:
       can we make this interface free of scala and have the implementors or callers adapt to scala if necessary?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] mynameborat commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
mynameborat commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r430537718



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       Let us take this opportunity to document more details about runloop and its interaction with the implementations of this interface.
   
   1. Exclusivity between process, window, scheduler, commit. & endOfStream with the exception of async commit.
   2. Not thread safe and needs synchronization of shared objects between instances, between commit & other methods in case of async commit enabled within an instance
   3. Lifecycle of this class if applicable; which brings a question on does this need `init()` and `close()`; 
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r429421558



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {
+
+  TaskName taskName();
+
+  default boolean isWindowableTask() {
+    return false;
+  }
+
+  default boolean isAsyncTask() {
+    return false;
+  }
+
+  default EpochTimeScheduler epochTimeScheduler() {
+    return null;
+  }
+
+  default scala.collection.immutable.Set<String> intermediateStreams() {
+    return JavaConversions.asScalaSet(Collections.emptySet()).toSet();
+  }
+
+  scala.collection.immutable.Set<SystemStreamPartition> systemStreamPartitions();

Review comment:
       sure. this isn't that much more of a change and makes the interactions in RunLoop cleaner.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r431484131



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
##########
@@ -52,18 +50,6 @@ public static Runnable createRunLoop(scala.collection.immutable.Map<TaskName, Ta
 
     log.info("Got commit milliseconds: {}.", taskCommitMs);
 
-    int asyncTaskCount = taskInstances.values().count(new AbstractFunction1<TaskInstance, Object>() {
-      @Override
-      public Boolean apply(TaskInstance t) {
-        return t.isAsyncTask();
-      }
-    });
-
-    // asyncTaskCount should be either 0 or the number of all taskInstances
-    if (asyncTaskCount > 0 && asyncTaskCount < taskInstances.size()) {
-      throw new SamzaException("Mixing StreamTask and AsyncStreamTask is not supported");
-    }

Review comment:
       Thanks for asking. I think I should update the PR description to include details on this.
   
   It is no longer necessary. In `SamzaContainer`, we will always wrap instances of `StreamTask` into `AsyncStreamTaskAdapter`. Thus, isAsyncTask is redundant since it is always true.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432165558



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);

Review comment:
       Good catch. I should use `getMockRunLoopTask` here instead of creating the mock inline.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432043820



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -231,14 +251,12 @@ public void testProcessMultipleTasks() throws Exception {
     when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);

Review comment:
       Done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] cameronlee314 commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
cameronlee314 commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r430587635



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       +1 regarding adding docs to each method. Ideally, there would have already been more docs in `TaskInstance`, but it is more useful now that this is an interface.

##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -236,7 +236,7 @@ public void testProcessMultipleTasks() throws Exception {
     TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
     TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();

Review comment:
       Should `RunLoopTask` now be mocked in this test instead of using a concrete `TaskInstance`? It could help to validate your extraction of the interface and simplify the test to not depend on `TaskInstance`.

##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import java.util.Set;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+
+
+/**
+ * The interface required for a task's execution to be managed within {@link RunLoop}.
+ */
+public interface RunLoopTask {
+
+  TaskName taskName();
+
+  Set<SystemStreamPartition> systemStreamPartitions();
+
+  TaskInstanceMetrics metrics();
+
+  void process(IncomingMessageEnvelope envelope, ReadableCoordinator coordinator, TaskCallbackFactory callbackFactory);
+
+  void endOfStream(ReadableCoordinator coordinator);
+
+  void window(ReadableCoordinator coordinator);
+
+  void scheduler(ReadableCoordinator coordinator);
+
+  void commit();
+
+  default boolean isWindowableTask() {
+    return false;
+  }
+
+  default boolean isAsyncTask() {
+    return false;
+  }
+
+  default EpochTimeScheduler epochTimeScheduler() {
+    return null;
+  }
+
+  default Set<String> intermediateStreams() {
+    return Collections.emptySet();
+  }
+
+  default OffsetManager offsetManager() {
+    return null;
+  }

Review comment:
       In my opinion, default implementations should be used when most of the implementors do not need to implement the methods or when you want to evolve an interface in a backwards compatible way. It doesn't sound like that is the case here, so maybe just require all implementors to implement these methods. The disadvantage of default implementations is that someone could unintentionally forget to implement them when they were supposed to implement them. Requiring implementations makes things explicit, and sometimes that is helpful.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r429418245



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       sure. i'll move the default methods to the end of the class. what do you mean by group?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r432165558



##########
File path: samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
##########
@@ -85,725 +62,484 @@
   private final TaskName taskName1 = new TaskName(p1.toString());
   private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
   private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
-  private final IncomingMessageEnvelope envelope0 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
-  private final IncomingMessageEnvelope envelope1 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
-  private final IncomingMessageEnvelope envelope3 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
+  private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
+  private final IncomingMessageEnvelope envelope10 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
+  private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
   private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
   private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
 
-  TaskInstance createTaskInstance(AsyncStreamTask task, TaskName taskName, SystemStreamPartition ssp, OffsetManager manager, SystemConsumers consumers) {
-    TaskModel taskModel = mock(TaskModel.class);
-    when(taskModel.getTaskName()).thenReturn(taskName);
-    TaskInstanceMetrics taskInstanceMetrics = new TaskInstanceMetrics("task", new MetricsRegistryMap());
-    scala.collection.immutable.Set<SystemStreamPartition> sspSet = JavaConverters.asScalaSetConverter(Collections.singleton(ssp)).asScala().toSet();
-    return new TaskInstance(task,
-        taskModel,
-        taskInstanceMetrics,
-        null,
-        consumers,
-        mock(TaskInstanceCollector.class),
-        manager,
-        null,
-        null,
-        sspSet,
-        new TaskInstanceExceptionHandler(taskInstanceMetrics, new scala.collection.immutable.HashSet<String>()),
-        null,
-        null,
-        null,
-        null,
-        mock(JobContext.class),
-        mock(ContainerContext.class),
-        Option.apply(null),
-        Option.apply(null),
-        Option.apply(null));
-  }
-
-  interface TestCode {
-    void run(TaskCallback callback);
-  }
-
-  class TestTask implements AsyncStreamTask, WindowableTask, EndOfStreamListenerTask {
-    private final boolean shutdown;
-    private final boolean commit;
-    private final boolean success;
-    private final ExecutorService callbackExecutor = Executors.newFixedThreadPool(4);
-
-    private AtomicInteger completed = new AtomicInteger(0);
-    private TestCode callbackHandler = null;
-    private TestCode commitHandler = null;
-    private TaskCoordinator.RequestScope commitRequest = null;
-    private TaskCoordinator.RequestScope shutdownRequest = TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER;
-
-    private CountDownLatch processedMessagesLatch = null;
-
-    private volatile int windowCount = 0;
-    private volatile int processed = 0;
-    private volatile int committed = 0;
-
-    private int maxMessagesInFlight;
-
-    TestTask(boolean success, boolean commit, boolean shutdown, CountDownLatch processedMessagesLatch) {
-      this.success = success;
-      this.shutdown = shutdown;
-      this.commit = commit;
-      this.processedMessagesLatch = processedMessagesLatch;
-    }
-
-    TestTask(boolean success, boolean commit, boolean shutdown,
-             CountDownLatch processedMessagesLatch, int maxMessagesInFlight) {
-      this(success, commit, shutdown, processedMessagesLatch);
-      this.maxMessagesInFlight = maxMessagesInFlight;
-    }
-
-    @Override
-    public void processAsync(IncomingMessageEnvelope envelope, MessageCollector collector, TaskCoordinator coordinator, TaskCallback callback) {
-
-      if (maxMessagesInFlight == 1) {
-        assertEquals(processed, completed.get());
-      }
-
-      processed++;
-
-      if (commit) {
-        if (commitHandler != null) {
-          callbackExecutor.submit(() -> commitHandler.run(callback));
-        }
-        if (commitRequest != null) {
-          coordinator.commit(commitRequest);
-        }
-        committed++;
-      }
-
-      if (shutdown) {
-        coordinator.shutdown(shutdownRequest);
-      }
-
-      callbackExecutor.submit(() -> {
-          if (callbackHandler != null) {
-            callbackHandler.run(callback);
-          }
-
-          completed.incrementAndGet();
-
-          if (success) {
-            callback.complete();
-          } else {
-            callback.failure(new Exception("process failure"));
-          }
-
-          if (processedMessagesLatch != null) {
-            processedMessagesLatch.countDown();
-          }
-        });
-    }
-
-    @Override
-    public void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
-      windowCount++;
-
-      if (shutdown && windowCount == 4) {
-        coordinator.shutdown(shutdownRequest);
-      }
-    }
-
-    @Override
-    public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) {
-      coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-    }
-
-    void setShutdownRequest(TaskCoordinator.RequestScope shutdownRequest) {
-      this.shutdownRequest = shutdownRequest;
-    }
-
-    void setCommitRequest(TaskCoordinator.RequestScope commitRequest) {
-      this.commitRequest = commitRequest;
-    }
-  }
-
   @Rule
   public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
 
   @Test
   public void testProcessMultipleTasks() throws Exception {
-    CountDownLatch task0ProcessedMessages = new CountDownLatch(1);
-    CountDownLatch task1ProcessedMessages = new CountDownLatch(1);
     SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.pollIntervalMs()).thenReturn(10);
-    OffsetManager offsetManager = mock(OffsetManager.class);
 
-    TestTask task0 = new TestTask(true, true, false, task0ProcessedMessages);
-    TestTask task1 = new TestTask(true, false, true, task1ProcessedMessages);
-    TaskInstance t0 = createTaskInstance(task0, taskName0, ssp0, offsetManager, consumerMultiplexer);
-    TaskInstance t1 = createTaskInstance(task1, taskName1, ssp1, offsetManager, consumerMultiplexer);
+    RunLoopTask task0 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task0Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+    when(task0.metrics()).thenReturn(task0Metrics);
+    when(task0.taskName()).thenReturn(taskName0);
 
-    Map<TaskName, TaskInstance> tasks = new HashMap<>();
-    tasks.put(taskName0, t0);
-    tasks.put(taskName1, t1);
+    RunLoopTask task1 = mock(RunLoopTask.class);
+    TaskInstanceMetrics task1Metrics = new TaskInstanceMetrics("test", new MetricsRegistryMap());
+    when(task1.systemStreamPartitions()).thenReturn(Collections.singleton(ssp1));
+    when(task1.metrics()).thenReturn(task1Metrics);
+    when(task0.taskName()).thenReturn(taskName1);

Review comment:
       Good catch. I should get using `getMockRunLoopTask` here instead of creating the mock inline.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [samza] bkonold commented on a change in pull request #1366: SAMZA-2529: Extract interface from TaskInstance for reuse of RunLoop

Posted by GitBox <gi...@apache.org>.
bkonold commented on a change in pull request #1366:
URL: https://github.com/apache/samza/pull/1366#discussion_r430684873



##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import java.util.Set;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+
+
+/**
+ * The interface required for a task's execution to be managed within {@link RunLoop}.
+ */
+public interface RunLoopTask {
+
+  TaskName taskName();
+
+  Set<SystemStreamPartition> systemStreamPartitions();
+
+  TaskInstanceMetrics metrics();
+
+  void process(IncomingMessageEnvelope envelope, ReadableCoordinator coordinator, TaskCallbackFactory callbackFactory);
+
+  void endOfStream(ReadableCoordinator coordinator);
+
+  void window(ReadableCoordinator coordinator);
+
+  void scheduler(ReadableCoordinator coordinator);
+
+  void commit();
+
+  default boolean isWindowableTask() {
+    return false;
+  }
+
+  default boolean isAsyncTask() {
+    return false;
+  }
+
+  default EpochTimeScheduler epochTimeScheduler() {
+    return null;
+  }
+
+  default Set<String> intermediateStreams() {
+    return Collections.emptySet();
+  }
+
+  default OffsetManager offsetManager() {
+    return null;
+  }

Review comment:
       I agree; in retrospect this was motivated out of convenience rather than function. I'll remove the default modifiers.

##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       I've updated with better class-level and method-level docs.

##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       Regarding life cycle... RunLoop currently does not manage life cycle of the tasks it executes. This is done at the scope of SamzaContainer (i.e. the entity creating and running the RunLoop).
   
   Since this interface is targeted to the relationship between RunLoop and the tasks it executes, I don't think life cycle management belongs here.
   
   What do you think @mynameborat @cameronlee314 ?

##########
File path: samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
##########
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.container;
+
+import java.util.Collections;
+import org.apache.samza.checkpoint.OffsetManager;
+import org.apache.samza.scheduler.EpochTimeScheduler;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.task.ReadableCoordinator;
+import org.apache.samza.task.TaskCallbackFactory;
+import scala.collection.JavaConversions;
+
+
+public interface RunLoopTask {

Review comment:
       Discussed with @mynameborat and we think that unifying on lifecycle management will (eventually) be more necessary once we've moved side inputs back to using run loop (though a different instance at first). To consolidate to only a single RunLoop instance within container, we'd want some unified way to init/close a RunLoopTask whether it be a TaskInstance or "NotYetImplementedSideInputTask".
   
   @cameronlee314 Feel free to offer your thoughts but for now I think we can consider the life cycle issue out of scope for this PR.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org