You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2019/11/13 17:05:11 UTC

[beam] branch master updated: Update SDF APIs

This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b94dca2  Update SDF APIs
     new 352ecf3  Merge pull request #9794 from boyuanzz/watermark
b94dca2 is described below

commit b94dca2b1df4a3aea66d822299a58c97accc0541
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Mon Oct 14 16:23:21 2019 -0700

    Update SDF APIs
---
 .../fn-execution/src/main/proto/beam_fn_api.proto  |   9 +
 sdks/python/apache_beam/io/iobase.py               | 231 +++++++++++++--------
 sdks/python/apache_beam/io/iobase_test.py          |  86 ++++++++
 sdks/python/apache_beam/io/restriction_trackers.py | 128 ++++--------
 .../apache_beam/io/restriction_trackers_test.py    |  12 +-
 sdks/python/apache_beam/runners/common.pxd         |   6 +-
 sdks/python/apache_beam/runners/common.py          |  62 ++++--
 .../runners/direct/sdf_direct_runner_test.py       |  18 +-
 .../runners/portability/flink_runner_test.py       |   3 +
 .../runners/portability/fn_api_runner_test.py      |  75 ++++++-
 .../apache_beam/runners/worker/bundle_processor.py |  47 +++--
 .../apache_beam/testing/synthetic_pipeline.py      |   2 +-
 sdks/python/apache_beam/transforms/core.py         |  75 ++++++-
 sdks/python/apache_beam/transforms/core_test.py    |  54 +++++
 sdks/python/apache_beam/utils/timestamp.py         |   7 +
 sdks/python/apache_beam/utils/timestamp_test.py    |   6 +
 sdks/python/scripts/generate_pydoc.sh              |   1 +
 17 files changed, 591 insertions(+), 231 deletions(-)

diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto
index ed2f013..0ddc48e 100644
--- a/model/fn-execution/src/main/proto/beam_fn_api.proto
+++ b/model/fn-execution/src/main/proto/beam_fn_api.proto
@@ -42,6 +42,7 @@ import "beam_runner_api.proto";
 import "endpoints.proto";
 import "google/protobuf/descriptor.proto";
 import "google/protobuf/timestamp.proto";
+import "google/protobuf/duration.proto";
 import "google/protobuf/wrappers.proto";
 import "metrics.proto";
 
@@ -203,13 +204,21 @@ message BundleApplication {
 }
 
 // An Application should be scheduled for execution after a delay.
+// Either an absolute timestamp or a relative timestamp can represent a
+// scheduled execution time.
 message DelayedBundleApplication {
   // Recommended time at which the application should be scheduled to execute
   // by the runner. Times in the past may be scheduled to execute immediately.
+  // TODO(BEAM-8536): Migrate usage of absolute time to requested_time_delay.
   google.protobuf.Timestamp requested_execution_time = 1;
 
   // (Required) The application that should be scheduled.
   BundleApplication application = 2;
+
+  // Recommended time delay at which the application should be scheduled to
+  // execute by the runner. Time delay that equals 0 may be scheduled to execute
+  // immediately. The unit of time delay should be microsecond.
+  google.protobuf.Duration requested_time_delay = 3;
 }
 
 // A request to process a given bundle.
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 5b66730..e21052f 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -35,6 +35,7 @@ from __future__ import division
 import logging
 import math
 import random
+import threading
 import uuid
 from builtins import object
 from builtins import range
@@ -1104,13 +1105,17 @@ class _RoundRobinKeyFn(core.DoFn):
 class RestrictionTracker(object):
   """Manages concurrent access to a restriction.
 
-  Experimental; no backwards-compatibility guarantees.
-
   Keeps track of the restrictions claimed part for a Splittable DoFn.
 
+  The restriction may be modified by different threads, however the system will
+  ensure sufficient locking such that no methods on the restriction tracker
+  will be called concurrently.
+
   See following documents for more details.
   * https://s.apache.org/splittable-do-fn
   * https://s.apache.org/splittable-do-fn-python-sdk
+
+  Experimental; no backwards-compatibility guarantees.
   """
 
   def current_restriction(self):
@@ -1121,52 +1126,20 @@ class RestrictionTracker(object):
 
     The current restriction returned by method may be updated dynamically due
     to due to concurrent invocation of other methods of the
-    ``RestrictionTracker``, For example, ``checkpoint()``.
-
-    ** Thread safety **
+    ``RestrictionTracker``, For example, ``split()``.
 
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
+    This API is required to be implemented.
 
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+    Returns: a restriction object.
     """
     raise NotImplementedError
 
   def current_progress(self):
     """Returns a RestrictionProgress object representing the current progress.
-    """
-    raise NotImplementedError
-
-  def current_watermark(self):
-    """Returns current watermark. By default, not report watermark.
-
-    TODO(BEAM-7473): Provide synchronization guarantee by using a wrapper.
-    """
-    return None
-
-  def checkpoint(self):
-    """Performs a checkpoint of the current restriction.
-
-    Signals that the current ``DoFn.process()`` call should terminate as soon as
-    possible. After this method returns, the tracker MUST refuse all future
-    claim calls, and ``RestrictionTracker.check_done()`` MUST succeed.
-
-    This invocation modifies the value returned by ``current_restriction()``
-    invocation and returns a restriction representing the rest of the work. The
-    old value of ``current_restriction()`` is equivalent to the new value of
-    ``current_restriction()`` and the return value of this method invocation
-    combined.
 
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+    This API is recommended to be implemented. The runner can do a better job
+    at parallel processing with better progress signals.
     """
-
     raise NotImplementedError
 
   def check_done(self):
@@ -1179,13 +1152,8 @@ class RestrictionTracker(object):
     remaining in the restriction when this method is invoked. Exception raised
     must have an informative error message.
 
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
+    This API is required to be implemented in order to make sure no data loss
+    during SDK processing.
 
     Returns: ``True`` if current restriction has been fully processed.
     Raises:
@@ -1215,8 +1183,12 @@ class RestrictionTracker(object):
     restrictions returned would be [100, 179), [179, 200) (note: current_offset
     + fraction_of_remainder * remaining_work = 130 + 0.7 * 70 = 179).
 
-    It is very important for pipeline scaling and end to end pipeline execution
-    that try_split is implemented well.
+    ``fraction_of_remainder`` = 0 means a checkpoint is required.
+
+    The API is recommended to be implemented for batch pipeline given that it is
+    very important for pipeline scaling and end to end pipeline execution.
+
+    The API is required to be implemented for a streaming pipeline.
 
     Args:
       fraction_of_remainder: A hint as to the fraction of work the primary
@@ -1226,19 +1198,11 @@ class RestrictionTracker(object):
     Returns:
       (primary_restriction, residual_restriction) if a split was possible,
       otherwise returns ``None``.
-
-    ** Thread safety **
-
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
-
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
     """
     raise NotImplementedError
 
   def try_claim(self, position):
