You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by bh...@apache.org on 2022/11/10 06:27:53 UTC

[samza] branch master updated: SAMZA-2765: [Pipeline Drain] Adding config for task callback timeout during drain (#1637)

This is an automated email from the ASF dual-hosted git repository.

bharathkk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a1652f9a SAMZA-2765: [Pipeline Drain] Adding config for task callback timeout during drain (#1637)
6a1652f9a is described below

commit 6a1652f9a5d0e675df434d41189b9cc60793a8de
Author: ajo thomas <aj...@linkedin.com>
AuthorDate: Wed Nov 9 22:27:47 2022 -0800

    SAMZA-2765: [Pipeline Drain] Adding config for task callback timeout during drain (#1637)
    
    Add a new config task.callback.drain.timeout.ms in TaskConfig
    Added logic to override TaskCallbackManager's timeout parameter in TaskWorkers on drain.
---
 .../samza/system/IncomingMessageEnvelope.java      |  4 +++
 .../java/org/apache/samza/config/TaskConfig.java   | 11 ++++++
 .../java/org/apache/samza/container/RunLoop.java   | 11 ++++--
 .../org/apache/samza/container/RunLoopFactory.java |  4 +++
 .../org/apache/samza/task/TaskCallbackManager.java | 30 ++++++++++++++++-
 .../samza/storage/ContainerStorageManager.java     |  1 +
 .../org/apache/samza/config/TestTaskConfig.java    |  9 +++++
 .../org/apache/samza/container/TestRunLoop.java    | 39 ++++++++++++----------
 .../apache/samza/task/TestTaskCallbackManager.java | 17 ++++++++--
 9 files changed, 103 insertions(+), 23 deletions(-)

diff --git a/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java b/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
index 39081703c..69d85d6ee 100644
--- a/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
+++ b/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
@@ -163,6 +163,10 @@ public class IncomingMessageEnvelope {
     return message != null && DrainMessage.class.isAssignableFrom(message.getClass());
   }
 
+  public boolean isWatermark() {
+    return message != null && WatermarkMessage.class.isAssignableFrom(message.getClass());
+  }
+
   /**
    * This method is deprecated in favor of WatermarkManager.buildEndOfStreamEnvelope(SystemStreamPartition ssp).
    *
diff --git a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
index a2b96ebdf..4d7847a91 100644
--- a/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/TaskConfig.java
@@ -90,6 +90,13 @@ public class TaskConfig extends MapConfig {
   // timeout for triggering a callback
   public static final String CALLBACK_TIMEOUT_MS = "task.callback.timeout.ms";
   static final long DEFAULT_CALLBACK_TIMEOUT_MS = -1L;
+
+  // timeout for triggering a callback during drain
+  public static final String DRAIN_CALLBACK_TIMEOUT_MS = "task.callback.drain.timeout.ms";
+
+  // default timeout for triggering a callback during drain
+  static final long DEFAULT_DRAIN_CALLBACK_TIMEOUT_MS = -1L;
+
   // enable async commit
   public static final String ASYNC_COMMIT = "task.async.commit";
   // maximum time to wait for a task worker to complete when there are no new messages to handle
@@ -224,6 +231,10 @@ public class TaskConfig extends MapConfig {
     return getLong(CALLBACK_TIMEOUT_MS, DEFAULT_CALLBACK_TIMEOUT_MS);
   }
 
+  public long getDrainCallbackTimeoutMs() {
+    return getLong(DRAIN_CALLBACK_TIMEOUT_MS, DEFAULT_DRAIN_CALLBACK_TIMEOUT_MS);
+  }
+
   public boolean getAsyncCommit() {
     return getBoolean(ASYNC_COMMIT, false);
   }
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
index ac89028e3..c075c44cf 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
@@ -78,6 +78,7 @@ public class RunLoop implements Runnable, Throttleable {
   private final long windowMs;
   private final long commitMs;
   private final long callbackTimeoutMs;
+  private final long drainCallbackTimeoutMs;
   private final long maxIdleMs;
   private final SamzaContainerMetrics containerMetrics;
   private final ScheduledExecutorService workerTimer;
@@ -90,7 +91,6 @@ public class RunLoop implements Runnable, Throttleable {
   private volatile boolean runLoopResumedSinceLastChecked;
   private final int elasticityFactor;
   private final String runId;
-
   private final boolean isHighLevelApiJob;
   private boolean isDraining = false;
 
@@ -101,13 +101,14 @@ public class RunLoop implements Runnable, Throttleable {
       long windowMs,
       long commitMs,
       long callbackTimeoutMs,
+      long drainCallbackTimeoutMs,
       long maxThrottlingDelayMs,
       long maxIdleMs,
       SamzaContainerMetrics containerMetrics,
       HighResolutionClock clock,
       boolean isAsyncCommitEnabled) {
     this(runLoopTasks, threadPool, consumerMultiplexer, maxConcurrency, windowMs, commitMs, callbackTimeoutMs,
-        maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1, null, false);
+        drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1, null, false);
   }
 
   public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
@@ -117,6 +118,7 @@ public class RunLoop implements Runnable, Throttleable {
       long windowMs,
       long commitMs,
       long callbackTimeoutMs,
+      long drainCallbackTimeoutMs,
       long maxThrottlingDelayMs,
       long maxIdleMs,
       SamzaContainerMetrics containerMetrics,
@@ -133,6 +135,7 @@ public class RunLoop implements Runnable, Throttleable {
     this.commitMs = commitMs;
     this.maxConcurrency = maxConcurrency;
     this.callbackTimeoutMs = callbackTimeoutMs;
+    this.drainCallbackTimeoutMs = drainCallbackTimeoutMs;
     this.maxIdleMs = maxIdleMs;
     this.callbackTimer = (callbackTimeoutMs > 0) ? Executors.newSingleThreadScheduledExecutor() : null;
     this.callbackExecutor = new ThrottlingScheduler(maxThrottlingDelayMs);
@@ -624,7 +627,9 @@ public class RunLoop implements Runnable, Throttleable {
         public TaskCallback createCallback() {
           state.startProcess();
           containerMetrics.processes().inc();
-          return callbackManager.createCallback(task.taskName(), envelope, coordinator);
+          return isDraining && (envelope.isDrain() || envelope.isWatermark())
+              ? callbackManager.createCallbackForDrain(task.taskName(), envelope, coordinator, drainCallbackTimeoutMs)
+              : callbackManager.createCallback(task.taskName(), envelope, coordinator);
         }
       };
 
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
index 07fa88a70..8dac7a1ab 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
@@ -62,6 +62,9 @@ public class RunLoopFactory {
     long callbackTimeout = taskConfig.getCallbackTimeoutMs();
     log.info("Got callbackTimeout: {}.", callbackTimeout);
 
+    long drainCallbackTimeout = taskConfig.getDrainCallbackTimeoutMs();
+    log.info("Got callback timeout for drain: {}.", callbackTimeout);
+
     long maxIdleMs = taskConfig.getMaxIdleMs();
     log.info("Got maxIdleMs: {}.", maxIdleMs);
 
@@ -85,6 +88,7 @@ public class RunLoopFactory {
       taskWindowMs,
       taskCommitMs,
       callbackTimeout,
+      drainCallbackTimeout,
       maxThrottlingDelayMs,
       maxIdleMs,
       containerMetrics,
diff --git a/samza-core/src/main/java/org/apache/samza/task/TaskCallbackManager.java b/samza-core/src/main/java/org/apache/samza/task/TaskCallbackManager.java
index 2d49de740..d435615d2 100644
--- a/samza-core/src/main/java/org/apache/samza/task/TaskCallbackManager.java
+++ b/samza-core/src/main/java/org/apache/samza/task/TaskCallbackManager.java
@@ -92,9 +92,36 @@ public class TaskCallbackManager {
     this.clock = clock;
   }
 
+  /**
+   * Creates a task callback.
+   * @param taskName task name
+   * @param envelope incoming envelope
+   * @param coordinator coordinator
+   * */
   public TaskCallbackImpl createCallback(TaskName taskName,
       IncomingMessageEnvelope envelope,
       ReadableCoordinator coordinator) {
+    return createCallback(taskName, envelope, coordinator, timeout);
+  }
+
+  /**
+   * Creates a task callback.
+   * @param taskName task name
+   * @param envelope incoming envelope
+   * @param coordinator coordinator
+   * @param drainTimeout timeout for processing drain messages.
+   * */
+  public TaskCallbackImpl createCallbackForDrain(TaskName taskName,
+      IncomingMessageEnvelope envelope,
+      ReadableCoordinator coordinator,
+      long drainTimeout) {
+    return createCallback(taskName, envelope, coordinator, drainTimeout);
+  }
+
+  private TaskCallbackImpl createCallback(TaskName taskName,
+      IncomingMessageEnvelope envelope,
+      ReadableCoordinator coordinator,
+      long callbackTimeout) {
     final TaskCallbackImpl callback =
         new TaskCallbackImpl(listener, taskName, envelope, coordinator, seqNum++, clock.nanoTime());
     if (timer != null) {
@@ -106,7 +133,8 @@ public class TaskCallbackManager {
           callback.failure(new SamzaException(msg));
         }
       };
-      ScheduledFuture scheduledFuture = timer.schedule(timerTask, timeout, TimeUnit.MILLISECONDS);
+
+      final ScheduledFuture scheduledFuture = timer.schedule(timerTask, callbackTimeout, TimeUnit.MILLISECONDS);
       callback.setScheduledFuture(scheduledFuture);
     }
 
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index 6c70fe308..326908535 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -937,6 +937,7 @@ public class ContainerStorageManager {
         -1, // no windowing
         taskConfig.getCommitMs(),
         taskConfig.getCallbackTimeoutMs(),
+        taskConfig.getDrainCallbackTimeoutMs(),
         // TODO consolidate these container configs SAMZA-2275
         this.config.getLong("container.disk.quota.delay.max.ms", TimeUnit.SECONDS.toMillis(1)),
         taskConfig.getMaxIdleMs(),
diff --git a/samza-core/src/test/java/org/apache/samza/config/TestTaskConfig.java b/samza-core/src/test/java/org/apache/samza/config/TestTaskConfig.java
index f6df6f7c7..d2fce7e26 100644
--- a/samza-core/src/test/java/org/apache/samza/config/TestTaskConfig.java
+++ b/samza-core/src/test/java/org/apache/samza/config/TestTaskConfig.java
@@ -204,6 +204,15 @@ public class TestTaskConfig {
     assertEquals(TaskConfig.DEFAULT_CALLBACK_TIMEOUT_MS, new TaskConfig(new MapConfig()).getCallbackTimeoutMs());
   }
 
+  @Test
+  public void testGetDrainCallbackTimeoutMs() {
+    Config config = new MapConfig(ImmutableMap.of(TaskConfig.DRAIN_CALLBACK_TIMEOUT_MS, "100"));
+    assertEquals(100, new TaskConfig(config).getDrainCallbackTimeoutMs());
+
+    // config not specified
+    assertEquals(TaskConfig.DEFAULT_DRAIN_CALLBACK_TIMEOUT_MS, new TaskConfig(new MapConfig()).getDrainCallbackTimeoutMs());
+  }
+
   @Test
   public void testGetAsyncCommit() {
     Config config = new MapConfig(ImmutableMap.of(TaskConfig.ASYNC_COMMIT, "true"));
diff --git a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
index faa0152d5..73f71d73d 100644
--- a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
@@ -56,6 +56,8 @@ public class TestRunLoop {
   private final long windowMs = -1;
   private final long commitMs = -1;
   private final long callbackTimeoutMs = 0;
+
+  private final long drainCallbackTimeoutMs = 0;
   private final long maxThrottlingDelayMs = 0;
   private final long maxIdleMs = 10;
   private final Partition p0 = new Partition(0);
@@ -96,7 +98,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false, 1, "foo", false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false, 1, "foo", false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(sspA0EndOfStream).thenReturn(
         sspA1EndOfStream).thenReturn(null);
     runLoop.run();
@@ -117,7 +119,7 @@ public class TestRunLoop {
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     runLoop.run();
 
     InOrder inOrder = inOrder(task0);
@@ -164,7 +166,7 @@ public class TestRunLoop {
     tasks.put(taskName0, task0);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(null);
     runLoop.run();
 
@@ -215,7 +217,8 @@ public class TestRunLoop {
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, null, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L,
+        false, 2, null, false);
     runLoop.run();
 
     verify(task0).process(eq(envelope00), any(), any());
@@ -241,7 +244,8 @@ public class TestRunLoop {
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L,
+        false, 1, runId, false);
     runLoop.run();
 
     // check if process was called once for each task
@@ -271,7 +275,8 @@ public class TestRunLoop {
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L,
+        false, 1, runId, false);
     runLoop.run();
 
     // check if process was called twice for each task
@@ -306,7 +311,7 @@ public class TestRunLoop {
     tasks.put(taskName0, task);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(null);
     runLoop.run();
 
@@ -338,7 +343,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
 
@@ -376,7 +381,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     //have a null message in between to make sure task0 finishes processing and invoke the commit
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
 
@@ -421,7 +426,7 @@ public class TestRunLoop {
     tasks.put(taskName1, task1);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     // consensus is reached after envelope1 is processed.
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
     runLoop.run();
@@ -447,7 +452,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     when(consumerMultiplexer.choose(false))
       .thenReturn(envelopeA00)
       .thenReturn(envelopeA11)
@@ -514,7 +519,7 @@ public class TestRunLoop {
     tasks.put(taskName0, task0);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0EndOfStream)
         .thenAnswer(invocation -> {
           // this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
@@ -569,7 +574,7 @@ public class TestRunLoop {
     tasks.put(taskName0, task0);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
         1, runId, false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0Drain)
         .thenAnswer(invocation -> {
@@ -602,7 +607,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-                                            callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(sspA0EndOfStream).thenReturn(null);
 
     runLoop.run();
@@ -624,7 +629,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
         1, runId, false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(sspA0Drain).thenReturn(null);
 
@@ -690,7 +695,7 @@ public class TestRunLoop {
     tasks.put(taskName0, task0);
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(null);
     runLoop.run();
 
@@ -715,7 +720,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
+        callbackTimeoutMs, drainCallbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
 
     when(consumerMultiplexer.choose(false))
         .thenReturn(envelopeA00)
diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
index 405157ad6..de418c014 100644
--- a/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
+++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskCallbackManager.java
@@ -26,6 +26,7 @@ import org.apache.samza.container.TaskName;
 import org.apache.samza.metrics.MetricsRegistryMap;
 import org.apache.samza.system.IncomingMessageEnvelope;
 import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.HighResolutionClock;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -40,7 +41,10 @@ public class TestTaskCallbackManager {
 
   @Before
   public void setup() {
-    TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap(), "");
+    final TaskInstanceMetrics metrics = new TaskInstanceMetrics("Partition 0", new MetricsRegistryMap(), "");
+    final long timeout = -1L;
+    final int maxConcurrency = 2;
+    final HighResolutionClock highResolutionClock = System::nanoTime;
     listener = new TaskCallbackListener() {
       @Override
       public void onComplete(TaskCallback callback) {
@@ -49,7 +53,7 @@ public class TestTaskCallbackManager {
       public void onFailure(TaskCallback callback, Throwable t) {
       }
     };
-    callbackManager = new TaskCallbackManager(listener, null, -1, 2, () -> System.nanoTime());
+    callbackManager = new TaskCallbackManager(listener, null, timeout, maxConcurrency, highResolutionClock);
   }
 
   @Test
@@ -61,6 +65,15 @@ public class TestTaskCallbackManager {
     assertTrue(callback.matchSeqNum(1));
   }
 
+  @Test
+  public void testCreateDrainCallback() {
+    TaskCallbackImpl callback = callbackManager.createCallbackForDrain(new TaskName("Partition 0"), mock(IncomingMessageEnvelope.class), null, -1);
+    assertTrue(callback.matchSeqNum(0));
+
+    callback = callbackManager.createCallbackForDrain(new TaskName("Partition 0"), mock(IncomingMessageEnvelope.class), null, -1);
+    assertTrue(callback.matchSeqNum(1));
+  }
+
   @Test
   public void testUpdateCallbackInOrder() {
     TaskName taskName = new TaskName("Partition 0");