You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/04/17 00:35:29 UTC
[beam] branch master updated: [BEAM-5605,
BEAM-2939] Add support for FnApiDoFnRunner to handle split calls.
(#11414)
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 99444c6 [BEAM-5605, BEAM-2939] Add support for FnApiDoFnRunner to handle split calls. (#11414)
99444c6 is described below
commit 99444c6cefbe30ebb222f2adad52fdd150310780
Author: Lukasz Cwik <lu...@gmail.com>
AuthorDate: Thu Apr 16 17:34:53 2020 -0700
[BEAM-5605, BEAM-2939] Add support for FnApiDoFnRunner to handle split calls. (#11414)
* [BEAM-5605, BEAM-2939] Add support for FnApiDoFnRunner to handle split calls.
The next step is to plumb the split request to the BeamFnDataReadRunner which will forward it to the FnApiDoFnRunner.
* fixup! Minor comment change.
---
.../apache/beam/fn/harness/FnApiDoFnRunner.java | 223 ++++++++++-----
.../org/apache/beam/fn/harness/HandlesSplits.java | 8 +
.../fn/harness/control/BundleSplitListener.java | 29 ++
.../fn/harness/control/ProcessBundleHandler.java | 39 +--
.../beam/fn/harness/FnApiDoFnRunnerTest.java | 305 ++++++++++++++++-----
.../harness/control/BundleSplitListenerTest.java | 58 ++++
.../harness/control/ProcessBundleHandlerTest.java | 11 +-
7 files changed, 499 insertions(+), 174 deletions(-)
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 13b9c57..41daef7 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -60,6 +60,8 @@ import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
+import org.apache.beam.sdk.fn.splittabledofn.RestrictionTrackers;
+import org.apache.beam.sdk.fn.splittabledofn.RestrictionTrackers.ClaimObserver;
import org.apache.beam.sdk.fn.splittabledofn.WatermarkEstimators;
import org.apache.beam.sdk.function.ThrowingRunnable;
import org.apache.beam.sdk.options.PipelineOptions;
@@ -102,7 +104,6 @@ import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.util.Durations;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -184,7 +185,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
} catch (IOException e) {
throw new RuntimeException(e);
}
- FnDataReceiver<WindowedValue> mainInputConsumer;
+ final FnDataReceiver<WindowedValue> mainInputConsumer;
switch (pTransform.getSpec().getUrn()) {
case PTransformTranslation.PAR_DO_TRANSFORM_URN:
mainInputConsumer = runner::processElementForParDo;
@@ -197,10 +198,22 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
mainInputConsumer = runner::processElementForSplitRestriction;
break;
case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN:
- mainInputConsumer = runner::processElementForElementAndRestriction;
+ mainInputConsumer =
+ runner.new SplittableFnDataReceiver() {
+ @Override
+ public void accept(WindowedValue input) throws Exception {
+ runner.processElementForElementAndRestriction(input);
+ }
+ };
break;
case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN:
- mainInputConsumer = runner::processElementForSizedElementAndRestriction;
+ mainInputConsumer =
+ runner.new SplittableFnDataReceiver() {
+ @Override
+ public void accept(WindowedValue input) throws Exception {
+ runner.processElementForSizedElementAndRestriction(input);
+ }
+ };
break;
default:
throw new IllegalStateException("Unknown urn: " + pTransform.getSpec().getUrn());
@@ -250,6 +263,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
private final FinishBundleArgumentProvider finishBundleArgumentProvider;
/**
+ * Used to guarantee a consistent view of this {@link FnApiDoFnRunner} while setting up for {@link
+ * DoFnInvoker#invokeProcessElement} since {@link #trySplitForElementAndRestriction} may access
+ * internal {@link FnApiDoFnRunner} state concurrently.
+ */
+ private final Object splitLock = new Object();
+
+ /**
* Only set for {@link PTransformTranslation#SPLITTABLE_PROCESS_ELEMENTS_URN} and {@link
* PTransformTranslation#SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN} transforms. Can
* only be invoked from within {@code processElement...} methods.
@@ -551,7 +571,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
currentElement.withValue(
KV.of(
currentElement.getValue(),
- KV.of(splitResult.getPrimary(), watermarkEstimatorState))),
+ KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState))),
currentElement.withValue(
KV.of(
currentElement.getValue(),
@@ -599,7 +619,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
KV.of(
KV.of(
currentElement.getValue(),
- KV.of(splitResult.getPrimary(), watermarkEstimatorState)),
+ KV.of(splitResult.getPrimary(), currentWatermarkEstimatorState)),
primarySize)),
currentElement.withValue(
KV.of(
@@ -620,7 +640,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
}
}
- public void startBundle() {
+ private void startBundle() {
this.stateAccessor =
new FnApiStateAccessor(
pipelineOptions,
@@ -662,7 +682,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
doFnInvoker.invokeStartBundle(startBundleArgumentProvider);
}
- public void processElementForParDo(WindowedValue<InputT> elem) {
+ private void processElementForParDo(WindowedValue<InputT> elem) {
currentElement = elem;
try {
Iterator<BoundedWindow> windowIterator =
@@ -677,7 +697,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
}
}
- public void processElementForPairWithRestriction(WindowedValue<InputT> elem) {
+ private void processElementForPairWithRestriction(WindowedValue<InputT> elem) {
currentElement = elem;
try {
Iterator<BoundedWindow> windowIterator =
@@ -702,7 +722,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
}
}
- public void processElementForSplitRestriction(
+ private void processElementForSplitRestriction(
WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
currentElement = elem.withValue(elem.getValue().getKey());
currentRestriction = elem.getValue().getValue().getKey();
@@ -735,24 +755,38 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
public abstract WindowedValue getResidualRoot();
}
- public void processElementForSizedElementAndRestriction(
+ private void processElementForSizedElementAndRestriction(
WindowedValue<KV<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>, Double>> elem) {
processElementForElementAndRestriction(elem.withValue(elem.getValue().getKey()));
}
- public void processElementForElementAndRestriction(
+ private void processElementForElementAndRestriction(
WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
currentElement = elem.withValue(elem.getValue().getKey());
try {
Iterator<BoundedWindow> windowIterator =
(Iterator<BoundedWindow>) elem.getWindows().iterator();
while (windowIterator.hasNext()) {
- currentRestriction = elem.getValue().getValue().getKey();
- currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
- currentWindow = windowIterator.next();
- currentTracker = doFnInvoker.invokeNewTracker(processContext);
- currentWatermarkEstimator =
- WatermarkEstimators.threadSafe(doFnInvoker.invokeNewWatermarkEstimator(processContext));
+ synchronized (splitLock) {
+ currentRestriction = elem.getValue().getValue().getKey();
+ currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
+ currentWindow = windowIterator.next();
+ currentTracker =
+ RestrictionTrackers.observe(
+ doFnInvoker.invokeNewTracker(processContext),
+ new ClaimObserver<PositionT>() {
+ @Override
+ public void onClaimed(PositionT position) {}
+
+ @Override
+ public void onClaimFailed(PositionT position) {}
+ });
+ currentWatermarkEstimator =
+ WatermarkEstimators.threadSafe(
+ doFnInvoker.invokeNewWatermarkEstimator(processContext));
+ }
+
+ // It is important to ensure that {@code splitLock} is not held during #invokeProcessElement
DoFn.ProcessContinuation continuation = doFnInvoker.invokeProcessElement(processContext);
// Ensure that all the work is done if the user tells us that they don't want to
// resume processing.
@@ -761,72 +795,111 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
continue;
}
- // Make sure to get the output watermark before we split to ensure that the lower bound
- // applies to both the primary and residual.
- KV<Instant, WatermarkEstimatorStateT> watermarkAndState =
- currentWatermarkEstimator.getWatermarkAndState();
- SplitResult<RestrictionT> result = currentTracker.trySplit(0);
+ // Attempt to checkpoint the current restriction.
+ HandlesSplits.SplitResult splitResult =
+ trySplitForElementAndRestriction(0, continuation.resumeDelay());
// After the user has chosen to resume processing later, the Runner may have stolen
- // the remainder of work through a split call so the above trySplit may fail. If so,
+ // the remainder of work through a split call so the above trySplit may return null. If so,
// the current restriction must be done.
- if (result == null) {
+ if (splitResult == null) {
currentTracker.checkDone();
continue;
}
-
- // Otherwise we have a successful self checkpoint.
- WindowedSplitResult windowedSplitResult =
- convertSplitResultToWindowedSplitResult.apply(result, watermarkAndState.getValue());
- ByteString.Output primaryBytes = ByteString.newOutput();
- ByteString.Output residualBytes = ByteString.newOutput();
- try {
- Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
- fullInputCoder.encode(windowedSplitResult.getPrimaryRoot(), primaryBytes);
- fullInputCoder.encode(windowedSplitResult.getResidualRoot(), residualBytes);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- BundleApplication.Builder primaryApplication =
- BundleApplication.newBuilder()
- .setTransformId(pTransformId)
- .setInputId(mainInputId)
- .setElement(primaryBytes.toByteString());
- BundleApplication.Builder residualApplication =
- BundleApplication.newBuilder()
- .setTransformId(pTransformId)
- .setInputId(mainInputId)
- .setElement(residualBytes.toByteString());
-
- if (!watermarkAndState.getKey().equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
- for (String outputId : pTransform.getOutputsMap().keySet()) {
- residualApplication.putOutputWatermarks(
- outputId,
- org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
- .setSeconds(watermarkAndState.getKey().getMillis() / 1000)
- .setNanos((int) (watermarkAndState.getKey().getMillis() % 1000) * 1000000)
- .build());
- }
- }
+ // Forward the split to the bundle level split listener.
splitListener.split(
- ImmutableList.of(primaryApplication.build()),
- ImmutableList.of(
- DelayedBundleApplication.newBuilder()
- .setApplication(residualApplication.build())
- .setRequestedTimeDelay(
- Durations.fromMillis(continuation.resumeDelay().getMillis()))
- .build()));
+ Collections.singletonList(splitResult.getPrimaryRoot()),
+ Collections.singletonList(splitResult.getResidualRoot()));
}
} finally {
- currentElement = null;
- currentRestriction = null;
- currentWatermarkEstimatorState = null;
- currentWindow = null;
- currentTracker = null;
- currentWatermarkEstimator = null;
+ synchronized (splitLock) {
+ currentElement = null;
+ currentRestriction = null;
+ currentWatermarkEstimatorState = null;
+ currentWindow = null;
+ currentTracker = null;
+ currentWatermarkEstimator = null;
+ }
+ }
+ }
+
+ /**
+ * An abstract class which forwards split and progress calls allowing the implementer to choose
+ * where input elements are sent.
+ */
+ private abstract class SplittableFnDataReceiver
+ implements HandlesSplits, FnDataReceiver<WindowedValue> {
+ @Override
+ public SplitResult trySplit(double fractionOfRemainder) {
+ return trySplitForElementAndRestriction(fractionOfRemainder, Duration.ZERO);
+ }
+
+ @Override
+ public double getProgress() {
+ // TODO(BEAM-2939): Implement plumbing progress through for splitting.
+ return 0;
+ }
+ }
+
+ private HandlesSplits.SplitResult trySplitForElementAndRestriction(
+ double fractionOfRemainder, Duration resumeDelay) {
+ synchronized (splitLock) {
+ // There is nothing to split if we are between element and restriction processing calls.
+ if (currentTracker == null) {
+ return null;
+ }
+
+ // Make sure to get the output watermark before we split to ensure that the lower bound
+ // applies to the residual.
+ KV<Instant, WatermarkEstimatorStateT> watermarkAndState =
+ currentWatermarkEstimator.getWatermarkAndState();
+ SplitResult<RestrictionT> result = currentTracker.trySplit(fractionOfRemainder);
+ if (result == null) {
+ return null;
+ }
+
+ // We have a successful self split, either runner initiated or via a self checkpoint.
+ WindowedSplitResult windowedSplitResult =
+ convertSplitResultToWindowedSplitResult.apply(result, watermarkAndState.getValue());
+ ByteString.Output primaryBytes = ByteString.newOutput();
+ ByteString.Output residualBytes = ByteString.newOutput();
+ try {
+ Coder fullInputCoder = WindowedValue.getFullCoder(inputCoder, windowCoder);
+ fullInputCoder.encode(windowedSplitResult.getPrimaryRoot(), primaryBytes);
+ fullInputCoder.encode(windowedSplitResult.getResidualRoot(), residualBytes);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ BundleApplication.Builder primaryApplication =
+ BundleApplication.newBuilder()
+ .setTransformId(pTransformId)
+ .setInputId(mainInputId)
+ .setElement(primaryBytes.toByteString());
+ BundleApplication.Builder residualApplication =
+ BundleApplication.newBuilder()
+ .setTransformId(pTransformId)
+ .setInputId(mainInputId)
+ .setElement(residualBytes.toByteString());
+
+ if (!watermarkAndState.getKey().equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
+ for (String outputId : pTransform.getOutputsMap().keySet()) {
+ residualApplication.putOutputWatermarks(
+ outputId,
+ org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+ .setSeconds(watermarkAndState.getKey().getMillis() / 1000)
+ .setNanos((int) (watermarkAndState.getKey().getMillis() % 1000) * 1000000)
+ .build());
+ }
+ }
+ return HandlesSplits.SplitResult.of(
+ primaryApplication.build(),
+ DelayedBundleApplication.newBuilder()
+ .setApplication(residualApplication.build())
+ .setRequestedTimeDelay(Durations.fromMillis(resumeDelay.getMillis()))
+ .build());
}
}
- public <K> void processTimer(String timerId, TimeDomain timeDomain, Timer<K> timer) {
+ private <K> void processTimer(String timerId, TimeDomain timeDomain, Timer<K> timer) {
currentTimer = timer;
currentTimeDomain = timeDomain;
try {
@@ -843,7 +916,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
}
}
- public void finishBundle() throws Exception {
+ private void finishBundle() throws Exception {
for (TimerHandler timerHandler : timerHandlers.values()) {
timerHandler.awaitCompletion();
}
@@ -858,7 +931,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
this.stateAccessor = null;
}
- public void tearDown() {
+ private void tearDown() {
doFnInvoker.invokeTeardown();
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
index a2ac123..d6a161a 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/HandlesSplits.java
@@ -19,10 +19,18 @@ package org.apache.beam.fn.harness;
import com.google.auto.value.AutoValue;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
+import org.apache.beam.sdk.fn.data.FnDataReceiver;
+/**
+ * An interface that may be used to extend a {@link FnDataReceiver} signalling that the downstream
+ * runner is capable of performing splitting and providing progress reporting.
+ */
public interface HandlesSplits {
+
+ /** Returns null if the split was unsuccessful. */
SplitResult trySplit(double fractionOfRemainder);
+ /** Returns the current progress of the active element. */
double getProgress();
@AutoValue
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
index 9eab245..830834e 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/BundleSplitListener.java
@@ -17,7 +17,10 @@
*/
package org.apache.beam.fn.harness.control;
+import com.google.auto.value.AutoValue;
+import java.util.ArrayList;
import java.util.List;
+import javax.annotation.concurrent.NotThreadSafe;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
@@ -37,4 +40,30 @@ public interface BundleSplitListener {
* it for someone else to execute.
*/
void split(List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots);
+
+ /** A {@link BundleSplitListener} which gathers all splits produced and stores them in memory. */
+ @AutoValue
+ @NotThreadSafe
+ abstract class InMemory implements BundleSplitListener {
+ public static InMemory create() {
+ return new AutoValue_BundleSplitListener_InMemory(
+ new ArrayList<BundleApplication>(), new ArrayList<DelayedBundleApplication>());
+ }
+
+ @Override
+ public void split(
+ List<BundleApplication> primaryRoots, List<DelayedBundleApplication> residualRoots) {
+ getPrimaryRoots().addAll(primaryRoots);
+ getResidualRoots().addAll(residualRoots);
+ }
+
+ public void clear() {
+ getPrimaryRoots().clear();
+ getResidualRoots().clear();
+ }
+
+ public abstract List<BundleApplication> getPrimaryRoots();
+
+ public abstract List<DelayedBundleApplication> getResidualRoots();
+ }
}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index a788348..e11d1ec 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -47,8 +47,6 @@ import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
-import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
@@ -75,13 +73,11 @@ import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message;
import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.TextFormat;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.SetMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
import org.joda.time.Instant;
@@ -275,7 +271,6 @@ public class ProcessBundleHandler {
});
PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry();
PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry();
- Multimap<String, DelayedBundleApplication> allResiduals = bundleProcessor.getAllResiduals();
PCollectionConsumerRegistry pCollectionConsumerRegistry =
bundleProcessor.getpCollectionConsumerRegistry();
MetricsContainerStepMap metricsContainerRegistry =
@@ -299,10 +294,11 @@ public class ProcessBundleHandler {
LOG.debug("Finishing function {}", finishFunction);
finishFunction.run();
}
- if (!allResiduals.isEmpty()) {
- response.addAllResidualRoots(allResiduals.values());
- }
}
+
+ // Add all checkpointed residuals to the response.
+ response.addAllResidualRoots(bundleProcessor.getSplitListener().getResidualRoots());
+
// Get start bundle Execution Time Metrics.
for (MonitoringInfo mi : startFunctionRegistry.getExecutionTimeMonitoringInfos()) {
response.addMonitoringInfos(mi);
@@ -403,22 +399,7 @@ public class ProcessBundleHandler {
queueingClient, bundleDescriptor.getTimerApiServiceDescriptor())
: new FailAllTimerRegistrations(processBundleRequest);
- Multimap<String, DelayedBundleApplication> allResiduals = ArrayListMultimap.create();
- Multimap<String, BundleApplication> allPrimaries = ArrayListMultimap.create();
- BundleSplitListener splitListener =
- (List<BundleApplication> primaries, List<DelayedBundleApplication> residuals) -> {
- // Reset primaries and accumulate residuals.
- Multimap<String, BundleApplication> newPrimaries = ArrayListMultimap.create();
- for (BundleApplication primary : primaries) {
- newPrimaries.put(primary.getTransformId(), primary);
- }
- allPrimaries.clear();
- allPrimaries.putAll(newPrimaries);
-
- for (DelayedBundleApplication residual : residuals) {
- allResiduals.put(residual.getApplication().getTransformId(), residual);
- }
- };
+ BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
Collection<CallbackRegistration> bundleFinalizationCallbackRegistrations = new ArrayList<>();
BundleFinalizer bundleFinalizer =
@@ -435,7 +416,7 @@ public class ProcessBundleHandler {
startFunctionRegistry,
finishFunctionRegistry,
tearDownFunctions,
- allResiduals,
+ splitListener,
pCollectionConsumerRegistry,
metricsContainerRegistry,
stateTracker,
@@ -563,7 +544,7 @@ public class ProcessBundleHandler {
PTransformFunctionRegistry startFunctionRegistry,
PTransformFunctionRegistry finishFunctionRegistry,
List<ThrowingRunnable> tearDownFunctions,
- Multimap<String, DelayedBundleApplication> allResiduals,
+ BundleSplitListener.InMemory splitListener,
PCollectionConsumerRegistry pCollectionConsumerRegistry,
MetricsContainerStepMap metricsContainerRegistry,
ExecutionStateTracker stateTracker,
@@ -574,7 +555,7 @@ public class ProcessBundleHandler {
startFunctionRegistry,
finishFunctionRegistry,
tearDownFunctions,
- allResiduals,
+ splitListener,
pCollectionConsumerRegistry,
metricsContainerRegistry,
stateTracker,
@@ -591,7 +572,7 @@ public class ProcessBundleHandler {
abstract List<ThrowingRunnable> getTearDownFunctions();
- abstract Multimap<String, DelayedBundleApplication> getAllResiduals();
+ abstract BundleSplitListener.InMemory getSplitListener();
abstract PCollectionConsumerRegistry getpCollectionConsumerRegistry();
@@ -616,7 +597,7 @@ public class ProcessBundleHandler {
void reset() {
getStartFunctionRegistry().reset();
getFinishFunctionRegistry().reset();
- getAllResiduals().clear();
+ getSplitListener().clear();
getpCollectionConsumerRegistry().reset();
getMetricsContainerRegistry().reset();
getStateTracker().reset();
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index b37ac2c..0aa2381 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
@@ -37,6 +38,14 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.ServiceLoader;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
import org.apache.beam.fn.harness.control.BundleSplitListener;
import org.apache.beam.fn.harness.data.FakeBeamFnTimerClient;
import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
@@ -48,9 +57,11 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
+import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
+import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
@@ -61,6 +72,7 @@ import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
@@ -1019,30 +1031,85 @@ public class FnApiDoFnRunnerTest implements Serializable {
fail("Expected registrar not found.");
}
+ /**
+ * The trySplit testing of this splittable DoFn is done when processing the {@link
+ * TestSplittableDoFn#SPLIT_ELEMENT}.
+ *
+ * <p>The expected thread flow is:
+ *
+ * <ul>
+ * <li>splitting thread: {@link TestSplittableDoFn#waitForSplitElementToBeProcessed()}
+ * <li>process element thread: {@link TestSplittableDoFn#enableAndWaitForTrySplitToHappen()}
+ * <li>splitting thread: perform try split
+ * <li>splitting thread: {@link TestSplittableDoFn#releaseWaitingProcessElementThread()}
+ * </ul>
+ */
static class TestSplittableDoFn extends DoFn<String, String> {
+ private static final ConcurrentMap<String, KV<CountDownLatch, CountDownLatch>>
+ DOFN_INSTANCE_TO_LOCK = new ConcurrentHashMap<>();
+ private static final long SPLIT_ELEMENT = 3;
+
+ private KV<CountDownLatch, CountDownLatch> getLatches() {
+ return DOFN_INSTANCE_TO_LOCK.computeIfAbsent(
+ this.uuid, (uuid) -> KV.of(new CountDownLatch(1), new CountDownLatch(1)));
+ }
+
+ private void enableAndWaitForTrySplitToHappen() throws Exception {
+ KV<CountDownLatch, CountDownLatch> latches = getLatches();
+ latches.getKey().countDown();
+ if (!latches.getValue().await(30, TimeUnit.SECONDS)) {
+ fail("Failed to wait for trySplit to occur.");
+ }
+ }
+
+ private void waitForSplitElementToBeProcessed() throws Exception {
+ KV<CountDownLatch, CountDownLatch> latches = getLatches();
+ if (!latches.getKey().await(30, TimeUnit.SECONDS)) {
+ fail("Failed to wait for split element to be processed.");
+ }
+ }
+
+ private void releaseWaitingProcessElementThread() {
+ KV<CountDownLatch, CountDownLatch> latches = getLatches();
+ latches.getValue().countDown();
+ }
+
private final PCollectionView<String> singletonSideInput;
+ private final String uuid;
private TestSplittableDoFn(PCollectionView<String> singletonSideInput) {
this.singletonSideInput = singletonSideInput;
+ this.uuid = UUID.randomUUID().toString();
}
@ProcessElement
public ProcessContinuation processElement(
ProcessContext context,
RestrictionTracker<OffsetRange, Long> tracker,
- ManualWatermarkEstimator<Instant> watermarkEstimator) {
- int upperBound = Integer.parseInt(context.sideInput(singletonSideInput));
- for (int i = 0; i < upperBound; ++i) {
- if (tracker.tryClaim((long) i)) {
- context.outputWithTimestamp(
- context.element() + ":" + i, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
- watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
+ ManualWatermarkEstimator<Instant> watermarkEstimator)
+ throws Exception {
+ long checkpointUpperBound = Long.parseLong(context.sideInput(singletonSideInput));
+ long position = tracker.currentRestriction().getFrom();
+ boolean claimStatus;
+ while (true) {
+ claimStatus = (tracker.tryClaim(position));
+ if (!claimStatus) {
+ break;
+ } else if (position == SPLIT_ELEMENT) {
+ enableAndWaitForTrySplitToHappen();
+ }
+ context.outputWithTimestamp(
+ context.element() + ":" + position, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+ watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(position));
+ position += 1L;
+ if (position == checkpointUpperBound) {
+ break;
}
}
- if (tracker.currentRestriction().getTo() > upperBound) {
- return ProcessContinuation.resume().withResumeDelay(Duration.millis(42L));
- } else {
+ if (!claimStatus) {
return ProcessContinuation.stop();
+ } else {
+ return ProcessContinuation.resume().withResumeDelay(Duration.millis(54321L));
}
}
@@ -1079,10 +1146,9 @@ public class FnApiDoFnRunnerTest implements Serializable {
Pipeline p = Pipeline.create();
PCollection<String> valuePCollection = p.apply(Create.of("unused"));
PCollectionView<String> singletonSideInputView = valuePCollection.apply(View.asSingleton());
+ TestSplittableDoFn doFn = new TestSplittableDoFn(singletonSideInputView);
valuePCollection.apply(
- TEST_TRANSFORM_ID,
- ParDo.of(new TestSplittableDoFn(singletonSideInputView))
- .withSideInputs(singletonSideInputView));
+ TEST_TRANSFORM_ID, ParDo.of(doFn).withSideInputs(singletonSideInputView));
RunnerApi.Pipeline pProto =
ProtoOverrides.updateTransform(
@@ -1106,12 +1172,32 @@ public class FnApiDoFnRunnerTest implements Serializable {
pProto.getComponents().getTransformsOrThrow(expandedTransformId);
String inputPCollectionId =
pTransform.getInputsOrThrow(ParDoTranslation.getMainInputName(pTransform));
+ RunnerApi.PCollection inputPCollection =
+ pProto.getComponents().getPcollectionsOrThrow(inputPCollectionId);
+ RehydratedComponents rehydratedComponents =
+ RehydratedComponents.forComponents(pProto.getComponents());
+ Coder<WindowedValue> inputCoder =
+ WindowedValue.getFullCoder(
+ CoderTranslation.fromProto(
+ pProto.getComponents().getCodersOrThrow(inputPCollection.getCoderId()),
+ rehydratedComponents),
+ (Coder)
+ CoderTranslation.fromProto(
+ pProto
+ .getComponents()
+ .getCodersOrThrow(
+ pProto
+ .getComponents()
+ .getWindowingStrategiesOrThrow(
+ inputPCollection.getWindowingStrategyId())
+ .getWindowCoderId()),
+ rehydratedComponents));
String outputPCollectionId = pTransform.getOutputsOrThrow("output");
ImmutableMap<StateKey, ByteString> stateData =
ImmutableMap.of(
multimapSideInputKey(singletonSideInputView.getTagInternal().getId(), ByteString.EMPTY),
- encode("3"));
+ encode("8"));
FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(stateData);
@@ -1131,8 +1217,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
new PTransformFunctionRegistry(
mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish");
List<ThrowingRunnable> teardownFunctions = new ArrayList<>();
- List<BundleApplication> primarySplits = new ArrayList<>();
- List<DelayedBundleApplication> residualSplits = new ArrayList<>();
+ BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
new FnApiDoFnRunner.Factory<>()
.createRunnerForPTransform(
@@ -1150,15 +1235,7 @@ public class FnApiDoFnRunnerTest implements Serializable {
startFunctionRegistry,
finishFunctionRegistry,
teardownFunctions::add,
- new BundleSplitListener() {
- @Override
- public void split(
- List<BundleApplication> primaryRoots,
- List<DelayedBundleApplication> residualRoots) {
- primarySplits.addAll(primaryRoots);
- residualSplits.addAll(residualRoots);
- }
- },
+ splitListener,
null /* bundleFinalizer */);
Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run();
@@ -1168,46 +1245,146 @@ public class FnApiDoFnRunnerTest implements Serializable {
FnDataReceiver<WindowedValue<?>> mainInput =
consumers.getMultiplexingConsumer(inputPCollectionId);
- mainInput.accept(
- valueInGlobalWindow(
- KV.of(
- KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0)));
- BundleApplication primaryRoot = Iterables.getOnlyElement(primarySplits);
- DelayedBundleApplication residualRoot = Iterables.getOnlyElement(residualSplits);
- assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
- assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
- assertEquals(
- ParDoTranslation.getMainInputName(pTransform), residualRoot.getApplication().getInputId());
- assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
- Instant expectedOutputWatermark =
- GlobalWindow.TIMESTAMP_MIN_VALUE.plus(
- 2); // side input upperBound is 3 hence we only process the first two elements
- assertEquals(
- ImmutableMap.of(
- "output",
- org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
- .setSeconds(expectedOutputWatermark.getMillis() / 1000)
- .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
- .build()),
- residualRoot.getApplication().getOutputWatermarksMap());
- primarySplits.clear();
- residualSplits.clear();
+ assertThat(mainInput, instanceOf(HandlesSplits.class));
+
+ {
+ mainInput.accept(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 10), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 5.0)));
+ // Since the side input upperBound is 8 we will process 5, 6, and 7 then checkpoint.
+ // We expect that the watermark advances to MIN + 7 and that the primary represents [5, 8)
+ // with
+ // the original watermark while the residual represents [8, 10) with the new MIN + 7
+ // watermark.
+ BundleApplication primaryRoot = Iterables.getOnlyElement(splitListener.getPrimaryRoots());
+ DelayedBundleApplication residualRoot =
+ Iterables.getOnlyElement(splitListener.getResidualRoots());
+ assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+ assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+ assertEquals(
+ ParDoTranslation.getMainInputName(pTransform),
+ residualRoot.getApplication().getInputId());
+ assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("5", KV.of(new OffsetRange(5, 8), GlobalWindow.TIMESTAMP_MIN_VALUE)), 3.0)),
+ inputCoder.decode(primaryRoot.getElement().newInput()));
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of(
+ "5", KV.of(new OffsetRange(8, 10), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7))),
+ 2.0)),
+ inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+ Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7);
+ assertEquals(
+ ImmutableMap.of(
+ "output",
+ org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+ .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+ .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+ .build()),
+ residualRoot.getApplication().getOutputWatermarksMap());
+ assertEquals(
+ org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Duration.newBuilder()
+ .setSeconds(54)
+ .setNanos(321000000)
+ .build(),
+ residualRoot.getRequestedTimeDelay());
+ splitListener.clear();
+
+ mainInput.accept(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 2.0)));
+ assertThat(
+ mainOutputValues,
+ contains(
+ timestampedValueInGlobalWindow("5:5", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(5)),
+ timestampedValueInGlobalWindow("5:6", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(6)),
+ timestampedValueInGlobalWindow("5:7", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(7)),
+ timestampedValueInGlobalWindow("2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
+ timestampedValueInGlobalWindow("2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))));
+ assertTrue(splitListener.getPrimaryRoots().isEmpty());
+ assertTrue(splitListener.getResidualRoots().isEmpty());
+ mainOutputValues.clear();
+ }
- mainInput.accept(
- valueInGlobalWindow(
- KV.of(
- KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0)));
- assertThat(
- mainOutputValues,
- contains(
- timestampedValueInGlobalWindow("5:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
- timestampedValueInGlobalWindow("5:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1)),
- timestampedValueInGlobalWindow("5:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
- timestampedValueInGlobalWindow("2:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
- timestampedValueInGlobalWindow("2:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1))));
- assertTrue(primarySplits.isEmpty());
- assertTrue(residualSplits.isEmpty());
- mainOutputValues.clear();
+ {
+ // Setup and launch the trySplit thread.
+ ExecutorService executorService = Executors.newSingleThreadExecutor();
+ Future<HandlesSplits.SplitResult> trySplitFuture =
+ executorService.submit(
+ () -> {
+ try {
+ doFn.waitForSplitElementToBeProcessed();
+ return ((HandlesSplits) mainInput).trySplit(0);
+ } finally {
+ doFn.releaseWaitingProcessElementThread();
+ }
+ });
+
+ mainInput.accept(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("7", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+ 2.0)));
+ // Since the SPLIT_ELEMENT is 3 we will process 0, 1, 2, 3 then be split.
+ // We expect that the watermark advances to MIN + 2 since the manual watermark estimator
+ // has yet to be invoked for the split element and that the primary represents [0, 4) with
+ // the original watermark while the residual represents [4, 5) with the new MIN + 2 watermark.
+ assertThat(
+ mainOutputValues,
+ contains(
+ timestampedValueInGlobalWindow("7:0", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(0)),
+ timestampedValueInGlobalWindow("7:1", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(1)),
+ timestampedValueInGlobalWindow("7:2", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2)),
+ timestampedValueInGlobalWindow("7:3", GlobalWindow.TIMESTAMP_MIN_VALUE.plus(3))));
+
+ HandlesSplits.SplitResult trySplitResult = trySplitFuture.get();
+ BundleApplication primaryRoot = trySplitResult.getPrimaryRoot();
+ DelayedBundleApplication residualRoot = trySplitResult.getResidualRoot();
+ assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
+ assertEquals(TEST_TRANSFORM_ID, primaryRoot.getTransformId());
+ assertEquals(
+ ParDoTranslation.getMainInputName(pTransform),
+ residualRoot.getApplication().getInputId());
+ assertEquals(TEST_TRANSFORM_ID, residualRoot.getApplication().getTransformId());
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of("7", KV.of(new OffsetRange(0, 4), GlobalWindow.TIMESTAMP_MIN_VALUE)), 4.0)),
+ inputCoder.decode(primaryRoot.getElement().newInput()));
+ assertEquals(
+ valueInGlobalWindow(
+ KV.of(
+ KV.of(
+ "7", KV.of(new OffsetRange(4, 5), GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2))),
+ 1.0)),
+ inputCoder.decode(residualRoot.getApplication().getElement().newInput()));
+ Instant expectedOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE.plus(2);
+ assertEquals(
+ ImmutableMap.of(
+ "output",
+ org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
+ .setSeconds(expectedOutputWatermark.getMillis() / 1000)
+ .setNanos((int) (expectedOutputWatermark.getMillis() % 1000) * 1000000)
+ .build()),
+ residualRoot.getApplication().getOutputWatermarksMap());
+ // We expect 0 resume delay.
+ assertEquals(
+ residualRoot.getRequestedTimeDelay().getDefaultInstanceForType(),
+ residualRoot.getRequestedTimeDelay());
+ // We don't expect the outputs to goto the SDK initiated checkpointing listener.
+ assertTrue(splitListener.getPrimaryRoots().isEmpty());
+ assertTrue(splitListener.getResidualRoots().isEmpty());
+ mainOutputValues.clear();
+ executorService.shutdown();
+ }
Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
assertThat(mainOutputValues, empty());
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
new file mode 100644
index 0000000..00b4e15
--- /dev/null
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/BundleSplitListenerTest.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.control;
+
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.empty;
+import static org.junit.Assert.assertThat;
+
+import java.util.Arrays;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.BundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link BundleSplitListener}. */
+@RunWith(JUnit4.class)
+public class BundleSplitListenerTest {
+ private static final BundleApplication TEST_PRIMARY_1 =
+ BundleApplication.newBuilder().setInputId("primary1").build();
+ private static final BundleApplication TEST_PRIMARY_2 =
+ BundleApplication.newBuilder().setInputId("primary2").build();
+ private static final DelayedBundleApplication TEST_RESIDUAL_1 =
+ DelayedBundleApplication.newBuilder()
+ .setApplication(BundleApplication.newBuilder().setInputId("residual1").build())
+ .build();
+ private static final DelayedBundleApplication TEST_RESIDUAL_2 =
+ DelayedBundleApplication.newBuilder()
+ .setApplication(BundleApplication.newBuilder().setInputId("residual2").build())
+ .build();
+
+ @Test
+ public void testInMemory() {
+ BundleSplitListener.InMemory splitListener = BundleSplitListener.InMemory.create();
+ splitListener.split(Arrays.asList(TEST_PRIMARY_1), Arrays.asList(TEST_RESIDUAL_1));
+ splitListener.split(Arrays.asList(TEST_PRIMARY_2), Arrays.asList(TEST_RESIDUAL_2));
+ assertThat(splitListener.getPrimaryRoots(), contains(TEST_PRIMARY_1, TEST_PRIMARY_2));
+ assertThat(splitListener.getResidualRoots(), contains(TEST_RESIDUAL_1, TEST_RESIDUAL_2));
+ splitListener.clear();
+ assertThat(splitListener.getPrimaryRoots(), empty());
+ assertThat(splitListener.getResidualRoots(), empty());
+ }
+}
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index ff97ed9..982fd96 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -92,7 +92,6 @@ import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Message;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.joda.time.Instant;
import org.junit.Before;
@@ -196,8 +195,8 @@ public class ProcessBundleHandlerTest {
}
@Override
- Multimap<String, BeamFnApi.DelayedBundleApplication> getAllResiduals() {
- return wrappedBundleProcessor.getAllResiduals();
+ BundleSplitListener.InMemory getSplitListener() {
+ return wrappedBundleProcessor.getSplitListener();
}
@Override
@@ -535,7 +534,7 @@ public class ProcessBundleHandlerTest {
public void testBundleProcessorReset() {
PTransformFunctionRegistry startFunctionRegistry = mock(PTransformFunctionRegistry.class);
PTransformFunctionRegistry finishFunctionRegistry = mock(PTransformFunctionRegistry.class);
- Multimap<String, BeamFnApi.DelayedBundleApplication> allResiduals = mock(Multimap.class);
+ BundleSplitListener.InMemory splitListener = mock(BundleSplitListener.InMemory.class);
Collection<CallbackRegistration> bundleFinalizationCallbacks = mock(Collection.class);
PCollectionConsumerRegistry pCollectionConsumerRegistry =
mock(PCollectionConsumerRegistry.class);
@@ -549,7 +548,7 @@ public class ProcessBundleHandlerTest {
startFunctionRegistry,
finishFunctionRegistry,
new ArrayList<>(),
- allResiduals,
+ splitListener,
pCollectionConsumerRegistry,
metricsContainerRegistry,
stateTracker,
@@ -560,7 +559,7 @@ public class ProcessBundleHandlerTest {
bundleProcessor.reset();
verify(startFunctionRegistry, times(1)).reset();
verify(finishFunctionRegistry, times(1)).reset();
- verify(allResiduals, times(1)).clear();
+ verify(splitListener, times(1)).clear();
verify(pCollectionConsumerRegistry, times(1)).reset();
verify(metricsContainerRegistry, times(1)).reset();
verify(stateTracker, times(1)).reset();