You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2020/11/11 23:00:56 UTC

[beam] branch master updated: Add sdf initiated checkpoint support to portable Flink.

This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 28914c2  Add sdf initiated checkpoint support to portable Flink.
     new cf6fd1c  Merge pull request #13105 from [BEAM-10940] Add sdf initiated checkpoint support to portable Flink.
28914c2 is described below

commit 28914c2679feaae8bf00955229f64bb46d3970cd
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Tue Oct 13 19:37:26 2020 -0700

    Add sdf initiated checkpoint support to portable Flink.
---
 runners/flink/job-server/flink_job_server.gradle   |  10 +-
 .../FlinkBatchPortablePipelineTranslator.java      |   3 +-
 .../FlinkStreamingPortablePipelineTranslator.java  |  43 ++-
 .../functions/FlinkExecutableStageFunction.java    |  84 +++++-
 .../wrappers/streaming/DoFnOperator.java           |  31 ++-
 .../streaming/ExecutableStageDoFnOperator.java     | 288 ++++++++++++++++++---
 .../streaming/SdfByteBufferKeySelector.java        |  61 +++++
 .../FlinkExecutableStageFunctionTest.java          |  14 +-
 .../streaming/ExecutableStageDoFnOperatorTest.java |  18 +-
 .../control/BundleCheckpointHandlers.java          | 142 ++++++++++
 .../control/DefaultJobBundleFactory.java           |   6 +-
 .../fnexecution/control/SdkHarnessClient.java      |  17 +-
 .../SingleEnvironmentInstanceJobBundleFactory.java |   6 +-
 .../fnexecution/control/StageBundleFactory.java    |  35 ++-
 .../fnexecution/control/RemoteExecutionTest.java   |   1 +
 .../SparkExecutableStageFunctionTest.java          |   4 +-
 .../runners/portability/flink_runner_test.py       |   3 -
 17 files changed, 673 insertions(+), 93 deletions(-)

diff --git a/runners/flink/job-server/flink_job_server.gradle b/runners/flink/job-server/flink_job_server.gradle
index 24b0838..1ca09a1 100644
--- a/runners/flink/job-server/flink_job_server.gradle
+++ b/runners/flink/job-server/flink_job_server.gradle
@@ -146,8 +146,6 @@ def portableValidatesRunnerTask(String name, Boolean streaming, Boolean checkpoi
       if (streaming && checkpointing) {
         includeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer'
         excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
-        // TODO(BEAM-10940): Enable this test suite once we have support.
-        excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
         // TestStreamSource does not support checkpointing
         excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
       } else {
@@ -169,18 +167,16 @@ def portableValidatesRunnerTask(String name, Boolean streaming, Boolean checkpoi
         excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer'
         excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'
         if (streaming) {
+          excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
           excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
           excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithMultipleStages'
           excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithOutputTimestamp'
         } else {
+          excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
+          excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
           excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedPCollections'
           excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream'
         }
-        //SplitableDoFnTests
-        excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
-        excludeCategories 'org.apache.beam.sdk.testing.UsesSplittableParDoWithWindowedSideInputs'
-        excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo'
-
       }
     },
     testFilter: {
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
index 25a0ed0..32726f3 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchPortablePipelineTranslator.java
@@ -342,7 +342,8 @@ public class FlinkBatchPortablePipelineTranslator
             context.getJobInfo(),
             outputMap,
             FlinkExecutableStageContextFactory.getInstance(),
-            getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder());
+            getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder(),
+            windowedInputCoder);
 
     final String operatorName = generateNameFromStagePayload(stagePayload);
 
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
index 2112941..c1d2583 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
@@ -56,6 +56,7 @@ import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.DoFnOperator;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.ExecutableStageDoFnOperator;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.KvToByteBufferKeySelector;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.SdfByteBufferKeySelector;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItem;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.WindowDoFnOperator;
@@ -682,10 +683,20 @@ public class FlinkStreamingPortablePipelineTranslator
 
     final boolean stateful =
         stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0;
