You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2020/08/27 02:02:19 UTC

[beam] branch master updated: Handle split when truncate observes windows.

This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6473909  Handle split when truncate observes windows.
     new d6adcdf  Merge pull request #12419 from boyuanzz/window_split
6473909 is described below

commit 6473909fb600631455126a968eb5c86908c145e9
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Wed Jul 29 20:01:03 2020 -0700

    Handle split when truncate observes windows.
---
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    |  704 +-
 .../beam/fn/harness/FnApiDoFnRunnerTest.java       | 7330 +++++++++++---------
 2 files changed, 4704 insertions(+), 3330 deletions(-)

diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index de479d5..8c7b677 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -31,7 +31,6 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
-import java.util.ListIterator;
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
@@ -114,6 +113,7 @@ import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.util.Durations;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableListMultimap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
@@ -248,8 +248,12 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   /** Only valid during {@code processElement...} methods, null otherwise. */
   private WindowedValue<InputT> currentElement;
 
-  /** Only valid during {@link #processElementForSizedElementAndRestriction}. */
-  private ListIterator<BoundedWindow> currentWindowIterator;
+  /**
+   * Only valud during {@link
+   * #processElementForWindowObservingSizedElementAndRestriction(WindowedValue)} and {@link
+   * #processElementForWindowObservingTruncateRestriction(WindowedValue)}.
+   */
+  private List<BoundedWindow> currentWindows;
 
   /**
    * Only valud during {@link
@@ -259,6 +263,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   private int windowStopIndex;
 
   /**
+   * Only valud during {@link
+   * #processElementForWindowObservingSizedElementAndRestriction(WindowedValue)} and {@link
+   * #processElementForWindowObservingTruncateRestriction(WindowedValue)}.
+   */
+  private int windowCurrentIndex;
+
+  /**
    * Only valid during {@link #processElementForPairWithRestriction}, {@link
    * #processElementForSplitRestriction}, and {@link #processElementForSizedElementAndRestriction},
    * null otherwise.
@@ -271,6 +282,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
    */
   private WatermarkEstimatorStateT currentWatermarkEstimatorState;
 
+  /**
+   * Only valud during {@link
+   * #processElementForWindowObservingSizedElementAndRestriction(WindowedValue)} and {@link
+   * #processElementForWindowObservingTruncateRestriction(WindowedValue)}.
+   */
+  private Instant initialWatermark;
+
   /** Only valid during {@link #processElementForSizedElementAndRestriction}, null otherwise. */
   private WatermarkEstimators.WatermarkAndStateObserver<WatermarkEstimatorStateT>
       currentWatermarkEstimator;
@@ -532,10 +550,10 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
                     processElementForWindowObservingTruncateRestriction(input);
                   }
 
-                  // TODO(BEAM-10303): Split should work with window observing optimization.
                   @Override
                   public SplitResult trySplit(double fractionOfRemainder) {
-                    return null;
+                    return trySplitForWindowObservingTruncateRestriction(
+                        fractionOfRemainder, splitDelegate);
                   }
 
                   @Override
