You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2020/05/13 23:53:19 UTC

[beam] branch master updated: [BEAM-9935] Respect allowed split points in Java

This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 216dbe6  [BEAM-9935] Respect allowed split points in Java
     new 644b219  Merge pull request #11688 from boyuanzz/java_split
216dbe6 is described below

commit 216dbe671f1203f73861317c8dce7f3a7f5c0e8e
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Mon May 11 20:19:38 2020 -0700

    [BEAM-9935] Respect allowed split points in Java
---
 .../beam/fn/harness/BeamFnDataReadRunner.java      |  45 +-
 .../beam/fn/harness/BeamFnDataReadRunnerTest.java  | 929 +++++++++++++--------
 .../runners/worker/bundle_processor_test.py        |   2 +-
 3 files changed, 627 insertions(+), 349 deletions(-)

diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index 33e4493..eedb42c 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -17,11 +17,13 @@
  */
 package org.apache.beam.fn.harness;
 
-import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.getOnlyElement;
 
 import com.google.auto.service.AutoService;
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
@@ -57,7 +59,6 @@ import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Ints;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -211,6 +212,7 @@ public class BeamFnDataReadRunner<OutputT> {
     }
 
     long totalBufferSize = desiredSplit.getEstimatedInputElements();
+    List<Long> allowedSplitPoints = new ArrayList<>(desiredSplit.getAllowedSplitPointsList());
 
     HandlesSplits splittingConsumer = null;
     if (consumer instanceof HandlesSplits) {
@@ -249,10 +251,6 @@ public class BeamFnDataReadRunner<OutputT> {
         }
       }
 
-      checkArgument(
-          desiredSplit.getAllowedSplitPointsList().isEmpty(),
-          "TODO: BEAM-3836, support split point restrictions.");
-
       // Now figure out where to split.
       //
       // The units here (except for keepOfElementRemainder) are all in terms of number or