+    final boolean hasSdfProcessFn =
+        stagePayload.getComponents().getTransformsMap().values().stream()
+            .anyMatch(
+                pTransform ->
+                    pTransform
+                        .getSpec()
+                        .getUrn()
+                        .equals(
+                            PTransformTranslation
+                                .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
     Coder keyCoder = null;
     KeySelector<WindowedValue<InputT>, ?> keySelector = null;
-    if (stateful) {
-      // Stateful stages are only allowed of KV input
+    if (stateful || hasSdfProcessFn) {
+      // Stateful/SDF stages are only allowed of KV input.
       Coder valueCoder =
           ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
       if (!(valueCoder instanceof KvCoder)) {
@@ -696,10 +707,28 @@ public class FlinkStreamingPortablePipelineTranslator
                 inputPCollectionId,
                 valueCoder.getClass().getSimpleName()));
       }
-      keyCoder = ((KvCoder) valueCoder).getKeyCoder();
-      keySelector =
-          new KvToByteBufferKeySelector(
-              keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
+      if (stateful) {
+        keyCoder = ((KvCoder) valueCoder).getKeyCoder();
+        keySelector =
+            new KvToByteBufferKeySelector(
+                keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
+      } else {
+        // For an SDF, we know that the input element should be
+        // KV<KV<element, KV<restriction, watermarkState>>, size>. We are going to use the element
+        // as the key.
+        if (!(((KvCoder) valueCoder).getKeyCoder() instanceof KvCoder)) {
+          throw new IllegalStateException(
+              String.format(
+                  Locale.ENGLISH,
+                  "The element coder for splittable DoFn '%s' must be KVCoder(KvCoder, DoubleCoder) but is: %s",
+                  inputPCollectionId,
+                  valueCoder.getClass().getSimpleName()));
+        }
+        keyCoder = ((KvCoder) ((KvCoder) valueCoder).getKeyCoder()).getKeyCoder();
+        keySelector =
+            new SdfByteBufferKeySelector(
+                keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
+      }
       inputDataStream = inputDataStream.keyBy(keySelector);
     }
 
@@ -738,7 +767,7 @@ public class FlinkStreamingPortablePipelineTranslator
     } else {
       DataStream<RawUnionValue> sideInputStream =
           transformedSideInputs.unionedSideInputs.broadcast();
-      if (stateful) {
+      if (stateful || hasSdfProcessFn) {
         // We have to manually construct the two-input transform because we're not
         // allowed to have only one input keyed, normally. Since Flink 1.5.0 it's
         // possible to use the Broadcast State Pattern which provides a more elegant
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
index cb5a2a6..6c9c984 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java
@@ -18,7 +18,9 @@
 package org.apache.beam.runners.flink.translation.functions;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.EnumMap;
+import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import javax.annotation.concurrent.GuardedBy;
@@ -26,13 +28,19 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressRespo
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.runners.core.InMemoryStateInternals;
 import org.apache.beam.runners.core.InMemoryTimerInternals;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateTags;
 import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
 import org.apache.beam.runners.core.construction.Timer;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.flink.FlinkPipelineOptions;
 import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers;
 import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
 import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
 import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -96,6 +104,7 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
   private final Map<String, Integer> outputMap;
   private final FlinkExecutableStageContextFactory contextFactory;
   private final Coder windowCoder;
+  private final Coder<WindowedValue<InputT>> inputCoder;
   // Unique name for namespacing metrics
   private final String stepName;
 
@@ -107,6 +116,9 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
   private transient StageBundleFactory stageBundleFactory;
   private transient BundleProgressHandler progressHandler;
   private transient BundleFinalizationHandler finalizationHandler;
+  private transient BundleCheckpointHandler bundleCheckpointHandler;
+  private transient InMemoryTimerInternals sdfTimerInternals;
+  private transient StateInternals sdfStateInternals;
   // Only initialized when the ExecutableStage is stateful
   private transient InMemoryBagUserStateFactory bagUserStateHandlerFactory;
   private transient ExecutableStage executableStage;
@@ -120,7 +132,8 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
       JobInfo jobInfo,
       Map<String, Integer> outputMap,
       FlinkExecutableStageContextFactory contextFactory,
-      Coder windowCoder) {
+      Coder windowCoder,
+      Coder<WindowedValue<InputT>> inputCoder) {
     this.stepName = stepName;
     this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
     this.stagePayload = stagePayload;
@@ -128,6 +141,7 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
     this.outputMap = outputMap;
     this.contextFactory = contextFactory;
     this.windowCoder = windowCoder;
+    this.inputCoder = inputCoder;
   }
 
   @Override
@@ -165,6 +179,35 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
           throw new UnsupportedOperationException(
               "Portable Flink runner doesn't support bundle finalization in batch mode. For more details, please refer to https://issues.apache.org/jira/browse/BEAM-11021.");
         };
+    bundleCheckpointHandler = getBundleCheckpointHandler(executableStage);
+  }
+
+  private boolean hasSDF(ExecutableStage executableStage) {
+    return executableStage.getTransforms().stream()
+        .anyMatch(
+            pTransformNode ->
+                pTransformNode
+                    .getTransform()
+                    .getSpec()
+                    .getUrn()
+                    .equals(
+                        PTransformTranslation
+                            .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
+  }
+
+  private BundleCheckpointHandler getBundleCheckpointHandler(ExecutableStage executableStage) {
+    if (!hasSDF(executableStage)) {
+      sdfStateInternals = null;
+      sdfStateInternals = null;
+      return response -> {
+        throw new UnsupportedOperationException(
+            "Self-checkpoint is only supported on splittable DoFn.");
+      };
+    }
+    sdfTimerInternals = new InMemoryTimerInternals();
+    sdfStateInternals = InMemoryStateInternals.forKey("sdf_state");
+    return new BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler(
+        key -> sdfTimerInternals, key -> sdfStateInternals, inputCoder, windowCoder);
   }
 
   private StateRequestHandler getStateRequestHandler(
@@ -210,11 +253,47 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
       throws Exception {
 
     ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap);
+    if (sdfStateInternals != null) {
+      sdfTimerInternals.advanceProcessingTime(Instant.now());
+      sdfTimerInternals.advanceSynchronizedProcessingTime(Instant.now());
+    }
     try (RemoteBundle bundle =
         stageBundleFactory.getBundle(
-            receiverFactory, stateRequestHandler, progressHandler, finalizationHandler)) {
+            receiverFactory,
+            stateRequestHandler,
+            progressHandler,
+            finalizationHandler,
+            bundleCheckpointHandler)) {
       processElements(iterable, bundle);
     }
+    if (sdfTimerInternals != null) {
+      // Finally, advance the processing time to infinity to fire any timers.
+      sdfTimerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+      sdfTimerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+      // Now we fire the SDF timers and process elements generated by timers.
+      while (sdfTimerInternals.hasPendingTimers()) {
+        try (RemoteBundle bundle =
+            stageBundleFactory.getBundle(
+                receiverFactory,
+                stateRequestHandler,
+                progressHandler,
+                finalizationHandler,
+                bundleCheckpointHandler)) {
+          List<WindowedValue<InputT>> residuals = new ArrayList<>();
+          TimerInternals.TimerData timer;
+          while ((timer = sdfTimerInternals.removeNextProcessingTimer()) != null) {
+            WindowedValue stateValue =
+                sdfStateInternals
+                    .state(timer.getNamespace(), StateTags.value(timer.getTimerId(), inputCoder))
+                    .read();
+
+            residuals.add(stateValue);
+          }
+          processElements(residuals, bundle);
+        }
+      }
+    }
   }
 
   /** For stateful and timer processing via a GroupReduceFunction. */
@@ -267,7 +346,6 @@ public class FlinkExecutableStageFunction<InputT> extends AbstractRichFunction
       try (RemoteBundle bundle =
           stageBundleFactory.getBundle(
               receiverFactory, timerReceiverFactory, stateRequestHandler, progressHandler)) {
-
         PipelineTranslatorUtils.fireEligibleTimers(
             timerInternals, bundle.getTimerReceivers(), currentTimerKey);
       }
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
index 06c9398..09eacd0 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java
@@ -64,6 +64,7 @@ import org.apache.beam.runners.flink.translation.utils.Workarounds;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.stableinput.BufferingDoFnRunner;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkBroadcastStateInternals;
 import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.StructuredCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
@@ -1332,7 +1333,7 @@ public class DoFnOperator<InputT, OutputT>
         keyedStateBackend.setCurrentKey(internalTimer.getKey());
         TimerData timer = internalTimer.getNamespace();
         checkInvokeStartBundle();
-        fireTimer(timer);
+        fireTimerInternal((ByteBuffer) internalTimer.getKey(), timer);
       }
     }
 
@@ -1347,16 +1348,13 @@ public class DoFnOperator<InputT, OutputT>
       }
     }
 