@@ -893,23 +911,34 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   private void processElementForWindowObservingTruncateRestriction(
       WindowedValue<KV<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>, Double>> elem) {
     currentElement = elem.withValue(elem.getValue().getKey().getKey());
-    currentRestriction = elem.getValue().getKey().getValue().getKey();
-    currentWatermarkEstimatorState = elem.getValue().getKey().getValue().getValue();
     try {
-      Iterator<BoundedWindow> windowIterator =
-          (Iterator<BoundedWindow>) elem.getWindows().iterator();
-      while (windowIterator.hasNext()) {
-        currentWindow = windowIterator.next();
-        currentTracker =
-            RestrictionTrackers.observe(
-                doFnInvoker.invokeNewTracker(processContext),
-                new ClaimObserver<PositionT>() {
-                  @Override
-                  public void onClaimed(PositionT position) {}
+      windowCurrentIndex = -1;
+      windowStopIndex = currentElement.getWindows().size();
+      currentWindows = ImmutableList.copyOf(currentElement.getWindows());
+      while (true) {
+        synchronized (splitLock) {
+          windowCurrentIndex++;
+          if (windowCurrentIndex >= windowStopIndex) {
+            break;
+          }
+          currentRestriction = elem.getValue().getKey().getValue().getKey();
+          currentWatermarkEstimatorState = elem.getValue().getKey().getValue().getValue();
+          currentWindow = currentWindows.get(windowCurrentIndex);
+          currentTracker =
+              RestrictionTrackers.observe(
+                  doFnInvoker.invokeNewTracker(processContext),
+                  new ClaimObserver<PositionT>() {
+                    @Override
+                    public void onClaimed(PositionT position) {}
 
-                  @Override
-                  public void onClaimFailed(PositionT position) {}
-                });
+                    @Override
+                    public void onClaimFailed(PositionT position) {}
+                  });
+          currentWatermarkEstimator =
+              WatermarkEstimators.threadSafe(
+                  doFnInvoker.invokeNewWatermarkEstimator(processContext));
+          initialWatermark = currentWatermarkEstimator.getWatermarkAndState().getKey();
+        }
         TruncateResult<OutputT> truncatedRestriction =
             doFnInvoker.invokeTruncateRestriction(processContext);
         if (truncatedRestriction != null) {
@@ -921,7 +950,10 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
       currentElement = null;
       currentRestriction = null;
       currentWatermarkEstimatorState = null;
+      currentWatermarkEstimator = null;
       currentWindow = null;
+      currentWindows = null;
+      initialWatermark = null;
     }
 
     // TODO(BEAM-10212): Support caching state data across bundle boundaries.
@@ -945,9 +977,9 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
 
     public abstract @Nullable WindowedValue getPrimaryInFullyProcessedWindowsRoot();
 
-    public abstract WindowedValue getPrimarySplitRoot();
+    public abstract @Nullable WindowedValue getPrimarySplitRoot();
 
-    public abstract WindowedValue getResidualSplitRoot();
+    public abstract @Nullable WindowedValue getResidualSplitRoot();
 
     public abstract @Nullable WindowedValue getResidualInUnprocessedWindowsRoot();
   }
@@ -956,19 +988,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
       WindowedValue<KV<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>, Double>> elem) {
     currentElement = elem.withValue(elem.getValue().getKey().getKey());
     try {
-      currentWindowIterator =
-          currentElement.getWindows() instanceof List
-              ? ((List) currentElement.getWindows()).listIterator()
-              : ImmutableList.<BoundedWindow>copyOf(elem.getWindows()).listIterator();
+      windowCurrentIndex = -1;
       windowStopIndex = currentElement.getWindows().size();
+      currentWindows = ImmutableList.copyOf(currentElement.getWindows());
       while (true) {
         synchronized (splitLock) {
-          if (!currentWindowIterator.hasNext()) {
+          windowCurrentIndex++;
+          if (windowCurrentIndex >= windowStopIndex) {
             return;
           }
           currentRestriction = elem.getValue().getKey().getValue().getKey();
           currentWatermarkEstimatorState = elem.getValue().getKey().getValue().getValue();
-          currentWindow = currentWindowIterator.next();
+          currentWindow = currentWindows.get(windowCurrentIndex);
           currentTracker =
               RestrictionTrackers.observe(
                   doFnInvoker.invokeNewTracker(processContext),
@@ -982,6 +1013,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
           currentWatermarkEstimator =
               WatermarkEstimators.threadSafe(
                   doFnInvoker.invokeNewWatermarkEstimator(processContext));
+          initialWatermark = currentWatermarkEstimator.getWatermarkAndState().getKey();
         }
 
         // It is important to ensure that {@code splitLock} is not held during #invokeProcessElement
@@ -1015,10 +1047,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
         currentElement = null;
         currentRestriction = null;
         currentWatermarkEstimatorState = null;
-        currentWindowIterator = null;
         currentWindow = null;
         currentTracker = null;
         currentWatermarkEstimator = null;
+        currentWindows = null;
+        initialWatermark = null;
       }
     }
   }
@@ -1051,9 +1084,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     synchronized (splitLock) {
       if (currentTracker instanceof RestrictionTracker.HasProgress) {
         return scaleProgress(
-            ((HasProgress) currentTracker).getProgress(),
-            currentWindowIterator.previousIndex(),
-            windowStopIndex);
+            ((HasProgress) currentTracker).getProgress(), windowCurrentIndex, windowStopIndex);
       }
     }
     return null;
@@ -1064,15 +1095,15 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
       if (currentWindow != null) {
         return scaleProgress(
             Progress.from(elementCompleted, 1 - elementCompleted),
-            currentWindowIterator.previousIndex(),
+            windowCurrentIndex,
             windowStopIndex);
       }
     }
     return null;
   }
 
-  private static Progress scaleProgress(
-      Progress progress, int currentWindowIndex, int stopWindowIndex) {
+  @VisibleForTesting
+  static Progress scaleProgress(Progress progress, int currentWindowIndex, int stopWindowIndex) {
     double totalWorkPerWindow = progress.getWorkCompleted() + progress.getWorkRemaining();
     double completed = totalWorkPerWindow * currentWindowIndex + progress.getWorkCompleted();
     double remaining =
@@ -1081,136 +1112,486 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     return Progress.from(completed, remaining);
   }
 
+  private WindowedSplitResult calculateRestrictionSize(
+      WindowedSplitResult splitResult, String errorContext) {
+    double fullSize =
+        splitResult.getResidualInUnprocessedWindowsRoot() == null
+                && splitResult.getPrimaryInFullyProcessedWindowsRoot() == null
+            ? 0
+            : doFnInvoker.invokeGetSize(
+                new DelegatingArgumentProvider<InputT, OutputT>(processContext, errorContext) {
+                  @Override
+                  public Object restriction() {
+                    return currentRestriction;
+                  }
+
+                  @Override
+                  public RestrictionTracker<?, ?> restrictionTracker() {
+                    return doFnInvoker.invokeNewTracker(this);
+                  }
+                });
+    double primarySize =
+        splitResult.getPrimarySplitRoot() == null
+            ? 0
+            : doFnInvoker.invokeGetSize(
+                new DelegatingArgumentProvider<InputT, OutputT>(processContext, errorContext) {
+                  @Override
+                  public Object restriction() {
+                    return ((KV<?, KV<?, ?>>) splitResult.getPrimarySplitRoot().getValue())
+                        .getValue()
+                        .getKey();
+                  }
+
+                  @Override
+                  public RestrictionTracker<?, ?> restrictionTracker() {
+                    return doFnInvoker.invokeNewTracker(this);
+                  }
+                });
+    double residualSize =
+        splitResult.getResidualSplitRoot() == null
+            ? 0
+            : doFnInvoker.invokeGetSize(
+                new DelegatingArgumentProvider<InputT, OutputT>(processContext, errorContext) {
+                  @Override
+                  public Object restriction() {
+                    return ((KV<?, KV<?, ?>>) splitResult.getResidualSplitRoot().getValue())
+                        .getValue()
+                        .getKey();
+                  }
+
+                  @Override
+                  public RestrictionTracker<?, ?> restrictionTracker() {
+                    return doFnInvoker.invokeNewTracker(this);
+                  }
+                });
+    return WindowedSplitResult.forRoots(
+        splitResult.getPrimaryInFullyProcessedWindowsRoot() == null
+            ? null
+            : WindowedValue.of(
+                KV.of(splitResult.getPrimaryInFullyProcessedWindowsRoot().getValue(), fullSize),
+                splitResult.getPrimaryInFullyProcessedWindowsRoot().getTimestamp(),
+                splitResult.getPrimaryInFullyProcessedWindowsRoot().getWindows(),
+                splitResult.getPrimaryInFullyProcessedWindowsRoot().getPane()),
+        splitResult.getPrimarySplitRoot() == null
+            ? null
+            : WindowedValue.of(
+                KV.of(splitResult.getPrimarySplitRoot().getValue(), primarySize),
+                splitResult.getPrimarySplitRoot().getTimestamp(),
+                splitResult.getPrimarySplitRoot().getWindows(),
+                splitResult.getPrimarySplitRoot().getPane()),
+        splitResult.getResidualSplitRoot() == null
+            ? null
+            : WindowedValue.of(
+                KV.of(splitResult.getResidualSplitRoot().getValue(), residualSize),
+                splitResult.getResidualSplitRoot().getTimestamp(),
+                splitResult.getResidualSplitRoot().getWindows(),
+                splitResult.getResidualSplitRoot().getPane()),
+        splitResult.getResidualInUnprocessedWindowsRoot() == null
+            ? null
+            : WindowedValue.of(
+                KV.of(splitResult.getResidualInUnprocessedWindowsRoot().getValue(), fullSize),
+                splitResult.getResidualInUnprocessedWindowsRoot().getTimestamp(),
+                splitResult.getResidualInUnprocessedWindowsRoot().getWindows(),
+                splitResult.getResidualInUnprocessedWindowsRoot().getPane()));
+  }
+
+  @VisibleForTesting
+  static <WatermarkEstimatorStateT>
+      KV<KV<WindowedSplitResult, HandlesSplits.SplitResult>, Integer> trySplitForTruncate(
+          WindowedValue currentElement,
+          Object currentRestriction,
+          BoundedWindow currentWindow,
+          List<BoundedWindow> windows,
+          WatermarkEstimatorStateT currentWatermarkEstimatorState,
+          double fractionOfRemainder,
+          HandlesSplits splitDelegate,
+          int currentWindowIndex,
+          int stopWindowIndex) {
+    WindowedSplitResult windowedSplitResult = null;
+    HandlesSplits.SplitResult downstreamSplitResult = null;
+    int newWindowStopIndex = stopWindowIndex;
+    // If we are not on the last window, try to compute the split which is on the current window or
+    // on a future window.
+    if (currentWindowIndex != stopWindowIndex - 1) {
+      // Compute the fraction of the remainder relative to the scaled progress.
+      double elementCompleted = splitDelegate.getProgress();
+      Progress elementProgress = Progress.from(elementCompleted, 1 - elementCompleted);
+      Progress scaledProgress = scaleProgress(elementProgress, currentWindowIndex, stopWindowIndex);
+      double scaledFractionOfRemainder = scaledProgress.getWorkRemaining() * fractionOfRemainder;
+      // The fraction is out of the current window and hence we will split at the closest window
+      // boundary.
+      if (scaledFractionOfRemainder >= elementProgress.getWorkRemaining()) {
+        newWindowStopIndex =
+            (int)
+                Math.min(
+                    stopWindowIndex - 1,
+                    currentWindowIndex
+                        + Math.max(
+                            1,
+                            Math.round(
+                                (elementProgress.getWorkCompleted() + scaledFractionOfRemainder)
+                                    / (elementProgress.getWorkCompleted()
+                                        + elementProgress.getWorkRemaining()))));
+        windowedSplitResult =
+            computeWindowSplitResult(
+                currentElement,
+                currentRestriction,
+                currentWindow,
+                windows,
+                currentWatermarkEstimatorState,
+                newWindowStopIndex,
+                newWindowStopIndex,
+                stopWindowIndex,
+                null,
+                null);
+
+      } else {
+        // Compute the downstream element split with the scaled fraction.
+        downstreamSplitResult = splitDelegate.trySplit(scaledFractionOfRemainder);
+        newWindowStopIndex = currentWindowIndex + 1;
+        windowedSplitResult =
+            computeWindowSplitResult(
+                currentElement,
+                currentRestriction,
+                currentWindow,
+                windows,
+                currentWatermarkEstimatorState,
+                currentWindowIndex,
+                newWindowStopIndex,
+                stopWindowIndex,
+                null,
+                null);
+      }
+    } else {
+      // We are on the last window then compute the downstream element split with given fraction.
+      newWindowStopIndex = stopWindowIndex;
+      downstreamSplitResult = splitDelegate.trySplit(fractionOfRemainder);
+      // We cannot produce any split if the downstream is not splittable.
+      if (downstreamSplitResult == null) {
+        return null;
+      }
+      windowedSplitResult =
+          computeWindowSplitResult(
+              currentElement,
+              currentRestriction,
+              currentWindow,
+              windows,
+              currentWatermarkEstimatorState,
+              currentWindowIndex,
+              stopWindowIndex,
+              stopWindowIndex,
+              null,
+              null);
+    }
+    return KV.of(KV.of(windowedSplitResult, downstreamSplitResult), newWindowStopIndex);
+  }
+
+  private HandlesSplits.SplitResult trySplitForWindowObservingTruncateRestriction(
+      double fractionOfRemainder, HandlesSplits splitDelegate) {
+    // Note that the assumption here is the fullInputCoder of the Truncate transform should be the
+    // the same as the SDF/Process transform.
+    Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
+    WindowedSplitResult windowedSplitResult = null;
+    HandlesSplits.SplitResult downstreamSplitResult = null;
+    synchronized (splitLock) {
+      // There is nothing to split if we are between truncate processing calls.
+      if (currentWindow == null) {
+        return null;
+      }
+
+      KV<KV<WindowedSplitResult, HandlesSplits.SplitResult>, Integer> result =
+          trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              currentWindow,
+              currentWindows,
+              currentWatermarkEstimatorState,
+              fractionOfRemainder,
+              splitDelegate,
+              windowCurrentIndex,
+              windowStopIndex);
+      if (result == null) {
+        return null;
+      }
+      windowStopIndex = result.getValue();
+      windowedSplitResult =
+          calculateRestrictionSize(
+              result.getKey().getKey(),
+              PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN + "/GetSize");
+      downstreamSplitResult = result.getKey().getValue();
+    }
+
+    List<BundleApplication> primaryRoots = new ArrayList<>();
+    List<DelayedBundleApplication> residualRoots = new ArrayList<>();
+
+    if (windowedSplitResult != null
+        && windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot() != null) {
+      ByteString.Output primaryInOtherWindowsBytes = ByteString.newOutput();
+      try {
+        fullInputCoder.encode(
+            windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot(),
+            primaryInOtherWindowsBytes);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      BundleApplication.Builder primaryApplicationInOtherWindows =
+          BundleApplication.newBuilder()
+              .setTransformId(pTransformId)
+              .setInputId(mainInputId)
+              .setElement(primaryInOtherWindowsBytes.toByteString());
+      primaryRoots.add(primaryApplicationInOtherWindows.build());
+    }
+    if (windowedSplitResult != null
+        && windowedSplitResult.getResidualInUnprocessedWindowsRoot() != null) {
+      ByteString.Output residualInUnprocessedWindowsBytesOut = ByteString.newOutput();
+      try {
+        fullInputCoder.encode(
+            windowedSplitResult.getResidualInUnprocessedWindowsRoot(),
+            residualInUnprocessedWindowsBytesOut);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      // We don't want to change the output watermarks or set the checkpoint resume time since
+      // that applies to the current window.
+      Map<String, org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp>
+          outputWatermarkMap = new HashMap<>();
+      if (!initialWatermark.equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
+        org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp outputWatermark =
+            org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                .setSeconds(initialWatermark.getMillis() / 1000)
+                .setNanos((int) (initialWatermark.getMillis() % 1000) * 1000000)
+                .build();
+        for (String outputId : pTransform.getOutputsMap().keySet()) {
+          outputWatermarkMap.put(outputId, outputWatermark);
+        }
+      }
+
+      BundleApplication.Builder residualApplicationInUnprocessedWindows =
+          BundleApplication.newBuilder()
+              .setTransformId(pTransformId)
+              .setInputId(mainInputId)
+              .putAllOutputWatermarks(outputWatermarkMap)
+              .setElement(residualInUnprocessedWindowsBytesOut.toByteString());
+
+      residualRoots.add(
+          DelayedBundleApplication.newBuilder()
+              .setApplication(residualApplicationInUnprocessedWindows)
+              .build());
+    }
+
+    if (downstreamSplitResult != null) {
+      primaryRoots.add(Iterables.getOnlyElement(downstreamSplitResult.getPrimaryRoots()));
+      residualRoots.add(Iterables.getOnlyElement(downstreamSplitResult.getResidualRoots()));
+    }
+
+    return HandlesSplits.SplitResult.of(primaryRoots, residualRoots);
+  }
+
+  private static <WatermarkEstimatorStateT> WindowedSplitResult computeWindowSplitResult(
+      WindowedValue currentElement,
+      Object currentRestriction,
+      BoundedWindow currentWindow,
+      List<BoundedWindow> windows,
+      WatermarkEstimatorStateT currentWatermarkEstimatorState,
+      int toIndex,
+      int fromIndex,
+      int stopWindowIndex,
+      SplitResult<?> splitResult,
+      KV<Instant, WatermarkEstimatorStateT> watermarkAndState) {
+    List<BoundedWindow> primaryFullyProcessedWindows = windows.subList(0, toIndex);
+    List<BoundedWindow> residualUnprocessedWindows = windows.subList(fromIndex, stopWindowIndex);
+    WindowedSplitResult windowedSplitResult;
+
+    windowedSplitResult =
+        WindowedSplitResult.forRoots(
+            primaryFullyProcessedWindows.isEmpty()
+                ? null
+                : WindowedValue.of(
+                    KV.of(
+                        currentElement.getValue(),
+                        KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                    currentElement.getTimestamp(),
+                    primaryFullyProcessedWindows,
+                    currentElement.getPane()),
+            splitResult == null
+                ? null
+                : WindowedValue.of(
+                    KV.of(
+                        currentElement.getValue(),
+                        KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
+                    currentElement.getTimestamp(),
+                    currentWindow,
+                    currentElement.getPane()),
+            splitResult == null
+                ? null
+                : WindowedValue.of(
+                    KV.of(
+                        currentElement.getValue(),
+                        KV.of(splitResult.getResidual(), watermarkAndState.getValue())),
+                    currentElement.getTimestamp(),
+                    currentWindow,
+                    currentElement.getPane()),
+            residualUnprocessedWindows.isEmpty()
+                ? null
+                : WindowedValue.of(
+                    KV.of(
+                        currentElement.getValue(),
+                        KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                    currentElement.getTimestamp(),
+                    residualUnprocessedWindows,
+                    currentElement.getPane()));
+    return windowedSplitResult;
+  }
+
+  @VisibleForTesting
+  static <WatermarkEstimatorStateT> KV<WindowedSplitResult, Integer> trySplitForProcess(
+      WindowedValue currentElement,
+      Object currentRestriction,
+      BoundedWindow currentWindow,
+      List<BoundedWindow> windows,
+      WatermarkEstimatorStateT currentWatermarkEstimatorState,
+      double fractionOfRemainder,
+      RestrictionTracker currentTracker,
+      KV<Instant, WatermarkEstimatorStateT> watermarkAndState,
+      int currentWindowIndex,
+      int stopWindowIndex) {
+    WindowedSplitResult windowedSplitResult = null;
+    int newWindowStopIndex = stopWindowIndex;
+    // If we are not on the last window, try to compute the split which is on the current window or
+    // on a future window.
+    if (currentWindowIndex != stopWindowIndex - 1) {
+      // Compute the fraction of the remainder relative to the scaled progress.
+      Progress elementProgress;
+      if (currentTracker instanceof HasProgress) {
+        elementProgress = ((HasProgress) currentTracker).getProgress();
+      } else {
+        elementProgress = Progress.from(0, 1);
+      }
+      Progress scaledProgress = scaleProgress(elementProgress, currentWindowIndex, stopWindowIndex);
+      double scaledFractionOfRemainder = scaledProgress.getWorkRemaining() * fractionOfRemainder;
+
+      // The fraction is out of the current window and hence we will split at the closest window
+      // boundary.
+      if (scaledFractionOfRemainder >= elementProgress.getWorkRemaining()) {
+        newWindowStopIndex =
+            (int)
+                Math.min(
+                    stopWindowIndex - 1,
+                    currentWindowIndex
+                        + Math.max(
+                            1,
+                            Math.round(
+                                (elementProgress.getWorkCompleted() + scaledFractionOfRemainder)
+                                    / (elementProgress.getWorkCompleted()
+                                        + elementProgress.getWorkRemaining()))));
+        windowedSplitResult =
+            computeWindowSplitResult(
+                currentElement,
+                currentRestriction,
+                currentWindow,
+                windows,
+                currentWatermarkEstimatorState,
+                newWindowStopIndex,
+                newWindowStopIndex,
+                stopWindowIndex,
+                null,
+                watermarkAndState);
+      } else {
+        // Compute the element split with the scaled fraction.
+        SplitResult<?> elementSplit =
+            currentTracker.trySplit(scaledFractionOfRemainder / elementProgress.getWorkRemaining());
+        newWindowStopIndex = currentWindowIndex + 1;
+        if (elementSplit != null) {
+          windowedSplitResult =
+              computeWindowSplitResult(
+                  currentElement,
+                  currentRestriction,
+                  currentWindow,
+                  windows,
+                  currentWatermarkEstimatorState,
+                  currentWindowIndex,
+                  newWindowStopIndex,
+                  stopWindowIndex,
+                  elementSplit,
+                  watermarkAndState);
+        } else {
+          windowedSplitResult =
+              computeWindowSplitResult(
+                  currentElement,
+                  currentRestriction,
+                  currentWindow,
+                  windows,
+                  currentWatermarkEstimatorState,
+                  newWindowStopIndex,
+                  newWindowStopIndex,
+                  stopWindowIndex,
+                  null,
+                  watermarkAndState);
+        }
+      }
+    } else {
+      // We are on the last window then compute the element split with given fraction.
+      newWindowStopIndex = stopWindowIndex;
+      SplitResult<?> splitResult = currentTracker.trySplit(fractionOfRemainder);
+      if (splitResult == null) {
+        return null;
+      }
+      windowedSplitResult =
+          computeWindowSplitResult(
+              currentElement,
+              currentRestriction,
+              currentWindow,
+              windows,
+              currentWatermarkEstimatorState,
+              currentWindowIndex,
+              stopWindowIndex,
+              stopWindowIndex,
+              splitResult,
+              watermarkAndState);
+    }
+    return KV.of(windowedSplitResult, newWindowStopIndex);
+  }
+
   private HandlesSplits.SplitResult trySplitForElementAndRestriction(
       double fractionOfRemainder, Duration resumeDelay) {
     KV<Instant, WatermarkEstimatorStateT> watermarkAndState;
-    WindowedSplitResult windowedSplitResult;
+    WindowedSplitResult windowedSplitResult = null;
     synchronized (splitLock) {
       // There is nothing to split if we are between element and restriction processing calls.
       if (currentTracker == null) {
         return null;
       }
-
       // Make sure to get the output watermark before we split to ensure that the lower bound
       // applies to the residual.
       watermarkAndState = currentWatermarkEstimator.getWatermarkAndState();
-      SplitResult<RestrictionT> splitResult = currentTracker.trySplit(fractionOfRemainder);
+      KV<WindowedSplitResult, Integer> splitResult =
+          trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              currentWindow,
+              currentWindows,
+              currentWatermarkEstimatorState,
+              fractionOfRemainder,
+              currentTracker,
+              watermarkAndState,
+              windowCurrentIndex,
+              windowStopIndex);
       if (splitResult == null) {
         return null;
       }
-
-      // We have a successful self split, either runner initiated or via a self checkpoint.
-      // Convert the split taking into account the processed windows, the current window and the
-      // yet to be processed windows.
-      List<BoundedWindow> primaryFullyProcessedWindows =
-          ImmutableList.copyOf(
-              Iterables.limit(currentElement.getWindows(), currentWindowIterator.previousIndex()));
-      windowStopIndex = currentWindowIterator.nextIndex();
-      // Advances the iterator consuming the remaining windows.
-      List<BoundedWindow> residualUnprocessedWindows = ImmutableList.copyOf(currentWindowIterator);
-      // If the window has been observed then the splitAndSize method would have already
-      // output sizes for each window separately.
-      //
-      // TODO: Consider using the original size on the element instead of recomputing
-      // this here.
-      double fullSize =
-          primaryFullyProcessedWindows.isEmpty() && residualUnprocessedWindows.isEmpty()
-              ? 0
-              : doFnInvoker.invokeGetSize(
-                  new DelegatingArgumentProvider<InputT, OutputT>(
-                      processContext,
-                      PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN
-                          + "/GetPrimarySize") {
-                    @Override
-                    public Object restriction() {
-                      return currentRestriction;
-                    }
-
-                    @Override
-                    public RestrictionTracker<?, ?> restrictionTracker() {
-                      return doFnInvoker.invokeNewTracker(this);
-                    }
-                  });
-      double primarySize =
-          doFnInvoker.invokeGetSize(
-              new DelegatingArgumentProvider<InputT, OutputT>(
-                  processContext,
-                  PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN
-                      + "/GetPrimarySize") {
-                @Override
-                public Object restriction() {
-                  return splitResult.getPrimary();
-                }
-
-                @Override
-                public RestrictionTracker<?, ?> restrictionTracker() {
-                  return doFnInvoker.invokeNewTracker(this);
-                }
-              });
-      double residualSize =
-          doFnInvoker.invokeGetSize(
-              new DelegatingArgumentProvider<InputT, OutputT>(
-                  processContext,
-                  PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN
-                      + "/GetResidualSize") {
-                @Override
-                public Object restriction() {
-                  return splitResult.getResidual();
-                }
-
-                @Override
-                public RestrictionTracker<?, ?> restrictionTracker() {
-                  return doFnInvoker.invokeNewTracker(this);
-                }
-              });
+      windowStopIndex = splitResult.getValue();
       windowedSplitResult =
-          WindowedSplitResult.forRoots(
-              primaryFullyProcessedWindows.isEmpty()
-                  ? null
-                  : WindowedValue.of(
-                      KV.of(
-                          KV.of(
-                              currentElement.getValue(),
-                              KV.of(currentRestriction, currentWatermarkEstimatorState)),
-                          fullSize),
-                      currentElement.getTimestamp(),
-                      primaryFullyProcessedWindows,
-                      currentElement.getPane()),
-              WindowedValue.of(
-                  KV.of(
-                      KV.of(
-                          currentElement.getValue(),
-                          KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
-                      primarySize),
-                  currentElement.getTimestamp(),
-                  currentWindow,
-                  currentElement.getPane()),
-              WindowedValue.of(
-                  KV.of(
-                      KV.of(
-                          currentElement.getValue(),
-                          KV.of(splitResult.getResidual(), watermarkAndState.getValue())),
-                      residualSize),
-                  currentElement.getTimestamp(),
-                  currentWindow,
-                  currentElement.getPane()),
-              residualUnprocessedWindows.isEmpty()
-                  ? null
-                  : WindowedValue.of(
-                      KV.of(
-                          KV.of(
-                              currentElement.getValue(),
-                              KV.of(currentRestriction, currentWatermarkEstimatorState)),
-                          fullSize),
-                      currentElement.getTimestamp(),
-                      residualUnprocessedWindows,
-                      currentElement.getPane()));
+          calculateRestrictionSize(
+              splitResult.getKey(),
+              PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN
+                  + "/GetSize");
     }
 
     List<BundleApplication> primaryRoots = new ArrayList<>();
     List<DelayedBundleApplication> residualRoots = new ArrayList<>();
     Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
-    if (windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot() != null) {
+    if (windowedSplitResult != null
+        && windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot() != null) {
       ByteString.Output primaryInOtherWindowsBytes = ByteString.newOutput();
       try {
         fullInputCoder.encode(
@@ -1226,7 +1607,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
               .setElement(primaryInOtherWindowsBytes.toByteString());
       primaryRoots.add(primaryApplicationInOtherWindows.build());
     }
-    if (windowedSplitResult.getResidualInUnprocessedWindowsRoot() != null) {
+    if (windowedSplitResult != null
+        && windowedSplitResult.getResidualInUnprocessedWindowsRoot() != null) {
       ByteString.Output bytesOut = ByteString.newOutput();
       try {
         fullInputCoder.encode(windowedSplitResult.getResidualInUnprocessedWindowsRoot(), bytesOut);
@@ -1240,6 +1622,20 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
               .setElement(bytesOut.toByteString());
       // We don't want to change the output watermarks or set the checkpoint resume time since
       // that applies to the current window.
+      Map<String, org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp>
+          outputWatermarkMapForUnprocessedWindows = new HashMap<>();
+      if (!initialWatermark.equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
+        org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp outputWatermark =
+            org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                .setSeconds(initialWatermark.getMillis() / 1000)
+                .setNanos((int) (initialWatermark.getMillis() % 1000) * 1000000)
+                .build();
+        for (String outputId : pTransform.getOutputsMap().keySet()) {
+          outputWatermarkMapForUnprocessedWindows.put(outputId, outputWatermark);
+        }
+      }
+      residualInUnprocessedWindowsRoot.putAllOutputWatermarks(
+          outputWatermarkMapForUnprocessedWindows);
       residualRoots.add(
           DelayedBundleApplication.newBuilder()
               .setApplication(residualInUnprocessedWindowsRoot)
@@ -1265,17 +1661,19 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
             .setTransformId(pTransformId)
             .setInputId(mainInputId)
             .setElement(residualBytes.toByteString());
-
+    Map<String, org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp>
+        outputWatermarkMap = new HashMap<>();
     if (!watermarkAndState.getKey().equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
+      org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp outputWatermark =
+          org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+              .setSeconds(watermarkAndState.getKey().getMillis() / 1000)
+              .setNanos((int) (watermarkAndState.getKey().getMillis() % 1000) * 1000000)
+              .build();
       for (String outputId : pTransform.getOutputsMap().keySet()) {
-        residualApplication.putOutputWatermarks(
-            outputId,
-            org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                .setSeconds(watermarkAndState.getKey().getMillis() / 1000)
-                .setNanos((int) (watermarkAndState.getKey().getMillis() % 1000) * 1000000)
-                .build());
+        outputWatermarkMap.put(outputId, outputWatermark);
       }
     }
+    residualApplication.putAllOutputWatermarks(outputWatermarkMap);
     residualRoots.add(
         DelayedBundleApplication.newBuilder()
             .setApplication(residualApplication)
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 8e36ed9..b388587 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.fn.harness;
 import static org.apache.beam.sdk.options.ExperimentalOptions.addExperiment;
 import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
 import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.empty;
@@ -39,6 +40,7 @@ import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
 import java.util.ServiceLoader;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
@@ -48,6 +50,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import org.apache.beam.fn.harness.FnApiDoFnRunner.WindowedSplitResult;
 import org.apache.beam.fn.harness.HandlesSplits.SplitResult;
 import org.apache.beam.fn.harness.PTransformRunnerFactory.ProgressRequestCallback;
 import org.apache.beam.fn.harness.control.BundleSplitListener;
@@ -112,6 +115,7 @@ import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Progress;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.TruncateResult;
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -137,3348 +141,4320 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterable
 import org.hamcrest.collection.IsMapContaining;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
+import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /** Tests for {@link FnApiDoFnRunner}. */
-@RunWith(JUnit4.class)
+@RunWith(Enclosed.class)
 public class FnApiDoFnRunnerTest implements Serializable {
 
-  @Rule public transient ResetDateTimeProvider dateTimeProvider = new ResetDateTimeProvider();
+  @RunWith(JUnit4.class)
+  public static class ExecutionTest implements Serializable {
+    @Rule public transient ResetDateTimeProvider dateTimeProvider = new ResetDateTimeProvider();
 
-  private static final Logger LOG = LoggerFactory.getLogger(FnApiDoFnRunnerTest.class);
+    private static final Logger LOG = LoggerFactory.getLogger(FnApiDoFnRunnerTest.class);
 
-  public static final String TEST_TRANSFORM_ID = "pTransformId";
+    public static final String TEST_TRANSFORM_ID = "pTransformId";
 
-  private static class ConcatCombineFn extends CombineFn<String, String, String> {
-    @Override
-    public String createAccumulator() {
-      return "";
-    }
+    private static class ConcatCombineFn extends CombineFn<String, String, String> {
+      @Override
+      public String createAccumulator() {
+        return "";
+      }
 
-    @Override
-    public String addInput(String accumulator, String input) {
-      return accumulator.concat(input);
-    }
+      @Override
+      public String addInput(String accumulator, String input) {
+        return accumulator.concat(input);
+      }
 
-    @Override
-    public String mergeAccumulators(Iterable<String> accumulators) {
-      StringBuilder builder = new StringBuilder();
-      for (String value : accumulators) {
-        builder.append(value);
+      @Override
+      public String mergeAccumulators(Iterable<String> accumulators) {
+        StringBuilder builder = new StringBuilder();
+        for (String value : accumulators) {
+          builder.append(value);
+        }
+        return builder.toString();
       }
-      return builder.toString();
-    }
 
-    @Override
-    public String extractOutput(String accumulator) {
-      return accumulator;
+      @Override
+      public String extractOutput(String accumulator) {
+        return accumulator;
+      }
     }
-  }
 
-  private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> {
-    private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
-    private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
+    private static class TestStatefulDoFn extends DoFn<KV<String, String>, String> {
+      private static final TupleTag<String> mainOutput = new TupleTag<>("mainOutput");
+      private static final TupleTag<String> additionalOutput = new TupleTag<>("output");
 
-    @StateId("value")
-    private final StateSpec<ValueState<String>> valueStateSpec =
-        StateSpecs.value(StringUtf8Coder.of());
+      @StateId("value")
+      private final StateSpec<ValueState<String>> valueStateSpec =
+          StateSpecs.value(StringUtf8Coder.of());
 
-    @StateId("bag")
-    private final StateSpec<BagState<String>> bagStateSpec = StateSpecs.bag(StringUtf8Coder.of());
+      @StateId("bag")
+      private final StateSpec<BagState<String>> bagStateSpec = StateSpecs.bag(StringUtf8Coder.of());
 
-    @StateId("combine")
-    private final StateSpec<CombiningState<String, String, String>> combiningStateSpec =
-        StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn());
+      @StateId("combine")
+      private final StateSpec<CombiningState<String, String, String>> combiningStateSpec =
+          StateSpecs.combining(StringUtf8Coder.of(), new ConcatCombineFn());
 
-    @ProcessElement
-    public void processElement(
-        ProcessContext context,
-        @StateId("value") ValueState<String> valueState,
-        @StateId("bag") BagState<String> bagState,
-        @StateId("combine") CombiningState<String, String, String> combiningState) {
-      context.output("value:" + valueState.read());
-      valueState.write(context.element().getValue());
+      @ProcessElement
+      public void processElement(
+          ProcessContext context,
+          @StateId("value") ValueState<String> valueState,
+          @StateId("bag") BagState<String> bagState,
+          @StateId("combine") CombiningState<String, String, String> combiningState) {
+        context.output("value:" + valueState.read());
+        valueState.write(context.element().getValue());
 
-      context.output("bag:" + Iterables.toString(bagState.read()));
-      bagState.add(context.element().getValue());
+        context.output("bag:" + Iterables.toString(bagState.read()));
+        bagState.add(context.element().getValue());
 
-      context.output("combine:" + combiningState.read());
-      combiningState.add(context.element().getValue());
+        context.output("combine:" + combiningState.read());
+        combiningState.add(context.element().getValue());
+      }
     }
-  }
 
-  @Test
-  public void testUsingUserState() throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<KV<String, String>> valuePCollection =
-        p.apply(Create.of(KV.of("unused", "unused")));
-    PCollection<String> outputPCollection =
-        valuePCollection.apply(TEST_TRANSFORM_ID, ParDo.of(new TestStatefulDoFn()));
-
-    SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
-    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
-    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
-    RunnerApi.PTransform pTransform =
-        pProto
-            .getComponents()
-            .getTransformsOrThrow(
-                pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID).getSubtransforms(0));
-
-    FakeBeamFnStateClient fakeClient =
-        new FakeBeamFnStateClient(
-            ImmutableMap.of(
-                bagUserStateKey("value", "X"), encode("X0"),
-                bagUserStateKey("bag", "X"), encode("X0"),
-                bagUserStateKey("combine", "X"), encode("X0")));
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* splitListener */,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    // Ensure that bag user state that is initially empty or populated works.
-    // Ensure that the key order does not matter when we traverse over KV pairs.
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInGlobalWindow(KV.of("X", "X1")));
-    mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1")));
-    mainInput.accept(valueInGlobalWindow(KV.of("X", "X2")));
-    mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y2")));
-    assertThat(
-        mainOutputValues,
-        contains(
-            valueInGlobalWindow("value:X0"),
-            valueInGlobalWindow("bag:[X0]"),
-            valueInGlobalWindow("combine:X0"),
-            valueInGlobalWindow("value:null"),
-            valueInGlobalWindow("bag:[]"),
-            valueInGlobalWindow("combine:"),
-            valueInGlobalWindow("value:X1"),
-            valueInGlobalWindow("bag:[X0, X1]"),
-            valueInGlobalWindow("combine:X0X1"),
-            valueInGlobalWindow("value:Y1"),
-            valueInGlobalWindow("bag:[Y1]"),
-            valueInGlobalWindow("combine:Y1")));
-    mainOutputValues.clear();
-
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-
-    assertEquals(
-        ImmutableMap.<StateKey, ByteString>builder()
-            .put(bagUserStateKey("value", "X"), encode("X2"))
-            .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
-            .put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
-            .put(bagUserStateKey("value", "Y"), encode("Y2"))
-            .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
-            .put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
-            .build(),
-        fakeClient.getData());
-  }
+    @Test
+    public void testUsingUserState() throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<KV<String, String>> valuePCollection =
+          p.apply(Create.of(KV.of("unused", "unused")));
+      PCollection<String> outputPCollection =
+          valuePCollection.apply(TEST_TRANSFORM_ID, ParDo.of(new TestStatefulDoFn()));
+
+      SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
+      RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+      String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+      String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+      RunnerApi.PTransform pTransform =
+          pProto
+              .getComponents()
+              .getTransformsOrThrow(
+                  pProto
+                      .getComponents()
+                      .getTransformsOrThrow(TEST_TRANSFORM_ID)
+                      .getSubtransforms(0));
+
+      FakeBeamFnStateClient fakeClient =
+          new FakeBeamFnStateClient(
+              ImmutableMap.of(
+                  bagUserStateKey("value", "X"), encode("X0"),
+                  bagUserStateKey("bag", "X"), encode("X0"),
+                  bagUserStateKey("combine", "X"), encode("X0")));
+
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* splitListener */,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      mainOutputValues.clear();
 
-  /** Produces a bag user {@link StateKey} for the test PTransform id in the global window. */
-  private StateKey bagUserStateKey(String userStateId, String key) throws IOException {
-    return StateKey.newBuilder()
-        .setBagUserState(
-            StateKey.BagUserState.newBuilder()
-                .setTransformId(TEST_TRANSFORM_ID)
-                .setUserStateId(userStateId)
-                .setKey(encode(key))
-                .setWindow(
-                    ByteString.copyFrom(
-                        CoderUtils.encodeToByteArray(
-                            GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE))))
-        .build();
-  }
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
 
-  private static class TestSideInputDoFn extends DoFn<String, String> {
-    private final PCollectionView<String> defaultSingletonSideInput;
-    private final PCollectionView<String> singletonSideInput;
-    private final PCollectionView<Iterable<String>> iterableSideInput;
-    private final TupleTag<String> additionalOutput;
-
-    private TestSideInputDoFn(
-        PCollectionView<String> defaultSingletonSideInput,
-        PCollectionView<String> singletonSideInput,
-        PCollectionView<Iterable<String>> iterableSideInput,
-        TupleTag<String> additionalOutput) {
-      this.defaultSingletonSideInput = defaultSingletonSideInput;
-      this.singletonSideInput = singletonSideInput;
-      this.iterableSideInput = iterableSideInput;
-      this.additionalOutput = additionalOutput;
-    }
-
-    @ProcessElement
-    public void processElement(ProcessContext context) {
-      context.output(context.element() + ":" + context.sideInput(defaultSingletonSideInput));
-      context.output(context.element() + ":" + context.sideInput(singletonSideInput));
-      for (String sideInputValue : context.sideInput(iterableSideInput)) {
-        context.output(context.element() + ":" + sideInputValue);
-      }
-      context.output(additionalOutput, context.element() + ":additional");
-    }
-  }
+      // Ensure that bag user state that is initially empty or populated works.
+      // Ensure that the key order does not matter when we traverse over KV pairs.
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(valueInGlobalWindow(KV.of("X", "X1")));
+      mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y1")));
+      mainInput.accept(valueInGlobalWindow(KV.of("X", "X2")));
+      mainInput.accept(valueInGlobalWindow(KV.of("Y", "Y2")));
+      assertThat(
+          mainOutputValues,
+          contains(
+              valueInGlobalWindow("value:X0"),
+              valueInGlobalWindow("bag:[X0]"),
+              valueInGlobalWindow("combine:X0"),
+              valueInGlobalWindow("value:null"),
+              valueInGlobalWindow("bag:[]"),
+              valueInGlobalWindow("combine:"),
+              valueInGlobalWindow("value:X1"),
+              valueInGlobalWindow("bag:[X0, X1]"),
+              valueInGlobalWindow("combine:X0X1"),
+              valueInGlobalWindow("value:Y1"),
+              valueInGlobalWindow("bag:[Y1]"),
+              valueInGlobalWindow("combine:Y1")));
+      mainOutputValues.clear();
 
-  @Test
-  public void testProcessElementWithSideInputsAndOutputs() throws Exception {
-    Pipeline p = Pipeline.create();
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
-    // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> defaultSingletonSideInputView =
-        valuePCollection.apply(
-            View.<String>asSingleton().withDefaultValue("defaultSingletonValue"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    PCollectionView<Iterable<String>> iterableSideInputView =
-        valuePCollection.apply(View.asIterable());
-    TupleTag<String> mainOutput = new TupleTag<String>("main") {};
-    TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
-    PCollectionTuple outputPCollection =
-        valuePCollection.apply(
-            TEST_TRANSFORM_ID,
-            ParDo.of(
-                    new TestSideInputDoFn(
-                        defaultSingletonSideInputView,
-                        singletonSideInputView,
-                        iterableSideInputView,
-                        additionalOutput))
-                .withSideInputs(
-                    defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView)
-                .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
-
-    SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
-    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
-    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId =
-        sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
-    String additionalPCollectionId =
-        sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
-
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID);
-
-    ImmutableMap<StateKey, ByteString> stateData =
-        ImmutableMap.of(
-            iterableSideInputKey(singletonSideInputView.getTagInternal().getId()),
-            encode("singletonValue"),
-            iterableSideInputKey(iterableSideInputView.getTagInternal().getId()),
-            encode("iterableValue1", "iterableValue2", "iterableValue3"));
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-    consumers.register(
-        additionalPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* splitListener */,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(
-        consumers.keySet(),
-        containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
-
-    // Ensure that bag user state that is initially empty or populated works.
-    // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInGlobalWindow("X"));
-    mainInput.accept(valueInGlobalWindow("Y"));
-    assertThat(
-        mainOutputValues,
-        contains(
-            valueInGlobalWindow("X:defaultSingletonValue"),
-            valueInGlobalWindow("X:singletonValue"),
-            valueInGlobalWindow("X:iterableValue1"),
-            valueInGlobalWindow("X:iterableValue2"),
-            valueInGlobalWindow("X:iterableValue3"),
-            valueInGlobalWindow("Y:defaultSingletonValue"),
-            valueInGlobalWindow("Y:singletonValue"),
-            valueInGlobalWindow("Y:iterableValue1"),
-            valueInGlobalWindow("Y:iterableValue2"),
-            valueInGlobalWindow("Y:iterableValue3")));
-    assertThat(
-        additionalOutputValues,
-        contains(valueInGlobalWindow("X:additional"), valueInGlobalWindow("Y:additional")));
-    mainOutputValues.clear();
-
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-
-    // Assert that state data did not change
-    assertEquals(stateData, fakeClient.getData());
-  }
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
 
-  private static class TestNonWindowObservingDoFn extends DoFn<String, String> {
-    private final TupleTag<String> additionalOutput;
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
 
-    private TestNonWindowObservingDoFn(TupleTag<String> additionalOutput) {
-      this.additionalOutput = additionalOutput;
+      assertEquals(
+          ImmutableMap.<StateKey, ByteString>builder()
+              .put(bagUserStateKey("value", "X"), encode("X2"))
+              .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2"))
+              .put(bagUserStateKey("combine", "X"), encode("X0X1X2"))
+              .put(bagUserStateKey("value", "Y"), encode("Y2"))
+              .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2"))
+              .put(bagUserStateKey("combine", "Y"), encode("Y1Y2"))
+              .build(),
+          fakeClient.getData());
     }
 
-    @ProcessElement
-    public void processElement(ProcessContext context) {
-      context.output(context.element() + ":main");
-      context.output(additionalOutput, context.element() + ":additional");
+    /** Produces a bag user {@link StateKey} for the test PTransform id in the global window. */
+    private StateKey bagUserStateKey(String userStateId, String key) throws IOException {
+      return StateKey.newBuilder()
+          .setBagUserState(
+              StateKey.BagUserState.newBuilder()
+                  .setTransformId(TEST_TRANSFORM_ID)
+                  .setUserStateId(userStateId)
+                  .setKey(encode(key))
+                  .setWindow(
+                      ByteString.copyFrom(
+                          CoderUtils.encodeToByteArray(
+                              GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE))))
+          .build();
     }
-  }
 
-  @Test
-  public void testProcessElementWithNonWindowObservingOptimization() throws Exception {
-    Pipeline p = Pipeline.create();
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
-    // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
-    PCollection<String> valuePCollection =
-        p.apply(Create.of("unused"))
-            .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1))));
-    TupleTag<String> mainOutput = new TupleTag<String>("main") {};
-    TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
-    PCollectionTuple outputPCollection =
-        valuePCollection.apply(
-            TEST_TRANSFORM_ID,
-            ParDo.of(new TestNonWindowObservingDoFn(additionalOutput))
-                .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
-
-    SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
-    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
-    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId =
-        sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
-    String additionalPCollectionId =
-        sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
-
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID);
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-    consumers.register(
-        additionalPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            null /* beamFnStateClient */,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* splitListener */,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(
-        consumers.keySet(),
-        containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(
-        valueInWindows(
-            "X",
-            new IntervalWindow(new Instant(0L), Duration.standardMinutes(1)),
-            new IntervalWindow(new Instant(10L), Duration.standardMinutes(1))));
-    mainInput.accept(
-        valueInWindows(
-            "Y",
-            new IntervalWindow(new Instant(1000L), Duration.standardMinutes(1)),
-            new IntervalWindow(new Instant(1010L), Duration.standardMinutes(1))));
-    // Ensure that each output element is in all the windows and not one per window.
-    assertThat(
-        mainOutputValues,
-        contains(
-            valueInWindows(
-                "X:main",
-                new IntervalWindow(new Instant(0L), Duration.standardMinutes(1)),
-                new IntervalWindow(new Instant(10L), Duration.standardMinutes(1))),
-            valueInWindows(
-                "Y:main",
-                new IntervalWindow(new Instant(1000L), Duration.standardMinutes(1)),
-                new IntervalWindow(new Instant(1010L), Duration.standardMinutes(1)))));
-    assertThat(
-        additionalOutputValues,
-        contains(
-            valueInWindows(
-                "X:additional",
-                new IntervalWindow(new Instant(0L), Duration.standardMinutes(1)),
-                new IntervalWindow(new Instant(10L), Duration.standardMinutes(1))),
-            valueInWindows(
-                "Y:additional",
-                new IntervalWindow(new Instant(1000L), Duration.standardMinutes(1)),
-                new IntervalWindow(new Instant(1010L), Duration.standardMinutes(1)))));
-    mainOutputValues.clear();
+    private static class TestSideInputDoFn extends DoFn<String, String> {
+      private final PCollectionView<String> defaultSingletonSideInput;
+      private final PCollectionView<String> singletonSideInput;
+      private final PCollectionView<Iterable<String>> iterableSideInput;
+      private final TupleTag<String> additionalOutput;
+
+      private TestSideInputDoFn(
+          PCollectionView<String> defaultSingletonSideInput,
+          PCollectionView<String> singletonSideInput,
+          PCollectionView<Iterable<String>> iterableSideInput,
+          TupleTag<String> additionalOutput) {
+        this.defaultSingletonSideInput = defaultSingletonSideInput;
+        this.singletonSideInput = singletonSideInput;
+        this.iterableSideInput = iterableSideInput;
+        this.additionalOutput = additionalOutput;
+      }
 
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
+      @ProcessElement
+      public void processElement(ProcessContext context) {
+        context.output(context.element() + ":" + context.sideInput(defaultSingletonSideInput));
+        context.output(context.element() + ":" + context.sideInput(singletonSideInput));
+        for (String sideInputValue : context.sideInput(iterableSideInput)) {
+          context.output(context.element() + ":" + sideInputValue);
+        }
+        context.output(additionalOutput, context.element() + ":additional");
+      }
+    }
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
+    @Test
+    public void testProcessElementWithSideInputsAndOutputs() throws Exception {
+      Pipeline p = Pipeline.create();
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
+      // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> defaultSingletonSideInputView =
+          valuePCollection.apply(
+              View.<String>asSingleton().withDefaultValue("defaultSingletonValue"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      PCollectionView<Iterable<String>> iterableSideInputView =
+          valuePCollection.apply(View.asIterable());
+      TupleTag<String> mainOutput = new TupleTag<String>("main") {};
+      TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
+      PCollectionTuple outputPCollection =
+          valuePCollection.apply(
+              TEST_TRANSFORM_ID,
+              ParDo.of(
+                      new TestSideInputDoFn(
+                          defaultSingletonSideInputView,
+                          singletonSideInputView,
+                          iterableSideInputView,
+                          additionalOutput))
+                  .withSideInputs(
+                      defaultSingletonSideInputView, singletonSideInputView, iterableSideInputView)
+                  .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
+
+      SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
+      RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
+      String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+      String outputPCollectionId =
+          sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
+      String additionalPCollectionId =
+          sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
+
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID);
+
+      ImmutableMap<StateKey, ByteString> stateData =
+          ImmutableMap.of(
+              iterableSideInputKey(singletonSideInputView.getTagInternal().getId()),
+              encode("singletonValue"),
+              iterableSideInputKey(iterableSideInputView.getTagInternal().getId()),
+              encode("iterableValue1", "iterableValue2", "iterableValue3"));
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+      consumers.register(
+          additionalPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* splitListener */,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      mainOutputValues.clear();
+
+      assertThat(
+          consumers.keySet(),
+          containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
+
+      // Ensure that bag user state that is initially empty or populated works.
+      // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(valueInGlobalWindow("X"));
+      mainInput.accept(valueInGlobalWindow("Y"));
+      assertThat(
+          mainOutputValues,
+          contains(
+              valueInGlobalWindow("X:defaultSingletonValue"),
+              valueInGlobalWindow("X:singletonValue"),
+              valueInGlobalWindow("X:iterableValue1"),
+              valueInGlobalWindow("X:iterableValue2"),
+              valueInGlobalWindow("X:iterableValue3"),
+              valueInGlobalWindow("Y:defaultSingletonValue"),
+              valueInGlobalWindow("Y:singletonValue"),
+              valueInGlobalWindow("Y:iterableValue1"),
+              valueInGlobalWindow("Y:iterableValue2"),
+              valueInGlobalWindow("Y:iterableValue3")));
+      assertThat(
+          additionalOutputValues,
+          contains(valueInGlobalWindow("X:additional"), valueInGlobalWindow("Y:additional")));
+      mainOutputValues.clear();
 
-  private static class TestSideInputIsAccessibleForDownstreamCallersDoFn
-      extends DoFn<String, Iterable<String>> {
-    public static final String USER_COUNTER_NAME = "userCountedElems";
-    private final Counter countedElements =
-        Metrics.counter(TestSideInputIsAccessibleForDownstreamCallersDoFn.class, USER_COUNTER_NAME);
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
 
-    private final PCollectionView<Iterable<String>> iterableSideInput;
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
 
-    private TestSideInputIsAccessibleForDownstreamCallersDoFn(
-        PCollectionView<Iterable<String>> iterableSideInput) {
-      this.iterableSideInput = iterableSideInput;
+      // Assert that state data did not change
+      assertEquals(stateData, fakeClient.getData());
     }
 
-    @ProcessElement
-    public void processElement(ProcessContext context) {
-      countedElements.inc();
-      context.output(context.sideInput(iterableSideInput));
+    private static class TestNonWindowObservingDoFn extends DoFn<String, String> {
+      private final TupleTag<String> additionalOutput;
+
+      private TestNonWindowObservingDoFn(TupleTag<String> additionalOutput) {
+        this.additionalOutput = additionalOutput;
+      }
+
+      @ProcessElement
+      public void processElement(ProcessContext context) {
+        context.output(context.element() + ":main");
+        context.output(additionalOutput, context.element() + ":additional");
+      }
     }
-  }
 
-  @Test
-  public void testSideInputIsAccessibleForDownstreamCallers() throws Exception {
-    FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
-    IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
-    IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
-    ByteString encodedWindowA =
-        ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowA));
-    ByteString encodedWindowB =
-        ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowB));
-
-    Pipeline p = Pipeline.create();
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
-    // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
-    PCollection<String> valuePCollection =
-        p.apply(Create.of("unused")).apply(Window.into(windowFn));
-    PCollectionView<Iterable<String>> iterableSideInputView =
-        valuePCollection.apply(View.asIterable());
-    PCollection<Iterable<String>> outputPCollection =
-        valuePCollection.apply(
-            TEST_TRANSFORM_ID,
-            ParDo.of(new TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
-                .withSideInputs(iterableSideInputView));
-
-    SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
-    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
-    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
-
-    RunnerApi.PTransform pTransform =
-        pProto
-            .getComponents()
-            .getTransformsOrThrow(
-                pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID).getSubtransforms(0));
-
-    ImmutableMap<StateKey, ByteString> stateData =
-        ImmutableMap.of(
-            iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowA),
-            encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
-            iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowB),
-            encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
-
-    List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>) mainOutputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* splitListener */,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    // Ensure that bag user state that is initially empty or populated works.
-    // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInWindows("X", windowA));
-    mainInput.accept(valueInWindows("Y", windowB));
-    assertThat(mainOutputValues, hasSize(2));
-    assertThat(
-        mainOutputValues.get(0).getValue(),
-        contains("iterableValue1A", "iterableValue2A", "iterableValue3A"));
-    assertThat(
-        mainOutputValues.get(1).getValue(),
-        contains("iterableValue1B", "iterableValue2B", "iterableValue3B"));
-    mainOutputValues.clear();
-
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-
-    // Assert that state data did not change
-    assertEquals(stateData, fakeClient.getData());
-  }
+    @Test
+    public void testProcessElementWithNonWindowObservingOptimization() throws Exception {
+      Pipeline p = Pipeline.create();
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
+      // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
+      PCollection<String> valuePCollection =
+          p.apply(Create.of("unused"))
+              .apply(Window.into(FixedWindows.of(Duration.standardMinutes(1))));
+      TupleTag<String> mainOutput = new TupleTag<String>("main") {};
+      TupleTag<String> additionalOutput = new TupleTag<String>("additional") {};
+      PCollectionTuple outputPCollection =
+          valuePCollection.apply(
+              TEST_TRANSFORM_ID,
+              ParDo.of(new TestNonWindowObservingDoFn(additionalOutput))
+                  .withOutputTags(mainOutput, TupleTagList.of(additionalOutput)));
+
+      SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
+      RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
+      String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+      String outputPCollectionId =
+          sdkComponents.registerPCollection(outputPCollection.get(mainOutput));
+      String additionalPCollectionId =
+          sdkComponents.registerPCollection(outputPCollection.get(additionalOutput));
+
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID);
+
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      List<WindowedValue<String>> additionalOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+      consumers.register(
+          additionalPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) additionalOutputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              null /* beamFnStateClient */,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* splitListener */,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      mainOutputValues.clear();
 
-  /** @return a test MetricUpdate for expected metrics to compare against */
-  public MetricUpdate create(String stepName, MetricName name, long value) {
-    return MetricUpdate.create(MetricKey.create(stepName, name), value);
-  }
+      assertThat(
+          consumers.keySet(),
+          containsInAnyOrder(inputPCollectionId, outputPCollectionId, additionalPCollectionId));
 
-  @Test
-  public void testUsingMetrics() throws Exception {
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    MetricsContainerImpl metricsContainer = metricsContainerRegistry.getUnboundContainer();
-    Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer);
-    FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
-    IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
-    IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
-    ByteString encodedWindowA =
-        ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowA));
-    ByteString encodedWindowB =
-        ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowB));
-
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection =
-        p.apply(Create.of("unused")).apply(Window.into(windowFn));
-    PCollectionView<Iterable<String>> iterableSideInputView =
-        valuePCollection.apply(View.asIterable());
-    PCollection<Iterable<String>> outputPCollection =
-        valuePCollection.apply(
-            TEST_TRANSFORM_ID,
-            ParDo.of(new TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
-                .withSideInputs(iterableSideInputView));
-
-    SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
-    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
-    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
-
-    RunnerApi.PTransform pTransform =
-        pProto
-            .getComponents()
-            .getTransformsOrThrow(
-                pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID).getSubtransforms(0));
-
-    ImmutableMap<StateKey, ByteString> stateData =
-        ImmutableMap.of(
-            iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowA),
-            encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
-            iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowB),
-            encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
-
-    List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
-
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>) mainOutputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* splitListener */,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    // Ensure that bag user state that is initially empty or populated works.
-    // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInWindows("X", windowA));
-    mainInput.accept(valueInWindows("Y", windowB));
-    mainOutputValues.clear();
-
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-
-    MetricsContainer mc = MetricsEnvironment.getCurrentContainer();
-
-    List<MonitoringInfo> expected = new ArrayList<MonitoringInfo>();
-    SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
-    builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
-    builder.setLabel(MonitoringInfoConstants.Labels.PCOLLECTION, "Window.Into()/Window.Assign.out");
-    builder.setInt64SumValue(2);
-    expected.add(builder.build());
-
-    builder = new SimpleMonitoringInfoBuilder();
-    builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
-    builder.setLabel(
-        MonitoringInfoConstants.Labels.PCOLLECTION,
-        "pTransformId/ParMultiDo(TestSideInputIsAccessibleForDownstreamCallers).output");
-    builder.setInt64SumValue(2);
-    expected.add(builder.build());
-
-    builder = new SimpleMonitoringInfoBuilder();
-    builder
-        .setUrn(MonitoringInfoConstants.Urns.USER_SUM_INT64)
-        .setLabel(
-            MonitoringInfoConstants.Labels.NAMESPACE,
-            TestSideInputIsAccessibleForDownstreamCallersDoFn.class.getName())
-        .setLabel(
-            MonitoringInfoConstants.Labels.NAME,
-            TestSideInputIsAccessibleForDownstreamCallersDoFn.USER_COUNTER_NAME);
-    builder.setLabel(MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
-    builder.setInt64SumValue(2);
-    expected.add(builder.build());
-
-    closeable.close();
-    List<MonitoringInfo> result = new ArrayList<MonitoringInfo>();
-    for (MonitoringInfo mi : metricsContainerRegistry.getMonitoringInfos()) {
-      result.add(mi);
-    }
-    assertThat(result, containsInAnyOrder(expected.toArray()));
-  }
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(
+          valueInWindows(
+              "X",
+              new IntervalWindow(new Instant(0L), Duration.standardMinutes(1)),
+              new IntervalWindow(new Instant(10L), Duration.standardMinutes(1))));
+      mainInput.accept(
+          valueInWindows(
+              "Y",
+              new IntervalWindow(new Instant(1000L), Duration.standardMinutes(1)),
+              new IntervalWindow(new Instant(1010L), Duration.standardMinutes(1))));
+      // Ensure that each output element is in all the windows and not one per window.
+      assertThat(
+          mainOutputValues,
+          contains(
+              valueInWindows(
+                  "X:main",
+                  new IntervalWindow(new Instant(0L), Duration.standardMinutes(1)),
+                  new IntervalWindow(new Instant(10L), Duration.standardMinutes(1))),
+              valueInWindows(
+                  "Y:main",
+                  new IntervalWindow(new Instant(1000L), Duration.standardMinutes(1)),
+                  new IntervalWindow(new Instant(1010L), Duration.standardMinutes(1)))));
+      assertThat(
+          additionalOutputValues,
+          contains(
+              valueInWindows(
+                  "X:additional",
+                  new IntervalWindow(new Instant(0L), Duration.standardMinutes(1)),
+                  new IntervalWindow(new Instant(10L), Duration.standardMinutes(1))),
+              valueInWindows(
+                  "Y:additional",
+                  new IntervalWindow(new Instant(1000L), Duration.standardMinutes(1)),
+                  new IntervalWindow(new Instant(1010L), Duration.standardMinutes(1)))));
+      mainOutputValues.clear();
 
-  @Test
-  public void testTimers() throws Exception {
-    dateTimeProvider.setDateTimeFixed(10000L);
-
-    Pipeline p = Pipeline.create();
-    PCollection<KV<String, String>> valuePCollection =
-        p.apply(Create.of(KV.of("unused", "unused")));
-    PCollection<String> outputPCollection =
-        valuePCollection.apply(TEST_TRANSFORM_ID, ParDo.of(new TestTimerfulDoFn()));
-
-    SdkComponents sdkComponents = SdkComponents.create();
-    sdkComponents.registerEnvironment(Environment.getDefaultInstance());
-    RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
-    String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
-    String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
-
-    RunnerApi.PTransform pTransform =
-        pProto
-            .getComponents()
-            .getTransformsOrThrow(
-                pProto.getComponents().getTransformsOrThrow(TEST_TRANSFORM_ID).getSubtransforms(0))
-            .toBuilder()
-            .build();
-
-    FakeBeamFnStateClient fakeStateClient =
-        new FakeBeamFnStateClient(
-            ImmutableMap.of(
-                bagUserStateKey("bag", "X"), encode("X0"),
-                bagUserStateKey("bag", "A"), encode("A0"),
-                bagUserStateKey("bag", "C"), encode("C0")));
-    FakeBeamFnTimerClient fakeTimerClient = new FakeBeamFnTimerClient();
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeStateClient,
-            fakeTimerClient,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* splitListener */,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    LogicalEndpoint eventTimer = LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "ts-event");
-    LogicalEndpoint processingTimer =
-        LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "ts-processing");
-    LogicalEndpoint eventFamilyTimer =
-        LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "tfs-event-family");
-    LogicalEndpoint processingFamilyTimer =
-        LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "tfs-processing-family");
-    // Ensure that bag user state that is initially empty or populated works.
-    // Ensure that the key order does not matter when we traverse over KV pairs.
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(timestampedValueInGlobalWindow(KV.of("X", "X1"), new Instant(1000L)));
-    mainInput.accept(timestampedValueInGlobalWindow(KV.of("Y", "Y1"), new Instant(1100L)));
-    mainInput.accept(timestampedValueInGlobalWindow(KV.of("X", "X2"), new Instant(1200L)));
-    mainInput.accept(timestampedValueInGlobalWindow(KV.of("Y", "Y2"), new Instant(1300L)));
-    fakeTimerClient.sendTimer(
-        eventTimer, timerInGlobalWindow("A", new Instant(1400L), new Instant(2400L)));
-    fakeTimerClient.sendTimer(
-        eventTimer, timerInGlobalWindow("B", new Instant(1500L), new Instant(2500L)));
-    fakeTimerClient.sendTimer(
-        eventTimer, timerInGlobalWindow("A", new Instant(1600L), new Instant(2600L)));
-    fakeTimerClient.sendTimer(
-        processingTimer, timerInGlobalWindow("X", new Instant(1700L), new Instant(2700L)));
-    fakeTimerClient.sendTimer(
-        processingTimer, timerInGlobalWindow("C", new Instant(1800L), new Instant(2800L)));
-    fakeTimerClient.sendTimer(
-        processingTimer, timerInGlobalWindow("B", new Instant(1900L), new Instant(2900L)));
-    fakeTimerClient.sendTimer(
-        eventFamilyTimer,
-        dynamicTimerInGlobalWindow("B", "event-timer2", new Instant(2000L), new Instant(3000L)));
-    fakeTimerClient.sendTimer(
-        processingFamilyTimer,
-        dynamicTimerInGlobalWindow(
-            "Y", "processing-timer2", new Instant(2100L), new Instant(3100L)));
-    assertThat(
-        mainOutputValues,
-        contains(
-            timestampedValueInGlobalWindow("mainX[X0]", new Instant(1000L)),
-            timestampedValueInGlobalWindow("mainY[]", new Instant(1100L)),
-            timestampedValueInGlobalWindow("mainX[X0, X1]", new Instant(1200L)),
-            timestampedValueInGlobalWindow("mainY[Y1]", new Instant(1300L)),
-            timestampedValueInGlobalWindow("event[A0]", new Instant(1400L)),
-            timestampedValueInGlobalWindow("event[]", new Instant(1500L)),
-            timestampedValueInGlobalWindow("event[A0, event]", new Instant(1600L)),
-            timestampedValueInGlobalWindow("processing[X0, X1, X2]", new Instant(1700L)),
-            timestampedValueInGlobalWindow("processing[C0]", new Instant(1800L)),
-            timestampedValueInGlobalWindow("processing[event]", new Instant(1900L)),
-            timestampedValueInGlobalWindow("event-family[event, processing]", new Instant(2000L)),
-            timestampedValueInGlobalWindow("processing-family[Y1, Y2]", new Instant(2100L))));
-    assertThat(
-        fakeTimerClient.getTimers(eventTimer),
-        contains(
-            timerInGlobalWindow("X", new Instant(1000L), new Instant(1001L)),
-            timerInGlobalWindow("Y", new Instant(1100L), new Instant(1101L)),
-            timerInGlobalWindow("X", new Instant(1200L), new Instant(1201L)),
-            timerInGlobalWindow("Y", new Instant(1300L), new Instant(1301L)),
-            timerInGlobalWindow("A", new Instant(1400L), new Instant(2411L)),
-            timerInGlobalWindow("B", new Instant(1500L), new Instant(2511L)),
-            timerInGlobalWindow("A", new Instant(1600L), new Instant(2611L)),
-            timerInGlobalWindow("X", new Instant(1700L), new Instant(1721L)),
-            timerInGlobalWindow("C", new Instant(1800L), new Instant(1821L)),
-            timerInGlobalWindow("B", new Instant(1900L), new Instant(1921L)),
-            timerInGlobalWindow("B", new Instant(2000L), new Instant(2031L)),
-            timerInGlobalWindow("Y", new Instant(2100L), new Instant(2141L))));
-    assertThat(
-        fakeTimerClient.getTimers(processingTimer),
-        contains(
-            timerInGlobalWindow("X", new Instant(1000L), new Instant(10002L)),
-            timerInGlobalWindow("Y", new Instant(1100L), new Instant(10002L)),
-            timerInGlobalWindow("X", new Instant(1200L), new Instant(10002L)),
-            timerInGlobalWindow("Y", new Instant(1300L), new Instant(10002L)),
-            timerInGlobalWindow("A", new Instant(1400L), new Instant(10012L)),
-            timerInGlobalWindow("B", new Instant(1500L), new Instant(10012L)),
-            timerInGlobalWindow("A", new Instant(1600L), new Instant(10012L)),
-            timerInGlobalWindow("X", new Instant(1700L), new Instant(10022L)),
-            timerInGlobalWindow("C", new Instant(1800L), new Instant(10022L)),
-            timerInGlobalWindow("B", new Instant(1900L), new Instant(10022L)),
-            timerInGlobalWindow("B", new Instant(2000L), new Instant(10032L)),
-            timerInGlobalWindow("Y", new Instant(2100L), new Instant(10042L))));
-    assertThat(
-        fakeTimerClient.getTimers(eventFamilyTimer),
-        contains(
-            dynamicTimerInGlobalWindow("X", "event-timer1", new Instant(1000L), new Instant(1003L)),
-            dynamicTimerInGlobalWindow("Y", "event-timer1", new Instant(1100L), new Instant(1103L)),
-            dynamicTimerInGlobalWindow("X", "event-timer1", new Instant(1200L), new Instant(1203L)),
-            dynamicTimerInGlobalWindow("Y", "event-timer1", new Instant(1300L), new Instant(1303L)),
-            dynamicTimerInGlobalWindow("A", "event-timer1", new Instant(1400L), new Instant(2413L)),
-            dynamicTimerInGlobalWindow("B", "event-timer1", new Instant(1500L), new Instant(2513L)),
-            dynamicTimerInGlobalWindow("A", "event-timer1", new Instant(1600L), new Instant(2613L)),
-            dynamicTimerInGlobalWindow("X", "event-timer1", new Instant(1700L), new Instant(1723L)),
-            dynamicTimerInGlobalWindow("C", "event-timer1", new Instant(1800L), new Instant(1823L)),
-            dynamicTimerInGlobalWindow("B", "event-timer1", new Instant(1900L), new Instant(1923L)),
-            dynamicTimerInGlobalWindow("B", "event-timer1", new Instant(2000L), new Instant(2033L)),
-            dynamicTimerInGlobalWindow(
-                "Y", "event-timer1", new Instant(2100L), new Instant(2143L))));
-    assertThat(
-        fakeTimerClient.getTimers(processingFamilyTimer),
-        contains(
-            dynamicTimerInGlobalWindow(
-                "X", "processing-timer1", new Instant(1000L), new Instant(10004L)),
-            dynamicTimerInGlobalWindow(
-                "Y", "processing-timer1", new Instant(1100L), new Instant(10004L)),
-            dynamicTimerInGlobalWindow(
-                "X", "processing-timer1", new Instant(1200L), new Instant(10004L)),
-            dynamicTimerInGlobalWindow(
-                "Y", "processing-timer1", new Instant(1300L), new Instant(10004L)),
-            dynamicTimerInGlobalWindow(
-                "A", "processing-timer1", new Instant(1400L), new Instant(10014L)),
-            dynamicTimerInGlobalWindow(
-                "B", "processing-timer1", new Instant(1500L), new Instant(10014L)),
-            dynamicTimerInGlobalWindow(
-                "A", "processing-timer1", new Instant(1600L), new Instant(10014L)),
-            dynamicTimerInGlobalWindow(
-                "X", "processing-timer1", new Instant(1700L), new Instant(10024L)),
-            dynamicTimerInGlobalWindow(
-                "C", "processing-timer1", new Instant(1800L), new Instant(10024L)),
-            dynamicTimerInGlobalWindow(
-                "B", "processing-timer1", new Instant(1900L), new Instant(10024L)),
-            dynamicTimerInGlobalWindow(
-                "B", "processing-timer1", new Instant(2000L), new Instant(10034L)),
-            dynamicTimerInGlobalWindow(
-                "Y", "processing-timer1", new Instant(2100L), new Instant(10044L))));
-    mainOutputValues.clear();
-
-    assertFalse(fakeTimerClient.isOutboundClosed(eventTimer));
-    assertFalse(fakeTimerClient.isOutboundClosed(processingTimer));
-    assertFalse(fakeTimerClient.isOutboundClosed(eventFamilyTimer));
-    assertFalse(fakeTimerClient.isOutboundClosed(processingFamilyTimer));
-    fakeTimerClient.closeInbound(eventTimer);
-    fakeTimerClient.closeInbound(processingTimer);
-    fakeTimerClient.closeInbound(eventFamilyTimer);
-    fakeTimerClient.closeInbound(processingFamilyTimer);
-
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
-
-    assertTrue(fakeTimerClient.isOutboundClosed(eventTimer));
-    assertTrue(fakeTimerClient.isOutboundClosed(processingTimer));
-    assertTrue(fakeTimerClient.isOutboundClosed(eventFamilyTimer));
-    assertTrue(fakeTimerClient.isOutboundClosed(processingFamilyTimer));
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-
-    assertEquals(
-        ImmutableMap.<StateKey, ByteString>builder()
-            .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2", "processing"))
-            .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2", "processing-family"))
-            .put(bagUserStateKey("bag", "A"), encode("A0", "event", "event"))
-            .put(bagUserStateKey("bag", "B"), encode("event", "processing", "event-family"))
-            .put(bagUserStateKey("bag", "C"), encode("C0", "processing"))
-            .build(),
-        fakeStateClient.getData());
-  }
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
 
-  private <K> org.apache.beam.runners.core.construction.Timer<K> timerInGlobalWindow(
-      K userKey, Instant holdTimestamp, Instant fireTimestamp) {
-    return dynamicTimerInGlobalWindow(userKey, "", holdTimestamp, fireTimestamp);
-  }
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
 
-  private <K> org.apache.beam.runners.core.construction.Timer<K> dynamicTimerInGlobalWindow(
-      K userKey, String dynamicTimerTag, Instant holdTimestamp, Instant fireTimestamp) {
-    return org.apache.beam.runners.core.construction.Timer.of(
-        userKey,
-        dynamicTimerTag,
-        Collections.singletonList(GlobalWindow.INSTANCE),
-        fireTimestamp,
-        holdTimestamp,
-        PaneInfo.NO_FIRING);
-  }
+    private static class TestSideInputIsAccessibleForDownstreamCallersDoFn
+        extends DoFn<String, Iterable<String>> {
+      public static final String USER_COUNTER_NAME = "userCountedElems";
+      private final Counter countedElements =
+          Metrics.counter(
+              TestSideInputIsAccessibleForDownstreamCallersDoFn.class, USER_COUNTER_NAME);
 
-  private <T> WindowedValue<T> valueInWindows(
-      T value, BoundedWindow window, BoundedWindow... windows) {
-    return WindowedValue.of(
-        value,
-        window.maxTimestamp(),
-        ImmutableList.<BoundedWindow>builder().add(window).add(windows).build(),
-        PaneInfo.NO_FIRING);
-  }
+      private final PCollectionView<Iterable<String>> iterableSideInput;
+
+      private TestSideInputIsAccessibleForDownstreamCallersDoFn(
+          PCollectionView<Iterable<String>> iterableSideInput) {
+        this.iterableSideInput = iterableSideInput;
+      }
 
-  private static class TestTimerfulDoFn extends DoFn<KV<String, String>, String> {
-    @StateId("bag")
-    private final StateSpec<BagState<String>> bagStateSpec = StateSpecs.bag(StringUtf8Coder.of());
-
-    @TimerId("event")
-    private final TimerSpec eventTimerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
-
-    @TimerId("processing")
-    private final TimerSpec processingTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
-
-    @TimerFamily("event-family")
-    private final TimerSpec eventTimerFamilySpec = TimerSpecs.timerMap(TimeDomain.EVENT_TIME);
-
-    @TimerFamily("processing-family")
-    private final TimerSpec processingTimerFamilySpec =
-        TimerSpecs.timerMap(TimeDomain.PROCESSING_TIME);
-
-    @ProcessElement
-    public void processElement(
-        ProcessContext context,
-        @StateId("bag") BagState<String> bagState,
-        @TimerId("event") Timer eventTimeTimer,
-        @TimerId("processing") Timer processingTimeTimer,
-        @TimerFamily("event-family") TimerMap eventTimerFamily,
-        @TimerFamily("processing-family") TimerMap processingTimerFamily) {
-      context.output("main" + context.element().getKey() + Iterables.toString(bagState.read()));
-      bagState.add(context.element().getValue());
-      eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(1L));
-      processingTimeTimer.offset(Duration.millis(2L));
-      processingTimeTimer.setRelative();
-      eventTimerFamily
-          .get("event-timer1")
-          .withOutputTimestamp(context.timestamp())
-          .set(context.timestamp().plus(3L));
-      processingTimerFamily.get("processing-timer1").offset(Duration.millis(4L)).setRelative();
-    }
-
-    @OnTimer("event")
-    public void eventTimer(
-        OnTimerContext context,
-        @StateId("bag") BagState<String> bagState,
-        @TimerId("event") Timer eventTimeTimer,
-        @TimerId("processing") Timer processingTimeTimer,
-        @TimerFamily("event-family") TimerMap eventTimerFamily,
-        @TimerFamily("processing-family") TimerMap processingTimerFamily) {
-      context.output("event" + Iterables.toString(bagState.read()));
-      bagState.add("event");
-      eventTimeTimer
-          .withOutputTimestamp(context.timestamp())
-          .set(context.fireTimestamp().plus(11L));
-      processingTimeTimer.offset(Duration.millis(12L));
-      processingTimeTimer.setRelative();
-      eventTimerFamily
-          .get("event-timer1")
-          .withOutputTimestamp(context.timestamp())
-          .set(context.fireTimestamp().plus(13L));
-      processingTimerFamily.get("processing-timer1").offset(Duration.millis(14L)).setRelative();
-    }
-
-    @OnTimer("processing")
-    public void processingTimer(
-        OnTimerContext context,
-        @StateId("bag") BagState<String> bagState,
-        @TimerId("event") Timer eventTimeTimer,
-        @TimerId("processing") Timer processingTimeTimer,
-        @TimerFamily("event-family") TimerMap eventTimerFamily,
-        @TimerFamily("processing-family") TimerMap processingTimerFamily) {
-      context.output("processing" + Iterables.toString(bagState.read()));
-      bagState.add("processing");
-      eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(21L));
-      processingTimeTimer.offset(Duration.millis(22L));
-      processingTimeTimer.setRelative();
-      eventTimerFamily
-          .get("event-timer1")
-          .withOutputTimestamp(context.timestamp())
-          .set(context.timestamp().plus(23L));
-      processingTimerFamily.get("processing-timer1").offset(Duration.millis(24L)).setRelative();
-    }
-
-    @OnTimerFamily("event-family")
-    public void eventFamilyOnTimer(
-        OnTimerContext context,
-        @StateId("bag") BagState<String> bagState,
-        @TimerId("event") Timer eventTimeTimer,
-        @TimerId("processing") Timer processingTimeTimer,
-        @TimerFamily("event-family") TimerMap eventTimerFamily,
-        @TimerFamily("processing-family") TimerMap processingTimerFamily) {
-      context.output("event-family" + Iterables.toString(bagState.read()));
-      bagState.add("event-family");
-      eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(31L));
-      processingTimeTimer.offset(Duration.millis(32L));
-      processingTimeTimer.setRelative();
-      eventTimerFamily
-          .get("event-timer1")
-          .withOutputTimestamp(context.timestamp())
-          .set(context.timestamp().plus(33L));
-      processingTimerFamily.get("processing-timer1").offset(Duration.millis(34L)).setRelative();
-    }
-
-    @OnTimerFamily("processing-family")
-    public void processingFamilyOnTimer(
-        OnTimerContext context,
-        @StateId("bag") BagState<String> bagState,
-        @TimerId("event") Timer eventTimeTimer,
-        @TimerId("processing") Timer processingTimeTimer,
-        @TimerFamily("event-family") TimerMap eventTimerFamily,
-        @TimerFamily("processing-family") TimerMap processingTimerFamily) {
-      context.output("processing-family" + Iterables.toString(bagState.read()));
-      bagState.add("processing-family");
-      eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(41L));
-      processingTimeTimer.offset(Duration.millis(42L));
-      processingTimeTimer.setRelative();
-      eventTimerFamily
-          .get("event-timer1")
-          .withOutputTimestamp(context.timestamp())
-          .set(context.timestamp().plus(43L));
-      processingTimerFamily.get("processing-timer1").offset(Duration.millis(44L)).setRelative();
+      @ProcessElement
+      public void processElement(ProcessContext context) {
+        countedElements.inc();
+        context.output(context.sideInput(iterableSideInput));
+      }
     }
-  }
 
-  /**
-   * Produces an iterable side input {@link StateKey} for the test PTransform id in the global
-   * window.
-   */
-  private StateKey iterableSideInputKey(String sideInputId) throws IOException {
-    return iterableSideInputKey(
-        sideInputId,
-        ByteString.copyFrom(
-            CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE)));
-  }
+    @Test
+    public void testSideInputIsAccessibleForDownstreamCallers() throws Exception {
+      FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
+      IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
+      IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
+      ByteString encodedWindowA =
+          ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowA));
+      ByteString encodedWindowB =
+          ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowB));
+
+      Pipeline p = Pipeline.create();
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
+      // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
+      PCollection<String> valuePCollection =
+          p.apply(Create.of("unused")).apply(Window.into(windowFn));
+      PCollectionView<Iterable<String>> iterableSideInputView =
+          valuePCollection.apply(View.asIterable());
+      PCollection<Iterable<String>> outputPCollection =
+          valuePCollection.apply(
+              TEST_TRANSFORM_ID,
+              ParDo.of(new TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
+                  .withSideInputs(iterableSideInputView));
+
+      SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
+      RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
+      String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+      String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+
+      RunnerApi.PTransform pTransform =
+          pProto
+              .getComponents()
+              .getTransformsOrThrow(
+                  pProto
+                      .getComponents()
+                      .getTransformsOrThrow(TEST_TRANSFORM_ID)
+                      .getSubtransforms(0));
+
+      ImmutableMap<StateKey, ByteString> stateData =
+          ImmutableMap.of(
+              iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowA),
+              encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
+              iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowB),
+              encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+      List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>) mainOutputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* splitListener */,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      mainOutputValues.clear();
 
-  /**
-   * Produces an iterable side input {@link StateKey} for the test PTransform id in the supplied
-   * window.
-   */
-  private StateKey iterableSideInputKey(String sideInputId, ByteString windowKey) {
-    return StateKey.newBuilder()
-        .setIterableSideInput(
-            StateKey.IterableSideInput.newBuilder()
-                .setTransformId(TEST_TRANSFORM_ID)
-                .setSideInputId(sideInputId)
-                .setWindow(windowKey))
-        .build();
-  }
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      // Ensure that bag user state that is initially empty or populated works.
+      // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(valueInWindows("X", windowA));
+      mainInput.accept(valueInWindows("Y", windowB));
+      assertThat(mainOutputValues, hasSize(2));
+      assertThat(
+          mainOutputValues.get(0).getValue(),
+          contains("iterableValue1A", "iterableValue2A", "iterableValue3A"));
+      assertThat(
+          mainOutputValues.get(1).getValue(),
+          contains("iterableValue1B", "iterableValue2B", "iterableValue3B"));
+      mainOutputValues.clear();
+
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
 
-  private ByteString encode(String... values) throws IOException {
-    ByteString.Output out = ByteString.newOutput();
-    for (String value : values) {
-      StringUtf8Coder.of().encode(value, out);
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+
+      // Assert that state data did not change
+      assertEquals(stateData, fakeClient.getData());
     }
-    return out.toByteString();
-  }
 
-  @Test
-  public void testRegistration() {
-    for (PTransformRunnerFactory.Registrar registrar :
-        ServiceLoader.load(PTransformRunnerFactory.Registrar.class)) {
-      if (registrar instanceof FnApiDoFnRunner.Registrar) {
-        assertThat(
-            registrar.getPTransformRunnerFactories(),
-            IsMapContaining.hasKey(PTransformTranslation.PAR_DO_TRANSFORM_URN));
-        return;
-      }
+    /** @return a test MetricUpdate for expected metrics to compare against */
+    public MetricUpdate create(String stepName, MetricName name, long value) {
+      return MetricUpdate.create(MetricKey.create(stepName, name), value);
     }
-    fail("Expected registrar not found.");
-  }
 
-  /**
-   * The trySplit testing of this splittable DoFn is done when processing the {@link
-   * NonWindowObservingTestSplittableDoFn#SPLIT_ELEMENT}. Always checkpoints at element {@link
-   * NonWindowObservingTestSplittableDoFn#CHECKPOINT_UPPER_BOUND}.
-   *
-   * <p>The expected thread flow is:
-   *
-   * <ul>
-   *   <li>splitting thread: {@link
-   *       NonWindowObservingTestSplittableDoFn#waitForSplitElementToBeProcessed()}
-   *   <li>process element thread: {@link
-   *       NonWindowObservingTestSplittableDoFn#enableAndWaitForTrySplitToHappen()}
-   *   <li>splitting thread: perform try split
-   *   <li>splitting thread: {@link
-   *       NonWindowObservingTestSplittableDoFn#releaseWaitingProcessElementThread()}
-   * </ul>
-   */
-  static class NonWindowObservingTestSplittableDoFn extends DoFn<String, String> {
-    private static final ConcurrentMap<String, KV<CountDownLatch, CountDownLatch>>
-        DOFN_INSTANCE_TO_LOCK = new ConcurrentHashMap<>();
-    private static final long SPLIT_ELEMENT = 3;
-    private static final long CHECKPOINT_UPPER_BOUND = 8;
-
-    private KV<CountDownLatch, CountDownLatch> getLatches() {
-      return DOFN_INSTANCE_TO_LOCK.computeIfAbsent(
-          this.uuid, (uuid) -> KV.of(new CountDownLatch(1), new CountDownLatch(1)));
-    }
-
-    public void enableAndWaitForTrySplitToHappen() throws Exception {
-      KV<CountDownLatch, CountDownLatch> latches = getLatches();
-      latches.getKey().countDown();
-      if (!latches.getValue().await(30, TimeUnit.SECONDS)) {
-        fail("Failed to wait for trySplit to occur.");
+    @Test
+    public void testUsingMetrics() throws Exception {
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      MetricsContainerImpl metricsContainer = metricsContainerRegistry.getUnboundContainer();
+      Closeable closeable = MetricsEnvironment.scopedMetricsContainer(metricsContainer);
+      FixedWindows windowFn = FixedWindows.of(Duration.millis(1L));
+      IntervalWindow windowA = windowFn.assignWindow(new Instant(1L));
+      IntervalWindow windowB = windowFn.assignWindow(new Instant(2L));
+      ByteString encodedWindowA =
+          ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowA));
+      ByteString encodedWindowB =
+          ByteString.copyFrom(CoderUtils.encodeToByteArray(windowFn.windowCoder(), windowB));
+
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection =
+          p.apply(Create.of("unused")).apply(Window.into(windowFn));
+      PCollectionView<Iterable<String>> iterableSideInputView =
+          valuePCollection.apply(View.asIterable());
+      PCollection<Iterable<String>> outputPCollection =
+          valuePCollection.apply(
+              TEST_TRANSFORM_ID,
+              ParDo.of(new TestSideInputIsAccessibleForDownstreamCallersDoFn(iterableSideInputView))
+                  .withSideInputs(iterableSideInputView));
+
+      SdkComponents sdkComponents = SdkComponents.create(p.getOptions());
+      RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents, true);
+      String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+      String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+
+      RunnerApi.PTransform pTransform =
+          pProto
+              .getComponents()
+              .getTransformsOrThrow(
+                  pProto
+                      .getComponents()
+                      .getTransformsOrThrow(TEST_TRANSFORM_ID)
+                      .getSubtransforms(0));
+
+      ImmutableMap<StateKey, ByteString> stateData =
+          ImmutableMap.of(
+              iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowA),
+              encode("iterableValue1A", "iterableValue2A", "iterableValue3A"),
+              iterableSideInputKey(iterableSideInputView.getTagInternal().getId(), encodedWindowB),
+              encode("iterableValue1B", "iterableValue2B", "iterableValue3B"));
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+      List<WindowedValue<Iterable<String>>> mainOutputValues = new ArrayList<>();
+
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          Iterables.getOnlyElement(pTransform.getOutputsMap().values()),
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<Iterable<String>>>) mainOutputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* splitListener */,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      // Ensure that bag user state that is initially empty or populated works.
+      // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(valueInWindows("X", windowA));
+      mainInput.accept(valueInWindows("Y", windowB));
+      mainOutputValues.clear();
+
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+
+      MetricsContainer mc = MetricsEnvironment.getCurrentContainer();
+
+      List<MonitoringInfo> expected = new ArrayList<MonitoringInfo>();
+      SimpleMonitoringInfoBuilder builder = new SimpleMonitoringInfoBuilder();
+      builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+      builder.setLabel(
+          MonitoringInfoConstants.Labels.PCOLLECTION, "Window.Into()/Window.Assign.out");
+      builder.setInt64SumValue(2);
+      expected.add(builder.build());
+
+      builder = new SimpleMonitoringInfoBuilder();
+      builder.setUrn(MonitoringInfoConstants.Urns.ELEMENT_COUNT);
+      builder.setLabel(
+          MonitoringInfoConstants.Labels.PCOLLECTION,
+          "pTransformId/ParMultiDo(TestSideInputIsAccessibleForDownstreamCallers).output");
+      builder.setInt64SumValue(2);
+      expected.add(builder.build());
+
+      builder = new SimpleMonitoringInfoBuilder();
+      builder
+          .setUrn(MonitoringInfoConstants.Urns.USER_SUM_INT64)
+          .setLabel(
+              MonitoringInfoConstants.Labels.NAMESPACE,
+              TestSideInputIsAccessibleForDownstreamCallersDoFn.class.getName())
+          .setLabel(
+              MonitoringInfoConstants.Labels.NAME,
+              TestSideInputIsAccessibleForDownstreamCallersDoFn.USER_COUNTER_NAME);
+      builder.setLabel(MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+      builder.setInt64SumValue(2);
+      expected.add(builder.build());
+
+      closeable.close();
+      List<MonitoringInfo> result = new ArrayList<MonitoringInfo>();
+      for (MonitoringInfo mi : metricsContainerRegistry.getMonitoringInfos()) {
+        result.add(mi);
       }
+      assertThat(result, containsInAnyOrder(expected.toArray()));
     }
 
-    public void waitForSplitElementToBeProcessed() throws Exception {
-      KV<CountDownLatch, CountDownLatch> latches = getLatches();
-      if (!latches.getKey().await(30, TimeUnit.SECONDS)) {
-        fail("Failed to wait for split element to be processed.");
-      }
+    @Test
+    public void testTimers() throws Exception {
+      dateTimeProvider.setDateTimeFixed(10000L);
+
+      Pipeline p = Pipeline.create();
+      PCollection<KV<String, String>> valuePCollection =
+          p.apply(Create.of(KV.of("unused", "unused")));
+      PCollection<String> outputPCollection =
+          valuePCollection.apply(TEST_TRANSFORM_ID, ParDo.of(new TestTimerfulDoFn()));
+
+      SdkComponents sdkComponents = SdkComponents.create();
+      sdkComponents.registerEnvironment(Environment.getDefaultInstance());
+      RunnerApi.Pipeline pProto = PipelineTranslation.toProto(p, sdkComponents);
+      String inputPCollectionId = sdkComponents.registerPCollection(valuePCollection);
+      String outputPCollectionId = sdkComponents.registerPCollection(outputPCollection);
+
+      RunnerApi.PTransform pTransform =
+          pProto
+              .getComponents()
+              .getTransformsOrThrow(
+                  pProto
+                      .getComponents()
+                      .getTransformsOrThrow(TEST_TRANSFORM_ID)
+                      .getSubtransforms(0))
+              .toBuilder()
+              .build();
+
+      FakeBeamFnStateClient fakeStateClient =
+          new FakeBeamFnStateClient(
+              ImmutableMap.of(
+                  bagUserStateKey("bag", "X"), encode("X0"),
+                  bagUserStateKey("bag", "A"), encode("A0"),
+                  bagUserStateKey("bag", "C"), encode("C0")));
+      FakeBeamFnTimerClient fakeTimerClient = new FakeBeamFnTimerClient();
+
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeStateClient,
+              fakeTimerClient,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* splitListener */,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      LogicalEndpoint eventTimer = LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "ts-event");
+      LogicalEndpoint processingTimer =
+          LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "ts-processing");
+      LogicalEndpoint eventFamilyTimer =
+          LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "tfs-event-family");
+      LogicalEndpoint processingFamilyTimer =
+          LogicalEndpoint.timer("57L", TEST_TRANSFORM_ID, "tfs-processing-family");
+      // Ensure that bag user state that is initially empty or populated works.
+      // Ensure that the key order does not matter when we traverse over KV pairs.
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(timestampedValueInGlobalWindow(KV.of("X", "X1"), new Instant(1000L)));
+      mainInput.accept(timestampedValueInGlobalWindow(KV.of("Y", "Y1"), new Instant(1100L)));
+      mainInput.accept(timestampedValueInGlobalWindow(KV.of("X", "X2"), new Instant(1200L)));
+      mainInput.accept(timestampedValueInGlobalWindow(KV.of("Y", "Y2"), new Instant(1300L)));
+      fakeTimerClient.sendTimer(
+          eventTimer, timerInGlobalWindow("A", new Instant(1400L), new Instant(2400L)));
+      fakeTimerClient.sendTimer(
+          eventTimer, timerInGlobalWindow("B", new Instant(1500L), new Instant(2500L)));
+      fakeTimerClient.sendTimer(
+          eventTimer, timerInGlobalWindow("A", new Instant(1600L), new Instant(2600L)));
+      fakeTimerClient.sendTimer(
+          processingTimer, timerInGlobalWindow("X", new Instant(1700L), new Instant(2700L)));
+      fakeTimerClient.sendTimer(
+          processingTimer, timerInGlobalWindow("C", new Instant(1800L), new Instant(2800L)));
+      fakeTimerClient.sendTimer(
+          processingTimer, timerInGlobalWindow("B", new Instant(1900L), new Instant(2900L)));
+      fakeTimerClient.sendTimer(
+          eventFamilyTimer,
+          dynamicTimerInGlobalWindow("B", "event-timer2", new Instant(2000L), new Instant(3000L)));
+      fakeTimerClient.sendTimer(
+          processingFamilyTimer,
+          dynamicTimerInGlobalWindow(
+              "Y", "processing-timer2", new Instant(2100L), new Instant(3100L)));
+      assertThat(
+          mainOutputValues,
+          contains(
+              timestampedValueInGlobalWindow("mainX[X0]", new Instant(1000L)),
+              timestampedValueInGlobalWindow("mainY[]", new Instant(1100L)),
+              timestampedValueInGlobalWindow("mainX[X0, X1]", new Instant(1200L)),
+              timestampedValueInGlobalWindow("mainY[Y1]", new Instant(1300L)),
+              timestampedValueInGlobalWindow("event[A0]", new Instant(1400L)),
+              timestampedValueInGlobalWindow("event[]", new Instant(1500L)),
+              timestampedValueInGlobalWindow("event[A0, event]", new Instant(1600L)),
+              timestampedValueInGlobalWindow("processing[X0, X1, X2]", new Instant(1700L)),
+              timestampedValueInGlobalWindow("processing[C0]", new Instant(1800L)),
+              timestampedValueInGlobalWindow("processing[event]", new Instant(1900L)),
+              timestampedValueInGlobalWindow("event-family[event, processing]", new Instant(2000L)),
+              timestampedValueInGlobalWindow("processing-family[Y1, Y2]", new Instant(2100L))));
+      assertThat(
+          fakeTimerClient.getTimers(eventTimer),
+          contains(
+              timerInGlobalWindow("X", new Instant(1000L), new Instant(1001L)),
+              timerInGlobalWindow("Y", new Instant(1100L), new Instant(1101L)),
+              timerInGlobalWindow("X", new Instant(1200L), new Instant(1201L)),
+              timerInGlobalWindow("Y", new Instant(1300L), new Instant(1301L)),
+              timerInGlobalWindow("A", new Instant(1400L), new Instant(2411L)),
+              timerInGlobalWindow("B", new Instant(1500L), new Instant(2511L)),
+              timerInGlobalWindow("A", new Instant(1600L), new Instant(2611L)),
+              timerInGlobalWindow("X", new Instant(1700L), new Instant(1721L)),
+              timerInGlobalWindow("C", new Instant(1800L), new Instant(1821L)),
+              timerInGlobalWindow("B", new Instant(1900L), new Instant(1921L)),
+              timerInGlobalWindow("B", new Instant(2000L), new Instant(2031L)),
+              timerInGlobalWindow("Y", new Instant(2100L), new Instant(2141L))));
+      assertThat(
+          fakeTimerClient.getTimers(processingTimer),
+          contains(
+              timerInGlobalWindow("X", new Instant(1000L), new Instant(10002L)),
+              timerInGlobalWindow("Y", new Instant(1100L), new Instant(10002L)),
+              timerInGlobalWindow("X", new Instant(1200L), new Instant(10002L)),
+              timerInGlobalWindow("Y", new Instant(1300L), new Instant(10002L)),
+              timerInGlobalWindow("A", new Instant(1400L), new Instant(10012L)),
+              timerInGlobalWindow("B", new Instant(1500L), new Instant(10012L)),
+              timerInGlobalWindow("A", new Instant(1600L), new Instant(10012L)),
+              timerInGlobalWindow("X", new Instant(1700L), new Instant(10022L)),
+              timerInGlobalWindow("C", new Instant(1800L), new Instant(10022L)),
+              timerInGlobalWindow("B", new Instant(1900L), new Instant(10022L)),
+              timerInGlobalWindow("B", new Instant(2000L), new Instant(10032L)),
+              timerInGlobalWindow("Y", new Instant(2100L), new Instant(10042L))));
+      assertThat(
+          fakeTimerClient.getTimers(eventFamilyTimer),
+          contains(
+              dynamicTimerInGlobalWindow(
+                  "X", "event-timer1", new Instant(1000L), new Instant(1003L)),
+              dynamicTimerInGlobalWindow(
+                  "Y", "event-timer1", new Instant(1100L), new Instant(1103L)),
+              dynamicTimerInGlobalWindow(
+                  "X", "event-timer1", new Instant(1200L), new Instant(1203L)),
+              dynamicTimerInGlobalWindow(
+                  "Y", "event-timer1", new Instant(1300L), new Instant(1303L)),
+              dynamicTimerInGlobalWindow(
+                  "A", "event-timer1", new Instant(1400L), new Instant(2413L)),
+              dynamicTimerInGlobalWindow(
+                  "B", "event-timer1", new Instant(1500L), new Instant(2513L)),
+              dynamicTimerInGlobalWindow(
+                  "A", "event-timer1", new Instant(1600L), new Instant(2613L)),
+              dynamicTimerInGlobalWindow(
+                  "X", "event-timer1", new Instant(1700L), new Instant(1723L)),
+              dynamicTimerInGlobalWindow(
+                  "C", "event-timer1", new Instant(1800L), new Instant(1823L)),
+              dynamicTimerInGlobalWindow(
+                  "B", "event-timer1", new Instant(1900L), new Instant(1923L)),
+              dynamicTimerInGlobalWindow(
+                  "B", "event-timer1", new Instant(2000L), new Instant(2033L)),
+              dynamicTimerInGlobalWindow(
+                  "Y", "event-timer1", new Instant(2100L), new Instant(2143L))));
+      assertThat(
+          fakeTimerClient.getTimers(processingFamilyTimer),
+          contains(
+              dynamicTimerInGlobalWindow(
+                  "X", "processing-timer1", new Instant(1000L), new Instant(10004L)),
+              dynamicTimerInGlobalWindow(
+                  "Y", "processing-timer1", new Instant(1100L), new Instant(10004L)),
+              dynamicTimerInGlobalWindow(
+                  "X", "processing-timer1", new Instant(1200L), new Instant(10004L)),
+              dynamicTimerInGlobalWindow(
+                  "Y", "processing-timer1", new Instant(1300L), new Instant(10004L)),
+              dynamicTimerInGlobalWindow(
+                  "A", "processing-timer1", new Instant(1400L), new Instant(10014L)),
+              dynamicTimerInGlobalWindow(
+                  "B", "processing-timer1", new Instant(1500L), new Instant(10014L)),
+              dynamicTimerInGlobalWindow(
+                  "A", "processing-timer1", new Instant(1600L), new Instant(10014L)),
+              dynamicTimerInGlobalWindow(
+                  "X", "processing-timer1", new Instant(1700L), new Instant(10024L)),
+              dynamicTimerInGlobalWindow(
+                  "C", "processing-timer1", new Instant(1800L), new Instant(10024L)),
+              dynamicTimerInGlobalWindow(
+                  "B", "processing-timer1", new Instant(1900L), new Instant(10024L)),
+              dynamicTimerInGlobalWindow(
+                  "B", "processing-timer1", new Instant(2000L), new Instant(10034L)),
+              dynamicTimerInGlobalWindow(
+                  "Y", "processing-timer1", new Instant(2100L), new Instant(10044L))));
+      mainOutputValues.clear();
+
+      assertFalse(fakeTimerClient.isOutboundClosed(eventTimer));
+      assertFalse(fakeTimerClient.isOutboundClosed(processingTimer));
+      assertFalse(fakeTimerClient.isOutboundClosed(eventFamilyTimer));
+      assertFalse(fakeTimerClient.isOutboundClosed(processingFamilyTimer));
+      fakeTimerClient.closeInbound(eventTimer);
+      fakeTimerClient.closeInbound(processingTimer);
+      fakeTimerClient.closeInbound(eventFamilyTimer);
+      fakeTimerClient.closeInbound(processingFamilyTimer);
+
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
+
+      assertTrue(fakeTimerClient.isOutboundClosed(eventTimer));
+      assertTrue(fakeTimerClient.isOutboundClosed(processingTimer));
+      assertTrue(fakeTimerClient.isOutboundClosed(eventFamilyTimer));
+      assertTrue(fakeTimerClient.isOutboundClosed(processingFamilyTimer));
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+
+      assertEquals(
+          ImmutableMap.<StateKey, ByteString>builder()
+              .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2", "processing"))
+              .put(bagUserStateKey("bag", "Y"), encode("Y1", "Y2", "processing-family"))
+              .put(bagUserStateKey("bag", "A"), encode("A0", "event", "event"))
+              .put(bagUserStateKey("bag", "B"), encode("event", "processing", "event-family"))
+              .put(bagUserStateKey("bag", "C"), encode("C0", "processing"))
+              .build(),
+          fakeStateClient.getData());
     }
 