-    """ Attempts to claim the block of work in the current restriction
+    """Attempts to claim the block of work in the current restriction
     identified by the given position.
 
     If this succeeds, the DoFn MUST execute the entire block of work. If it
@@ -1247,40 +1211,137 @@ class RestrictionTracker(object):
     work from ``DoFn.process()`` is also not allowed before the first call of
     this method).
 
+    The API is required to be implemented.
+
     Args:
       position: current position that wants to be claimed.
 
     Returns: ``True`` if the position can be claimed as current_position.
     Otherwise, returns ``False``.
+    """
+    raise NotImplementedError
 
-    ** Thread safety **
 
-    Methods of the class ``RestrictionTracker`` including this method may get
-    invoked by different threads, hence must be made thread-safe, e.g. by using
-    a single lock object.
+class ThreadsafeRestrictionTracker(object):
+  """A thread-safe wrapper which wraps a `RestritionTracker`.
 
-    TODO(BEAM-7473): Remove thread safety requirements from API implementation.
-    """
-    raise NotImplementedError
+  This wrapper guarantees synchronization of modifying restrictions across
+  multi-thread.
+  """
+
+  def __init__(self, restriction_tracker):
+    if not isinstance(restriction_tracker, RestrictionTracker):
+      raise ValueError(
+          'Initialize ThreadsafeRestrictionTracker requires'
+          'RestrictionTracker.')
+    self._restriction_tracker = restriction_tracker
+    # Records an absolute timestamp when defer_remainder is called.
+    self._deferred_timestamp = None
+    self._lock = threading.RLock()
+    self._deferred_residual = None
+    self._deferred_watermark = None
 
-  def defer_remainder(self, watermark=None):
-    """ Invokes checkpoint() in an SDF.process().
+  def current_restriction(self):
+    with self._lock:
+      return self._restriction_tracker.current_restriction()
 
-    TODO(BEAM-7472): Remove defer_remainder() once SDF.process() uses
-    ``ProcessContinuation``.
+  def try_claim(self, position):
+    with self._lock:
+      return self._restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    """Performs self-checkpoint on current processing restriction with an
+    expected resuming time.
+
+    Self-checkpoint could happen during processing elements. When executing an
+    DoFn.process(), you may want to stop processing an element and resuming
+    later if current element has been processed quit a long time or you also
+    want to have some outputs from other elements. ``defer_remainder()`` can be
+    called on per element if needed.
 
     Args:
-      watermark
+      deferred_time: A relative ``timestamp.Duration`` that indicates the ideal
+      time gap between now and resuming, or an absolute ``timestamp.Timestamp``
+      for resuming execution time. If the time_delay is None, the deferred work
+      will be executed as soon as possible.
     """
-    raise NotImplementedError
+
+    # Record current time for calculating deferred_time later.
+    self._deferred_timestamp = timestamp.Timestamp.now()
+    if (deferred_time and
+        not isinstance(deferred_time, timestamp.Duration) and
+        not isinstance(deferred_time, timestamp.Timestamp)):
+      raise ValueError('The timestamp of deter_remainder() should be a '
+                       'Duration or a Timestamp, or None.')
+    self._deferred_watermark = deferred_time
+    checkpoint = self.try_split(0)
+    if checkpoint:
+      _, self._deferred_residual = checkpoint
+
+  def check_done(self):
+    with self._lock:
+      return self._restriction_tracker.check_done()
+
+  def current_progress(self):
+    with self._lock:
+      return self._restriction_tracker.current_progress()
+
+  def try_split(self, fraction_of_remainder):
+    with self._lock:
+      return self._restriction_tracker.try_split(fraction_of_remainder)
 
   def deferred_status(self):
-    """ Returns deferred_residual with deferred_watermark.
+    """Returns deferred work which is produced by ``defer_remainder()``.
+
+    When there is a self-checkpoint performed, the system needs to fulfill the
+    DelayedBundleApplication with deferred_work for a  ProcessBundleResponse.
+    The system calls this API to get deferred_residual with watermark together
+    to help the runner to schedule a future work.
 
-    TODO(BEAM-7472): Remove defer_status() once SDF.process() uses
-    ``ProcessContinuation``.
+    Returns: (deferred_residual, time_delay) if having any residual, else None.
     """
-    raise NotImplementedError
+    if self._deferred_residual:
+      # If _deferred_watermark is None, create Duration(0).
+      if not self._deferred_watermark:
+        self._deferred_watermark = timestamp.Duration()
+      # If an absolute timestamp is provided, calculate the delta between
+      # the absoluted time and the time deferred_status() is called.
+      elif isinstance(self._deferred_watermark, timestamp.Timestamp):
+        self._deferred_watermark = (self._deferred_watermark -
+                                    timestamp.Timestamp.now())
+      # If a Duration is provided, the deferred time should be:
+      # provided duration - the spent time since the defer_remainder() is
+      # called.
+      elif isinstance(self._deferred_watermark, timestamp.Duration):
+        self._deferred_watermark -= (timestamp.Timestamp.now() -
+                                     self._deferred_timestamp)
+      return self._deferred_residual, self._deferred_watermark
+
+
+class RestrictionTrackerView(object):
+  """A DoFn view of thread-safe RestrictionTracker.
+
+  The RestrictionTrackerView wraps a ThreadsafeRestrictionTracker and only
+  exposes APIs that will be called by a ``DoFn.process()``. During execution
+  time, the RestrictionTrackerView will be fed into the ``DoFn.process`` as a
+  restriction_tracker.
+  """
+
+  def __init__(self, threadsafe_restriction_tracker):
+    if not isinstance(threadsafe_restriction_tracker,
+                      ThreadsafeRestrictionTracker):
+      raise ValueError('Initialize RestrictionTrackerView requires '
+                       'ThreadsafeRestrictionTracker.')
+    self._threadsafe_restriction_tracker = threadsafe_restriction_tracker
+
+  def current_restriction(self):
+    return self._threadsafe_restriction_tracker.current_restriction()
+
+  def try_claim(self, position):
+    return self._threadsafe_restriction_tracker.try_claim(position)
+
+  def defer_remainder(self, deferred_time=None):
+    self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
 
 
 class RestrictionProgress(object):