-    private void onRemovedEventTimer(TimerData removedTimer) {
+    /** Holds the watermark when there is an sdf timer. */
+    private void onNewSdfTimer(TimerData newTimer) {
       Preconditions.checkState(
-          removedTimer.getDomain() == TimeDomain.EVENT_TIME,
-          "Timer with id %s is not an event time timer!",
-          removedTimer.getTimerId());
-      // Remove the first occurrence of the output timestamp, if cached
-      // Note: There may be duplicate timestamps from other timers, that's ok.
-      if (timerUsesOutputTimestamp(removedTimer)) {
-        keyedStateInternals.removeWatermarkHoldUsage(removedTimer.getOutputTimestamp());
-      }
+          StateAndTimerBundleCheckpointHandler.isSdfTimer(newTimer.getTimerId()));
+      // An SDF timer should hold the watermark for further output.
+      Preconditions.checkState(timerUsesOutputTimestamp(newTimer));
+      keyedStateInternals.addWatermarkHoldUsage(newTimer.getOutputTimestamp());
     }
 
     private void populateOutputTimestampQueue() {
@@ -1369,7 +1367,8 @@ public class DoFnOperator<InputT, OutputT>
               keyedStateBackend.setCurrentKey(key);
               try {
                 for (TimerData timerData : pendingTimersById.values()) {
-                  if (timerData.getDomain() == TimeDomain.EVENT_TIME) {
+                  if (timerData.getDomain() == TimeDomain.EVENT_TIME
+                      || StateAndTimerBundleCheckpointHandler.isSdfTimer(timerData.getTimerId())) {
                     if (timerUsesOutputTimestamp(timerData)) {
                       keyedStateInternals.addWatermarkHoldUsage(timerData.getOutputTimestamp());
                     }
@@ -1441,6 +1440,9 @@ public class DoFnOperator<InputT, OutputT>
         case PROCESSING_TIME:
         case SYNCHRONIZED_PROCESSING_TIME:
           timerService.registerProcessingTimeTimer(timer, adjustTimestampForFlink(time));
+          if (StateAndTimerBundleCheckpointHandler.isSdfTimer(timer.getTimerId())) {
+            onNewSdfTimer(timer);
+          }
           break;
         default:
           throw new UnsupportedOperationException("Unsupported time domain: " + timer.getDomain());
@@ -1466,8 +1468,11 @@ public class DoFnOperator<InputT, OutputT>
     void onFiredOrDeletedTimer(TimerData timer) {
       try {
         pendingTimersById.remove(getContextTimerId(timer.getTimerId(), timer.getNamespace()));
-        if (timer.getDomain() == TimeDomain.EVENT_TIME) {
-          onRemovedEventTimer(timer);
+        if (timer.getDomain() == TimeDomain.EVENT_TIME
+            || StateAndTimerBundleCheckpointHandler.isSdfTimer(timer.getTimerId())) {
+          if (timerUsesOutputTimestamp(timer)) {
+            keyedStateInternals.removeWatermarkHoldUsage(timer.getOutputTimestamp());
+          }
         }
       } catch (Exception e) {
         throw new RuntimeException("Failed to cleanup pending timers state.", e);
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
index a27c966..d479a3c 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java
@@ -40,6 +40,7 @@ import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase;
@@ -47,6 +48,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.LateDataUtils;
 import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateInternalsFactory;
 import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateNamespaces;
 import org.apache.beam.runners.core.StateTag;
@@ -54,6 +56,8 @@ import org.apache.beam.runners.core.StateTags;
 import org.apache.beam.runners.core.StatefulDoFnRunner;
 import org.apache.beam.runners.core.StepContext;
 import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.TimerInternalsFactory;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
 import org.apache.beam.runners.core.construction.Timer;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
@@ -61,6 +65,10 @@ import org.apache.beam.runners.core.construction.graph.UserStateReference;
 import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory;
 import org.apache.beam.runners.flink.translation.functions.FlinkStreamingSideInputHandlerFactory;
 import org.apache.beam.runners.flink.translation.types.CoderTypeSerializer;
+import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler;
 import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
 import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandlers;
 import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandlers.InMemoryFinalizer;
@@ -80,6 +88,8 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.function.ThrowingFunction;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.state.BagState;
+import org.apache.beam.sdk.state.State;
+import org.apache.beam.sdk.state.StateContext;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
@@ -141,10 +151,15 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
 
   private final boolean isStateful;
 
+  private final Coder windowCoder;
+  private final Coder<WindowedValue<InputT>> inputCoder;
+
   private transient ExecutableStageContext stageContext;
   private transient StateRequestHandler stateRequestHandler;
   private transient BundleProgressHandler progressHandler;
   private transient InMemoryFinalizer finalizationHandler;
+  private transient BundleCheckpointHandler checkpointHandler;
+  private transient boolean hasSdfProcessFn;
   private transient StageBundleFactory stageBundleFactory;
   private transient ExecutableStage executableStage;
   private transient SdkHarnessDoFnRunner<InputT, OutputT> sdkHarnessRunner;
@@ -203,6 +218,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     this.outputMap = outputMap;
     this.sideInputIds = sideInputIds;
     this.stateBackendLock = new ReentrantLock();
+    this.windowCoder = (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder();
+    this.inputCoder = windowedInputCoder;
     this.pipelineOptions = new SerializablePipelineOptions(options);
   }
 
@@ -214,6 +231,7 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
   @Override
   public void open() throws Exception {
     executableStage = ExecutableStage.fromPayload(payload);
+    hasSdfProcessFn = hasSDF(executableStage);
     initializeUserState(executableStage, getKeyedStateBackend(), pipelineOptions);
     // TODO: Wire this into the distributed cache and make it pluggable.
     // TODO: Do we really want this layer of indirection when accessing the stage bundle factory?
@@ -244,6 +262,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
         BundleFinalizationHandlers.inMemoryFinalizer(
             stageBundleFactory.getInstructionRequestHandler());
 
+    checkpointHandler = getBundleCheckpointHandler(hasSdfProcessFn);
+
     minEventTimeTimerTimestampInCurrentBundle = Long.MAX_VALUE;
     minEventTimeTimerTimestampInLastBundle = Long.MAX_VALUE;
     super.setPreBundleCallback(this::preBundleStartCallback);
@@ -259,6 +279,34 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     super.notifyCheckpointComplete(checkpointId);
   }
 
+  private BundleCheckpointHandler getBundleCheckpointHandler(boolean hasSDF) {
+    if (!hasSDF) {
+      return response -> {
+        throw new UnsupportedOperationException(
+            "Self-checkpoint is only supported on splittable DoFn.");
+      };
+    }
+
+    return new BundleCheckpointHandlers.StateAndTimerBundleCheckpointHandler(
+        new SdfFlinkTimerInternalsFactory(),
+        new SdfFlinkStateInternalsFactory(),
+        inputCoder,
+        windowCoder);
+  }
+
+  private boolean hasSDF(ExecutableStage executableStage) {
+    return executableStage.getTransforms().stream()
+        .anyMatch(
+            pTransformNode ->
+                pTransformNode
+                    .getTransform()
+                    .getSpec()
+                    .getUrn()
+                    .equals(
+                        PTransformTranslation
+                            .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
+  }
+
   private StateRequestHandler getStateRequestHandler(ExecutableStage executableStage) {
 
     final StateRequestHandler sideInputStateHandler;
@@ -491,6 +539,148 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     }
   }
 
+  /**
+   * A {@link TimerInternalsFactory} for Flink operator to create a {@link
+   * StateAndTimerBundleCheckpointHandler} to handle {@link
+   * org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication}.
+   */
+  class SdfFlinkTimerInternalsFactory implements TimerInternalsFactory<InputT> {
+    @Override
+    public TimerInternals timerInternalsForKey(InputT key) {
+      try {
+        ByteBuffer encodedKey =
+            (ByteBuffer) keySelector.getKey(WindowedValue.valueInGlobalWindow(key));
+        return new SdfFlinkTimerInternals(encodedKey);
+      } catch (Exception e) {
+        throw new RuntimeException("Couldn't get a timer internals", e);
+      }
+    }
+  }
+
+  /**
+   * A {@link TimerInternals} for rescheduling {@link
+   * org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication}.
+   */
+  class SdfFlinkTimerInternals implements TimerInternals {
+    private final ByteBuffer key;
+
+    SdfFlinkTimerInternals(ByteBuffer key) {
+      this.key = key;
+    }
+
+    @Override
+    public void setTimer(
+        StateNamespace namespace,
+        String timerId,
+        String timerFamilyId,
+        Instant target,
+        Instant outputTimestamp,
+        TimeDomain timeDomain) {
+      setTimer(
+          TimerData.of(timerId, timerFamilyId, namespace, target, outputTimestamp, timeDomain));
+    }
+
+    @Override
+    public void setTimer(TimerData timerData) {
+      try {
+        try (Locker locker = Locker.locked(stateBackendLock)) {
+          getKeyedStateBackend().setCurrentKey(key);
+          timerInternals.setTimer(timerData);
+          minEventTimeTimerTimestampInCurrentBundle =
+              Math.min(
+                  minEventTimeTimerTimestampInCurrentBundle,
+                  adjustTimestampForFlink(timerData.getOutputTimestamp().getMillis()));
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Couldn't set timer", e);
+      }
+    }
+
+    @Override
+    public void deleteTimer(StateNamespace namespace, String timerId, TimeDomain timeDomain) {
+      throw new UnsupportedOperationException(
+          "It is not expected to use SdfFlinkTimerInternals to delete a timer");
+    }
+
+    @Override
+    public void deleteTimer(StateNamespace namespace, String timerId, String timerFamilyId) {
+      throw new UnsupportedOperationException(
+          "It is not expected to use SdfFlinkTimerInternals to delete a timer");
+    }
+
+    @Override
+    public void deleteTimer(TimerData timerKey) {
+      throw new UnsupportedOperationException(
+          "It is not expected to use SdfFlinkTimerInternals to delete a timer");
+    }
+
+    @Override
+    public Instant currentProcessingTime() {
+      return timerInternals.currentProcessingTime();
+    }
+
+    @Override
+    public @Nullable Instant currentSynchronizedProcessingTime() {
+      return timerInternals.currentSynchronizedProcessingTime();
+    }
+
+    @Override
+    public Instant currentInputWatermarkTime() {
+      return timerInternals.currentInputWatermarkTime();
+    }
+
+    @Override
+    public @Nullable Instant currentOutputWatermarkTime() {
+      return timerInternals.currentOutputWatermarkTime();
+    }
+  }
+
+  /**
+   * A {@link StateInternalsFactory} for Flink operator to create a {@link
+   * StateAndTimerBundleCheckpointHandler} to handle {@link
+   * org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication}.
+   */
+  class SdfFlinkStateInternalsFactory implements StateInternalsFactory<InputT> {
+    @Override
+    public StateInternals stateInternalsForKey(InputT key) {
+      try {
+        ByteBuffer encodedKey =
+            (ByteBuffer) keySelector.getKey(WindowedValue.valueInGlobalWindow(key));
+        return new SdfFlinkStateInternals(encodedKey);
+      } catch (Exception e) {
+        throw new RuntimeException("Couldn't get a state internals", e);
+      }
+    }
+  }
+
+  /** A {@link StateInternals} for keeping {@link DelayedBundleApplication}s as states. */
+  class SdfFlinkStateInternals implements StateInternals {
+
+    private final ByteBuffer key;
+
+    SdfFlinkStateInternals(ByteBuffer key) {
+      this.key = key;
+    }
+
+    @Override
+    public Object getKey() {
+      return key;
+    }
+
+    @Override
+    public <T extends State> T state(
+        StateNamespace namespace, StateTag<T> address, StateContext<?> c) {
+      try {
+        try (Locker locker = Locker.locked(stateBackendLock)) {
+          getKeyedStateBackend().setCurrentKey(key);
+          return keyedStateInternals.state(namespace, address);
+        }
+      } catch (Exception e) {
+        throw new RuntimeException("Couldn't set state", e);
+      }
+    }
+  }
+
   @Override
   protected void fireTimerInternal(ByteBuffer key, TimerInternals.TimerData timer) {
     // We have to synchronize to ensure the state backend is not concurrently accessed by the state
@@ -509,6 +699,17 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     processWatermark1(Watermark.MAX_WATERMARK);
     while (getCurrentOutputWatermark() < Watermark.MAX_WATERMARK.getTimestamp()) {
       invokeFinishBundle();
+      if (hasSdfProcessFn) {
+        // Manually drain processing time timers since Flink will ignore pending
+        // processing-time timers when upstream operators have shut down and will also
+        // shut down this operator with pending processing-time timers.
+        // TODO(BEAM-11210, FLINK-18647): It doesn't work efficiently when the watermark of upstream
+        // advances
+        // to MAX_TIMESTAMP immediately.
+        if (numProcessingTimeTimers() > 0) {
+          timerInternals.processPendingProcessingTimeTimers();
+        }
+      }
     }
     super.close();
   }
@@ -549,11 +750,14 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
             stateRequestHandler,
             progressHandler,
             finalizationHandler,
+            checkpointHandler,
             outputManager,
             outputMap,
-            (Coder<BoundedWindow>) windowingStrategy.getWindowFn().windowCoder(),
+            windowCoder,
+            inputCoder,
             this::setTimer,
-            () -> FlinkKeyUtils.decodeKey(getCurrentKey(), keyCoder));
+            () -> FlinkKeyUtils.decodeKey(getCurrentKey(), keyCoder),
+            keyedStateInternals);
 
     return ensureStateDoFnRunner(sdkHarnessRunner, payload, stepContext);
   }
@@ -658,10 +862,12 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
     private final StateRequestHandler stateRequestHandler;
     private final BundleProgressHandler progressHandler;
     private final BundleFinalizationHandler finalizationHandler;
+    private final BundleCheckpointHandler checkpointHandler;
     private final BufferedOutputManager<OutputT> outputManager;
     private final Map<String, TupleTag<?>> outputMap;
-
+    private final FlinkStateInternals<?> keyedStateInternals;
     private final Coder<BoundedWindow> windowCoder;
+    private final Coder<WindowedValue<InputT>> residualCoder;
     private final BiConsumer<Timer<?>, TimerInternals.TimerData> timerRegistration;
     private final Supplier<Object> keyForTimer;
 
@@ -682,23 +888,29 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
         StateRequestHandler stateRequestHandler,
         BundleProgressHandler progressHandler,
         BundleFinalizationHandler finalizationHandler,
+        BundleCheckpointHandler checkpointHandler,
         BufferedOutputManager<OutputT> outputManager,
         Map<String, TupleTag<?>> outputMap,
         Coder<BoundedWindow> windowCoder,
+        Coder<WindowedValue<InputT>> residualCoder,
         BiConsumer<Timer<?>, TimerInternals.TimerData> timerRegistration,
-        Supplier<Object> keyForTimer) {
+        Supplier<Object> keyForTimer,
+        FlinkStateInternals<?> keyedStateInternals) {
 
       this.doFn = doFn;
       this.stageBundleFactory = stageBundleFactory;
       this.stateRequestHandler = stateRequestHandler;
       this.progressHandler = progressHandler;
       this.finalizationHandler = finalizationHandler;
+      this.checkpointHandler = checkpointHandler;
       this.outputManager = outputManager;
       this.outputMap = outputMap;
       this.timerRegistration = timerRegistration;
       this.keyForTimer = keyForTimer;
       this.windowCoder = windowCoder;
+      this.residualCoder = residualCoder;
       this.outputQueue = new LinkedBlockingQueue<>();
+      this.keyedStateInternals = keyedStateInternals;
     }
 
     @Override
@@ -723,7 +935,8 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
                 timerReceiverFactory,
                 stateRequestHandler,
                 progressHandler,
-                finalizationHandler);
+                finalizationHandler,
+                checkpointHandler);
         mainInputReceiver = Iterables.getOnlyElement(remoteBundle.getInputReceivers().values());
       } catch (Exception e) {
         throw new RuntimeException("Failed to start remote bundle", e);
@@ -753,35 +966,42 @@ public class ExecutableStageDoFnOperator<InputT, OutputT> extends DoFnOperator<I
       Object timerKey = keyForTimer.get();
       Preconditions.checkNotNull(timerKey, "Key for timer needs to be set before calling onTimer");
       Preconditions.checkNotNull(remoteBundle, "Call to onTimer outside of a bundle");
-      KV<String, String> transformAndTimerFamilyId =
-          TimerReceiverFactory.decodeTimerDataTimerId(timerId);
-      LOG.debug(
-          "timer callback: {} {} {} {} {}",
-          transformAndTimerFamilyId.getKey(),
-          transformAndTimerFamilyId.getValue(),
-          window,
-          timestamp,
-          timeDomain);
-      FnDataReceiver<Timer> timerReceiver =
-          Preconditions.checkNotNull(
-              remoteBundle.getTimerReceivers().get(transformAndTimerFamilyId),
-              "No receiver found for timer %s %s",
-              transformAndTimerFamilyId.getKey(),
-              transformAndTimerFamilyId.getValue());
-      Timer<?> timerValue =
-          Timer.of(
-              timerKey,
-              "",
-              Collections.singletonList(window),
-              timestamp,
-              outputTimestamp,
-              // TODO: Support propagating the PaneInfo through.
-              PaneInfo.NO_FIRING);
-      try {
-        timerReceiver.accept(timerValue);
-      } catch (Exception e) {
-        throw new RuntimeException(
-            String.format(Locale.ENGLISH, "Failed to process timer %s", timerReceiver), e);
+      if (StateAndTimerBundleCheckpointHandler.isSdfTimer(timerId)) {
+        StateNamespace namespace = StateNamespaces.window(windowCoder, window);
+        WindowedValue stateValue =
+            keyedStateInternals.state(namespace, StateTags.value(timerId, residualCoder)).read();
+        processElement(stateValue);
+      } else {
+        KV<String, String> transformAndTimerFamilyId =
+            TimerReceiverFactory.decodeTimerDataTimerId(timerId);
+        LOG.debug(
+            "timer callback: {} {} {} {} {}",
+            transformAndTimerFamilyId.getKey(),
+            transformAndTimerFamilyId.getValue(),
+            window,
+            timestamp,
+            timeDomain);
+        FnDataReceiver<Timer> timerReceiver =
+            Preconditions.checkNotNull(
+                remoteBundle.getTimerReceivers().get(transformAndTimerFamilyId),
+                "No receiver found for timer %s %s",
+                transformAndTimerFamilyId.getKey(),
+                transformAndTimerFamilyId.getValue());
+        Timer<?> timerValue =
+            Timer.of(
+                timerKey,
+                "",
+                Collections.singletonList(window),
+                timestamp,
+                outputTimestamp,
+                // TODO: Support propagating the PaneInfo through.
+                PaneInfo.NO_FIRING);
+        try {
+          timerReceiver.accept(timerValue);
+        } catch (Exception e) {
+          throw new RuntimeException(
+              String.format(Locale.ENGLISH, "Failed to process timer %s", timerReceiver), e);
+        }
       }
     }
 
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java
new file mode 100644
index 0000000..29af81d
--- /dev/null
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SdfByteBufferKeySelector.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.flink.translation.wrappers.streaming;
+
+import java.nio.ByteBuffer;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
+
+/**
+ * {@link KeySelector} that retrieves a key from a {@code KV<KV<element, KV<restriction,
+ * watermarkState>>, size>}. This will return the element as encoded by the provided {@link Coder}
+ * in a {@link ByteBuffer}. This ensures that all key comparisons/hashing happen on the encoded
+ * form. Note that the reason we don't use the whole {@code KV<KV<element, KV<restriction,
+ * watermarkState>>, Double>} as the key is when checkpoint happens, we will get different
+ * restriction/watermarkState/size, which Flink treats as a new key. Using new key to set state and
+ * timer may cause defined behavior.
+ */
+public class SdfByteBufferKeySelector<K, V>
+    implements KeySelector<WindowedValue<KV<KV<K, V>, Double>>, ByteBuffer>,
+        ResultTypeQueryable<ByteBuffer> {
+
+  private final Coder<K> keyCoder;
+  private final SerializablePipelineOptions pipelineOptions;
+
+  public SdfByteBufferKeySelector(Coder<K> keyCoder, SerializablePipelineOptions pipelineOptions) {
+    this.keyCoder = keyCoder;
+    this.pipelineOptions = pipelineOptions;
+  }
+
+  @Override
+  public ByteBuffer getKey(WindowedValue<KV<KV<K, V>, Double>> value) {
+    K key = value.getValue().getKey().getKey();
+    return FlinkKeyUtils.encodeKey(key, keyCoder);
+  }
+
+  @Override
+  public TypeInformation<ByteBuffer> getProducedType() {
+    return new CoderTypeInformation<>(FlinkKeyUtils.ByteBufferCoder.of(), pipelineOptions.get());
+  }
+}
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
index 58ac315..67b772d 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunctionTest.java
@@ -34,6 +34,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.runners.core.construction.Timer;
 import org.apache.beam.runners.flink.metrics.FlinkMetricContainer;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
 import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
 import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
 import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -123,7 +124,8 @@ public class FlinkExecutableStageFunctionTest {
             any(),
             any(StateRequestHandler.class),
             any(BundleProgressHandler.class),
-            any(BundleFinalizationHandler.class)))
+            any(BundleFinalizationHandler.class),
+            any(BundleCheckpointHandler.class)))
         .thenReturn(remoteBundle);
     when(stageBundleFactory.getBundle(
             any(),
@@ -148,7 +150,8 @@ public class FlinkExecutableStageFunctionTest {
             any(),
             any(StateRequestHandler.class),
             any(BundleProgressHandler.class),
-            any(BundleFinalizationHandler.class)))
+            any(BundleFinalizationHandler.class),
+            any(BundleCheckpointHandler.class)))
         .thenReturn(bundle);
 
     @SuppressWarnings("unchecked")
