You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/03/19 16:58:45 UTC

[beam] branch master updated: [BEAM-9430] Migrate from ProcessContext#updateWatermark to WatermarkEstimators

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bde44e9  [BEAM-9430] Migrate from ProcessContext#updateWatermark to WatermarkEstimators
     new 4743e13  Merge pull request #11126 from lukecwik/splittabledofn2
bde44e9 is described below

commit bde44e977a27e94acda9a9a46692d9f9fc924379
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Mar 13 14:21:03 2020 -0700

    [BEAM-9430] Migrate from ProcessContext#updateWatermark to WatermarkEstimators
---
 .../runners/apex/translation/ParDoTranslator.java  |   9 +-
 .../translation/operators/ApexParDoOperator.java   |   4 +-
 .../core/construction/ParDoTranslation.java        |  10 +-
 .../runners/core/construction/SplittableParDo.java |  79 ++++++++--
 .../construction/SplittableParDoNaiveBounded.java  | 161 ++++++++++++++++---
 .../org/apache/beam/runners/core/DoFnRunners.java  |   2 +-
 ...TimeBoundedSplittableProcessElementInvoker.java | 130 +++++-----------
 .../apache/beam/runners/core/SimpleDoFnRunner.java |   5 -
 .../core/SplittableParDoViaKeyedWorkItems.java     | 173 ++++++++++++++++++---
 .../core/SplittableProcessElementInvoker.java      |  16 +-
 ...BoundedSplittableProcessElementInvokerTest.java |  30 ++--
 .../runners/core/SplittableParDoProcessFnTest.java |  56 +++++--
 .../SplittableProcessElementsEvaluatorFactory.java |  24 +--
 .../runners/direct/TransformEvaluatorRegistry.java |   4 +-
 .../flink/FlinkStreamingTransformTranslators.java  |  12 +-
 .../dataflow/DataflowPipelineTranslator.java       |  22 ++-
 .../dataflow/PrimitiveParDoSingleFactory.java      |  14 +-
 .../dataflow/DataflowPipelineTranslatorTest.java   |   9 +-
 .../worker/SplittableProcessFnFactory.java         |  21 +--
 .../src/main/java/org/apache/beam/sdk/io/Read.java |  16 +-
 .../java/org/apache/beam/sdk/transforms/DoFn.java  |  12 --
 .../org/apache/beam/sdk/transforms/DoFnTester.java |   5 -
 .../java/org/apache/beam/sdk/transforms/Watch.java |  34 +++-
 .../reflect/ByteBuddyDoFnInvokerFactory.java       |  34 +++-
 .../beam/sdk/transforms/reflect/DoFnInvoker.java   |   7 +-
 .../splittabledofn/ManualWatermarkEstimator.java   |   3 +-
 .../splittabledofn/WatermarkEstimator.java         |   2 +-
 .../splittabledofn/WatermarkEstimators.java        |  61 ++++----
 .../sdk/transforms/reflect/DoFnInvokersTest.java   |  28 +++-
 .../splittabledofn/WatermarkEstimatorsTest.java    |  43 ++---
 .../sdk/fn/splittabledofn/WatermarkEstimators.java | 115 ++++++++++++++
 .../fn/splittabledofn/WatermarkEstimatorsTest.java | 109 +++++++++++++
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    | 130 ++++++++++------
 .../beam/fn/harness/FnApiDoFnRunnerTest.java       |  63 ++++++--
 34 files changed, 1085 insertions(+), 358 deletions(-)

diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
index 9a35a72..7144a0e 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java
@@ -138,12 +138,15 @@ class ParDoTranslator<InputT, OutputT>
     }
   }
 