@@ -269,7 +267,10 @@ public class BeamFnDataReadRunner<OutputT> {
         // See if the amount we need to keep falls within the current element's remainder and if
         // so, attempt to split it.
         double keepOfElementRemainder = keep / (1 - currentElementProgress);
-        if (keepOfElementRemainder < 1) {
+        // If both index and index are allowed split point, we can split at index.
+        if (keepOfElementRemainder < 1
+            && isValidSplitPoint(allowedSplitPoints, index)
+            && isValidSplitPoint(allowedSplitPoints, index + 1)) {
           SplitResult splitResult =
               splittingConsumer != null ? splittingConsumer.trySplit(keepOfElementRemainder) : null;
           if (splitResult != null) {
@@ -285,10 +286,28 @@ public class BeamFnDataReadRunner<OutputT> {
         }
       }
 
-      // Otherwise, split at the closest element boundary.
-      int newStopIndex =
-          Ints.checkedCast(index + Math.max(1, Math.round(currentElementProgress + keep)));
-      if (newStopIndex < stopIndex) {
+      // Otherwise, split at the closest allowed element boundary.
+      long newStopIndex = index + Math.max(1, Math.round(currentElementProgress + keep));
+      if (!isValidSplitPoint(allowedSplitPoints, newStopIndex)) {
+        // Choose the closest allowed split point.
+        Collections.sort(allowedSplitPoints);
+        int closestSplitPointIndex =
+            -(Collections.binarySearch(allowedSplitPoints, newStopIndex) + 1);
+        if (closestSplitPointIndex == 0) {
+          newStopIndex = allowedSplitPoints.get(0);
+        } else if (closestSplitPointIndex == allowedSplitPoints.size()) {
+          newStopIndex = allowedSplitPoints.get(closestSplitPointIndex - 1);
+        } else {
+          long prevPoint = allowedSplitPoints.get(closestSplitPointIndex - 1);
+          long nextPoint = allowedSplitPoints.get(closestSplitPointIndex);
+          if (index < prevPoint && newStopIndex - prevPoint < nextPoint - newStopIndex) {
+            newStopIndex = prevPoint;
+          } else {
+            newStopIndex = nextPoint;
+          }
+        }
+      }
+      if (newStopIndex < stopIndex && newStopIndex > index) {
         stopIndex = newStopIndex;
         response
             .addChannelSplitsBuilder()
@@ -310,4 +329,8 @@ public class BeamFnDataReadRunner<OutputT> {
       index += 1;
     }
   }
+
+  private boolean isValidSplitPoint(List<Long> allowedSplitPoints, long index) {
+    return allowedSplitPoints.isEmpty() || allowedSplitPoints.contains(index);
+  }
 }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index bc7417f..75f3105 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -35,6 +35,7 @@ import static org.mockito.Mockito.when;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import java.util.ServiceLoader;
 import java.util.concurrent.Executors;
@@ -76,6 +77,7 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
@@ -83,17 +85,18 @@ import org.hamcrest.collection.IsMapContaining;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.junit.runners.Parameterized;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 /** Tests for {@link BeamFnDataReadRunner}. */
-@RunWith(JUnit4.class)
+@RunWith(Enclosed.class)
 public class BeamFnDataReadRunnerTest {
-
   private static final Coder<String> ELEMENT_CODER = StringUtf8Coder.of();
   private static final String ELEMENT_CODER_SPEC_ID = "string-coder-id";
   private static final Coder<WindowedValue<String>> CODER =
@@ -125,363 +128,558 @@ public class BeamFnDataReadRunnerTest {
 
   private static final String INPUT_TRANSFORM_ID = "1";
 
-  @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
-  @Mock private BeamFnDataClient mockBeamFnDataClient;
-  @Captor private ArgumentCaptor<FnDataReceiver<WindowedValue<String>>> consumerCaptor;
-
-  @Before
-  public void setUp() {
-    MockitoAnnotations.initMocks(this);
-  }
-
-  @Test
-  public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception {
-    String bundleId = "57";
-
-    List<WindowedValue<String>> outputValues = new ArrayList<>();
-
-    MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
-    PCollectionConsumerRegistry consumers =
-        new PCollectionConsumerRegistry(
-            metricsContainerRegistry, mock(ExecutionStateTracker.class));
-    String localOutputId = "outputPC";
-    String pTransformId = "pTransformId";
-    consumers.register(
-        localOutputId,
-        pTransformId,
-        (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) outputValues::add);
-    PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
-    PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
-    List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
-
-    RunnerApi.PTransform pTransform =
-        RemoteGrpcPortRead.readFromPort(PORT_SPEC, localOutputId).toPTransform();
-
-    new BeamFnDataReadRunner.Factory<String>()
-        .createRunnerForPTransform(
-            PipelineOptionsFactory.create(),
-            mockBeamFnDataClient,
-            null /* beamFnStateClient */,
-            null /* beamFnTimerClient */,
-            pTransformId,
-            pTransform,
-            Suppliers.ofInstance(bundleId)::get,
-            ImmutableMap.of(
-                localOutputId,
-                RunnerApi.PCollection.newBuilder().setCoderId(ELEMENT_CODER_SPEC_ID).build()),
-            COMPONENTS.getCodersMap(),
-            COMPONENTS.getWindowingStrategiesMap(),
-            consumers,
-            startFunctionRegistry,
-            finishFunctionRegistry,
-            teardownFunctions::add,
-            (PTransformRunnerFactory.ProgressRequestCallback callback) -> {},
-            null /* splitListener */,
-            null /* bundleFinalizer */);
+  private static final String PTRANSFORM_ID = "ptransform_id";
 
-    assertThat(teardownFunctions, empty());
+  // Test basic executions of BeamFnDataReadRunner.
+  @RunWith(JUnit4.class)
+  public static class BeamFnDataReadRunnerExecutionTest {
+    @Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
+    @Mock private BeamFnDataClient mockBeamFnDataClient;
+    @Captor private ArgumentCaptor<FnDataReceiver<WindowedValue<String>>> consumerCaptor;
 
-    verifyZeroInteractions(mockBeamFnDataClient);
+    @Before
+    public void setUp() {
+      MockitoAnnotations.initMocks(this);
+    }
 
-    InboundDataClient completionFuture = CompletableFutureInboundDataClient.create();
-    when(mockBeamFnDataClient.receive(any(), any(), any(), any())).thenReturn(completionFuture);
-    Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
-    verify(mockBeamFnDataClient)
-        .receive(
-            eq(PORT_SPEC.getApiServiceDescriptor()),
-            eq(LogicalEndpoint.data(bundleId, pTransformId)),
-            eq(CODER),
-            consumerCaptor.capture());
+    @Test
+    public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception {
+      String bundleId = "57";
+
+      List<WindowedValue<String>> outputValues = new ArrayList<>();
+
+      MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+      PCollectionConsumerRegistry consumers =
+          new PCollectionConsumerRegistry(
+              metricsContainerRegistry, mock(ExecutionStateTracker.class));
+      String localOutputId = "outputPC";
+      String pTransformId = "pTransformId";
+      consumers.register(
+          localOutputId,
+          pTransformId,
+          (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) outputValues::add);
+      PTransformFunctionRegistry startFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+      PTransformFunctionRegistry finishFunctionRegistry =
+          new PTransformFunctionRegistry(
+              mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+      List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+      RunnerApi.PTransform pTransform =
+          RemoteGrpcPortRead.readFromPort(PORT_SPEC, localOutputId).toPTransform();
+
+      new BeamFnDataReadRunner.Factory<String>()
+          .createRunnerForPTransform(
+              PipelineOptionsFactory.create(),
+              mockBeamFnDataClient,
+              null /* beamFnStateClient */,
+              null /* beamFnTimerClient */,
+              pTransformId,
+              pTransform,
+              Suppliers.ofInstance(bundleId)::get,
+              ImmutableMap.of(
+                  localOutputId,
+                  RunnerApi.PCollection.newBuilder().setCoderId(ELEMENT_CODER_SPEC_ID).build()),
+              COMPONENTS.getCodersMap(),
+              COMPONENTS.getWindowingStrategiesMap(),
+              consumers,
+              startFunctionRegistry,
+              finishFunctionRegistry,
+              teardownFunctions::add,
+              (PTransformRunnerFactory.ProgressRequestCallback callback) -> {},
+              null /* splitListener */,
+              null /* bundleFinalizer */);
+
+      assertThat(teardownFunctions, empty());
+
+      verifyZeroInteractions(mockBeamFnDataClient);
+
+      InboundDataClient completionFuture = CompletableFutureInboundDataClient.create();
+      when(mockBeamFnDataClient.receive(any(), any(), any(), any())).thenReturn(completionFuture);
+      Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+      verify(mockBeamFnDataClient)
+          .receive(
+              eq(PORT_SPEC.getApiServiceDescriptor()),
+              eq(LogicalEndpoint.data(bundleId, pTransformId)),
+              eq(CODER),
+              consumerCaptor.capture());
+
+      consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue"));
+      assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
+      outputValues.clear();
+
+      assertThat(consumers.keySet(), containsInAnyOrder(localOutputId));
+
+      completionFuture.complete();
+      Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+
+      verifyNoMoreInteractions(mockBeamFnDataClient);
+    }
 
-    consumerCaptor.getValue().accept(valueInGlobalWindow("TestValue"));
-    assertThat(outputValues, contains(valueInGlobalWindow("TestValue")));
-    outputValues.clear();
+    @Test
+    public void testReuseForMultipleBundles() throws Exception {
+      InboundDataClient bundle1Future = CompletableFutureInboundDataClient.create();
+      InboundDataClient bundle2Future = CompletableFutureInboundDataClient.create();
+      when(mockBeamFnDataClient.receive(any(), any(), any(), any()))
+          .thenReturn(bundle1Future)
+          .thenReturn(bundle2Future);
+      List<WindowedValue<String>> values = new ArrayList<>();
+      FnDataReceiver<WindowedValue<String>> consumers = values::add;
+      AtomicReference<String> bundleId = new AtomicReference<>("0");
+      List<PTransformRunnerFactory.ProgressRequestCallback> progressCallbacks = new ArrayList<>();
+      BeamFnDataReadRunner<String> readRunner =
+          new BeamFnDataReadRunner<>(
+              INPUT_TRANSFORM_ID,
+              RemoteGrpcPortRead.readFromPort(PORT_SPEC, "localOutput").toPTransform(),
+              bundleId::get,
+              COMPONENTS.getCodersMap(),
+              mockBeamFnDataClient,
+              (PTransformRunnerFactory.ProgressRequestCallback callback) -> {
+                progressCallbacks.add(callback);
+              },
+              consumers);
+
+      // Process for bundle id 0
+      readRunner.registerInputLocation();
+
+      assertEquals(
+          createReadIndexMonitoringInfoAt(-1),
+          Iterables.getOnlyElement(
+              Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+
+      verify(mockBeamFnDataClient)
+          .receive(
+              eq(PORT_SPEC.getApiServiceDescriptor()),
+              eq(LogicalEndpoint.data(bundleId.get(), INPUT_TRANSFORM_ID)),
+              eq(CODER),
+              consumerCaptor.capture());
+
+      Future<?> future =
+          executor.submit(
+              () -> {
+                // Sleep for some small amount of time simulating the parent blocking
+                Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
+                try {
+                  consumerCaptor.getValue().accept(valueInGlobalWindow("ABC"));
+                  assertEquals(
+                      createReadIndexMonitoringInfoAt(0),
+                      Iterables.getOnlyElement(
+                          Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+                  consumerCaptor.getValue().accept(valueInGlobalWindow("DEF"));
+                  assertEquals(
+                      createReadIndexMonitoringInfoAt(1),
+                      Iterables.getOnlyElement(
+                          Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+                } catch (Exception e) {
+                  bundle1Future.fail(e);
+                } finally {
+                  bundle1Future.complete();
+                }
+              });
+
+      readRunner.blockTillReadFinishes();
+      future.get();
+      assertEquals(
+          createReadIndexMonitoringInfoAt(2),
+          Iterables.getOnlyElement(
+              Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+      assertThat(values, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
+
+      // Process for bundle id 1
+      bundleId.set("1");
+      values.clear();
+      readRunner.registerInputLocation();
+      // Ensure that when we reuse the BeamFnDataReadRunner the read index is reset to -1
+      assertEquals(
+          createReadIndexMonitoringInfoAt(-1),
+          Iterables.getOnlyElement(
+              Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+
+      verify(mockBeamFnDataClient)
+          .receive(
+              eq(PORT_SPEC.getApiServiceDescriptor()),
+              eq(LogicalEndpoint.data(bundleId.get(), INPUT_TRANSFORM_ID)),
+              eq(CODER),
+              consumerCaptor.capture());
+
+      future =
+          executor.submit(
+              () -> {
+                // Sleep for some small amount of time simulating the parent blocking
+                Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
+                try {
+                  consumerCaptor.getValue().accept(valueInGlobalWindow("GHI"));
+                  assertEquals(
+                      createReadIndexMonitoringInfoAt(0),
+                      Iterables.getOnlyElement(
+                          Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+                  consumerCaptor.getValue().accept(valueInGlobalWindow("JKL"));
+                  assertEquals(
+                      createReadIndexMonitoringInfoAt(1),
+                      Iterables.getOnlyElement(
+                          Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+                } catch (Exception e) {
+                  bundle2Future.fail(e);
+                } finally {
+                  bundle2Future.complete();
+                }
+              });
+
+      readRunner.blockTillReadFinishes();
+      future.get();
+      assertEquals(
+          createReadIndexMonitoringInfoAt(2),
+          Iterables.getOnlyElement(
+              Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
+      assertThat(values, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
+
+      verifyNoMoreInteractions(mockBeamFnDataClient);
+    }
 
-    assertThat(consumers.keySet(), containsInAnyOrder(localOutputId));
+    @Test
+    public void testRegistration() {
+      for (Registrar registrar : ServiceLoader.load(Registrar.class)) {
+        if (registrar instanceof BeamFnDataReadRunner.Registrar) {
+          assertThat(
+              registrar.getPTransformRunnerFactories(),
+              IsMapContaining.hasKey(RemoteGrpcPortRead.URN));
+          return;
+        }
+      }
+      fail("Expected registrar not found.");
+    }
 
-    completionFuture.complete();
-    Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+    @Test
+    public void testSplittingWhenNoElementsProcessed() throws Exception {
+      List<WindowedValue<String>> outputValues = new ArrayList<>();
+      BeamFnDataReadRunner<String> readRunner =
+          createReadRunner(outputValues::add, PTRANSFORM_ID, mockBeamFnDataClient);
+      readRunner.registerInputLocation();
+      // The split should happen at 5 since the allowedSplitPoints is empty.
+      assertEquals(
+          channelSplitResult(5),
+          executeSplit(readRunner, PTRANSFORM_ID, -1L, 0.5, 10, Collections.EMPTY_LIST));
+
+      // Ensure that we process the correct number of elements after splitting.
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("A"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("B"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("C"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("D"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("E"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("F"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("G"));
+      assertThat(
+          outputValues,
+          contains(
+              valueInGlobalWindow("A"),
+              valueInGlobalWindow("B"),
+              valueInGlobalWindow("C"),
+              valueInGlobalWindow("D"),
+              valueInGlobalWindow("E")));
+    }
 
-    verifyNoMoreInteractions(mockBeamFnDataClient);
+    @Test
+    public void testSplittingWhenSomeElementsProcessed() throws Exception {
+      List<WindowedValue<String>> outputValues = new ArrayList<>();
+      BeamFnDataReadRunner<String> readRunner =
+          createReadRunner(outputValues::add, PTRANSFORM_ID, mockBeamFnDataClient);
+      readRunner.registerInputLocation();
+      assertEquals(
+          channelSplitResult(6),
+          executeSplit(readRunner, PTRANSFORM_ID, 1L, 0.5, 10, Collections.EMPTY_LIST));
+
+      // Ensure that we process the correct number of elements after splitting.
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("1"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("2"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("3"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("4"));
+      readRunner.forwardElementToConsumer(valueInGlobalWindow("5"));
+      assertThat(
+          outputValues,
+          contains(
+              valueInGlobalWindow("-1"),
+              valueInGlobalWindow("0"),
+              valueInGlobalWindow("1"),
+              valueInGlobalWindow("2"),
+              valueInGlobalWindow("3"),
+              valueInGlobalWindow("4")));
+    }
   }
 
-  @Test
-  public void testReuseForMultipleBundles() throws Exception {
-    InboundDataClient bundle1Future = CompletableFutureInboundDataClient.create();
-    InboundDataClient bundle2Future = CompletableFutureInboundDataClient.create();
-    when(mockBeamFnDataClient.receive(any(), any(), any(), any()))
-        .thenReturn(bundle1Future)
-        .thenReturn(bundle2Future);
-    List<WindowedValue<String>> values = new ArrayList<>();
-    FnDataReceiver<WindowedValue<String>> consumers = values::add;
-    AtomicReference<String> bundleId = new AtomicReference<>("0");
-    List<PTransformRunnerFactory.ProgressRequestCallback> progressCallbacks = new ArrayList<>();
-    BeamFnDataReadRunner<String> readRunner =
-        new BeamFnDataReadRunner<>(
-            INPUT_TRANSFORM_ID,
-            RemoteGrpcPortRead.readFromPort(PORT_SPEC, "localOutput").toPTransform(),
-            bundleId::get,
-            COMPONENTS.getCodersMap(),
-            mockBeamFnDataClient,
-            (PTransformRunnerFactory.ProgressRequestCallback callback) -> {
-              progressCallbacks.add(callback);
-            },
-            consumers);
-
-    // Process for bundle id 0
-    readRunner.registerInputLocation();
-
-    assertEquals(
-        createReadIndexMonitoringInfoAt(-1),
-        Iterables.getOnlyElement(Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-
-    verify(mockBeamFnDataClient)
-        .receive(
-            eq(PORT_SPEC.getApiServiceDescriptor()),
-            eq(LogicalEndpoint.data(bundleId.get(), INPUT_TRANSFORM_ID)),
-            eq(CODER),
-            consumerCaptor.capture());
-
-    Future<?> future =
-        executor.submit(
-            () -> {
-              // Sleep for some small amount of time simulating the parent blocking
-              Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
-              try {
-                consumerCaptor.getValue().accept(valueInGlobalWindow("ABC"));
-                assertEquals(
-                    createReadIndexMonitoringInfoAt(0),
-                    Iterables.getOnlyElement(
-                        Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-                consumerCaptor.getValue().accept(valueInGlobalWindow("DEF"));
-                assertEquals(
-                    createReadIndexMonitoringInfoAt(1),
-                    Iterables.getOnlyElement(
-                        Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-              } catch (Exception e) {
-                bundle1Future.fail(e);
-              } finally {
-                bundle1Future.complete();
-              }
-            });
-
-    readRunner.blockTillReadFinishes();
-    future.get();
-    assertEquals(
-        createReadIndexMonitoringInfoAt(2),
-        Iterables.getOnlyElement(Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-    assertThat(values, contains(valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF")));
-
-    // Process for bundle id 1
-    bundleId.set("1");
-    values.clear();
-    readRunner.registerInputLocation();
-    // Ensure that when we reuse the BeamFnDataReadRunner the read index is reset to -1
-    assertEquals(
-        createReadIndexMonitoringInfoAt(-1),
-        Iterables.getOnlyElement(Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-
-    verify(mockBeamFnDataClient)
-        .receive(
-            eq(PORT_SPEC.getApiServiceDescriptor()),
-            eq(LogicalEndpoint.data(bundleId.get(), INPUT_TRANSFORM_ID)),
-            eq(CODER),
-            consumerCaptor.capture());
-
-    future =
-        executor.submit(
-            () -> {
-              // Sleep for some small amount of time simulating the parent blocking
-              Uninterruptibles.sleepUninterruptibly(100, TimeUnit.MILLISECONDS);
-              try {
-                consumerCaptor.getValue().accept(valueInGlobalWindow("GHI"));
-                assertEquals(
-                    createReadIndexMonitoringInfoAt(0),
-                    Iterables.getOnlyElement(
-                        Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-                consumerCaptor.getValue().accept(valueInGlobalWindow("JKL"));
-                assertEquals(
-                    createReadIndexMonitoringInfoAt(1),
-                    Iterables.getOnlyElement(
-                        Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-              } catch (Exception e) {
-                bundle2Future.fail(e);
-              } finally {
-                bundle2Future.complete();
-              }
-            });
-
-    readRunner.blockTillReadFinishes();
-    future.get();
-    assertEquals(
-        createReadIndexMonitoringInfoAt(2),
-        Iterables.getOnlyElement(Iterables.getOnlyElement(progressCallbacks).getMonitoringInfos()));
-    assertThat(values, contains(valueInGlobalWindow("GHI"), valueInGlobalWindow("JKL")));
-
-    verifyNoMoreInteractions(mockBeamFnDataClient);
-  }
+  // Test different cases of chan nel split with empty allowed split points.
+  @RunWith(Parameterized.class)
+  public static class ChannelSplitTest {
+
+    @Parameterized.Parameters
+    public static Iterable<Object[]> data() {
+      return ImmutableList.<Object[]>builder()
+          // Split as close to the beginning as possible.
+          .add(new Object[] {channelSplitResult(1L), 0L, 0, 0, 16L})
+          // The closest split is at 4, even when just above or below it.
+          .add(new Object[] {channelSplitResult(4L), 0L, 0, 0.24, 16L})
+          .add(new Object[] {channelSplitResult(4L), 0L, 0, 0.25, 16L})
+          .add(new Object[] {channelSplitResult(4L), 0L, 0, 0.26, 16L})
+          // Split the *remainder* in half.
+          .add(new Object[] {channelSplitResult(8L), 0L, 0, 0.5, 16L})
+          .add(new Object[] {channelSplitResult(9L), 2, 0, 0.5, 16L})
+          .add(new Object[] {channelSplitResult(11L), 6L, 0, 0.5, 16L})
+          // Progress into the active element influences where the split of the remainder falls.
+          .add(new Object[] {channelSplitResult(1L), 0L, 0.5, 0.25, 4L})
+          .add(new Object[] {channelSplitResult(2L), 0L, 0.9, 0.25, 4L})
+          .add(new Object[] {channelSplitResult(2L), 1L, 0, 0.25, 4L})
+          .add(new Object[] {channelSplitResult(2L), 1L, 0.1, 0.25, 4L})
+          .build();
+    }
 
-  @Test
-  public void testRegistration() {
-    for (Registrar registrar : ServiceLoader.load(Registrar.class)) {
-      if (registrar instanceof BeamFnDataReadRunner.Registrar) {
-        assertThat(
-            registrar.getPTransformRunnerFactories(),
-            IsMapContaining.hasKey(RemoteGrpcPortRead.URN));
-        return;
-      }
+    @Parameterized.Parameter(0)
+    public ProcessBundleSplitResponse expectedResponse;
+
+    @Parameterized.Parameter(1)
+    public long index;
+
+    @Parameterized.Parameter(2)
+    public double elementProgress;
+
+    @Parameterized.Parameter(3)
+    public double fractionOfRemainder;
+
+    @Parameterized.Parameter(4)
+    public long bufferSize;
+
+    @Test
+    public void testChannelSplit() throws Exception {
+      SplittingReceiver splittingReceiver = mock(SplittingReceiver.class);
+      BeamFnDataClient mockBeamFnDataClient = mock(BeamFnDataClient.class);
+      when(splittingReceiver.getProgress()).thenReturn(elementProgress);
+      BeamFnDataReadRunner<String> readRunner =
+          createReadRunner(splittingReceiver, PTRANSFORM_ID, mockBeamFnDataClient);
+      readRunner.registerInputLocation();
+      assertEquals(
+          expectedResponse,
+          executeSplit(
+              readRunner,
+              PTRANSFORM_ID,
+              index,
+              fractionOfRemainder,
+              bufferSize,
+              Collections.EMPTY_LIST));
     }
-    fail("Expected registrar not found.");
   }
 
-  @Test
-  public void testSplittingWhenNoElementsProcessed() throws Exception {
-    List<WindowedValue<String>> outputValues = new ArrayList<>();
-    BeamFnDataReadRunner<String> readRunner = createReadRunner(outputValues::add);
-    readRunner.registerInputLocation();
-
-    ProcessBundleSplitRequest request =
-        ProcessBundleSplitRequest.newBuilder()
-            .putDesiredSplits(
-                "pTransformId",
-                DesiredSplit.newBuilder()
-                    .setEstimatedInputElements(10)
-                    .setFractionOfRemainder(0.5)
-                    .build())
-            .build();
-    ProcessBundleSplitResponse.Builder responseBuilder = ProcessBundleSplitResponse.newBuilder();
-    readRunner.trySplit(request, responseBuilder);
+  // Test different cases of channel split with non-empty allowed split points.
+  @RunWith(Parameterized.class)
+  public static class ChannelSplitWithAllowedSplitPointsTest {
+    @Parameterized.Parameters
+    public static Iterable<Object[]> data() {
+      return ImmutableList.<Object[]>builder()
+          // The desired split point is at 4.
+          .add(
+              new Object[] {
+                channelSplitResult(4L), 0L, 0.25, 16L, ImmutableList.of(2L, 3L, 4L, 5L)
+              })
+          // If we can't split at 4, choose the closest possible split point.
+          .add(new Object[] {channelSplitResult(5L), 0L, 0.25, 16L, ImmutableList.of(2L, 3L, 5L)})
+          .add(new Object[] {channelSplitResult(3L), 0L, 0.25, 16L, ImmutableList.of(2L, 3L, 6L)})
+          // Also test the case where all possible split points lie above or below the desired split
+          // point.
+          .add(new Object[] {channelSplitResult(5L), 0L, 0.25, 16L, ImmutableList.of(5L, 6L, 7L)})
+          .add(new Object[] {channelSplitResult(3L), 0L, 0.25, 16L, ImmutableList.of(1L, 2L, 3L)})
+          // We have progressed beyond all possible split points, so can't split.
+          .add(
+              new Object[] {
+                ProcessBundleSplitResponse.getDefaultInstance(),
+                5L,
+                0.25,
+                16L,
+                ImmutableList.of(1L, 2L, 3L)
+              })
+          .build();
+    }
 
-    ProcessBundleSplitResponse expected =
-        ProcessBundleSplitResponse.newBuilder()
-            .addChannelSplits(
-                ChannelSplit.newBuilder()
-                    .setLastPrimaryElement(4)
-                    .setFirstResidualElement(5)
-                    .build())
-            .build();
-    assertEquals(expected, responseBuilder.build());
-
-    // Ensure that we process the correct number of elements after splitting.
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("A"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("B"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("C"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("D"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("E"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("F"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("G"));
-    assertThat(
-        outputValues,
-        contains(
-            valueInGlobalWindow("A"),
-            valueInGlobalWindow("B"),
-            valueInGlobalWindow("C"),
-            valueInGlobalWindow("D"),
-            valueInGlobalWindow("E")));
+    @Parameterized.Parameter(0)
+    public ProcessBundleSplitResponse expectedResponse;
+
+    @Parameterized.Parameter(1)
+    public long index;
+
+    @Parameterized.Parameter(2)
+    public double fractionOfRemainder;
+
+    @Parameterized.Parameter(3)
+    public long bufferSize;
+
+    @Parameterized.Parameter(4)
+    public List<Long> allowedSplitPoints;
+
+    @Test
+    public void testChannelSplittingWithAllowedSplitPoints() throws Exception {
+      List<WindowedValue<String>> outputValues = new ArrayList<>();
+      BeamFnDataClient mockBeamFnDataClient = mock(BeamFnDataClient.class);
+      BeamFnDataReadRunner<String> readRunner =
+          createReadRunner(outputValues::add, PTRANSFORM_ID, mockBeamFnDataClient);
+      readRunner.registerInputLocation();
+      assertEquals(
+          expectedResponse,
+          executeSplit(
+              readRunner,
+              PTRANSFORM_ID,
+              index,
+              fractionOfRemainder,
+              bufferSize,
+              allowedSplitPoints));
+    }
   }
 
-  @Test
-  public void testSplittingWhenSomeElementsProcessed() throws Exception {
-    List<WindowedValue<String>> outputValues = new ArrayList<>();
-    BeamFnDataReadRunner<String> readRunner = createReadRunner(outputValues::add);
-    readRunner.registerInputLocation();
-
-    ProcessBundleSplitRequest request =
-        ProcessBundleSplitRequest.newBuilder()
-            .putDesiredSplits(
-                "pTransformId",
-                DesiredSplit.newBuilder()
-                    .setEstimatedInputElements(10)
-                    .setFractionOfRemainder(0.5)
-                    .build())
-            .build();
-    ProcessBundleSplitResponse.Builder responseBuilder = ProcessBundleSplitResponse.newBuilder();
-
-    // Process 2 elements then split
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("A"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("B"));
-    readRunner.trySplit(request, responseBuilder);
+  // Test different cases of element split with empty allowed split points.
+  @RunWith(Parameterized.class)
+  public static class ElementSplitTest {
+    @Parameterized.Parameters
+    public static Iterable<Object[]> data() {
+      return ImmutableList.<Object[]>builder()
+          // Split between future elements at element boundaries.
+          .add(new Object[] {channelSplitResult(2L), 0L, 0, 0.51, 4L})
+          .add(new Object[] {channelSplitResult(2L), 0L, 0, 0.49, 4L})
+          .add(new Object[] {channelSplitResult(1L), 0L, 0, 0.26, 4L})
+          .add(new Object[] {channelSplitResult(1L), 0L, 0, 0.25, 4L})
+          // If the split falls inside the first, splittable element, split there.
+          .add(new Object[] {elementSplitResult(0L, 0.8), 0L, 0, 0.2, 4L})
+          // The choice of split depends on the progress into the first element.
+          .add(new Object[] {elementSplitResult(0L, 0.5), 0L, 0, 0.125, 4L})
+          // Here we are far enough into the first element that splitting at 0.2 of the remainder
+          // falls outside the first element.
+          .add(new Object[] {channelSplitResult(1L), 0L, 0.5, 0.2, 4L})
+          // Verify the above logic when we are partially through the stream.
+          .add(new Object[] {channelSplitResult(3L), 2L, 0, 0.6, 4L})
+          .add(new Object[] {channelSplitResult(4L), 2L, 0.9, 0.6, 4L})
+          .add(new Object[] {elementSplitResult(2L, 0.6), 2L, 0.5, 0.2, 4L})
+          .build();
+    }
 
-    ProcessBundleSplitResponse expected =
-        ProcessBundleSplitResponse.newBuilder()
-            .addChannelSplits(
-                ChannelSplit.newBuilder()
-                    .setLastPrimaryElement(5)
-                    .setFirstResidualElement(6)
-                    .build())
-            .build();
-    assertEquals(expected, responseBuilder.build());
-
-    // Ensure that we process the correct number of elements after splitting.
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("C"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("D"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("E"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("F"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("G"));
-    assertThat(
-        outputValues,
-        contains(
-            valueInGlobalWindow("A"),
-            valueInGlobalWindow("B"),
-            valueInGlobalWindow("C"),
-            valueInGlobalWindow("D"),
-            valueInGlobalWindow("E"),
-            valueInGlobalWindow("F")));
+    @Parameterized.Parameter(0)
+    public ProcessBundleSplitResponse expectedResponse;
+
+    @Parameterized.Parameter(1)
+    public long index;
+
+    @Parameterized.Parameter(2)
+    public double elementProgress;
+
+    @Parameterized.Parameter(3)
+    public double fractionOfRemainder;
+
+    @Parameterized.Parameter(4)
+    public long bufferSize;
+
+    @Test
+    public void testElementSplit() throws Exception {
+      SplittingReceiver splittingReceiver = mock(SplittingReceiver.class);
+      BeamFnDataClient mockBeamFnDataClient = mock(BeamFnDataClient.class);
+      when(splittingReceiver.getProgress()).thenReturn(elementProgress);
+      when(splittingReceiver.trySplit(anyDouble())).thenCallRealMethod();
+      BeamFnDataReadRunner<String> readRunner =
+          createReadRunner(splittingReceiver, PTRANSFORM_ID, mockBeamFnDataClient);
+      readRunner.registerInputLocation();
+
+      assertEquals(
+          expectedResponse,
+          executeSplit(
+              readRunner,
+              PTRANSFORM_ID,
+              index,
+              fractionOfRemainder,
+              bufferSize,
+              Collections.EMPTY_LIST));
+    }
   }
 
-  @Test
-  public void testSplittingDownstreamReceiver() throws Exception {
-    SplitResult splitResult =
-        SplitResult.of(
-            BundleApplication.newBuilder().setInputId("primary").build(),
-            DelayedBundleApplication.newBuilder()
-                .setApplication(BundleApplication.newBuilder().setInputId("residual").build())
-                .build());
-    SplittingReceiver splittingReceiver = mock(SplittingReceiver.class);
-    when(splittingReceiver.getProgress()).thenReturn(0.3);
-    when(splittingReceiver.trySplit(anyDouble())).thenReturn(splitResult);
-    BeamFnDataReadRunner<String> readRunner = createReadRunner(splittingReceiver);
-    readRunner.registerInputLocation();
-
-    ProcessBundleSplitRequest request =
-        ProcessBundleSplitRequest.newBuilder()
-            .putDesiredSplits(
-                "pTransformId",
-                DesiredSplit.newBuilder()
-                    .setEstimatedInputElements(10)
-                    .setFractionOfRemainder(0.05)
-                    .build())
-            .build();
-    ProcessBundleSplitResponse.Builder responseBuilder = ProcessBundleSplitResponse.newBuilder();
-
-    // We will be "processing" the 'C' element, aka 2nd index
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("A"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("B"));
-    readRunner.forwardElementToConsumer(valueInGlobalWindow("C"));
-    readRunner.trySplit(request, responseBuilder);
+  // Test different cases of element split with non-empty allowed split points.
+  @RunWith(Parameterized.class)
+  public static class ElementSplitWithAllowedSplitPointsTest {
+    @Parameterized.Parameters
+    public static Iterable<Object[]> data() {
+      return ImmutableList.<Object[]>builder()
+          // This is where we would like to split, when all split points are available.
+          .add(
+              new Object[] {
+                elementSplitResult(2L, 0.6), 2L, 0, 0.2, 5L, ImmutableList.of(1L, 2L, 3L, 4L, 5L)
+              })
+          // This is where we would like to split, when all split points are available.
+          .add(
+              new Object[] {
+                channelSplitResult(4L), 2L, 0, 0.2, 5L, ImmutableList.of(1L, 2L, 4L, 5L)
+              })
+          // We can't even split element at index 4 as above, because 4 is also not a split point.
+          .add(new Object[] {channelSplitResult(5L), 2L, 0, 0.2, 5L, ImmutableList.of(1L, 2L, 5L)})
+          // We can't split element at index 2, because 2 is not a split point.
+          .add(
+              new Object[] {
+                channelSplitResult(3L), 2L, 0, 0.2, 5L, ImmutableList.of(1L, 3L, 4L, 5L)
+              })
+          .build();
+    }
 
-    ProcessBundleSplitResponse expected =
-        ProcessBundleSplitResponse.newBuilder()
-            .addPrimaryRoots(splitResult.getPrimaryRoot())
-            .addResidualRoots(splitResult.getResidualRoot())
-            .addChannelSplits(
-                ChannelSplit.newBuilder()
-                    .setLastPrimaryElement(1)
-                    .setFirstResidualElement(3)
-                    .build())
-            .build();
-    assertEquals(expected, responseBuilder.build());
+    @Parameterized.Parameter(0)
+    public ProcessBundleSplitResponse expectedResponse;
+
+    @Parameterized.Parameter(1)
+    public long index;
+
+    @Parameterized.Parameter(2)
+    public double elementProgress;
+
+    @Parameterized.Parameter(3)
+    public double fractionOfRemainder;
+
+    @Parameterized.Parameter(4)
+    public long bufferSize;
+
+    @Parameterized.Parameter(5)
+    public List<Long> allowedSplitPoints;
+
+    @Test
+    public void testElementSplittingWithAllowedSplitPoints() throws Exception {
+      SplittingReceiver splittingReceiver = mock(SplittingReceiver.class);
+      BeamFnDataClient mockBeamFnDataClient = mock(BeamFnDataClient.class);
+      when(splittingReceiver.getProgress()).thenReturn(elementProgress);
+      when(splittingReceiver.trySplit(anyDouble())).thenCallRealMethod();
+      BeamFnDataReadRunner<String> readRunner =
+          createReadRunner(splittingReceiver, PTRANSFORM_ID, mockBeamFnDataClient);
+      readRunner.registerInputLocation();
+      assertEquals(
+          expectedResponse,
+          executeSplit(
+              readRunner,
+              PTRANSFORM_ID,
+              index,
+              fractionOfRemainder,
+              bufferSize,
+              allowedSplitPoints));
+    }
   }
 
   private abstract static class SplittingReceiver
-      implements FnDataReceiver<WindowedValue<String>>, HandlesSplits {}
+      implements FnDataReceiver<WindowedValue<String>>, HandlesSplits {
+    @Override
+    public SplitResult trySplit(double fractionOfRemainder) {
+      return SplitResult.of(
+          BundleApplication.newBuilder()
+              .setInputId(String.format("primary%.1f", fractionOfRemainder))
+              .build(),
+          DelayedBundleApplication.newBuilder()
+              .setApplication(
+                  BundleApplication.newBuilder()
+                      .setInputId(String.format("residual%.1f", 1 - fractionOfRemainder))
+                      .build())
+              .build());
+    }
+  }
 
-  private BeamFnDataReadRunner<String> createReadRunner(
-      FnDataReceiver<WindowedValue<String>> consumer) throws Exception {
+  private static BeamFnDataReadRunner<String> createReadRunner(
+      FnDataReceiver<WindowedValue<String>> consumer,
+      String pTransformId,
+      BeamFnDataClient dataClient)
+      throws Exception {
     String bundleId = "57";
 
     MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
@@ -489,7 +687,6 @@ public class BeamFnDataReadRunnerTest {
         new PCollectionConsumerRegistry(
             metricsContainerRegistry, mock(ExecutionStateTracker.class));
     String localOutputId = "outputPC";
-    String pTransformId = "pTransformId";
     consumers.register(localOutputId, pTransformId, consumer);
     PTransformFunctionRegistry startFunctionRegistry =
         new PTransformFunctionRegistry(
@@ -505,7 +702,7 @@ public class BeamFnDataReadRunnerTest {
     return new BeamFnDataReadRunner.Factory<String>()
         .createRunnerForPTransform(
             PipelineOptionsFactory.create(),
-            mockBeamFnDataClient,
+            dataClient,
             null /* beamFnStateClient */,
             null /* beamFnTimerClient */,
             pTransformId,
@@ -532,4 +729,62 @@ public class BeamFnDataReadRunnerTest {
         .setInt64SumValue(index)
         .build();
   }
+
+  private static ProcessBundleSplitResponse executeSplit(
+      BeamFnDataReadRunner<String> readRunner,
+      String pTransformId,
+      long index,
+      double fractionOfRemainder,
+      long inputElements,
+      List<Long> allowedSplitPoints)
+      throws Exception {
+    for (long i = -1; i < index; i++) {
+      readRunner.forwardElementToConsumer(valueInGlobalWindow(Long.valueOf(i).toString()));
+    }
+    ProcessBundleSplitRequest request =
+        ProcessBundleSplitRequest.newBuilder()
+            .putDesiredSplits(
+                pTransformId,
+                DesiredSplit.newBuilder()
+                    .setEstimatedInputElements(inputElements)
+                    .setFractionOfRemainder(fractionOfRemainder)
+                    .addAllAllowedSplitPoints(allowedSplitPoints)
+                    .build())
+            .build();
+    ProcessBundleSplitResponse.Builder responseBuilder = ProcessBundleSplitResponse.newBuilder();
+    readRunner.trySplit(request, responseBuilder);
+    return responseBuilder.build();
+  }
+
+  private static ProcessBundleSplitResponse channelSplitResult(long firstResidualIndex) {
+    return ProcessBundleSplitResponse.newBuilder()
+        .addChannelSplits(
+            ChannelSplit.newBuilder()
+                .setLastPrimaryElement(firstResidualIndex - 1)
+                .setFirstResidualElement(firstResidualIndex)
+                .build())
+        .build();
+  }
+
+  private static ProcessBundleSplitResponse elementSplitResult(
+      long index, double fractionOfRemainder) {
+    return ProcessBundleSplitResponse.newBuilder()
+        .addPrimaryRoots(
+            BundleApplication.newBuilder()
+                .setInputId(String.format("primary%.1f", fractionOfRemainder))
+                .build())
+        .addResidualRoots(
+            DelayedBundleApplication.newBuilder()
+                .setApplication(
+                    BundleApplication.newBuilder()
+                        .setInputId(String.format("residual%.1f", 1 - fractionOfRemainder))
+                        .build())
+                .build())
+        .addChannelSplits(
+            ChannelSplit.newBuilder()
+                .setLastPrimaryElement(index - 1)
+                .setFirstResidualElement(index + 1)
+                .build())
+        .build();
+  }
 }
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
index 6c65189..722e560 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor_test.py
@@ -115,7 +115,7 @@ class SplitTest(unittest.TestCase):
     # remainder falls outside the first element.
     self.assertEqual(self.sdf_split(0, .5, 0.2, 4), simple_split(1))
 
-    # Verify the above logic when we are partially throug the stream.
+    # Verify the above logic when we are partially through the stream.
     self.assertEqual(self.sdf_split(2, 0, 0.6, 4), simple_split(3))
     self.assertEqual(self.sdf_split(2, 0.9, 0.6, 4), simple_split(4))
     self.assertEqual(