@@ -1400,17 +1461,8 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
                 SourceBundle(residual_weight, self._source, split_pos,
                              stop_pos))
 
-    def deferred_status(self):
-      return None
-
-    def current_watermark(self):
-      return None
-
-    def get_delegate_range_tracker(self):
-      return self._delegate_range_tracker
-
-    def get_tracking_source(self):
-      return self._source
+    def check_done(self):
+      return self._delegate_range_tracker.fraction_consumed() >= 1.0
 
   class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
     """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
@@ -1463,8 +1515,13 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
           restriction_tracker=core.DoFn.RestrictionParam(
               _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
                   source, chunk_size))):
-        return restriction_tracker.get_tracking_source().read(
-            restriction_tracker.get_delegate_range_tracker())
+        current_restriction = restriction_tracker.current_restriction()
+        assert isinstance(current_restriction, SourceBundle)
+        tracking_source = current_restriction.source
+        start = current_restriction.start_position
+        stop = current_restriction.stop_position
+        return tracking_source.read(tracking_source.get_range_tracker(start,
+                                                                      stop))
 
     return SDFBoundedSourceDoFn(self.source)
 
diff --git a/sdks/python/apache_beam/io/iobase_test.py b/sdks/python/apache_beam/io/iobase_test.py
index 7adb764..0a6afae 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -19,6 +19,7 @@
 
 from __future__ import absolute_import
 
+import time
 import unittest
 
 import mock
@@ -28,6 +29,9 @@ from apache_beam.io.concat_source import ConcatSource
 from apache_beam.io.concat_source_test import RangeSource
 from apache_beam.io import iobase
 from apache_beam.io.iobase import SourceBundle
+from apache_beam.io.restriction_trackers import OffsetRange
+from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
+from apache_beam.utils import timestamp
 from apache_beam.options.pipeline_options import DebugOptions
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -191,5 +195,87 @@ class UseSdfBoundedSourcesTests(unittest.TestCase):
     self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
 
 
+class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
+
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
+
+  def test_defer_remainder_with_wrong_time_type(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    with self.assertRaises(ValueError):
+      threadsafe_tracker.defer_remainder(10)
+
+  def test_self_checkpoint_immediately(self):
+    restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        restriction_tracker)
+    threadsafe_tracker.defer_remainder()
+    deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
+    expected_residual = OffsetRange(0, 10)
+    self.assertEqual(deferred_residual, expected_residual)
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    self.assertEqual(deferred_time, 0)
+
+  def test_self_checkpoint_with_relative_time(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    threadsafe_tracker.defer_remainder(timestamp.Duration(100))
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation = 100 - 2 - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+  def test_self_checkpoint_with_absolute_time(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    now = timestamp.Timestamp.now()
+    schedule_time = now + timestamp.Duration(100)
+    self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
+    threadsafe_tracker.defer_remainder(schedule_time)
+    time.sleep(2)
+    _, deferred_time = threadsafe_tracker.deferred_status()
+    self.assertTrue(isinstance(deferred_time, timestamp.Duration))
+    # The expectation =
+    # schedule_time - the time when deferred_status is called - some_delta
+    self.assertTrue(deferred_time <= 98)
+
+
+class RestrictionTrackerViewTest(unittest.TestCase):
+
+  def test_initialization(self):
+    with self.assertRaises(ValueError):
+      iobase.RestrictionTrackerView(
+          OffsetRestrictionTracker(OffsetRange(0, 10)))
+
+  def test_api_expose(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+    current_restriction = tracker_view.current_restriction()
+    self.assertEqual(current_restriction, OffsetRange(0, 10))
+    self.assertTrue(tracker_view.try_claim(0))
+    tracker_view.defer_remainder()
+    deferred_remainder, deferred_watermark = (
+        threadsafe_tracker.deferred_status())
+    self.assertEqual(deferred_remainder, OffsetRange(1, 10))
+    self.assertEqual(deferred_watermark, timestamp.Duration())
+
+  def test_non_expose_apis(self):
+    threadsafe_tracker = iobase.ThreadsafeRestrictionTracker(
+        OffsetRestrictionTracker(OffsetRange(0, 10)))
+    tracker_view = iobase.RestrictionTrackerView(threadsafe_tracker)
+    with self.assertRaises(AttributeError):
+      tracker_view.check_done()
+    with self.assertRaises(AttributeError):
+      tracker_view.current_progress()
+    with self.assertRaises(AttributeError):
+      tracker_view.try_split()
+    with self.assertRaises(AttributeError):
+      tracker_view.deferred_status()
+
+
 if __name__ == '__main__':
   unittest.main()
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index 0ba5b23..20bb5c1 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -19,7 +19,6 @@
 from __future__ import absolute_import
 from __future__ import division
 
-import threading
 from builtins import object
 
 from apache_beam.io.iobase import RestrictionProgress
@@ -86,104 +85,69 @@ class OffsetRestrictionTracker(RestrictionTracker):
     assert isinstance(offset_range, OffsetRange)
     self._range = offset_range
     self._current_position = None
-    self._current_watermark = None
     self._last_claim_attempt = None
-    self._deferred_residual = None
     self._checkpointed = False
-    self._lock = threading.RLock()
 
   def check_done(self):
-    with self._lock:
-      if self._last_claim_attempt < self._range.stop - 1:
-        raise ValueError(
-            'OffsetRestrictionTracker is not done since work in range [%s, %s) '
-            'has not been claimed.'
-            % (self._last_claim_attempt if self._last_claim_attempt is not None
-               else self._range.start,
-               self._range.stop))
+    if self._last_claim_attempt < self._range.stop - 1:
+      raise ValueError(
+          'OffsetRestrictionTracker is not done since work in range [%s, %s) '
+          'has not been claimed.'
+          % (self._last_claim_attempt if self._last_claim_attempt is not None
+             else self._range.start,
+             self._range.stop))
 
   def current_restriction(self):
-    with self._lock:
-      return self._range
-
-  def current_watermark(self):
-    return self._current_watermark
+    return self._range
 
   def current_progress(self):
-    with self._lock:
-      if self._current_position is None:
-        fraction = 0.0
-      elif self._range.stop == self._range.start:
-        # If self._current_position is not None, we must be done.
-        fraction = 1.0
-      else:
-        fraction = (
-            float(self._current_position - self._range.start)
-            / (self._range.stop - self._range.start))
+    if self._current_position is None:
+      fraction = 0.0
+    elif self._range.stop == self._range.start:
+      # If self._current_position is not None, we must be done.
+      fraction = 1.0
+    else:
+      fraction = (
+          float(self._current_position - self._range.start)
+          / (self._range.stop - self._range.start))
     return RestrictionProgress(fraction=fraction)
 
   def start_position(self):
-    with self._lock:
-      return self._range.start
+    return self._range.start
 
   def stop_position(self):
-    with self._lock:
-      return self._range.stop
-
-  def default_size(self):
-    return self._range.size()
+    return self._range.stop
 
   def try_claim(self, position):
-    with self._lock:
-      if self._last_claim_attempt and position <= self._last_claim_attempt:
-        raise ValueError(
-            'Positions claimed should strictly increase. Trying to claim '
-            'position %d while last claim attempt was %d.'
-            % (position, self._last_claim_attempt))
-
-      self._last_claim_attempt = position
-      if position < self._range.start:
-        raise ValueError(
-            'Position to be claimed cannot be smaller than the start position '
-            'of the range. Tried to claim position %r for the range [%r, %r)'
-            % (position, self._range.start, self._range.stop))
-
-      if position >= self._range.start and position < self._range.stop:
-        self._current_position = position
-        return True
+    if self._last_claim_attempt and position <= self._last_claim_attempt:
+      raise ValueError(
+          'Positions claimed should strictly increase. Trying to claim '
+          'position %d while last claim attempt was %d.'
+          % (position, self._last_claim_attempt))
 
-      return False
+    self._last_claim_attempt = position
+    if position < self._range.start:
+      raise ValueError(
+          'Position to be claimed cannot be smaller than the start position '
+          'of the range. Tried to claim position %r for the range [%r, %r)'
+          % (position, self._range.start, self._range.stop))
+
+    if position >= self._range.start and position < self._range.stop:
+      self._current_position = position
+      return True
+
+    return False
 
   def try_split(self, fraction_of_remainder):
-    with self._lock:
-      if not self._checkpointed:
-        if self._current_position is None:
-          cur = self._range.start - 1
-        else:
-          cur = self._current_position
-        split_point = (
-            cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
-        if split_point < self._range.stop:
-          self._range, residual_range = self._range.split_at(split_point)
-          return self._range, residual_range
-
-  # TODO(SDF): Replace all calls with try_claim(0).
-  def checkpoint(self):
-    with self._lock:
-      # If self._current_position is 'None' no records have been claimed so
-      # residual should start from self._range.start.
+    if not self._checkpointed:
       if self._current_position is None:
-        end_position = self._range.start
+        cur = self._range.start - 1
       else:
-        end_position = self._current_position + 1
-      self._range, residual_range = self._range.split_at(end_position)
-      return residual_range
-
-  def defer_remainder(self, watermark=None):
-    with self._lock:
-      self._deferred_watermark = watermark or self._current_watermark
-      self._deferred_residual = self.checkpoint()
-
-  def deferred_status(self):
-    if self._deferred_residual:
-      return (self._deferred_residual, self._deferred_watermark)
+        cur = self._current_position
+      split_point = (
+          cur + int(max(1, (self._range.stop - cur) * fraction_of_remainder)))
+      if split_point < self._range.stop:
+        if fraction_of_remainder == 0:
+          self._checkpointed = True
+        self._range, residual_range = self._range.split_at(split_point)
+        return self._range, residual_range
diff --git a/sdks/python/apache_beam/io/restriction_trackers_test.py b/sdks/python/apache_beam/io/restriction_trackers_test.py
index 459b039..4a57d98 100644
--- a/sdks/python/apache_beam/io/restriction_trackers_test.py
+++ b/sdks/python/apache_beam/io/restriction_trackers_test.py
@@ -81,14 +81,14 @@ class OffsetRestrictionTrackerTest(unittest.TestCase):
 
   def test_checkpoint_unstarted(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 100), tracker.current_restriction())
     self.assertEqual(OffsetRange(100, 200), checkpoint)
 
   def test_checkpoint_just_started(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
     self.assertTrue(tracker.try_claim(100))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 101), tracker.current_restriction())
     self.assertEqual(OffsetRange(101, 200), checkpoint)
 
@@ -96,7 +96,7 @@ class OffsetRestrictionTrackerTest(unittest.TestCase):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
     self.assertTrue(tracker.try_claim(105))
     self.assertTrue(tracker.try_claim(110))
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 111), tracker.current_restriction())
     self.assertEqual(OffsetRange(111, 200), checkpoint)
 
@@ -105,9 +105,9 @@ class OffsetRestrictionTrackerTest(unittest.TestCase):
     self.assertTrue(tracker.try_claim(105))
     self.assertTrue(tracker.try_claim(110))
     self.assertTrue(tracker.try_claim(199))
-    checkpoint = tracker.checkpoint()
+    checkpoint = tracker.try_split(0)
     self.assertEqual(OffsetRange(100, 200), tracker.current_restriction())
-    self.assertEqual(OffsetRange(200, 200), checkpoint)
+    self.assertEqual(None, checkpoint)
 
   def test_checkpoint_after_failed_claim(self):
     tracker = OffsetRestrictionTracker(OffsetRange(100, 200))
@@ -116,7 +116,7 @@ class OffsetRestrictionTrackerTest(unittest.TestCase):
     self.assertTrue(tracker.try_claim(160))
     self.assertFalse(tracker.try_claim(240))
 
-    checkpoint = tracker.checkpoint()
+    _, checkpoint = tracker.try_split(0)
     self.assertTrue(OffsetRange(100, 161), tracker.current_restriction())
     self.assertTrue(OffsetRange(161, 200), checkpoint)
 
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index 2ffe432..37e05bf 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -42,6 +42,8 @@ cdef class MethodWrapper(object):
   cdef object key_arg_name
   cdef object restriction_provider
   cdef object restriction_provider_arg_name
+  cdef object watermark_estimator
+  cdef object watermark_estimator_arg_name
 
 
 cdef class DoFnSignature(object):
@@ -91,7 +93,9 @@ cdef class PerWindowInvoker(DoFnInvoker):
   cdef bint cache_globally_windowed_args
   cdef object process_method
   cdef bint is_splittable
-  cdef object restriction_tracker
+  cdef object threadsafe_restriction_tracker
+  cdef object watermark_estimator
+  cdef object watermark_estimator_param
   cdef WindowedValue current_windowed_value
   cdef bint is_key_param_required
 
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 3e14f3b..8632cfd 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -14,7 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 # cython: profile=True
 
 """Worker operations executor.