-  static class SplittableProcessElementsTranslator<InputT, OutputT, RestrictionT, PositionT>
-      implements TransformTranslator<ProcessElements<InputT, OutputT, RestrictionT, PositionT>> {
+  static class SplittableProcessElementsTranslator<
+          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+      implements TransformTranslator<
+          ProcessElements<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>> {
 
     @Override
     public void translate(
-        ProcessElements<InputT, OutputT, RestrictionT, PositionT> transform,
+        ProcessElements<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+            transform,
         TranslationContext context) {
 
       Map<TupleTag<?>, PValue> outputs = context.getOutputs();
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
index 8df7997..111d997 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java
@@ -75,7 +75,6 @@ import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
-import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.UserCodeException;
@@ -531,8 +530,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator
           (StateInternalsFactory<byte[]>) this.currentKeyStateInternals.getFactory();
 
       @SuppressWarnings({"rawtypes", "unchecked"})
-      ProcessFn<InputT, OutputT, Object, RestrictionTracker<Object, Object>> splittableDoFn =
-          (ProcessFn) doFn;
+      ProcessFn<InputT, OutputT, Object, Object, Object> splittableDoFn = (ProcessFn) doFn;
       splittableDoFn.setStateInternalsFactory(stateInternalsFactory);
       TimerInternalsFactory<byte[]> timerInternalsFactory = key -> currentKeyTimerInternals;
       splittableDoFn.setTimerInternalsFactory(timerInternalsFactory);
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
index da8dae0..ebf7440 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
@@ -62,6 +62,7 @@ import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
 import org.apache.beam.sdk.transforms.ViewFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
@@ -202,9 +203,12 @@ public class ParDoTranslation {
     final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
     final String restrictionCoderId;
     if (signature.processElement().isSplittable()) {
-      final Coder<?> restrictionCoder =
-          DoFnInvokers.invokerFor(doFn).invokeGetRestrictionCoder(pipeline.getCoderRegistry());
-      restrictionCoderId = components.registerCoder(restrictionCoder);
+      DoFnInvoker<?, ?> doFnInvoker = DoFnInvokers.invokerFor(doFn);
+      final Coder<?> restrictionAndWatermarkStateCoder =
+          KvCoder.of(
+              doFnInvoker.invokeGetRestrictionCoder(pipeline.getCoderRegistry()),
+              doFnInvoker.invokeGetWatermarkEstimatorStateCoder(pipeline.getCoderRegistry()));
+      restrictionCoderId = components.registerCoder(restrictionAndWatermarkStateCoder);
     } else {
       restrictionCoderId = "";
     }
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
index 84ed10a..d76ab8a 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java
@@ -40,6 +40,7 @@ import org.apache.beam.runners.core.construction.ReadTranslation.BoundedReadPayl
 import org.apache.beam.runners.core.construction.ReadTranslation.UnboundedReadPayloadTranslator;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -55,7 +56,9 @@ import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.BaseArgumentProvider;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
@@ -86,7 +89,7 @@ import org.joda.time.Instant;
  * <p>This transform is intended as a helper for internal use by runners when implementing {@code
  * ParDo.of(splittable DoFn)}, but not for direct use by pipeline writers.
  */
-public class SplittableParDo<InputT, OutputT, RestrictionT>
+public class SplittableParDo<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
     extends PTransform<PCollection<InputT>, PCollectionTuple> {
   /**
    * A {@link PTransformOverrideFactory} that overrides a <a
@@ -145,7 +148,7 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
    * ParDo}. Instead {@link ParDoTranslation} will be used to extract fields.
    */
   @SuppressWarnings({"unchecked", "rawtypes"})
-  public static <InputT, OutputT> SplittableParDo<InputT, OutputT, ?> forAppliedParDo(
+  public static <InputT, OutputT> SplittableParDo<InputT, OutputT, ?, ?> forAppliedParDo(
       AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> parDo) {
     checkArgument(parDo != null, "parDo must not be null");
 
@@ -170,6 +173,9 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
     Coder<RestrictionT> restrictionCoder =
         DoFnInvokers.invokerFor(doFn)
             .invokeGetRestrictionCoder(input.getPipeline().getCoderRegistry());
+    Coder<WatermarkEstimatorStateT> watermarkEstimatorStateCoder =
+        DoFnInvokers.invokerFor(doFn)
+            .invokeGetWatermarkEstimatorStateCoder(input.getPipeline().getCoderRegistry());
     Coder<KV<InputT, RestrictionT>> splitCoder = KvCoder.of(input.getCoder(), restrictionCoder);
 
     PCollection<KV<byte[], KV<InputT, RestrictionT>>> keyedRestrictions =
@@ -192,6 +198,7 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
             doFn,
             input.getCoder(),
             restrictionCoder,
+            watermarkEstimatorStateCoder,
             (WindowingStrategy<InputT, ?>) input.getWindowingStrategy(),
             sideInputs,
             mainOutputTag,
@@ -221,11 +228,12 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
    * method for a splittable {@link DoFn} on each {@link KV} of the input {@link PCollection} of
    * {@link KV KVs} keyed with arbitrary but globally unique keys.
    */
-  public static class ProcessKeyedElements<InputT, OutputT, RestrictionT>
+  public static class ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
       extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
     private final DoFn<InputT, OutputT> fn;
     private final Coder<InputT> elementCoder;
     private final Coder<RestrictionT> restrictionCoder;
+    private final Coder<WatermarkEstimatorStateT> watermarkEstimatorStateCoder;
     private final WindowingStrategy<InputT, ?> windowingStrategy;
     private final List<PCollectionView<?>> sideInputs;
     private final TupleTag<OutputT> mainOutputTag;
@@ -245,6 +253,7 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
         DoFn<InputT, OutputT> fn,
         Coder<InputT> elementCoder,
         Coder<RestrictionT> restrictionCoder,
+        Coder<WatermarkEstimatorStateT> watermarkEstimatorStateCoder,
         WindowingStrategy<InputT, ?> windowingStrategy,
         List<PCollectionView<?>> sideInputs,
         TupleTag<OutputT> mainOutputTag,
@@ -253,6 +262,7 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
       this.fn = fn;
       this.elementCoder = elementCoder;
       this.restrictionCoder = restrictionCoder;
+      this.watermarkEstimatorStateCoder = watermarkEstimatorStateCoder;
       this.windowingStrategy = windowingStrategy;
       this.sideInputs = sideInputs;
       this.mainOutputTag = mainOutputTag;
@@ -272,6 +282,10 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
       return restrictionCoder;
     }
 
+    public Coder<WatermarkEstimatorStateT> getWatermarkEstimatorStateCoder() {
+      return watermarkEstimatorStateCoder;
+    }
+
     public WindowingStrategy<InputT, ?> getInputWindowingStrategy() {
       return windowingStrategy;
     }
@@ -340,7 +354,8 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
 
   /** A translator for {@link ProcessKeyedElements}. */
   public static class ProcessKeyedElementsTranslator
-      implements PTransformTranslation.TransformPayloadTranslator<ProcessKeyedElements<?, ?, ?>> {
+      implements PTransformTranslation.TransformPayloadTranslator<
+          ProcessKeyedElements<?, ?, ?, ?>> {
 
     public static TransformPayloadTranslator create() {
       return new ProcessKeyedElementsTranslator();
@@ -349,15 +364,16 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
     private ProcessKeyedElementsTranslator() {}
 
     @Override
-    public String getUrn(ProcessKeyedElements<?, ?, ?> transform) {
+    public String getUrn(ProcessKeyedElements<?, ?, ?, ?> transform) {
       return PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN;
     }
 
     @Override
     public FunctionSpec translate(
-        AppliedPTransform<?, ?, ProcessKeyedElements<?, ?, ?>> transform, SdkComponents components)
+        AppliedPTransform<?, ?, ProcessKeyedElements<?, ?, ?, ?>> transform,
+        SdkComponents components)
         throws IOException {
-      ProcessKeyedElements<?, ?, ?> pke = transform.getTransform();
+      ProcessKeyedElements<?, ?, ?, ?> pke = transform.getTransform();
       final DoFn<?, ?> fn = pke.getFn();
       final DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
       final String restrictionCoderId = components.registerCoder(pke.getRestrictionCoder());
@@ -483,7 +499,7 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
     }
 
     @ProcessElement
-    public void processElement(ProcessContext context) {
+    public void processElement(ProcessContext context, BoundedWindow w) {
       context.output(
           KV.of(
               context.element(),
@@ -495,6 +511,26 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
                     }
 
                     @Override
+                    public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+                      return context.timestamp();
+                    }
+
+                    @Override
+                    public PipelineOptions pipelineOptions() {
+                      return context.getPipelineOptions();
+                    }
+
+                    @Override
+                    public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
+                      return context.pane();
+                    }
+
+                    @Override
+                    public BoundedWindow window() {
+                      return w;
+                    }
+
+                    @Override
                     public String getErrorContext() {
                       return PairWithRestrictionFn.class.getSimpleName()
                           + ".invokeGetInitialRestriction";
@@ -528,7 +564,7 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
     }
 
     @ProcessElement
-    public void processElement(final ProcessContext c) {
+    public void processElement(final ProcessContext c, BoundedWindow w) {
       invoker.invokeSplitRestriction(
           (ArgumentProvider)
               new BaseArgumentProvider<InputT, RestrictionT>() {
@@ -543,6 +579,31 @@ public class SplittableParDo<InputT, OutputT, RestrictionT>
                 }
 
                 @Override
+                public RestrictionTracker<?, ?> restrictionTracker() {
+                  return invoker.invokeNewTracker((DoFnInvoker.BaseArgumentProvider) this);
+                }
+
+                @Override
+                public Instant timestamp(DoFn<InputT, RestrictionT> doFn) {
+                  return c.timestamp();
+                }
+
+                @Override
+                public PipelineOptions pipelineOptions() {
+                  return c.getPipelineOptions();
+                }
+
+                @Override
+                public PaneInfo paneInfo(DoFn<InputT, RestrictionT> doFn) {
+                  return c.pane();
+                }
+
+                @Override
+                public BoundedWindow window() {
+                  return w;
+                }
+
+                @Override
                 public OutputReceiver<RestrictionT> outputReceiver(
                     DoFn<InputT, RestrictionT> doFn) {
                   return new OutputReceiver<RestrictionT>() {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
index 17d435c..b002169 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java
@@ -62,11 +62,11 @@ import org.joda.time.Instant;
  */
 public class SplittableParDoNaiveBounded {
   /** Overrides a {@link ProcessKeyedElements} into {@link SplittableProcessNaive}. */
-  public static class OverrideFactory<InputT, OutputT, RestrictionT>
+  public static class OverrideFactory<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
       implements PTransformOverrideFactory<
           PCollection<KV<byte[], KV<InputT, RestrictionT>>>,
           PCollectionTuple,
-          ProcessKeyedElements<InputT, OutputT, RestrictionT>> {
+          ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>> {
     @Override
     public PTransformReplacement<
             PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple>
@@ -74,7 +74,7 @@ public class SplittableParDoNaiveBounded {
             AppliedPTransform<
                     PCollection<KV<byte[], KV<InputT, RestrictionT>>>,
                     PCollectionTuple,
-                    ProcessKeyedElements<InputT, OutputT, RestrictionT>>
+                    ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>>
                 transform) {
       checkArgument(
           DoFnSignatures.signatureForDoFn(transform.getTransform().getFn()).isBoundedPerElement()
@@ -93,11 +93,13 @@ public class SplittableParDoNaiveBounded {
   }
 
   static class SplittableProcessNaive<
-          InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
+          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
       extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
-    private final ProcessKeyedElements<InputT, OutputT, RestrictionT> original;
+    private final ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
+        original;
 
-    SplittableProcessNaive(ProcessKeyedElements<InputT, OutputT, RestrictionT> original) {
+    SplittableProcessNaive(
+        ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT> original) {
       this.original = original;
     }
 
@@ -109,13 +111,15 @@ public class SplittableParDoNaiveBounded {
           .apply(
               "NaiveProcess",
               ParDo.of(
-                      new NaiveProcessFn<InputT, OutputT, RestrictionT, TrackerT>(original.getFn()))
+                      new NaiveProcessFn<
+                          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>(
+                          original.getFn()))
                   .withSideInputs(original.getSideInputs())
                   .withOutputTags(original.getMainOutputTag(), original.getAdditionalOutputTags()));
     }
   }
 
-  static class NaiveProcessFn<InputT, OutputT, RestrictionT, PositionT>
+  static class NaiveProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
       extends DoFn<KV<InputT, RestrictionT>, OutputT> {
     private final DoFn<InputT, OutputT> fn;
 
@@ -160,15 +164,84 @@ public class SplittableParDoNaiveBounded {
 
     @ProcessElement
     public void process(ProcessContext c, BoundedWindow w) {
+      WatermarkEstimatorStateT initialWatermarkEstimatorState =
+          (WatermarkEstimatorStateT)
+              invoker.invokeGetInitialWatermarkEstimatorState(
+                  new BaseArgumentProvider<InputT, OutputT>() {
+                    @Override
+                    public InputT element(DoFn<InputT, OutputT> doFn) {
+                      return c.element().getKey();
+                    }
+
+                    @Override
+                    public Object restriction() {
+                      return c.element().getValue();
+                    }
+
+                    @Override
+                    public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+                      return c.timestamp();
+                    }
+
+                    @Override
+                    public PipelineOptions pipelineOptions() {
+                      return c.getPipelineOptions();
+                    }
+
+                    @Override
+                    public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
+                      return c.pane();
+                    }
+
+                    @Override
+                    public BoundedWindow window() {
+                      return w;
+                    }
+
+                    @Override
+                    public String getErrorContext() {
+                      return NaiveProcessFn.class.getSimpleName()
+                          + ".invokeGetInitialWatermarkEstimatorState";
+                    }
+                  });
+
       RestrictionT restriction = c.element().getValue();
+      WatermarkEstimatorStateT watermarkEstimatorState = initialWatermarkEstimatorState;
       while (true) {
-        RestrictionT finalRestriction = restriction;
+        RestrictionT currentRestriction = restriction;
+        WatermarkEstimatorStateT currentWatermarkEstimatorState = watermarkEstimatorState;
+
         RestrictionTracker<RestrictionT, PositionT> tracker =
             invoker.invokeNewTracker(
                 new BaseArgumentProvider<InputT, OutputT>() {
                   @Override
+                  public InputT element(DoFn<InputT, OutputT> doFn) {
+                    return c.element().getKey();
+                  }
+
+                  @Override
                   public RestrictionT restriction() {
-                    return finalRestriction;
+                    return currentRestriction;
+                  }
+
+                  @Override
+                  public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+                    return c.timestamp();
+                  }
+
+                  @Override
+                  public PipelineOptions pipelineOptions() {
+                    return c.getPipelineOptions();
+                  }
+
+                  @Override
+                  public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
+                    return c.pane();
+                  }
+
+                  @Override
+                  public BoundedWindow window() {
+                    return w;
                   }
 
                   @Override
@@ -176,10 +249,57 @@ public class SplittableParDoNaiveBounded {
                     return NaiveProcessFn.class.getSimpleName() + ".invokeNewTracker";
                   }
                 });
+        WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator =
+            invoker.invokeNewWatermarkEstimator(
+                new BaseArgumentProvider<InputT, OutputT>() {
+                  @Override
+                  public InputT element(DoFn<InputT, OutputT> doFn) {
+                    return c.element().getKey();
+                  }
+
+                  @Override
+                  public RestrictionT restriction() {
+                    return currentRestriction;
+                  }
+
+                  @Override
+                  public WatermarkEstimatorStateT watermarkEstimatorState() {
+                    return currentWatermarkEstimatorState;
+                  }
+
+                  @Override
+                  public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+                    return c.timestamp();
+                  }
+
+                  @Override
+                  public PipelineOptions pipelineOptions() {
+                    return c.getPipelineOptions();
+                  }
+
+                  @Override
+                  public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
+                    return c.pane();
+                  }
+
+                  @Override
+                  public BoundedWindow window() {
+                    return w;
+                  }
+
+                  @Override
+                  public String getErrorContext() {
+                    return NaiveProcessFn.class.getSimpleName() + ".invokeNewWatermarkEstimator";
+                  }
+                });
         ProcessContinuation continuation =
             invoker.invokeProcessElement(
-                new NestedProcessContext<>(fn, c, c.element().getKey(), w, tracker));
+                new NestedProcessContext<>(
+                    fn, c, c.element().getKey(), w, tracker, watermarkEstimator));
         if (continuation.shouldResume()) {
+          // Fetch the watermark before splitting to ensure that the watermark applies to both
+          // the primary and the residual.
+          watermarkEstimatorState = watermarkEstimator.getState();
           restriction = tracker.trySplit(0).getResidual();
           Uninterruptibles.sleepUninterruptibly(
               continuation.resumeDelay().getMillis(), TimeUnit.MILLISECONDS);
@@ -236,25 +356,33 @@ public class SplittableParDoNaiveBounded {
     }
 
     private static class NestedProcessContext<
-            InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker<RestrictionT, ?>>
+            InputT,
+            OutputT,
+            RestrictionT,
+            TrackerT extends RestrictionTracker<RestrictionT, ?>,
+            WatermarkEstimatorStateT,
+            WatermarkEstimatorT extends WatermarkEstimator<WatermarkEstimatorStateT>>
         extends DoFn<InputT, OutputT>.ProcessContext implements ArgumentProvider<InputT, OutputT> {
 
       private final BoundedWindow window;
       private final DoFn<KV<InputT, RestrictionT>, OutputT>.ProcessContext outerContext;
       private final InputT element;
       private final TrackerT tracker;
+      private final WatermarkEstimatorT watermarkEstimator;
 
       private NestedProcessContext(
           DoFn<InputT, OutputT> fn,
           DoFn<KV<InputT, RestrictionT>, OutputT>.ProcessContext outerContext,
           InputT element,
           BoundedWindow window,
-          TrackerT tracker) {
+          TrackerT tracker,
+          WatermarkEstimatorT watermarkEstimator) {
         fn.super();
         this.window = window;
         this.outerContext = outerContext;
         this.element = element;
         this.tracker = tracker;
+        this.watermarkEstimator = watermarkEstimator;
       }
 
       @Override
@@ -413,11 +541,6 @@ public class SplittableParDoNaiveBounded {
       }
 
       @Override
-      public void updateWatermark(Instant watermark) {
-        // Ignore watermark updates
-      }
-
-      @Override
       public Object watermarkEstimatorState() {
         throw new UnsupportedOperationException(
             "@WatermarkEstimatorState parameters are not supported.");
@@ -425,7 +548,7 @@ public class SplittableParDoNaiveBounded {
 
       @Override
       public WatermarkEstimator<?> watermarkEstimator() {
-        throw new UnsupportedOperationException("WatermarkEstimator parameters are not supported.");
+        return watermarkEstimator;
       }
 
       // ----------- Unsupported methods --------------------
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java
index 88ba954..d7b87dd 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java
@@ -163,7 +163,7 @@ public class DoFnRunners {
 
   public static <InputT, OutputT, RestrictionT>
       ProcessFnRunner<InputT, OutputT, RestrictionT> newProcessFnRunner(
-          ProcessFn<InputT, OutputT, RestrictionT, ?> fn,
+          ProcessFn<InputT, OutputT, RestrictionT, ?, ?> fn,
           PipelineOptions options,
           Collection<PCollectionView<?>> views,
           ReadyCheckingSideInputReader sideInputReader,
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
index 9c0e79a..f3c3c39 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java
@@ -26,13 +26,10 @@ import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.fn.splittabledofn.RestrictionTrackers;
+import org.apache.beam.sdk.fn.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.state.State;
 import org.apache.beam.sdk.state.TimeDomain;
-import org.apache.beam.sdk.state.Timer;
-import org.apache.beam.sdk.state.TimerMap;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
 import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
 import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
@@ -41,7 +38,8 @@ import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.splittabledofn.TimestampObservingWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
@@ -59,8 +57,9 @@ import org.joda.time.Instant;
  * outputs), or runs for the given duration.
  */
 public class OutputAndTimeBoundedSplittableProcessElementInvoker<
-        InputT, OutputT, RestrictionT, PositionT>
-    extends SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, PositionT> {
+        InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+    extends SplittableProcessElementInvoker<
+        InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> {
   private final DoFn<InputT, OutputT> fn;
   private final PipelineOptions pipelineOptions;
   private final OutputWindowedValue<OutputT> output;
@@ -106,8 +105,9 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
   public Result invokeProcessElement(
       DoFnInvoker<InputT, OutputT> invoker,
       final WindowedValue<InputT> element,
-      final RestrictionTracker<RestrictionT, PositionT> tracker) {
-    final ProcessContext processContext = new ProcessContext(element, tracker);
+      final RestrictionTracker<RestrictionT, PositionT> tracker,
+      final WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator) {
+    final ProcessContext processContext = new ProcessContext(element, tracker, watermarkEstimator);
 
     DoFn.ProcessContinuation cont =
         invoker.invokeProcessElement(
@@ -134,16 +134,6 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
               }
 
               @Override
-              public Object sideInput(String tagId) {
-                throw new UnsupportedOperationException("Not supported in SplittableDoFn");
-              }
-
-              @Override
-              public Object schemaElement(int index) {
-                throw new UnsupportedOperationException("Not supported in SplittableDoFn");
-              }
-
-              @Override
               public Instant timestamp(DoFn<InputT, OutputT> doFn) {
                 return processContext.timestamp();
               }
@@ -176,28 +166,13 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
               }
 
               @Override
-              public BundleFinalizer bundleFinalizer() {
-                throw new UnsupportedOperationException(
-                    "Not supported in non-portable SplittableDoFn");
-              }
-
-              @Override
               public RestrictionTracker<?, ?> restrictionTracker() {
                 return processContext.tracker;
               }
 
-              // Unsupported methods below.
-
-              @Override
-              public BoundedWindow window() {
-                throw new UnsupportedOperationException(
-                    "Access to window of the element not supported in Splittable DoFn");
-              }
-
               @Override
-              public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
-                throw new UnsupportedOperationException(
-                    "Access to pane of the element not supported in Splittable DoFn");
+              public WatermarkEstimator<?> watermarkEstimator() {
+                return processContext.watermarkEstimator;
               }
 
               @Override
@@ -205,6 +180,8 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
                 return pipelineOptions;
               }
 
+              // Unsupported methods below.
+
               @Override
               public StartBundleContext startBundleContext(DoFn<InputT, OutputT> doFn) {
                 throw new IllegalStateException(
@@ -218,34 +195,11 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
                     "Should not access finishBundleContext() from @"
                         + DoFn.ProcessElement.class.getSimpleName());
               }
-
-              @Override
-              public DoFn<InputT, OutputT>.OnTimerContext onTimerContext(
-                  DoFn<InputT, OutputT> doFn) {
-                throw new UnsupportedOperationException(
-                    "Access to timers not supported in Splittable DoFn");
-              }
-
-              @Override
-              public State state(String stateId, boolean alwaysFetched) {
-                throw new UnsupportedOperationException(
-                    "Access to state not supported in Splittable DoFn");
-              }
-
-              @Override
-              public Timer timer(String timerId) {
-                throw new UnsupportedOperationException(
-                    "Access to timers not supported in Splittable DoFn");
-              }
-
-              @Override
-              public TimerMap timerFamily(String tagId) {
-                throw new UnsupportedOperationException(
-                    "Access to timerFamily not supported in Splittable DoFn");
-              }
             });
     processContext.cancelScheduledCheckpoint();
-    @Nullable KV<RestrictionT, Instant> residual = processContext.getTakenCheckpoint();
+    @Nullable
+    KV<RestrictionT, KV<Instant, WatermarkEstimatorStateT>> residual =
+        processContext.getTakenCheckpoint();
     if (cont.shouldResume()) {
       checkState(
           !processContext.hasClaimFailed,
@@ -274,7 +228,10 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
           // "a runner repeatedly checkpoints the DoFn before it has a chance to even attempt
           // claiming work": the former is valid, and the latter would be a bug, and is addressed
           // by not checkpointing the tracker until it attempts to claim some work.
-          residual = KV.of(tracker.currentRestriction(), processContext.getLastReportedWatermark());
+          residual =
+              KV.of(
+                  tracker.currentRestriction(),
+                  KV.of(watermarkEstimator.currentWatermark(), watermarkEstimator.getState()));
           // Don't call tracker.checkDone() - it's not done.
         }
       } else {
@@ -300,15 +257,18 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
       // Can only be true if cont.shouldResume() is false and no checkpoint was taken.
       // This means the restriction has been fully processed.
       checkState(!cont.shouldResume());
-      return new Result(null, cont, BoundedWindow.TIMESTAMP_MAX_VALUE);
+      return new Result(null, cont, null, null);
     }
-    return new Result(residual.getKey(), cont, residual.getValue());
+    return new Result(
+        residual.getKey(), cont, residual.getValue().getKey(), residual.getValue().getValue());
   }
 
   private class ProcessContext extends DoFn<InputT, OutputT>.ProcessContext
       implements RestrictionTrackers.ClaimObserver<PositionT> {
     private final WindowedValue<InputT> element;
     private final RestrictionTracker<RestrictionT, PositionT> tracker;
+    private final WatermarkEstimators.WatermarkAndStateObserver<WatermarkEstimatorStateT>
+        watermarkEstimator;
     private int numClaimedBlocks;
     private boolean hasClaimFailed;
 
@@ -320,17 +280,20 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
     // the call completed before reaching the given number of outputs or duration.
     private @Nullable RestrictionT checkpoint;
     // Watermark captured at the moment before checkpoint was taken, describing a lower bound
-    // on the output from "checkpoint".
-    private @Nullable Instant residualWatermark;
+    // on the output from "checkpoint" and its associated watermark estimator state.
+    private @Nullable KV<Instant, WatermarkEstimatorStateT> residualWatermarkAndState;
+
     // A handle on the scheduled action to take a checkpoint.
     private @Nullable Future<?> scheduledCheckpoint;
-    private @Nullable Instant lastReportedWatermark;
 
     public ProcessContext(
-        WindowedValue<InputT> element, RestrictionTracker<RestrictionT, PositionT> tracker) {
+        WindowedValue<InputT> element,
+        RestrictionTracker<RestrictionT, PositionT> tracker,
+        WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator) {
       fn.super();
       this.element = element;
       this.tracker = RestrictionTrackers.observe(tracker, this);
+      this.watermarkEstimator = WatermarkEstimators.threadSafe(watermarkEstimator);
     }
 
     @Override
@@ -368,11 +331,11 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
       }
     }
 
-    synchronized KV<RestrictionT, Instant> takeCheckpointNow() {
+    synchronized KV<RestrictionT, KV<Instant, WatermarkEstimatorStateT>> takeCheckpointNow() {
       // This method may be entered either via .output(), or via scheduledCheckpoint.
       // Only one of them "wins" - tracker.checkpoint() must be called only once.
       if (checkpoint == null) {
-        residualWatermark = lastReportedWatermark;
+        residualWatermarkAndState = watermarkEstimator.getWatermarkAndState();
         SplitResult<RestrictionT> split = tracker.trySplit(0);
         if (split != null) {
           checkpoint = checkNotNull(split.getResidual());
@@ -382,9 +345,9 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
     }
 
     @Nullable
-    synchronized KV<RestrictionT, Instant> getTakenCheckpoint() {
+    synchronized KV<RestrictionT, KV<Instant, WatermarkEstimatorStateT>> getTakenCheckpoint() {
       // The checkpoint may or may not have been taken.
-      return (checkpoint == null) ? null : KV.of(checkpoint, residualWatermark);
+      return (checkpoint == null) ? null : KV.of(checkpoint, residualWatermarkAndState);
     }
 
     @Override
@@ -411,21 +374,6 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
     }
 
     @Override
-    public synchronized void updateWatermark(Instant watermark) {
-      // Updating the watermark without any claimed blocks is allowed.
-      // The watermark is a promise about the timestamps of output from future claimed blocks.
-      // Such a promise can be made even if there are no claimed blocks. E.g. imagine reading
-      // from a streaming source that currently has no new data: there are no blocks to claim, but
-      // we may still want to advance the watermark if we have information about what timestamps
-      // of future elements in the source will be like.
-      lastReportedWatermark = watermark;
-    }
-
-    synchronized Instant getLastReportedWatermark() {
-      return lastReportedWatermark;
-    }
-
-    @Override
     public PipelineOptions getPipelineOptions() {
       return pipelineOptions;
     }
@@ -438,6 +386,9 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
     @Override
     public void outputWithTimestamp(OutputT value, Instant timestamp) {
       noteOutput();
+      if (watermarkEstimator instanceof TimestampObservingWatermarkEstimator) {
+        ((TimestampObservingWatermarkEstimator) watermarkEstimator).observeTimestamp(timestamp);
+      }
       output.outputWindowedValue(value, timestamp, element.getWindows(), element.getPane());
     }
 
@@ -449,6 +400,9 @@ public class OutputAndTimeBoundedSplittableProcessElementInvoker<
     @Override
     public <T> void outputWithTimestamp(TupleTag<T> tag, T value, Instant timestamp) {
       noteOutput();
+      if (watermarkEstimator instanceof TimestampObservingWatermarkEstimator) {
+        ((TimestampObservingWatermarkEstimator) watermarkEstimator).observeTimestamp(timestamp);
+      }
       output.outputWindowedValue(tag, value, timestamp, element.getWindows(), element.getPane());
     }
 
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
index aec8365..75f9e14 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
@@ -389,11 +389,6 @@ public class SimpleDoFnRunner<InputT, OutputT> implements DoFnRunner<InputT, Out
     }
 
     @Override
-    public void updateWatermark(Instant watermark) {
-      throw new UnsupportedOperationException("Only splittable DoFn's can use updateWatermark()");
-    }
-
-    @Override
     public void output(OutputT output) {
       output(mainOutputTag, output);
     }
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
index 28277f9..01612f0 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java
@@ -36,17 +36,16 @@ import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.state.WatermarkHoldState;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
-import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
-import org.apache.beam.sdk.transforms.DoFn.StartBundleContext;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.BaseArgumentProvider;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
@@ -106,11 +105,11 @@ public class SplittableParDoViaKeyedWorkItems {
   }
 
   /** Overrides a {@link ProcessKeyedElements} into {@link SplittableProcessViaKeyedWorkItems}. */
-  public static class OverrideFactory<InputT, OutputT, RestrictionT>
+  public static class OverrideFactory<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
       implements PTransformOverrideFactory<
           PCollection<KV<byte[], KV<InputT, RestrictionT>>>,
           PCollectionTuple,
-          ProcessKeyedElements<InputT, OutputT, RestrictionT>> {
+          ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>> {
     @Override
     public PTransformReplacement<
             PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple>
@@ -118,7 +117,7 @@ public class SplittableParDoViaKeyedWorkItems {
             AppliedPTransform<
                     PCollection<KV<byte[], KV<InputT, RestrictionT>>>,
                     PCollectionTuple,
-                    ProcessKeyedElements<InputT, OutputT, RestrictionT>>
+                    ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>>
                 transform) {
       return PTransformReplacement.of(
           PTransformReplacements.getSingletonMainInput(transform),
@@ -136,12 +135,14 @@ public class SplittableParDoViaKeyedWorkItems {
    * Runner-specific primitive {@link PTransform} that invokes the {@link DoFn.ProcessElement}
    * method for a splittable {@link DoFn}.
    */
-  public static class SplittableProcessViaKeyedWorkItems<InputT, OutputT, RestrictionT>
+  public static class SplittableProcessViaKeyedWorkItems<
+          InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
       extends PTransform<PCollection<KV<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
-    private final ProcessKeyedElements<InputT, OutputT, RestrictionT> original;
+    private final ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
+        original;
 
     public SplittableProcessViaKeyedWorkItems(
-        ProcessKeyedElements<InputT, OutputT, RestrictionT> original) {
+        ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT> original) {
       this.original = original;
     }
 
@@ -159,21 +160,25 @@ public class SplittableParDoViaKeyedWorkItems {
   }
 
   /** A primitive transform wrapping around {@link ProcessFn}. */
-  public static class ProcessElements<InputT, OutputT, RestrictionT, PositionT>
+  public static class ProcessElements<
+          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
       extends PTransform<
           PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>, PCollectionTuple> {
-    private final ProcessKeyedElements<InputT, OutputT, RestrictionT> original;
+    private final ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
+        original;
 
-    public ProcessElements(ProcessKeyedElements<InputT, OutputT, RestrictionT> original) {
+    public ProcessElements(
+        ProcessKeyedElements<InputT, OutputT, RestrictionT, WatermarkEstimatorStateT> original) {
       this.original = original;
     }
 
-    public ProcessFn<InputT, OutputT, RestrictionT, PositionT> newProcessFn(
-        DoFn<InputT, OutputT> fn) {
+    public ProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+        newProcessFn(DoFn<InputT, OutputT> fn) {
       return new ProcessFn<>(
           fn,
           original.getElementCoder(),
           original.getRestrictionCoder(),
+          original.getWatermarkEstimatorStateCoder(),
           original.getInputWindowingStrategy());
     }
 
@@ -219,7 +224,7 @@ public class SplittableParDoViaKeyedWorkItems {
    * <p>See also: https://issues.apache.org/jira/browse/BEAM-1983
    */
   @VisibleForTesting
-  public static class ProcessFn<InputT, OutputT, RestrictionT, PositionT>
+  public static class ProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
       extends DoFn<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
     /**
      * The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is
@@ -248,6 +253,8 @@ public class SplittableParDoViaKeyedWorkItems {
      */
     private StateTag<ValueState<RestrictionT>> restrictionTag;
 
+    private StateTag<ValueState<WatermarkEstimatorStateT>> watermarkEstimatorStateTag;
+
     private final DoFn<InputT, OutputT> fn;
     private final Coder<InputT> elementCoder;
     private final Coder<RestrictionT> restrictionCoder;
@@ -256,7 +263,7 @@ public class SplittableParDoViaKeyedWorkItems {
     private transient @Nullable StateInternalsFactory<byte[]> stateInternalsFactory;
     private transient @Nullable TimerInternalsFactory<byte[]> timerInternalsFactory;
     private transient @Nullable SplittableProcessElementInvoker<
-            InputT, OutputT, RestrictionT, PositionT>
+            InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
         processElementInvoker;
 
     private transient @Nullable DoFnInvoker<InputT, OutputT> invoker;
@@ -265,6 +272,7 @@ public class SplittableParDoViaKeyedWorkItems {
         DoFn<InputT, OutputT> fn,
         Coder<InputT> elementCoder,
         Coder<RestrictionT> restrictionCoder,
+        Coder<WatermarkEstimatorStateT> watermarkEstimatorStateCoder,
         WindowingStrategy<InputT, ?> inputWindowingStrategy) {
       this.fn = fn;
       this.elementCoder = elementCoder;
@@ -276,6 +284,8 @@ public class SplittableParDoViaKeyedWorkItems {
               WindowedValue.getFullCoder(
                   elementCoder, inputWindowingStrategy.getWindowFn().windowCoder()));
       this.restrictionTag = StateTags.value("restriction", restrictionCoder);
+      this.watermarkEstimatorStateTag =
+          StateTags.value("watermarkEstimatorState", watermarkEstimatorStateCoder);
     }
 
     public void setStateInternalsFactory(StateInternalsFactory<byte[]> stateInternalsFactory) {
@@ -287,7 +297,9 @@ public class SplittableParDoViaKeyedWorkItems {
     }
 
     public void setProcessElementInvoker(
-        SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, PositionT> invoker) {
+        SplittableProcessElementInvoker<
+                InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+            invoker) {
       this.processElementInvoker = invoker;
     }
 
@@ -336,7 +348,7 @@ public class SplittableParDoViaKeyedWorkItems {
      * <p>Uses a watermark hold to control watermark advancement.
      */
     @ProcessElement
-    public void processElement(final ProcessContext c) {
+    public void processElement(final ProcessContext c, BoundedWindow boundedWindow) {
       byte[] key = c.element().key();
       StateInternals stateInternals = stateInternalsFactory.stateInternalsForKey(key);
       TimerInternals timerInternals = timerInternalsFactory.timerInternalsForKey(key);
@@ -362,49 +374,168 @@ public class SplittableParDoViaKeyedWorkItems {
           stateInternals.state(stateNamespace, elementTag);
       ValueState<RestrictionT> restrictionState =
           stateInternals.state(stateNamespace, restrictionTag);
+      ValueState<WatermarkEstimatorStateT> watermarkEstimatorState =
+          stateInternals.state(stateNamespace, watermarkEstimatorStateTag);
       WatermarkHoldState holdState = stateInternals.state(stateNamespace, watermarkHoldTag);
 
       KV<WindowedValue<InputT>, RestrictionT> elementAndRestriction;
+      WatermarkEstimatorStateT watermarkEstimatorStateT;
       if (isSeedCall) {
         WindowedValue<KV<InputT, RestrictionT>> windowedValue =
             Iterables.getOnlyElement(c.element().elementsIterable());
         WindowedValue<InputT> element = windowedValue.withValue(windowedValue.getValue().getKey());
         elementState.write(element);
         elementAndRestriction = KV.of(element, windowedValue.getValue().getValue());
+        watermarkEstimatorStateT =
+            invoker.invokeGetInitialWatermarkEstimatorState(
+                new BaseArgumentProvider<InputT, OutputT>() {
+                  @Override
+                  public InputT element(DoFn<InputT, OutputT> doFn) {
+                    return elementAndRestriction.getKey().getValue();
+                  }
+
+                  @Override
+                  public Object restriction() {
+                    return elementAndRestriction.getValue();
+                  }
+
+                  @Override
+                  public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+                    return c.timestamp();
+                  }
+
+                  @Override
+                  public PipelineOptions pipelineOptions() {
+                    return c.getPipelineOptions();
+                  }
+
+                  @Override
+                  public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
+                    return c.pane();
+                  }
+
+                  @Override
+                  public BoundedWindow window() {
+                    return boundedWindow;
+                  }
+
+                  @Override
+                  public String getErrorContext() {
+                    return ProcessFn.class.getSimpleName()
+                        + ".invokeGetInitialWatermarkEstimatorState";
+                  }
+                });
       } else {
         // This is not the first ProcessElement call for this element/restriction - rather,
         // this is a timer firing, so we need to fetch the element and restriction from state.
         elementState.readLater();
         restrictionState.readLater();
+        watermarkEstimatorState.readLater();
         elementAndRestriction = KV.of(elementState.read(), restrictionState.read());
+        watermarkEstimatorStateT = watermarkEstimatorState.read();
       }
 
+      final WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator =
+          invoker.invokeNewWatermarkEstimator(
+              new BaseArgumentProvider<InputT, OutputT>() {
+                @Override
+                public InputT element(DoFn<InputT, OutputT> doFn) {
+                  return elementAndRestriction.getKey().getValue();
+                }
+
+                @Override
+                public Object restriction() {
+                  return elementAndRestriction.getValue();
+                }
+
+                @Override
+                public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+                  return c.timestamp();
+                }
+
+                @Override
+                public PipelineOptions pipelineOptions() {
+                  return c.getPipelineOptions();
+                }
+
+                @Override
+                public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
+                  return c.pane();
+                }
+
+                @Override
+                public BoundedWindow window() {
+                  return boundedWindow;
+                }
+
+                @Override
+                public Object watermarkEstimatorState() {
+                  return watermarkEstimatorStateT;
+                }
+
+                @Override
+                public String getErrorContext() {
+                  return ProcessFn.class.getSimpleName() + ".invokeNewWatermarkEstimator";
+                }
+              });
+
       final RestrictionTracker<RestrictionT, PositionT> tracker =
           invoker.invokeNewTracker(
               new BaseArgumentProvider<InputT, OutputT>() {
                 @Override
+                public InputT element(DoFn<InputT, OutputT> doFn) {
+                  return elementAndRestriction.getKey().getValue();
+                }
+
+                @Override
                 public Object restriction() {
                   return elementAndRestriction.getValue();
                 }
 
                 @Override
+                public Instant timestamp(DoFn<InputT, OutputT> doFn) {
+                  return c.timestamp();
+                }
+
+                @Override
+                public PipelineOptions pipelineOptions() {
+                  return c.getPipelineOptions();
+                }
+
+                @Override
+                public PaneInfo paneInfo(DoFn<InputT, OutputT> doFn) {
+                  return c.pane();
+                }
+
+                @Override
+                public BoundedWindow window() {
+                  return boundedWindow;
+                }
+
+                @Override
                 public String getErrorContext() {
                   return ProcessFn.class.getSimpleName() + ".invokeNewTracker";
                 }
               });
-      SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, PositionT>.Result result =
-          processElementInvoker.invokeProcessElement(
-              invoker, elementAndRestriction.getKey(), tracker);
+      SplittableProcessElementInvoker<
+                  InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+              .Result
+          result =
+              processElementInvoker.invokeProcessElement(
+                  invoker, elementAndRestriction.getKey(), tracker, watermarkEstimator);
 
       // Save state for resuming.
       if (result.getResidualRestriction() == null) {
         // All work for this element/restriction is completed. Clear state and release hold.
         elementState.clear();
         restrictionState.clear();
+        watermarkEstimatorState.clear();
         holdState.clear();
         return;
       }
+
       restrictionState.write(result.getResidualRestriction());
+      watermarkEstimatorState.write(result.getFutureWatermarkEstimatorState());
       @Nullable Instant futureOutputWatermark = result.getFutureOutputWatermark();
       if (futureOutputWatermark == null) {
         futureOutputWatermark = elementAndRestriction.getKey().getTimestamp();
diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java
index 68f0c1f..111493c 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java
@@ -24,6 +24,7 @@ import javax.annotation.Nullable;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.joda.time.Instant;
 
@@ -31,12 +32,14 @@ import org.joda.time.Instant;
  * A runner-specific hook for invoking a {@link DoFn.ProcessElement} method for a splittable {@link
  * DoFn}, in particular, allowing the runner to access the {@link RestrictionTracker}.
  */
-public abstract class SplittableProcessElementInvoker<InputT, OutputT, RestrictionT, PositionT> {
+public abstract class SplittableProcessElementInvoker<
+    InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> {
   /** Specifies how to resume a splittable {@link DoFn.ProcessElement} call. */
   public class Result {
     @Nullable private final RestrictionT residualRestriction;
     private final DoFn.ProcessContinuation continuation;
     private final @Nullable Instant futureOutputWatermark;
+    private final @Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState;
 
     @SuppressFBWarnings(
         value = "NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE",
@@ -44,7 +47,8 @@ public abstract class SplittableProcessElementInvoker<InputT, OutputT, Restricti
     public Result(
         @Nullable RestrictionT residualRestriction,
         DoFn.ProcessContinuation continuation,
-        @Nullable Instant futureOutputWatermark) {
+        @Nullable Instant futureOutputWatermark,
+        @Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState) {
       checkArgument(continuation != null, "continuation must not be null");
       this.continuation = continuation;
       if (continuation.shouldResume()) {
@@ -54,6 +58,7 @@ public abstract class SplittableProcessElementInvoker<InputT, OutputT, Restricti
       }
       this.residualRestriction = residualRestriction;
       this.futureOutputWatermark = futureOutputWatermark;
+      this.futureWatermarkEstimatorState = futureWatermarkEstimatorState;
     }
 
     /**
@@ -73,6 +78,10 @@ public abstract class SplittableProcessElementInvoker<InputT, OutputT, Restricti
     public @Nullable Instant getFutureOutputWatermark() {
       return futureOutputWatermark;
     }
+
+    public @Nullable WatermarkEstimatorStateT getFutureWatermarkEstimatorState() {
+      return futureWatermarkEstimatorState;
+    }
   }
 
   /**
@@ -85,5 +94,6 @@ public abstract class SplittableProcessElementInvoker<InputT, OutputT, Restricti
   public abstract Result invokeProcessElement(
       DoFnInvoker<InputT, OutputT> invoker,
       WindowedValue<InputT> element,
-      RestrictionTracker<RestrictionT, PositionT> tracker);
+      RestrictionTracker<RestrictionT, PositionT> tracker,
+      WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator);
 }
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
index 81b8131..bf005bc 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
@@ -36,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -90,7 +91,7 @@ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest {
     }
   }
 
-  private SplittableProcessElementInvoker<Void, String, OffsetRange, Long>.Result runTest(
+  private SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void>.Result runTest(
       int totalNumOutputs,
       Duration sleepBeforeFirstClaim,
       int numOutputsPerProcessCall,
@@ -100,9 +101,9 @@ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest {
     return runTest(fn, initialRestriction);
   }
 
-  private SplittableProcessElementInvoker<Void, String, OffsetRange, Long>.Result runTest(
+  private SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void>.Result runTest(
       DoFn<Void, String> fn, OffsetRange initialRestriction) {
-    SplittableProcessElementInvoker<Void, String, OffsetRange, Long> invoker =
+    SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void> invoker =
         new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
             fn,
             PipelineOptionsFactory.create(),
@@ -130,12 +131,23 @@ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest {
     return invoker.invokeProcessElement(
         DoFnInvokers.invokerFor(fn),
         WindowedValue.of(null, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING),
-        new OffsetRangeTracker(initialRestriction));
+        new OffsetRangeTracker(initialRestriction),
+        new WatermarkEstimator<Void>() {
+          @Override
+          public Instant currentWatermark() {
+            return GlobalWindow.TIMESTAMP_MIN_VALUE;
+          }
+
+          @Override
+          public Void getState() {
+            return null;
+          }
+        });
   }
 
   @Test
   public void testInvokeProcessElementOutputBounded() throws Exception {
-    SplittableProcessElementInvoker<Void, String, OffsetRange, Long>.Result res =
+    SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void>.Result res =
         runTest(10000, Duration.ZERO, Integer.MAX_VALUE, Duration.ZERO);
     assertFalse(res.getContinuation().shouldResume());
     OffsetRange residualRange = res.getResidualRestriction();
@@ -146,7 +158,7 @@ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest {
 
   @Test
   public void testInvokeProcessElementTimeBounded() throws Exception {
-    SplittableProcessElementInvoker<Void, String, OffsetRange, Long>.Result res =
+    SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void>.Result res =
         runTest(10000, Duration.ZERO, Integer.MAX_VALUE, Duration.millis(100));
     assertFalse(res.getContinuation().shouldResume());
     OffsetRange residualRange = res.getResidualRestriction();
@@ -159,7 +171,7 @@ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest {
 
   @Test
   public void testInvokeProcessElementTimeBoundedWithStartupDelay() throws Exception {
-    SplittableProcessElementInvoker<Void, String, OffsetRange, Long>.Result res =
+    SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void>.Result res =
         runTest(10000, Duration.standardSeconds(3), Integer.MAX_VALUE, Duration.millis(100));
     assertFalse(res.getContinuation().shouldResume());
     OffsetRange residualRange = res.getResidualRestriction();
@@ -171,7 +183,7 @@ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest {
 
   @Test
   public void testInvokeProcessElementVoluntaryReturnStop() throws Exception {
-    SplittableProcessElementInvoker<Void, String, OffsetRange, Long>.Result res =
+    SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void>.Result res =
         runTest(5, Duration.ZERO, Integer.MAX_VALUE, Duration.millis(100));
     assertFalse(res.getContinuation().shouldResume());
     assertNull(res.getResidualRestriction());
@@ -179,7 +191,7 @@ public class OutputAndTimeBoundedSplittableProcessElementInvokerTest {
 
   @Test
   public void testInvokeProcessElementVoluntaryReturnResume() throws Exception {
-    SplittableProcessElementInvoker<Void, String, OffsetRange, Long>.Result res =
+    SplittableProcessElementInvoker<Void, String, OffsetRange, Long, Void>.Result res =
         runTest(10, Duration.ZERO, 5, Duration.millis(100));
     assertTrue(res.getContinuation().shouldResume());
     assertEquals(new OffsetRange(5, 10), res.getResidualRestriction());
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
index 7c43311..78e54ed 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoProcessFnTest.java
@@ -44,15 +44,18 @@ import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.range.OffsetRange;
 import org.apache.beam.sdk.testing.ResetDateTimeProvider;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnTester;
 import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -120,7 +123,8 @@ public class SplittableParDoProcessFnTest {
    * A helper for testing {@link ProcessFn} on 1 element (but possibly over multiple {@link
    * DoFn.ProcessElement} calls).
    */
-  private static class ProcessFnTester<InputT, OutputT, RestrictionT, PositionT>
+  private static class ProcessFnTester<
+          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
       implements AutoCloseable {
     private final DoFnTester<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> tester;
     private Instant currentProcessingTime;
@@ -133,6 +137,7 @@ public class SplittableParDoProcessFnTest {
         final DoFn<InputT, OutputT> fn,
         Coder<InputT> inputCoder,
         Coder<RestrictionT> restrictionCoder,
+        Coder<WatermarkEstimatorStateT> watermarkEstimatorStateCoder,
         int maxOutputsPerBundle,
         Duration maxBundleDuration)
         throws Exception {
@@ -140,8 +145,14 @@ public class SplittableParDoProcessFnTest {
       // encode IntervalWindow's because that's what all tests here use.
       WindowingStrategy<InputT, BoundedWindow> windowingStrategy =
           (WindowingStrategy) WindowingStrategy.of(FixedWindows.of(Duration.standardSeconds(1)));
-      final ProcessFn<InputT, OutputT, RestrictionT, PositionT> processFn =
-          new ProcessFn<>(fn, inputCoder, restrictionCoder, windowingStrategy);
+      final ProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+          processFn =
+              new ProcessFn<>(
+                  fn,
+                  inputCoder,
+                  restrictionCoder,
+                  watermarkEstimatorStateCoder,
+                  windowingStrategy);
       this.tester = DoFnTester.of(processFn);
       this.timerInternals = new InMemoryTimerInternals();
       this.stateInternals = new TestInMemoryStateInternals<>("dummy");
@@ -294,12 +305,13 @@ public class SplittableParDoProcessFnTest {
         new IntervalWindow(
             base.minus(Duration.standardMinutes(1)), base.plus(Duration.standardMinutes(1)));
 
-    ProcessFnTester<Integer, String, SomeRestriction, Void> tester =
+    ProcessFnTester<Integer, String, SomeRestriction, Void, Void> tester =
         new ProcessFnTester<>(
             base,
             fn,
             BigEndianIntegerCoder.of(),
             SerializableCoder.of(SomeRestriction.class),
+            VoidCoder.of(),
             MAX_OUTPUTS_PER_BUNDLE,
             MAX_BUNDLE_DURATION);
     tester.startElement(
@@ -319,9 +331,12 @@ public class SplittableParDoProcessFnTest {
 
   private static class WatermarkUpdateFn extends DoFn<Instant, String> {
     @ProcessElement
-    public void process(ProcessContext c, RestrictionTracker<OffsetRange, Long> tracker) {
+    public void process(
+        ProcessContext c,
+        RestrictionTracker<OffsetRange, Long> tracker,
+        ManualWatermarkEstimator<Instant> watermarkEstimator) {
       for (long i = tracker.currentRestriction().getFrom(); tracker.tryClaim(i); ++i) {
-        c.updateWatermark(c.element().plus(Duration.standardSeconds(i)));
+        watermarkEstimator.setWatermark(c.element().plus(Duration.standardSeconds(i)));
         c.output(String.valueOf(i));
       }
     }
@@ -335,6 +350,17 @@ public class SplittableParDoProcessFnTest {
     public OffsetRangeTracker newTracker(@Restriction OffsetRange range) {
       return new OffsetRangeTracker(range);
     }
+
+    @GetInitialWatermarkEstimatorState
+    public Instant getInitialWatermarkEstimatorState() {
+      return GlobalWindow.TIMESTAMP_MIN_VALUE;
+    }
+
+    @NewWatermarkEstimator
+    public WatermarkEstimators.Manual newWatermarkEstimator(
+        @WatermarkEstimatorState Instant watermarkEstimatorState) {
+      return new WatermarkEstimators.Manual(watermarkEstimatorState);
+    }
   }
 
   @Test
@@ -342,12 +368,13 @@ public class SplittableParDoProcessFnTest {
     DoFn<Instant, String> fn = new WatermarkUpdateFn();
     Instant base = Instant.now();
 
-    ProcessFnTester<Instant, String, OffsetRange, Long> tester =
+    ProcessFnTester<Instant, String, OffsetRange, Long, Instant> tester =
         new ProcessFnTester<>(
             base,
             fn,
             InstantCoder.of(),
             SerializableCoder.of(OffsetRange.class),
+            InstantCoder.of(),
             3,
             MAX_BUNDLE_DURATION);
 
@@ -385,12 +412,13 @@ public class SplittableParDoProcessFnTest {
     DoFn<Integer, String> fn = new SelfInitiatedResumeFn();
     Instant base = Instant.now();
     dateTimeProvider.setDateTimeFixed(base.getMillis());
-    ProcessFnTester<Integer, String, SomeRestriction, Void> tester =
+    ProcessFnTester<Integer, String, SomeRestriction, Void, Void> tester =
         new ProcessFnTester<>(
             base,
             fn,
             BigEndianIntegerCoder.of(),
             SerializableCoder.of(SomeRestriction.class),
+            VoidCoder.of(),
             MAX_OUTPUTS_PER_BUNDLE,
             MAX_BUNDLE_DURATION);
 
@@ -447,12 +475,13 @@ public class SplittableParDoProcessFnTest {
     DoFn<Integer, String> fn = new CounterFn(1);
     Instant base = Instant.now();
     dateTimeProvider.setDateTimeFixed(base.getMillis());
-    ProcessFnTester<Integer, String, OffsetRange, Long> tester =
+    ProcessFnTester<Integer, String, OffsetRange, Long, Void> tester =
         new ProcessFnTester<>(
             base,
             fn,
             BigEndianIntegerCoder.of(),
             SerializableCoder.of(OffsetRange.class),
+            VoidCoder.of(),
             MAX_OUTPUTS_PER_BUNDLE,
             MAX_BUNDLE_DURATION);
 
@@ -476,12 +505,13 @@ public class SplittableParDoProcessFnTest {
     Instant base = Instant.now();
     int baseIndex = 42;
 
-    ProcessFnTester<Integer, String, OffsetRange, Long> tester =
+    ProcessFnTester<Integer, String, OffsetRange, Long, Void> tester =
         new ProcessFnTester<>(
             base,
             fn,
             BigEndianIntegerCoder.of(),
             SerializableCoder.of(OffsetRange.class),
+            VoidCoder.of(),
             max,
             MAX_BUNDLE_DURATION);
 
@@ -522,12 +552,13 @@ public class SplittableParDoProcessFnTest {
     Instant base = Instant.now();
     int baseIndex = 42;
 
-    ProcessFnTester<Integer, String, OffsetRange, Long> tester =
+    ProcessFnTester<Integer, String, OffsetRange, Long, Void> tester =
         new ProcessFnTester<>(
             base,
             fn,
             BigEndianIntegerCoder.of(),
             SerializableCoder.of(OffsetRange.class),
+            VoidCoder.of(),
             max,
             maxBundleDuration);
 
@@ -591,12 +622,13 @@ public class SplittableParDoProcessFnTest {
   @Test
   public void testInvokesLifecycleMethods() throws Exception {
     DoFn<Integer, String> fn = new LifecycleVerifyingFn();
-    try (ProcessFnTester<Integer, String, SomeRestriction, Void> tester =
+    try (ProcessFnTester<Integer, String, SomeRestriction, Void, Void> tester =
         new ProcessFnTester<>(
             Instant.now(),
             fn,
             BigEndianIntegerCoder.of(),
             SerializableCoder.of(SomeRestriction.class),
+            VoidCoder.of(),
             MAX_OUTPUTS_PER_BUNDLE,
             MAX_BUNDLE_DURATION)) {
       tester.startElement(42, new SomeRestriction());
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
index b0ff2d3..fa99584 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java
@@ -47,7 +47,8 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 
-class SplittableProcessElementsEvaluatorFactory<InputT, OutputT, RestrictionT, PositionT>
+class SplittableProcessElementsEvaluatorFactory<
+        InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
     implements TransformEvaluatorFactory {
   private final ParDoEvaluatorFactory<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT>
       delegateFactory;
@@ -70,9 +71,12 @@ class SplittableProcessElementsEvaluatorFactory<InputT, OutputT, RestrictionT, P
                 checkArgument(
                     ProcessElements.class.isInstance(application.getTransform()),
                     "No know extraction of the fn from " + application);
-                final ProcessElements<InputT, OutputT, RestrictionT, PositionT> transform =
-                    (ProcessElements<InputT, OutputT, RestrictionT, PositionT>)
-                        application.getTransform();
+                final ProcessElements<
+                        InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+                    transform =
+                        (ProcessElements<
+                                InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>)
+                            application.getTransform();
                 return DoFnLifecycleManager.of(transform.newProcessFn(transform.getFn()));
               }
             },
@@ -107,12 +111,12 @@ class SplittableProcessElementsEvaluatorFactory<InputT, OutputT, RestrictionT, P
       AppliedPTransform<
               PCollection<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>,
               PCollectionTuple,
-              ProcessElements<InputT, OutputT, RestrictionT, PositionT>>
+              ProcessElements<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>>
           application,
       CommittedBundle<InputT> inputBundle)
       throws Exception {
-    final ProcessElements<InputT, OutputT, RestrictionT, PositionT> transform =
-        application.getTransform();
+    final ProcessElements<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
+        transform = application.getTransform();
 
     final DoFnLifecycleManagerRemovingTransformEvaluator<
             KeyedWorkItem<byte[], KV<InputT, RestrictionT>>>
@@ -129,8 +133,8 @@ class SplittableProcessElementsEvaluatorFactory<InputT, OutputT, RestrictionT, P
                 Collections.emptyMap());
     final ParDoEvaluator<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>> pde =
         evaluator.getParDoEvaluator();
-    final ProcessFn<InputT, OutputT, RestrictionT, PositionT> processFn =
-        (ProcessFn<InputT, OutputT, RestrictionT, PositionT>)
+    final ProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> processFn =
+        (ProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>)
             ProcessFnRunner.class.cast(pde.getFnRunner()).getFn();
 
     final DirectExecutionContext.DirectStepContext stepContext = pde.getStepContext();
@@ -192,7 +196,7 @@ class SplittableProcessElementsEvaluatorFactory<InputT, OutputT, RestrictionT, P
         windowingStrategy,
         doFnSchemaInformation,
         sideInputMapping) -> {
-      ProcessFn<InputT, OutputT, RestrictionT, ?> processFn = (ProcessFn) fn;
+      ProcessFn<InputT, OutputT, RestrictionT, ?, ?> processFn = (ProcessFn) fn;
       return DoFnRunners.newProcessFnRunner(
           processFn,
           options,
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
index a03cb38..a2272af 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
@@ -136,12 +136,12 @@ class TransformEvaluatorRegistry {
    * once SDF is reorganized appropriately.
    */
   private static class SplittableParDoProcessElementsTranslator
-      extends TransformPayloadTranslator.NotSerializable<ProcessElements<?, ?, ?, ?>> {
+      extends TransformPayloadTranslator.NotSerializable<ProcessElements<?, ?, ?, ?, ?>> {
 
     private SplittableParDoProcessElementsTranslator() {}
 
     @Override
-    public String getUrn(ProcessElements<?, ?, ?, ?> transform) {
+    public String getUrn(ProcessElements<?, ?, ?, ?, ?> transform) {
       return SPLITTABLE_PROCESS_URN;
     }
   }
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java
index 4efdc34..6c4f25b 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java
@@ -717,14 +717,15 @@ class FlinkStreamingTransformTranslators {
   }
 
   private static class SplittableProcessElementsStreamingTranslator<
-          InputT, OutputT, RestrictionT, PositionT>
+          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
       extends FlinkStreamingPipelineTranslator.StreamTransformTranslator<
           SplittableParDoViaKeyedWorkItems.ProcessElements<
-              InputT, OutputT, RestrictionT, PositionT>> {
+              InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>> {
 
     @Override
     public void translateNode(
-        SplittableParDoViaKeyedWorkItems.ProcessElements<InputT, OutputT, RestrictionT, PositionT>
+        SplittableParDoViaKeyedWorkItems.ProcessElements<
+                InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
             transform,
         FlinkStreamingTranslationContext context) {
 
@@ -1277,12 +1278,13 @@ class FlinkStreamingTransformTranslators {
    */
   private static class SplittableParDoProcessElementsTranslator
       extends PTransformTranslation.TransformPayloadTranslator.NotSerializable<
-          SplittableParDoViaKeyedWorkItems.ProcessElements<?, ?, ?, ?>> {
+          SplittableParDoViaKeyedWorkItems.ProcessElements<?, ?, ?, ?, ?>> {
 
     private SplittableParDoProcessElementsTranslator() {}
 
     @Override
-    public String getUrn(SplittableParDoViaKeyedWorkItems.ProcessElements<?, ?, ?, ?> transform) {
+    public String getUrn(
+        SplittableParDoViaKeyedWorkItems.ProcessElements<?, ?, ?, ?, ?> transform) {
       return SPLITTABLE_PROCESS_URN;
     }
   }
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index 57b47c4..5f7ec26 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -72,6 +72,7 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.Pipeline.PipelineVisitor;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -90,6 +91,7 @@ import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.display.HasDisplayData;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
@@ -1037,12 +1039,16 @@ public class DataflowPipelineTranslator {
             if (context.isFnApi()) {
               DoFnSignature signature = DoFnSignatures.signatureForDoFn(transform.getFn());
               if (signature.processElement().isSplittable()) {
-                Coder<?> restrictionCoder =
-                    DoFnInvokers.invokerFor(transform.getFn())
-                        .invokeGetRestrictionCoder(
-                            context.getInput(transform).getPipeline().getCoderRegistry());
+                DoFnInvoker<?, ?> doFnInvoker = DoFnInvokers.invokerFor(transform.getFn());
+                Coder<?> restrictionAndWatermarkStateCoder =
+                    KvCoder.of(
+                        doFnInvoker.invokeGetRestrictionCoder(
+                            context.getInput(transform).getPipeline().getCoderRegistry()),
+                        doFnInvoker.invokeGetWatermarkEstimatorStateCoder(
+                            context.getInput(transform).getPipeline().getCoderRegistry()));
                 stepContext.addInput(
-                    PropertyNames.RESTRICTION_ENCODING, translateCoder(restrictionCoder, context));
+                    PropertyNames.RESTRICTION_ENCODING,
+                    translateCoder(restrictionAndWatermarkStateCoder, context));
               }
             }
           }
@@ -1147,8 +1153,10 @@ public class DataflowPipelineTranslator {
             translateTyped(transform, context);
           }
 
-          private <InputT, OutputT, RestrictionT> void translateTyped(
-              SplittableParDo.ProcessKeyedElements<InputT, OutputT, RestrictionT> transform,
+          private <InputT, OutputT, RestrictionT, WatermarkEstimatorStateT> void translateTyped(
+              SplittableParDo.ProcessKeyedElements<
+                      InputT, OutputT, RestrictionT, WatermarkEstimatorStateT>
+                  transform,
               TranslationContext context) {
             DoFnSchemaInformation doFnSchemaInformation;
             doFnSchemaInformation =
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
index 2f99372..5e12e08 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
@@ -44,6 +44,7 @@ import org.apache.beam.runners.core.construction.SdkComponents;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.runners.core.construction.TransformPayloadTranslatorRegistrar;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -51,6 +52,7 @@ import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.ParDo.SingleOutput;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
@@ -290,11 +292,15 @@ public class PrimitiveParDoSingleFactory<InputT, OutputT>
             @Override
             public String translateRestrictionCoderId(SdkComponents newComponents) {
               if (signature.processElement().isSplittable()) {
-                Coder<?> restrictionCoder =
-                    DoFnInvokers.invokerFor(doFn)
-                        .invokeGetRestrictionCoder(transform.getPipeline().getCoderRegistry());
+                DoFnInvoker<?, ?> doFnInvoker = DoFnInvokers.invokerFor(doFn);
+                final Coder<?> restrictionAndWatermarkStateCoder =
+                    KvCoder.of(
+                        doFnInvoker.invokeGetRestrictionCoder(
+                            transform.getPipeline().getCoderRegistry()),
+                        doFnInvoker.invokeGetWatermarkEstimatorStateCoder(
+                            transform.getPipeline().getCoderRegistry()));
                 try {
-                  return newComponents.registerCoder(restrictionCoder);
+                  return newComponents.registerCoder(restrictionAndWatermarkStateCoder);
                 } catch (IOException e) {
                   throw new IllegalStateException(
                       String.format(
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
index f8769f4..d48f060 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
@@ -65,6 +65,7 @@ import org.apache.beam.runners.dataflow.util.PropertyNames;
 import org.apache.beam.runners.dataflow.util.Structs;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
@@ -789,8 +790,10 @@ public class DataflowPipelineTranslatorTest implements Serializable {
     assertThat(
         ParDoTranslation.doFnWithExecutionInformationFromProto(payload.getDoFn()).getDoFn(),
         instanceOf(TestSplittableFn.class));
-    assertThat(
-        components.getCoder(payload.getRestrictionCoderId()), instanceOf(SerializableCoder.class));
+    Coder expectedRestrictionAndStateCoder =
+        KvCoder.of(SerializableCoder.of(OffsetRange.class), VoidCoder.of());
+    assertEquals(
+        expectedRestrictionAndStateCoder, components.getCoder(payload.getRestrictionCoderId()));
 
     // In the Fn API case, we still translate the restriction coder into the RESTRICTION_CODER
     // property as a CloudObject, and it gets passed through the Dataflow backend, but in the end
@@ -800,7 +803,7 @@ public class DataflowPipelineTranslatorTest implements Serializable {
             (CloudObject)
                 Structs.getObject(
                     splittableParDo.getProperties(), PropertyNames.RESTRICTION_ENCODING));
-    assertEquals(SerializableCoder.of(OffsetRange.class), restrictionCoder);
+    assertEquals(expectedRestrictionAndStateCoder, restrictionCoder);
   }
 
   @Test
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java
index b840c68..dc1e73e 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java
@@ -22,6 +22,7 @@ import static org.apache.beam.runners.dataflow.util.CloudObjects.coderFromCloudO
 import static org.apache.beam.runners.dataflow.util.Structs.getBytes;
 import static org.apache.beam.runners.dataflow.util.Structs.getObject;
 import static org.apache.beam.sdk.util.SerializableUtils.deserializeFromByteArray;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 
 import java.util.Collection;
 import java.util.List;
@@ -47,7 +48,6 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.DoFnInfo;
@@ -75,15 +75,22 @@ class SplittableProcessFnFactory {
           (DoFnInfo<?, ?>)
               deserializeFromByteArray(
                   getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), "Serialized DoFnInfo");
-      Coder restrictionCoder =
+      Coder restrictionAndStateCoder =
           coderFromCloudObject(
               fromSpec(getObject(cloudUserFn, WorkerPropertyNames.RESTRICTION_CODER)));
+      checkState(
+          restrictionAndStateCoder instanceof KvCoder,
+          "Expected pair coder with restriction as key coder and watermark estimator state as value coder, but received %s.",
+          restrictionAndStateCoder);
+      Coder restrictionCoder = ((KvCoder) restrictionAndStateCoder).getKeyCoder();
+      Coder watermarkEstimatorStateCoder = ((KvCoder) restrictionAndStateCoder).getValueCoder();
 
       ProcessFn processFn =
           new ProcessFn(
               doFnInfo.getDoFn(),
               doFnInfo.getInputCoder(),
               restrictionCoder,
+              watermarkEstimatorStateCoder,
               doFnInfo.getWindowingStrategy());
 
       return DoFnInfo.forFn(
@@ -102,11 +109,7 @@ class SplittableProcessFnFactory {
   }
 
   private static class SplittableDoFnRunnerFactory<
-          InputT,
-          OutputT,
-          RestrictionT,
-          PositionT,
-          TrackerT extends RestrictionTracker<RestrictionT, PositionT>>
+          InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>
       implements DoFnRunnerFactory<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> {
     @Override
     public DoFnRunner<KeyedWorkItem<byte[], KV<InputT, RestrictionT>>, OutputT> createRunner(
@@ -124,8 +127,8 @@ class SplittableProcessFnFactory {
         OutputManager outputManager,
         DoFnSchemaInformation doFnSchemaInformation,
         Map<String, PCollectionView<?>> sideInputMapping) {
-      ProcessFn<InputT, OutputT, RestrictionT, TrackerT> processFn =
-          (ProcessFn<InputT, OutputT, RestrictionT, TrackerT>) fn;
+      ProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT> processFn =
+          (ProcessFn<InputT, OutputT, RestrictionT, PositionT, WatermarkEstimatorStateT>) fn;
       processFn.setStateInternalsFactory(key -> (StateInternals) stepContext.stateInternals());
       processFn.setTimerInternalsFactory(key -> stepContext.timerInternals());
       processFn.setProcessElementInvoker(
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
index 6c51686..835dcbf 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
@@ -41,8 +41,10 @@ import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.util.NameUtils;
 import org.apache.beam.sdk.util.SerializableUtils;
 import org.apache.beam.sdk.values.KV;
@@ -518,6 +520,7 @@ public class Read {
         RestrictionTracker<
                 KV<UnboundedSource<OutputT, CheckpointT>, CheckpointT>, UnboundedSourceValue[]>
             tracker,
+        ManualWatermarkEstimator<Instant> watermarkEstimator,
         OutputReceiver<ValueWithRecordId<OutputT>> receiver,
         BundleFinalizer bundleFinalizer)
         throws IOException {
@@ -525,7 +528,7 @@ public class Read {
       while (tracker.tryClaim(out)) {
         receiver.outputWithTimestamp(
             new ValueWithRecordId<>(out[0].getValue(), out[0].getId()), out[0].getTimestamp());
-        context.updateWatermark(out[0].getWatermark());
+        watermarkEstimator.setWatermark(out[0].getWatermark());
       }
 
       // Add the checkpoint mark to be finalized if the checkpoint mark isn't trivial.
@@ -547,6 +550,17 @@ public class Read {
       return ProcessContinuation.resume();
     }
 
+    @GetInitialWatermarkEstimatorState
+    public Instant getInitialWatermarkEstimatorState(@Timestamp Instant currentElementTimestamp) {
+      return currentElementTimestamp;
+    }
+
+    @NewWatermarkEstimator
+    public WatermarkEstimators.Manual newWatermarkEstimator(
+        @WatermarkEstimatorState Instant watermarkEstimatorState) {
+      return new WatermarkEstimators.Manual(watermarkEstimatorState);
+    }
+
     @GetRestrictionCoder
     public Coder<KV<UnboundedSource<OutputT, CheckpointT>, CheckpointT>> restrictionCoder() {
       return KvCoder.of(
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index 1790ecc..7e7f3da 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -265,18 +265,6 @@ public abstract class DoFn<InputT, OutputT> implements Serializable, HasDisplayD
      * data has been explicitly requested. See {@link Window} for more information.
      */
     public abstract PaneInfo pane();
-
-    /**
-     * Gives the runner a (best-effort) lower bound about the timestamps of future output associated
-     * with the current element.
-     *
-     * <p>If the {@link DoFn} has multiple outputs, the watermark applies to all of them.
-     *
-     * <p>Only splittable {@link DoFn DoFns} are allowed to call this method. It is safe to call
-     * this method from a different thread than the one running {@link ProcessElement}, but all
-     * calls must finish before {@link ProcessElement} returns.
-     */
-    public abstract void updateWatermark(Instant watermark);
   }
 
   /** Information accessible when running a {@link DoFn.OnTimer} method. */
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
index 6c445ec..4d9f81f 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java
@@ -555,11 +555,6 @@ public class DoFnTester<InputT, OutputT> implements AutoCloseable {
     }
 
     @Override
-    public void updateWatermark(Instant watermark) {
-      throw new UnsupportedOperationException();
-    }
-
-    @Override
     public PipelineOptions getPipelineOptions() {
       return options;
     }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
index d59e5c4..6f4420d 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
@@ -56,9 +56,11 @@ import org.apache.beam.sdk.transforms.Contextful.Fn;
 import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement;
 import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement;
 import org.apache.beam.sdk.transforms.Watch.Growth.PollResult;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.VarInt;
 import org.apache.beam.sdk.values.KV;
@@ -754,11 +756,21 @@ public class Watch {
       while (tracker.tryClaim(position)) {
         TimestampedValue<OutputT> value = c.element().getValue().get((int) position);
         c.outputWithTimestamp(KV.of(c.element().getKey(), value.getValue()), value.getTimestamp());
-        c.updateWatermark(value.getTimestamp());
         position += 1L;
       }
     }
 
+    @GetInitialWatermarkEstimatorState
+    public Instant getInitialWatermarkEstimatorState(@Timestamp Instant currentElementTimestamp) {
+      return currentElementTimestamp;
+    }
+
+    @NewWatermarkEstimator
+    public WatermarkEstimators.MonotonicallyIncreasing newWatermarkEstimator(
+        @WatermarkEstimatorState Instant watermarkEstimatorState) {
+      return new WatermarkEstimators.MonotonicallyIncreasing(watermarkEstimatorState);
+    }
+
     @GetInitialRestriction
     public OffsetRange getInitialRestriction(
         @Element KV<InputT, List<TimestampedValue<OutputT>>> element) {
@@ -806,10 +818,22 @@ public class Watch {
           };
     }
 
+    @GetInitialWatermarkEstimatorState
+    public Instant getInitialWatermarkEstimatorState(@Timestamp Instant currentElementTimestamp) {
+      return currentElementTimestamp;
+    }
+
+    @NewWatermarkEstimator
+    public WatermarkEstimators.Manual newWatermarkEstimator(
+        @WatermarkEstimatorState Instant watermarkEstimatorState) {
+      return new WatermarkEstimators.Manual(watermarkEstimatorState);
+    }
+
     @ProcessElement
     public ProcessContinuation process(
         ProcessContext c,
-        RestrictionTracker<GrowthState, KV<Growth.PollResult<OutputT>, TerminationStateT>> tracker)
+        RestrictionTracker<GrowthState, KV<Growth.PollResult<OutputT>, TerminationStateT>> tracker,
+        ManualWatermarkEstimator<Instant> watermarkEstimator)
         throws Exception {
 
       GrowthState currentRestriction = tracker.currentRestriction();
@@ -824,9 +848,7 @@ public class Watch {
                 priorPoll.getOutputs().size());
             c.output(KV.of(c.element(), priorPoll.getOutputs()));
           }
-          if (priorPoll.getWatermark() != null) {
-            c.updateWatermark(priorPoll.getWatermark());
-          }
+          watermarkEstimator.setWatermark(priorPoll.getWatermark());
         }
         return stop();
       }
@@ -868,7 +890,7 @@ public class Watch {
       }
 
       if (newResults.getWatermark() != null) {
-        c.updateWatermark(newResults.getWatermark());
+        watermarkEstimator.setWatermark(newResults.getWatermark());
       }
 
       Instant currentTime = Instant.now();
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
index 014e6dd..f4a5dca 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
@@ -341,13 +341,31 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
     }
   }
 
+  /**
+   * Default implementation for {@link DoFn.GetInitialWatermarkEstimatorState}, for delegation by
+   * bytebuddy.
+   */
+  public static class DefaultGetInitialWatermarkEstimatorState {
+    /** The default watermark estimator state is {@code null}. */
+    @SuppressWarnings("unused")
+    public static <InputT, OutputT, WatermarkEstimatorStateT>
+        WatermarkEstimator<WatermarkEstimatorStateT> invokeNewWatermarkEstimator(
+            DoFnInvoker.ArgumentProvider<InputT, OutputT> argumentProvider) {
+      return null;
+    }
+  }
+
   /** Default implementation of {@link DoFn.NewWatermarkEstimator}, for delegation by bytebuddy. */
   public static class DefaultNewWatermarkEstimator {
 
-    /** Returns a watermark estimator that always reports the minimum watermark. */
+    /**
+     * Constructs a new watermark estimator from the state type if it is annotated wtih {@link
+     * HasDefaultWatermarkEstimator} otherwise returns a watermark estimator that always reports the
+     * minimum watermark.
+     */
     @SuppressWarnings("unused")
     public static <InputT, OutputT, WatermarkEstimatorStateT>
-        WatermarkEstimator<WatermarkEstimatorStateT> invokeNewTracker(
+        WatermarkEstimator<WatermarkEstimatorStateT> invokeNewWatermarkEstimator(
             DoFnInvoker.ArgumentProvider<InputT, OutputT> argumentProvider) {
       if (argumentProvider.watermarkEstimatorState() instanceof HasDefaultWatermarkEstimator) {
         return ((HasDefaultWatermarkEstimator) argumentProvider.watermarkEstimatorState())
@@ -451,7 +469,7 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
             .intercept(getWatermarkEstimatorStateCoderDelegation(clazzDescription, signature))
             .method(ElementMatchers.named("invokeGetInitialWatermarkEstimatorState"))
             .intercept(
-                delegateMethodWithExtraParametersOrNoop(
+                getInitialWatermarkEstimatorStateDelegation(
                     clazzDescription, signature.getInitialWatermarkEstimatorState()))
             .method(ElementMatchers.named("invokeNewWatermarkEstimator"))
             .intercept(
@@ -512,6 +530,16 @@ class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
     }
   }
 
+  private static Implementation getInitialWatermarkEstimatorStateDelegation(
+      TypeDescription doFnType,
+      @Nullable DoFnSignature.GetInitialWatermarkEstimatorStateMethod signature) {
+    if (signature == null) {
+      return MethodDelegation.to(DefaultGetInitialWatermarkEstimatorState.class);
+    } else {
+      return new DoFnMethodWithExtraParametersDelegation(doFnType, signature);
+    }
+  }
+
   private static Implementation newWatermarkEstimatorDelegation(
       TypeDescription doFnType, @Nullable DoFnSignature.NewWatermarkEstimatorMethod signature) {
     if (signature == null) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
index 412d5f7..3053775 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
@@ -112,13 +112,16 @@ public interface DoFnInvoker<InputT, OutputT> {
           ArgumentProvider<InputT, OutputT> arguments);
 
   /** Invoke the {@link DoFn.GetInitialWatermarkEstimatorState} method on the bound {@link DoFn}. */
-  Object invokeGetInitialWatermarkEstimatorState(ArgumentProvider<InputT, OutputT> arguments);
+  @SuppressWarnings("TypeParameterUnusedInFormals")
+  <WatermarkEstimatorStateT> WatermarkEstimatorStateT invokeGetInitialWatermarkEstimatorState(
+      ArgumentProvider<InputT, OutputT> arguments);
 
   /**
    * Invoke the {@link DoFn.GetWatermarkEstimatorStateCoder} method on the bound {@link DoFn}.
    * Called only during pipeline construction time.
    */
-  Coder<?> invokeGetWatermarkEstimatorStateCoder(CoderRegistry coderRegistry);
+  <WatermarkEstimatorStateT> Coder<WatermarkEstimatorStateT> invokeGetWatermarkEstimatorStateCoder(
+      CoderRegistry coderRegistry);
 
   /** Get the bound {@link DoFn}. */
   DoFn<InputT, OutputT> getFn();
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/ManualWatermarkEstimator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/ManualWatermarkEstimator.java
index a6e2797..0121af0 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/ManualWatermarkEstimator.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/ManualWatermarkEstimator.java
@@ -24,7 +24,8 @@ import org.joda.time.Instant;
 
 /**
  * A {@link WatermarkEstimator} which is controlled manually from within a {@link DoFn}. The {@link
- * DoFn} must invoke {@link #setWatermark} to advance the watermark.
+ * DoFn} must invoke {@link #setWatermark} to advance the watermark. See {@link
+ * WatermarkEstimators.Manual} for a concrete implementation.
  */
 @Experimental(Kind.SPLITTABLE_DO_FN)
 public interface ManualWatermarkEstimator<WatermarkEstimatorStateT>
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimator.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimator.java
index 343fea6..92d9248 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimator.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimator.java
@@ -24,7 +24,7 @@ import org.joda.time.Instant;
 
 /**
  * A {@link WatermarkEstimator} which is used for estimating output watermarks of a splittable
- * {@link DoFn}.
+ * {@link DoFn}. See {@link WatermarkEstimators} for commonly used watermark estimators.
  */
 @Experimental(Kind.SPLITTABLE_DO_FN)
 public interface WatermarkEstimator<WatermarkEstimatorStateT> {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimators.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimators.java
index 40992f1..bf31abe 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimators.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimators.java
@@ -34,6 +34,7 @@ public class WatermarkEstimators {
   /** Concrete implementation of a {@link ManualWatermarkEstimator}. */
   public static class Manual implements ManualWatermarkEstimator<Instant> {
     private Instant watermark;
+    private Instant lastReportedWatermark;
 
     public Manual(Instant watermark) {
       this.watermark = checkNotNull(watermark, "watermark must not be null.");
@@ -48,25 +49,17 @@ public class WatermarkEstimators {
 
     @Override
     public void setWatermark(Instant watermark) {
-      if (watermark.isBefore(GlobalWindow.TIMESTAMP_MIN_VALUE)
-          || watermark.isAfter(GlobalWindow.TIMESTAMP_MAX_VALUE)) {
-        throw new IllegalArgumentException(
-            String.format(
-                "Provided watermark %s must be within bounds [%s, %s].",
-                watermark, GlobalWindow.TIMESTAMP_MIN_VALUE, GlobalWindow.TIMESTAMP_MAX_VALUE));
-      }
-      if (watermark.isBefore(this.watermark)) {
-        throw new IllegalArgumentException(
-            String.format(
-                "Watermark must be monotonically increasing. Provided watermark %s is less then "
-                    + "current watermark %s.",
-                watermark, this.watermark));
-      }
-      this.watermark = watermark;
+      this.lastReportedWatermark = watermark;
     }
 
     @Override
     public Instant currentWatermark() {
+      // Beyond bounds error checking isn't important since the runner is expected to perform
+      // watermark bounds checking.
+      if (lastReportedWatermark != null && lastReportedWatermark.isAfter(watermark)) {
+        watermark = lastReportedWatermark;
+      }
+
       return watermark;
     }
 
@@ -76,7 +69,14 @@ public class WatermarkEstimators {
     }
   }
 
-  /** A watermark estimator that tracks wall time. */
+  /**
+   * A watermark estimator that tracks wall time.
+   *
+   * <p>Note that this watermark estimator expects wall times of all machines performing the
+   * processing to be close to each other. Any machine with a wall clock that is far in the past may
+   * cause the pipeline to perform poorly while a watermark far in the future may cause records to
+   * be marked as late.
+   */
   public static class WallTime implements WatermarkEstimator<Instant> {
     private Instant watermark;
 
@@ -93,6 +93,8 @@ public class WatermarkEstimators {
 
     @Override
     public Instant currentWatermark() {
+      // Beyond bounds error checking isn't important since the runner is expected to perform
+      // watermark bounds checking.
       Instant now = Instant.now();
       this.watermark = now.isAfter(watermark) ? now : watermark;
       return watermark;
@@ -105,15 +107,17 @@ public class WatermarkEstimators {
   }
 
   /**
-   * A watermark estimator that observes and timestamps of records output from a DoFn reporting the
+   * A watermark estimator that observes timestamps of records output from a DoFn reporting the
    * timestamp of the last element seen as the current watermark.
    *
-   * <p>Note that this watermark estimator requires output timestamps in monotonically increasing
-   * order.
+   * <p>Note that this watermark estimator expects output timestamps in monotonically increasing
+   * order. If they are not, then the watermark will advance based upon the last observed timestamp
+   * as long as it is greater then any previously reported watermark.
    */
   public static class MonotonicallyIncreasing
       implements TimestampObservingWatermarkEstimator<Instant> {
     private Instant watermark;
+    private Instant lastObservedTimestamp;
 
     public MonotonicallyIncreasing(Instant watermark) {
       this.watermark = checkNotNull(watermark, "timestamp must not be null.");
@@ -128,20 +132,16 @@ public class WatermarkEstimators {
 
     @Override
     public void observeTimestamp(Instant timestamp) {
-      // Beyond bounds error checking isn't important since the system is expected to perform output
-      // timestamp bounds checking already.
-      if (timestamp.isBefore(this.watermark)) {
-        throw new IllegalArgumentException(
-            String.format(
-                "Timestamp must be monotonically increasing. Provided timestamp %s is less then "
-                    + "previously provided timestamp %s.",
-                timestamp, this.watermark));
-      }
-      this.watermark = timestamp;
+      this.lastObservedTimestamp = timestamp;
     }
 
     @Override
     public Instant currentWatermark() {
+      // Beyond bounds error checking isn't important since the runner is expected to perform
+      // watermark bounds checking.
+      if (lastObservedTimestamp != null && lastObservedTimestamp.isAfter(watermark)) {
+        watermark = lastObservedTimestamp;
+      }
       return watermark;
     }
 
@@ -150,4 +150,7 @@ public class WatermarkEstimators {
       return watermark;
     }
   }
+
+  // prevent instantiation
+  private WatermarkEstimators() {}
 }
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
index 9dcc86a..dc97ec8 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
@@ -23,6 +23,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.Matchers.equalTo;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
@@ -46,6 +47,7 @@ import org.apache.beam.sdk.coders.CoderProviders;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
 import org.apache.beam.sdk.state.TimeDomain;
@@ -662,7 +664,7 @@ public class DoFnInvokersTest {
   }
 
   @Test
-  public void testSplittableDoFnDefaultMethods() throws Exception {
+  public void testSplittableDoFnWithHasDefaultMethods() throws Exception {
     class MockFn extends DoFn<String, String> {
       @ProcessElement
       public void processElement(
@@ -753,6 +755,30 @@ public class DoFnInvokersTest {
         instanceOf(DefaultWatermarkEstimator.class));
   }
 
+  @Test
+  public void testDefaultWatermarkEstimatorStateAndCoder() throws Exception {
+    class MockFn extends DoFn<String, String> {
+      @ProcessElement
+      public void processElement(
+          ProcessContext c, RestrictionTracker<RestrictionWithDefaultTracker, Void> tracker) {}
+
+      @GetInitialRestriction
+      public RestrictionWithDefaultTracker getInitialRestriction(@Element String element) {
+        return null;
+      }
+    }
+
+    MockFn fn = mock(MockFn.class);
+    DoFnInvoker<String, String> invoker = DoFnInvokers.invokerFor(fn);
+
+    CoderRegistry coderRegistry = CoderRegistry.createDefault();
+    coderRegistry.registerCoderProvider(
+        CoderProviders.fromStaticMethods(
+            RestrictionWithDefaultTracker.class, CoderForDefaultTracker.class));
+    assertEquals(VoidCoder.of(), invoker.invokeGetWatermarkEstimatorStateCoder(coderRegistry));
+    assertNull(invoker.invokeGetInitialWatermarkEstimatorState(new FakeArgumentProvider<>()));
+  }
+
   // ---------------------------------------------------------------------------------------
   // Tests for ability to invoke @OnTimer for private, inner and anonymous classes.
   // ---------------------------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimatorsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimatorsTest.java
index 3fd0dec..a30302d 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimatorsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/WatermarkEstimatorsTest.java
@@ -18,7 +18,6 @@
 package org.apache.beam.sdk.transforms.splittabledofn;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThrows;
 
 import org.apache.beam.sdk.testing.ResetDateTimeProvider;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -41,22 +40,19 @@ public class WatermarkEstimatorsTest {
     assertEquals(GlobalWindow.TIMESTAMP_MIN_VALUE, watermarkEstimator.currentWatermark());
     watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE);
     watermarkEstimator.setWatermark(
-        watermarkEstimator.currentWatermark().plus(Duration.standardHours(1)));
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(2)));
     assertEquals(
-        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(1)),
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(2)),
         watermarkEstimator.currentWatermark());
-    assertThrows(
-        "must be within bounds",
-        IllegalArgumentException.class,
-        () -> watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.minus(1)));
-    assertThrows(
-        "must be within bounds",
-        IllegalArgumentException.class,
-        () -> watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MAX_VALUE.plus(1)));
-    assertThrows(
-        "monotonically increasing",
-        IllegalArgumentException.class,
-        () -> watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE));
+
+    // Make sure that even if the watermark goes backwards we report the "greatest" value we have
+    // reported so far.
+    watermarkEstimator.setWatermark(
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(1)));
+    assertEquals(
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(2)),
+        watermarkEstimator.currentWatermark());
+
     watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MAX_VALUE);
     assertEquals(GlobalWindow.TIMESTAMP_MAX_VALUE, watermarkEstimator.currentWatermark());
   }
@@ -88,14 +84,19 @@ public class WatermarkEstimatorsTest {
     assertEquals(GlobalWindow.TIMESTAMP_MIN_VALUE, watermarkEstimator.currentWatermark());
     watermarkEstimator.observeTimestamp(GlobalWindow.TIMESTAMP_MIN_VALUE);
     watermarkEstimator.observeTimestamp(
-        watermarkEstimator.currentWatermark().plus(Duration.standardHours(1)));
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(2)));
     assertEquals(
-        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(1)),
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(2)),
         watermarkEstimator.currentWatermark());
-    assertThrows(
-        "monotonically increasing",
-        IllegalArgumentException.class,
-        () -> watermarkEstimator.observeTimestamp(GlobalWindow.TIMESTAMP_MIN_VALUE));
+
+    // Make sure that even if the watermark goes backwards we report the "greatest" value we have
+    // reported so far.
+    watermarkEstimator.observeTimestamp(
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(1)));
+    assertEquals(
+        GlobalWindow.TIMESTAMP_MIN_VALUE.plus(Duration.standardHours(2)),
+        watermarkEstimator.currentWatermark());
+
     watermarkEstimator.observeTimestamp(GlobalWindow.TIMESTAMP_MAX_VALUE);
     assertEquals(GlobalWindow.TIMESTAMP_MAX_VALUE, watermarkEstimator.currentWatermark());
   }
diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/splittabledofn/WatermarkEstimators.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/splittabledofn/WatermarkEstimators.java
new file mode 100644
index 0000000..fc91d87
--- /dev/null
+++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/splittabledofn/WatermarkEstimators.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.fn.splittabledofn;
+
+import javax.annotation.concurrent.ThreadSafe;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.TimestampObservingWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Instant;
+
+/** Support utilties for interacting with {@link WatermarkEstimator}s. */
+public class WatermarkEstimators {
+  /** Interface which allows for accessing the current watermark and watermark estimator state. */
+  public interface WatermarkAndStateObserver<WatermarkEstimatorStateT>
+      extends WatermarkEstimator<WatermarkEstimatorStateT> {
+    KV<Instant, WatermarkEstimatorStateT> getWatermarkAndState();
+  }
+
+  /**
+   * Returns a thread safe {@link WatermarkEstimator} which allows getting a snapshot of the current
+   * watermark and watermark estimator state.
+   */
+  public static <WatermarkEstimatorStateT>
+      WatermarkAndStateObserver<WatermarkEstimatorStateT> threadSafe(
+          WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator) {
+    if (watermarkEstimator instanceof TimestampObservingWatermarkEstimator) {
+      return new ThreadSafeTimestampObservingWatermarkEstimator<>(watermarkEstimator);
+    } else if (watermarkEstimator instanceof ManualWatermarkEstimator) {
+      return new ThreadSafeManualWatermarkEstimator<>(watermarkEstimator);
+    }
+    return new ThreadSafeWatermarkEstimator<>(watermarkEstimator);
+  }
+
+  /**
+   * Thread safe wrapper for {@link WatermarkEstimator}s that allows one to get a snapshot of the
+   * current watermark and the watermark estimator state. \
+   */
+  @ThreadSafe
+  private static class ThreadSafeWatermarkEstimator<WatermarkEstimatorStateT>
+      implements WatermarkAndStateObserver<WatermarkEstimatorStateT> {
+    protected final WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator;
+
+    ThreadSafeWatermarkEstimator(WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator) {
+      this.watermarkEstimator = watermarkEstimator;
+    }
+
+    @Override
+    public synchronized Instant currentWatermark() {
+      return watermarkEstimator.currentWatermark();
+    }
+
+    @Override
+    public synchronized WatermarkEstimatorStateT getState() {
+      return watermarkEstimator.getState();
+    }
+
+    @Override
+    public synchronized KV<Instant, WatermarkEstimatorStateT> getWatermarkAndState() {
+      // The order of these calls is important. We want to get the watermark and then its
+      // associated state representation since state is not allowed to mutate the internal
+      // representation.
+      return KV.of(watermarkEstimator.currentWatermark(), watermarkEstimator.getState());
+    }
+  }
+
+  /** Thread safe wrapper for {@link TimestampObservingWatermarkEstimator}s. */
+  @ThreadSafe
+  private static class ThreadSafeTimestampObservingWatermarkEstimator<WatermarkEstimatorStateT>
+      extends ThreadSafeWatermarkEstimator<WatermarkEstimatorStateT>
+      implements TimestampObservingWatermarkEstimator<WatermarkEstimatorStateT> {
+
+    ThreadSafeTimestampObservingWatermarkEstimator(
+        WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator) {
+      super(watermarkEstimator);
+    }
+
+    @Override
+    public synchronized void observeTimestamp(Instant timestamp) {
+      ((TimestampObservingWatermarkEstimator) watermarkEstimator).observeTimestamp(timestamp);
+    }
+  }
+
+  /** Thread safe wrapper for {@link ManualWatermarkEstimator}s. */
+  @ThreadSafe
+  private static class ThreadSafeManualWatermarkEstimator<WatermarkEstimatorStateT>
+      extends ThreadSafeWatermarkEstimator<WatermarkEstimatorStateT>
+      implements ManualWatermarkEstimator<WatermarkEstimatorStateT> {
+
+    ThreadSafeManualWatermarkEstimator(
+        WatermarkEstimator<WatermarkEstimatorStateT> watermarkEstimator) {
+      super(watermarkEstimator);
+    }
+
+    @Override
+    public synchronized void setWatermark(Instant watermark) {
+      ((ManualWatermarkEstimator) watermarkEstimator).setWatermark(watermark);
+    }
+  }
+}
diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/splittabledofn/WatermarkEstimatorsTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/splittabledofn/WatermarkEstimatorsTest.java
new file mode 100644
index 0000000..3e12c7e
--- /dev/null
+++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/splittabledofn/WatermarkEstimatorsTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.fn.splittabledofn;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.function.Consumer;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.TimestampObservingWatermarkEstimator;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.values.KV;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link WatermarkEstimators}. */
+@RunWith(JUnit4.class)
+public class WatermarkEstimatorsTest {
+
+  @Test
+  public void testThreadSafeWatermarkEstimator() throws Exception {
+    Instant[] reference = new Instant[] {GlobalWindow.TIMESTAMP_MIN_VALUE};
+    WatermarkEstimator<Instant> watermarkEstimator =
+        new WatermarkEstimator<Instant>() {
+
+          @Override
+          public Instant currentWatermark() {
+            return reference[0];
+          }
+
+          @Override
+          public Instant getState() {
+            return reference[0];
+          }
+        };
+    testWatermarkEstimatorSnapshotsStateWithCompetingThread(
+        watermarkEstimator, (instant) -> reference[0] = instant);
+  }
+
+  @Test
+  public void testThreadSafeManualWatermarkEstimator() throws Exception {
+    ManualWatermarkEstimator<Instant> watermarkEstimator =
+        new org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Manual(
+            GlobalWindow.TIMESTAMP_MIN_VALUE);
+    testWatermarkEstimatorSnapshotsStateWithCompetingThread(
+        watermarkEstimator, watermarkEstimator::setWatermark);
+  }
+
+  @Test
+  public void testThreadSafeTimestampObservingWatermarkEstimator() throws Exception {
+    TimestampObservingWatermarkEstimator<Instant> watermarkEstimator =
+        new org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators
+            .MonotonicallyIncreasing(GlobalWindow.TIMESTAMP_MIN_VALUE);
+    testWatermarkEstimatorSnapshotsStateWithCompetingThread(
+        watermarkEstimator, watermarkEstimator::observeTimestamp);
+  }
+
+  public <WatermarkEstimatorT extends WatermarkEstimator<Instant>>
+      void testWatermarkEstimatorSnapshotsStateWithCompetingThread(
+          WatermarkEstimatorT watermarkEstimator, Consumer<Instant> watermarkUpdater)
+          throws Exception {
+    CountDownLatch countDownLatch = new CountDownLatch(1);
+    Thread t =
+        new Thread(
+            () -> {
+              countDownLatch.countDown();
+              for (int i = 0; i < 1000; ++i) {
+                watermarkUpdater.accept(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
+              }
+            });
+    t.start();
+
+    WatermarkEstimators.WatermarkAndStateObserver<Instant> threadsafeWatermarkEstimator =
+        WatermarkEstimators.threadSafe(watermarkEstimator);
+
+    // Ensure the thread has started before we start fetching values.
+    countDownLatch.await();
+    Instant currentMinimum = GlobalWindow.TIMESTAMP_MIN_VALUE;
+    for (int i = 0; i < 100; ++i) {
+      KV<Instant, Instant> value = threadsafeWatermarkEstimator.getWatermarkAndState();
+      // The watermark estimators that we use ensure that state == current watermark so test that
+      // they are equal here
+      assertEquals(value.getKey(), value.getValue());
+      // Also ensure that the watermark is not declining and somehow we are getting an "old" read
+      assertFalse(currentMinimum.isAfter(value.getKey()));
+    }
+
+    t.join(10000);
+  }
+}
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 681d041..ec26195 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -27,8 +27,8 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
-import java.util.function.Function;
 import java.util.function.Supplier;
 import javax.annotation.Nullable;
 import org.apache.beam.fn.harness.control.BundleSplitListener;
@@ -55,6 +55,7 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
+import org.apache.beam.sdk.fn.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.function.ThrowingRunnable;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.schemas.SchemaCoder;
@@ -81,6 +82,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.SplitResult;
+import org.apache.beam.sdk.transforms.splittabledofn.TimestampObservingWatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -112,7 +114,7 @@ import org.joda.time.Instant;
  * abstraction caused by StateInternals/TimerInternals since they model state and timer concepts
  * differently.
  */
-public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
+public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT, OutputT> {
   /** A registrar which provides a factory to handle Java {@link DoFn}s. */
   @AutoService(PTransformRunnerFactory.Registrar.class)
   public static class Registrar implements PTransformRunnerFactory.Registrar {
@@ -131,12 +133,12 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
     }
   }
 
-  static class Factory<InputT, RestrictionT, PositionT, OutputT>
+  static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT, OutputT>
       implements PTransformRunnerFactory<
-          FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT>> {
+          FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT, OutputT>> {
 
     @Override
-    public final FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT>
+    public final FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT, OutputT>
         createRunnerForPTransform(
             PipelineOptions pipelineOptions,
             BeamFnDataClient beamFnDataClient,
@@ -154,7 +156,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
             BundleSplitListener splitListener,
             BundleFinalizer bundleFinalizer) {
 
-      FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> runner =
+      FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT, OutputT> runner =
           new FnApiDoFnRunner<>(
               pipelineOptions,
               beamFnStateClient,
@@ -259,7 +261,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
    * PTransformTranslation#SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN} transforms. Can
    * only be invoked from within {@code processElement...} methods.
    */
-  private final Function<SplitResult<RestrictionT>, WindowedSplitResult>
+  private final BiFunction<SplitResult<RestrictionT>, WatermarkEstimatorStateT, WindowedSplitResult>
       convertSplitResultToWindowedSplitResult;
 
   private final DoFnSchemaInformation doFnSchemaInformation;
@@ -269,14 +271,26 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
   /** Only valid during {@code processElement...} methods, null otherwise. */
   private WindowedValue<InputT> currentElement;
 
-  /** Only valid during {@link #processElementForSplitRestriction}, null otherwise. */
+  /**
+   * Only valid during {@link #processElementForPairWithRestriction}, {@link
+   * #processElementForSplitRestriction}, {@link #processElementForElementAndRestriction} and {@link
+   * #processElementForSizedElementAndRestriction}, null otherwise.
+   */
   private RestrictionT currentRestriction;
 
   /**
+   * Only valid during {@link #processElementForSplitRestriction}, {@link
+   * #processElementForElementAndRestriction} and {@link
+   * #processElementForSizedElementAndRestriction}, null otherwise.
+   */
+  private WatermarkEstimatorStateT currentWatermarkEstimatorState;
+
+  /**
    * Only valid during {@link #processElementForElementAndRestriction} and {@link
    * #processElementForSizedElementAndRestriction}, null otherwise.
    */
-  private Instant currentOutputWatermark;
+  private WatermarkEstimators.WatermarkAndStateObserver<WatermarkEstimatorStateT>
+      currentWatermarkEstimator;
 
   /**
    * Only valid during {@code processElement...} and {@link #processTimer} methods, null otherwise.
@@ -455,7 +469,9 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
                     mainOutputConsumers,
                     (WindowedValue<OutputT>)
                         WindowedValue.of(
-                            KV.of(currentElement.getValue(), output),
+                            KV.of(
+                                currentElement.getValue(),
+                                KV.of(output, currentWatermarkEstimatorState)),
                             timestamp,
                             currentWindow,
                             currentElement.getPane()));
@@ -494,7 +510,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
                     mainOutputConsumers,
                     (WindowedValue<OutputT>)
                         WindowedValue.of(
-                            KV.of(KV.of(currentElement.getValue(), output), size),
+                            KV.of(
+                                KV.of(
+                                    currentElement.getValue(),
+                                    KV.of(output, currentWatermarkEstimatorState)),
+                                size),
                             timestamp,
                             currentWindow,
                             currentElement.getPane()));
@@ -517,16 +537,20 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
     switch (pTransform.getSpec().getUrn()) {
       case PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN:
         this.convertSplitResultToWindowedSplitResult =
-            (splitResult) ->
+            (splitResult, watermarkEstimatorState) ->
                 WindowedSplitResult.forRoots(
                     currentElement.withValue(
-                        KV.of(currentElement.getValue(), splitResult.getPrimary())),
+                        KV.of(
+                            currentElement.getValue(),
+                            KV.of(splitResult.getPrimary(), watermarkEstimatorState))),
                     currentElement.withValue(
-                        KV.of(currentElement.getValue(), splitResult.getResidual())));
+                        KV.of(
+                            currentElement.getValue(),
+                            KV.of(splitResult.getResidual(), watermarkEstimatorState))));
         break;
       case PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN:
         this.convertSplitResultToWindowedSplitResult =
-            (splitResult) -> {
+            (splitResult, watermarkEstimatorState) -> {
               double primarySize =
                   doFnInvoker.invokeGetSize(
                       new DelegatingArgumentProvider<InputT, OutputT>(
@@ -564,16 +588,21 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
               return WindowedSplitResult.forRoots(
                   currentElement.withValue(
                       KV.of(
-                          KV.of(currentElement.getValue(), splitResult.getPrimary()), primarySize)),
+                          KV.of(
+                              currentElement.getValue(),
+                              KV.of(splitResult.getPrimary(), watermarkEstimatorState)),
+                          primarySize)),
                   currentElement.withValue(
                       KV.of(
-                          KV.of(currentElement.getValue(), splitResult.getResidual()),
+                          KV.of(
+                              currentElement.getValue(),
+                              KV.of(splitResult.getResidual(), watermarkEstimatorState)),
                           residualSize)));
             };
         break;
       default:
         this.convertSplitResultToWindowedSplitResult =
-            (splitResult) -> {
+            (splitResult, watermarkEstimatorStateT) -> {
               throw new IllegalStateException(
                   String.format(
                       "Unimplemented split conversion handler for %s.",
@@ -620,22 +649,29 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
           (Iterator<BoundedWindow>) elem.getWindows().iterator();
       while (windowIterator.hasNext()) {
         currentWindow = windowIterator.next();
+        currentRestriction = doFnInvoker.invokeGetInitialRestriction(processContext);
         outputTo(
             mainOutputConsumers,
             (WindowedValue)
                 elem.withValue(
                     KV.of(
-                        elem.getValue(), doFnInvoker.invokeGetInitialRestriction(processContext))));
+                        elem.getValue(),
+                        KV.of(
+                            currentRestriction,
+                            doFnInvoker.invokeGetInitialWatermarkEstimatorState(processContext)))));
       }
     } finally {
       currentElement = null;
       currentWindow = null;
+      currentRestriction = null;
     }
   }
 
-  public void processElementForSplitRestriction(WindowedValue<KV<InputT, RestrictionT>> elem) {
+  public void processElementForSplitRestriction(
+      WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
     currentElement = elem.withValue(elem.getValue().getKey());
-    currentRestriction = elem.getValue().getValue();
+    currentRestriction = elem.getValue().getValue().getKey();
+    currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
     try {
       Iterator<BoundedWindow> windowIterator =
           (Iterator<BoundedWindow>) elem.getWindows().iterator();
@@ -646,6 +682,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
     } finally {
       currentElement = null;
       currentRestriction = null;
+      currentWatermarkEstimatorState = null;
       currentWindow = null;
     }
   }
@@ -664,22 +701,23 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
   }
 
   public void processElementForSizedElementAndRestriction(
-      WindowedValue<KV<KV<InputT, RestrictionT>, Double>> elem) {
+      WindowedValue<KV<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>, Double>> elem) {
     processElementForElementAndRestriction(elem.withValue(elem.getValue().getKey()));
   }
 
-  public void processElementForElementAndRestriction(WindowedValue<KV<InputT, RestrictionT>> elem) {
+  public void processElementForElementAndRestriction(
+      WindowedValue<KV<InputT, KV<RestrictionT, WatermarkEstimatorStateT>>> elem) {
     currentElement = elem.withValue(elem.getValue().getKey());
     try {
       Iterator<BoundedWindow> windowIterator =
           (Iterator<BoundedWindow>) elem.getWindows().iterator();
       while (windowIterator.hasNext()) {
-        // TODO(BEAM-2939): Ensure that the watermark we use as the lower bound comes from
-        // the previously reported watermark and doesn't reset to -infinity on each element.
-        currentOutputWatermark = GlobalWindow.TIMESTAMP_MIN_VALUE;
-        currentRestriction = elem.getValue().getValue();
+        currentRestriction = elem.getValue().getValue().getKey();
+        currentWatermarkEstimatorState = elem.getValue().getValue().getValue();
         currentWindow = windowIterator.next();
         currentTracker = doFnInvoker.invokeNewTracker(processContext);
+        currentWatermarkEstimator =
+            WatermarkEstimators.threadSafe(doFnInvoker.invokeNewWatermarkEstimator(processContext));
         DoFn.ProcessContinuation continuation = doFnInvoker.invokeProcessElement(processContext);
         // Ensure that all the work is done if the user tells us that they don't want to
         // resume processing.
@@ -688,6 +726,10 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
           continue;
         }
 
+        // Make sure to get the output watermark before we split to ensure that the lower bound
+        // applies to both the primary and residual.
+        KV<Instant, WatermarkEstimatorStateT> watermarkAndState =
+            currentWatermarkEstimator.getWatermarkAndState();
         SplitResult<RestrictionT> result = currentTracker.trySplit(0);
         // After the user has chosen to resume processing later, the Runner may have stolen
         // the remainder of work through a split call so the above trySplit may fail. If so,
@@ -699,7 +741,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
 
         // Otherwise we have a successful self checkpoint.
         WindowedSplitResult windowedSplitResult =
-            convertSplitResultToWindowedSplitResult.apply(result);
+            convertSplitResultToWindowedSplitResult.apply(result, watermarkAndState.getValue());
         ByteString.Output primaryBytes = ByteString.newOutput();
         ByteString.Output residualBytes = ByteString.newOutput();
         try {
@@ -719,13 +761,14 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
                 .setTransformId(pTransformId)
                 .setInputId(mainInputId)
                 .setElement(residualBytes.toByteString());
-        if (!currentOutputWatermark.equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
+
+        if (!watermarkAndState.getKey().equals(GlobalWindow.TIMESTAMP_MIN_VALUE)) {
           for (String outputId : pTransform.getOutputsMap().keySet()) {
             residualApplication.putOutputWatermarks(
                 outputId,
                 org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.Timestamp.newBuilder()
-                    .setSeconds(currentOutputWatermark.getMillis() / 1000)
-                    .setNanos((int) (currentOutputWatermark.getMillis() % 1000) * 1000000)
+                    .setSeconds(watermarkAndState.getKey().getMillis() / 1000)
+                    .setNanos((int) (watermarkAndState.getKey().getMillis() % 1000) * 1000000)
                     .build());
           }
         }
@@ -741,9 +784,10 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
     } finally {
       currentElement = null;
       currentRestriction = null;
+      currentWatermarkEstimatorState = null;
       currentWindow = null;
       currentTracker = null;
-      currentOutputWatermark = null;
+      currentWatermarkEstimator = null;
     }
   }
 
@@ -780,6 +824,10 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
   /** Outputs the given element to the specified set of consumers wrapping any exceptions. */
   private <T> void outputTo(
       Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
+    if (currentWatermarkEstimator instanceof TimestampObservingWatermarkEstimator) {
+      ((TimestampObservingWatermarkEstimator) currentWatermarkEstimator)
+          .observeTimestamp(output.getTimestamp());
+    }
     try {
       for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
         consumer.accept(output);
@@ -1232,27 +1280,13 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, OutputT> {
     }
 
     @Override
-    public void updateWatermark(Instant watermark) {
-      checkState(
-          currentOutputWatermark != null,
-          "Updating the watermark is only allowed for Splittable DoFns.");
-      checkArgument(
-          !watermark.isBefore(currentOutputWatermark),
-          "Watermark must be monotonically increasing. Provided watermark %s is less then current watermark %s.",
-          watermark,
-          currentOutputWatermark);
-      currentOutputWatermark = watermark;
-    }
-
-    @Override
     public Object watermarkEstimatorState() {
-      throw new UnsupportedOperationException(
-          "@WatermarkEstimatorState parameters are not supported.");
+      return currentWatermarkEstimatorState;
     }
 
     @Override
     public WatermarkEstimator<?> watermarkEstimator() {
-      throw new UnsupportedOperationException("WatermarkEstimator parameters are not supported.");
+      return currentWatermarkEstimator;
     }
   }
 
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index f211b6d..426bf85 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -84,8 +84,10 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
 import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -1044,13 +1046,15 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
     @ProcessElement
     public ProcessContinuation processElement(
-        ProcessContext context, RestrictionTracker<OffsetRange, Long> tracker) {
+        ProcessContext context,
+        RestrictionTracker<OffsetRange, Long> tracker,
+        ManualWatermarkEstimator<Instant> watermarkEstimator) {
       int upperBound = Integer.parseInt(context.sideInput(singletonSideInput));
       for (int i = 0; i < upperBound; ++i) {
         if (tracker.tryClaim((long) i)) {
           context.outputWithTimestamp(
               context.element() + ":" + i, GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
-          context.updateWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
+          watermarkEstimator.setWatermark(GlobalWindow.TIMESTAMP_MIN_VALUE.plus(i));
         }
       }
       if (tracker.currentRestriction().getTo() > upperBound) {
@@ -1075,6 +1079,17 @@ public class FnApiDoFnRunnerTest implements Serializable {
       receiver.output(new OffsetRange(range.getFrom(), (range.getFrom() + range.getTo()) / 2));
       receiver.output(new OffsetRange((range.getFrom() + range.getTo()) / 2, range.getTo()));
     }
+
+    @GetInitialWatermarkEstimatorState
+    public Instant getInitialWatermarkEstimatorState() {
+      return GlobalWindow.TIMESTAMP_MIN_VALUE;
+    }
+
+    @NewWatermarkEstimator
+    public WatermarkEstimators.Manual newWatermarkEstimator(
+        @WatermarkEstimatorState Instant watermark) {
+      return new WatermarkEstimators.Manual(watermark);
+    }
   }
 
   @Test
@@ -1170,7 +1185,10 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
     FnDataReceiver<WindowedValue<?>> mainInput =
         consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInGlobalWindow(KV.of(KV.of("5", new OffsetRange(0, 5)), 5.0)));
+    mainInput.accept(
+        valueInGlobalWindow(
+            KV.of(
+                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)), 5.0)));
     BundleApplication primaryRoot = Iterables.getOnlyElement(primarySplits);
     DelayedBundleApplication residualRoot = Iterables.getOnlyElement(residualSplits);
     assertEquals(ParDoTranslation.getMainInputName(pTransform), primaryRoot.getInputId());
@@ -1192,7 +1210,10 @@ public class FnApiDoFnRunnerTest implements Serializable {
     primarySplits.clear();
     residualSplits.clear();
 
-    mainInput.accept(valueInGlobalWindow(KV.of(KV.of("2", new OffsetRange(0, 2)), 2.0)));
+    mainInput.accept(
+        valueInGlobalWindow(
+            KV.of(
+                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)), 2.0)));
     assertThat(
         mainOutputValues,
         contains(
@@ -1293,8 +1314,10 @@ public class FnApiDoFnRunnerTest implements Serializable {
     assertThat(
         mainOutputValues,
         contains(
-            valueInGlobalWindow(KV.of("5", new OffsetRange(0, 5))),
-            valueInGlobalWindow(KV.of("2", new OffsetRange(0, 2)))));
+            valueInGlobalWindow(
+                KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE))),
+            valueInGlobalWindow(
+                KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)))));
     mainOutputValues.clear();
 
     Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();
@@ -1378,15 +1401,31 @@ public class FnApiDoFnRunnerTest implements Serializable {
 
     FnDataReceiver<WindowedValue<?>> mainInput =
         consumers.getMultiplexingConsumer(inputPCollectionId);
-    mainInput.accept(valueInGlobalWindow(KV.of("5", new OffsetRange(0, 5))));
-    mainInput.accept(valueInGlobalWindow(KV.of("2", new OffsetRange(0, 2))));
+    mainInput.accept(
+        valueInGlobalWindow(
+            KV.of("5", KV.of(new OffsetRange(0, 5), GlobalWindow.TIMESTAMP_MIN_VALUE))));
+    mainInput.accept(
+        valueInGlobalWindow(
+            KV.of("2", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE))));
     assertThat(
         mainOutputValues,
         contains(
-            valueInGlobalWindow(KV.of(KV.of("5", new OffsetRange(0, 2)), 2.0)),
-            valueInGlobalWindow(KV.of(KV.of("5", new OffsetRange(2, 5)), 3.0)),
-            valueInGlobalWindow(KV.of(KV.of("2", new OffsetRange(0, 1)), 1.0)),
-            valueInGlobalWindow(KV.of(KV.of("2", new OffsetRange(1, 2)), 1.0))));
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(0, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    2.0)),
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("5", KV.of(new OffsetRange(2, 5), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    3.0)),
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("2", KV.of(new OffsetRange(0, 1), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    1.0)),
+            valueInGlobalWindow(
+                KV.of(
+                    KV.of("2", KV.of(new OffsetRange(1, 2), GlobalWindow.TIMESTAMP_MIN_VALUE)),
+                    1.0))));
     mainOutputValues.clear();
 
     Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run();