You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/04/17 00:35:29 UTC

[beam] branch master updated: [BEAM-5605, BEAM-2939] Add support for FnApiDoFnRunner to handle split calls. (#11414)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 99444c6  [BEAM-5605, BEAM-2939] Add support for FnApiDoFnRunner to handle split calls. (#11414)
99444c6 is described below

commit 99444c6cefbe30ebb222f2adad52fdd150310780
Author: Lukasz Cwik <lu...@gmail.com>
AuthorDate: Thu Apr 16 17:34:53 2020 -0700

    [BEAM-5605, BEAM-2939] Add support for FnApiDoFnRunner to handle split calls. (#11414)
    
    * [BEAM-5605, BEAM-2939] Add support for FnApiDoFnRunner to handle split calls.
    
    The next step is to plumb the split request to the BeamFnDataReadRunner which will forward it to the FnApiDoFnRunner.
    
    * fixup! Minor comment change.
---
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    | 223 ++++++++++-----
 .../org/apache/beam/fn/harness/HandlesSplits.java  |   8 +
 .../fn/harness/control/BundleSplitListener.java    |  29 ++
 .../fn/harness/control/ProcessBundleHandler.java   |  39 +--
 .../beam/fn/harness/FnApiDoFnRunnerTest.java       | 305 ++++++++++++++++-----
 .../harness/control/BundleSplitListenerTest.java   |  58 ++++
 .../harness/control/ProcessBundleHandlerTest.java  |  11 +-
 7 files changed, 499 insertions(+), 174 deletions(-)

diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 13b9c57..41daef7 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -60,6 +60,8 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.data.LogicalEndpoint;
+import org.apache.beam.sdk.fn.splittabledofn.RestrictionTrackers;
+import org.apache.beam.sdk.fn.splittabledofn.RestrictionTrackers.ClaimObserver;
 import org.apache.beam.sdk.fn.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.function.ThrowingRunnable;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -102,7 +104,6 @@ import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.util.Durations;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -184,7 +185,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
       } catch (IOException e) {
         throw new RuntimeException(e);
       }
-      FnDataReceiver<WindowedValue> mainInputConsumer;
+      final FnDataReceiver<WindowedValue> mainInputConsumer;
       switch (pTransform.getSpec().getUrn()) {
         case PTransformTranslation.PAR_DO_TRANSFORM_URN:
           mainInputConsumer = runner::processElementForParDo;
@@ -197,10 +198,22 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
           mainInputConsumer = runner::processElementForSplitRestriction;
           break;
         case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN:
-          mainInputConsumer = runner::processElementForElementAndRestriction;
+          mainInputConsumer =
+              runner.new SplittableFnDataReceiver() {
+                @Override
+                public void accept(WindowedValue input) throws Exception {
+                  runner.processElementForElementAndRestriction(input);
+                }
+              };
           break;
         case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN:
-          mainInputConsumer = runner::processElementForSizedElementAndRestriction;
+          mainInputConsumer =
+              runner.new SplittableFnDataReceiver() {
+                @Override
+                public void accept(WindowedValue input) throws Exception {
+                  runner.processElementForSizedElementAndRestriction(input);
+                }
+              };
           break;
         default:
           throw new IllegalStateException("Unknown urn: " + pTransform.getSpec().getUrn());
@@ -250,6 +263,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   private final FinishBundleArgumentProvider finishBundleArgumentProvider;
 
   /**
+   * Used to guarantee a consistent view of this {@link FnApiDoFnRunner} while setting up for {@link
+   * DoFnInvoker#invokeProcessElement} since {@link #trySplitForElementAndRestriction} may access
+   * internal {@link FnApiDoFnRunner} state concurrently.
+   */
+  private final Object splitLock = new Object();
+
+  /**
    * Only set for {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} and {@link
    * PTransformTranslation#SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN} transforms. Can
    * only be invoked from within {@code processElement...} methods.
@@ -551,7 +571,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
                     currentElement.withValue(
                         KV.of(
                             currentElement.getValue(),
-                            KV.of(splitResult.getPrimary(), watermarkEstimatorState))),
+                            KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState))),
                     currentElement.withValue(
                         KV.of(
                             currentElement.getValue(),
@@ -599,7 +619,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
                       KV.of(
                           KV.of(
                               currentElement.getValue(),
-                              KV.of(splitResult.getPrimary(), watermarkEstimatorState)),
+                              KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
                           primarySize)),
                   currentElement.withValue(
                       KV.of(
@@ -620,7 +640,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     }
   }
 
-  public void startBundle() {
+  private void startBundle() {
     this.stateAccessor =
         new FnApiStateAccessor(
             pipelineOptions,
@@ -662,7 +682,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     doFnInvoker.invokeStartBundle(startBundleArgumentProvider);
   }
 
-  public void processElementForParDo(WindowedValue<InputT> elem) {
+  private void processElementForParDo(WindowedValue<InputT> elem) {
     currentElement = elem;
     try {
       Iterator<BoundedWindow> windowIterator =
@@ -677,7 +697,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     }
   }
 
-  public void processElementForPairWithRestriction(WindowedValue<InputT> elem) {
+  private void processElementForPairWithRestriction(WindowedValue<InputT> elem) {
     currentElement = elem;
     try {
       Iterator<BoundedWindow> windowIterator =
@@ -702,7 +722,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     }
   }
 
-  public void processElementForSplitRestriction(
+  private void processElementForSplitRestriction(
       WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
     currentElement = elem.withValue(elem.getValue().getKey());
     currentRestriction = elem.getValue().getValue().getKey();
@@ -735,24 +755,38 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     public abstract WindowedValue getResidualRoot();
   }
 
-  public void processElementForSizedElementAndRestriction(
+  private void processElementForSizedElementAndRestriction(
       WindowedValue<KV<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>, Double>> elem) {
     processElementForElementAndRestriction(elem.withValue(elem.getValue().getKey()));
   }
 
-  public void processElementForElementAndRestriction(
+  private void processElementForElementAndRestriction(
       WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
     currentElement = elem.withValue(elem.getValue().getKey());
     try {
       Iterator<BoundedWindow> windowIterator =
           (Iterator<BoundedWindow>) elem.getWindows().iterator();
       while (windowIterator.hasNext()) {
-        currentRestriction = elem.getValue().getValue().getKey();
-        currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
-        currentWindow = windowIterator.next();
-        currentTracker = doFnInvoker.invokeNewTracker(processContext);
-        currentWatermarkEstimator =
-            WatermarkEstimators.threadSafe(doFnInvoker.invokeNewWatermarkEstimator(processContext));
+        synchronized (splitLock) {
+          currentRestriction = elem.getValue().getValue().getKey();
+          currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
+          currentWindow = windowIterator.next();
+          currentTracker =
+              RestrictionTrackers.observe(
+                  doFnInvoker.invokeNewTracker(processContext),
+                  new ClaimObserver<PositionT>() {
+                    @Override
+                    public void onClaimed(PositionT position) {}
+
+                    @Override
+                    public void onClaimFailed(PositionT position) {}
+                  });
+          currentWatermarkEstimator =
+              WatermarkEstimators.threadSafe(
+                  doFnInvoker.invokeNewWatermarkEstimator(processContext));
+        }
+
+        // It is important to ensure that {@code splitLock} is not held during #invokeProcessElement
         DoFn.ProcessContinuation continuation = doFnInvoker.invokeProcessElement(processContext);
         // Ensure that all the work is done if the user tells us that they don't want to
         // resume processing.
@@ -761,72 +795,111 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
           continue;
         }
 
-        // Make sure to get the output watermark before we split to ensure that the lower bound
-        // applies to both the primary and residual.
-        KV<Instant, WatermarkEstimatorStateT> watermarkAndState =
-            currentWatermarkEstimator.getWatermarkAndState();
-        SplitResult<RestrictionT> result = currentTracker.trySplit(0);
+        // Attempt to checkpoint the current restriction.
+        HandlesSplits.SplitResult splitResult =
+            trySplitForElementAndRestriction(0, continuation.resumeDelay());
         // After the user has chosen to resume processing later, the Runner may have stolen
-        // the remainder of work through a split call so the above trySplit may fail. If so,
+        // the remainder of work through a split call so the above trySplit may return null. If so,
         // the current restriction must be done.
-        if (result == null) {
+        if (splitResult == null) {
           currentTracker.checkDone();
           continue;
         }
-
-        // Otherwise we have a successful self checkpoint.
-        WindowedSplitResult windowedSplitResult =
-            convertSplitResultToWindowedSplitResult.apply(result, watermarkAndState.getValue());
-        ByteString.Output primaryBytes = ByteString.newOutput();
-        ByteString.Output residualBytes = ByteString.newOutput();
-        try {
-          Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
-          fullInputCoder.encode(windowedSplitResult.getPrimaryRoot(), primaryBytes);
-          fullInputCoder.encode(windowedSplitResult.getResidualRoot(), residualBytes);
-        } catch (IOException e) {
-          throw new RuntimeException(e);
-        }
-        BundleApplication.Builder primaryApplication =
-            BundleApplication.newBuilder()
-                .setTransformId(pTransformId)
-                .setInputId(mainInputId)
-                .setElement(primaryBytes.toByteString());
-        BundleApplication.Builder residualApplication =
-            BundleApplication.newBuilder()
-                .setTransformId(pTransformId)
-                .setInputId(mainInputId)
-                .setElement(residualBytes.toByteString());
-
-        if (!watermarkAndState.getKey().equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
-          for (String outputId : pTransform.getOutputsMap().keySet()) {
-            residualApplication.putOutputWatermarks(
-                outputId,
-                org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                    .setSeconds(watermarkAndState.getKey().getMillis() / 1000)
-                    .setNanos((int) (watermarkAndState.getKey().getMillis() % 1000) * 1000000)
-                    .build());
-          }
-        }
+        // Forward the split to the bundle level split listener.
         splitListener.split(
-            ImmutableList.of(primaryApplication.build()),
-            ImmutableList.of(
-                DelayedBundleApplication.newBuilder()
-                    .setApplication(residualApplication.build())
-                    .setRequestedTimeDelay(
-                        Durations.fromMillis(continuation.resumeDelay().getMillis()))
-                    .build()));
+            Collections.singletonList(splitResult.getPrimaryRoot()),
+            Collections.singletonList(splitResult.getResidualRoot()));
       }
     } finally {
-      currentElement = null;
-      currentRestriction = null;
-      currentWatermarkEstimatorState = null;
-      currentWindow = null;
-      currentTracker = null;
-      currentWatermarkEstimator = null;
+      synchronized (splitLock) {
+        currentElement = null;
+        currentRestriction = null;
+        currentWatermarkEstimatorState = null;
+        currentWindow = null;
+        currentTracker = null;
+        currentWatermarkEstimator = null;
+      }
+    }
+  }
+
+  /**
+   * An abstract class which forwards split and progress calls allowing the implementer to choose
+   * where input elements are sent.
+   */
+  private abstract class SplittableFnDataReceiver
+      implements HandlesSplits, FnDataReceiver<WindowedValue> {
+    @Override
+    public SplitResult trySplit(double fractionOfRemainder) {
+      return trySplitForElementAndRestriction(fractionOfRemainder, Duration.ZERO);
+    }
+
+    @Override
+    public double getProgress() {
+      // TODO(BEAM-2939): Implement plumbing progress through for splitting.
+      return 0;
+    }
+  }
+
+  private HandlesSplits.SplitResult trySplitForElementAndRestriction(
+      double fractionOfRemainder, Duration resumeDelay) {
+    synchronized (splitLock) {
+      // There is nothing to split if we are between element and restriction processing calls.
+      if (currentTracker == null) {
+        return null;
+      }
+
+      // Make sure to get the output watermark before we split to ensure that the lower bound
+      // applies to the residual.
+      KV<Instant, WatermarkEstimatorStateT> watermarkAndState =
+          currentWatermarkEstimator.getWatermarkAndState();
+      SplitResult<RestrictionT> result = currentTracker.trySplit(fractionOfRemainder);
+      if (result == null) {
+        return null;
+      }
+
+      // We have a successful self split, either runner initiated or via a self checkpoint.
+      WindowedSplitResult windowedSplitResult =
+          convertSplitResultToWindowedSplitResult.apply(result, watermarkAndState.getValue());
+      ByteString.Output primaryBytes = ByteString.newOutput();
+      ByteString.Output residualBytes = ByteString.newOutput();
+      try {
+        Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
+        fullInputCoder.encode(windowedSplitResult.getPrimaryRoot(), primaryBytes);
+        fullInputCoder.encode(windowedSplitResult.getResidualRoot(), residualBytes);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      BundleApplication.Builder primaryApplication =
+          BundleApplication.newBuilder()
+              .setTransformId(pTransformId)
+              .setInputId(mainInputId)
+              .setElement(primaryBytes.toByteString());
+      BundleApplication.Builder residualApplication =
+          BundleApplication.newBuilder()
+              .setTransformId(pTransformId)
+              .setInputId(mainInputId)
+              .setElement(residualBytes.toByteString());
+
+      if (!watermarkAndState.getKey().equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
+        for (String outputId : pTransform.getOutputsMap().keySet()) {
+          residualApplication.putOutputWatermarks(
+              outputId,
+              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                  .setSeconds(watermarkAndState.getKey().getMillis() / 1000)
+                  .setNanos((int) (watermarkAndState.getKey().getMillis() % 1000) * 1000000)
+                  .build());
+        }
+      }
+      return HandlesSplits.SplitResult.of(
+          primaryApplication.build(),
+          DelayedBundleApplication.newBuilder()
+              .setApplication(residualApplication.build())
+              .setRequestedTimeDelay(Durations.fromMillis(resumeDelay.getMillis()))
+              .build());
     }
   }
 
-  public <K> void processTimer(String timerId, TimeDomain timeDomain, Timer<K> timer) {
+  private <K> void processTimer(String timerId, TimeDomain timeDomain, Timer<K> timer) {
     currentTimer = timer;
     currentTimeDomain = timeDomain;
     try {
@@ -843,7 +916,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     }
   }
 
-  public void finishBundle() throws Exception {
+  private void finishBundle() throws Exception {
     for (TimerHandler timerHandler : timerHandlers.values()) {
       timerHandler.awaitCompletion();
     }
@@ -858,7 +931,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     this.stateAccessor = null;
   }
 
-  public void tearDown() {
+  private void tearDown() {
     doFnInvoker.invokeTeardown();
   }
 
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
index a2ac123..d6a161a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
@@ -19,10 +19,18 @@ package org.apache.beam.fn.harness;
 
 import com.google.auto.value.AutoValue;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
 
+/**
+ * An interface that may be used to extend a {@link FnDataReceiver} signalling that the downstream
+ * runner is capable of performing splitting and providing progress reporting.
+ */
 public interface HandlesSplits {
+
+  /** Returns null if the split was unsuccessful. */
   SplitResult trySplit(double fractionOfRemainder);
 
+  /** Returns the current progress of the active element. */
   double getProgress();
 
   @AutoValue
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
index 9eab245..830834e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -17,7 +17,10 @@
  */
 package org.apache.beam.fn.harness.control;
 
+import com.google.auto.value.AutoValue;
+import java.util.ArrayList;
 import java.util.List;
+import javax.annotation.concurrent.NotThreadSafe;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
 
@@ -37,4 +40,30 @@ public interface BundleSplitListener {
    * it for someone else to execute.
    */
   void split(List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots);
+
+  /** A {@link BundleSplitListener} which gathers all splits produced and stores them in memory. */
+  @AutoValue
+  @NotThreadSafe
+  abstract class InMemory implements BundleSplitListener {
+    public static InMemory create() {
+      return new AutoValue_BundleSplitListener_InMemory(
+          new ArrayList<BundleApplication>(), new ArrayList<DelayedBundleApplication>());
+    }
+
+    @Override
+    public void split(
+        List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots) {
+      getPrimaryRoots().addAll(primaryRoots);
+      getResidualRoots().addAll(residualRoots);
+    }
+
+    public void clear() {
+      getPrimaryRoots().clear();
+      getResidualRoots().clear();
+    }
+
+    public abstract List<BundleApplication> getPrimaryRoots();
+
+    public abstract List<DelayedBundleApplication> getResidualRoots();
+  }
 }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index a788348..e11d1ec 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -47,8 +47,6 @@ import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient;
 import org.apache.beam.fn.harness.state.BeamFnStateClient;
 import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
@@ -75,13 +73,11 @@ import org.apache.beam.sdk.util.common.ReflectHelpers;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.TextFormat;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.SetMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
 import org.joda.time.Instant;
@@ -275,7 +271,6 @@ public class ProcessBundleHandler {
             });
     PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry();
     PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry();
-    Multimap<String, DelayedBundleApplication> allResiduals = bundleProcessor.getAllResiduals();
     PCollectionConsumerRegistry pCollectionConsumerRegistry =
         bundleProcessor.getpCollectionConsumerRegistry();
     MetricsContainerStepMap metricsContainerRegistry =
@@ -299,10 +294,11 @@ public class ProcessBundleHandler {
           LOG.debug("Finishing function {}", finishFunction);
           finishFunction.run();
         }
-        if (!allResiduals.isEmpty()) {
-          response.addAllResidualRoots(allResiduals.values());
-        }
       }
+
+      // Add all checkpointed residuals to the response.
+      response.addAllResidualRoots(bundleProcessor.getSplitListener().getResidualRoots());
+
       // Get start bundle Execution Time Metrics.
       for (MonitoringInfo mi : startFunctionRegistry.getExecutionTimeMonitoringInfos()) {
         response.addMonitoringInfos(mi);
@@ -403,22 +399,7 @@ public class ProcessBundleHandler {
                 queueingClient, bundleDescriptor.getTimerApiServiceDescriptor())
             : new FailAllTimerRegistrations(processBundleRequest);
 
-    Multimap<String, DelayedBundleApplication> allResiduals = ArrayListMultimap.create();
-    Multimap<String, BundleApplication> allPrimaries = ArrayListMultimap.create();
-    BundleSplitListener splitListener =
-        (List<BundleApplication> primaries, List<DelayedBundleApplication> residuals) -> {
-          // Reset primaries and accumulate residuals.
-          Multimap<String, BundleApplication> newPrimaries = ArrayListMultimap.create();
-          for (BundleApplication primary : primaries) {
-            newPrimaries.put(primary.getTransformId(), primary);
-          }
-          allPrimaries.clear();
-          allPrimaries.putAll(newPrimaries);
-
-          for (DelayedBundleApplication residual : residuals) {
-            allResiduals.put(residual.getApplication().getTransformId(), residual);
-          }
-        };
+    BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
 
     Collection<CallbackRegistration> bundleFinalizationCallbackRegistrations = new ArrayList<>();
     BundleFinalizer bundleFinalizer =
@@ -435,7 +416,7 @@ public class ProcessBundleHandler {
             startFunctionRegistry,
             finishFunctionRegistry,
             tearDownFunctions,
-            allResiduals,
+            splitListener,
             pCollectionConsumerRegistry,
             metricsContainerRegistry,
             stateTracker,
@@ -563,7 +544,7 @@ public class ProcessBundleHandler {
         PTransformFunctionRegistry startFunctionRegistry,
         PTransformFunctionRegistry finishFunctionRegistry,
         List<ThrowingRunnable> tearDownFunctions,
-        Multimap<String, DelayedBundleApplication> allResiduals,
+        BundleSplitListener.InMemory splitListener,
         PCollectionConsumerRegistry pCollectionConsumerRegistry,
         MetricsContainerStepMap metricsContainerRegistry,
         ExecutionStateTracker stateTracker,
@@ -574,7 +555,7 @@ public class ProcessBundleHandler {
           startFunctionRegistry,
           finishFunctionRegistry,
           tearDownFunctions,
-          allResiduals,
+          splitListener,
           pCollectionConsumerRegistry,
           metricsContainerRegistry,
           stateTracker,
@@ -591,7 +572,7 @@ public class ProcessBundleHandler {
 
     abstract List<ThrowingRunnable> getTearDownFunctions();
 
-    abstract Multimap<String, DelayedBundleApplication> getAllResiduals();
+    abstract BundleSplitListener.InMemory getSplitListener();
 
     abstract PCollectionConsumerRegistry getpCollectionConsumerRegistry();
 
@@ -616,7 +597,7 @@ public class ProcessBundleHandler {
     void reset() {
       getStartFunctionRegistry().reset();
       getFinishFunctionRegistry().reset();
-      getAllResiduals().clear();
+      getSplitListener().clear();
       getpCollectionConsumerRegistry().reset();
       getMetricsContainerRegistry().reset();
       getStateTracker().reset();
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index b37ac2c..0aa2381 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
@@ -37,6 +38,14 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.ServiceLoader;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
 import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.FakeBeamFnTimerClient;
 import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
@@ -48,9 +57,11 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
+import org.apache.beam.runners.core.construction.CoderTranslation;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
 import org.apache.beam.runners.core.construction.SdkComponents;
 import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
 import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
@@ -61,6 +72,7 @@ import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.fn.data.LogicalEndpoint;
@@ -1019,30 +1031,85 @@ public class FnApiDoFnRunnerTest implements Serializable {
     fail("Expected registrar not found.");
   }
 
+  /**
+   * The trySplit testing of this splittable DoFn is done when processing the {@link
+   * TestSplittableDoFn#SPLIT_ELEMENT}.
+   *
+   * <p>The expected thread flow is:
+   *
+   * <ul>
+   *   <li>splitting thread: {@link TestSplittableDoFn#waitForSplitElementToBeProcessed()}
+   *   <li>process element thread: {@link TestSplittableDoFn#enableAndWaitForTrySplitToHappen()}
+   *   <li>splitting thread: perform try split
+   *   <li>splitting thread: {@link TestSplittableDoFn#releaseWaitingProcessElementThread()}
+   * </ul>
+   */
   static class TestSplittableDoFn extends DoFn<String, String> {
+    private static final ConcurrentMap<String, KV<CountDownLatch, CountDownLatch>>
+        DOFN_INSTANCE_TO_LOCK = new ConcurrentHashMap<>();
+    private static final long SPLIT_ELEMENT = 3;
+
+    private KV<CountDownLatch, CountDownLatch> getLatches() {
+      return DOFN_INSTANCE_TO_LOCK.computeIfAbsent(
+          this.uuid, (uuid) -> KV.of(new CountDownLatch(1), new CountDownLatch(1)));
+    }
+
+    private void enableAndWaitForTrySplitToHappen() throws Exception {
+      KV<CountDownLatch, CountDownLatch> latches = getLatches();
+      latches.getKey().countDown();
+      if (!latches.getValue().await(30, TimeUnit.SECONDS)) {
+        fail("Failed to wait for trySplit to occur.");
+      }
+    }
+
+    private void waitForSplitElementToBeProcessed() throws Exception {
+      KV<CountDownLatch, CountDownLatch> latches = getLatches();
+      if (!latches.getKey().await(30, TimeUnit.SECONDS)) {
+        fail("Failed to wait for split element to be processed.");
+      }
+    }
+
+    private void releaseWaitingProcessElementThread() {
+      KV<CountDownLatch, CountDownLatch> latches = getLatches();
+      latches.getValue().countDown();
+    }
+
     private final PCollectionView<String> singletonSideInput;
+    private final String uuid;
 
     private TestSplittableDoFn(PCollectionView<String> singletonSideInput) {
       this.singletonSideInput = singletonSideInput;
+      this.uuid = UUID.randomUUID().toString();
     }
 
     @ProcessElement
     public ProcessContinuation processElement(
         ProcessContext context,
         RestrictionTracker<OffsetRange, Long> tracker,
-        ManualWatermarkEstimator<Instant> watermarkEstimator) {
-      int upperBound = Integer.parseInt(context.sideInput(singletonSideInput));
-      for (int i = 0; i < upperBound; ++i) {
-        if (tracker.tryClaim((long) i)) {
-          context.outputWithTimestamp(
-              context.element() + ":" + i, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
-          watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
+        ManualWatermarkEstimator<Instant> watermarkEstimator)
+        throws Exception {
+      long checkpointUpperBound = Long.parseLong(context.sideInput(singletonSideInput));
+      long position = tracker.currentRestriction().getFrom();
+      boolean claimStatus;
+      while (true) {
+        claimStatus = (tracker.tryClaim(position));
+        if (!claimStatus) {
+          break;
+        } else if (position == SPLIT_ELEMENT) {
+          enableAndWaitForTrySplitToHappen();
+        }
+        context.outputWithTimestamp(
+            context.element() + ":" + position, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+        watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+        position += 1L;
+        if (position == checkpointUpperBound) {
+          break;
         }
       }
-      if (tracker.currentRestriction().getTo() > upperBound) {
-        return ProcessContinuation.resume().withResumeDelay(Duration.millis(42L));
-      } else {
+      if (!claimStatus) {
         return ProcessContinuation.stop();
+      } else {
+        return ProcessContinuation.resume().withResumeDelay(Duration.millis(54321L));
       }
     }
 
@@ -1079,10 +1146,9 @@ public class FnApiDoFnRunnerTest implements Serializable {
     Pipeline p = Pipeline.create();
     PCollection<String> valuePCollection = p.apply(Create.of("unused"));
     PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+    TestSplittableDoFn doFn = new TestSplittableDoFn(singletonSideInputView);
     valuePCollection.apply(
-        TEST_TRANSFORM_ID,
-        ParDo.of(new TestSplittableDoFn(singletonSideInputView))
-            .withSideInputs(singletonSideInputView));
+        TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
 
     RunnerApi.Pipeline pProto =
         ProtoOverrides.updateTransform(
@@ -1106,12 +1172,32 @@ public class FnApiDoFnRunnerTest implements Serializable {
         pProto.getComponents().getTransformsOrThrow(expandedTransformId);
     String inputPCollectionId =
         pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+    RunnerApi.PCollection inputPCollection =
+        pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+    RehydratedComponents rehydratedComponents =
+        RehydratedComponents.forComponents(pProto.getComponents());
+    Coder<WindowedValue> inputCoder =
+        WindowedValue.getFullCoder(
+            CoderTranslation.fromProto(
+                pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                rehydratedComponents),
+            (Coder)
+                CoderTranslation.fromProto(
+                    pProto
+                        .getComponents()
+                        .getCodersOrThrow(
+                            pProto
+                                .getComponents()
+                                .getWindowingStrategiesOrThrow(
+                                    inputPCollection.getWindowingStrategyId())
+                                .getWindowCoderId()),
+                    rehydratedComponents));
     String outputPCollectionId = pTransform.getOutputsOrThrow("output");
 
     ImmutableMap<StateKey, ByteString> stateData =
         ImmutableMap.of(
             multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
-            encode("3"));
+            encode("8"));
 
     FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
 
@@ -1131,8 +1217,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
         new PTransformFunctionRegistry(
             mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
     List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-    List<BundleApplication> primarySplits = new ArrayList<>();
-    List<DelayedBundleApplication> residualSplits = new ArrayList<>();
+    BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
 
     new FnApiDoFnRunner.Factory<>()
         .createRunnerForPTransform(
@@ -1150,15 +1235,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
             startFunctionRegistry,
             finishFunctionRegistry,
             teardownFunctions::add,
-            new BundleSplitListener() {
-              @Override
-              public void split(
-                  List<BundleApplication> primaryRoots,
-                  List<DelayedBundleApplication> residualRoots) {
-                primarySplits.addAll(primaryRoots);
-                residualSplits.addAll(residualRoots);
-              }
-            },
+            splitListener,
             null /* bundleFinalizer */);
 
     Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
@@ -1168,46 +1245,146 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
     FnDataReceiver<WindowedValue<?>> mainInput =
         consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(
-        valueInGlobalWindow(
-            KV.of(
-                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0)));
-    BundleApplication primaryRoot = Iterables.getOnlyElement(primarySplits);
-    DelayedBundleApplication residualRoot = Iterables.getOnlyElement(residualSplits);
-    assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
-    assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
-    assertEquals(
-        ParDoTranslation.getMainInputName(pTransform), residualRoot.getApplication().getInputId());
-    assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
-    Instant expectedOutputWatermark =
-        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(
-            2); // side input upperBound is 3 hence we only process the first two elements
-    assertEquals(
-        ImmutableMap.of(
-            "output",
-            org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                .setSeconds(expectedOutputWatermark.getMillis() / 1000)
-                .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
-                .build()),
-        residualRoot.getApplication().getOutputWatermarksMap());
-    primarySplits.clear();
-    residualSplits.clear();
+    assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+    {
+      mainInput.accept(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                  5.0)));
+      // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
+      // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
+      // with
+      // the original watermark while the residual represents [8, 10) with the new MIN + 7
+      // watermark.
+      BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+      DelayedBundleApplication residualRoot =
+          Iterables.getOnlyElement(splitListener.getResidualRoots());
+      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+      assertEquals(
+          ParDoTranslation.getMainInputName(pTransform),
+          residualRoot.getApplication().getInputId());
+      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+      assertEquals(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE)), 3.0)),
+          inputCoder.decode(primaryRoot.getElement().newInput()));
+      assertEquals(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of(
+                      "5", KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
+                  2.0)),
+          inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
+      assertEquals(
+          ImmutableMap.of(
+              "output",
+              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                  .build()),
+          residualRoot.getApplication().getOutputWatermarksMap());
+      assertEquals(
+          org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
+              .setSeconds(54)
+              .setNanos(321000000)
+              .build(),
+          residualRoot.getRequestedTimeDelay());
+      splitListener.clear();
+
+      mainInput.accept(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                  2.0)));
+      assertThat(
+          mainOutputValues,
+          contains(
+              timestampedValueInGlobalWindow("5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5)),
+              timestampedValueInGlobalWindow("5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6)),
+              timestampedValueInGlobalWindow("5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7)),
+              timestampedValueInGlobalWindow("2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
+              timestampedValueInGlobalWindow("2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))));
+      assertTrue(splitListener.getPrimaryRoots().isEmpty());
+      assertTrue(splitListener.getResidualRoots().isEmpty());
+      mainOutputValues.clear();
+    }
 
-    mainInput.accept(
-        valueInGlobalWindow(
-            KV.of(
-                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0)));
-    assertThat(
-        mainOutputValues,
-        contains(
-            timestampedValueInGlobalWindow("5:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
-            timestampedValueInGlobalWindow("5:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1)),
-            timestampedValueInGlobalWindow("5:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
-            timestampedValueInGlobalWindow("2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
-            timestampedValueInGlobalWindow("2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))));
-    assertTrue(primarySplits.isEmpty());
-    assertTrue(residualSplits.isEmpty());
-    mainOutputValues.clear();
+    {
+      // Setup and launch the trySplit thread.
+      ExecutorService executorService = Executors.newSingleThreadExecutor();
+      Future<HandlesSplits.SplitResult> trySplitFuture =
+          executorService.submit(
+              () -> {
+                try {
+                  doFn.waitForSplitElementToBeProcessed();
+                  return ((HandlesSplits) mainInput).trySplit(0);
+                } finally {
+                  doFn.releaseWaitingProcessElementThread();
+                }
+              });
+
+      mainInput.accept(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                  2.0)));
+      // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split.
+      // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
+      // has yet to be invoked for the split element and that the primary represents [0, 4) with
+      // the original watermark while the residual represents [4, 5) with the new MIN + 2 watermark.
+      assertThat(
+          mainOutputValues,
+          contains(
+              timestampedValueInGlobalWindow("7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
+              timestampedValueInGlobalWindow("7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1)),
+              timestampedValueInGlobalWindow("7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
+              timestampedValueInGlobalWindow("7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3))));
+
+      HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+      BundleApplication primaryRoot = trySplitResult.getPrimaryRoot();
+      DelayedBundleApplication residualRoot = trySplitResult.getResidualRoot();
+      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+      assertEquals(
+          ParDoTranslation.getMainInputName(pTransform),
+          residualRoot.getApplication().getInputId());
+      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+      assertEquals(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)), 4.0)),
+          inputCoder.decode(primaryRoot.getElement().newInput()));
+      assertEquals(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of(
+                      "7", KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
+                  1.0)),
+          inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
+      assertEquals(
+          ImmutableMap.of(
+              "output",
+              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                  .build()),
+          residualRoot.getApplication().getOutputWatermarksMap());
+      // We expect 0 resume delay.
+      assertEquals(
+          residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+          residualRoot.getRequestedTimeDelay());
+      // We don't expect the outputs to goto the SDK initiated checkpointing listener.
+      assertTrue(splitListener.getPrimaryRoots().isEmpty());
+      assertTrue(splitListener.getResidualRoots().isEmpty());
+      mainOutputValues.clear();
+      executorService.shutdown();
+    }
 
     Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
     assertThat(mainOutputValues, empty());
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
new file mode 100644
index 0000000..00b4e15
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.control;
+
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.empty;
+import static org.junit.Assert.assertThat;
+
+import java.util.Arrays;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link BundleSplitListener}. */
+@RunWith(JUnit4.class)
+public class BundleSplitListenerTest {
+  private static final BundleApplication TEST_PRIMARY_1 =
+      BundleApplication.newBuilder().setInputId("primary1").build();
+  private static final BundleApplication TEST_PRIMARY_2 =
+      BundleApplication.newBuilder().setInputId("primary2").build();
+  private static final DelayedBundleApplication TEST_RESIDUAL_1 =
+      DelayedBundleApplication.newBuilder()
+          .setApplication(BundleApplication.newBuilder().setInputId("residual1").build())
+          .build();
+  private static final DelayedBundleApplication TEST_RESIDUAL_2 =
+      DelayedBundleApplication.newBuilder()
+          .setApplication(BundleApplication.newBuilder().setInputId("residual2").build())
+          .build();
+
+  @Test
+  public void testInMemory() {
+    BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
+    splitListener.split(Arrays.asList(TEST_PRIMARY_1), Arrays.asList(TEST_RESIDUAL_1));
+    splitListener.split(Arrays.asList(TEST_PRIMARY_2), Arrays.asList(TEST_RESIDUAL_2));
+    assertThat(splitListener.getPrimaryRoots(), contains(TEST_PRIMARY_1, TEST_PRIMARY_2));
+    assertThat(splitListener.getResidualRoots(), contains(TEST_RESIDUAL_1, TEST_RESIDUAL_2));
+    splitListener.clear();
+    assertThat(splitListener.getPrimaryRoots(), empty());
+    assertThat(splitListener.getResidualRoots(), empty());
+  }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index ff97ed9..982fd96 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -92,7 +92,6 @@ import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
 import org.joda.time.Instant;
 import org.junit.Before;
@@ -196,8 +195,8 @@ public class ProcessBundleHandlerTest {
     }
 
     @Override
-    Multimap<String, BeamFnApi.DelayedBundleApplication> getAllResiduals() {
-      return wrappedBundleProcessor.getAllResiduals();
+    BundleSplitListener.InMemory getSplitListener() {
+      return wrappedBundleProcessor.getSplitListener();
     }
 
     @Override
@@ -535,7 +534,7 @@ public class ProcessBundleHandlerTest {
   public void testBundleProcessorReset() {
     PTransformFunctionRegistry startFunctionRegistry = mock(PTransformFunctionRegistry.class);
     PTransformFunctionRegistry finishFunctionRegistry = mock(PTransformFunctionRegistry.class);
-    Multimap<String, BeamFnApi.DelayedBundleApplication> allResiduals = mock(Multimap.class);
+    BundleSplitListener.InMemory splitListener = mock(BundleSplitListener.InMemory.class);
     Collection<CallbackRegistration> bundleFinalizationCallbacks = mock(Collection.class);
     PCollectionConsumerRegistry pCollectionConsumerRegistry =
         mock(PCollectionConsumerRegistry.class);
@@ -549,7 +548,7 @@ public class ProcessBundleHandlerTest {
             startFunctionRegistry,
             finishFunctionRegistry,
             new ArrayList<>(),
-            allResiduals,
+            splitListener,
             pCollectionConsumerRegistry,
             metricsContainerRegistry,
             stateTracker,
@@ -560,7 +559,7 @@ public class ProcessBundleHandlerTest {
     bundleProcessor.reset();
     verify(startFunctionRegistry, times(1)).reset();
     verify(finishFunctionRegistry, times(1)).reset();
-    verify(allResiduals, times(1)).clear();
+    verify(splitListener, times(1)).clear();
     verify(pCollectionConsumerRegistry, times(1)).reset();
     verify(metricsContainerRegistry, times(1)).reset();
     verify(stateTracker, times(1)).reset();