@@ -167,6 +166,8 @@ class MethodWrapper(object):
     self.key_arg_name = None
     self.restriction_provider = None
     self.restriction_provider_arg_name = None
+    self.watermark_estimator = None
+    self.watermark_estimator_arg_name = None
 
     for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
       if isinstance(v, core.DoFn.StateParam):
@@ -184,6 +185,9 @@ class MethodWrapper(object):
       elif isinstance(v, core.DoFn.RestrictionParam):
         self.restriction_provider = v.restriction_provider
         self.restriction_provider_arg_name = kw
+      elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
+        self.watermark_estimator = v.watermark_estimator
+        self.watermark_estimator_arg_name = kw
 
   def invoke_timer_callback(self,
                             user_state_context,
@@ -264,6 +268,9 @@ class DoFnSignature(object):
   def get_restriction_provider(self):
     return self.process_method.restriction_provider
 
+  def get_watermark_estimator(self):
+    return self.process_method.watermark_estimator
+
   def _validate(self):
     self._validate_process()
     self._validate_bundle_method(self.start_bundle_method)
@@ -458,7 +465,11 @@ class PerWindowInvoker(DoFnInvoker):
         signature.is_stateful_dofn())
     self.user_state_context = user_state_context
     self.is_splittable = signature.is_splittable_dofn()
