You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2020/11/11 23:00:56 UTC
[beam] branch master updated: Add sdf initiated checkpoint support
to portable Flink.
This is an automated email from the ASF dual-hosted git repository.
boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 28914c2 Add sdf initiated checkpoint support to portable Flink.
new cf6fd1c Merge pull request #13105 from [BEAM-10940] Add sdf initiated checkpoint support to portable Flink.
28914c2 is described below
commit 28914c2679feaae8bf00955229f64bb46d3970cd
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Tue Oct 13 19:37:26 2020 -0700
Add sdf initiated checkpoint support to portable Flink.
---
runners/flink/job-server/flink_job_server.gradle | 10 +-
.../FlinkBatchPortablePipelineTranslator.java | 3 +-
.../FlinkStreamingPortablePipelineTranslator.java | 43 ++-
.../functions/FlinkExecutableStageFunction.java | 84 +++++-
.../wrappers/streaming/DoFnOperator.java | 31 ++-
.../streaming/ExecutableStageDoFnOperator.java | 288 ++++++++++++++++++---
.../streaming/SdfByteBufferKeySelector.java | 61 +++++
.../FlinkExecutableStageFunctionTest.java | 14 +-
.../streaming/ExecutableStageDoFnOperatorTest.java | 18 +-
.../control/BundleCheckpointHandlers.java | 142 ++++++++++
.../control/DefaultJobBundleFactory.java | 6 +-
.../fnexecution/control/SdkHarnessClient.java | 17 +-
.../SingleEnvironmentInstanceJobBundleFactory.java | 6 +-
.../fnexecution/control/StageBundleFactory.java | 35 ++-
.../fnexecution/control/RemoteExecutionTest.java | 1 +
.../SparkExecutableStageFunctionTest.java | 4 +-
.../runners/portability/flink_runner_test.py | 3 -
17 files changed, 673 insertions(+), 93 deletions(-)
diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle
index 24b0838..1ca09a1 100644
--- a/runners/flink/job-server/flink_job_server.gradle
+++ b/runners/flink/job-server/flink_job_server.gradle
@@ -146,8 +146,6 @@ def portableValidatesRunnerTask(String name, Boolean streaming, Boolean checkpoi
if (streaming && checkpointing) {
includeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer'
excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
- // TODO(BEAM-10940): Enable this test suite once we have support.
- excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
// TestStreamSource does not support checkpointing
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
} else {
@@ -169,18 +167,16 @@ def portableValidatesRunnerTask(String name, Boolean streaming, Boolean checkpoi
excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer'
excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'
if (streaming) {
+ excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithOutputTimestamp'
} else {
+ excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
+ excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedPCollections'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
}
- //SplitableDoFnTests
- excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
- excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
- excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
-
}
},
testFilter: {
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
index 25a0ed0..32726f3 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
@@ -342,7 +342,8 @@ public class FlinkBatchPortablePipelineTranslator
context.getJobInfo(),
outputMap,
FlinkExecutableStageContextFactory.getInstance(),
- getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder());
+ getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder(),
+ windowedInputCoder);
final String operatorName = generateNameFromStagePayload(stagePayload);
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
index 2112941..c1d2583 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
@@ -56,6 +56,7 @@ import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
import org.apache.beam.runners.flink.translation.wrappers.streaming.ExecutableStageDoFnOperator;
import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToByteBufferKeySelector;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.SdfByteBufferKeySelector;
import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItem;
import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder;
import org.apache.beam.runners.flink.translation.wrappers.streaming.WindowDoFnOperator;
@@ -682,10 +683,20 @@ public class FlinkStreamingPortablePipelineTranslator
final boolean stateful =
stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0;
+ final boolean hasSdfProcessFn =
+ stagePayload.getComponents().getTransformsMap().values().stream()
+ .anyMatch(
+ pTransform ->
+ pTransform
+ .getSpec()
+ .getUrn()
+ .equals(
+ PTransformTranslation
+ .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
Coder keyCoder = null;
KeySelector<WindowedValue<InputT>, ?> keySelector = null;
- if (stateful) {
- // Stateful stages are only allowed of KV input
+ if (stateful || hasSdfProcessFn) {
+ // Stateful/SDF stages are only allowed of KV input.
Coder valueCoder =
((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
if (!(valueCoder instanceof KvCoder)) {
@@ -696,10 +707,28 @@ public class FlinkStreamingPortablePipelineTranslator
inputPCollectionId,
valueCoder.getClass().getSimpleName()));
}
- keyCoder = ((KvCoder) valueCoder).getKeyCoder();
- keySelector =
- new KvToByteBufferKeySelector(
- keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
+ if (stateful) {
+ keyCoder = ((KvCoder) valueCoder).getKeyCoder();
+ keySelector =
+ new KvToByteBufferKeySelector(
+ keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
+ } else {
+ // For an SDF, we know that the input element should be
+ // KV<KV<element, KV<restriction, watermarkState>>, size>. We are going to use the element
+ // as the key.
+ if (!(((KvCoder) valueCoder).getKeyCoder() instanceof KvCoder)) {
+ throw new IllegalStateException(
+ String.format(
+ Locale.ENGLISH,
+ "The element coder for splittable DoFn '%s' must be KVCoder(KvCoder, DoubleCoder) but is: %s",
+ inputPCollectionId,
+ valueCoder.getClass().getSimpleName()));
+ }
+ keyCoder = ((KvCoder) ((KvCoder) valueCoder).getKeyCoder()).getKeyCoder();
+ keySelector =
+ new SdfByteBufferKeySelector(
+ keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
+ }
inputDataStream = inputDataStream.keyBy(keySelector);
}
@@ -738,7 +767,7 @@ public class FlinkStreamingPortablePipelineTranslator
} else {
DataStream<RawUnionValue> sideInputStream =
transformedSideInputs.unionedSideInputs.broadcast();
- if (stateful) {
+ if (stateful || hasSdfProcessFn) {
// We have to manually construct the two-input transform because we're not
// allowed to have only one input keyed, normally. Since Flink 1.5.0 it's
// possible to use the Broadcast State Pattern which provides a more elegant
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
index cb5a2a6..6c9c984 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
@@ -18,7 +18,9 @@
package org.apache.beam.runners.flink.translation.functions;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.EnumMap;
+import java.util.List;
import java.util.Locale;
import java.util.Map;
import javax.annotation.concurrent.GuardedBy;
@@ -26,13 +28,19 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressRespo
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.InMemoryStateInternals;
import org.apache.beam.runners.core.InMemoryTimerInternals;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.Timer;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -96,6 +104,7 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
private final Map<String, Integer> outputMap;
private final FlinkExecutableStageContextFactory contextFactory;
private final Coder windowCoder;
+ private final Coder<WindowedValue<InputT>> inputCoder;
// Unique name for namespacing metrics
private final String stepName;
@@ -107,6 +116,9 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
private transient StageBundleFactory stageBundleFactory;
private transient BundleProgressHandler progressHandler;
private transient BundleFinalizationHandler finalizationHandler;
+ private transient BundleCheckpointHandler bundleCheckpointHandler;
+ private transient InMemoryTimerInternals sdfTimerInternals;
+ private transient StateInternals sdfStateInternals;
// Only initialized when the ExecutableStage is stateful
private transient InMemoryBagUserStateFactory bagUserStateHandlerFactory;
private transient ExecutableStage executableStage;
@@ -120,7 +132,8 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
JobInfo jobInfo,
Map<String, Integer> outputMap,
FlinkExecutableStageContextFactory contextFactory,
- Coder windowCoder) {
+ Coder windowCoder,
+ Coder<WindowedValue<InputT>> inputCoder) {
this.stepName = stepName;
this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
this.stagePayload = stagePayload;
@@ -128,6 +141,7 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
this.outputMap = outputMap;
this.contextFactory = contextFactory;
this.windowCoder = windowCoder;
+ this.inputCoder = inputCoder;
}
@Override
@@ -165,6 +179,35 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
throw new UnsupportedOperationException(
"Portable Flink runner doesn't support bundle finalization in batch mode. For more details, please refer to https://issues.apache.org/jira/browse/BEAM-11021.");
};
+ bundleCheckpointHandler = getBundleCheckpointHandler(executableStage);
+ }
+
+ private boolean hasSDF(ExecutableStage executableStage) {
+ return executableStage.getTransforms().stream()
+ .anyMatch(
+ pTransformNode ->
+ pTransformNode
+ .getTransform()
+ .getSpec()
+ .getUrn()
+ .equals(
+ PTransformTranslation
+ .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
+ }
+
+ private BundleCheckpointHandler getBundleCheckpointHandler(ExecutableStage executableStage) {
+ if (!hasSDF(executableStage)) {
+ sdfStateInternals = null;
+ sdfStateInternals = null;
+ return response -> {
+ throw new UnsupportedOperationException(
+ "Self-checkpoint is only supported on splittable DoFn.");
+ };
+ }
+ sdfTimerInternals = new InMemoryTimerInternals();
+ sdfStateInternals = InMemoryStateInternals.forKey("sdf_state");
+ return new BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler(
+ key -> sdfTimerInternals, key -> sdfStateInternals, inputCoder, windowCoder);
}
private StateRequestHandler getStateRequestHandler(
@@ -210,11 +253,47 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
throws Exception {
ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap);
+ if (sdfStateInternals != null) {
+ sdfTimerInternals.advanceProcessingTime(Instant.now());
+ sdfTimerInternals.advanceSynchronizedProcessingTime(Instant.now());
+ }
try (RemoteBundle bundle =
stageBundleFactory.getBundle(
- receiverFactory, stateRequestHandler, progressHandler, finalizationHandler)) {
+ receiverFactory,
+ stateRequestHandler,
+ progressHandler,
+ finalizationHandler,
+ bundleCheckpointHandler)) {
processElements(iterable, bundle);
}
+ if (sdfTimerInternals != null) {
+ // Finally, advance the processing time to infinity to fire any timers.
+ sdfTimerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+ sdfTimerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+ // Now we fire the SDF timers and process elements generated by timers.
+ while (sdfTimerInternals.hasPendingTimers()) {
+ try (RemoteBundle bundle =
+ stageBundleFactory.getBundle(
+ receiverFactory,
+ stateRequestHandler,
+ progressHandler,
+ finalizationHandler,
+ bundleCheckpointHandler)) {
+ List<WindowedValue<InputT>> residuals = new ArrayList<>();
+ TimerInternals.TimerData timer;
+ while ((timer = sdfTimerInternals.removeNextProcessingTimer()) != null) {
+ WindowedValue stateValue =
+ sdfStateInternals
+ .state(timer.getNamespace(), StateTags.value(timer.getTimerId(), inputCoder))
+ .read();
+
+ residuals.add(stateValue);
+ }
+ processElements(residuals, bundle);
+ }
+ }
+ }
}
/** For stateful and timer processing via a GroupReduceFunction. */
@@ -267,7 +346,6 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
try (RemoteBundle bundle =
stageBundleFactory.getBundle(
receiverFactory, timerReceiverFactory, stateRequestHandler, progressHandler)) {
-
PipelineTranslatorUtils.fireEligibleTimers(
timerInternals, bundle.getTimerReceivers(), currentTimerKey);
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
index 06c9398..09eacd0 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
@@ -64,6 +64,7 @@ import org.apache.beam.runners.flink.translation.utils.Workarounds;
import org.apache.beam.runners.flink.translation.wrappers.streaming.stableinput.BufferingDoFnRunner;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals;
import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StructuredCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
@@ -1332,7 +1333,7 @@ public class DoFnOperator<InputT, OutputT>
keyedStateBackend.setCurrentKey(internalTimer.getKey());
TimerData timer = internalTimer.getNamespace();
checkInvokeStartBundle();
- fireTimer(timer);
+ fireTimerInternal((ByteBuffer) internalTimer.getKey(), timer);
}
}
@@ -1347,16 +1348,13 @@ public class DoFnOperator<InputT, OutputT>
}
}
- private void onRemovedEventTimer(TimerData removedTimer) {
+ /** Holds the watermark when there is an sdf timer. */
+ private void onNewSdfTimer(TimerData newTimer) {
Preconditions.checkState(
- removedTimer.getDomain() == TimeDomain.EVENT_TIME,
- "Timer with id %s is not an event time timer!",
- removedTimer.getTimerId());
- // Remove the first occurrence of the output timestamp, if cached
- // Note: There may be duplicate timestamps from other timers, that's ok.
- if (timerUsesOutputTimestamp(removedTimer)) {
- keyedStateInternals.removeWatermarkHoldUsage(removedTimer.getOutputTimestamp());
- }
+ StateAndTimerBundleCheckpointHandler.isSdfTimer(newTimer.getTimerId()));
+ // An SDF timer should hold the watermark for further output.
+ Preconditions.checkState(timerUsesOutputTimestamp(newTimer));
+ keyedStateInternals.addWatermarkHoldUsage(newTimer.getOutputTimestamp());
}
private void populateOutputTimestampQueue() {
@@ -1369,7 +1367,8 @@ public class DoFnOperator<InputT, OutputT>
keyedStateBackend.setCurrentKey(key);
try {
for (TimerData timerData : pendingTimersById.values()) {
- if (timerData.getDomain() == TimeDomain.EVENT_TIME) {
+ if (timerData.getDomain() == TimeDomain.EVENT_TIME
+ || StateAndTimerBundleCheckpointHandler.isSdfTimer(timerData.getTimerId())) {
if (timerUsesOutputTimestamp(timerData)) {
keyedStateInternals.addWatermarkHoldUsage(timerData.getOutputTimestamp());
}
@@ -1441,6 +1440,9 @@ public class DoFnOperator<InputT, OutputT>
case PROCESSING_TIME:
case SYNCHRONIZED_PROCESSING_TIME:
timerService.registerProcessingTimeTimer(timer, adjustTimestampForFlink(time));
+ if (StateAndTimerBundleCheckpointHandler.isSdfTimer(timer.getTimerId())) {
+ onNewSdfTimer(timer);
+ }
break;
default:
throw new UnsupportedOperationException("Unsupported time domain: " + timer.getDomain());
@@ -1466,8 +1468,11 @@ public class DoFnOperator<InputT, OutputT>
void onFiredOrDeletedTimer(TimerData timer) {
try {
pendingTimersById.remove(getContextTimerId(timer.getTimerId(), timer.getNamespace()));
- if (timer.getDomain() == TimeDomain.EVENT_TIME) {
- onRemovedEventTimer(timer);
+ if (timer.getDomain() == TimeDomain.EVENT_TIME
+ || StateAndTimerBundleCheckpointHandler.isSdfTimer(timer.getTimerId())) {
+ if (timerUsesOutputTimestamp(timer)) {
+ keyedStateInternals.removeWatermarkHoldUsage(timer.getOutputTimestamp());
+ }
}
} catch (Exception e) {
throw new RuntimeException("Failed to cleanup pending timers state.", e);
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index a27c966..d479a3c 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -40,6 +40,7 @@ import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase;
@@ -47,6 +48,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.LateDataUtils;
import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateNamespaces;
import org.apache.beam.runners.core.StateTag;
@@ -54,6 +56,8 @@ import org.apache.beam.runners.core.StateTags;
import org.apache.beam.runners.core.StatefulDoFnRunner;
import org.apache.beam.runners.core.StepContext;
import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.TimerInternalsFactory;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.Timer;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
@@ -61,6 +65,10 @@ import org.apache.beam.runners.core.construction.graph.UserStateReference;
import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory;
import org.apache.beam.runners.flink.translation.functions.FlinkStreamingSideInputHandlerFactory;
import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandlers;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandlers.InMemoryFinalizer;
@@ -80,6 +88,8 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.function.ThrowingFunction;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.State;
+import org.apache.beam.sdk.state.StateContext;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
@@ -141,10 +151,15 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
private final boolean isStateful;
+ private final Coder windowCoder;
+ private final Coder<WindowedValue<InputT>> inputCoder;
+
private transient ExecutableStageContext stageContext;
private transient StateRequestHandler stateRequestHandler;
private transient BundleProgressHandler progressHandler;
private transient InMemoryFinalizer finalizationHandler;
+ private transient BundleCheckpointHandler checkpointHandler;
+ private transient boolean hasSdfProcessFn;
private transient StageBundleFactory stageBundleFactory;
private transient ExecutableStage executableStage;
private transient SdkHarnessDoFnRunner<InputT, OutputT> sdkHarnessRunner;
@@ -203,6 +218,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
this.outputMap = outputMap;
this.sideInputIds = sideInputIds;
this.stateBackendLock = new ReentrantLock();
+ this.windowCoder = (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder();
+ this.inputCoder = windowedInputCoder;
this.pipelineOptions = new SerializablePipelineOptions(options);
}
@@ -214,6 +231,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
@Override
public void open() throws Exception {
executableStage = ExecutableStage.fromPayload(payload);
+ hasSdfProcessFn = hasSDF(executableStage);
initializeUserState(executableStage, getKeyedStateBackend(), pipelineOptions);
// TODO: Wire this into the distributed cache and make it pluggable.
// TODO: Do we really want this layer of indirection when accessing the stage bundle factory?
@@ -244,6 +262,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
BundleFinalizationHandlers.inMemoryFinalizer(
stageBundleFactory.getInstructionRequestHandler());
+ checkpointHandler = getBundleCheckpointHandler(hasSdfProcessFn);
+
minEventTimeTimerTimestampInCurrentBundle = Long.MAX_VALUE;
minEventTimeTimerTimestampInLastBundle = Long.MAX_VALUE;
super.setPreBundleCallback(this::preBundleStartCallback);
@@ -259,6 +279,34 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
super.notifyCheckpointComplete(checkpointId);
}
+ private BundleCheckpointHandler getBundleCheckpointHandler(boolean hasSDF) {
+ if (!hasSDF) {
+ return response -> {
+ throw new UnsupportedOperationException(
+ "Self-checkpoint is only supported on splittable DoFn.");
+ };
+ }
+
+ return new BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler(
+ new SdfFlinkTimerInternalsFactory(),
+ new SdfFlinkStateInternalsFactory(),
+ inputCoder,
+ windowCoder);
+ }
+
+ private boolean hasSDF(ExecutableStage executableStage) {
+ return executableStage.getTransforms().stream()
+ .anyMatch(
+ pTransformNode ->
+ pTransformNode
+ .getTransform()
+ .getSpec()
+ .getUrn()
+ .equals(
+ PTransformTranslation
+ .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
+ }
+
private StateRequestHandler getStateRequestHandler(ExecutableStage executableStage) {
final StateRequestHandler sideInputStateHandler;
@@ -491,6 +539,148 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
}
}
+ /**
+ * A {@link TimerInternalsFactory} for Flink operator to create a {@link
+ * StateAndTimerBundleCheckpointHandler} to handle {@link
+ * org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication}.
+ */
+ class SdfFlinkTimerInternalsFactory implements TimerInternalsFactory<InputT> {
+ @Override
+ public TimerInternals timerInternalsForKey(InputT key) {
+ try {
+ ByteBuffer encodedKey =
+ (ByteBuffer) keySelector.getKey(WindowedValue.valueInGlobalWindow(key));
+ return new SdfFlinkTimerInternals(encodedKey);
+ } catch (Exception e) {
+ throw new RuntimeException("Couldn't get a timer internals", e);
+ }
+ }
+ }
+
+ /**
+ * A {@link TimerInternals} for rescheduling {@link
+ * org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication}.
+ */
+ class SdfFlinkTimerInternals implements TimerInternals {
+ private final ByteBuffer key;
+
+ SdfFlinkTimerInternals(ByteBuffer key) {
+ this.key = key;
+ }
+
+ @Override
+ public void setTimer(
+ StateNamespace namespace,
+ String timerId,
+ String timerFamilyId,
+ Instant target,
+ Instant outputTimestamp,
+ TimeDomain timeDomain) {
+ setTimer(
+ TimerData.of(timerId, timerFamilyId, namespace, target, outputTimestamp, timeDomain));
+ }
+
+ @Override
+ public void setTimer(TimerData timerData) {
+ try {
+ try (Locker locker = Locker.locked(stateBackendLock)) {
+ getKeyedStateBackend().setCurrentKey(key);
+ timerInternals.setTimer(timerData);
+ minEventTimeTimerTimestampInCurrentBundle =
+ Math.min(
+ minEventTimeTimerTimestampInCurrentBundle,
+ adjustTimestampForFlink(timerData.getOutputTimestamp().getMillis()));
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Couldn't set timer", e);
+ }
+ }
+
+ @Override
+ public void deleteTimer(StateNamespace namespace, String timerId, TimeDomain timeDomain) {
+ throw new UnsupportedOperationException(
+ "It is not expected to use SdfFlinkTimerInternals to delete a timer");
+ }
+
+ @Override
+ public void deleteTimer(StateNamespace namespace, String timerId, String timerFamilyId) {
+ throw new UnsupportedOperationException(
+ "It is not expected to use SdfFlinkTimerInternals to delete a timer");
+ }
+
+ @Override
+ public void deleteTimer(TimerData timerKey) {
+ throw new UnsupportedOperationException(
+ "It is not expected to use SdfFlinkTimerInternals to delete a timer");
+ }
+
+ @Override
+ public Instant currentProcessingTime() {
+ return timerInternals.currentProcessingTime();
+ }
+
+ @Override
+ public @Nullable Instant currentSynchronizedProcessingTime() {
+ return timerInternals.currentSynchronizedProcessingTime();
+ }
+
+ @Override
+ public Instant currentInputWatermarkTime() {
+ return timerInternals.currentInputWatermarkTime();
+ }
+
+ @Override
+ public @Nullable Instant currentOutputWatermarkTime() {
+ return timerInternals.currentOutputWatermarkTime();
+ }
+ }
+
+ /**
+ * A {@link StateInternalsFactory} for Flink operator to create a {@link
+ * StateAndTimerBundleCheckpointHandler} to handle {@link
+ * org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication}.
+ */
+ class SdfFlinkStateInternalsFactory implements StateInternalsFactory<InputT> {
+ @Override
+ public StateInternals stateInternalsForKey(InputT key) {
+ try {
+ ByteBuffer encodedKey =
+ (ByteBuffer) keySelector.getKey(WindowedValue.valueInGlobalWindow(key));
+ return new SdfFlinkStateInternals(encodedKey);
+ } catch (Exception e) {
+ throw new RuntimeException("Couldn't get a state internals", e);
+ }
+ }
+ }
+
+ /** A {@link StateInternals} for keeping {@link DelayedBundleApplication}s as states. */
+ class SdfFlinkStateInternals implements StateInternals {
+
+ private final ByteBuffer key;
+
+ SdfFlinkStateInternals(ByteBuffer key) {
+ this.key = key;
+ }
+
+ @Override
+ public Object getKey() {
+ return key;
+ }
+
+ @Override
+ public <T extends State> T state(
+ StateNamespace namespace, StateTag<T> address, StateContext<?> c) {
+ try {
+ try (Locker locker = Locker.locked(stateBackendLock)) {
+ getKeyedStateBackend().setCurrentKey(key);
+ return keyedStateInternals.state(namespace, address);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Couldn't set state", e);
+ }
+ }
+ }
+
@Override
protected void fireTimerInternal(ByteBuffer key, TimerInternals.TimerData timer) {
// We have to synchronize to ensure the state backend is not concurrently accessed by the state
@@ -509,6 +699,17 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
processWatermark1(Watermark.MAX_WATERMARK);
while (getCurrentOutputWatermark() < Watermark.MAX_WATERMARK.getTimestamp()) {
invokeFinishBundle();
+ if (hasSdfProcessFn) {
+ // Manually drain processing time timers since Flink will ignore pending
+ // processing-time timers when upstream operators have shut down and will also
+ // shut down this operator with pending processing-time timers.
+ // TODO(BEAM-11210, FLINK-18647): It doesn't work efficiently when the watermark of upstream
+ // advances
+ // to MAX_TIMESTAMP immediately.
+ if (numProcessingTimeTimers() > 0) {
+ timerInternals.processPendingProcessingTimeTimers();
+ }
+ }
}
super.close();
}
@@ -549,11 +750,14 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
stateRequestHandler,
progressHandler,
finalizationHandler,
+ checkpointHandler,
outputManager,
outputMap,
- (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder(),
+ windowCoder,
+ inputCoder,
this::setTimer,
- () -> FlinkKeyUtils.decodeKey(getCurrentKey(), keyCoder));
+ () -> FlinkKeyUtils.decodeKey(getCurrentKey(), keyCoder),
+ keyedStateInternals);
return ensureStateDoFnRunner(sdkHarnessRunner, payload, stepContext);
}
@@ -658,10 +862,12 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
private final StateRequestHandler stateRequestHandler;
private final BundleProgressHandler progressHandler;
private final BundleFinalizationHandler finalizationHandler;
+ private final BundleCheckpointHandler checkpointHandler;
private final BufferedOutputManager<OutputT> outputManager;
private final Map<String, TupleTag<?>> outputMap;
-
+ private final FlinkStateInternals<?> keyedStateInternals;
private final Coder<BoundedWindow> windowCoder;
+ private final Coder<WindowedValue<InputT>> residualCoder;
private final BiConsumer<Timer<?>, TimerInternals.TimerData> timerRegistration;
private final Supplier<Object> keyForTimer;
@@ -682,23 +888,29 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler,
BufferedOutputManager<OutputT> outputManager,
Map<String, TupleTag<?>> outputMap,
Coder<BoundedWindow> windowCoder,
+ Coder<WindowedValue<InputT>> residualCoder,
BiConsumer<Timer<?>, TimerInternals.TimerData> timerRegistration,
- Supplier<Object> keyForTimer) {
+ Supplier<Object> keyForTimer,
+ FlinkStateInternals<?> keyedStateInternals) {
this.doFn = doFn;
this.stageBundleFactory = stageBundleFactory;
this.stateRequestHandler = stateRequestHandler;
this.progressHandler = progressHandler;
this.finalizationHandler = finalizationHandler;
+ this.checkpointHandler = checkpointHandler;
this.outputManager = outputManager;
this.outputMap = outputMap;
this.timerRegistration = timerRegistration;
this.keyForTimer = keyForTimer;
this.windowCoder = windowCoder;
+ this.residualCoder = residualCoder;
this.outputQueue = new LinkedBlockingQueue<>();
+ this.keyedStateInternals = keyedStateInternals;
}
@Override
@@ -723,7 +935,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
timerReceiverFactory,
stateRequestHandler,
progressHandler,
- finalizationHandler);
+ finalizationHandler,
+ checkpointHandler);
mainInputReceiver = Iterables.getOnlyElement(remoteBundle.getInputReceivers().values());
} catch (Exception e) {
throw new RuntimeException("Failed to start remote bundle", e);
@@ -753,35 +966,42 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
Object timerKey = keyForTimer.get();
Preconditions.checkNotNull(timerKey, "Key for timer needs to be set before calling onTimer");
Preconditions.checkNotNull(remoteBundle, "Call to onTimer outside of a bundle");
- KV<String, String> transformAndTimerFamilyId =
- TimerReceiverFactory.decodeTimerDataTimerId(timerId);
- LOG.debug(
- "timer callback: {} {} {} {} {}",
- transformAndTimerFamilyId.getKey(),
- transformAndTimerFamilyId.getValue(),
- window,
- timestamp,
- timeDomain);
- FnDataReceiver<Timer> timerReceiver =
- Preconditions.checkNotNull(
- remoteBundle.getTimerReceivers().get(transformAndTimerFamilyId),
- "No receiver found for timer %s %s",
- transformAndTimerFamilyId.getKey(),
- transformAndTimerFamilyId.getValue());
- Timer<?> timerValue =
- Timer.of(
- timerKey,
- "",
- Collections.singletonList(window),
- timestamp,
- outputTimestamp,
- // TODO: Support propagating the PaneInfo through.
- PaneInfo.NO_FIRING);
- try {
- timerReceiver.accept(timerValue);
- } catch (Exception e) {
- throw new RuntimeException(
- String.format(Locale.ENGLISH, "Failed to process timer %s", timerReceiver), e);
+ if (StateAndTimerBundleCheckpointHandler.isSdfTimer(timerId)) {
+ StateNamespace namespace = StateNamespaces.window(windowCoder, window);
+ WindowedValue stateValue =
+ keyedStateInternals.state(namespace, StateTags.value(timerId, residualCoder)).read();
+ processElement(stateValue);
+ } else {
+ KV<String, String> transformAndTimerFamilyId =
+ TimerReceiverFactory.decodeTimerDataTimerId(timerId);
+ LOG.debug(
+ "timer callback: {} {} {} {} {}",
+ transformAndTimerFamilyId.getKey(),
+ transformAndTimerFamilyId.getValue(),
+ window,
+ timestamp,
+ timeDomain);
+ FnDataReceiver<Timer> timerReceiver =
+ Preconditions.checkNotNull(
+ remoteBundle.getTimerReceivers().get(transformAndTimerFamilyId),
+ "No receiver found for timer %s %s",
+ transformAndTimerFamilyId.getKey(),
+ transformAndTimerFamilyId.getValue());
+ Timer<?> timerValue =
+ Timer.of(
+ timerKey,
+ "",
+ Collections.singletonList(window),
+ timestamp,
+ outputTimestamp,
+ // TODO: Support propagating the PaneInfo through.
+ PaneInfo.NO_FIRING);
+ try {
+ timerReceiver.accept(timerValue);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ String.format(Locale.ENGLISH, "Failed to process timer %s", timerReceiver), e);
+ }
}
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java
new file mode 100644
index 0000000..29af81d
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming;
+
+import java.nio.ByteBuffer;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+
+/**
+ * {@link KeySelector} that retrieves a key from a {@code KV<KV<element, KV<restriction,
+ * watermarkState>>, size>}. This will return the element as encoded by the provided {@link Coder}
+ * in a {@link ByteBuffer}. This ensures that all key comparisons/hashing happen on the encoded
+ * form. Note that the reason we don't use the whole {@code KV<KV<element, KV<restriction,
+ * watermarkState>>, Double>} as the key is when checkpoint happens, we will get different
+ * restriction/watermarkState/size, which Flink treats as a new key. Using new key to set state and
+ * timer may cause defined behavior.
+ */
+public class SdfByteBufferKeySelector<K, V>
+ implements KeySelector<WindowedValue<KV<KV<K, V>, Double>>, ByteBuffer>,
+ ResultTypeQueryable<ByteBuffer> {
+
+ private final Coder<K> keyCoder;
+ private final SerializablePipelineOptions pipelineOptions;
+
+ public SdfByteBufferKeySelector(Coder<K> keyCoder, SerializablePipelineOptions pipelineOptions) {
+ this.keyCoder = keyCoder;
+ this.pipelineOptions = pipelineOptions;
+ }
+
+ @Override
+ public ByteBuffer getKey(WindowedValue<KV<KV<K, V>, Double>> value) {
+ K key = value.getValue().getKey().getKey();
+ return FlinkKeyUtils.encodeKey(key, keyCoder);
+ }
+
+ @Override
+ public TypeInformation<ByteBuffer> getProducedType() {
+ return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions.get());
+ }
+}
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
index 58ac315..67b772d 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
@@ -34,6 +34,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.runners.core.construction.Timer;
import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -123,7 +124,8 @@ public class FlinkExecutableStageFunctionTest {
any(),
any(StateRequestHandler.class),
any(BundleProgressHandler.class),
- any(BundleFinalizationHandler.class)))
+ any(BundleFinalizationHandler.class),
+ any(BundleCheckpointHandler.class)))
.thenReturn(remoteBundle);
when(stageBundleFactory.getBundle(
any(),
@@ -148,7 +150,8 @@ public class FlinkExecutableStageFunctionTest {
any(),
any(StateRequestHandler.class),
any(BundleProgressHandler.class),
- any(BundleFinalizationHandler.class)))
+ any(BundleFinalizationHandler.class),
+ any(BundleCheckpointHandler.class)))
.thenReturn(bundle);
@SuppressWarnings("unchecked")
@@ -172,7 +175,8 @@ public class FlinkExecutableStageFunctionTest {
any(),
any(StateRequestHandler.class),
any(BundleProgressHandler.class),
- any(BundleFinalizationHandler.class)))
+ any(BundleFinalizationHandler.class),
+ any(BundleCheckpointHandler.class)))
.thenReturn(bundle);
@SuppressWarnings("unchecked")
@@ -213,7 +217,8 @@ public class FlinkExecutableStageFunctionTest {
TimerReceiverFactory timerReceiverFactory,
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
- BundleFinalizationHandler finalizationHandler) {
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler) {
return new RemoteBundle() {
@Override
public String getId() {
@@ -333,6 +338,7 @@ public class FlinkExecutableStageFunctionTest {
jobInfo,
outputMap,
contextFactory,
+ null,
null);
function.setRuntimeContext(runtimeContext);
Whitebox.setInternalState(function, "stateRequestHandler", stateRequestHandler);
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index bc5c277..336b414 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -67,6 +67,7 @@ import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
import org.apache.beam.runners.flink.streaming.FlinkStateInternalsTest;
import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory;
import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -213,7 +214,7 @@ public class ExecutableStageDoFnOperatorTest {
@SuppressWarnings("unchecked")
RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
- when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+ when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
@SuppressWarnings("unchecked")
FnDataReceiver<WindowedValue<?>> receiver = Mockito.mock(FnDataReceiver.class);
@@ -240,7 +241,7 @@ public class ExecutableStageDoFnOperatorTest {
@SuppressWarnings("unchecked")
RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
- when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+ when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
@SuppressWarnings("unchecked")
FnDataReceiver<WindowedValue<?>> receiver = Mockito.mock(FnDataReceiver.class);
@@ -323,7 +324,8 @@ public class ExecutableStageDoFnOperatorTest {
TimerReceiverFactory timerReceiverFactory,
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
- BundleFinalizationHandler finalizationHandler) {
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler) {
return new RemoteBundle() {
@Override
public String getId() {
@@ -458,7 +460,7 @@ public class ExecutableStageDoFnOperatorTest {
.put(KV.of("transform", "timer2"), Mockito.mock(FnDataReceiver.class))
.put(KV.of("transform", "timer3"), Mockito.mock(FnDataReceiver.class))
.build());
- when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+ when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
testHarness.open();
assertThat(
@@ -576,7 +578,7 @@ public class ExecutableStageDoFnOperatorTest {
ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder()
.put("input", Mockito.mock(FnDataReceiver.class))
.build());
- when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+ when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
testHarness.open();
testHarness.close();
@@ -624,7 +626,7 @@ public class ExecutableStageDoFnOperatorTest {
ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder()
.put("input", Mockito.mock(FnDataReceiver.class))
.build());
- when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+ when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
testHarness.open();
@@ -736,7 +738,7 @@ public class ExecutableStageDoFnOperatorTest {
@SuppressWarnings("unchecked")
RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
- when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+ when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
KV<String, String> timerInputKey = KV.of("transformId", "timerId");
AtomicBoolean timerInputReceived = new AtomicBoolean();
@@ -921,7 +923,7 @@ public class ExecutableStageDoFnOperatorTest {
ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder()
.put("input", Mockito.mock(FnDataReceiver.class))
.build());
- when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+ when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
testHarness.open();
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandlers.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandlers.java
new file mode 100644
index 0000000..6ed6127
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandlers.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateInternalsFactory;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaces;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.TimerInternalsFactory;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.IdGenerators;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Utility methods for creating {@link BundleCheckpointHandler}s. */
+@SuppressWarnings({
+ "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public class BundleCheckpointHandlers {
+
+ /**
+ * A {@link BundleCheckpointHandler} which uses {@link
+ * org.apache.beam.runners.core.TimerInternals.TimerData} ans {@link
+ * org.apache.beam.sdk.state.ValueState} to reschedule {@link DelayedBundleApplication}.
+ */
+ public static class StateAndTimerBundleCheckpointHandler<T> implements BundleCheckpointHandler {
+ private static final Logger LOG =
+ LoggerFactory.getLogger(StateAndTimerBundleCheckpointHandler.class);
+ private final TimerInternalsFactory<T> timerInternalsFactory;
+ private final StateInternalsFactory<T> stateInternalsFactory;
+ private final Coder<WindowedValue<T>> residualCoder;
+ private final Coder windowCoder;
+ private final IdGenerator idGenerator = IdGenerators.incrementingLongs();
+ public static final String SDF_PREFIX = "sdf_checkpoint";
+
+ public StateAndTimerBundleCheckpointHandler(
+ TimerInternalsFactory<T> timerInternalsFactory,
+ StateInternalsFactory<T> stateInternalsFactory,
+ Coder<WindowedValue<T>> residualCoder,
+ Coder windowCoder) {
+ this.residualCoder = residualCoder;
+ this.windowCoder = windowCoder;
+ this.timerInternalsFactory = timerInternalsFactory;
+ this.stateInternalsFactory = stateInternalsFactory;
+ }
+
+ /**
+ * A helper function to help check whether the given timer is the timer which is set for
+ * rescheduling {@link DelayedBundleApplication}.
+ */
+ public static boolean isSdfTimer(String timerId) {
+ return timerId.startsWith(SDF_PREFIX);
+ }
+
+ private static String constructSdfCheckpointId(String id, int index) {
+ return SDF_PREFIX + ":" + id + ":" + index;
+ }
+
+ @Override
+ public void onCheckpoint(ProcessBundleResponse response) {
+ String id = idGenerator.getId();
+ for (int index = 0; index < response.getResidualRootsCount(); index++) {
+ DelayedBundleApplication residual = response.getResidualRoots(index);
+ if (!residual.hasApplication()) {
+ continue;
+ }
+ String tag = constructSdfCheckpointId(id, index);
+ try {
+ WindowedValue<T> stateValue =
+ CoderUtils.decodeFromByteArray(
+ residualCoder, residual.getApplication().getElement().toByteArray());
+ TimerInternals timerInternals =
+ timerInternalsFactory.timerInternalsForKey((stateValue.getValue()));
+ StateInternals stateInternals =
+ stateInternalsFactory.stateInternalsForKey(stateValue.getValue());
+ // Calculate the timestamp for the timer.
+ Instant timestamp = Instant.now();
+ if (residual.hasRequestedTimeDelay()) {
+ timestamp = timestamp.plus(residual.getRequestedTimeDelay().getSeconds() * 1000);
+ }
+ // Calculate the watermark hold for the timer.
+ long outputTimestamp = BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis();
+ if (!residual.getApplication().getOutputWatermarksMap().isEmpty()) {
+ for (org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp outputWatermark :
+ residual.getApplication().getOutputWatermarksMap().values()) {
+ outputTimestamp = Math.min(outputTimestamp, outputWatermark.getSeconds() * 1000);
+ }
+ } else {
+ outputTimestamp = BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis();
+ }
+ for (BoundedWindow window : stateValue.getWindows()) {
+ StateNamespace stateNamespace = StateNamespaces.window(windowCoder, window);
+ timerInternals.setTimer(
+ stateNamespace,
+ tag,
+ "",
+ timestamp,
+ Instant.ofEpochMilli(outputTimestamp),
+ TimeDomain.PROCESSING_TIME);
+ stateInternals
+ .state(stateNamespace, StateTags.value(tag, residualCoder))
+ .write(
+ WindowedValue.of(
+ stateValue.getValue(),
+ stateValue.getTimestamp(),
+ ImmutableList.of(window),
+ stateValue.getPane()));
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to set timer/state for the residual", e);
+ }
+ }
+ }
+ }
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
index 87d1e1e..a0d370e 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
@@ -459,7 +459,8 @@ public class DefaultJobBundleFactory implements JobBundleFactory {
TimerReceiverFactory timerReceiverFactory,
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
- BundleFinalizationHandler finalizationHandler)
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler)
throws Exception {
// TODO: Consider having BundleProcessor#newBundle take in an OutputReceiverFactory rather
// than constructing the receiver map here. Every bundle factory will need this.
@@ -520,7 +521,8 @@ public class DefaultJobBundleFactory implements JobBundleFactory {
getTimerReceivers(currentClient.processBundleDescriptor, timerReceiverFactory),
stateRequestHandler,
progressHandler,
- finalizationHandler);
+ finalizationHandler,
+ checkpointHandler);
return new RemoteBundle() {
@Override
public String getId() {
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
index 9fcfb36..7f51081 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
@@ -192,19 +192,22 @@ public class SdkHarnessClient implements AutoCloseable {
Map<KV<String, String>, RemoteOutputReceiver<Timer<?>>> timerReceivers,
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
- BundleFinalizationHandler finalizationHandler) {
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler) {
return newBundle(
outputReceivers,
timerReceivers,
stateRequestHandler,
progressHandler,
BundleSplitHandler.unsupported(),
- request -> {
- throw new UnsupportedOperationException(
- String.format(
- "The %s does not have a registered bundle checkpoint handler.",
- ActiveBundle.class.getSimpleName()));
- },
+ checkpointHandler == null
+ ? request -> {
+ throw new UnsupportedOperationException(
+ String.format(
+ "The %s does not have a registered bundle checkpoint handler.",
+ ActiveBundle.class.getSimpleName()));
+ }
+ : checkpointHandler,
finalizationHandler == null
? bundleId -> {
throw new UnsupportedOperationException(
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java
index af4cb52..6e696c0 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java
@@ -162,7 +162,8 @@ public class SingleEnvironmentInstanceJobBundleFactory implements JobBundleFacto
TimerReceiverFactory timerReceiverFactory,
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
- BundleFinalizationHandler finalizationHandler) {
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler) {
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Map.Entry<String, Coder> remoteOutputCoder :
descriptor.getRemoteOutputCoders().entrySet()) {
@@ -195,7 +196,8 @@ public class SingleEnvironmentInstanceJobBundleFactory implements JobBundleFacto
timerReceivers,
stateRequestHandler,
progressHandler,
- finalizationHandler);
+ finalizationHandler,
+ checkpointHandler);
}
@Override
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java
index 3b0911e..ee2b8e5 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java
@@ -50,6 +50,22 @@ public interface StageBundleFactory extends AutoCloseable {
outputReceiverFactory, null, stateRequestHandler, progressHandler, finalizationHandler);
}
+ default RemoteBundle getBundle(
+ OutputReceiverFactory outputReceiverFactory,
+ StateRequestHandler stateRequestHandler,
+ BundleProgressHandler progressHandler,
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler)
+ throws Exception {
+ return getBundle(
+ outputReceiverFactory,
+ null,
+ stateRequestHandler,
+ progressHandler,
+ finalizationHandler,
+ checkpointHandler);
+ }
+
/** Get a new {@link RemoteBundle bundle} for processing the data in an executable stage. */
default RemoteBundle getBundle(
OutputReceiverFactory outputReceiverFactory,
@@ -61,12 +77,29 @@ public interface StageBundleFactory extends AutoCloseable {
outputReceiverFactory, timerReceiverFactory, stateRequestHandler, progressHandler, null);
}
- RemoteBundle getBundle(
+ default RemoteBundle getBundle(
OutputReceiverFactory outputReceiverFactory,
TimerReceiverFactory timerReceiverFactory,
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
BundleFinalizationHandler finalizationHandler)
+ throws Exception {
+ return getBundle(
+ outputReceiverFactory,
+ timerReceiverFactory,
+ stateRequestHandler,
+ progressHandler,
+ finalizationHandler,
+ null);
+ }
+
+ RemoteBundle getBundle(
+ OutputReceiverFactory outputReceiverFactory,
+ TimerReceiverFactory timerReceiverFactory,
+ StateRequestHandler stateRequestHandler,
+ BundleProgressHandler progressHandler,
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler)
throws Exception;
ProcessBundleDescriptors.ExecutableProcessBundleDescriptor getProcessBundleDescriptor();
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 46a5d60..dcea277 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -1178,6 +1178,7 @@ public class RemoteExecutionTest implements Serializable {
timerReceivers,
StateRequestHandler.unsupported(),
BundleProgressHandler.ignored(),
+ null,
null)) {
Iterables.getOnlyElement(bundle.getInputReceivers().values())
.accept(valueInGlobalWindow(KV.of("X", "X")));
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
index 10d5835..97e05cf 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
@@ -40,6 +40,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.runners.core.construction.Timer;
import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -162,7 +163,8 @@ public class SparkExecutableStageFunctionTest {
TimerReceiverFactory timerReceiverFactory,
StateRequestHandler stateRequestHandler,
BundleProgressHandler progressHandler,
- BundleFinalizationHandler finalizationHandler) {
+ BundleFinalizationHandler finalizationHandler,
+ BundleCheckpointHandler checkpointHandler) {
return new RemoteBundle() {
@Override
public String getId() {
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 3e45f4a..303fe96 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -389,9 +389,6 @@ class FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("BEAM-2939")
- def test_sdf_with_sdf_initiated_checkpointing(self):
- raise unittest.SkipTest("BEAM-2939")
-
def test_callbacks_with_exception(self):
raise unittest.SkipTest("BEAM-11021")