-    public void releaseWaitingProcessElementThread() {
-      KV<CountDownLatch, CountDownLatch> latches = getLatches();
-      latches.getValue().countDown();
+    private <K> org.apache.beam.runners.core.construction.Timer<K> timerInGlobalWindow(
+        K userKey, Instant holdTimestamp, Instant fireTimestamp) {
+      return dynamicTimerInGlobalWindow(userKey, "", holdTimestamp, fireTimestamp);
     }
 
-    private final String uuid;
+    private <K> org.apache.beam.runners.core.construction.Timer<K> dynamicTimerInGlobalWindow(
+        K userKey, String dynamicTimerTag, Instant holdTimestamp, Instant fireTimestamp) {
+      return org.apache.beam.runners.core.construction.Timer.of(
+          userKey,
+          dynamicTimerTag,
+          Collections.singletonList(GlobalWindow.INSTANCE),
+          fireTimestamp,
+          holdTimestamp,
+          PaneInfo.NO_FIRING);
+    }
 
-    private NonWindowObservingTestSplittableDoFn() {
-      this.uuid = UUID.randomUUID().toString();
+    private <T> WindowedValue<T> valueInWindows(
+        T value, BoundedWindow window, BoundedWindow... windows) {
+      return WindowedValue.of(
+          value,
+          window.maxTimestamp(),
+          ImmutableList.<BoundedWindow>builder().add(window).add(windows).build(),
+          PaneInfo.NO_FIRING);
     }
 
-    @ProcessElement
-    public ProcessContinuation processElement(
-        ProcessContext context,
-        RestrictionTracker<OffsetRange, Long> tracker,
-        ManualWatermarkEstimator<Instant> watermarkEstimator)
-        throws Exception {
-      long checkpointUpperBound = CHECKPOINT_UPPER_BOUND;
-      long position = tracker.currentRestriction().getFrom();
-      boolean claimStatus;
-      while (true) {
-        claimStatus = (tracker.tryClaim(position));
-        if (!claimStatus) {
-          break;
-        } else if (position == SPLIT_ELEMENT) {
-          enableAndWaitForTrySplitToHappen();
-        }
-        context.outputWithTimestamp(
-            context.element() + ":" + position, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
-        watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
-        position += 1L;
-        if (position == checkpointUpperBound) {
-          break;
-        }
+    private static class TestTimerfulDoFn extends DoFn<KV<String, String>, String> {
+      @StateId("bag")
+      private final StateSpec<BagState<String>> bagStateSpec = StateSpecs.bag(StringUtf8Coder.of());
+
+      @TimerId("event")
+      private final TimerSpec eventTimerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+      @TimerId("processing")
+      private final TimerSpec processingTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
+
+      @TimerFamily("event-family")
+      private final TimerSpec eventTimerFamilySpec = TimerSpecs.timerMap(TimeDomain.EVENT_TIME);
+
+      @TimerFamily("processing-family")
+      private final TimerSpec processingTimerFamilySpec =
+          TimerSpecs.timerMap(TimeDomain.PROCESSING_TIME);
+
+      @ProcessElement
+      public void processElement(
+          ProcessContext context,
+          @StateId("bag") BagState<String> bagState,
+          @TimerId("event") Timer eventTimeTimer,
+          @TimerId("processing") Timer processingTimeTimer,
+          @TimerFamily("event-family") TimerMap eventTimerFamily,
+          @TimerFamily("processing-family") TimerMap processingTimerFamily) {
+        context.output("main" + context.element().getKey() + Iterables.toString(bagState.read()));
+        bagState.add(context.element().getValue());
+
+        eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(1L));
+        processingTimeTimer.offset(Duration.millis(2L));
+        processingTimeTimer.setRelative();
+        eventTimerFamily
+            .get("event-timer1")
+            .withOutputTimestamp(context.timestamp())
+            .set(context.timestamp().plus(3L));
+        processingTimerFamily.get("processing-timer1").offset(Duration.millis(4L)).setRelative();
+      }
+
+      @OnTimer("event")
+      public void eventTimer(
+          OnTimerContext context,
+          @StateId("bag") BagState<String> bagState,
+          @TimerId("event") Timer eventTimeTimer,
+          @TimerId("processing") Timer processingTimeTimer,
+          @TimerFamily("event-family") TimerMap eventTimerFamily,
+          @TimerFamily("processing-family") TimerMap processingTimerFamily) {
+        context.output("event" + Iterables.toString(bagState.read()));
+        bagState.add("event");
+        eventTimeTimer
+            .withOutputTimestamp(context.timestamp())
+            .set(context.fireTimestamp().plus(11L));
+        processingTimeTimer.offset(Duration.millis(12L));
+        processingTimeTimer.setRelative();
+        eventTimerFamily
+            .get("event-timer1")
+            .withOutputTimestamp(context.timestamp())
+            .set(context.fireTimestamp().plus(13L));
+
+        processingTimerFamily.get("processing-timer1").offset(Duration.millis(14L)).setRelative();
       }
-      if (!claimStatus) {
-        return ProcessContinuation.stop();
-      } else {
-        return ProcessContinuation.resume().withResumeDelay(Duration.millis(54321L));
+
+      @OnTimer("processing")
+      public void processingTimer(
+          OnTimerContext context,
+          @StateId("bag") BagState<String> bagState,
+          @TimerId("event") Timer eventTimeTimer,
+          @TimerId("processing") Timer processingTimeTimer,
+          @TimerFamily("event-family") TimerMap eventTimerFamily,
+          @TimerFamily("processing-family") TimerMap processingTimerFamily) {
+        context.output("processing" + Iterables.toString(bagState.read()));
+        bagState.add("processing");
+
+        eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(21L));
+        processingTimeTimer.offset(Duration.millis(22L));
+        processingTimeTimer.setRelative();
+        eventTimerFamily
+            .get("event-timer1")
+            .withOutputTimestamp(context.timestamp())
+            .set(context.timestamp().plus(23L));
+
+        processingTimerFamily.get("processing-timer1").offset(Duration.millis(24L)).setRelative();
       }
-    }
 
-    @GetInitialRestriction
-    public OffsetRange restriction(@Element String element) {
-      return new OffsetRange(0, Integer.parseInt(element));
-    }
+      @OnTimerFamily("event-family")
+      public void eventFamilyOnTimer(
+          OnTimerContext context,
+          @StateId("bag") BagState<String> bagState,
+          @TimerId("event") Timer eventTimeTimer,
+          @TimerId("processing") Timer processingTimeTimer,
+          @TimerFamily("event-family") TimerMap eventTimerFamily,
+          @TimerFamily("processing-family") TimerMap processingTimerFamily) {
+        context.output("event-family" + Iterables.toString(bagState.read()));
+        bagState.add("event-family");
+
+        eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(31L));
+        processingTimeTimer.offset(Duration.millis(32L));
+        processingTimeTimer.setRelative();
+        eventTimerFamily
+            .get("event-timer1")
+            .withOutputTimestamp(context.timestamp())
+            .set(context.timestamp().plus(33L));
+
+        processingTimerFamily.get("processing-timer1").offset(Duration.millis(34L)).setRelative();
+      }
 
-    @NewTracker
-    public RestrictionTracker<OffsetRange, Long> newTracker(@Restriction OffsetRange restriction) {
-      return new OffsetRangeTracker(restriction);
+      @OnTimerFamily("processing-family")
+      public void processingFamilyOnTimer(
+          OnTimerContext context,
+          @StateId("bag") BagState<String> bagState,
+          @TimerId("event") Timer eventTimeTimer,
+          @TimerId("processing") Timer processingTimeTimer,
+          @TimerFamily("event-family") TimerMap eventTimerFamily,
+          @TimerFamily("processing-family") TimerMap processingTimerFamily) {
+        context.output("processing-family" + Iterables.toString(bagState.read()));
+        bagState.add("processing-family");
+
+        eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(41L));
+        processingTimeTimer.offset(Duration.millis(42L));
+        processingTimeTimer.setRelative();
+        eventTimerFamily
+            .get("event-timer1")
+            .withOutputTimestamp(context.timestamp())
+            .set(context.timestamp().plus(43L));
+
+        processingTimerFamily.get("processing-timer1").offset(Duration.millis(44L)).setRelative();
+      }
     }
 
-    @SplitRestriction
-    public void splitRange(@Restriction OffsetRange range, OutputReceiver<OffsetRange> receiver) {
-      receiver.output(new OffsetRange(range.getFrom(), (range.getFrom() + range.getTo()) / 2));
-      receiver.output(new OffsetRange((range.getFrom() + range.getTo()) / 2, range.getTo()));
+    /**
+     * Produces an iterable side input {@link StateKey} for the test PTransform id in the global
+     * window.
+     */
+    private StateKey iterableSideInputKey(String sideInputId) throws IOException {
+      return iterableSideInputKey(
+          sideInputId,
+          ByteString.copyFrom(
+              CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE)));
     }
 
-    @TruncateRestriction
-    public TruncateResult<OffsetRange> truncateRestriction(@Restriction OffsetRange range) {
-      return TruncateResult.of(new OffsetRange(range.getFrom(), range.getTo() / 2));
+    /**
+     * Produces an iterable side input {@link StateKey} for the test PTransform id in the supplied
+     * window.
+     */
+    private StateKey iterableSideInputKey(String sideInputId, ByteString windowKey) {
+      return StateKey.newBuilder()
+          .setIterableSideInput(
+              StateKey.IterableSideInput.newBuilder()
+                  .setTransformId(TEST_TRANSFORM_ID)
+                  .setSideInputId(sideInputId)
+                  .setWindow(windowKey))
+          .build();
     }
 
-    @GetInitialWatermarkEstimatorState
-    public Instant getInitialWatermarkEstimatorState() {
-      return GlobalWindow.TIMESTAMP_MIN_VALUE;
+    private ByteString encode(String... values) throws IOException {
+      ByteString.Output out = ByteString.newOutput();
+      for (String value : values) {
+        StringUtf8Coder.of().encode(value, out);
+      }
+      return out.toByteString();
     }
 
-    @NewWatermarkEstimator
-    public WatermarkEstimators.Manual newWatermarkEstimator(
-        @WatermarkEstimatorState Instant watermark) {
-      return new WatermarkEstimators.Manual(watermark);
+    @Test
+    public void testRegistration() {
+      for (PTransformRunnerFactory.Registrar registrar :
+          ServiceLoader.load(PTransformRunnerFactory.Registrar.class)) {
+        if (registrar instanceof FnApiDoFnRunner.Registrar) {
+          assertThat(
+              registrar.getPTransformRunnerFactories(),
+              IsMapContaining.hasKey(PTransformTranslation.PAR_DO_TRANSFORM_URN));
+          return;
+        }
+      }
+      fail("Expected registrar not found.");
     }
-  }
 
-  /**
-   * A window observing variant of {@link NonWindowObservingTestSplittableDoFn} which uses the side
-   * inputs to choose the checkpoint upper bound.
-   */
-  static class WindowObservingTestSplittableDoFn extends NonWindowObservingTestSplittableDoFn {
+    /**
+     * The trySplit testing of this splittable DoFn is done when processing the {@link
+     * NonWindowObservingTestSplittableDoFn#SPLIT_ELEMENT}. Always checkpoints at element {@link
+     * NonWindowObservingTestSplittableDoFn#CHECKPOINT_UPPER_BOUND}.
+     *
+     * <p>The expected thread flow is:
+     *
+     * <ul>
+     *   <li>splitting thread: {@link
+     *       NonWindowObservingTestSplittableDoFn#waitForSplitElementToBeProcessed()}
+     *   <li>process element thread: {@link
+     *       NonWindowObservingTestSplittableDoFn#enableAndWaitForTrySplitToHappen()}
+     *   <li>splitting thread: perform try split
+     *   <li>splitting thread: {@link
+     *       NonWindowObservingTestSplittableDoFn#releaseWaitingProcessElementThread()}
+     * </ul>
+     */
+    static class NonWindowObservingTestSplittableDoFn extends DoFn<String, String> {
+      private static final ConcurrentMap<String, KV<CountDownLatch, CountDownLatch>>
+          DOFN_INSTANCE_TO_LOCK = new ConcurrentHashMap<>();
+      private static final long SPLIT_ELEMENT = 3;
+      private static final long CHECKPOINT_UPPER_BOUND = 8;
+
+      private KV<CountDownLatch, CountDownLatch> getLatches() {
+        return DOFN_INSTANCE_TO_LOCK.computeIfAbsent(
+            this.uuid, (uuid) -> KV.of(new CountDownLatch(1), new CountDownLatch(1)));
+      }
 
-    private final PCollectionView<String> singletonSideInput;
+      public void enableAndWaitForTrySplitToHappen() throws Exception {
+        KV<CountDownLatch, CountDownLatch> latches = getLatches();
+        latches.getKey().countDown();
+        if (!latches.getValue().await(30, TimeUnit.SECONDS)) {
+          fail("Failed to wait for trySplit to occur.");
+        }
+      }
 
-    private WindowObservingTestSplittableDoFn(PCollectionView<String> singletonSideInput) {
-      this.singletonSideInput = singletonSideInput;
-    }
+      public void waitForSplitElementToBeProcessed() throws Exception {
+        KV<CountDownLatch, CountDownLatch> latches = getLatches();
+        if (!latches.getKey().await(30, TimeUnit.SECONDS)) {
+          fail("Failed to wait for split element to be processed.");
+        }
+      }
 
-    @Override
-    @ProcessElement
-    public ProcessContinuation processElement(
-        ProcessContext context,
-        RestrictionTracker<OffsetRange, Long> tracker,
-        ManualWatermarkEstimator<Instant> watermarkEstimator)
-        throws Exception {
-      long checkpointUpperBound = Long.parseLong(context.sideInput(singletonSideInput));
-      long position = tracker.currentRestriction().getFrom();
-      boolean claimStatus;
-      while (true) {
-        claimStatus = (tracker.tryClaim(position));
-        if (!claimStatus) {
-          break;
-        } else if (position == NonWindowObservingTestSplittableDoFn.SPLIT_ELEMENT) {
-          enableAndWaitForTrySplitToHappen();
+      public void releaseWaitingProcessElementThread() {
+        KV<CountDownLatch, CountDownLatch> latches = getLatches();
+        latches.getValue().countDown();
+      }
+
+      private final String uuid;
+
+      private NonWindowObservingTestSplittableDoFn() {
+        this.uuid = UUID.randomUUID().toString();
+      }
+
+      @ProcessElement
+      public ProcessContinuation processElement(
+          ProcessContext context,
+          RestrictionTracker<OffsetRange, Long> tracker,
+          ManualWatermarkEstimator<Instant> watermarkEstimator)
+          throws Exception {
+        long checkpointUpperBound = CHECKPOINT_UPPER_BOUND;
+        long position = tracker.currentRestriction().getFrom();
+        boolean claimStatus;
+        while (true) {
+          claimStatus = (tracker.tryClaim(position));
+          if (!claimStatus) {
+            break;
+          } else if (position == SPLIT_ELEMENT) {
+            enableAndWaitForTrySplitToHappen();
+          }
+          context.outputWithTimestamp(
+              context.element() + ":" + position, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+          watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+          position += 1L;
+          if (position == checkpointUpperBound) {
+            break;
+          }
         }
-        context.outputWithTimestamp(
-            context.element() + ":" + position, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
-        watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
-        position += 1L;
-        if (position == checkpointUpperBound) {
-          break;
+        if (!claimStatus) {
+          return ProcessContinuation.stop();
+        } else {
+          return ProcessContinuation.resume().withResumeDelay(Duration.millis(54321L));
         }
       }
-      if (!claimStatus) {
-        return ProcessContinuation.stop();
-      } else {
-        return ProcessContinuation.resume().withResumeDelay(Duration.millis(54321L));
+
+      @GetInitialRestriction
+      public OffsetRange restriction(@Element String element) {
+        return new OffsetRange(0, Integer.parseInt(element));
       }
-    }
-  }
 
-  @Test
-  public void testProcessElementForSizedElementAndRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
-    // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    WindowObservingTestSplittableDoFn doFn =
-        new WindowObservingTestSplittableDoFn(singletonSideInputView);
-    valuePCollection.apply(
-        TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(
-                                PTransformTranslation
-                                    .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    RunnerApi.PCollection inputPCollection =
-        pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
-    RehydratedComponents rehydratedComponents =
-        RehydratedComponents.forComponents(pProto.getComponents());
-    Coder<WindowedValue> inputCoder =
-        WindowedValue.getFullCoder(
-            CoderTranslation.fromProto(
-                pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
-                rehydratedComponents,
-                TranslationContext.DEFAULT),
-            (Coder)
-                CoderTranslation.fromProto(
-                    pProto
-                        .getComponents()
-                        .getCodersOrThrow(
-                            pProto
-                                .getComponents()
-                                .getWindowingStrategiesOrThrow(
-                                    inputPCollection.getWindowingStrategyId())
-                                .getWindowCoderId()),
-                    rehydratedComponents,
-                    TranslationContext.DEFAULT));
-    String outputPCollectionId = pTransform.getOutputsOrThrow("output");
-
-    ImmutableMap<StateKey, ByteString> stateData =
-        ImmutableMap.of(
-            iterableSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
-            encode("8"));
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-    List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
-    BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            progressRequestCallbacks::add,
-            splitListener,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    assertThat(mainInput, instanceOf(HandlesSplits.class));
-
-    {
-      // Check that before processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-      mainInput.accept(
-          valueInGlobalWindow(
-              KV.of(
-                  KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                  5.0)));
-      // Check that after processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-
-      // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
-      // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
-      // with
-      // the original watermark while the residual represents [8, 10) with the new MIN + 7
-      // watermark.
-      BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
-      DelayedBundleApplication residualRoot =
-          Iterables.getOnlyElement(splitListener.getResidualRoots());
-      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
-      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
-      assertEquals(
-          ParDoTranslation.getMainInputName(pTransform),
-          residualRoot.getApplication().getInputId());
-      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
-      assertEquals(
-          valueInGlobalWindow(
-              KV.of(
-                  KV.of("5", KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE)), 3.0)),
-          inputCoder.decode(primaryRoot.getElement().newInput()));
-      assertEquals(
-          valueInGlobalWindow(
-              KV.of(
-                  KV.of(
-                      "5", KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
-                  2.0)),
-          inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
-      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
-      assertEquals(
-          ImmutableMap.of(
-              "output",
-              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
-                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
-                  .build()),
-          residualRoot.getApplication().getOutputWatermarksMap());
-      assertEquals(
-          org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
-              .setSeconds(54)
-              .setNanos(321000000)
-              .build(),
-          residualRoot.getRequestedTimeDelay());
-      splitListener.clear();
+      @NewTracker
+      public RestrictionTracker<OffsetRange, Long> newTracker(
+          @Restriction OffsetRange restriction) {
+        return new OffsetRangeTracker(restriction);
+      }
 
-      // Check that before processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-      mainInput.accept(
-          valueInGlobalWindow(
-              KV.of(
-                  KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                  2.0)));
-      // Check that after processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+      @SplitRestriction
+      public void splitRange(@Restriction OffsetRange range, OutputReceiver<OffsetRange> receiver) {
+        receiver.output(new OffsetRange(range.getFrom(), (range.getFrom() + range.getTo()) / 2));
+        receiver.output(new OffsetRange((range.getFrom() + range.getTo()) / 2, range.getTo()));
+      }
 
-      assertThat(
-          mainOutputValues,
-          contains(
-              timestampedValueInGlobalWindow("5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5)),
-              timestampedValueInGlobalWindow("5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6)),
-              timestampedValueInGlobalWindow("5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7)),
-              timestampedValueInGlobalWindow("2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
-              timestampedValueInGlobalWindow("2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))));
-      assertTrue(splitListener.getPrimaryRoots().isEmpty());
-      assertTrue(splitListener.getResidualRoots().isEmpty());
-      mainOutputValues.clear();
+      @TruncateRestriction
+      public TruncateResult<OffsetRange> truncateRestriction(@Restriction OffsetRange range)
+          throws Exception {
+        return TruncateResult.of(new OffsetRange(range.getFrom(), range.getTo() / 2));
+      }
+
+      @GetInitialWatermarkEstimatorState
+      public Instant getInitialWatermarkEstimatorState() {
+        return GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1);
+      }
+
+      @NewWatermarkEstimator
+      public WatermarkEstimators.Manual newWatermarkEstimator(
+          @WatermarkEstimatorState Instant watermark) {
+        return new WatermarkEstimators.Manual(watermark);
+      }
     }
 
-    {
-      // Setup and launch the trySplit thread.
-      ExecutorService executorService = Executors.newSingleThreadExecutor();
-      Future<HandlesSplits.SplitResult> trySplitFuture =
-          executorService.submit(
-              () -> {
-                try {
-                  doFn.waitForSplitElementToBeProcessed();
-                  // Currently processing "3" out of range [0, 5) elements.
-                  assertEquals(0.6, ((HandlesSplits) mainInput).getProgress(), 0.01);
-
-                  // Check that during progressing of an element we report progress
-                  List<MonitoringInfo> mis =
-                      Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos();
-                  MonitoringInfo.Builder expectedCompleted = MonitoringInfo.newBuilder();
-                  expectedCompleted.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED);
-                  expectedCompleted.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
-                  expectedCompleted.putLabels(
-                      MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
-                  expectedCompleted.setPayload(
-                      ByteString.copyFrom(
-                          CoderUtils.encodeToByteArray(
-                              IterableCoder.of(DoubleCoder.of()), Collections.singletonList(3.0))));
-                  MonitoringInfo.Builder expectedRemaining = MonitoringInfo.newBuilder();
-                  expectedRemaining.setUrn(MonitoringInfoConstants.Urns.WORK_REMAINING);
-                  expectedRemaining.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
-                  expectedRemaining.putLabels(
-                      MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
-                  expectedRemaining.setPayload(
-                      ByteString.copyFrom(
-                          CoderUtils.encodeToByteArray(
-                              IterableCoder.of(DoubleCoder.of()), Collections.singletonList(2.0))));
-                  assertThat(
-                      mis,
-                      containsInAnyOrder(expectedCompleted.build(), expectedRemaining.build()));
+    /**
+     * A window observing variant of {@link NonWindowObservingTestSplittableDoFn} which uses the
+     * side inputs to choose the checkpoint upper bound.
+     */
+    static class WindowObservingTestSplittableDoFn extends NonWindowObservingTestSplittableDoFn {
 
-                  return ((HandlesSplits) mainInput).trySplit(0);
-                } finally {
-                  doFn.releaseWaitingProcessElementThread();
-                }
-              });
+      private final PCollectionView<String> singletonSideInput;
+      private static final long PROCESSED_WINDOW = 1;
+      private boolean splitAtTruncate = false;
+      private long processedWindowCount = 0;
 
-      // Check that before processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-      mainInput.accept(
-          valueInGlobalWindow(
-              KV.of(
-                  KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                  2.0)));
-      HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+      private WindowObservingTestSplittableDoFn(PCollectionView<String> singletonSideInput) {
+        this.singletonSideInput = singletonSideInput;
+      }
 
-      // Check that after processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+      private static WindowObservingTestSplittableDoFn forSplitAtTruncate(
+          PCollectionView<String> singletonSideInput) {
+        WindowObservingTestSplittableDoFn doFn =
+            new WindowObservingTestSplittableDoFn(singletonSideInput);
+        doFn.splitAtTruncate = true;
+        return doFn;
+      }
 
-      // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split.
-      // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
-      // has yet to be invoked for the split element and that the primary represents [0, 4) with
-      // the original watermark while the residual represents [4, 5) with the new MIN + 2 watermark.
-      assertThat(
-          mainOutputValues,
-          contains(
-              timestampedValueInGlobalWindow("7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
-              timestampedValueInGlobalWindow("7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1)),
-              timestampedValueInGlobalWindow("7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
-              timestampedValueInGlobalWindow("7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3))));
-
-      BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
-      DelayedBundleApplication residualRoot =
-          Iterables.getOnlyElement(trySplitResult.getResidualRoots());
-      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
-      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
-      assertEquals(
-          ParDoTranslation.getMainInputName(pTransform),
-          residualRoot.getApplication().getInputId());
-      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
-      assertEquals(
-          valueInGlobalWindow(
-              KV.of(
-                  KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)), 4.0)),
-          inputCoder.decode(primaryRoot.getElement().newInput()));
-      assertEquals(
-          valueInGlobalWindow(
-              KV.of(
-                  KV.of(
-                      "7", KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
-                  1.0)),
-          inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
-      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
-      assertEquals(
+      @Override
+      @ProcessElement
+      public ProcessContinuation processElement(
+          ProcessContext context,
+          RestrictionTracker<OffsetRange, Long> tracker,
+          ManualWatermarkEstimator<Instant> watermarkEstimator)
+          throws Exception {
+        long checkpointUpperBound = Long.parseLong(context.sideInput(singletonSideInput));
+        long position = tracker.currentRestriction().getFrom();
+        boolean claimStatus;
+        while (true) {
+          claimStatus = (tracker.tryClaim(position));
+          if (!claimStatus) {
+            break;
+          } else if (position == NonWindowObservingTestSplittableDoFn.SPLIT_ELEMENT) {
+            enableAndWaitForTrySplitToHappen();
+          }
+          context.outputWithTimestamp(
+              context.element() + ":" + position, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+          watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+          position += 1L;
+          if (position == checkpointUpperBound) {
+            break;
+          }
+        }
+        if (!claimStatus) {
+          return ProcessContinuation.stop();
+        } else {
+          return ProcessContinuation.resume().withResumeDelay(Duration.millis(54321L));
+        }
+      }
+
+      @Override
+      @TruncateRestriction
+      public TruncateResult<OffsetRange> truncateRestriction(@Restriction OffsetRange range)
+          throws Exception {
+        // Waiting for split when we are on the second window.
+        if (splitAtTruncate && processedWindowCount == PROCESSED_WINDOW) {
+          enableAndWaitForTrySplitToHappen();
+        }
+        processedWindowCount += 1;
+        return TruncateResult.of(new OffsetRange(range.getFrom(), range.getTo() / 2));
+      }
+    }
+
+    @Test
+    public void testProcessElementForSizedElementAndRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
+      // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      WindowObservingTestSplittableDoFn doFn =
+          new WindowObservingTestSplittableDoFn(singletonSideInputView);
+      valuePCollection.apply(
+          TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation
+                                      .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      RunnerApi.PCollection inputPCollection =
+          pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+      RehydratedComponents rehydratedComponents =
+          RehydratedComponents.forComponents(pProto.getComponents());
+      Coder<WindowedValue> inputCoder =
+          WindowedValue.getFullCoder(
+              CoderTranslation.fromProto(
+                  pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                  rehydratedComponents,
+                  TranslationContext.DEFAULT),
+              (Coder)
+                  CoderTranslation.fromProto(
+                      pProto
+                          .getComponents()
+                          .getCodersOrThrow(
+                              pProto
+                                  .getComponents()
+                                  .getWindowingStrategiesOrThrow(
+                                      inputPCollection.getWindowingStrategyId())
+                                  .getWindowCoderId()),
+                      rehydratedComponents,
+                      TranslationContext.DEFAULT));
+      String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+      ImmutableMap<StateKey, ByteString> stateData =
           ImmutableMap.of(
-              "output",
-              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
-                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
-                  .build()),
-          residualRoot.getApplication().getOutputWatermarksMap());
-      // We expect 0 resume delay.
-      assertEquals(
-          residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
-          residualRoot.getRequestedTimeDelay());
-      // We don't expect the outputs to goto the SDK initiated checkpointing listener.
-      assertTrue(splitListener.getPrimaryRoots().isEmpty());
-      assertTrue(splitListener.getResidualRoots().isEmpty());
+              iterableSideInputKey(
+                  singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
+              encode("8"));
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+      List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
+      BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              progressRequestCallbacks::add,
+              splitListener,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
       mainOutputValues.clear();
-      executorService.shutdown();
-    }
 
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
 
-    // Assert that state data did not change
-    assertEquals(stateData, fakeClient.getData());
-  }
+      {
+        // Check that before processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+        mainInput.accept(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    5.0)));
+        // Check that after processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+        // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
+        // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
+        // with
+        // the original watermark while the residual represents [8, 10) with the new MIN + 7
+        // watermark.
+        BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+        DelayedBundleApplication residualRoot =
+            Iterables.getOnlyElement(splitListener.getResidualRoots());
+        assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+        assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRoot.getApplication().getInputId());
+        assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+        assertEquals(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    3.0)),
+            inputCoder.decode(primaryRoot.getElement().newInput()));
+        assertEquals(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of(
+                        "5",
+                        KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
+                    2.0)),
+            inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+        Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
+        assertEquals(
+            ImmutableMap.of(
+                "output",
+                org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                    .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                    .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                    .build()),
+            residualRoot.getApplication().getOutputWatermarksMap());
+        assertEquals(
+            org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
+                .setSeconds(54)
+                .setNanos(321000000)
+                .build(),
+            residualRoot.getRequestedTimeDelay());
+        splitListener.clear();
+
+        // Check that before processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+        mainInput.accept(
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    2.0)));
+        // Check that after processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
 
-  @Test
-  public void testProcessElementForWindowedSizedElementAndRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
-    // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
-    addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    WindowObservingTestSplittableDoFn doFn =
-        new WindowObservingTestSplittableDoFn(singletonSideInputView);
-
-    valuePCollection
-        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
-        .apply(TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(
-                                PTransformTranslation
-                                    .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    RunnerApi.PCollection inputPCollection =
-        pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
-    RehydratedComponents rehydratedComponents =
-        RehydratedComponents.forComponents(pProto.getComponents());
-    Coder<WindowedValue> inputCoder =
-        WindowedValue.getFullCoder(
-            CoderTranslation.fromProto(
-                pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
-                rehydratedComponents,
-                TranslationContext.DEFAULT),
-            (Coder)
-                CoderTranslation.fromProto(
-                    pProto
-                        .getComponents()
-                        .getCodersOrThrow(
-                            pProto
-                                .getComponents()
-                                .getWindowingStrategiesOrThrow(
-                                    inputPCollection.getWindowingStrategyId())
-                                .getWindowCoderId()),
-                    rehydratedComponents,
-                    TranslationContext.DEFAULT));
-    String outputPCollectionId = pTransform.getOutputsOrThrow("output");
-
-    ImmutableMap<StateKey, ByteString> stateData =
-        ImmutableMap.of(
-            iterableSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
-            encode("8"));
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
-
-    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-    List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
-    BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            progressRequestCallbacks::add,
-            splitListener,
-            null /* bundleFinalizer */);
-
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    assertThat(mainInput, instanceOf(HandlesSplits.class));
-
-    BoundedWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
-    BoundedWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
-    {
-      // Check that before processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-      WindowedValue<?> firstValue =
-          valueInWindows(
-              KV.of(
-                  KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
-              window1,
-              window2);
-      mainInput.accept(firstValue);
-      // Check that after processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-
-      // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
-      // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
-      // with the original watermark while the residual represents [8, 10) with the new MIN + 7
-      // watermark.
-      //
-      // Since we were on the first window, we expect only a single primary root and two residual
-      // roots (the split + the unprocessed window).
-      BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
-      assertEquals(2, splitListener.getResidualRoots().size());
-      DelayedBundleApplication residualRoot = splitListener.getResidualRoots().get(1);
-      DelayedBundleApplication residualRootForUnprocessedWindows =
-          splitListener.getResidualRoots().get(0);
-      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
-      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
-      assertEquals(
-          ParDoTranslation.getMainInputName(pTransform),
-          residualRoot.getApplication().getInputId());
-      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
-      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
-      assertEquals(
-          ImmutableMap.of(
-              "output",
-              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
-                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
-                  .build()),
-          residualRoot.getApplication().getOutputWatermarksMap());
-      assertEquals(
-          org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
-              .setSeconds(54)
-              .setNanos(321000000)
-              .build(),
-          residualRoot.getRequestedTimeDelay());
-      assertEquals(
-          ParDoTranslation.getMainInputName(pTransform),
-          residualRootForUnprocessedWindows.getApplication().getInputId());
-      assertEquals(
-          TEST_TRANSFORM_ID, residualRootForUnprocessedWindows.getApplication().getTransformId());
-      assertEquals(
-          residualRootForUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
-          residualRootForUnprocessedWindows.getRequestedTimeDelay());
-      assertTrue(
-          residualRootForUnprocessedWindows.getApplication().getOutputWatermarksMap().isEmpty());
-
-      assertEquals(
-          decode(inputCoder, primaryRoot.getElement()),
-          WindowedValue.of(
-              KV.of(
-                  KV.of("5", KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE)), 3.0),
-              firstValue.getTimestamp(),
-              window1,
-              firstValue.getPane()));
-      assertEquals(
-          decode(inputCoder, residualRoot.getApplication().getElement()),
-          WindowedValue.of(
-              KV.of(
-                  KV.of(
-                      "5", KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
-                  2.0),
-              firstValue.getTimestamp(),
-              window1,
-              firstValue.getPane()));
-      assertEquals(
-          decode(inputCoder, residualRootForUnprocessedWindows.getApplication().getElement()),
-          WindowedValue.of(
-              KV.of(
-                  KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
-              firstValue.getTimestamp(),
-              window2,
-              firstValue.getPane()));
-      splitListener.clear();
-
-      // Check that before processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-      WindowedValue<?> secondValue =
-          valueInWindows(
-              KV.of(
-                  KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
-              window1,
-              window2);
-      mainInput.accept(secondValue);
-      // Check that after processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-
-      assertThat(
-          mainOutputValues,
-          contains(
-              WindowedValue.of(
-                  "5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5), window1, firstValue.getPane()),
-              WindowedValue.of(
-                  "5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6), window1, firstValue.getPane()),
-              WindowedValue.of(
-                  "5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7), window1, firstValue.getPane()),
-              WindowedValue.of(
-                  "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, firstValue.getPane()),
-              WindowedValue.of(
-                  "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, firstValue.getPane()),
-              WindowedValue.of(
-                  "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window2, firstValue.getPane()),
-              WindowedValue.of(
-                  "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window2, firstValue.getPane())));
-      assertTrue(splitListener.getPrimaryRoots().isEmpty());
-      assertTrue(splitListener.getResidualRoots().isEmpty());
-      mainOutputValues.clear();
-    }
-
-    {
-      // Setup and launch the trySplit thread.
-      ExecutorService executorService = Executors.newSingleThreadExecutor();
-      Future<HandlesSplits.SplitResult> trySplitFuture =
-          executorService.submit(
-              () -> {
-                try {
-                  doFn.waitForSplitElementToBeProcessed();
-                  // Currently processing "3" out of range [0, 5) elements for the first window.
-                  assertEquals(0.3, ((HandlesSplits) mainInput).getProgress(), 0.01);
-
-                  // Check that during progressing of an element we report progress
-                  List<MonitoringInfo> mis =
-                      Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos();
-                  MonitoringInfo.Builder expectedCompleted = MonitoringInfo.newBuilder();
-                  expectedCompleted.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED);
-                  expectedCompleted.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
-                  expectedCompleted.putLabels(
-                      MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
-                  expectedCompleted.setPayload(
-                      ByteString.copyFrom(
-                          CoderUtils.encodeToByteArray(
-                              IterableCoder.of(DoubleCoder.of()), Collections.singletonList(3.0))));
-                  MonitoringInfo.Builder expectedRemaining = MonitoringInfo.newBuilder();
-                  expectedRemaining.setUrn(MonitoringInfoConstants.Urns.WORK_REMAINING);
-                  expectedRemaining.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
-                  expectedRemaining.putLabels(
-                      MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
-                  expectedRemaining.setPayload(
-                      ByteString.copyFrom(
-                          CoderUtils.encodeToByteArray(
-                              IterableCoder.of(DoubleCoder.of()), Collections.singletonList(7.0))));
-                  assertThat(
-                      mis,
-                      containsInAnyOrder(expectedCompleted.build(), expectedRemaining.build()));
-
-                  return ((HandlesSplits) mainInput).trySplit(0);
-                } finally {
-                  doFn.releaseWaitingProcessElementThread();
-                }
-              });
-
-      // Check that before processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-      WindowedValue<?> splitValue =
-          valueInWindows(
-              KV.of(
-                  KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
-              window1,
-              window2);
-      mainInput.accept(splitValue);
-      HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
-
-      // Check that after processing an element we don't report progress
-      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
-
-      // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split on the first window.
-      // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
-      // has yet to be invoked for the split element and that the primary represents [0, 4) with
-      // the original watermark while the residual represents [4, 5) with the new MIN + 2 watermark.
-      //
-      // We expect to see none of the output for the second window.
-      assertThat(
-          mainOutputValues,
-          contains(
-              WindowedValue.of(
-                  "7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, splitValue.getPane()),
-              WindowedValue.of(
-                  "7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, splitValue.getPane()),
-              WindowedValue.of(
-                  "7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2), window1, splitValue.getPane()),
-              WindowedValue.of(
-                  "7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3), window1, splitValue.getPane())));
-
-      BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
-      assertEquals(2, trySplitResult.getResidualRoots().size());
-      DelayedBundleApplication residualRoot = trySplitResult.getResidualRoots().get(1);
-      DelayedBundleApplication residualRootInUnprocessedWindows =
-          trySplitResult.getResidualRoots().get(0);
-      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
-      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
-      assertEquals(
-          ParDoTranslation.getMainInputName(pTransform),
-          residualRoot.getApplication().getInputId());
-      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
-      assertEquals(
-          TEST_TRANSFORM_ID, residualRootInUnprocessedWindows.getApplication().getTransformId());
-      assertEquals(
-          residualRootInUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
-          residualRootInUnprocessedWindows.getRequestedTimeDelay());
-      assertTrue(
-          residualRootInUnprocessedWindows.getApplication().getOutputWatermarksMap().isEmpty());
-      assertEquals(
-          valueInWindows(
-              KV.of(
-                  KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)), 4.0),
-              window1),
-          inputCoder.decode(primaryRoot.getElement().newInput()));
-      assertEquals(
-          valueInWindows(
-              KV.of(
-                  KV.of(
-                      "7", KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
-                  1.0),
-              window1),
-          inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
-      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
-      assertEquals(
-          ImmutableMap.of(
-              "output",
-              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
-                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
-                  .build()),
-          residualRoot.getApplication().getOutputWatermarksMap());
-      assertEquals(
-          WindowedValue.of(
-              KV.of(
-                  KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
-              splitValue.getTimestamp(),
-              window2,
-              splitValue.getPane()),
-          inputCoder.decode(
-              residualRootInUnprocessedWindows.getApplication().getElement().newInput()));
-
-      // We expect 0 resume delay.
-      assertEquals(
-          residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
-          residualRoot.getRequestedTimeDelay());
-      // We don't expect the outputs to goto the SDK initiated checkpointing listener.
-      assertTrue(splitListener.getPrimaryRoots().isEmpty());
-      assertTrue(splitListener.getResidualRoots().isEmpty());
-      mainOutputValues.clear();
-      executorService.shutdown();
-    }
-
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
-    assertThat(mainOutputValues, empty());
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-
-    // Assert that state data did not change
-    assertEquals(stateData, fakeClient.getData());
-  }
-
-  private static <T> T decode(Coder<T> coder, ByteString value) {
-    try {
-      return coder.decode(value.newInput());
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
-  }
+        assertThat(
+            mainOutputValues,
+            contains(
+                timestampedValueInGlobalWindow("5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5)),
+                timestampedValueInGlobalWindow("5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6)),
+                timestampedValueInGlobalWindow("5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7)),
+                timestampedValueInGlobalWindow("2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
+                timestampedValueInGlobalWindow("2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))));
+        assertTrue(splitListener.getPrimaryRoots().isEmpty());
+        assertTrue(splitListener.getResidualRoots().isEmpty());
+        mainOutputValues.clear();
+      }
 
-  @Test
-  public void testProcessElementForPairWithRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection.apply(
-        TEST_TRANSFORM_ID,
-        ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
-            .withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInGlobalWindow("5"));
-    mainInput.accept(valueInGlobalWindow("2"));
-    assertThat(
-        mainOutputValues,
-        contains(
-            valueInGlobalWindow(
-                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE))),
+      {
+        // Setup and launch the trySplit thread.
+        ExecutorService executorService = Executors.newSingleThreadExecutor();
+        Future<HandlesSplits.SplitResult> trySplitFuture =
+            executorService.submit(
+                () -> {
+                  try {
+                    doFn.waitForSplitElementToBeProcessed();
+                    // Currently processing "3" out of range [0, 5) elements.
+                    assertEquals(0.6, ((HandlesSplits) mainInput).getProgress(), 0.01);
+
+                    // Check that during progressing of an element we report progress
+                    List<MonitoringInfo> mis =
+                        Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos();
+                    MonitoringInfo.Builder expectedCompleted = MonitoringInfo.newBuilder();
+                    expectedCompleted.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED);
+                    expectedCompleted.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                    expectedCompleted.putLabels(
+                        MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+                    expectedCompleted.setPayload(
+                        ByteString.copyFrom(
+                            CoderUtils.encodeToByteArray(
+                                IterableCoder.of(DoubleCoder.of()),
+                                Collections.singletonList(3.0))));
+                    MonitoringInfo.Builder expectedRemaining = MonitoringInfo.newBuilder();
+                    expectedRemaining.setUrn(MonitoringInfoConstants.Urns.WORK_REMAINING);
+                    expectedRemaining.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                    expectedRemaining.putLabels(
+                        MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+                    expectedRemaining.setPayload(
+                        ByteString.copyFrom(
+                            CoderUtils.encodeToByteArray(
+                                IterableCoder.of(DoubleCoder.of()),
+                                Collections.singletonList(2.0))));
+                    assertThat(
+                        mis,
+                        containsInAnyOrder(expectedCompleted.build(), expectedRemaining.build()));
+
+                    return ((HandlesSplits) mainInput).trySplit(0);
+                  } finally {
+                    doFn.releaseWaitingProcessElementThread();
+                  }
+                });
+
+        // Check that before processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+        mainInput.accept(
             valueInGlobalWindow(
-                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)))));
-    mainOutputValues.clear();
-
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
-
-  @Test
-  public void testProcessElementForWindowedPairWithRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection
-        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
-        .apply(
-            TEST_TRANSFORM_ID,
-            ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
-                .withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
-    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
-    WindowedValue<?> firstValue = valueInWindows("5", window1, window2);
-    WindowedValue<?> secondValue = valueInWindows("2", window1, window2);
-    mainInput.accept(firstValue);
-    mainInput.accept(secondValue);
-    assertThat(
-        mainOutputValues,
-        contains(
-            WindowedValue.of(
-                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                firstValue.getTimestamp(),
-                window1,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                firstValue.getTimestamp(),
-                window2,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                secondValue.getTimestamp(),
-                window1,
-                secondValue.getPane()),
-            WindowedValue.of(
-                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                secondValue.getTimestamp(),
-                window2,
-                secondValue.getPane())));
-    mainOutputValues.clear();
-
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
-
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
-
-  @Test
-  public void testProcessElementForWindowedPairWithRestrictionWithNonWindowObservingOptimization()
-      throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection
-        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
-        .apply(TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
-    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
-    WindowedValue<?> firstValue = valueInWindows("5", window1, window2);
-    WindowedValue<?> secondValue = valueInWindows("2", window1, window2);
-    mainInput.accept(firstValue);
-    mainInput.accept(secondValue);
-    // Ensure that each output element is in all the windows and not one per window.
-    assertThat(
-        mainOutputValues,
-        contains(
-            WindowedValue.of(
-                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                firstValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                secondValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                secondValue.getPane())));
-    mainOutputValues.clear();
-
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
+                KV.of(
+                    KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    2.0)));
+        HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
+        // Check that after processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
 
-  @Test
-  public void testProcessElementForSplitAndSizeRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection.apply(
-        TEST_TRANSFORM_ID,
-        ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
-            .withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(
-                                PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(
-        valueInGlobalWindow(
-            KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE))));
-    mainInput.accept(
-        valueInGlobalWindow(
-            KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE))));
-    assertThat(
-        mainOutputValues,
-        contains(
-            valueInGlobalWindow(
-                KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    2.0)),
+        // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split.
+        // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
+        // has yet to be invoked for the split element and that the primary represents [0, 4) with
+        // the original watermark while the residual represents [4, 5) with the new MIN + 2
+        // watermark.
+        assertThat(
+            mainOutputValues,
+            contains(
+                timestampedValueInGlobalWindow("7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
+                timestampedValueInGlobalWindow("7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1)),
+                timestampedValueInGlobalWindow("7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
+                timestampedValueInGlobalWindow("7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3))));
+
+        BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+        DelayedBundleApplication residualRoot =
+            Iterables.getOnlyElement(trySplitResult.getResidualRoots());
+        assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+        assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRoot.getApplication().getInputId());
+        assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+        assertEquals(
             valueInGlobalWindow(
                 KV.of(
-                    KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    3.0)),
+                    KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    4.0)),
+            inputCoder.decode(primaryRoot.getElement().newInput()));
+        assertEquals(
             valueInGlobalWindow(
                 KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    KV.of(
+                        "7",
+                        KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
                     1.0)),
-            valueInGlobalWindow(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0))));
-    mainOutputValues.clear();
+            inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+        Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
+        assertEquals(
+            ImmutableMap.of(
+                "output",
+                org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                    .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                    .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                    .build()),
+            residualRoot.getApplication().getOutputWatermarksMap());
+        // We expect 0 resume delay.
+        assertEquals(
+            residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+            residualRoot.getRequestedTimeDelay());
+        // We don't expect the outputs to goto the SDK initiated checkpointing listener.
+        assertTrue(splitListener.getPrimaryRoots().isEmpty());
+        assertTrue(splitListener.getResidualRoots().isEmpty());
+        mainOutputValues.clear();
+        executorService.shutdown();
+      }
 
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
 
-  @Test
-  public void testProcessElementForWindowedSplitAndSizeRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection
-        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
-        .apply(
-            TEST_TRANSFORM_ID,
-            ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
-                .withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(
-                                PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
-    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
-    WindowedValue<?> firstValue =
-        valueInWindows(
-            KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-            window1,
-            window2);
-    WindowedValue<?> secondValue =
-        valueInWindows(
-            KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-            window1,
-            window2);
-    mainInput.accept(firstValue);
-    mainInput.accept(secondValue);
-    assertThat(
-        mainOutputValues,
-        contains(
-            WindowedValue.of(
+      // Assert that state data did not change
+      assertEquals(stateData, fakeClient.getData());
+    }
+
+    @Test
+    public void testProcessElementForWindowedSizedElementAndRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
+      // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
+      addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      WindowObservingTestSplittableDoFn doFn =
+          new WindowObservingTestSplittableDoFn(singletonSideInputView);
+
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation
+                                      .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      RunnerApi.PCollection inputPCollection =
+          pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+      RehydratedComponents rehydratedComponents =
+          RehydratedComponents.forComponents(pProto.getComponents());
+      Coder<WindowedValue> inputCoder =
+          WindowedValue.getFullCoder(
+              CoderTranslation.fromProto(
+                  pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                  rehydratedComponents,
+                  TranslationContext.DEFAULT),
+              (Coder)
+                  CoderTranslation.fromProto(
+                      pProto
+                          .getComponents()
+                          .getCodersOrThrow(
+                              pProto
+                                  .getComponents()
+                                  .getWindowingStrategiesOrThrow(
+                                      inputPCollection.getWindowingStrategyId())
+                                  .getWindowCoderId()),
+                      rehydratedComponents,
+                      TranslationContext.DEFAULT));
+      String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+      ImmutableMap<StateKey, ByteString> stateData =
+          ImmutableMap.of(
+              iterableSideInputKey(
+                  singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
+              encode("8"));
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+      List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+      List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
+      BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              progressRequestCallbacks::add,
+              splitListener,
+              null /* bundleFinalizer */);
+
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      BoundedWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      BoundedWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      {
+        // Check that before processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+        WindowedValue<?> firstValue =
+            valueInWindows(
                 KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    2.0),
-                firstValue.getTimestamp(),
+                    KV.of(
+                        "5",
+                        KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                    5.0),
                 window1,
-                firstValue.getPane()),
+                window2);
+        mainInput.accept(firstValue);
+        // Check that after processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+        // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
+        // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
+        // with the original watermark while the residual represents [8, 10) with the new MIN + 7
+        // watermark.
+        //
+        // Since we were on the first window, we expect only a single primary root and two residual
+        // roots (the split + the unprocessed window).
+        BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+        assertEquals(2, splitListener.getResidualRoots().size());
+        DelayedBundleApplication residualRoot = splitListener.getResidualRoots().get(1);
+        DelayedBundleApplication residualRootForUnprocessedWindows =
+            splitListener.getResidualRoots().get(0);
+        assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+        assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRoot.getApplication().getInputId());
+        assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+        Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
+        Map<String, org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp>
+            expectedOutputWatmermarkMap =
+                ImmutableMap.of(
+                    "output",
+                    org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                        .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                        .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                        .build());
+        Instant initialWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1);
+        Map<String, org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp>
+            expectedOutputWatmermarkMapForUnprocessedWindows =
+                ImmutableMap.of(
+                    "output",
+                    org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                        .setSeconds(initialWatermark.getMillis() / 1000)
+                        .setNanos((int) (initialWatermark.getMillis() % 1000) * 1000000)
+                        .build());
+        assertEquals(
+            expectedOutputWatmermarkMap, residualRoot.getApplication().getOutputWatermarksMap());
+        assertEquals(
+            org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
+                .setSeconds(54)
+                .setNanos(321000000)
+                .build(),
+            residualRoot.getRequestedTimeDelay());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRootForUnprocessedWindows.getApplication().getInputId());
+        assertEquals(
+            TEST_TRANSFORM_ID, residualRootForUnprocessedWindows.getApplication().getTransformId());
+        assertEquals(
+            residualRootForUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
+            residualRootForUnprocessedWindows.getRequestedTimeDelay());
+        assertEquals(
+            expectedOutputWatmermarkMapForUnprocessedWindows,
+            residualRootForUnprocessedWindows.getApplication().getOutputWatermarksMap());
+
+        assertEquals(
+            decode(inputCoder, primaryRoot.getElement()),
             WindowedValue.of(
                 KV.of(
-                    KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    KV.of(
+                        "5",
+                        KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
                     3.0),
                 firstValue.getTimestamp(),
                 window1,
-                firstValue.getPane()),
+                firstValue.getPane()));
+        assertEquals(
+            decode(inputCoder, residualRoot.getApplication().getElement()),
             WindowedValue.of(
                 KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    KV.of(
+                        "5",
+                        KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
                     2.0),
                 firstValue.getTimestamp(),
-                window2,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    3.0),
-                firstValue.getTimestamp(),
-                window2,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
-                firstValue.getTimestamp(),
-                window1,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
-                firstValue.getTimestamp(),
                 window1,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
-                firstValue.getTimestamp(),
-                window2,
-                firstValue.getPane()),
+                firstValue.getPane()));
+        assertEquals(
+            decode(inputCoder, residualRootForUnprocessedWindows.getApplication().getElement()),
             WindowedValue.of(
                 KV.of(
-                    KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
+                    KV.of(
+                        "5",
+                        KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                    5.0),
                 firstValue.getTimestamp(),
                 window2,
-                firstValue.getPane())));
-    mainOutputValues.clear();
+                firstValue.getPane()));
+        splitListener.clear();
 
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
+        // Check that before processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+        WindowedValue<?> secondValue =
+            valueInWindows(
+                KV.of(
+                    KV.of(
+                        "2",
+                        KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                    2.0),
+                window1,
+                window2);
+        mainInput.accept(secondValue);
+        // Check that after processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
+        assertThat(
+            mainOutputValues,
+            contains(
+                WindowedValue.of(
+                    "5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5), window1, firstValue.getPane()),
+                WindowedValue.of(
+                    "5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6), window1, firstValue.getPane()),
+                WindowedValue.of(
+                    "5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7), window1, firstValue.getPane()),
+                WindowedValue.of(
+                    "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, firstValue.getPane()),
+                WindowedValue.of(
+                    "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, firstValue.getPane()),
+                WindowedValue.of(
+                    "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window2, firstValue.getPane()),
+                WindowedValue.of(
+                    "2:1",
+                    GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1),
+                    window2,
+                    firstValue.getPane())));
+        assertTrue(splitListener.getPrimaryRoots().isEmpty());
+        assertTrue(splitListener.getResidualRoots().isEmpty());
+        mainOutputValues.clear();
+      }
 
-  @Test
-  public void
-      testProcessElementForWindowedSplitAndSizeRestrictionWithNonWindowObservingOptimization()
-          throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection
-        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
-        .apply(TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createSizedReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(
-                                PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            null /* beamFnStateClient */,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
-    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
-    WindowedValue<?> firstValue =
-        valueInWindows(
-            KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-            window1,
-            window2);
-    WindowedValue<?> secondValue =
-        valueInWindows(
-            KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-            window1,
-            window2);
-    mainInput.accept(firstValue);
-    mainInput.accept(secondValue);
-    // Ensure that each output element is in all the windows and not one per window.
-    assertThat(
-        mainOutputValues,
-        contains(
-            WindowedValue.of(
+      {
+        // Setup and launch the trySplit thread.
+        ExecutorService executorService = Executors.newSingleThreadExecutor();
+        Future<HandlesSplits.SplitResult> trySplitFuture =
+            executorService.submit(
+                () -> {
+                  try {
+                    doFn.waitForSplitElementToBeProcessed();
+                    // Currently processing "3" out of range [0, 5) elements for the first window.
+                    assertEquals(0.3, ((HandlesSplits) mainInput).getProgress(), 0.01);
+
+                    // Check that during progressing of an element we report progress
+                    List<MonitoringInfo> mis =
+                        Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos();
+                    MonitoringInfo.Builder expectedCompleted = MonitoringInfo.newBuilder();
+                    expectedCompleted.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED);
+                    expectedCompleted.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                    expectedCompleted.putLabels(
+                        MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+                    expectedCompleted.setPayload(
+                        ByteString.copyFrom(
+                            CoderUtils.encodeToByteArray(
+                                IterableCoder.of(DoubleCoder.of()),
+                                Collections.singletonList(3.0))));
+                    MonitoringInfo.Builder expectedRemaining = MonitoringInfo.newBuilder();
+                    expectedRemaining.setUrn(MonitoringInfoConstants.Urns.WORK_REMAINING);
+                    expectedRemaining.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                    expectedRemaining.putLabels(
+                        MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+                    expectedRemaining.setPayload(
+                        ByteString.copyFrom(
+                            CoderUtils.encodeToByteArray(
+                                IterableCoder.of(DoubleCoder.of()),
+                                Collections.singletonList(7.0))));
+                    assertThat(
+                        mis,
+                        containsInAnyOrder(expectedCompleted.build(), expectedRemaining.build()));
+
+                    return ((HandlesSplits) mainInput).trySplit(0);
+                  } finally {
+                    doFn.releaseWaitingProcessElementThread();
+                  }
+                });
+
+        // Check that before processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+        WindowedValue<?> splitValue =
+            valueInWindows(
                 KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    KV.of(
+                        "7",
+                        KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
                     2.0),
-                firstValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                firstValue.getPane()),
-            WindowedValue.of(
+                window1,
+                window2);
+        mainInput.accept(splitValue);
+        HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+
+        // Check that after processing an element we don't report progress
+        assertThat(
+            Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+        // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split on the first
+        // window.
+        // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
+        // has yet to be invoked for the split element and that the primary represents [0, 4) with
+        // the original watermark while the residual represents [4, 5) with the new MIN + 2
+        // watermark.
+        //
+        // We expect to see none of the output for the second window.
+        assertThat(
+            mainOutputValues,
+            contains(
+                WindowedValue.of(
+                    "7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, splitValue.getPane()),
+                WindowedValue.of(
+                    "7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, splitValue.getPane()),
+                WindowedValue.of(
+                    "7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2), window1, splitValue.getPane()),
+                WindowedValue.of(
+                    "7:3",
+                    GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3),
+                    window1,
+                    splitValue.getPane())));
+
+        BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+        assertEquals(2, trySplitResult.getResidualRoots().size());
+        DelayedBundleApplication residualRoot = trySplitResult.getResidualRoots().get(1);
+        DelayedBundleApplication residualRootInUnprocessedWindows =
+            trySplitResult.getResidualRoots().get(0);
+        assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+        assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+        assertEquals(
+            ParDoTranslation.getMainInputName(pTransform),
+            residualRoot.getApplication().getInputId());
+        assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+        assertEquals(
+            TEST_TRANSFORM_ID, residualRootInUnprocessedWindows.getApplication().getTransformId());
+        assertEquals(
+            residualRootInUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
+            residualRootInUnprocessedWindows.getRequestedTimeDelay());
+        Instant initialWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1);
+        Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
+        Map<String, org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp>
+            expectedOutputWatermarkMapInUnprocessedResiduals =
+                ImmutableMap.of(
+                    "output",
+                    org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                        .setSeconds(initialWatermark.getMillis() / 1000)
+                        .setNanos((int) (initialWatermark.getMillis() % 1000) * 1000000)
+                        .build());
+        Map<String, org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp>
+            expectedOutputWatermarkMap =
+                ImmutableMap.of(
+                    "output",
+                    org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                        .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                        .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                        .build());
+        assertEquals(
+            expectedOutputWatermarkMapInUnprocessedResiduals,
+            residualRootInUnprocessedWindows.getApplication().getOutputWatermarksMap());
+        assertEquals(
+            valueInWindows(
                 KV.of(
-                    KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    3.0),
-                firstValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                firstValue.getPane()),
-            WindowedValue.of(
+                    KV.of(
+                        "7",
+                        KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                    4.0),
+                window1),
+            inputCoder.decode(primaryRoot.getElement().newInput()));
+        assertEquals(
+            valueInWindows(
                 KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    KV.of(
+                        "7",
+                        KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
                     1.0),
-                firstValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                firstValue.getPane()),
+                window1),
+            inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+        assertEquals(
+            expectedOutputWatermarkMap, residualRoot.getApplication().getOutputWatermarksMap());
+        assertEquals(
             WindowedValue.of(
                 KV.of(
-                    KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
-                firstValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                firstValue.getPane())));
-    mainOutputValues.clear();
+                    KV.of(
+                        "7",
+                        KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                    5.0),
+                splitValue.getTimestamp(),
+                window2,
+                splitValue.getPane()),
+            inputCoder.decode(
+                residualRootInUnprocessedWindows.getApplication().getElement().newInput()));
+
+        // We expect 0 resume delay.
+        assertEquals(
+            residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+            residualRoot.getRequestedTimeDelay());
+        // We don't expect the outputs to goto the SDK initiated checkpointing listener.
+        assertTrue(splitListener.getPrimaryRoots().isEmpty());
+        assertTrue(splitListener.getResidualRoots().isEmpty());
+        mainOutputValues.clear();
+        executorService.shutdown();
+      }
 
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+      assertThat(mainOutputValues, empty());
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
 
-  private static SplitResult createSplitResult(double fractionOfRemainder) {
-    ByteString.Output primaryBytes = ByteString.newOutput();
-    ByteString.Output residualBytes = ByteString.newOutput();
-    try {
-      DoubleCoder.of().encode(fractionOfRemainder, primaryBytes);
-      DoubleCoder.of().encode(1 - fractionOfRemainder, residualBytes);
-    } catch (Exception e) {
-      // No-op.
-    }
-    return SplitResult.of(
-        ImmutableList.of(
-            BundleApplication.newBuilder().setElement(primaryBytes.toByteString()).build()),
-        ImmutableList.of(
-            DelayedBundleApplication.newBuilder()
-                .setApplication(
-                    BundleApplication.newBuilder().setElement(residualBytes.toByteString()).build())
-                .build()));
-  }
+      // Assert that state data did not change
+      assertEquals(stateData, fakeClient.getData());
+    }
 
-  private static class SplittableFnDataReceiver
-      implements HandlesSplits, FnDataReceiver<WindowedValue> {
-    SplittableFnDataReceiver(
-        List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues) {
-      this.mainOutputValues = mainOutputValues;
+    private static <T> T decode(Coder<T> coder, ByteString value) {
+      try {
+        return coder.decode(value.newInput());
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
     }
 
-    private final List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues;
+    @Test
+    public void testProcessElementForPairWithRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      valuePCollection.apply(
+          TEST_TRANSFORM_ID,
+          ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
+              .withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
 
-    @Override
-    public SplitResult trySplit(double fractionOfRemainder) {
-      return createSplitResult(fractionOfRemainder);
-    }
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
 
-    @Override
-    public double getProgress() {
-      return 0.7;
-    }
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(valueInGlobalWindow("5"));
+      mainInput.accept(valueInGlobalWindow("2"));
+      assertThat(
+          mainOutputValues,
+          contains(
+              valueInGlobalWindow(
+                  KV.of(
+                      "5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1)))),
+              valueInGlobalWindow(
+                  KV.of(
+                      "2",
+                      KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))))));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
 
-    @Override
-    public void accept(WindowedValue input) throws Exception {
-      mainOutputValues.add(input);
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
     }
-  }
 
-  @Test
-  public void testProcessElementForTruncateAndSizeRestrictionForwardSplitWhenObservingWindow()
-      throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection.apply(
-        TEST_TRANSFORM_ID,
-        ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
-            .withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createTruncateReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    assertThat(mainInput, instanceOf(HandlesSplits.class));
-
-    assertEquals(0, ((HandlesSplits) mainInput).getProgress(), 0.0);
-    assertNull(((HandlesSplits) mainInput).trySplit(0.4));
-  }
+    @Test
+    public void testProcessElementForWindowedPairWithRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(
+              TEST_TRANSFORM_ID,
+              ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
+                  .withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
 
-  @Test
-  public void testProcessElementForTruncateAndSizeRestrictionForwardSplitWhenoutObservingWindow()
-      throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    valuePCollection.apply(TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createTruncateReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    assertThat(mainInput, instanceOf(HandlesSplits.class));
-
-    assertEquals(0.7, ((HandlesSplits) mainInput).getProgress(), 0.0);
-    assertEquals(createSplitResult(0.4), ((HandlesSplits) mainInput).trySplit(0.4));
-  }
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
 
-  @Test
-  public void testProcessElementForTruncateAndSizeRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection.apply(
-        TEST_TRANSFORM_ID,
-        ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
-            .withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createTruncateReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    assertThat(mainInput, instanceOf(HandlesSplits.class));
-
-    mainInput.accept(
-        valueInGlobalWindow(
-            KV.of(
-                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0)));
-    mainInput.accept(
-        valueInGlobalWindow(
-            KV.of(
-                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0)));
-    assertThat(
-        mainOutputValues,
-        contains(
-            valueInGlobalWindow(
-                KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    2.0)),
-            valueInGlobalWindow(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0))));
-    mainOutputValues.clear();
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      WindowedValue<?> firstValue = valueInWindows("5", window1, window2);
+      WindowedValue<?> secondValue = valueInWindows("2", window1, window2);
+      mainInput.accept(firstValue);
+      mainInput.accept(secondValue);
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      "5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      "5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      "2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                  secondValue.getTimestamp(),
+                  window1,
+                  secondValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      "2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                  secondValue.getTimestamp(),
+                  window2,
+                  secondValue.getPane())));
+      mainOutputValues.clear();
 
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
-  }
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
 
-  @Test
-  public void testProcessElementForWindowedTruncateAndSizeRestriction() throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection
-        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
-        .apply(
-            TEST_TRANSFORM_ID,
-            ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
-                .withSideInputs(singletonSideInputView));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createTruncateReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            fakeClient,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    assertThat(mainInput, instanceOf(HandlesSplits.class));
-
-    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
-    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
-    WindowedValue<?> firstValue =
-        valueInWindows(
-            KV.of(KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
-            window1,
-            window2);
-    WindowedValue<?> secondValue =
-        valueInWindows(
-            KV.of(KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
-            window1,
-            window2);
-    mainInput.accept(firstValue);
-    mainInput.accept(secondValue);
-    assertThat(
-        mainOutputValues,
-        contains(
-            WindowedValue.of(
-                KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    2.0),
-                firstValue.getTimestamp(),
-                window1,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    2.0),
-                firstValue.getTimestamp(),
-                window2,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
-                firstValue.getTimestamp(),
-                window1,
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
-                firstValue.getTimestamp(),
-                window2,
-                firstValue.getPane())));
-    mainOutputValues.clear();
+    @Test
+    public void testProcessElementForWindowedPairWithRestrictionWithNonWindowObservingOptimization()
+        throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      WindowedValue<?> firstValue = valueInWindows("5", window1, window2);
+      WindowedValue<?> secondValue = valueInWindows("2", window1, window2);
+      mainInput.accept(firstValue);
+      mainInput.accept(secondValue);
+      // Ensure that each output element is in all the windows and not one per window.
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      "5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                  firstValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      "2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))),
+                  secondValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  secondValue.getPane())));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
+
+    @Test
+    public void testProcessElementForSplitAndSizeRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      valuePCollection.apply(
+          TEST_TRANSFORM_ID,
+          ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
+              .withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      mainInput.accept(
+          valueInGlobalWindow(
+              KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE))));
+      mainInput.accept(
+          valueInGlobalWindow(
+              KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE))));
+      assertThat(
+          mainOutputValues,
+          contains(
+              valueInGlobalWindow(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0)),
+              valueInGlobalWindow(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0)),
+              valueInGlobalWindow(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0)),
+              valueInGlobalWindow(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0))));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
+
+    @Test
+    public void testProcessElementForWindowedSplitAndSizeRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(
+              TEST_TRANSFORM_ID,
+              ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
+                  .withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      WindowedValue<?> firstValue =
+          valueInWindows(
+              KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+              window1,
+              window2);
+      WindowedValue<?> secondValue =
+          valueInWindows(
+              KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+              window1,
+              window2);
+      mainInput.accept(firstValue);
+      mainInput.accept(secondValue);
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane())));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
+
+    @Test
+    public void
+        testProcessElementForWindowedSplitAndSizeRestrictionWithNonWindowObservingOptimization()
+            throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createSizedReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              null /* beamFnStateClient */,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      WindowedValue<?> firstValue =
+          valueInWindows(
+              KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+              window1,
+              window2);
+      WindowedValue<?> secondValue =
+          valueInWindows(
+              KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+              window1,
+              window2);
+      mainInput.accept(firstValue);
+      mainInput.accept(secondValue);
+      // Ensure that each output element is in all the windows and not one per window.
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0),
+                  firstValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  firstValue.getPane())));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
+
+    private static SplitResult createSplitResult(double fractionOfRemainder) {
+      ByteString.Output primaryBytes = ByteString.newOutput();
+      ByteString.Output residualBytes = ByteString.newOutput();
+      try {
+        DoubleCoder.of().encode(fractionOfRemainder, primaryBytes);
+        DoubleCoder.of().encode(1 - fractionOfRemainder, residualBytes);
+      } catch (Exception e) {
+        // No-op.
+      }
+      return SplitResult.of(
+          ImmutableList.of(
+              BundleApplication.newBuilder()
+                  .setElement(primaryBytes.toByteString())
+                  .setInputId("mainInputId-process")
+                  .setTransformId("processPTransfromId")
+                  .build()),
+          ImmutableList.of(
+              DelayedBundleApplication.newBuilder()
+                  .setApplication(
+                      BundleApplication.newBuilder()
+                          .setElement(residualBytes.toByteString())
+                          .setInputId("mainInputId-process")
+                          .setTransformId("processPTransfromId")
+                          .build())
+                  .build()));
+    }
+
+    private static class SplittableFnDataReceiver
+        implements HandlesSplits, FnDataReceiver<WindowedValue> {
+      SplittableFnDataReceiver(
+          List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues) {
+        this.mainOutputValues = mainOutputValues;
+      }
+
+      private final List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues;
+
+      @Override
+      public SplitResult trySplit(double fractionOfRemainder) {
+        return createSplitResult(fractionOfRemainder);
+      }
+
+      @Override
+      public double getProgress() {
+        return 0.7;
+      }
+
+      @Override
+      public void accept(WindowedValue input) throws Exception {
+        mainOutputValues.add(input);
+      }
+    }
+
+    @Test
+    public void testProcessElementForTruncateAndSizeRestrictionForwardSplitWhenObservingWindows()
+        throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      WindowObservingTestSplittableDoFn doFn =
+          WindowObservingTestSplittableDoFn.forSplitAtTruncate(singletonSideInputView);
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createTruncateReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      RunnerApi.PCollection inputPCollection =
+          pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+      RehydratedComponents rehydratedComponents =
+          RehydratedComponents.forComponents(pProto.getComponents());
+      Coder<WindowedValue> inputCoder =
+          WindowedValue.getFullCoder(
+              CoderTranslation.fromProto(
+                  pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                  rehydratedComponents,
+                  TranslationContext.DEFAULT),
+              (Coder)
+                  CoderTranslation.fromProto(
+                      pProto
+                          .getComponents()
+                          .getCodersOrThrow(
+                              pProto
+                                  .getComponents()
+                                  .getWindowingStrategiesOrThrow(
+                                      inputPCollection.getWindowingStrategyId())
+                                  .getWindowCoderId()),
+                      rehydratedComponents,
+                      TranslationContext.DEFAULT));
+
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+      List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
+      BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      mainOutputValues.clear();
+      BoundedWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      BoundedWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      BoundedWindow window3 = new IntervalWindow(new Instant(7), new Instant(12));
+      // Setup and launch the trySplit thread.
+      ExecutorService executorService = Executors.newSingleThreadExecutor();
+      Future<HandlesSplits.SplitResult> trySplitFuture =
+          executorService.submit(
+              () -> {
+                try {
+                  doFn.waitForSplitElementToBeProcessed();
 
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
+                  return ((HandlesSplits) mainInput).trySplit(0);
+                } finally {
+                  doFn.releaseWaitingProcessElementThread();
+                }
+              });
+
+      WindowedValue<?> splitValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 6), GlobalWindow.TIMESTAMP_MIN_VALUE)), 6.0),
+              window1,
+              window2,
+              window3);
+      mainInput.accept(splitValue);
+      HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+
+      // We expect that there are outputs from window1 and window2
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("7", KV.of(new OffsetRange(0, 3), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0),
+                  splitValue.getTimestamp(),
+                  window1,
+                  splitValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("7", KV.of(new OffsetRange(0, 3), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      3.0),
+                  splitValue.getTimestamp(),
+                  window2,
+                  splitValue.getPane())));
+
+      SplitResult expectedElementSplit = createSplitResult(0);
+      BundleApplication expectedElementSplitPrimary =
+          Iterables.getOnlyElement(expectedElementSplit.getPrimaryRoots());
+      ByteString.Output primaryBytes = ByteString.newOutput();
+      inputCoder.encode(
+          WindowedValue.of(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 6), GlobalWindow.TIMESTAMP_MIN_VALUE)), 6.0),
+              splitValue.getTimestamp(),
+              window1,
+              splitValue.getPane()),
+          primaryBytes);
+      BundleApplication expectedWindowedPrimary =
+          BundleApplication.newBuilder()
+              .setElement(primaryBytes.toByteString())
+              .setInputId(ParDoTranslation.getMainInputName(pTransform))
+              .setTransformId(TEST_TRANSFORM_ID)
+              .build();
+      DelayedBundleApplication expectedElementSplitResidual =
+          Iterables.getOnlyElement(expectedElementSplit.getResidualRoots());
+      ByteString.Output residualBytes = ByteString.newOutput();
+      inputCoder.encode(
+          WindowedValue.of(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 6), GlobalWindow.TIMESTAMP_MIN_VALUE)), 6.0),
+              splitValue.getTimestamp(),
+              window3,
+              splitValue.getPane()),
+          residualBytes);
+      DelayedBundleApplication expectedWindowedResidual =
+          DelayedBundleApplication.newBuilder()
+              .setApplication(
+                  BundleApplication.newBuilder()
+                      .setElement(residualBytes.toByteString())
+                      .setInputId(ParDoTranslation.getMainInputName(pTransform))
+                      .setTransformId(TEST_TRANSFORM_ID)
+                      .build())
+              .build();
+      assertThat(
+          trySplitResult.getPrimaryRoots(),
+          contains(expectedWindowedPrimary, expectedElementSplitPrimary));
+      assertThat(
+          trySplitResult.getResidualRoots(),
+          contains(expectedWindowedResidual, expectedElementSplitResidual));
+    }
+
+    @Test
+    public void testProcessElementForTruncateAndSizeRestrictionForwardSplitWithoutObservingWindow()
+        throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      valuePCollection.apply(
+          TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createTruncateReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      assertEquals(0.7, ((HandlesSplits) mainInput).getProgress(), 0.0);
+      assertEquals(createSplitResult(0.4), ((HandlesSplits) mainInput).trySplit(0.4));
+    }
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
+    @Test
+    public void testProcessElementForTruncateAndSizeRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      valuePCollection.apply(
+          TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createTruncateReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      mainInput.accept(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                  5.0)));
+      mainInput.accept(
+          valueInGlobalWindow(
+              KV.of(
+                  KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                  2.0)));
+      assertThat(
+          mainOutputValues,
+          contains(
+              valueInGlobalWindow(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0)),
+              valueInGlobalWindow(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0))));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
+
+    @Test
+    public void testProcessElementForWindowedTruncateAndSizeRestriction() throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(
+              TEST_TRANSFORM_ID,
+              ParDo.of(new WindowObservingTestSplittableDoFn(singletonSideInputView))
+                  .withSideInputs(singletonSideInputView));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createTruncateReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              fakeClient,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      WindowedValue<?> firstValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+              window1,
+              window2);
+      WindowedValue<?> secondValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+              window1,
+              window2);
+      mainInput.accept(firstValue);
+      mainInput.accept(secondValue);
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  window1,
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  window2,
+                  firstValue.getPane())));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
+
+    @Test
+    public void
+        testProcessElementForWindowedTruncateAndSizeRestrictionWithNonWindowObservingOptimization()
+            throws Exception {
+      Pipeline p = Pipeline.create();
+      PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+      valuePCollection
+          .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+          .apply(TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
+
+      RunnerApi.Pipeline pProto =
+          ProtoOverrides.updateTransform(
+              PTransformTranslation.PAR_DO_TRANSFORM_URN,
+              PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+              SplittableParDoExpander.createTruncateReplacement());
+      String expandedTransformId =
+          Iterables.find(
+                  pProto.getComponents().getTransformsMap().entrySet(),
+                  entry ->
+                      entry
+                              .getValue()
+                              .getSpec()
+                              .getUrn()
+                              .equals(
+                                  PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
+                          && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+              .getKey();
+      RunnerApi.PTransform pTransform =
+          pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+      String inputPCollectionId =
+          pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+      String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+      List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      consumers.register(
+          outputPCollectionId,
+          TEST_TRANSFORM_ID,
+          (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      new FnApiDoFnRunner.Factory<>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              null /* beamFnDataClient */,
+              null /* beamFnStateClient */,
+              null /* beamFnTimerClient */,
+              TEST_TRANSFORM_ID,
+              pTransform,
+              Suppliers.ofInstance("57L")::get,
+              pProto.getComponents().getPcollectionsMap(),
+              pProto.getComponents().getCodersMap(),
+              pProto.getComponents().getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              null /* addResetFunction */,
+              teardownFunctions::add,
+              null /* addProgressRequestCallback */,
+              null /* bundleSplitListener */,
+              null /* bundleFinalizer */);
+
+      assertTrue(startFunctionRegistry.getFunctions().isEmpty());
+      mainOutputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+      FnDataReceiver<WindowedValue<?>> mainInput =
+          consumers.getMultiplexingConsumer(inputPCollectionId);
+      assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+      IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+      IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+      WindowedValue<?> firstValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+              window1,
+              window2);
+      WindowedValue<?> secondValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+              window1,
+              window2);
+      mainInput.accept(firstValue);
+      mainInput.accept(secondValue);
+      // Ensure that each output element is in all the windows and not one per window.
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      2.0),
+                  firstValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  firstValue.getPane()),
+              WindowedValue.of(
+                  KV.of(
+                      KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                      1.0),
+                  firstValue.getTimestamp(),
+                  ImmutableList.of(window1, window2),
+                  firstValue.getPane())));
+      mainOutputValues.clear();
+
+      assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
+      assertThat(mainOutputValues, empty());
+
+      Iterables.getOnlyElement(teardownFunctions).run();
+      assertThat(mainOutputValues, empty());
+    }
   }
 
-  @Test
-  public void
-      testProcessElementForWindowedTruncateAndSizeRestrictionWithNonWindowObservingOptimization()
-          throws Exception {
-    Pipeline p = Pipeline.create();
-    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
-    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
-    valuePCollection
-        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
-        .apply(TEST_TRANSFORM_ID, ParDo.of(new NonWindowObservingTestSplittableDoFn()));
-
-    RunnerApi.Pipeline pProto =
-        ProtoOverrides.updateTransform(
-            PTransformTranslation.PAR_DO_TRANSFORM_URN,
-            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
-            SplittableParDoExpander.createTruncateReplacement());
-    String expandedTransformId =
-        Iterables.find(
-                pProto.getComponents().getTransformsMap().entrySet(),
-                entry ->
-                    entry
-                            .getValue()
-                            .getSpec()
-                            .getUrn()
-                            .equals(PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN)
-                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
-            .getKey();
-    RunnerApi.PTransform pTransform =
-        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
-    String inputPCollectionId =
-        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
-    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
-
-    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    consumers.register(
-        outputPCollectionId,
-        TEST_TRANSFORM_ID,
-        (FnDataReceiver) new SplittableFnDataReceiver(mainOutputValues));
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    new FnApiDoFnRunner.Factory<>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            null /* beamFnDataClient */,
-            null /* beamFnStateClient */,
-            null /* beamFnTimerClient */,
-            TEST_TRANSFORM_ID,
-            pTransform,
-            Suppliers.ofInstance("57L")::get,
-            pProto.getComponents().getPcollectionsMap(),
-            pProto.getComponents().getCodersMap(),
-            pProto.getComponents().getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            null /* addResetFunction */,
-            teardownFunctions::add,
-            null /* addProgressRequestCallback */,
-            null /* bundleSplitListener */,
-            null /* bundleFinalizer */);
-
-    assertTrue(startFunctionRegistry.getFunctions().isEmpty());
-    mainOutputValues.clear();
-
-    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
-
-    FnDataReceiver<WindowedValue<?>> mainInput =
-        consumers.getMultiplexingConsumer(inputPCollectionId);
-    assertThat(mainInput, instanceOf(HandlesSplits.class));
-
-    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
-    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
-    WindowedValue<?> firstValue =
-        valueInWindows(
-            KV.of(KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
-            window1,
-            window2);
-    WindowedValue<?> secondValue =
-        valueInWindows(
-            KV.of(KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
-            window1,
-            window2);
-    mainInput.accept(firstValue);
-    mainInput.accept(secondValue);
-    // Ensure that each output element is in all the windows and not one per window.
-    assertThat(
-        mainOutputValues,
-        contains(
-            WindowedValue.of(
-                KV.of(
-                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    2.0),
-                firstValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                firstValue.getPane()),
-            WindowedValue.of(
-                KV.of(
-                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
-                    1.0),
-                firstValue.getTimestamp(),
-                ImmutableList.of(window1, window2),
-                firstValue.getPane())));
-    mainOutputValues.clear();
+  @RunWith(JUnit4.class)
+  public static class SplitTest {
+    private IntervalWindow window1;
+    private IntervalWindow window2;
+    private IntervalWindow window3;
+    private WindowedValue<String> currentElement;
+    private OffsetRange currentRestriction;
+    private Instant currentWatermarkEstimatorState;
+    KV<Instant, Instant> watermarkAndState;
+
+    private KV<WindowedValue, WindowedValue> createSplitInWindow(
+        OffsetRange primaryRestriction, OffsetRange residualRestriction, BoundedWindow window) {
+      return KV.of(
+          WindowedValue.of(
+              KV.of(
+                  currentElement.getValue(),
+                  KV.of(primaryRestriction, currentWatermarkEstimatorState)),
+              currentElement.getTimestamp(),
+              window,
+              currentElement.getPane()),
+          WindowedValue.of(
+              KV.of(
+                  currentElement.getValue(),
+                  KV.of(residualRestriction, watermarkAndState.getValue())),
+              currentElement.getTimestamp(),
+              window,
+              currentElement.getPane()));
+    }
+
+    private KV<WindowedValue, WindowedValue> createSplitAcrossWindows(
+        List<BoundedWindow> primaryWindows, List<BoundedWindow> residualWindows) {
+      return KV.of(
+          primaryWindows.isEmpty()
+              ? null
+              : WindowedValue.of(
+                  KV.of(
+                      currentElement.getValue(),
+                      KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                  currentElement.getTimestamp(),
+                  primaryWindows,
+                  currentElement.getPane()),
+          residualWindows.isEmpty()
+              ? null
+              : WindowedValue.of(
+                  KV.of(
+                      currentElement.getValue(),
+                      KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                  currentElement.getTimestamp(),
+                  residualWindows,
+                  currentElement.getPane()));
+    }
 
-    assertTrue(finishFunctionRegistry.getFunctions().isEmpty());
-    assertThat(mainOutputValues, empty());
+    @Before
+    public void setUp() {
+      window1 = new IntervalWindow(Instant.ofEpochMilli(0), Instant.ofEpochMilli(10));
+      window2 = new IntervalWindow(Instant.ofEpochMilli(10), Instant.ofEpochMilli(20));
+      window3 = new IntervalWindow(Instant.ofEpochMilli(20), Instant.ofEpochMilli(30));
+      currentElement =
+          WindowedValue.of(
+              "a",
+              Instant.ofEpochMilli(57),
+              ImmutableList.of(window1, window2, window3),
+              PaneInfo.NO_FIRING);
+      currentRestriction = new OffsetRange(0L, 100L);
+      currentWatermarkEstimatorState = Instant.ofEpochMilli(21);
+      watermarkAndState = KV.of(Instant.ofEpochMilli(42), Instant.ofEpochMilli(42));
+    }
+
+    @Test
+    public void testScaleProgress() throws Exception {
+      Progress elementProgress = Progress.from(2, 8);
+      // There is only one window.
+      Progress scaledResult = FnApiDoFnRunner.scaleProgress(elementProgress, 0, 1);
+      assertEquals(2, scaledResult.getWorkCompleted(), 0.0);
+      assertEquals(8, scaledResult.getWorkRemaining(), 0.0);
+
+      // We are at the first window of 3 in total.
+      scaledResult = FnApiDoFnRunner.scaleProgress(elementProgress, 0, 3);
+      assertEquals(2, scaledResult.getWorkCompleted(), 0.0);
+      assertEquals(28, scaledResult.getWorkRemaining(), 0.0);
+
+      // We are at the second window of 3 in total.
+      scaledResult = FnApiDoFnRunner.scaleProgress(elementProgress, 1, 3);
+      assertEquals(12, scaledResult.getWorkCompleted(), 0.0);
+      assertEquals(18, scaledResult.getWorkRemaining(), 0.0);
+
+      // We are at the last window of 3 in total.
+      scaledResult = FnApiDoFnRunner.scaleProgress(elementProgress, 2, 3);
+      assertEquals(22, scaledResult.getWorkCompleted(), 0.0);
+      assertEquals(8, scaledResult.getWorkRemaining(), 0.0);
+    }
 
-    Iterables.getOnlyElement(teardownFunctions).run();
-    assertThat(mainOutputValues, empty());
+    @Test
+    public void testTrySplitForProcessCheckpointOnFirstWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.0,
+              tracker,
+              watermarkAndState,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedElementSplit =
+          createSplitInWindow(new OffsetRange(0, 31), new OffsetRange(31, 100), window1);
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(), ImmutableList.of(window2, window3));
+      assertEquals(expectedElementSplit.getKey(), result.getKey().getPrimarySplitRoot());
+      assertEquals(expectedElementSplit.getValue(), result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessCheckpointOnFirstWindowAfterOneSplit() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.0,
+              tracker,
+              watermarkAndState,
+              0,
+              2);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedElementSplit =
+          createSplitInWindow(new OffsetRange(0, 31), new OffsetRange(31, 100), window1);
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(), ImmutableList.of(window2));
+      assertEquals(expectedElementSplit.getKey(), result.getKey().getPrimarySplitRoot());
+      assertEquals(expectedElementSplit.getValue(), result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessSplitOnFirstWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.2,
+              tracker,
+              watermarkAndState,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedElementSplit =
+          createSplitInWindow(new OffsetRange(0, 84), new OffsetRange(84, 100), window1);
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(), ImmutableList.of(window2, window3));
+      assertEquals(expectedElementSplit.getKey(), result.getKey().getPrimarySplitRoot());
+      assertEquals(expectedElementSplit.getValue(), result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessSplitOnMiddleWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window2,
+              windows,
+              currentWatermarkEstimatorState,
+              0.2,
+              tracker,
+              watermarkAndState,
+              1,
+              3);
+      assertEquals(2, (int) result.getValue());
+      // Java uses BigDecimal so 0.2 * 170 = 63.9...
+      // BigDecimal.longValue() will round down to 63 instead of the expected 64
+      KV<WindowedValue, WindowedValue> expectedElementSplit =
+          createSplitInWindow(new OffsetRange(0, 63), new OffsetRange(63, 100), window2);
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1), ImmutableList.of(window3));
+      assertEquals(expectedElementSplit.getKey(), result.getKey().getPrimarySplitRoot());
+      assertEquals(expectedElementSplit.getValue(), result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessSplitOnLastWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window3,
+              windows,
+              currentWatermarkEstimatorState,
+              0.2,
+              tracker,
+              watermarkAndState,
+              2,
+              3);
+      assertEquals(3, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedElementSplit =
+          createSplitInWindow(new OffsetRange(0, 44), new OffsetRange(44, 100), window3);
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1, window2), ImmutableList.of());
+      assertEquals(expectedElementSplit.getKey(), result.getKey().getPrimarySplitRoot());
+      assertEquals(expectedElementSplit.getValue(), result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessSplitOnFirstWindowFallback() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(100L);
+      assertNull(tracker.trySplit(0.0));
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window3,
+              windows,
+              currentWatermarkEstimatorState,
+              0,
+              tracker,
+              watermarkAndState,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1), ImmutableList.of(window2, window3));
+      assertNull(result.getKey().getPrimarySplitRoot());
+      assertNull(result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessSplitOnLastWindowWhenNoElementSplit() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(100L);
+      assertNull(tracker.trySplit(0.0));
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window3,
+              windows,
+              currentWatermarkEstimatorState,
+              0,
+              tracker,
+              watermarkAndState,
+              2,
+              3);
+      assertNull(result);
+    }
+
+    @Test
+    public void testTrySplitForProcessOnWindowBoundaryRoundUp() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window2,
+              windows,
+              currentWatermarkEstimatorState,
+              0.6,
+              tracker,
+              watermarkAndState,
+              0,
+              3);
+      assertEquals(2, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1, window2), ImmutableList.of(window3));
+      assertNull(result.getKey().getPrimarySplitRoot());
+      assertNull(result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessOnWindowBoundaryRoundDown() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window2,
+              windows,
+              currentWatermarkEstimatorState,
+              0.3,
+              tracker,
+              watermarkAndState,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1), ImmutableList.of(window2, window3));
+      assertNull(result.getKey().getPrimarySplitRoot());
+      assertNull(result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForProcessOnWindowBoundaryRoundDownOnLastWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      OffsetRangeTracker tracker = new OffsetRangeTracker(currentRestriction);
+      tracker.tryClaim(30L);
+      KV<WindowedSplitResult, Integer> result =
+          FnApiDoFnRunner.trySplitForProcess(
+              currentElement,
+              currentRestriction,
+              window2,
+              windows,
+              currentWatermarkEstimatorState,
+              0.9,
+              tracker,
+              watermarkAndState,
+              0,
+              3);
+      assertEquals(2, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1, window2), ImmutableList.of(window3));
+      assertNull(result.getKey().getPrimarySplitRoot());
+      assertNull(result.getKey().getResidualSplitRoot());
+      assertEquals(
+          expectedWindowSplit.getKey(), result.getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(), result.getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    private HandlesSplits createSplitDelegate(
+        double progress, double expectedFraction, HandlesSplits.SplitResult result) {
+      return new HandlesSplits() {
+        @Override
+        public SplitResult trySplit(double fractionOfRemainder) {
+          checkArgument(fractionOfRemainder == expectedFraction);
+          return result;
+        }
+
+        @Override
+        public double getProgress() {
+          return progress;
+        }
+      };
+    }
+
+    @Test
+    public void testTrySplitForTruncateCheckpointOnFirstWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult splitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.0, splitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.0,
+              splitDelegate,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(), ImmutableList.of(window2, window3));
+      assertEquals(splitResult, result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateCheckpointOnFirstWindowAfterOneSplit() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult splitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.0, splitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.0,
+              splitDelegate,
+              0,
+              2);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(), ImmutableList.of(window2));
+      assertEquals(splitResult, result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateSplitOnFirstWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult splitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.54, splitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.2,
+              splitDelegate,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(), ImmutableList.of(window2, window3));
+      assertEquals(splitResult, result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateSplitOnMiddleWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult splitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.34, splitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.2,
+              splitDelegate,
+              1,
+              3);
+      assertEquals(2, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1), ImmutableList.of(window3));
+      assertEquals(splitResult, result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateSplitOnLastWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult splitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.2, splitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.2,
+              splitDelegate,
+              2,
+              3);
+      assertEquals(3, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1, window2), ImmutableList.of());
+      assertEquals(splitResult, result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateSplitOnFirstWindowFallback() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult unusedSplitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(1.0, 0.0, unusedSplitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.0,
+              splitDelegate,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1), ImmutableList.of(window2, window3));
+      assertNull(result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateSplitOnLastWindowWhenNoElementSplit() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      HandlesSplits splitDelegate = createSplitDelegate(1.0, 0.0, null);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.0,
+              splitDelegate,
+              2,
+              3);
+      assertNull(result);
+    }
+
+    @Test
+    public void testTrySplitForTruncateOnWindowBoundaryRoundUp() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult unusedSplitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.0, unusedSplitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.6,
+              splitDelegate,
+              0,
+              3);
+      assertEquals(2, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1, window2), ImmutableList.of(window3));
+      assertNull(result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateOnWindowBoundaryRoundDown() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult unusedSplitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.0, unusedSplitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.3,
+              splitDelegate,
+              0,
+              3);
+      assertEquals(1, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1), ImmutableList.of(window2, window3));
+      assertNull(result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
+
+    @Test
+    public void testTrySplitForTruncateOnWindowBoundaryRoundDownOnLastWindow() throws Exception {
+      List<BoundedWindow> windows = ImmutableList.copyOf(currentElement.getWindows());
+      SplitResult unusedSplitResult =
+          SplitResult.of(
+              ImmutableList.of(BundleApplication.getDefaultInstance()),
+              ImmutableList.of(DelayedBundleApplication.getDefaultInstance()));
+      HandlesSplits splitDelegate = createSplitDelegate(0.3, 0.0, unusedSplitResult);
+      KV<KV<WindowedSplitResult, SplitResult>, Integer> result =
+          FnApiDoFnRunner.trySplitForTruncate(
+              currentElement,
+              currentRestriction,
+              window1,
+              windows,
+              currentWatermarkEstimatorState,
+              0.6,
+              splitDelegate,
+              0,
+              3);
+      assertEquals(2, (int) result.getValue());
+      KV<WindowedValue, WindowedValue> expectedWindowSplit =
+          createSplitAcrossWindows(ImmutableList.of(window1, window2), ImmutableList.of(window3));
+      assertNull(result.getKey().getValue());
+      assertEquals(
+          expectedWindowSplit.getKey(),
+          result.getKey().getKey().getPrimaryInFullyProcessedWindowsRoot());
+      assertEquals(
+          expectedWindowSplit.getValue(),
+          result.getKey().getKey().getResidualInUnprocessedWindowsRoot());
+    }
   }
 }