-    self.restriction_tracker = None
+    self.watermark_estimator = self.signature.get_watermark_estimator()
+    self.watermark_estimator_param = (
+        self.signature.process_method.watermark_estimator_arg_name
+        if self.watermark_estimator else None)
+    self.threadsafe_restriction_tracker = None
     self.current_windowed_value = None
     self.bundle_finalizer_param = bundle_finalizer_param
     self.is_key_param_required = False
@@ -569,15 +580,24 @@ class PerWindowInvoker(DoFnInvoker):
         raise ValueError(
             'A RestrictionTracker %r was provided but DoFn does not have a '
             'RestrictionTrackerParam defined' % restriction_tracker)
-      additional_kwargs[restriction_tracker_param] = restriction_tracker
+      from apache_beam.io import iobase
+      self.threadsafe_restriction_tracker = iobase.ThreadsafeRestrictionTracker(
+          restriction_tracker)
+      additional_kwargs[restriction_tracker_param] = (
+          iobase.RestrictionTrackerView(self.threadsafe_restriction_tracker))
+
+      if self.watermark_estimator:
+        # The watermark estimator needs to be reset for every element.
+        self.watermark_estimator.reset()
+        additional_kwargs[self.watermark_estimator_param] = (
+            self.watermark_estimator)
       try:
         self.current_windowed_value = windowed_value
-        self.restriction_tracker = restriction_tracker
         return self._invoke_process_per_window(
             windowed_value, additional_args, additional_kwargs,
             output_processor)
       finally:
-        self.restriction_tracker = None
+        self.threadsafe_restriction_tracker = None
         self.current_windowed_value = windowed_value
 
     elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
@@ -664,24 +684,34 @@ class PerWindowInvoker(DoFnInvoker):
           windowed_value, self.process_method(*args_for_process))
 
     if self.is_splittable:
-      deferred_status = self.restriction_tracker.deferred_status()
+      # TODO: Consider calling check_done right after SDF.Process() finishing.
+      # In order to do this, we need to know that current invoking dofn is
+      # ProcessSizedElementAndRestriction.
+      self.threadsafe_restriction_tracker.check_done()
+      deferred_status = self.threadsafe_restriction_tracker.deferred_status()
+      output_watermark = None
+      if self.watermark_estimator:
+        output_watermark = self.watermark_estimator.current_watermark()
       if deferred_status:
         deferred_restriction, deferred_watermark = deferred_status
         element = windowed_value.value
         size = self.signature.get_restriction_provider().restriction_size(
             element, deferred_restriction)
-        return (
+        return ((
             windowed_value.with_value(((element, deferred_restriction), size)),
-            deferred_watermark)
+            output_watermark), deferred_watermark)
 
   def try_split(self, fraction):
-    restriction_tracker = self.restriction_tracker
+    restriction_tracker = self.threadsafe_restriction_tracker
     current_windowed_value = self.current_windowed_value
     if restriction_tracker and current_windowed_value:
       # Temporary workaround for [BEAM-7473]: get current_watermark before
       # split, in case watermark gets advanced before getting split results.
       # In worst case, current_watermark is always stale, which is ok.
-      current_watermark = restriction_tracker.current_watermark()
+      if self.watermark_estimator:
+        current_watermark = self.watermark_estimator.current_watermark()
+      else:
+        current_watermark = None
       split = restriction_tracker.try_split(fraction)
       if split:
         primary, residual = split
@@ -690,15 +720,13 @@ class PerWindowInvoker(DoFnInvoker):
         primary_size = restriction_provider.restriction_size(element, primary)
         residual_size = restriction_provider.restriction_size(element, residual)
         return (
-            (self.current_windowed_value.with_value(
-                ((element, primary), primary_size)),
-             None),
-            (self.current_windowed_value.with_value(
-                ((element, residual), residual_size)),
-             current_watermark))
+            ((self.current_windowed_value.with_value((
+                (element, primary), primary_size)), None), None),
+            ((self.current_windowed_value.with_value((
+                (element, residual), residual_size)), current_watermark), None))
 
   def current_element_progress(self):
-    restriction_tracker = self.restriction_tracker
+    restriction_tracker = self.threadsafe_restriction_tracker
     if restriction_tracker:
       return restriction_tracker.current_progress()
 
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index 946ef34..fd04d4c 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -51,6 +51,9 @@ class ReadFilesProvider(RestrictionProvider):
   def create_tracker(self, restriction):
     return OffsetRestrictionTracker(restriction)
 