@@ -172,7 +175,8 @@ public class FlinkExecutableStageFunctionTest {
             any(),
             any(StateRequestHandler.class),
             any(BundleProgressHandler.class),
-            any(BundleFinalizationHandler.class)))
+            any(BundleFinalizationHandler.class),
+            any(BundleCheckpointHandler.class)))
         .thenReturn(bundle);
 
     @SuppressWarnings("unchecked")
@@ -213,7 +217,8 @@ public class FlinkExecutableStageFunctionTest {
               TimerReceiverFactory timerReceiverFactory,
               StateRequestHandler stateRequestHandler,
               BundleProgressHandler progressHandler,
-              BundleFinalizationHandler finalizationHandler) {
+              BundleFinalizationHandler finalizationHandler,
+              BundleCheckpointHandler checkpointHandler) {
             return new RemoteBundle() {
               @Override
               public String getId() {
@@ -333,6 +338,7 @@ public class FlinkExecutableStageFunctionTest {
             jobInfo,
             outputMap,
             contextFactory,
+            null,
             null);
     function.setRuntimeContext(runtimeContext);
     Whitebox.setInternalState(function, "stateRequestHandler", stateRequestHandler);
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
index bc5c277..336b414 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperatorTest.java
@@ -67,6 +67,7 @@ import org.apache.beam.runners.flink.metrics.DoFnRunnerWithMetricsUpdate;
 import org.apache.beam.runners.flink.streaming.FlinkStateInternalsTest;
 import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageContextFactory;
 import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
 import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
 import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
 import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -213,7 +214,7 @@ public class ExecutableStageDoFnOperatorTest {
 
     @SuppressWarnings("unchecked")
     RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
-    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
 
     @SuppressWarnings("unchecked")
     FnDataReceiver<WindowedValue<?>> receiver = Mockito.mock(FnDataReceiver.class);
@@ -240,7 +241,7 @@ public class ExecutableStageDoFnOperatorTest {
 
     @SuppressWarnings("unchecked")
     RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
-    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
 
     @SuppressWarnings("unchecked")
     FnDataReceiver<WindowedValue<?>> receiver = Mockito.mock(FnDataReceiver.class);
@@ -323,7 +324,8 @@ public class ExecutableStageDoFnOperatorTest {
               TimerReceiverFactory timerReceiverFactory,
               StateRequestHandler stateRequestHandler,
               BundleProgressHandler progressHandler,
-              BundleFinalizationHandler finalizationHandler) {
+              BundleFinalizationHandler finalizationHandler,
+              BundleCheckpointHandler checkpointHandler) {
             return new RemoteBundle() {
               @Override
               public String getId() {
@@ -458,7 +460,7 @@ public class ExecutableStageDoFnOperatorTest {
                 .put(KV.of("transform", "timer2"), Mockito.mock(FnDataReceiver.class))
                 .put(KV.of("transform", "timer3"), Mockito.mock(FnDataReceiver.class))
                 .build());
-    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
 
     testHarness.open();
     assertThat(
@@ -576,7 +578,7 @@ public class ExecutableStageDoFnOperatorTest {
             ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder()
                 .put("input", Mockito.mock(FnDataReceiver.class))
                 .build());
-    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
 
     testHarness.open();
     testHarness.close();
@@ -624,7 +626,7 @@ public class ExecutableStageDoFnOperatorTest {
             ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder()
                 .put("input", Mockito.mock(FnDataReceiver.class))
                 .build());
-    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
 
     testHarness.open();
 
@@ -736,7 +738,7 @@ public class ExecutableStageDoFnOperatorTest {
 
     @SuppressWarnings("unchecked")
     RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
-    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
 
     KV<String, String> timerInputKey = KV.of("transformId", "timerId");
     AtomicBoolean timerInputReceived = new AtomicBoolean();
@@ -921,7 +923,7 @@ public class ExecutableStageDoFnOperatorTest {
             ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder()
                 .put("input", Mockito.mock(FnDataReceiver.class))
                 .build());
-    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any())).thenReturn(bundle);
+    when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
 
     testHarness.open();
 
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandlers.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandlers.java
new file mode 100644
index 0000000..6ed6127
--- /dev/null
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/BundleCheckpointHandlers.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.fnexecution.control;
+
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateInternalsFactory;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaces;
+import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.runners.core.TimerInternals;
+import org.apache.beam.runners.core.TimerInternalsFactory;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.fn.IdGenerator;
+import org.apache.beam.sdk.fn.IdGenerators;
+import org.apache.beam.sdk.state.TimeDomain;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Utility methods for creating {@link BundleCheckpointHandler}s. */
+@SuppressWarnings({
+  "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556)
+  "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+})
+public class BundleCheckpointHandlers {
+
+  /**
+   * A {@link BundleCheckpointHandler} which uses {@link
+   * org.apache.beam.runners.core.TimerInternals.TimerData} ans {@link
+   * org.apache.beam.sdk.state.ValueState} to reschedule {@link DelayedBundleApplication}.
+   */
+  public static class StateAndTimerBundleCheckpointHandler<T> implements BundleCheckpointHandler {
+    private static final Logger LOG =
+        LoggerFactory.getLogger(StateAndTimerBundleCheckpointHandler.class);
+    private final TimerInternalsFactory<T> timerInternalsFactory;
+    private final StateInternalsFactory<T> stateInternalsFactory;
+    private final Coder<WindowedValue<T>> residualCoder;
+    private final Coder windowCoder;
+    private final IdGenerator idGenerator = IdGenerators.incrementingLongs();
+    public static final String SDF_PREFIX = "sdf_checkpoint";
+
+    public StateAndTimerBundleCheckpointHandler(
+        TimerInternalsFactory<T> timerInternalsFactory,
+        StateInternalsFactory<T> stateInternalsFactory,
+        Coder<WindowedValue<T>> residualCoder,
+        Coder windowCoder) {
+      this.residualCoder = residualCoder;
+      this.windowCoder = windowCoder;
+      this.timerInternalsFactory = timerInternalsFactory;
+      this.stateInternalsFactory = stateInternalsFactory;
+    }
+
+    /**
+     * A helper function to help check whether the given timer is the timer which is set for
+     * rescheduling {@link DelayedBundleApplication}.
+     */
+    public static boolean isSdfTimer(String timerId) {
+      return timerId.startsWith(SDF_PREFIX);
+    }
+
+    private static String constructSdfCheckpointId(String id, int index) {
+      return SDF_PREFIX + ":" + id + ":" + index;
+    }
+
+    @Override
+    public void onCheckpoint(ProcessBundleResponse response) {
+      String id = idGenerator.getId();
+      for (int index = 0; index < response.getResidualRootsCount(); index++) {
+        DelayedBundleApplication residual = response.getResidualRoots(index);
+        if (!residual.hasApplication()) {
+          continue;
+        }
+        String tag = constructSdfCheckpointId(id, index);
+        try {
+          WindowedValue<T> stateValue =
+              CoderUtils.decodeFromByteArray(
+                  residualCoder, residual.getApplication().getElement().toByteArray());
+          TimerInternals timerInternals =
+              timerInternalsFactory.timerInternalsForKey((stateValue.getValue()));
+          StateInternals stateInternals =
+              stateInternalsFactory.stateInternalsForKey(stateValue.getValue());
+          // Calculate the timestamp for the timer.
+          Instant timestamp = Instant.now();
+          if (residual.hasRequestedTimeDelay()) {
+            timestamp = timestamp.plus(residual.getRequestedTimeDelay().getSeconds() * 1000);
+          }
+          // Calculate the watermark hold for the timer.
+          long outputTimestamp = BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis();
+          if (!residual.getApplication().getOutputWatermarksMap().isEmpty()) {
+            for (org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp outputWatermark :
+                residual.getApplication().getOutputWatermarksMap().values()) {
+              outputTimestamp = Math.min(outputTimestamp, outputWatermark.getSeconds() * 1000);
+            }
+          } else {
+            outputTimestamp = BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis();
+          }
+          for (BoundedWindow window : stateValue.getWindows()) {
+            StateNamespace stateNamespace = StateNamespaces.window(windowCoder, window);
+            timerInternals.setTimer(
+                stateNamespace,
+                tag,
+                "",
+                timestamp,
+                Instant.ofEpochMilli(outputTimestamp),
+                TimeDomain.PROCESSING_TIME);
+            stateInternals
+                .state(stateNamespace, StateTags.value(tag, residualCoder))
+                .write(
+                    WindowedValue.of(
+                        stateValue.getValue(),
+                        stateValue.getTimestamp(),
+                        ImmutableList.of(window),
+                        stateValue.getPane()));
+          }
+        } catch (Exception e) {
+          throw new RuntimeException("Failed to set timer/state for the residual", e);
+        }
+      }
+    }
+  }
+}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
index 87d1e1e..a0d370e 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java
@@ -459,7 +459,8 @@ public class DefaultJobBundleFactory implements JobBundleFactory {
         TimerReceiverFactory timerReceiverFactory,
         StateRequestHandler stateRequestHandler,
         BundleProgressHandler progressHandler,
-        BundleFinalizationHandler finalizationHandler)
+        BundleFinalizationHandler finalizationHandler,
+        BundleCheckpointHandler checkpointHandler)
         throws Exception {
       // TODO: Consider having BundleProcessor#newBundle take in an OutputReceiverFactory rather
       // than constructing the receiver map here. Every bundle factory will need this.
@@ -520,7 +521,8 @@ public class DefaultJobBundleFactory implements JobBundleFactory {
               getTimerReceivers(currentClient.processBundleDescriptor, timerReceiverFactory),
               stateRequestHandler,
               progressHandler,
-              finalizationHandler);
+              finalizationHandler,
+              checkpointHandler);
       return new RemoteBundle() {
         @Override
         public String getId() {
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
index 9fcfb36..7f51081 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SdkHarnessClient.java
@@ -192,19 +192,22 @@ public class SdkHarnessClient implements AutoCloseable {
         Map<KV<String, String>, RemoteOutputReceiver<Timer<?>>> timerReceivers,
         StateRequestHandler stateRequestHandler,
         BundleProgressHandler progressHandler,
-        BundleFinalizationHandler finalizationHandler) {
+        BundleFinalizationHandler finalizationHandler,
+        BundleCheckpointHandler checkpointHandler) {
       return newBundle(
           outputReceivers,
           timerReceivers,
           stateRequestHandler,
           progressHandler,
           BundleSplitHandler.unsupported(),
-          request -> {
-            throw new UnsupportedOperationException(
-                String.format(
-                    "The %s does not have a registered bundle checkpoint handler.",
-                    ActiveBundle.class.getSimpleName()));
-          },
+          checkpointHandler == null
+              ? request -> {
+                throw new UnsupportedOperationException(
+                    String.format(
+                        "The %s does not have a registered bundle checkpoint handler.",
+                        ActiveBundle.class.getSimpleName()));
+              }
+              : checkpointHandler,
           finalizationHandler == null
               ? bundleId -> {
                 throw new UnsupportedOperationException(
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java
index af4cb52..6e696c0 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactory.java
@@ -162,7 +162,8 @@ public class SingleEnvironmentInstanceJobBundleFactory implements JobBundleFacto
         TimerReceiverFactory timerReceiverFactory,
         StateRequestHandler stateRequestHandler,
         BundleProgressHandler progressHandler,
-        BundleFinalizationHandler finalizationHandler) {
+        BundleFinalizationHandler finalizationHandler,
+        BundleCheckpointHandler checkpointHandler) {
       Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
       for (Map.Entry<String, Coder> remoteOutputCoder :
           descriptor.getRemoteOutputCoders().entrySet()) {
@@ -195,7 +196,8 @@ public class SingleEnvironmentInstanceJobBundleFactory implements JobBundleFacto
           timerReceivers,
           stateRequestHandler,
           progressHandler,
-          finalizationHandler);
+          finalizationHandler,
+          checkpointHandler);
     }
 
     @Override
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java
index 3b0911e..ee2b8e5 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/StageBundleFactory.java
@@ -50,6 +50,22 @@ public interface StageBundleFactory extends AutoCloseable {
         outputReceiverFactory, null, stateRequestHandler, progressHandler, finalizationHandler);
   }
 
+  default RemoteBundle getBundle(
+      OutputReceiverFactory outputReceiverFactory,
+      StateRequestHandler stateRequestHandler,
+      BundleProgressHandler progressHandler,
+      BundleFinalizationHandler finalizationHandler,
+      BundleCheckpointHandler checkpointHandler)
+      throws Exception {
+    return getBundle(
+        outputReceiverFactory,
+        null,
+        stateRequestHandler,
+        progressHandler,
+        finalizationHandler,
+        checkpointHandler);
+  }
+
   /** Get a new {@link RemoteBundle bundle} for processing the data in an executable stage. */
   default RemoteBundle getBundle(
       OutputReceiverFactory outputReceiverFactory,
@@ -61,12 +77,29 @@ public interface StageBundleFactory extends AutoCloseable {
         outputReceiverFactory, timerReceiverFactory, stateRequestHandler, progressHandler, null);
   }
 
-  RemoteBundle getBundle(
+  default RemoteBundle getBundle(
       OutputReceiverFactory outputReceiverFactory,
       TimerReceiverFactory timerReceiverFactory,
       StateRequestHandler stateRequestHandler,
       BundleProgressHandler progressHandler,
       BundleFinalizationHandler finalizationHandler)
+      throws Exception {
+    return getBundle(
+        outputReceiverFactory,
+        timerReceiverFactory,
+        stateRequestHandler,
+        progressHandler,
+        finalizationHandler,
+        null);
+  }
+
+  RemoteBundle getBundle(
+      OutputReceiverFactory outputReceiverFactory,
+      TimerReceiverFactory timerReceiverFactory,
+      StateRequestHandler stateRequestHandler,
+      BundleProgressHandler progressHandler,
+      BundleFinalizationHandler finalizationHandler,
+      BundleCheckpointHandler checkpointHandler)
       throws Exception;
 
   ProcessBundleDescriptors.ExecutableProcessBundleDescriptor getProcessBundleDescriptor();
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 46a5d60..dcea277 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -1178,6 +1178,7 @@ public class RemoteExecutionTest implements Serializable {
             timerReceivers,
             StateRequestHandler.unsupported(),
             BundleProgressHandler.ignored(),
+            null,
             null)) {
       Iterables.getOnlyElement(bundle.getInputReceivers().values())
           .accept(valueInGlobalWindow(KV.of("X", "X")));
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
index 10d5835..97e05cf 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SparkExecutableStageFunctionTest.java
@@ -40,6 +40,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.runners.core.construction.Timer;
 import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
+import org.apache.beam.runners.fnexecution.control.BundleCheckpointHandler;
 import org.apache.beam.runners.fnexecution.control.BundleFinalizationHandler;
 import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
 import org.apache.beam.runners.fnexecution.control.ExecutableStageContext;
@@ -162,7 +163,8 @@ public class SparkExecutableStageFunctionTest {
               TimerReceiverFactory timerReceiverFactory,
               StateRequestHandler stateRequestHandler,
               BundleProgressHandler progressHandler,
-              BundleFinalizationHandler finalizationHandler) {
+              BundleFinalizationHandler finalizationHandler,
+              BundleCheckpointHandler checkpointHandler) {
             return new RemoteBundle() {
               @Override
               public String getId() {
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index 3e45f4a..303fe96 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -389,9 +389,6 @@ class FlinkRunnerTest(portable_runner_test.PortableRunnerTest):
   def test_sdf_with_watermark_tracking(self):
     raise unittest.SkipTest("BEAM-2939")
 
-  def test_sdf_with_sdf_initiated_checkpointing(self):
-    raise unittest.SkipTest("BEAM-2939")
-
   def test_callbacks_with_exception(self):
     raise unittest.SkipTest("BEAM-11021")