You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@samza.apache.org by xi...@apache.org on 2022/11/03 20:31:30 UTC

[samza] branch master updated: SAMZA-2741: Pipeline Drain- Fix processing of Drain messages for both low-level and high level API (#1639)

This is an automated email from the ASF dual-hosted git repository.

xinyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d4747c5c SAMZA-2741: Pipeline Drain- Fix processing of Drain messages for both low-level and high level API (#1639)
3d4747c5c is described below

commit 3d4747c5c2cfc2aadff42bca562f1fbfb4acad0b
Author: ajo thomas <aj...@linkedin.com>
AuthorDate: Thu Nov 3 13:31:24 2022 -0700

    SAMZA-2741: Pipeline Drain- Fix processing of Drain messages for both low-level and high level API (#1639)
---
 .../apache/samza/application/ApplicationUtil.java  |  28 ++-
 .../java/org/apache/samza/container/RunLoop.java   |  68 +++----
 .../org/apache/samza/container/RunLoopFactory.java |   6 +-
 .../samza/operators/impl/ControlMessageSender.java |   2 +
 .../apache/samza/operators/impl/OperatorImpl.java  |  28 ++-
 .../apache/samza/container/SamzaContainer.scala    |   6 +-
 .../org/apache/samza/container/TaskInstance.scala  |   3 +-
 .../samza/storage/ContainerStorageManager.java     |   7 +-
 .../org/apache/samza/system/SystemConsumers.scala  |   9 +-
 .../org/apache/samza/container/TestRunLoop.java    |  44 +---
 .../drain/DrainHighLevelApiIntegrationTest.java    | 222 +++++++++++++++++++--
 .../drain/DrainLowLevelApiIntegrationTest.java     |  91 +++++++--
 12 files changed, 377 insertions(+), 137 deletions(-)

diff --git a/samza-core/src/main/java/org/apache/samza/application/ApplicationUtil.java b/samza-core/src/main/java/org/apache/samza/application/ApplicationUtil.java
index edcb3aea8..cf52a2568 100644
--- a/samza-core/src/main/java/org/apache/samza/application/ApplicationUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/application/ApplicationUtil.java
@@ -18,18 +18,22 @@
  */
 package org.apache.samza.application;
 
+import com.google.common.base.Strings;
 import java.util.Optional;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.config.ApplicationConfig;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.ConfigException;
 import org.apache.samza.config.TaskConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
  * Util class to create {@link SamzaApplication} from the configuration.
  */
 public class ApplicationUtil {
+  private static final Logger LOG = LoggerFactory.getLogger(ApplicationUtil.class);
 
   /**
    * Creates the {@link SamzaApplication} object from the task or application class name specified in {@code config}
@@ -38,6 +42,7 @@ public class ApplicationUtil {
    * @return the {@link SamzaApplication} object
    */
   public static SamzaApplication fromConfig(Config config) {
+
     String appClassName = new ApplicationConfig(config).getAppClass();
     if (StringUtils.isNotBlank(appClassName)) {
       // app.class is configured
@@ -60,4 +65,25 @@ public class ApplicationUtil {
     }
     return new LegacyTaskApplication(taskClassOption.get());
   }
-}
\ No newline at end of file
+
+  /**
+   * Determines if the job is a Samza high-level job.
+   * High-level job implements StreamApplication.
+   * @param config config
+   * */
+  public static boolean isHighLevelApiJob(Config config) {
+    final ApplicationConfig applicationConfig = new ApplicationConfig(config);
+    final String appClass = applicationConfig.getAppClass();
+    if (!Strings.isNullOrEmpty(appClass)) {
+      try {
+        return StreamApplication.class.isAssignableFrom(Class.forName(appClass));
+      } catch (ClassNotFoundException e) {
+        LOG.debug("Error while determining if the job is a high level API job", e);
+        return false;
+      }
+    } else {
+      LOG.warn("Config {} is empty or null. Cannot determine if the job is a high-level API job", ApplicationConfig.APP_CLASS);
+      return false;
+    }
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
index 3334adc74..ac89028e3 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
@@ -27,7 +27,6 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
@@ -91,6 +90,8 @@ public class RunLoop implements Runnable, Throttleable {
   private volatile boolean runLoopResumedSinceLastChecked;
   private final int elasticityFactor;
   private final String runId;
+
+  private final boolean isHighLevelApiJob;
   private boolean isDraining = false;
 
   public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
@@ -106,7 +107,7 @@ public class RunLoop implements Runnable, Throttleable {
       HighResolutionClock clock,
       boolean isAsyncCommitEnabled) {
     this(runLoopTasks, threadPool, consumerMultiplexer, maxConcurrency, windowMs, commitMs, callbackTimeoutMs,
-        maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1, null);
+        maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1, null, false);
   }
 
   public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
@@ -122,7 +123,8 @@ public class RunLoop implements Runnable, Throttleable {
       HighResolutionClock clock,
       boolean isAsyncCommitEnabled,
       int elasticityFactor,
-      String runId) {
+      String runId,
+      boolean isHighLevelApiJob) {
 
     this.threadPool = threadPool;
     this.consumerMultiplexer = consumerMultiplexer;
@@ -141,23 +143,26 @@ public class RunLoop implements Runnable, Throttleable {
     // assign runId before creating workers. As the inner AsyncTaskWorker class is not static, it relies on
     // the outer class fields to be init first
     this.runId = runId;
+    this.isHighLevelApiJob = isHighLevelApiJob;
+    this.isAsyncCommitEnabled = isAsyncCommitEnabled;
+    this.elasticityFactor = elasticityFactor;
+
     Map<TaskName, AsyncTaskWorker>  workers = new HashMap<>();
     for (RunLoopTask task : runLoopTasks.values()) {
       workers.put(task.taskName(), new AsyncTaskWorker(task));
     }
     // Partions and tasks assigned to the container will not change during the run loop life time
     this.sspToTaskWorkerMapping = Collections.unmodifiableMap(getSspToAsyncTaskWorkerMap(runLoopTasks, workers));
-
     this.taskWorkers = Collections.unmodifiableList(new ArrayList<>(workers.values()));
-    this.isAsyncCommitEnabled = isAsyncCommitEnabled;
-    this.elasticityFactor = elasticityFactor;
   }
 
   /**
    * Sets the RunLoop to drain mode.
    * */
   private void drain() {
+    log.info("Setting the RunLoop to drain mode.");
     isDraining = true;
+    log.debug("Disabling async commit when the pipeline is draining.");
     isAsyncCommitEnabled = false;
   }
 
@@ -480,7 +485,7 @@ public class RunLoop implements Runnable, Throttleable {
       this.task = task;
       this.callbackManager = new TaskCallbackManager(this, callbackTimer, callbackTimeoutMs, maxConcurrency, clock);
       Set<SystemStreamPartition> sspSet = getWorkingSSPSet(task);
-      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, !task.intermediateStreams().isEmpty(),
+      this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, !task.intermediateStreams().isEmpty(), isHighLevelApiJob,
           runId);
     }
 
@@ -841,16 +846,18 @@ public class RunLoop implements Runnable, Throttleable {
     private final TaskName taskName;
     private final TaskInstanceMetrics taskMetrics;
     private final boolean hasIntermediateStreams;
+    private final boolean isHighLevelApiJob;
     private final String runId;
 
     AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics, Set<SystemStreamPartition> sspSet,
-        boolean hasIntermediateStreams, String runId) {
+        boolean hasIntermediateStreams, boolean isHighLevelApiJob, String runId) {
       this.taskName = taskName;
       this.taskMetrics = taskMetrics;
       this.pendingEnvelopeQueue = new ArrayDeque<>();
       this.processingSspSet = sspSet;
       this.processingSspSetToDrain = new HashSet<>(sspSet);
       this.hasIntermediateStreams = hasIntermediateStreams;
+      this.isHighLevelApiJob = isHighLevelApiJob;
       this.runId = runId;
     }
 
@@ -875,50 +882,39 @@ public class RunLoop implements Runnable, Throttleable {
         return false;
       }
 
-      if (!pendingEnvelopeQueue.isEmpty()) {
-        PendingEnvelope pendingEnvelope = pendingEnvelopeQueue.peek();
-        IncomingMessageEnvelope envelope = pendingEnvelope.envelope;
+      if (pendingEnvelopeQueue.size() > 0) {
+        final PendingEnvelope pendingEnvelope = pendingEnvelopeQueue.peek();
+        final IncomingMessageEnvelope envelope = pendingEnvelope.envelope;
 
         if (envelope.isDrain()) {
           final DrainMessage message = (DrainMessage) envelope.getMessage();
           if (!message.getRunId().equals(runId)) {
-            // Removing the drain message from the pending queue as it doesn't match with the current runId
-            // Removing it will ensure that it is not picked up by process()
-            pendingEnvelopeQueue.remove();
+            // Removing the drain message from the pending queue as it doesn't match with the current deployment
+            final PendingEnvelope discardedDrainMessage = pendingEnvelopeQueue.remove();
+            consumerMultiplexer.tryUpdate(discardedDrainMessage.envelope.getSystemStreamPartition());
           } else {
+            // Found drain message matching the current deployment
+
             // set the RunLoop to drain mode
             if (!isDraining) {
               drain();
             }
 
-            if (elasticityFactor <= 1) {
-              SystemStreamPartition ssp = envelope.getSystemStreamPartition();
-              processingSspSetToDrain.remove(ssp);
-            } else {
-              // SystemConsumers will write only one envelope (enclosing DrainMessage) per SSP in its buffer.
-              // This envelope doesn't have keybucket info it's SSP. With elasticity, the same SSP can be processed by
-              // multiple tasks. Therefore, if envelope contains drain message, the ssp of envelope should be removed
-              // from task's processing set irrespective of keyBucket.
-              SystemStreamPartition sspOfEnvelope = envelope.getSystemStreamPartition();
-              Optional<SystemStreamPartition> ssp = processingSspSetToDrain.stream()
-                  .filter(sspInSet -> sspInSet.getSystemStream().equals(sspOfEnvelope.getSystemStream())
-                      && sspInSet.getPartition().equals(sspOfEnvelope.getPartition()))
-                  .findFirst();
-              ssp.ifPresent(processingSspSetToDrain::remove);
-            }
+            if (!isHighLevelApiJob) {
+              // The flow below only applies to samza low-level API
+
+              // For high-level API, we do not remove the message from pending queue.
+              // It will be picked by the process flow instead of drain flow, as we want the drain control message
+              // to be processed by the High-level API Operator DAG.
 
-            if (!hasIntermediateStreams) {
-              // Don't remove from the pending queue as we want the DAG to pick up Drain message and propagate it to
-              // intermediate streams
+              processingSspSetToDrain.remove(envelope.getSystemStreamPartition());
               pendingEnvelopeQueue.remove();
             }
           }
         }
-        return processingSspSetToDrain.isEmpty();
       }
-      // if no messages are in the queue, the task has probably already drained or there are no messages from
-      // the chooser
-      return false;
+
+      return processingSspSetToDrain.isEmpty();
     }
 
     /**
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
index cb021d098..54c5bc5a9 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
@@ -42,7 +42,8 @@ public class RunLoopFactory {
       TaskConfig taskConfig,
       HighResolutionClock clock,
       int elasticityFactor,
-      String runId) {
+      String runId,
+      boolean isHighLevelApiJob) {
 
     long taskWindowMs = taskConfig.getWindowMs();
 
@@ -84,6 +85,7 @@ public class RunLoopFactory {
       clock,
       isAsyncCommitEnabled,
       elasticityFactor,
-      runId);
+      runId,
+      isHighLevelApiJob);
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
index 779644d3c..fdfd28774 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/ControlMessageSender.java
@@ -67,6 +67,8 @@ class ControlMessageSender {
     int currentPartition = ssp.getPartition().getPartitionId();
     for (int i = 0; i < partitionCount; i++) {
       if (i != currentPartition) {
+        LOG.debug(String.format("Broadcast %s message from task %s to %s partition %d for aggregation",
+            MessageType.of(message).name(), message.getTaskName(), systemStream, i));
         OutgoingMessageEnvelope envelopeOut = new OutgoingMessageEnvelope(systemStream, i, null, message);
         collector.send(envelopeOut);
       }
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
index 5723d912b..61f32bcd7 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
@@ -129,6 +129,7 @@ public abstract class OperatorImpl<M, RM> {
     this.eosStates = (EndOfStreamStates) internalTaskContext.fetchObject(EndOfStreamStates.class.getName());
     this.watermarkStates = (WatermarkStates) internalTaskContext.fetchObject(WatermarkStates.class.getName());
     this.drainStates = (DrainStates) internalTaskContext.fetchObject(DrainStates.class.getName());
+
     this.controlMessageSender = new ControlMessageSender(internalTaskContext.getStreamMetadataCache());
     this.taskModel = taskContext.getTaskModel();
     this.callbackScheduler = taskContext.getCallbackScheduler();
@@ -366,13 +367,13 @@ public abstract class OperatorImpl<M, RM> {
   }
 
   /**
-   * This method is implemented when all input stream to this operation have encountered drain-and-stop control message.
-   * Inherited class should handle drain-and-stop by overriding this function.
-   * By default noop implementation is for in-memory operator to handle the drain-and-stop. Output operator need to
-   * override this to actually propagate drain-and-stop over the wire.
+   * This method is implemented when all input stream to this operation have encountered drain control message.
+   * Inherited operator implementation should handle drain by overriding this function.
+   * By default, noop implementation is for in-memory operator to handle the drain. Output operator need to
+   * override this to actually propagate drain control message over the wire.
    * @param collector message collector
    * @param coordinator task coordinator
-   * @return results to be emitted when this operator reaches drain-and-stop
+   * @return results to be emitted when this operator encounters drain control message
    */
   protected Collection<RM> handleDrain(MessageCollector collector, TaskCoordinator coordinator) {
     return Collections.emptyList();
@@ -395,20 +396,19 @@ public abstract class OperatorImpl<M, RM> {
     CompletionStage<Void> drainFuture = CompletableFuture.completedFuture(null);
 
     if (drainStates.isDrained(stream)) {
-      LOG.info("Input {} reaches the end for task {}", stream.toString(), taskName.getTaskName());
-      if (drainMessage.getTaskName() != null && shouldTaskBroadcastToOtherPartitions(ssp)) {
-        // This is the aggregation task, which already received all the eos messages from upstream
-        // broadcast the end-of-stream to all the peer partitions
-        // additionally if elasiticty is enabled
-        // then only one of the elastic tasks of the ssp will broadcast
+      LOG.info("Input {} is drained for task {}", stream.toString(), taskName.getTaskName());
+      if (drainMessage.getTaskName() != null) {
+        // This is the aggregation task which already received all the drain messages from upstream.
+        // Broadcast the drain messages to all the peer partitions.
         controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector);
       }
 
       drainFuture = onDrainOfStream(collector, coordinator)
           .thenAccept(result -> {
             if (drainStates.areAllStreamsDrained()) {
-              // all input streams have been drained, shut down the task
-              LOG.info("All input streams have been drained for task {}", taskName.getTaskName());
+              // All input streams have been drained, shut down the task
+              LOG.info("All input streams have been drained for task {}. Requesting shutdown.", taskName.getTaskName());
+              coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
               coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
             }
           });
@@ -427,10 +427,8 @@ public abstract class OperatorImpl<M, RM> {
    */
   private CompletionStage<Void> onDrainOfStream(MessageCollector collector, TaskCoordinator coordinator) {
     CompletionStage<Void> drainFuture = CompletableFuture.completedFuture(null);
-
     if (inputStreams.stream().allMatch(input -> drainStates.isDrained(input))) {
       Collection<RM> results = handleDrain(collector, coordinator);
-
       CompletionStage<Void> resultFuture = CompletableFuture.allOf(
           results.stream()
               .flatMap(r -> this.registeredOperators.stream()
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index f0aa20ba4..69c7d4e5a 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -29,6 +29,7 @@ import java.util.function.Consumer
 import java.util.{Base64, Optional}
 import com.google.common.util.concurrent.ThreadFactoryBuilder
 import org.apache.samza.SamzaException
+import org.apache.samza.application.ApplicationUtil
 import org.apache.samza.checkpoint.{CheckpointListener, OffsetManager, OffsetManagerMetrics}
 import org.apache.samza.clustermanager.StandbyTaskUtil
 import org.apache.samza.config.{StreamConfig, _}
@@ -624,6 +625,8 @@ object SamzaContainer extends Logging {
 
     val maxThrottlingDelayMs = config.getLong("container.disk.quota.delay.max.ms", TimeUnit.SECONDS.toMillis(1))
 
+    val isHighLevelApiJob = ApplicationUtil.isHighLevelApiJob(config)
+
     val runLoop: Runnable = RunLoopFactory.createRunLoop(
       taskInstances,
       consumerMultiplexer,
@@ -633,7 +636,8 @@ object SamzaContainer extends Logging {
       taskConfig,
       clock,
       jobConfig.getElasticityFactor,
-      appConfig.getRunId)
+      appConfig.getRunId,
+      isHighLevelApiJob)
 
     val containerMemoryMb : Int = new ClusterManagerConfig(config).getContainerMemoryMb
 
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 75f4a8d8f..564c47344 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -568,7 +568,8 @@ class TaskInstance(
     // within the keyBucket of the SSP assigned to the task.
     val incomingMessageSsp = envelope.getSystemStreamPartition(elasticityFactor)
 
-    if (IncomingMessageEnvelope.END_OF_STREAM_OFFSET.equals(envelope.getOffset)) {
+    if (IncomingMessageEnvelope.END_OF_STREAM_OFFSET.equals(envelope.getOffset)
+      || envelope.isDrain) {
       ssp2CaughtupMapping(incomingMessageSsp) = true
     } else {
       systemAdmins match {
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index 40b299993..6c70fe308 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -42,6 +42,7 @@ import java.util.function.Function;
 import java.util.stream.Collectors;
 import org.apache.commons.collections4.MapUtils;
 import org.apache.samza.SamzaException;
+import org.apache.samza.application.ApplicationUtil;
 import org.apache.samza.checkpoint.Checkpoint;
 import org.apache.samza.checkpoint.CheckpointManager;
 import org.apache.samza.checkpoint.CheckpointV2;
@@ -927,7 +928,7 @@ public class ContainerStorageManager {
         new SamzaContainerMetrics(SIDEINPUTS_METRICS_PREFIX + this.samzaContainerMetrics.source(),
             this.samzaContainerMetrics.registry(), SIDEINPUTS_METRICS_PREFIX);
 
-    ApplicationConfig applicationConfig = new ApplicationConfig(config);
+    final ApplicationConfig applicationConfig = new ApplicationConfig(config);
 
     this.sideInputRunLoop = new RunLoop(sideInputTasks,
         null, // all operations are executed in the main runloop thread
@@ -943,7 +944,9 @@ public class ContainerStorageManager {
         System::nanoTime,
         false,
         DEFAULT_SIDE_INPUT_ELASTICITY_FACTOR,
-        applicationConfig.getRunId()); // commit must be synchronous to ensure integrity of state flush
+        applicationConfig.getRunId(),
+        ApplicationUtil.isHighLevelApiJob(config)
+        ); // commit must be synchronous to ensure integrity of state flush
 
     try {
       sideInputsExecutor.submit(() -> {
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index d1b639cef..f71f7a3a3 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -219,6 +219,10 @@ class SystemConsumers (
 
     chooser.start
 
+    // SystemConsumers could be set to drain mode prior to start if a drain message was encountered on container start
+    if (isDraining) {
+      writeDrainControlMessageToSspQueue()
+    }
 
     started = true
 
@@ -243,8 +247,9 @@ class SystemConsumers (
     if (!isDraining) {
       isDraining = true;
       info("SystemConsumers is set to drain mode.")
-      consumers.values.foreach(_.stop)
-      writeDrainControlMessageToSspQueue()
+      if (started) {
+        writeDrainControlMessageToSspQueue()
+      }
     }
   }
 
diff --git a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
index 16ef93de1..faa0152d5 100644
--- a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
@@ -96,7 +96,7 @@ public class TestRunLoop {
 
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false, 1, "foo");
+        callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false, 1, "foo", false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(sspA0EndOfStream).thenReturn(
         sspA1EndOfStream).thenReturn(null);
     runLoop.run();
@@ -215,7 +215,7 @@ public class TestRunLoop {
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, null);
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, null, false);
     runLoop.run();
 
     verify(task0).process(eq(envelope00), any(), any());
@@ -224,38 +224,6 @@ public class TestRunLoop {
     assertEquals(1, containerMetrics.processes().getCount()); // only envelope00 and not envelope01 and not end of stream
   }
 
-  @Test
-  public void testDrainWithElasticityEnabled() {
-    TaskName taskName0 = new TaskName(p0.toString() + " 0");
-    TaskName taskName1 = new TaskName(p0.toString() + " 1");
-    SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStreamA", p0);
-    SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStreamA", p0, 0);
-    SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStreamA", p0, 1);
-
-    // create EOS IME such that its ssp keybucket maps to ssp0 and not to ssp1
-    // task in the runloop should give this ime to both it tasks
-    IncomingMessageEnvelope envelopeDrain = spy(IncomingMessageEnvelope.buildDrainMessage(ssp, runId));
-    when(envelopeDrain.getSystemStreamPartition(2)).thenReturn(ssp0);
-
-    // two task in the run loop that processes ssp0 -> 0th keybucket of ssp and ssp1 -> 1st keybucket of ssp
-    // Drain ime should be given to both the tasks irrespective of the keybucket
-    RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
-    RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
-
-    SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
-    when(consumerMultiplexer.choose(false)).thenReturn(envelopeDrain).thenReturn(null);
-
-    Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
-    int maxMessagesInFlight = 1;
-    RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, runId);
-    runLoop.run();
-
-    verify(task0).drain(any());
-    verify(task1).drain(any());
-  }
-
-
   @Test
   public void testDrainForTasksWithSingleSSP() {
     TaskName taskName0 = new TaskName(p0.toString() + " 0");
@@ -273,7 +241,7 @@ public class TestRunLoop {
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId);
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId, false);
     runLoop.run();
 
     // check if process was called once for each task
@@ -303,7 +271,7 @@ public class TestRunLoop {
     Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
-        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId);
+        callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId, false);
     runLoop.run();
 
     // check if process was called twice for each task
@@ -602,7 +570,7 @@ public class TestRunLoop {
 
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
-        1, runId);
+        1, runId, false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0Drain)
         .thenAnswer(invocation -> {
           // this ensures that the drain message has passed through run loop BEFORE the flight message
@@ -657,7 +625,7 @@ public class TestRunLoop {
     int maxMessagesInFlight = 1;
     RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
         callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
-        1, runId);
+        1, runId, false);
     when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(sspA0Drain).thenReturn(null);
 
     runLoop.run();
diff --git a/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java
index c8cd5340f..9985b673b 100644
--- a/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java
@@ -29,6 +29,7 @@ import java.util.concurrent.Callable;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
+import net.jcip.annotations.NotThreadSafe;
 import org.apache.samza.application.StreamApplication;
 import org.apache.samza.application.descriptors.StreamApplicationDescriptor;
 import org.apache.samza.config.ApplicationConfig;
@@ -50,7 +51,7 @@ import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescript
 import org.apache.samza.test.table.TestTableData;
 import org.apache.samza.test.table.TestTableData.PageView;
 import org.apache.samza.util.CoordinatorStreamUtil;
-import org.junit.Ignore;
+import org.junit.FixMethodOrder;
 import org.junit.Test;
 
 import static org.junit.Assert.*;
@@ -59,38 +60,72 @@ import static org.junit.Assert.*;
 /**
  * End to end integration test to check drain functionality with samza high-level API.
  */
+@FixMethodOrder
+@NotThreadSafe
 public class DrainHighLevelApiIntegrationTest {
   private static final List<PageView> RECEIVED = new ArrayList<>();
-  private static final String SYSTEM_NAME = "test";
-  private static final String STREAM_ID = "PageView";
+  private static final String SYSTEM_NAME1 = "test1";
+  private static final String STREAM_ID1 = "PageView1";
 
-  private static class  PageViewEventCountHighLevelApplication implements StreamApplication {
+  private static final String SYSTEM_NAME2 = "test2";
+  private static final String STREAM_ID2 = "PageView2";
+
+  /**
+   * High-Level Job with multiple shuffle stages.
+   * */
+  private static class PageViewEventCountHighLevelApplication implements StreamApplication {
     @Override
     public void describe(StreamApplicationDescriptor appDescriptor) {
-      DelegatingSystemDescriptor sd = new DelegatingSystemDescriptor(SYSTEM_NAME);
+      DelegatingSystemDescriptor sd = new DelegatingSystemDescriptor(SYSTEM_NAME1);
       GenericInputDescriptor<KV<String, PageView>> isd =
-          sd.getInputDescriptor(STREAM_ID, KVSerde.of(new NoOpSerde<>(), new NoOpSerde<>()));
+          sd.getInputDescriptor(STREAM_ID1, KVSerde.of(new NoOpSerde<>(), new NoOpSerde<>()));
       appDescriptor.getInputStream(isd)
           .map(KV::getValue)
           .partitionBy(PageView::getMemberId, pv -> pv,
-              KVSerde.of(new IntegerSerde(), new TestTableData.PageViewJsonSerde()), "p1")
+              KVSerde.of(new IntegerSerde(), new TestTableData.PageViewJsonSerde()), "p11")
+          .map(kv -> KV.of(kv.getKey() * 31, kv.getValue()))
+          .partitionBy(KV::getKey, KV::getValue, KVSerde.of(new IntegerSerde(), new TestTableData.PageViewJsonSerde()), "p21")
+          .sink((m, collector, coordinator) -> {
+            RECEIVED.add(m.getValue());
+          });
+    }
+  }
+
+  /**
+   * Simple high-level application without shuffle stages.
+   * */
+  private static class SimpleHighLevelApplication implements StreamApplication {
+    @Override
+    public void describe(StreamApplicationDescriptor appDescriptor) {
+      DelegatingSystemDescriptor sd = new DelegatingSystemDescriptor(SYSTEM_NAME2);
+      GenericInputDescriptor<KV<String, PageView>> isd =
+          sd.getInputDescriptor(STREAM_ID2, KVSerde.of(new NoOpSerde<>(), new NoOpSerde<>()));
+      appDescriptor.getInputStream(isd)
           .sink((m, collector, coordinator) -> {
             RECEIVED.add(m.getValue());
           });
     }
   }
 
-  // The test can be occasionally flaky, so we set Ignore annotation
-  // Remove ignore annotation and run the test as follows:
-  // ./gradlew :samza-test:test --tests org.apache.samza.test.drain.DrainHighLevelApiIntegrationTest -PscalaSuffix=2.12
-  @Ignore
+
+  /**
+   * This test will test drain and consumption of some messages from the in-memory topic.
+   * In order to simulate the real-world behaviour of drain, the test adds messages to the in-memory topic buffer periodically
+   * in a delayed fashion instead of all at once. The test then writes the drain notification message to the in-memory
+   * metadata store to drain and stop the pipeline. This write is done shortly after the pipeline starts and before all
+   * the messages are written to the topic's buffer. As a result, the total count of the processed messages will be less
+   * than the expected count of messages.
+   * */
   @Test
-  public void testPipeline() {
+  public void testDrain() {
+    int numPageViews = 200;
+    int numPartitions = 4;
+    long delayBetweenMessagesInMillis = 500L;
+    long drainTriggerDelay = 10_000L;
     String runId = "DrainTestId";
-    int numPageViews = 40;
 
-    InMemorySystemDescriptor isd = new InMemorySystemDescriptor(SYSTEM_NAME);
-    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor(STREAM_ID, new NoOpSerde<>());
+    InMemorySystemDescriptor isd = new InMemorySystemDescriptor(SYSTEM_NAME1);
+    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor(STREAM_ID1, new NoOpSerde<>());
     InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
 
     Map<String, String> customConfig = ImmutableMap.of(
@@ -99,14 +134,11 @@ public class DrainHighLevelApiIntegrationTest {
         JobConfig.DRAIN_MONITOR_ENABLED, "true");
 
     // Create a TestRunner
-    // Set a InMemoryMetadataFactory. We will use this factory in the test to create a metadata store and
-    // write drain message to it
-    // Mock data comprises of 40 messages across 4 partitions. TestRunner adds a 1 second delay between messages
-    // per partition when writing messages to the InputStream
+    // Set a InMemoryMetadataFactory.This factory is shared between TestRunner and DrainUtils's write drain method
     TestRunner testRunner = TestRunner.of(new PageViewEventCountHighLevelApplication())
         .setInMemoryMetadataFactory(metadataStoreFactory)
         .addConfig(customConfig)
-        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, 4), 1000L);
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, numPartitions), delayBetweenMessagesInMillis);
 
     Config configFromRunner = testRunner.getConfig();
     MetadataStore metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
@@ -125,10 +157,156 @@ public class DrainHighLevelApiIntegrationTest {
         UUID uuid = DrainUtils.writeDrainNotification(metadataStore);
         return uuid.toString();
       }
-    }, 2000L, TimeUnit.MILLISECONDS);
+    }, drainTriggerDelay, TimeUnit.MILLISECONDS);
 
-    testRunner.run(Duration.ofSeconds(20));
+    testRunner.run(Duration.ofSeconds(40));
 
     assertTrue(RECEIVED.size() < numPageViews && RECEIVED.size() > 0);
+    RECEIVED.clear();
+    clearMetadataStore(metadataStore);
+  }
+
+  @Test
+  public void testDrainWithoutReshuffleStages() {
+    int numPageViews = 200;
+    int numPartitions = 4;
+    long delayBetweenMessagesInMillis = 500L;
+    long drainTriggerDelay = 10_000L;
+    String runId = "DrainTestId";
+
+    InMemorySystemDescriptor isd = new InMemorySystemDescriptor(SYSTEM_NAME2);
+    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor(STREAM_ID2, new NoOpSerde<>());
+    InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+    Map<String, String> customConfig = ImmutableMap.of(
+        ApplicationConfig.APP_RUN_ID, runId,
+        JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+        JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+    // Create a TestRunner
+    // Set a InMemoryMetadataFactory.This factory is shared between TestRunner and DrainUtils's write drain method
+    TestRunner testRunner = TestRunner.of(new SimpleHighLevelApplication())
+        .setInMemoryMetadataFactory(metadataStoreFactory)
+        .addConfig(customConfig)
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, numPartitions), delayBetweenMessagesInMillis);
+
+    Config configFromRunner = testRunner.getConfig();
+    MetadataStore metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+    // Write configs to the coordinator stream here as neither the passthrough JC nor the StreamProcessor is writing
+    // configs to coordinator stream. RemoteApplicationRunner typically write the configs to the metadata store
+    // before starting the JC.
+    // We are doing this so that DrainUtils.writeDrainNotification can read app.run.id from the config
+    CoordinatorStreamUtil.writeConfigToCoordinatorStream(metadataStore, configFromRunner);
+
+    // write drain message after a delay
+    ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
+    executorService.schedule(new Callable<String>() {
+      @Override
+      public String call() throws Exception {
+        UUID uuid = DrainUtils.writeDrainNotification(metadataStore);
+        return uuid.toString();
+      }
+    }, drainTriggerDelay, TimeUnit.MILLISECONDS);
+
+    testRunner.run(Duration.ofSeconds(40));
+
+    assertTrue(RECEIVED.size() < numPageViews && RECEIVED.size() > 0);
+    RECEIVED.clear();
+    clearMetadataStore(metadataStore);
+  }
+
+  /**
+   * This test will test drain and that no messages are processed as drain notification is written to the metadata store
+   * before start.
+   * */
+  @Test
+  public void testDrainOnContainerStart() {
+    int numPageViews = 200;
+    int numPartitions = 4;
+    long delayBetweenMessagesInMillis = 500L;
+    String runId = "DrainTestId";
+
+    InMemorySystemDescriptor isd = new InMemorySystemDescriptor(SYSTEM_NAME1);
+    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor(STREAM_ID1, new NoOpSerde<>());
+    InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+    Map<String, String> customConfig = ImmutableMap.of(
+        ApplicationConfig.APP_RUN_ID, runId,
+        JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+        JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+    // Create a TestRunner
+    // Set a InMemoryMetadataFactory.This factory is shared between TestRunner and DrainUtils's write drain method
+    TestRunner testRunner = TestRunner.of(new PageViewEventCountHighLevelApplication())
+        .setInMemoryMetadataFactory(metadataStoreFactory)
+        .addConfig(customConfig)
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, numPartitions), delayBetweenMessagesInMillis);
+
+    Config configFromRunner = testRunner.getConfig();
+    MetadataStore
+        metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+    // Write configs to the coordinator stream here as neither the passthrough JC nor the StreamProcessor is writing
+    // configs to coordinator stream. RemoteApplicationRunner typically write the configs to the metadata store
+    // before starting the JC.
+    // We are doing this so that DrainUtils.writeDrainNotification can read app.run.id from the config
+    CoordinatorStreamUtil.writeConfigToCoordinatorStream(metadataStore, configFromRunner);
+
+    // write on the test thread to ensure that drain notification is available on container start
+    DrainUtils.writeDrainNotification(metadataStore);
+
+    testRunner.run(Duration.ofSeconds(20));
+
+    assertEquals(RECEIVED.size(), 0);
+    RECEIVED.clear();
+    clearMetadataStore(metadataStore);
+  }
+
+  @Test
+  public void testDrainOnContainerStartWithoutReshuffleStages() {
+    int numPageViews = 200;
+    int numPartitions = 4;
+    long delayBetweenMessagesInMillis = 500L;
+    String runId = "DrainTestId";
+
+    InMemorySystemDescriptor isd = new InMemorySystemDescriptor(SYSTEM_NAME2);
+    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor(STREAM_ID2, new NoOpSerde<>());
+    InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+    Map<String, String> customConfig = ImmutableMap.of(
+        ApplicationConfig.APP_RUN_ID, runId,
+        JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+        JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+    // Create a TestRunner
+    // Set a InMemoryMetadataFactory.This factory is shared between TestRunner and DrainUtils's write drain method
+    TestRunner testRunner = TestRunner.of(new SimpleHighLevelApplication())
+        .setInMemoryMetadataFactory(metadataStoreFactory)
+        .addConfig(customConfig)
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, numPartitions), delayBetweenMessagesInMillis);
+
+    Config configFromRunner = testRunner.getConfig();
+    MetadataStore
+        metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+    // Write configs to the coordinator stream here as neither the passthrough JC nor the StreamProcessor is writing
+    // configs to coordinator stream. RemoteApplicationRunner typically write the configs to the metadata store
+    // before starting the JC.
+    // We are doing this so that DrainUtils.writeDrainNotification can read app.run.id from the config
+    CoordinatorStreamUtil.writeConfigToCoordinatorStream(metadataStore, configFromRunner);
+
+    // write on the test thread to ensure that drain notification is available on container start
+    DrainUtils.writeDrainNotification(metadataStore);
+
+    testRunner.run(Duration.ofSeconds(20));
+
+    assertEquals(RECEIVED.size(), 0);
+    RECEIVED.clear();
+    clearMetadataStore(metadataStore);
+  }
+
+  private static void clearMetadataStore(MetadataStore store) {
+    store.all().keySet().forEach(store::delete);
   }
 }
diff --git a/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java
index ed385f035..ab92df4e0 100644
--- a/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java
@@ -29,6 +29,7 @@ import java.util.concurrent.Callable;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
+import net.jcip.annotations.NotThreadSafe;
 import org.apache.samza.application.TaskApplication;
 import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
 import org.apache.samza.config.ApplicationConfig;
@@ -55,7 +56,7 @@ import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescripto
 import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor;
 import org.apache.samza.test.table.TestTableData;
 import org.apache.samza.util.CoordinatorStreamUtil;
-import org.junit.Ignore;
+import org.junit.FixMethodOrder;
 import org.junit.Test;
 
 import static org.junit.Assert.*;
@@ -63,10 +64,12 @@ import static org.junit.Assert.*;
 
 /**
  * End to end integration test to check drain functionality with samza low-level API.
+ * Tests have been annotated with @Ignore as they seem to timeout on the build system.
  * */
+@NotThreadSafe
+@FixMethodOrder
 public class DrainLowLevelApiIntegrationTest {
   private static final List<TestTableData.PageView> RECEIVED = new ArrayList<>();
-
   private static Integer drainCounter = 0;
   private static Integer eosCounter = 0;
 
@@ -109,32 +112,37 @@ public class DrainLowLevelApiIntegrationTest {
     }
   }
 
-  // The test can be occasionally flaky, so we set Ignore annotation
-  // Remove ignore annotation and run the test as follows:
-  // ./gradlew :samza-test:test --tests org.apache.samza.test.drain.DrainLowLevelApiIntegrationTest -PscalaSuffix=2.12
-  @Ignore
+  /**
+   * This test will test drain and consumption of some messages from the in-memory topic.
+   * In order to simulate the real-world behaviour of drain, the test adds messages to the in-memory topic buffer in
+   * a delayed fashion instead of all at once. The test then writes the drain notification message to the in-memory
+   * metadata store to drain and stop the pipeline. This write is done shortly after the pipeline starts and before all
+   * the messages are written to the topic's buffer. As a result, the total count of the processed messages will be less
+   * than the expected count of messages.
+   * */
   @Test
-  public void testPipeline() {
-    int numPageViews = 40;
+  public void testDrain() {
+    int numPageViews = 200;
+    int numPartitions = 4;
+    long delayBetweenMessagesInMillis = 500L;
+    long drainTriggerDelay = 5000L;
+    String runId = "DrainTestId";
+
     InMemorySystemDescriptor isd = new InMemorySystemDescriptor("test");
     InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor("PageView", new NoOpSerde<>());
     InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
 
-    String runId = "DrainTestId";
     Map<String, String> customConfig = ImmutableMap.of(
         ApplicationConfig.APP_RUN_ID, runId,
         JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
         JobConfig.DRAIN_MONITOR_ENABLED, "true");
 
     // Create a TestRunner
-    // Set a InMemoryMetadataFactory. We will use this factory in the test to create a metadata store and
-    // write drain message to it
-    // Mock data comprises of 40 messages across 4 partitions. TestRunner adds a 1 second delay between messages
-    // per partition when writing messages to the InputStream
+    // Set a InMemoryMetadataFactory.This factory is shared between TestRunner and DrainUtils's write drain method
     TestRunner testRunner = TestRunner.of(new PageViewEventCountLowLevelApplication())
         .setInMemoryMetadataFactory(metadataStoreFactory)
         .addConfig(customConfig)
-        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, 4), 1000L);
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, numPartitions), delayBetweenMessagesInMillis);
 
     Config configFromRunner = testRunner.getConfig();
     MetadataStore
@@ -146,7 +154,7 @@ public class DrainLowLevelApiIntegrationTest {
     // We are doing this so that DrainUtils.writeDrainNotification can read app.run.id from the config
     CoordinatorStreamUtil.writeConfigToCoordinatorStream(metadataStore, configFromRunner);
 
-    // write drain message after a delay
+    // Trigger drain after a delay
     ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
     executorService.schedule(new Callable<String>() {
       @Override
@@ -154,10 +162,59 @@ public class DrainLowLevelApiIntegrationTest {
         UUID uuid = DrainUtils.writeDrainNotification(metadataStore);
         return uuid.toString();
       }
-    }, 2000L, TimeUnit.MILLISECONDS);
+    }, drainTriggerDelay, TimeUnit.MILLISECONDS);
 
-    testRunner.run(Duration.ofSeconds(25));
+    testRunner.run(Duration.ofSeconds(40));
 
     assertTrue(RECEIVED.size() < numPageViews && RECEIVED.size() > 0);
+    RECEIVED.clear();
+    clearMetadataStore(metadataStore);
+  }
+
+  @Test
+  public void testDrainOnContainerStart() {
+    int numPageViews = 200;
+    int numPartitions = 4;
+    long delayBetweenMessagesInMillis = 500L;
+    String runId = "DrainTestId";
+
+    InMemorySystemDescriptor isd = new InMemorySystemDescriptor("test");
+    InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor("PageView", new NoOpSerde<>());
+    InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+    Map<String, String> customConfig = ImmutableMap.of(
+        ApplicationConfig.APP_RUN_ID, runId,
+        JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+        JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+    // Create a TestRunner
+    // Set a InMemoryMetadataFactory.This factory is shared between TestRunner and DrainUtils's write drain method
+    TestRunner testRunner = TestRunner.of(new PageViewEventCountLowLevelApplication())
+        .setInMemoryMetadataFactory(metadataStoreFactory)
+        .addConfig(customConfig)
+        .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, numPartitions), delayBetweenMessagesInMillis);
+
+    Config configFromRunner = testRunner.getConfig();
+    MetadataStore
+        metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+    // Write configs to the coordinator stream here as neither the passthrough JC nor the StreamProcessor is writing
+    // configs to coordinator stream. RemoteApplicationRunner typically write the configs to the metadata store
+    // before starting the JC.
+    // We are doing this so that DrainUtils.writeDrainNotification can read app.run.id from the config
+    CoordinatorStreamUtil.writeConfigToCoordinatorStream(metadataStore, configFromRunner);
+
+    // write on the test thread to ensure that drain notification is available on container start
+    DrainUtils.writeDrainNotification(metadataStore);
+
+    testRunner.run(Duration.ofSeconds(40));
+
+    assertEquals(RECEIVED.size(), 0);
+    RECEIVED.clear();
+    clearMetadataStore(metadataStore);
+  }
+
+  private static void clearMetadataStore(MetadataStore store) {
+    store.all().keySet().forEach(store::delete);
   }
 }