+  def restriction_size(self, element, restriction):
+    return restriction.size()
+
 
 class ReadFiles(DoFn):
 
@@ -63,12 +66,11 @@ class ReadFiles(DoFn):
       restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()),
       *args, **kwargs):
     file_name = element
-    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
 
     with open(file_name, 'rb') as file:
-      pos = restriction_tracker.start_position()
-      if restriction_tracker.start_position() > 0:
-        file.seek(restriction_tracker.start_position() - 1)
+      pos = restriction_tracker.current_restriction().start
+      if restriction_tracker.current_restriction().start > 0:
+        file.seek(restriction_tracker.current_restriction().start - 1)
         line = file.readline()
         pos = pos - 1 + len(line)
 
@@ -104,6 +106,9 @@ class ExpandStringsProvider(RestrictionProvider):
   def split(self, element, restriction):
     return [restriction,]
 
+  def restriction_size(self, element, restriction):
+    return restriction.size()
+
 
 class ExpandStrings(DoFn):
 
@@ -118,10 +123,9 @@ class ExpandStrings(DoFn):
     side.extend(side1)
     side.extend(side2)
     side.extend(side3)
-    assert isinstance(restriction_tracker, OffsetRestrictionTracker)
     side = list(side)
-    for i in range(restriction_tracker.start_position(),
-                   restriction_tracker.stop_position()):
+    for i in range(restriction_tracker.current_restriction().start,
+                   restriction_tracker.current_restriction().stop):
       if restriction_tracker.try_claim(i):
         if not side:
           yield (
diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index e97d65e..377ceb7 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -319,6 +319,9 @@ if __name__ == '__main__':
           line = f.readline()
       self.assertSetEqual(lines_actual, lines_expected)
 
+    def test_sdf_with_watermark_tracking(self):
+      raise unittest.SkipTest("BEAM-2939")
+
     def test_sdf_with_sdf_initiated_checkpointing(self):
       raise unittest.SkipTest("BEAM-2939")
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 2204a24..b7929cb 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -41,6 +41,7 @@ from tenacity import retry
 from tenacity import stop_after_attempt
 
 import apache_beam as beam
+from apache_beam.io import iobase
 from apache_beam.io import restriction_trackers
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.execution import MetricKey
@@ -56,9 +57,11 @@ from apache_beam.testing.test_stream import TestStream
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 from apache_beam.tools import utils
+from apache_beam.transforms import core
 from apache_beam.transforms import environments
 from apache_beam.transforms import userstate
 from apache_beam.transforms import window
+from apache_beam.utils import timestamp
 
 if statesampler.FAST_SAMPLER:
   DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS
@@ -423,20 +426,67 @@ class FnApiRunnerTest(unittest.TestCase):
       assert_that(actual, is_buffered_correctly)
 
   def test_sdf(self):
+    class ExpandingStringsDoFn(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider())):
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
+        while restriction_tracker.try_claim(cur):
+          yield element[cur]
+          cur += 1
 
+    with self.create_pipeline() as p:
+      data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+      actual = (
+          p
+          | beam.Create(data)
+          | beam.ParDo(ExpandingStringsDoFn()))
+      assert_that(actual, equal_to(list(''.join(data))))
+
+  def test_sdf_with_check_done_failed(self):
     class ExpandingStringsDoFn(beam.DoFn):
       def process(
           self,
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(
-            restriction_tracker,
-            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
-        cur = restriction_tracker.start_position()
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           yield element[cur]
           cur += 1
+          return
+    with self.assertRaises(Exception):
+      with self.create_pipeline() as p:
+        data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+        _ = (
+            p
+            | beam.Create(data)
+            | beam.ParDo(ExpandingStringsDoFn()))
+
+  def test_sdf_with_watermark_tracking(self):
+
+    class ExpandingStringsDoFn(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider()),
+          watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
+              core.WatermarkEstimator())):
+        cur = restriction_tracker.current_restriction().start
+        start = cur
+        while restriction_tracker.try_claim(cur):
+          watermark_estimator.set_watermark(timestamp.Timestamp(micros=cur))
+          assert watermark_estimator.current_watermark().micros == start
+          yield element[cur]
+          if cur % 2 == 1:
+            restriction_tracker.defer_remainder(timestamp.Duration(micros=5))
+            return
+          cur += 1
 
     with self.create_pipeline() as p:
       data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
@@ -456,10 +506,8 @@ class FnApiRunnerTest(unittest.TestCase):
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(
               ExpandStringsProvider())):
-        assert isinstance(
-            restriction_tracker,
-            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
-        cur = restriction_tracker.start_position()
+        assert isinstance(restriction_tracker, iobase.RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           counter.inc()
           yield element[cur]
@@ -1123,6 +1171,9 @@ class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest):
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
 
@@ -1142,6 +1193,9 @@ class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
 
@@ -1172,6 +1226,9 @@ class FnApiRunnerTestWithBundleRepeatAndMultiWorkers(FnApiRunnerTest):
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_sdf_with_watermark_tracking(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
 
 class FnApiRunnerSplitTest(unittest.TestCase):
 
@@ -1340,7 +1397,7 @@ class FnApiRunnerSplitTest(unittest.TestCase):
           element,
           restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
         to_emit = []
-        cur = restriction_tracker.start_position()
+        cur = restriction_tracker.current_restriction().start
         while restriction_tracker.try_claim(cur):
           to_emit.append((element, cur))
           element_counter.increment()
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 8439c8f..b3440df 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -32,6 +32,7 @@ from builtins import next
 from builtins import object
 
 from future.utils import itervalues
+from google.protobuf import duration_pb2
 from google.protobuf import timestamp_pb2
 
 import apache_beam as beam
@@ -704,8 +705,7 @@ class BundleProcessor(object):
               ) = split
               if element_primary:
                 split_response.primary_roots.add().CopyFrom(
-                    self.delayed_bundle_application(
-                        *element_primary).application)
+                    self.bundle_application(*element_primary))
               if element_residual:
                 split_response.residual_roots.add().CopyFrom(
                     self.delayed_bundle_application(*element_residual))
@@ -718,22 +718,39 @@ class BundleProcessor(object):
     return split_response
 
   def delayed_bundle_application(self, op, deferred_remainder):
