You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/06/05 21:04:50 UTC
[beam] branch master updated: [BEAM-2939] Fix FnApiDoFnRunner to
ensure that we output within the correct window when processing a
splittable dofn (#11922)
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 96836a7 [BEAM-2939] Fix FnApiDoFnRunner to ensure that we output within the correct window when processing a splittable dofn (#11922)
96836a7 is described below
commit 96836a741558ba93ab367b44ad46fbbf28b86449
Author: Lukasz Cwik <lu...@gmail.com>
AuthorDate: Fri Jun 5 14:04:28 2020 -0700
[BEAM-2939] Fix FnApiDoFnRunner to ensure that we output within the correct window when processing a splittable dofn (#11922)
* [BEAM-2939] Fix FnApiDoFnRunner to ensure that we output within the correct window.
This fixes a bug where we would output within all the windows instead of just the current window.
This would not impact any SDF that used only a single window while processing.
* Make sure that splitting/checkpointing is window aware.
* fixup! Address PR comments.
---
.../beam/fn/harness/BeamFnDataReadRunner.java | 4 +-
.../apache/beam/fn/harness/FnApiDoFnRunner.java | 238 +++++--
.../org/apache/beam/fn/harness/HandlesSplits.java | 11 +-
.../fn/harness/control/BundleSplitListener.java | 12 +-
.../beam/fn/harness/BeamFnDataReadRunnerTest.java | 20 +-
.../beam/fn/harness/FnApiDoFnRunnerTest.java | 686 ++++++++++++++++++++-
.../harness/control/BundleSplitListenerTest.java | 7 +-
7 files changed, 910 insertions(+), 68 deletions(-)
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
index 3b10b33..34cdcb4 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java
@@ -292,8 +292,8 @@ public class BeamFnDataReadRunner<OutputT> {
if (splitResult != null) {
stopIndex = index + 1;
response
- .addPrimaryRoots(splitResult.getPrimaryRoot())
- .addResidualRoots(splitResult.getResidualRoot())
+ .addAllPrimaryRoots(splitResult.getPrimaryRoots())
+ .addAllResidualRoots(splitResult.getResidualRoots())
.addChannelSplitsBuilder()
.setLastPrimaryElement(index - 1)
.setFirstResidualElement(stopIndex);
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 0baf76b..1c8ba32 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -24,16 +24,19 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
+import java.util.ListIterator;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Supplier;
+import javax.annotation.Nullable;
import org.apache.beam.fn.harness.PTransformRunnerFactory.ProgressRequestCallback;
import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
@@ -297,6 +300,12 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
private WindowedValue<InputT> currentElement;
/**
+ * Only valid during {@link #processElementForElementAndRestriction} and {@link
+ * #processElementForSizedElementAndRestriction}.
+ */
+ private ListIterator<BoundedWindow> currentWindowIterator;
+
+ /**
* Only valid during {@link #processElementForPairWithRestriction}, {@link
* #processElementForSplitRestriction}, {@link #processElementForElementAndRestriction} and {@link
* #processElementForSizedElementAndRestriction}, null otherwise.
@@ -577,20 +586,83 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
switch (pTransform.getSpec().getUrn()) {
case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN:
this.convertSplitResultToWindowedSplitResult =
- (splitResult, watermarkEstimatorState) ->
- WindowedSplitResult.forRoots(
- currentElement.withValue(
- KV.of(
- currentElement.getValue(),
- KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState))),
- currentElement.withValue(
- KV.of(
- currentElement.getValue(),
- KV.of(splitResult.getResidual(), watermarkEstimatorState))));
+ (splitResult, watermarkEstimatorState) -> {
+ List<BoundedWindow> primaryFullyProcessedWindows =
+ ImmutableList.copyOf(
+ Iterables.limit(
+ currentElement.getWindows(), currentWindowIterator.previousIndex()));
+ // Advances the iterator consuming the remaining windows.
+ List<BoundedWindow> residualUnprocessedWindows =
+ ImmutableList.copyOf(currentWindowIterator);
+ return WindowedSplitResult.forRoots(
+ primaryFullyProcessedWindows.isEmpty()
+ ? null
+ : WindowedValue.of(
+ KV.of(
+ currentElement.getValue(),
+ KV.of(currentRestriction, currentWatermarkEstimatorState)),
+ currentElement.getTimestamp(),
+ primaryFullyProcessedWindows,
+ currentElement.getPane()),
+ WindowedValue.of(
+ KV.of(
+ currentElement.getValue(),
+ KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
+ currentElement.getTimestamp(),
+ currentWindow,
+ currentElement.getPane()),
+ WindowedValue.of(
+ KV.of(
+ currentElement.getValue(),
+ KV.of(splitResult.getResidual(), watermarkEstimatorState)),
+ currentElement.getTimestamp(),
+ currentWindow,
+ currentElement.getPane()),
+ residualUnprocessedWindows.isEmpty()
+ ? null
+ : WindowedValue.of(
+ KV.of(
+ currentElement.getValue(),
+ KV.of(currentRestriction, currentWatermarkEstimatorState)),
+ currentElement.getTimestamp(),
+ residualUnprocessedWindows,
+ currentElement.getPane()));
+ };
break;
case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN:
this.convertSplitResultToWindowedSplitResult =
(splitResult, watermarkEstimatorState) -> {
+ List<BoundedWindow> primaryFullyProcessedWindows =
+ ImmutableList.copyOf(
+ Iterables.limit(
+ currentElement.getWindows(), currentWindowIterator.previousIndex()));
+ // Advances the iterator consuming the remaining windows.
+ List<BoundedWindow> residualUnprocessedWindows =
+ ImmutableList.copyOf(currentWindowIterator);
+ // If the window has been observed then the splitAndSize method would have already
+ // output sizes for each window separately.
+ //
+ // TODO: Consider using the original size on the element instead of recomputing
+ // this here.
+ double fullSize =
+ primaryFullyProcessedWindows.isEmpty() && residualUnprocessedWindows.isEmpty()
+ ? 0
+ : doFnInvoker.invokeGetSize(
+ new DelegatingArgumentProvider<InputT, OutputT>(
+ processContext,
+ PTransformTranslation
+ .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN
+ + "/GetPrimarySize") {
+ @Override
+ public Object restriction() {
+ return currentRestriction;
+ }
+
+ @Override
+ public RestrictionTracker<?, ?> restrictionTracker() {
+ return doFnInvoker.invokeNewTracker(this);
+ }
+ });
double primarySize =
doFnInvoker.invokeGetSize(
new DelegatingArgumentProvider<InputT, OutputT>(
@@ -626,18 +698,46 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
}
});
return WindowedSplitResult.forRoots(
- currentElement.withValue(
+ primaryFullyProcessedWindows.isEmpty()
+ ? null
+ : WindowedValue.of(
+ KV.of(
+ KV.of(
+ currentElement.getValue(),
+ KV.of(currentRestriction, currentWatermarkEstimatorState)),
+ fullSize),
+ currentElement.getTimestamp(),
+ primaryFullyProcessedWindows,
+ currentElement.getPane()),
+ WindowedValue.of(
KV.of(
KV.of(
currentElement.getValue(),
KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
- primarySize)),
- currentElement.withValue(
+ primarySize),
+ currentElement.getTimestamp(),
+ currentWindow,
+ currentElement.getPane()),
+ WindowedValue.of(
KV.of(
KV.of(
currentElement.getValue(),
KV.of(splitResult.getResidual(), watermarkEstimatorState)),
- residualSize)));
+ residualSize),
+ currentElement.getTimestamp(),
+ currentWindow,
+ currentElement.getPane()),
+ residualUnprocessedWindows.isEmpty()
+ ? null
+ : WindowedValue.of(
+ KV.of(
+ KV.of(
+ currentElement.getValue(),
+ KV.of(currentRestriction, currentWatermarkEstimatorState)),
+ fullSize),
+ currentElement.getTimestamp(),
+ residualUnprocessedWindows,
+ currentElement.getPane()));
};
break;
default:
@@ -683,7 +783,6 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
break;
default:
// no-op
-
}
}
@@ -755,12 +854,15 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
outputTo(
mainOutputConsumers,
(WindowedValue)
- elem.withValue(
+ WindowedValue.of(
KV.of(
elem.getValue(),
KV.of(
currentRestriction,
- doFnInvoker.invokeGetInitialWatermarkEstimatorState(processContext)))));
+ doFnInvoker.invokeGetInitialWatermarkEstimatorState(processContext))),
+ currentElement.getTimestamp(),
+ currentWindow,
+ currentElement.getPane()));
}
} finally {
currentElement = null;
@@ -793,13 +895,26 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
@AutoValue
abstract static class WindowedSplitResult {
public static WindowedSplitResult forRoots(
- WindowedValue primaryRoot, WindowedValue residualRoot) {
- return new AutoValue_FnApiDoFnRunner_WindowedSplitResult(primaryRoot, residualRoot);
+ WindowedValue primaryInFullyProcessedWindowsRoot,
+ WindowedValue primarySplitRoot,
+ WindowedValue residualSplitRoot,
+ WindowedValue residualInUnprocessedWindowsRoot) {
+ return new AutoValue_FnApiDoFnRunner_WindowedSplitResult(
+ primaryInFullyProcessedWindowsRoot,
+ primarySplitRoot,
+ residualSplitRoot,
+ residualInUnprocessedWindowsRoot);
}
- public abstract WindowedValue getPrimaryRoot();
+ @Nullable
+ public abstract WindowedValue getPrimaryInFullyProcessedWindowsRoot();
- public abstract WindowedValue getResidualRoot();
+ public abstract WindowedValue getPrimarySplitRoot();
+
+ public abstract WindowedValue getResidualSplitRoot();
+
+ @Nullable
+ public abstract WindowedValue getResidualInUnprocessedWindowsRoot();
}
private void processElementForSizedElementAndRestriction(
@@ -811,13 +926,18 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
currentElement = elem.withValue(elem.getValue().getKey());
try {
- Iterator<BoundedWindow> windowIterator =
- (Iterator<BoundedWindow>) elem.getWindows().iterator();
- while (windowIterator.hasNext()) {
+ currentWindowIterator =
+ currentElement.getWindows() instanceof List
+ ? ((List) currentElement.getWindows()).listIterator()
+ : ImmutableList.<BoundedWindow>copyOf(elem.getWindows()).listIterator();
+ while (true) {
synchronized (splitLock) {
+ if (!currentWindowIterator.hasNext()) {
+ return;
+ }
currentRestriction = elem.getValue().getValue().getKey();
currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
- currentWindow = windowIterator.next();
+ currentWindow = currentWindowIterator.next();
currentTracker =
RestrictionTrackers.observe(
doFnInvoker.invokeNewTracker(processContext),
@@ -845,21 +965,25 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
// Attempt to checkpoint the current restriction.
HandlesSplits.SplitResult splitResult =
trySplitForElementAndRestriction(0, continuation.resumeDelay());
- // After the user has chosen to resume processing later, the Runner may have stolen
- // the remainder of work through a split call so the above trySplit may return null. If so,
- // the current restriction must be done.
+ /**
+ * After the user has chosen to resume processing later, either the restriction is already
+ * done and the user unknowingly claimed the last element or the Runner may have stolen the
+ * remainder of work through a split call so the above trySplit may return null. If so, the
+ * current restriction must be done.
+ */
if (splitResult == null) {
currentTracker.checkDone();
continue;
}
// Forward the split to the bundle level split listener.
- splitListener.split(splitResult.getPrimaryRoot(), splitResult.getResidualRoot());
+ splitListener.split(splitResult.getPrimaryRoots(), splitResult.getResidualRoots());
}
} finally {
synchronized (splitLock) {
currentElement = null;
currentRestriction = null;
currentWatermarkEstimatorState = null;
+ currentWindowIterator = null;
currentWindow = null;
currentTracker = null;
currentWatermarkEstimator = null;
@@ -923,20 +1047,59 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
convertSplitResultToWindowedSplitResult.apply(result, watermarkAndState.getValue());
}
+ List<BundleApplication> primaryRoots = new ArrayList<>();
+ List<DelayedBundleApplication> residualRoots = new ArrayList<>();
+ Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
+ if (windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot() != null) {
+ ByteString.Output primaryInOtherWindowsBytes = ByteString.newOutput();
+ try {
+ fullInputCoder.encode(
+ windowedSplitResult.getPrimaryInFullyProcessedWindowsRoot(),
+ primaryInOtherWindowsBytes);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ BundleApplication.Builder primaryApplicationInOtherWindows =
+ BundleApplication.newBuilder()
+ .setTransformId(pTransformId)
+ .setInputId(mainInputId)
+ .setElement(primaryInOtherWindowsBytes.toByteString());
+ primaryRoots.add(primaryApplicationInOtherWindows.build());
+ }
+ if (windowedSplitResult.getResidualInUnprocessedWindowsRoot() != null) {
+ ByteString.Output bytesOut = ByteString.newOutput();
+ try {
+ fullInputCoder.encode(windowedSplitResult.getResidualInUnprocessedWindowsRoot(), bytesOut);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ BundleApplication.Builder residualInUnprocessedWindowsRoot =
+ BundleApplication.newBuilder()
+ .setTransformId(pTransformId)
+ .setInputId(mainInputId)
+ .setElement(bytesOut.toByteString());
+ // We don't want to change the output watermarks or set the checkpoint resume time since
+ // that applies to the current window.
+ residualRoots.add(
+ DelayedBundleApplication.newBuilder()
+ .setApplication(residualInUnprocessedWindowsRoot)
+ .build());
+ }
+
ByteString.Output primaryBytes = ByteString.newOutput();
ByteString.Output residualBytes = ByteString.newOutput();
try {
- Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
- fullInputCoder.encode(windowedSplitResult.getPrimaryRoot(), primaryBytes);
- fullInputCoder.encode(windowedSplitResult.getResidualRoot(), residualBytes);
+ fullInputCoder.encode(windowedSplitResult.getPrimarySplitRoot(), primaryBytes);
+ fullInputCoder.encode(windowedSplitResult.getResidualSplitRoot(), residualBytes);
} catch (IOException e) {
throw new RuntimeException(e);
}
- BundleApplication.Builder primaryApplication =
+ primaryRoots.add(
BundleApplication.newBuilder()
.setTransformId(pTransformId)
.setInputId(mainInputId)
- .setElement(primaryBytes.toByteString());
+ .setElement(primaryBytes.toByteString())
+ .build());
BundleApplication.Builder residualApplication =
BundleApplication.newBuilder()
.setTransformId(pTransformId)
@@ -953,12 +1116,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
.build());
}
}
- return HandlesSplits.SplitResult.of(
- primaryApplication.build(),
+ residualRoots.add(
DelayedBundleApplication.newBuilder()
- .setApplication(residualApplication.build())
+ .setApplication(residualApplication)
.setRequestedTimeDelay(Durations.fromMillis(resumeDelay.getMillis()))
.build());
+
+ return HandlesSplits.SplitResult.of(primaryRoots, residualRoots);
}
private <K> void processTimer(
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
index 63b5868..54e1983 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
@@ -18,7 +18,9 @@
package org.apache.beam.fn.harness;
import com.google.auto.value.AutoValue;
+import java.util.List;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
/**
@@ -36,12 +38,13 @@ public interface HandlesSplits {
@AutoValue
abstract class SplitResult {
public static SplitResult of(
- BeamFnApi.BundleApplication primaryRoot, BeamFnApi.DelayedBundleApplication residualRoot) {
- return new AutoValue_HandlesSplits_SplitResult(primaryRoot, residualRoot);
+ List<BundleApplication> primaryRoots,
+ List<BeamFnApi.DelayedBundleApplication> residualRoots) {
+ return new AutoValue_HandlesSplits_SplitResult(primaryRoots, residualRoots);
}
- public abstract BeamFnApi.BundleApplication getPrimaryRoot();
+ public abstract List<BeamFnApi.BundleApplication> getPrimaryRoots();
- public abstract BeamFnApi.DelayedBundleApplication getResidualRoot();
+ public abstract List<BeamFnApi.DelayedBundleApplication> getResidualRoots();
}
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
index 5557be8..f897298 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -38,21 +38,21 @@ public interface BundleSplitListener {
* are a decomposition of work that has been given away by the bundle, so the runner must delegate
* it for someone else to execute.
*/
- void split(BundleApplication primaryRoot, DelayedBundleApplication residualRoots);
+ void split(List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots);
/** A {@link BundleSplitListener} which gathers all splits produced and stores them in memory. */
@AutoValue
@NotThreadSafe
abstract class InMemory implements BundleSplitListener {
public static InMemory create() {
- return new AutoValue_BundleSplitListener_InMemory(
- new ArrayList<BundleApplication>(), new ArrayList<DelayedBundleApplication>());
+ return new AutoValue_BundleSplitListener_InMemory(new ArrayList<>(), new ArrayList<>());
}
@Override
- public void split(BundleApplication primaryRoot, DelayedBundleApplication residualRoot) {
- getPrimaryRoots().add(primaryRoot);
- getResidualRoots().add(residualRoot);
+ public void split(
+ List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots) {
+ getPrimaryRoots().addAll(primaryRoots);
+ getResidualRoots().addAll(residualRoots);
}
public void clear() {
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
index 9293d01..18994e7 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java
@@ -664,15 +664,17 @@ public class BeamFnDataReadRunnerTest {
@Override
public SplitResult trySplit(double fractionOfRemainder) {
return SplitResult.of(
- BundleApplication.newBuilder()
- .setInputId(String.format("primary%.1f", fractionOfRemainder))
- .build(),
- DelayedBundleApplication.newBuilder()
- .setApplication(
- BundleApplication.newBuilder()
- .setInputId(String.format("residual%.1f", 1 - fractionOfRemainder))
- .build())
- .build());
+ Collections.singletonList(
+ BundleApplication.newBuilder()
+ .setInputId(String.format("primary%.1f", fractionOfRemainder))
+ .build()),
+ Collections.singletonList(
+ DelayedBundleApplication.newBuilder()
+ .setApplication(
+ BundleApplication.newBuilder()
+ .setInputId(String.format("residual%.1f", 1 - fractionOfRemainder))
+ .build())
+ .build()));
}
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index fb00550..f9eec08 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -114,6 +114,7 @@ import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
@@ -125,6 +126,7 @@ import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.hamcrest.collection.IsMapContaining;
@@ -595,8 +597,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
// Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
FnDataReceiver<WindowedValue<?>> mainInput =
consumers.getMultiplexingConsumer(inputPCollectionId);
- mainInput.accept(valueInWindow("X", windowA));
- mainInput.accept(valueInWindow("Y", windowB));
+ mainInput.accept(valueInWindows("X", windowA));
+ mainInput.accept(valueInWindows("Y", windowB));
assertThat(mainOutputValues, hasSize(2));
assertThat(
mainOutputValues.get(0).getValue(),
@@ -713,8 +715,8 @@ public class FnApiDoFnRunnerTest implements Serializable {
// Ensure that the bagUserStateKey order does not matter when we traverse over KV pairs.
FnDataReceiver<WindowedValue<?>> mainInput =
consumers.getMultiplexingConsumer(inputPCollectionId);
- mainInput.accept(valueInWindow("X", windowA));
- mainInput.accept(valueInWindow("Y", windowB));
+ mainInput.accept(valueInWindows("X", windowA));
+ mainInput.accept(valueInWindows("Y", windowB));
mainOutputValues.clear();
Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
@@ -1007,8 +1009,13 @@ public class FnApiDoFnRunnerTest implements Serializable {
PaneInfo.NO_FIRING);
}
- private <T> WindowedValue<T> valueInWindow(T value, BoundedWindow window) {
- return WindowedValue.of(value, window.maxTimestamp(), window, PaneInfo.NO_FIRING);
+ private <T> WindowedValue<T> valueInWindows(
+ T value, BoundedWindow window, BoundedWindow... windows) {
+ return WindowedValue.of(
+ value,
+ window.maxTimestamp(),
+ ImmutableList.<BoundedWindow>builder().add(window).add(windows).build(),
+ PaneInfo.NO_FIRING);
}
private static class TestTimerfulDoFn extends DoFn<KV<String, String>, String> {
@@ -1543,8 +1550,9 @@ public class FnApiDoFnRunnerTest implements Serializable {
timestampedValueInGlobalWindow("7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
timestampedValueInGlobalWindow("7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3))));
- BundleApplication primaryRoot = trySplitResult.getPrimaryRoot();
- DelayedBundleApplication residualRoot = trySplitResult.getResidualRoot();
+ BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+ DelayedBundleApplication residualRoot =
+ Iterables.getOnlyElement(trySplitResult.getResidualRoots());
assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
assertEquals(
@@ -1594,6 +1602,393 @@ public class FnApiDoFnRunnerTest implements Serializable {
}
@Test
+ public void testProcessElementForWindowedSizedElementAndRestriction() throws Exception {
+ Pipeline p = Pipeline.create();
+ PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+ PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+ TestSplittableDoFn doFn = new TestSplittableDoFn(singletonSideInputView);
+
+ valuePCollection
+ .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+ .apply(TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
+
+ RunnerApi.Pipeline pProto =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+ SplittableParDoExpander.createSizedReplacement());
+ String expandedTransformId =
+ Iterables.find(
+ pProto.getComponents().getTransformsMap().entrySet(),
+ entry ->
+ entry
+ .getValue()
+ .getSpec()
+ .getUrn()
+ .equals(
+ PTransformTranslation
+ .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)
+ && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+ .getKey();
+ RunnerApi.PTransform pTransform =
+ pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+ String inputPCollectionId =
+ pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+ RunnerApi.PCollection inputPCollection =
+ pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+ RehydratedComponents rehydratedComponents =
+ RehydratedComponents.forComponents(pProto.getComponents());
+ Coder<WindowedValue> inputCoder =
+ WindowedValue.getFullCoder(
+ CoderTranslation.fromProto(
+ pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+ rehydratedComponents,
+ TranslationContext.DEFAULT),
+ (Coder)
+ CoderTranslation.fromProto(
+ pProto
+ .getComponents()
+ .getCodersOrThrow(
+ pProto
+ .getComponents()
+ .getWindowingStrategiesOrThrow(
+ inputPCollection.getWindowingStrategyId())
+ .getWindowCoderId()),
+ rehydratedComponents,
+ TranslationContext.DEFAULT));
+ String outputPCollectionId = pTransform.getOutputsOrThrow("output");
+
+ ImmutableMap<StateKey, ByteString> stateData =
+ ImmutableMap.of(
+ multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
+ encode("8"));
+
+ FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
+
+ List<WindowedValue<String>> mainOutputValues = new ArrayList<>();
+ MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+ PCollectionConsumerRegistry consumers =
+ new PCollectionConsumerRegistry(
+ metricsContainerRegistry, mock(ExecutionStateTracker.class));
+ consumers.register(
+ outputPCollectionId,
+ TEST_TRANSFORM_ID,
+ (FnDataReceiver) (FnDataReceiver<WindowedValue<String>>) mainOutputValues::add);
+ PTransformFunctionRegistry startFunctionRegistry =
+ new PTransformFunctionRegistry(
+ mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+ PTransformFunctionRegistry finishFunctionRegistry =
+ new PTransformFunctionRegistry(
+ mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+ List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+ List<ProgressRequestCallback> progressRequestCallbacks = new ArrayList<>();
+ BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
+
+ new FnApiDoFnRunner.Factory<>()
+ .createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ fakeClient,
+ null /* beamFnTimerClient */,
+ TEST_TRANSFORM_ID,
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
+ consumers,
+ startFunctionRegistry,
+ finishFunctionRegistry,
+ teardownFunctions::add,
+ progressRequestCallbacks::add,
+ splitListener,
+ null /* bundleFinalizer */);
+
+ Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+ mainOutputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ consumers.getMultiplexingConsumer(inputPCollectionId);
+ assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+ BoundedWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+ BoundedWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+ {
+ // Check that before processing an element we don't report progress
+ assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+ WindowedValue<?> firstValue =
+ valueInWindows(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+ window1,
+ window2);
+ mainInput.accept(firstValue);
+ // Check that after processing an element we don't report progress
+ assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+ // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
+ // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
+ // with the original watermark while the residual represents [8, 10) with the new MIN + 7
+ // watermark.
+ //
+ // Since we were on the first window, we expect only a single primary root and two residual
+ // roots (the split + the unprocessed window).
+ BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+ assertEquals(2, splitListener.getResidualRoots().size());
+ DelayedBundleApplication residualRoot = splitListener.getResidualRoots().get(1);
+ DelayedBundleApplication residualRootForUnprocessedWindows =
+ splitListener.getResidualRoots().get(0);
+ assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+ assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+ assertEquals(
+ ParDoTranslation.getMainInputName(pTransform),
+ residualRoot.getApplication().getInputId());
+ assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+ Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
+ assertEquals(
+ ImmutableMap.of(
+ "output",
+ org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+ .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+ .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+ .build()),
+ residualRoot.getApplication().getOutputWatermarksMap());
+ assertEquals(
+ org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
+ .setSeconds(54)
+ .setNanos(321000000)
+ .build(),
+ residualRoot.getRequestedTimeDelay());
+ assertEquals(
+ ParDoTranslation.getMainInputName(pTransform),
+ residualRootForUnprocessedWindows.getApplication().getInputId());
+ assertEquals(
+ TEST_TRANSFORM_ID, residualRootForUnprocessedWindows.getApplication().getTransformId());
+ assertEquals(
+ residualRootForUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
+ residualRootForUnprocessedWindows.getRequestedTimeDelay());
+ assertTrue(
+ residualRootForUnprocessedWindows.getApplication().getOutputWatermarksMap().isEmpty());
+
+ assertEquals(
+ decode(inputCoder, primaryRoot.getElement()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE)), 3.0),
+ firstValue.getTimestamp(),
+ window1,
+ firstValue.getPane()));
+ assertEquals(
+ decode(inputCoder, residualRoot.getApplication().getElement()),
+ WindowedValue.of(
+ KV.of(
+ KV.of(
+ "5", KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
+ 2.0),
+ firstValue.getTimestamp(),
+ window1,
+ firstValue.getPane()));
+ assertEquals(
+ decode(inputCoder, residualRootForUnprocessedWindows.getApplication().getElement()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+ firstValue.getTimestamp(),
+ window2,
+ firstValue.getPane()));
+ splitListener.clear();
+
+ // Check that before processing an element we don't report progress
+ assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+ WindowedValue<?> secondValue =
+ valueInWindows(
+ KV.of(
+ KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+ window1,
+ window2);
+ mainInput.accept(secondValue);
+ // Check that after processing an element we don't report progress
+ assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+ assertThat(
+ mainOutputValues,
+ contains(
+ WindowedValue.of(
+ "5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5), window1, firstValue.getPane()),
+ WindowedValue.of(
+ "5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6), window1, firstValue.getPane()),
+ WindowedValue.of(
+ "5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7), window1, firstValue.getPane()),
+ WindowedValue.of(
+ "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, firstValue.getPane()),
+ WindowedValue.of(
+ "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, firstValue.getPane()),
+ WindowedValue.of(
+ "2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window2, firstValue.getPane()),
+ WindowedValue.of(
+ "2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window2, firstValue.getPane())));
+ assertTrue(splitListener.getPrimaryRoots().isEmpty());
+ assertTrue(splitListener.getResidualRoots().isEmpty());
+ mainOutputValues.clear();
+ }
+
+ {
+ // Setup and launch the trySplit thread.
+ ExecutorService executorService = Executors.newSingleThreadExecutor();
+ Future<HandlesSplits.SplitResult> trySplitFuture =
+ executorService.submit(
+ () -> {
+ try {
+ doFn.waitForSplitElementToBeProcessed();
+ // Currently processing "3" out of range [0, 5) elements.
+ assertEquals(0.6, ((HandlesSplits) mainInput).getProgress(), 0.01);
+
+ // Check that during progressing of an element we report progress
+ List<MonitoringInfo> mis =
+ Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos();
+ MonitoringInfo.Builder expectedCompleted = MonitoringInfo.newBuilder();
+ expectedCompleted.setUrn(MonitoringInfoConstants.Urns.WORK_COMPLETED);
+ expectedCompleted.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+ expectedCompleted.putLabels(
+ MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+ expectedCompleted.setPayload(
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(
+ IterableCoder.of(DoubleCoder.of()), Collections.singletonList(3.0))));
+ MonitoringInfo.Builder expectedRemaining = MonitoringInfo.newBuilder();
+ expectedRemaining.setUrn(MonitoringInfoConstants.Urns.WORK_REMAINING);
+ expectedRemaining.setType(MonitoringInfoConstants.TypeUrns.PROGRESS_TYPE);
+ expectedRemaining.putLabels(
+ MonitoringInfoConstants.Labels.PTRANSFORM, TEST_TRANSFORM_ID);
+ expectedRemaining.setPayload(
+ ByteString.copyFrom(
+ CoderUtils.encodeToByteArray(
+ IterableCoder.of(DoubleCoder.of()), Collections.singletonList(2.0))));
+ assertThat(
+ mis,
+ containsInAnyOrder(expectedCompleted.build(), expectedRemaining.build()));
+
+ return ((HandlesSplits) mainInput).trySplit(0);
+ } finally {
+ doFn.releaseWaitingProcessElementThread();
+ }
+ });
+
+ // Check that before processing an element we don't report progress
+ assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+ WindowedValue<?> splitValue =
+ valueInWindows(
+ KV.of(
+ KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0),
+ window1,
+ window2);
+ mainInput.accept(splitValue);
+ HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+
+ // Check that after processing an element we don't report progress
+ assertThat(Iterables.getOnlyElement(progressRequestCallbacks).getMonitoringInfos(), empty());
+
+ // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split on the first window.
+ // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
+ // has yet to be invoked for the split element and that the primary represents [0, 4) with
+ // the original watermark while the residual represents [4, 5) with the new MIN + 2 watermark.
+ //
+ // We expect to see none of the output for the second window.
+ assertThat(
+ mainOutputValues,
+ contains(
+ WindowedValue.of(
+ "7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0), window1, splitValue.getPane()),
+ WindowedValue.of(
+ "7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1), window1, splitValue.getPane()),
+ WindowedValue.of(
+ "7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2), window1, splitValue.getPane()),
+ WindowedValue.of(
+ "7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3), window1, splitValue.getPane())));
+
+ BundleApplication primaryRoot = Iterables.getOnlyElement(trySplitResult.getPrimaryRoots());
+ assertEquals(2, trySplitResult.getResidualRoots().size());
+ DelayedBundleApplication residualRoot = trySplitResult.getResidualRoots().get(1);
+ DelayedBundleApplication residualRootInUnprocessedWindows =
+ trySplitResult.getResidualRoots().get(0);
+ assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+ assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+ assertEquals(
+ ParDoTranslation.getMainInputName(pTransform),
+ residualRoot.getApplication().getInputId());
+ assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+ assertEquals(
+ TEST_TRANSFORM_ID, residualRootInUnprocessedWindows.getApplication().getTransformId());
+ assertEquals(
+ residualRootInUnprocessedWindows.getRequestedTimeDelay().getDefaultInstanceForType(),
+ residualRootInUnprocessedWindows.getRequestedTimeDelay());
+ assertTrue(
+ residualRootInUnprocessedWindows.getApplication().getOutputWatermarksMap().isEmpty());
+ assertEquals(
+ valueInWindows(
+ KV.of(
+ KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)), 4.0),
+ window1),
+ inputCoder.decode(primaryRoot.getElement().newInput()));
+ assertEquals(
+ valueInWindows(
+ KV.of(
+ KV.of(
+ "7", KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
+ 1.0),
+ window1),
+ inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+ Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
+ assertEquals(
+ ImmutableMap.of(
+ "output",
+ org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+ .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+ .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+ .build()),
+ residualRoot.getApplication().getOutputWatermarksMap());
+ assertEquals(
+ WindowedValue.of(
+ KV.of(
+ KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0),
+ splitValue.getTimestamp(),
+ window2,
+ splitValue.getPane()),
+ inputCoder.decode(
+ residualRootInUnprocessedWindows.getApplication().getElement().newInput()));
+
+ // We expect 0 resume delay.
+ assertEquals(
+ residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+ residualRoot.getRequestedTimeDelay());
+ // We don't expect the outputs to goto the SDK initiated checkpointing listener.
+ assertTrue(splitListener.getPrimaryRoots().isEmpty());
+ assertTrue(splitListener.getResidualRoots().isEmpty());
+ mainOutputValues.clear();
+ executorService.shutdown();
+ }
+
+ Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+ assertThat(mainOutputValues, empty());
+
+ Iterables.getOnlyElement(teardownFunctions).run();
+ assertThat(mainOutputValues, empty());
+
+ // Assert that state data did not change
+ assertEquals(stateData, fakeClient.getData());
+ }
+
+ private static <T> T decode(Coder<T> coder, ByteString value) {
+ try {
+ return coder.decode(value.newInput());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Test
public void testProcessElementForPairWithRestriction() throws Exception {
Pipeline p = Pipeline.create();
PCollection<String> valuePCollection = p.apply(Create.of("unused"));
@@ -1687,6 +2082,121 @@ public class FnApiDoFnRunnerTest implements Serializable {
}
@Test
+ public void testProcessElementForWindowedPairWithRestriction() throws Exception {
+ Pipeline p = Pipeline.create();
+ PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+ PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+ valuePCollection
+ .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+ .apply(
+ TEST_TRANSFORM_ID,
+ ParDo.of(new TestSplittableDoFn(singletonSideInputView))
+ .withSideInputs(singletonSideInputView));
+
+ RunnerApi.Pipeline pProto =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+ SplittableParDoExpander.createSizedReplacement());
+ String expandedTransformId =
+ Iterables.find(
+ pProto.getComponents().getTransformsMap().entrySet(),
+ entry ->
+ entry
+ .getValue()
+ .getSpec()
+ .getUrn()
+ .equals(PTransformTranslation.SPLITTABLE_PAIR_WITH_RESTRICTION_URN)
+ && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+ .getKey();
+ RunnerApi.PTransform pTransform =
+ pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+ String inputPCollectionId =
+ pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+ String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+ FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+ List<WindowedValue<KV<String, OffsetRange>>> mainOutputValues = new ArrayList<>();
+ MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+ PCollectionConsumerRegistry consumers =
+ new PCollectionConsumerRegistry(
+ metricsContainerRegistry, mock(ExecutionStateTracker.class));
+ consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+ PTransformFunctionRegistry startFunctionRegistry =
+ new PTransformFunctionRegistry(
+ mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+ PTransformFunctionRegistry finishFunctionRegistry =
+ new PTransformFunctionRegistry(
+ mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+ List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+ new FnApiDoFnRunner.Factory<>()
+ .createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ fakeClient,
+ null /* beamFnTimerClient */,
+ TEST_TRANSFORM_ID,
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
+ consumers,
+ startFunctionRegistry,
+ finishFunctionRegistry,
+ teardownFunctions::add,
+ null /* addProgressRequestCallback */,
+ null /* bundleSplitListener */,
+ null /* bundleFinalizer */);
+
+ Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+ mainOutputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ consumers.getMultiplexingConsumer(inputPCollectionId);
+ IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+ IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+ WindowedValue<?> firstValue = valueInWindows("5", window1, window2);
+ WindowedValue<?> secondValue = valueInWindows("2", window1, window2);
+ mainInput.accept(firstValue);
+ mainInput.accept(secondValue);
+ assertThat(
+ mainOutputValues,
+ contains(
+ WindowedValue.of(
+ KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ firstValue.getTimestamp(),
+ window1,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ firstValue.getTimestamp(),
+ window2,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ secondValue.getTimestamp(),
+ window1,
+ secondValue.getPane()),
+ WindowedValue.of(
+ KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ secondValue.getTimestamp(),
+ window2,
+ secondValue.getPane())));
+ mainOutputValues.clear();
+
+ Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+ assertThat(mainOutputValues, empty());
+
+ Iterables.getOnlyElement(teardownFunctions).run();
+ assertThat(mainOutputValues, empty());
+ }
+
+ @Test
public void testProcessElementForSplitAndSizeRestriction() throws Exception {
Pipeline p = Pipeline.create();
PCollection<String> valuePCollection = p.apply(Create.of("unused"));
@@ -1795,4 +2305,164 @@ public class FnApiDoFnRunnerTest implements Serializable {
Iterables.getOnlyElement(teardownFunctions).run();
assertThat(mainOutputValues, empty());
}
+
+ @Test
+ public void testProcessElementForWindowedSplitAndSizeRestriction() throws Exception {
+ Pipeline p = Pipeline.create();
+ PCollection<String> valuePCollection = p.apply(Create.of("unused"));
+ PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+ valuePCollection
+ .apply(Window.into(SlidingWindows.of(Duration.standardSeconds(1))))
+ .apply(
+ TEST_TRANSFORM_ID,
+ ParDo.of(new TestSplittableDoFn(singletonSideInputView))
+ .withSideInputs(singletonSideInputView));
+
+ RunnerApi.Pipeline pProto =
+ ProtoOverrides.updateTransform(
+ PTransformTranslation.PAR_DO_TRANSFORM_URN,
+ PipelineTranslation.toProto(p, SdkComponents.create(p.getOptions()), true),
+ SplittableParDoExpander.createSizedReplacement());
+ String expandedTransformId =
+ Iterables.find(
+ pProto.getComponents().getTransformsMap().entrySet(),
+ entry ->
+ entry
+ .getValue()
+ .getSpec()
+ .getUrn()
+ .equals(
+ PTransformTranslation.SPLITTABLE_SPLIT_AND_SIZE_RESTRICTIONS_URN)
+ && entry.getValue().getUniqueName().contains(TEST_TRANSFORM_ID))
+ .getKey();
+ RunnerApi.PTransform pTransform =
+ pProto.getComponents().getTransformsOrThrow(expandedTransformId);
+ String inputPCollectionId =
+ pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+ String outputPCollectionId = Iterables.getOnlyElement(pTransform.getOutputsMap().values());
+
+ FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(ImmutableMap.of());
+
+ List<WindowedValue<KV<KV<String, OffsetRange>, Double>>> mainOutputValues = new ArrayList<>();
+ MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap();
+ PCollectionConsumerRegistry consumers =
+ new PCollectionConsumerRegistry(
+ metricsContainerRegistry, mock(ExecutionStateTracker.class));
+ consumers.register(outputPCollectionId, TEST_TRANSFORM_ID, ((List) mainOutputValues)::add);
+ PTransformFunctionRegistry startFunctionRegistry =
+ new PTransformFunctionRegistry(
+ mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "start");
+ PTransformFunctionRegistry finishFunctionRegistry =
+ new PTransformFunctionRegistry(
+ mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
+ List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
+
+ new FnApiDoFnRunner.Factory<>()
+ .createRunnerForPTransform(
+ PipelineOptionsFactory.create(),
+ null /* beamFnDataClient */,
+ fakeClient,
+ null /* beamFnTimerClient */,
+ TEST_TRANSFORM_ID,
+ pTransform,
+ Suppliers.ofInstance("57L")::get,
+ pProto.getComponents().getPcollectionsMap(),
+ pProto.getComponents().getCodersMap(),
+ pProto.getComponents().getWindowingStrategiesMap(),
+ consumers,
+ startFunctionRegistry,
+ finishFunctionRegistry,
+ teardownFunctions::add,
+ null /* addProgressRequestCallback */,
+ null /* bundleSplitListener */,
+ null /* bundleFinalizer */);
+
+ Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
+ mainOutputValues.clear();
+
+ assertThat(consumers.keySet(), containsInAnyOrder(inputPCollectionId, outputPCollectionId));
+
+ FnDataReceiver<WindowedValue<?>> mainInput =
+ consumers.getMultiplexingConsumer(inputPCollectionId);
+ IntervalWindow window1 = new IntervalWindow(new Instant(5), new Instant(10));
+ IntervalWindow window2 = new IntervalWindow(new Instant(6), new Instant(11));
+ WindowedValue<?> firstValue =
+ valueInWindows(
+ KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ window1,
+ window2);
+ WindowedValue<?> secondValue =
+ valueInWindows(
+ KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ window1,
+ window2);
+ mainInput.accept(firstValue);
+ mainInput.accept(secondValue);
+ assertThat(
+ mainOutputValues,
+ contains(
+ WindowedValue.of(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 2.0),
+ firstValue.getTimestamp(),
+ window1,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 3.0),
+ firstValue.getTimestamp(),
+ window1,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 2.0),
+ firstValue.getTimestamp(),
+ window2,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 3.0),
+ firstValue.getTimestamp(),
+ window2,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 1.0),
+ firstValue.getTimestamp(),
+ window1,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 1.0),
+ firstValue.getTimestamp(),
+ window1,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 1.0),
+ firstValue.getTimestamp(),
+ window2,
+ firstValue.getPane()),
+ WindowedValue.of(
+ KV.of(
+ KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 1.0),
+ firstValue.getTimestamp(),
+ window2,
+ firstValue.getPane())));
+ mainOutputValues.clear();
+
+ Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
+ assertThat(mainOutputValues, empty());
+
+ Iterables.getOnlyElement(teardownFunctions).run();
+ assertThat(mainOutputValues, empty());
+ }
}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
index 340bac2a..5a48efe 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
@@ -21,6 +21,7 @@ import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.junit.Assert.assertThat;
+import java.util.Collections;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
import org.junit.Test;
@@ -46,8 +47,10 @@ public class BundleSplitListenerTest {
@Test
public void testInMemory() {
BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
- splitListener.split(TEST_PRIMARY_1, TEST_RESIDUAL_1);
- splitListener.split(TEST_PRIMARY_2, TEST_RESIDUAL_2);
+ splitListener.split(
+ Collections.singletonList(TEST_PRIMARY_1), Collections.singletonList(TEST_RESIDUAL_1));
+ splitListener.split(
+ Collections.singletonList(TEST_PRIMARY_2), Collections.singletonList(TEST_RESIDUAL_2));
assertThat(splitListener.getPrimaryRoots(), contains(TEST_PRIMARY_1, TEST_PRIMARY_2));
assertThat(splitListener.getResidualRoots(), contains(TEST_RESIDUAL_1, TEST_RESIDUAL_2));
splitListener.clear();