You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/06/05 21:04:50 UTC

[beam] branch master updated: [BEAM-2939] Fix FnApiDoFnRunner to ensure that we output within the correct window when processing a splittable dofn (#11922)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 96836a7  [BEAM-2939] Fix FnApiDoFnRunner to ensure that we output within the correct window when processing a splittable dofn (#11922)
96836a7 is described below

commit 96836a741558ba93ab367b44ad46fbbf28b86449
Author: Lukasz Cwik <lu...@gmail.com>
AuthorDate: Fri Jun 5 14:04:28 2020 -0700

    [BEAM-2939] Fix FnApiDoFnRunner to ensure that we output within the correct window when processing a splittable dofn (#11922)
    
    * [BEAM-2939] Fix FnApiDoFnRunner to ensure that we output within the correct window.
    
    This fixes a bug where we would output within all the windows instead of just the current window.
    This would not impact any SDF that used only a single window while processing.
    
    * Make sure that splitting/checkpointing is window aware.
    
    * fixup! Address PR comments.
---
 .../beam/fn/harness/BeamFnDataReadRunner.java      |   4 +-
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    | 238 +++++--
 .../org/apache/beam/fn/harness/HandlesSplits.java  |  11 +-
 .../fn/harness/control/BundleSplitListener.java    |  12 +-
 .../beam/fn/harness/BeamFnDataReadRunnerTest.java  |  20 +-
 .../beam/fn/harness/FnApiDoFnRunnerTest.java       | 686 ++++++++++++++++++++-
 .../harness/control/BundleSplitListenerTest.java   |   7 +-
 7 files changed, 910 insertions(+), 68 deletions(-)

diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index 3b10b33..34cdcb4 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -292,8 +292,8 @@ public class BeamFnDataReadRunner<OutputT> {
           if (splitResult != null) {
             stopIndex = index + 1;
             response
-                .addPrimaryRoots(splitResult.getPrimaryRoot())
-                .addResidualRoots(splitResult.getResidualRoot())
+                .addAllPrimaryRoots(splitResult.getPrimaryRoots())
+                .addAllResidualRoots(splitResult.getResidualRoots())
                 .addChannelSplitsBuilder()
                 .setLastPrimaryElement(index - 1)
                 .setFirstResidualElement(stopIndex);
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 0baf76b..1c8ba32 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -24,16 +24,19 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
 import com.google.auto.service.AutoService;
 import com.google.auto.value.AutoValue;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
+import java.util.ListIterator;
 import java.util.Map;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
+import javax.annotation.Nullable;
 import org.apache.beam.fn.harness.PTransformRunnerFactory.ProgressRequestCallback;
 import org.apache.beam.fn.harness.control.BundleSplitListener;
 import org.apache.beam.fn.harness.data.BeamFnDataClient;
@@ -297,6 +300,12 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   private WindowedValue<InputT> currentElement;
 
   /**
+   * Only valid during {@link #processElementForElementAndRestriction} and {@link
+   * #processElementForSizedElementAndRestriction}.
+   */
+  private ListIterator<BoundedWindow> currentWindowIterator;
+
+  /**
    * Only valid during {@link #processElementForPairWithRestriction}, {@link
    * #processElementForSplitRestriction}, {@link #processElementForElementAndRestriction} and {@link
    * #processElementForSizedElementAndRestriction}, null otherwise.
@@ -577,20 +586,83 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
     switch (pTransform.getSpec().getUrn()) {
       case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN:
         this.convertSplitResultToWindowedSplitResult =
-            (splitResult, watermarkEstimatorState) ->
-                WindowedSplitResult.forRoots(
-                    currentElement.withValue(
-                        KV.of(
-                            currentElement.getValue(),
-                            KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState))),
-                    currentElement.withValue(
-                        KV.of(
-                            currentElement.getValue(),
-                            KV.of(splitResult.getResidual(), watermarkEstimatorState))));
+            (splitResult, watermarkEstimatorState) -> {
+              List<BoundedWindow> primaryFullyProcessedWindows =
+                  ImmutableList.copyOf(
+                      Iterables.limit(
+                          currentElement.getWindows(), currentWindowIterator.previousIndex()));
+              // Advances the iterator consuming the remaining windows.
+              List<BoundedWindow> residualUnprocessedWindows =
+                  ImmutableList.copyOf(currentWindowIterator);
+              return WindowedSplitResult.forRoots(
+                  primaryFullyProcessedWindows.isEmpty()
+                      ? null
+                      : WindowedValue.of(
+                          KV.of(
+                              currentElement.getValue(),
+                              KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                          currentElement.getTimestamp(),
+                          primaryFullyProcessedWindows,
+                          currentElement.getPane()),
+                  WindowedValue.of(
+                      KV.of(
+                          currentElement.getValue(),
+                          KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
+                      currentElement.getTimestamp(),
+                      currentWindow,
+                      currentElement.getPane()),
+                  WindowedValue.of(
+                      KV.of(
+                          currentElement.getValue(),
+                          KV.of(splitResult.getResidual(), watermarkEstimatorState)),
+                      currentElement.getTimestamp(),
+                      currentWindow,
+                      currentElement.getPane()),
+                  residualUnprocessedWindows.isEmpty()
+                      ? null
+                      : WindowedValue.of(
+                          KV.of(
+                              currentElement.getValue(),
+                              KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                          currentElement.getTimestamp(),
+                          residualUnprocessedWindows,
+                          currentElement.getPane()));
+            };
         break;
       case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN:
         this.convertSplitResultToWindowedSplitResult =
             (splitResult, watermarkEstimatorState) -> {
+              List<BoundedWindow> primaryFullyProcessedWindows =
+                  ImmutableList.copyOf(
+                      Iterables.limit(
+                          currentElement.getWindows(), currentWindowIterator.previousIndex()));
+              // Advances the iterator consuming the remaining windows.
+              List<BoundedWindow> residualUnprocessedWindows =
+                  ImmutableList.copyOf(currentWindowIterator);
+              // If the window has been observed then the splitAndSize method would have already
+              // output sizes for each window separately.
+              //
+              // TODO: Consider using the original size on the element instead of recomputing
+              // this here.
+              double fullSize =
+                  primaryFullyProcessedWindows.isEmpty() && residualUnprocessedWindows.isEmpty()
+                      ? 0
+                      : doFnInvoker.invokeGetSize(
+                          new DelegatingArgumentProvider<InputT, OutputT>(
+                              processContext,
+                              PTransformTranslation
+                                      .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN
+                                  + "/GetPrimarySize") {
+                            @Override
+                            public Object restriction() {
+                              return currentRestriction;
+                            }
+
+                            @Override
+                            public RestrictionTracker<?, ?> restrictionTracker() {
+                              return doFnInvoker.invokeNewTracker(this);
+                            }
+                          });
               double primarySize =
                   doFnInvoker.invokeGetSize(
                       new DelegatingArgumentProvider<InputT, OutputT>(
@@ -626,18 +698,46 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
                         }
                       });
               return WindowedSplitResult.forRoots(
-                  currentElement.withValue(
+                  primaryFullyProcessedWindows.isEmpty()
+                      ? null
+                      : WindowedValue.of(
+                          KV.of(
+                              KV.of(
+                                  currentElement.getValue(),
+                                  KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                              fullSize),
+                          currentElement.getTimestamp(),
+                          primaryFullyProcessedWindows,
+                          currentElement.getPane()),
+                  WindowedValue.of(
                       KV.of(
                           KV.of(
                               currentElement.getValue(),
                               KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
-                          primarySize)),
-                  currentElement.withValue(
+                          primarySize),
+                      currentElement.getTimestamp(),
+                      currentWindow,
+                      currentElement.getPane()),
+                  WindowedValue.of(
                       KV.of(
                           KV.of(
                               currentElement.getValue(),
                               KV.of(splitResult.getResidual(), watermarkEstimatorState)),
-                          residualSize)));
+                          residualSize),
+                      currentElement.getTimestamp(),
+                      currentWindow,
+                      currentElement.getPane()),
+                  residualUnprocessedWindows.isEmpty()
+                      ? null
+                      : WindowedValue.of(
+                          KV.of(
+                              KV.of(
+                                  currentElement.getValue(),
+                                  KV.of(currentRestriction, currentWatermarkEstimatorState)),
+                              fullSize),
+                          currentElement.getTimestamp(),
+                          residualUnprocessedWindows,
+                          currentElement.getPane()));
             };
         break;
       default:
@@ -683,7 +783,6 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
         break;
       default:
         // no-op
-
     }
   }
 
@@ -755,12 +854,15 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
         outputTo(
             mainOutputConsumers,
             (WindowedValue)
-                elem.withValue(
+                WindowedValue.of(
                     KV.of(
                         elem.getValue(),
                         KV.of(
                             currentRestriction,
-                            doFnInvoker.invokeGetInitialWatermarkEstimatorState(processContext)))));
+                            doFnInvoker.invokeGetInitialWatermarkEstimatorState(processContext))),
+                    currentElement.getTimestamp(),
+                    currentWindow,
+                    currentElement.getPane()));
       }
     } finally {
       currentElement = null;
@@ -793,13 +895,26 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   @AutoValue
   abstract static class WindowedSplitResult {
     public static WindowedSplitResult forRoots(
-        WindowedValue primaryRoot, WindowedValue residualRoot) {
-      return new AutoValue_FnApiDoFnRunner_WindowedSplitResult(primaryRoot, residualRoot);
+        WindowedValue primaryInFullyProcessedWindowsRoot,
+        WindowedValue primarySplitRoot,
+        WindowedValue residualSplitRoot,
+        WindowedValue residualInUnprocessedWindowsRoot) {
+      return new AutoValue_FnApiDoFnRunner_WindowedSplitResult(
+          primaryInFullyProcessedWindowsRoot,
+          primarySplitRoot,
+          residualSplitRoot,
+          residualInUnprocessedWindowsRoot);
     }
 
-    public abstract WindowedValue getPrimaryRoot();
+    @Nullable
+    public abstract WindowedValue getPrimaryInFullyProcessedWindowsRoot();
 
-    public abstract WindowedValue getResidualRoot();
+    public abstract WindowedValue getPrimarySplitRoot();
+
+    public abstract WindowedValue getResidualSplitRoot();
+
+    @Nullable
+    public abstract WindowedValue getResidualInUnprocessedWindowsRoot();
   }
 
   private void processElementForSizedElementAndRestriction(
@@ -811,13 +926,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
       WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
     currentElement = elem.withValue(elem.getValue().getKey());
     try {
-      Iterator<BoundedWindow> windowIterator =
-          (Iterator<BoundedWindow>) elem.getWindows().iterator();
-      while (windowIterator.hasNext()) {
+      currentWindowIterator =
+          currentElement.getWindows() instanceof List
+              ? ((List) currentElement.getWindows()).listIterator()
+              : ImmutableList.<BoundedWindow>copyOf(elem.getWindows()).listIterator();
+      while (true) {
         synchronized (splitLock) {
+          if (!currentWindowIterator.hasNext()) {
+            return;
+          }
           currentRestriction = elem.getValue().getValue().getKey();
           currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
-          currentWindow = windowIterator.next();
+          currentWindow = currentWindowIterator.next();
           currentTracker =
               RestrictionTrackers.observe(
                   doFnInvoker.invokeNewTracker(processContext),
@@ -845,21 +965,25 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
         // Attempt to checkpoint the current restriction.
         HandlesSplits.SplitResult splitResult =
             trySplitForElementAndRestriction(0, continuation.resumeDelay());
-        // After the user has chosen to resume processing later, the Runner may have stolen
-        // the remainder of work through a split call so the above trySplit may return null. If so,
-        // the current restriction must be done.
+        /**
+         * After the user has chosen to resume processing later, either the restriction is already
+         * done and the user unknowingly claimed the last element or the Runner may have stolen the
+         * remainder of work through a split call so the above trySplit may return null. If so, the
+         * current restriction must be done.
+         */
         if (splitResult == null) {
           currentTracker.checkDone();
           continue;
         }
         // Forward the split to the bundle level split listener.
-        splitListener.split(splitResult.getPrimaryRoot(), splitResult.getResidualRoot());
+        splitListener.split(splitResult.getPrimaryRoots(), splitResult.getResidualRoots());
       }
     } finally {
       synchronized (splitLock) {
         currentElement = null;
         currentRestriction = null;
         currentWatermarkEstimatorState = null;
+        currentWindowIterator = null;
         currentWindow = null;
         currentTracker = null;
         currentWatermarkEstimator = null;
@@ -923,20 +1047,59 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
           convertSplitResultToWindowedSplitResult.apply(result, watermarkAndState.getValue());
     }
 
+    List<BundleApplication> primaryRoots = new ArrayList<>();
+    List<DelayedBundleApplication> residualRoots = new ArrayList<>();
+    Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
+    if (windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot() != null) {
+      ByteString.Output primaryInOtherWindowsBytes = ByteString.newOutput();
+      try {
+        fullInputCoder.encode(
+            windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot(),
+            primaryInOtherWindowsBytes);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      BundleApplication.Builder primaryApplicationInOtherWindows =
+          BundleApplication.newBuilder()
+              .setTransformId(pTransformId)
+              .setInputId(mainInputId)
+              .setElement(primaryInOtherWindowsBytes.toByteString());
+      primaryRoots.add(primaryApplicationInOtherWindows.build());
+    }
+    if (windowedSplitResult.getResidualInUnprocessedWindowsRoot() != null) {
+      ByteString.Output bytesOut = ByteString.newOutput();
+      try {
+        fullInputCoder.encode(windowedSplitResult.getResidualInUnprocessedWindowsRoot(), bytesOut);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+      BundleApplication.Builder residualInUnprocessedWindowsRoot =
+          BundleApplication.newBuilder()
+              .setTransformId(pTransformId)
+              .setInputId(mainInputId)
+              .setElement(bytesOut.toByteString());
+      // We don't want to change the output watermarks or set the checkpoint resume time since
+      // that applies to the current window.
+      residualRoots.add(
+          DelayedBundleApplication.newBuilder()
+              .setApplication(residualInUnprocessedWindowsRoot)
+              .build());
+    }
+
     ByteString.Output primaryBytes = ByteString.newOutput();
     ByteString.Output residualBytes = ByteString.newOutput();
     try {
-      Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
-      fullInputCoder.encode(windowedSplitResult.getPrimaryRoot(), primaryBytes);
-      fullInputCoder.encode(windowedSplitResult.getResidualRoot(), residualBytes);
+      fullInputCoder.encode(windowedSplitResult.getPrimarySplitRoot(), primaryBytes);
+      fullInputCoder.encode(windowedSplitResult.getResidualSplitRoot(), residualBytes);
     } catch (IOException e) {
       throw new RuntimeException(e);
     }
-    BundleApplication.Builder primaryApplication =
+    primaryRoots.add(
         BundleApplication.newBuilder()
             .setTransformId(pTransformId)
             .setInputId(mainInputId)
-            .setElement(primaryBytes.toByteString());
+            .setElement(primaryBytes.toByteString())
+            .build());
     BundleApplication.Builder residualApplication =
         BundleApplication.newBuilder()
             .setTransformId(pTransformId)
@@ -953,12 +1116,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
                 .build());
       }
     }
-    return HandlesSplits.SplitResult.of(
-        primaryApplication.build(),
+    residualRoots.add(
         DelayedBundleApplication.newBuilder()
-            .setApplication(residualApplication.build())
+            .setApplication(residualApplication)
             .setRequestedTimeDelay(Durations.fromMillis(resumeDelay.getMillis()))
             .build());
+
+    return HandlesSplits.SplitResult.of(primaryRoots, residualRoots);
   }
 
   private <K> void processTimer(
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
index 63b5868..54e1983 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
@@ -18,7 +18,9 @@
 package org.apache.beam.fn.harness;
 
 import com.google.auto.value.AutoValue;
+import java.util.List;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 
 /**
@@ -36,12 +38,13 @@ public interface HandlesSplits {
   @AutoValue
   abstract class SplitResult {
     public static SplitResult of(
-        BeamFnApi.BundleApplication primaryRoot, BeamFnApi.DelayedBundleApplication residualRoot) {
-      return new AutoValue_HandlesSplits_SplitResult(primaryRoot, residualRoot);
+        List<BundleApplication> primaryRoots,
+        List<BeamFnApi.DelayedBundleApplication> residualRoots) {
+      return new AutoValue_HandlesSplits_SplitResult(primaryRoots, residualRoots);
     }
 
-    public abstract BeamFnApi.BundleApplication getPrimaryRoot();
+    public abstract List<BeamFnApi.BundleApplication> getPrimaryRoots();
 
-    public abstract BeamFnApi.DelayedBundleApplication getResidualRoot();
+    public abstract List<BeamFnApi.DelayedBundleApplication> getResidualRoots();
   }
 }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
index 5557be8..f897298 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -38,21 +38,21 @@ public interface BundleSplitListener {
    * are a decomposition of work that has been given away by the bundle, so the runner must delegate
    * it for someone else to execute.
    */
-  void split(BundleApplication primaryRoot, DelayedBundleApplication residualRoots);
+  void split(List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots);
 
   /** A {@link BundleSplitListener} which gathers all splits produced and stores them in memory. */
   @AutoValue
   @NotThreadSafe
   abstract class InMemory implements BundleSplitListener {
     public static InMemory create() {
-      return new AutoValue_BundleSplitListener_InMemory(
-          new ArrayList<BundleApplication>(), new ArrayList<DelayedBundleApplication>());
+      return new AutoValue_BundleSplitListener_InMemory(new ArrayList<>(), new ArrayList<>());
     }
 
     @Override
-    public void split(BundleApplication primaryRoot, DelayedBundleApplication residualRoot) {
-      getPrimaryRoots().add(primaryRoot);
-      getResidualRoots().add(residualRoot);
+    public void split(
+        List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots) {
+      getPrimaryRoots().addAll(primaryRoots);
+      getResidualRoots().addAll(residualRoots);
     }
 
     public void clear() {
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index 9293d01..18994e7 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -664,15 +664,17 @@ public class BeamFnDataReadRunnerTest {
     @Override
     public SplitResult trySplit(double fractionOfRemainder) {
       return SplitResult.of(
-          BundleApplication.newBuilder()
-              .setInputId(String.format("primary%.1f", fractionOfRemainder))
-              .build(),
-          DelayedBundleApplication.newBuilder()
-              .setApplication(
-                  BundleApplication.newBuilder()
-                      .setInputId(String.format("residual%.1f", 1 - fractionOfRemainder))
-                      .build())
-              .build());
+          Collections.singletonList(
+              BundleApplication.newBuilder()
+                  .setInputId(String.format("primary%.1f", fractionOfRemainder))
+                  .build()),
+          Collections.singletonList(
+              DelayedBundleApplication.newBuilder()
+                  .setApplication(
+                      BundleApplication.newBuilder()
+                          .setInputId(String.format("residual%.1f", 1 - fractionOfRemainder))
+                          .build())
+                  .build()));
     }
   }
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index fb00550..f9eec08 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -114,6 +114,7 @@ import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.WindowedValue;
@@ -125,6 +126,7 @@ import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.hamcrest.collection.IsMapContaining;
@@ -595,8 +597,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
     // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
     FnDataReceiver<WindowedValue<?>> mainInput =
         consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInWindow("X", windowA));
-    mainInput.accept(valueInWindow("Y", windowB));
+    mainInput.accept(valueInWindows("X", windowA));
+    mainInput.accept(valueInWindows("Y", windowB));
     assertThat(mainOutputValues, hasSize(2));
     assertThat(
         mainOutputValues.get(0).getValue(),
@@ -713,8 +715,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
     // Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
     FnDataReceiver<WindowedValue<?>> mainInput =
         consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInWindow("X", windowA));
-    mainInput.accept(valueInWindow("Y", windowB));
+    mainInput.accept(valueInWindows("X", windowA));
+    mainInput.accept(valueInWindows("Y", windowB));
     mainOutputValues.clear();
 
     Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
@@ -1007,8 +1009,13 @@ public class FnApiDoFnRunnerTest implements Serializable {
         PaneInfo.NO_FIRING);
   }
 
-  private <T> WindowedValue<T> valueInWindow(T value, BoundedWindow window) {
-    return WindowedValue.of(value, window.maxTimestamp(), window, PaneInfo.NO_FIRING);
+  private <T> WindowedValue<T> valueInWindows(
+      T value, BoundedWindow window, BoundedWindow... windows) {
+    return WindowedValue.of(
+        value,
+        window.maxTimestamp(),
+        ImmutableList.<BoundedWindow>builder().add(window).add(windows).build(),
+        PaneInfo.NO_FIRING);
   }
 
   private static class TestTimerfulDoFn extends DoFn<KV<String, String>, String> {
@@ -1543,8 +1550,9 @@ public class FnApiDoFnRunnerTest implements Serializable {
               timestampedValueInGlobalWindow("7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
               timestampedValueInGlobalWindow("7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3))));
 
-      BundleApplication primaryRoot = trySplitResult.getPrimaryRoot();
-      DelayedBundleApplication residualRoot = trySplitResult.getResidualRoot();
+      BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+      DelayedBundleApplication residualRoot =
+          Iterables.getOnlyElement(trySplitResult.getResidualRoots());
       assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
       assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
       assertEquals(
@@ -1594,6 +1602,393 @@ public class FnApiDoFnRunnerTest implements Serializable {
   }
 
   @Test
+  public void testProcessElementForWindowedSizedElementAndRestriction() throws Exception {
+    Pipeline p = Pipeline.create();
+    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+    TestSplittableDoFn doFn = new TestSplittableDoFn(singletonSideInputView);
+
+    valuePCollection
+        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+        .apply(TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+    RunnerApi.Pipeline pProto =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+            SplittableParDoExpander.createSizedReplacement());
+    String expandedTransformId =
+        Iterables.find(
+                pProto.getComponents().getTransformsMap().entrySet(),
+                entry ->
+                    entry
+                            .getValue()
+                            .getSpec()
+                            .getUrn()
+                            .equals(
+                                PTransformTranslation
+                                    .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+            .getKey();
+    RunnerApi.PTransform pTransform =
+        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+    String inputPCollectionId =
+        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+    RunnerApi.PCollection inputPCollection =
+        pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+    RehydratedComponents rehydratedComponents =
+        RehydratedComponents.forComponents(pProto.getComponents());
+    Coder<WindowedValue> inputCoder =
+        WindowedValue.getFullCoder(
+            CoderTranslation.fromProto(
+                pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+                rehydratedComponents,
+                TranslationContext.DEFAULT),
+            (Coder)
+                CoderTranslation.fromProto(
+                    pProto
+                        .getComponents()
+                        .getCodersOrThrow(
+                            pProto
+                                .getComponents()
+                                .getWindowingStrategiesOrThrow(
+                                    inputPCollection.getWindowingStrategyId())
+                                .getWindowCoderId()),
+                    rehydratedComponents,
+                    TranslationContext.DEFAULT));
+    String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+    ImmutableMap<StateKey, ByteString> stateData =
+        ImmutableMap.of(
+            multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
+            encode("8"));
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+    List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    consumers.register(
+        outputPCollectionId,
+        TEST_TRANSFORM_ID,
+        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+    PTransformFunctionRegistry startFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+    PTransformFunctionRegistry finishFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+    List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
+    BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
+
+    new FnApiDoFnRunner.Factory<>()
+        .createRunnerForPTransform(
+            PipelineOptionsFactory.create(),
+            null /* beamFnDataClient */,
+            fakeClient,
+            null /* beamFnTimerClient */,
+            TEST_TRANSFORM_ID,
+            pTransform,
+            Suppliers.ofInstance("57L")::get,
+            pProto.getComponents().getPcollectionsMap(),
+            pProto.getComponents().getCodersMap(),
+            pProto.getComponents().getWindowingStrategiesMap(),
+            consumers,
+            startFunctionRegistry,
+            finishFunctionRegistry,
+            teardownFunctions::add,
+            progressRequestCallbacks::add,
+            splitListener,
+            null /* bundleFinalizer */);
+
+    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+    FnDataReceiver<WindowedValue<?>> mainInput =
+        consumers.getMultiplexingConsumer(inputPCollectionId);
+    assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+    BoundedWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+    BoundedWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+    {
+      // Check that before processing an element we don't report progress
+      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+      WindowedValue<?> firstValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+              window1,
+              window2);
+      mainInput.accept(firstValue);
+      // Check that after processing an element we don't report progress
+      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+      // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
+      // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
+      // with the original watermark while the residual represents [8, 10) with the new MIN + 7
+      // watermark.
+      //
+      // Since we were on the first window, we expect only a single primary root and two residual
+      // roots (the split + the unprocessed window).
+      BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+      assertEquals(2, splitListener.getResidualRoots().size());
+      DelayedBundleApplication residualRoot = splitListener.getResidualRoots().get(1);
+      DelayedBundleApplication residualRootForUnprocessedWindows =
+          splitListener.getResidualRoots().get(0);
+      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+      assertEquals(
+          ParDoTranslation.getMainInputName(pTransform),
+          residualRoot.getApplication().getInputId());
+      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
+      assertEquals(
+          ImmutableMap.of(
+              "output",
+              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                  .build()),
+          residualRoot.getApplication().getOutputWatermarksMap());
+      assertEquals(
+          org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
+              .setSeconds(54)
+              .setNanos(321000000)
+              .build(),
+          residualRoot.getRequestedTimeDelay());
+      assertEquals(
+          ParDoTranslation.getMainInputName(pTransform),
+          residualRootForUnprocessedWindows.getApplication().getInputId());
+      assertEquals(
+          TEST_TRANSFORM_ID, residualRootForUnprocessedWindows.getApplication().getTransformId());
+      assertEquals(
+          residualRootForUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
+          residualRootForUnprocessedWindows.getRequestedTimeDelay());
+      assertTrue(
+          residualRootForUnprocessedWindows.getApplication().getOutputWatermarksMap().isEmpty());
+
+      assertEquals(
+          decode(inputCoder, primaryRoot.getElement()),
+          WindowedValue.of(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE)), 3.0),
+              firstValue.getTimestamp(),
+              window1,
+              firstValue.getPane()));
+      assertEquals(
+          decode(inputCoder, residualRoot.getApplication().getElement()),
+          WindowedValue.of(
+              KV.of(
+                  KV.of(
+                      "5", KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
+                  2.0),
+              firstValue.getTimestamp(),
+              window1,
+              firstValue.getPane()));
+      assertEquals(
+          decode(inputCoder, residualRootForUnprocessedWindows.getApplication().getElement()),
+          WindowedValue.of(
+              KV.of(
+                  KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+              firstValue.getTimestamp(),
+              window2,
+              firstValue.getPane()));
+      splitListener.clear();
+
+      // Check that before processing an element we don't report progress
+      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+      WindowedValue<?> secondValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+              window1,
+              window2);
+      mainInput.accept(secondValue);
+      // Check that after processing an element we don't report progress
+      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  "5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5), window1, firstValue.getPane()),
+              WindowedValue.of(
+                  "5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6), window1, firstValue.getPane()),
+              WindowedValue.of(
+                  "5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7), window1, firstValue.getPane()),
+              WindowedValue.of(
+                  "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, firstValue.getPane()),
+              WindowedValue.of(
+                  "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, firstValue.getPane()),
+              WindowedValue.of(
+                  "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window2, firstValue.getPane()),
+              WindowedValue.of(
+                  "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window2, firstValue.getPane())));
+      assertTrue(splitListener.getPrimaryRoots().isEmpty());
+      assertTrue(splitListener.getResidualRoots().isEmpty());
+      mainOutputValues.clear();
+    }
+
+    {
+      // Setup and launch the trySplit thread.
+      ExecutorService executorService = Executors.newSingleThreadExecutor();
+      Future<HandlesSplits.SplitResult> trySplitFuture =
+          executorService.submit(
+              () -> {
+                try {
+                  doFn.waitForSplitElementToBeProcessed();
+                  // Currently processing "3" out of range [0, 5) elements.
+                  assertEquals(0.6, ((HandlesSplits) mainInput).getProgress(), 0.01);
+
+                  // Check that during progressing of an element we report progress
+                  List<MonitoringInfo> mis =
+                      Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos();
+                  MonitoringInfo.Builder expectedCompleted = MonitoringInfo.newBuilder();
+                  expectedCompleted.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED);
+                  expectedCompleted.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                  expectedCompleted.putLabels(
+                      MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+                  expectedCompleted.setPayload(
+                      ByteString.copyFrom(
+                          CoderUtils.encodeToByteArray(
+                              IterableCoder.of(DoubleCoder.of()), Collections.singletonList(3.0))));
+                  MonitoringInfo.Builder expectedRemaining = MonitoringInfo.newBuilder();
+                  expectedRemaining.setUrn(MonitoringInfoConstants.Urns.WORK_REMAINING);
+                  expectedRemaining.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+                  expectedRemaining.putLabels(
+                      MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+                  expectedRemaining.setPayload(
+                      ByteString.copyFrom(
+                          CoderUtils.encodeToByteArray(
+                              IterableCoder.of(DoubleCoder.of()), Collections.singletonList(2.0))));
+                  assertThat(
+                      mis,
+                      containsInAnyOrder(expectedCompleted.build(), expectedRemaining.build()));
+
+                  return ((HandlesSplits) mainInput).trySplit(0);
+                } finally {
+                  doFn.releaseWaitingProcessElementThread();
+                }
+              });
+
+      // Check that before processing an element we don't report progress
+      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+      WindowedValue<?> splitValue =
+          valueInWindows(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+              window1,
+              window2);
+      mainInput.accept(splitValue);
+      HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+
+      // Check that after processing an element we don't report progress
+      assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+      // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split on the first window.
+      // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
+      // has yet to be invoked for the split element and that the primary represents [0, 4) with
+      // the original watermark while the residual represents [4, 5) with the new MIN + 2 watermark.
+      //
+      // We expect to see none of the output for the second window.
+      assertThat(
+          mainOutputValues,
+          contains(
+              WindowedValue.of(
+                  "7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, splitValue.getPane()),
+              WindowedValue.of(
+                  "7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, splitValue.getPane()),
+              WindowedValue.of(
+                  "7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2), window1, splitValue.getPane()),
+              WindowedValue.of(
+                  "7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3), window1, splitValue.getPane())));
+
+      BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+      assertEquals(2, trySplitResult.getResidualRoots().size());
+      DelayedBundleApplication residualRoot = trySplitResult.getResidualRoots().get(1);
+      DelayedBundleApplication residualRootInUnprocessedWindows =
+          trySplitResult.getResidualRoots().get(0);
+      assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+      assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+      assertEquals(
+          ParDoTranslation.getMainInputName(pTransform),
+          residualRoot.getApplication().getInputId());
+      assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+      assertEquals(
+          TEST_TRANSFORM_ID, residualRootInUnprocessedWindows.getApplication().getTransformId());
+      assertEquals(
+          residualRootInUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
+          residualRootInUnprocessedWindows.getRequestedTimeDelay());
+      assertTrue(
+          residualRootInUnprocessedWindows.getApplication().getOutputWatermarksMap().isEmpty());
+      assertEquals(
+          valueInWindows(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)), 4.0),
+              window1),
+          inputCoder.decode(primaryRoot.getElement().newInput()));
+      assertEquals(
+          valueInWindows(
+              KV.of(
+                  KV.of(
+                      "7", KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
+                  1.0),
+              window1),
+          inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+      Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
+      assertEquals(
+          ImmutableMap.of(
+              "output",
+              org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+                  .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+                  .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+                  .build()),
+          residualRoot.getApplication().getOutputWatermarksMap());
+      assertEquals(
+          WindowedValue.of(
+              KV.of(
+                  KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+              splitValue.getTimestamp(),
+              window2,
+              splitValue.getPane()),
+          inputCoder.decode(
+              residualRootInUnprocessedWindows.getApplication().getElement().newInput()));
+
+      // We expect 0 resume delay.
+      assertEquals(
+          residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+          residualRoot.getRequestedTimeDelay());
+      // We don't expect the outputs to goto the SDK initiated checkpointing listener.
+      assertTrue(splitListener.getPrimaryRoots().isEmpty());
+      assertTrue(splitListener.getResidualRoots().isEmpty());
+      mainOutputValues.clear();
+      executorService.shutdown();
+    }
+
+    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+    assertThat(mainOutputValues, empty());
+
+    Iterables.getOnlyElement(teardownFunctions).run();
+    assertThat(mainOutputValues, empty());
+
+    // Assert that state data did not change
+    assertEquals(stateData, fakeClient.getData());
+  }
+
+  private static <T> T decode(Coder<T> coder, ByteString value) {
+    try {
+      return coder.decode(value.newInput());
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  @Test
   public void testProcessElementForPairWithRestriction() throws Exception {
     Pipeline p = Pipeline.create();
     PCollection<String> valuePCollection = p.apply(Create.of("unused"));
@@ -1687,6 +2082,121 @@ public class FnApiDoFnRunnerTest implements Serializable {
   }
 
   @Test
+  public void testProcessElementForWindowedPairWithRestriction() throws Exception {
+    Pipeline p = Pipeline.create();
+    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+    valuePCollection
+        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+        .apply(
+            TEST_TRANSFORM_ID,
+            ParDo.of(new TestSplittableDoFn(singletonSideInputView))
+                .withSideInputs(singletonSideInputView));
+
+    RunnerApi.Pipeline pProto =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+            SplittableParDoExpander.createSizedReplacement());
+    String expandedTransformId =
+        Iterables.find(
+                pProto.getComponents().getTransformsMap().entrySet(),
+                entry ->
+                    entry
+                            .getValue()
+                            .getSpec()
+                            .getUrn()
+                            .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
+                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+            .getKey();
+    RunnerApi.PTransform pTransform =
+        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+    String inputPCollectionId =
+        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+    List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+    PTransformFunctionRegistry startFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+    PTransformFunctionRegistry finishFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+    new FnApiDoFnRunner.Factory<>()
+        .createRunnerForPTransform(
+            PipelineOptionsFactory.create(),
+            null /* beamFnDataClient */,
+            fakeClient,
+            null /* beamFnTimerClient */,
+            TEST_TRANSFORM_ID,
+            pTransform,
+            Suppliers.ofInstance("57L")::get,
+            pProto.getComponents().getPcollectionsMap(),
+            pProto.getComponents().getCodersMap(),
+            pProto.getComponents().getWindowingStrategiesMap(),
+            consumers,
+            startFunctionRegistry,
+            finishFunctionRegistry,
+            teardownFunctions::add,
+            null /* addProgressRequestCallback */,
+            null /* bundleSplitListener */,
+            null /* bundleFinalizer */);
+
+    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+    FnDataReceiver<WindowedValue<?>> mainInput =
+        consumers.getMultiplexingConsumer(inputPCollectionId);
+    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+    WindowedValue<?> firstValue = valueInWindows("5", window1, window2);
+    WindowedValue<?> secondValue = valueInWindows("2", window1, window2);
+    mainInput.accept(firstValue);
+    mainInput.accept(secondValue);
+    assertThat(
+        mainOutputValues,
+        contains(
+            WindowedValue.of(
+                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                firstValue.getTimestamp(),
+                window1,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                firstValue.getTimestamp(),
+                window2,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                secondValue.getTimestamp(),
+                window1,
+                secondValue.getPane()),
+            WindowedValue.of(
+                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                secondValue.getTimestamp(),
+                window2,
+                secondValue.getPane())));
+    mainOutputValues.clear();
+
+    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+    assertThat(mainOutputValues, empty());
+
+    Iterables.getOnlyElement(teardownFunctions).run();
+    assertThat(mainOutputValues, empty());
+  }
+
+  @Test
   public void testProcessElementForSplitAndSizeRestriction() throws Exception {
     Pipeline p = Pipeline.create();
     PCollection<String> valuePCollection = p.apply(Create.of("unused"));
@@ -1795,4 +2305,164 @@ public class FnApiDoFnRunnerTest implements Serializable {
     Iterables.getOnlyElement(teardownFunctions).run();
     assertThat(mainOutputValues, empty());
   }
+
+  @Test
+  public void testProcessElementForWindowedSplitAndSizeRestriction() throws Exception {
+    Pipeline p = Pipeline.create();
+    PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+    PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+    valuePCollection
+        .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+        .apply(
+            TEST_TRANSFORM_ID,
+            ParDo.of(new TestSplittableDoFn(singletonSideInputView))
+                .withSideInputs(singletonSideInputView));
+
+    RunnerApi.Pipeline pProto =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+            SplittableParDoExpander.createSizedReplacement());
+    String expandedTransformId =
+        Iterables.find(
+                pProto.getComponents().getTransformsMap().entrySet(),
+                entry ->
+                    entry
+                            .getValue()
+                            .getSpec()
+                            .getUrn()
+                            .equals(
+                                PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
+                        && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+            .getKey();
+    RunnerApi.PTransform pTransform =
+        pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+    String inputPCollectionId =
+        pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+    String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+    FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+    List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+    PCollectionConsumerRegistry consumers =
+        new PCollectionConsumerRegistry(
+            metricsContainerRegistry, mock(ExecutionStateTracker.class));
+    consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+    PTransformFunctionRegistry startFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+    PTransformFunctionRegistry finishFunctionRegistry =
+        new PTransformFunctionRegistry(
+            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+    new FnApiDoFnRunner.Factory<>()
+        .createRunnerForPTransform(
+            PipelineOptionsFactory.create(),
+            null /* beamFnDataClient */,
+            fakeClient,
+            null /* beamFnTimerClient */,
+            TEST_TRANSFORM_ID,
+            pTransform,
+            Suppliers.ofInstance("57L")::get,
+            pProto.getComponents().getPcollectionsMap(),
+            pProto.getComponents().getCodersMap(),
+            pProto.getComponents().getWindowingStrategiesMap(),
+            consumers,
+            startFunctionRegistry,
+            finishFunctionRegistry,
+            teardownFunctions::add,
+            null /* addProgressRequestCallback */,
+            null /* bundleSplitListener */,
+            null /* bundleFinalizer */);
+
+    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+    mainOutputValues.clear();
+
+    assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+    FnDataReceiver<WindowedValue<?>> mainInput =
+        consumers.getMultiplexingConsumer(inputPCollectionId);
+    IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+    IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+    WindowedValue<?> firstValue =
+        valueInWindows(
+            KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+            window1,
+            window2);
+    WindowedValue<?> secondValue =
+        valueInWindows(
+            KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+            window1,
+            window2);
+    mainInput.accept(firstValue);
+    mainInput.accept(secondValue);
+    assertThat(
+        mainOutputValues,
+        contains(
+            WindowedValue.of(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    2.0),
+                firstValue.getTimestamp(),
+                window1,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    3.0),
+                firstValue.getTimestamp(),
+                window1,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    2.0),
+                firstValue.getTimestamp(),
+                window2,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    3.0),
+                firstValue.getTimestamp(),
+                window2,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of(
+                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    1.0),
+                firstValue.getTimestamp(),
+                window1,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of(
+                    KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    1.0),
+                firstValue.getTimestamp(),
+                window1,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of(
+                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    1.0),
+                firstValue.getTimestamp(),
+                window2,
+                firstValue.getPane()),
+            WindowedValue.of(
+                KV.of(
+                    KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    1.0),
+                firstValue.getTimestamp(),
+                window2,
+                firstValue.getPane())));
+    mainOutputValues.clear();
+
+    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+    assertThat(mainOutputValues, empty());
+
+    Iterables.getOnlyElement(teardownFunctions).run();
+    assertThat(mainOutputValues, empty());
+  }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
index 340bac2a..5a48efe 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
@@ -21,6 +21,7 @@ import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.empty;
 import static org.junit.Assert.assertThat;
 
+import java.util.Collections;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
 import org.junit.Test;
@@ -46,8 +47,10 @@ public class BundleSplitListenerTest {
   @Test
   public void testInMemory() {
     BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
-    splitListener.split(TEST_PRIMARY_1, TEST_RESIDUAL_1);
-    splitListener.split(TEST_PRIMARY_2, TEST_RESIDUAL_2);
+    splitListener.split(
+        Collections.singletonList(TEST_PRIMARY_1), Collections.singletonList(TEST_RESIDUAL_1));
+    splitListener.split(
+        Collections.singletonList(TEST_PRIMARY_2), Collections.singletonList(TEST_RESIDUAL_2));
     assertThat(splitListener.getPrimaryRoots(), contains(TEST_PRIMARY_1, TEST_PRIMARY_2));
     assertThat(splitListener.getResidualRoots(), contains(TEST_RESIDUAL_1, TEST_RESIDUAL_2));
     splitListener.clear();