-    transform_id, main_input_tag, main_input_coder, outputs = op.input_info
     # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
-    element_and_restriction, watermark = deferred_remainder
-    if watermark:
-      proto_watermark = timestamp_pb2.Timestamp()
-      proto_watermark.FromMicroseconds(watermark.micros)
-      output_watermarks = {output: proto_watermark for output in outputs}
+    ((element_and_restriction, output_watermark),
+     deferred_watermark) = deferred_remainder
+    if deferred_watermark:
+      assert isinstance(deferred_watermark, timestamp.Duration)
+      proto_deferred_watermark = duration_pb2.Duration()
+      proto_deferred_watermark.FromMicroseconds(deferred_watermark.micros)
     else:
-      output_watermarks = None
+      proto_deferred_watermark = None
     return beam_fn_api_pb2.DelayedBundleApplication(
-        application=beam_fn_api_pb2.BundleApplication(
-            transform_id=transform_id,
-            input_id=main_input_tag,
-            output_watermarks=output_watermarks,
-            element=main_input_coder.get_impl().encode_nested(
-                element_and_restriction)))
+        requested_time_delay=proto_deferred_watermark,
+        application=self.construct_bundle_application(
+            op, output_watermark, element_and_restriction))
+
+  def bundle_application(self, op, primary):
+    ((element_and_restriction, output_watermark),
+     _) = primary
+    return self.construct_bundle_application(
+        op, output_watermark, element_and_restriction)
+
+  def construct_bundle_application(self, op, output_watermark, element):
+    transform_id, main_input_tag, main_input_coder, outputs = op.input_info
+    if output_watermark:
+      proto_output_watermark = timestamp_pb2.Timestamp()
+      proto_output_watermark.FromMicroseconds(output_watermark.micros)
+      output_watermarks = {output: proto_output_watermark for output in outputs}
+    else:
+      output_watermarks = None
+    return beam_fn_api_pb2.BundleApplication(
+        transform_id=transform_id,
+        input_id=main_input_tag,
+        output_watermarks=output_watermarks,
+        element=main_input_coder.get_impl().encode_nested(element))
 
   def metrics(self):
     # DEPRECATED
diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py
index 50740ba..fbef112 100644
--- a/sdks/python/apache_beam/testing/synthetic_pipeline.py
+++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py
@@ -523,7 +523,7 @@ class SyntheticSDFAsSource(beam.DoFn):
       element,
       restriction_tracker=beam.DoFn.RestrictionParam(
           SyntheticSDFSourceRestrictionProvider())):
-    cur = restriction_tracker.start_position()
+    cur = restriction_tracker.current_restriction().start
     while restriction_tracker.try_claim(cur):
       r = np.random.RandomState(cur)
       time.sleep(element['sleep_per_input_record_sec'])
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 148caae..06fd201 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -63,6 +63,7 @@ from apache_beam.typehints.decorators import get_signature
 from apache_beam.typehints.decorators import get_type_hints
 from apache_beam.typehints.trivial_inference import element_type
 from apache_beam.typehints.typehints import is_consistent_with
+from apache_beam.utils import timestamp
 from apache_beam.utils import urns
 
 try:
@@ -91,7 +92,8 @@ __all__ = [
     'Flatten',
     'Create',
     'Impulse',
-    'RestrictionProvider'
+    'RestrictionProvider',
+    'WatermarkEstimator'
     ]
 
 # Type variables
@@ -242,6 +244,8 @@ class RestrictionProvider(object):
   def create_tracker(self, restriction):
     """Produces a new ``RestrictionTracker`` for the given restriction.
 
+    This API is required to be implemented.
+
     Args:
       restriction: an object that defines a restriction as identified by a
         Splittable ``DoFn`` that utilizes the current ``RestrictionProvider``.
@@ -252,7 +256,10 @@ class RestrictionProvider(object):
     raise NotImplementedError
 
   def initial_restriction(self, element):
-    """Produces an initial restriction for the given element."""
+    """Produces an initial restriction for the given element.
+
+    This API is required to be implemented.
+    """
     raise NotImplementedError
 
   def split(self, element, restriction):
@@ -262,6 +269,9 @@ class RestrictionProvider(object):
     reading input element for each of the returned restrictions should be the
     same as the total set of elements produced by reading the input element for
     the input restriction.
+
+    This API is optional if ``split_and_size`` has been implemented.
+
     """
     yield restriction
 
@@ -281,11 +291,16 @@ class RestrictionProvider(object):
 
     By default, asks a newly-created restriction tracker for the default size
     of the restriction.
+
+    This API is required to be implemented.
     """
-    return self.create_tracker(restriction).default_size()
+    raise NotImplementedError
 
   def split_and_size(self, element, restriction):
     """Like split, but also does sizing, returning (restriction, size) pairs.
+
+    This API is optional if ``split`` and ``restriction_size`` have been
+    implemented.
     """
     for part in self.split(element, restriction):
       yield part, self.restriction_size(element, part)
@@ -379,6 +394,43 @@ class RunnerAPIPTransformHolder(PTransform):
     return None
 
 
+class WatermarkEstimator(object):
+  """A WatermarkEstimator which is used for tracking output_watermark in a
+  DoFn.process(), typically tracking per <element, restriction> pair in SDF in
+  streaming.
+
+  There are 3 APIs in this class: set_watermark, current_watermark and reset
+  with default implementations.
+
+  TODO(BEAM-8537): Create WatermarkEstimatorProvider to support different types.
+  """
+  def __init__(self):
+    self._watermark = None
+
+  def set_watermark(self, watermark):
+    """Update tracking output_watermark with latest output_watermark.
+    This function is called inside an SDF.Process() to track the watermark of
+    output element.
+
+    Args:
+      watermark: the `timestamp.Timestamp` of current output element.
+    """
+    if not isinstance(watermark, timestamp.Timestamp):
+      raise ValueError('watermark should be a object of timestamp.Timestamp')
+    if self._watermark is None:
+      self._watermark = watermark
+    else:
+      self._watermark = min(self._watermark, watermark)
+
+  def current_watermark(self):
+    """Get current output_watermark. This function is called by system."""
+    return self._watermark
+
+  def reset(self):
+    """ Reset current tracking watermark to None."""
+    self._watermark = None
+
+
 class _DoFnParam(object):
   """DoFn parameter."""
 
@@ -459,6 +511,17 @@ class _BundleFinalizerParam(_DoFnParam):
     del self._callbacks[:]
 
 
+class _WatermarkEstimatorParam(_DoFnParam):
+  """WatermarkEstomator DoFn parameter."""
+
+  def __init__(self, watermark_estimator):
+    if not isinstance(watermark_estimator, WatermarkEstimator):
+      raise ValueError('DoFn.WatermarkEstimatorParam expected'
+                       'WatermarkEstimator object.')
+    self.watermark_estimator = watermark_estimator
+    self.param_id = 'WatermarkEstimator'
+
+
 class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   """A function object used by a transform with custom processing.
 
@@ -477,7 +540,7 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   TimestampParam = _DoFnParam('TimestampParam')
   WindowParam = _DoFnParam('WindowParam')
   PaneInfoParam = _DoFnParam('PaneInfoParam')
-  WatermarkReporterParam = _DoFnParam('WatermarkReporterParam')
+  WatermarkEstimatorParam = _WatermarkEstimatorParam
   BundleFinalizerParam = _BundleFinalizerParam
   KeyParam = _DoFnParam('KeyParam')
 
@@ -489,7 +552,7 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   TimerParam = _TimerDoFnParam
 
   DoFnProcessParams = [ElementParam, SideInputParam, TimestampParam,
-                       WindowParam, WatermarkReporterParam, PaneInfoParam,
+                       WindowParam, WatermarkEstimatorParam, PaneInfoParam,
                        BundleFinalizerParam, KeyParam, StateParam, TimerParam]
 
   RestrictionParam = _RestrictionDoFnParam
@@ -522,7 +585,7 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
     ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
     provided here to allow treatment as a Splittable ``DoFn``. The restriction
     tracker will be derived from the restriction provider in the parameter.
-    ``DoFn.WatermarkReporterParam``: a function that can be used to report
+    ``DoFn.WatermarkEstimatorParam``: a function that can be used to track
     output watermark of Splittable ``DoFn`` implementations.
 
     Args:
diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py
new file mode 100644
index 0000000..1a27bd2
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/core_test.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for core module."""
+
+from __future__ import absolute_import
+
+import unittest
+
+from apache_beam.transforms.core import WatermarkEstimator
+from apache_beam.utils.timestamp import Timestamp
+
+
+class WatermarkEstimatorTest(unittest.TestCase):
+
+  def test_set_watermark(self):
+    watermark_estimator = WatermarkEstimator()
+    self.assertEqual(watermark_estimator.current_watermark(), None)
+    # set_watermark should only accept timestamp.Timestamp.
+    with self.assertRaises(ValueError):
+      watermark_estimator.set_watermark(0)
+
+    # watermark_estimator should always keep minimal timestamp.
+    watermark_estimator.set_watermark(Timestamp(100))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.set_watermark(Timestamp(150))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.set_watermark(Timestamp(50))
+    self.assertEqual(watermark_estimator.current_watermark(), 50)
+
+  def test_reset(self):
+    watermark_estimator = WatermarkEstimator()
+    watermark_estimator.set_watermark(Timestamp(100))
+    self.assertEqual(watermark_estimator.current_watermark(), 100)
+    watermark_estimator.reset()
+    self.assertEqual(watermark_estimator.current_watermark(), None)
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/sdks/python/apache_beam/utils/timestamp.py b/sdks/python/apache_beam/utils/timestamp.py
index 9bccdfd..a3f3abf 100644
--- a/sdks/python/apache_beam/utils/timestamp.py
+++ b/sdks/python/apache_beam/utils/timestamp.py
@@ -25,6 +25,7 @@ from __future__ import division
 
 import datetime
 import functools
+import time
 from builtins import object
 
 import dateutil.parser
@@ -76,6 +77,10 @@ class Timestamp(object):
     return Timestamp(seconds)
 
   @staticmethod
+  def now():
+    return Timestamp(seconds=time.time())
+
+  @staticmethod
   def _epoch_datetime_utc():
     return datetime.datetime.fromtimestamp(0, pytz.utc)
 
@@ -173,6 +178,8 @@ class Timestamp(object):
     return self + other
 
   def __sub__(self, other):
+    if isinstance(other, Timestamp):
+      return Duration(micros=self.micros - other.micros)
     other = Duration.of(other)
     return Timestamp(micros=self.micros - other.micros)
 
diff --git a/sdks/python/apache_beam/utils/timestamp_test.py b/sdks/python/apache_beam/utils/timestamp_test.py
index d26d561..2a4d454 100644
--- a/sdks/python/apache_beam/utils/timestamp_test.py
+++ b/sdks/python/apache_beam/utils/timestamp_test.py
@@ -100,6 +100,7 @@ class TimestampTest(unittest.TestCase):
     self.assertEqual(Timestamp(123) - Duration(456), -333)
     self.assertEqual(Timestamp(1230) % 456, 318)
     self.assertEqual(Timestamp(1230) % Duration(456), 318)
+    self.assertEqual(Timestamp(123) - Timestamp(100), 23)
 
     # Check that direct comparison of Timestamp and Duration is allowed.
     self.assertTrue(Duration(123) == Timestamp(123))
@@ -116,6 +117,7 @@ class TimestampTest(unittest.TestCase):
     self.assertEqual((Timestamp(123) - Duration(456)).__class__, Timestamp)
     self.assertEqual((Timestamp(1230) % 456).__class__, Duration)
     self.assertEqual((Timestamp(1230) % Duration(456)).__class__, Duration)
+    self.assertEqual((Timestamp(123) - Timestamp(100)).__class__, Duration)
 
     # Unsupported operations.
     with self.assertRaises(TypeError):
@@ -159,6 +161,10 @@ class TimestampTest(unittest.TestCase):
     self.assertEqual('Timestamp(-999999999)',
                      str(Timestamp(-999999999)))
 
+  def test_now(self):
+    now = Timestamp.now()
+    self.assertTrue(isinstance(now, Timestamp))
+
 
 class DurationTest(unittest.TestCase):
 
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index e3794ba..bc77fd8 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -183,6 +183,7 @@ ignore_identifiers = [
   '_TimerDoFnParam',
   '_BundleFinalizerParam',
   '_RestrictionDoFnParam',
+  '_WatermarkEstimatorParam',
 
   # Sphinx cannot find this py:class reference target
   'typing.Generic',