You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2020/07/18 23:18:10 UTC

[beam] branch master updated: Insert TruncateSizedRestriction when pipeline starts to drain.

This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4decedd  Insert TruncateSizedRestriction when pipeline starts to drain.
     new f841d60  Merge pull request #12289 from boyuanzz/drain_py
4decedd is described below

commit 4decedd9cb6e8e4eef3a41a349a44466eea50f06
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Mon Jun 15 10:21:33 2020 -0700

    Insert TruncateSizedRestriction when pipeline starts to drain.
---
 sdks/python/apache_beam/io/iobase.py               |  20 ++
 sdks/python/apache_beam/io/restriction_trackers.py |   3 +
 sdks/python/apache_beam/runners/common.pxd         |   1 +
 sdks/python/apache_beam/runners/common.py          |   8 +
 sdks/python/apache_beam/runners/common_test.py     |  15 ++
 .../runners/portability/fn_api_runner/execution.py |  14 ++
 .../runners/portability/fn_api_runner/fn_runner.py |   8 +-
 .../portability/fn_api_runner/fn_runner_test.py    | 214 +++++++++++++++++----
 .../portability/fn_api_runner/translations.py      |  66 +++++--
 .../runners/portability/portable_runner_test.py    |  14 +-
 sdks/python/apache_beam/runners/sdf_utils.py       |   6 +
 .../apache_beam/runners/worker/bundle_processor.py |  37 +++-
 .../apache_beam/runners/worker/operations.pxd      |   2 +
 .../apache_beam/runners/worker/operations.py       |  14 ++
 sdks/python/apache_beam/transforms/core.py         |  28 +++
 sdks/python/apache_beam/transforms/environments.py |   2 +
 .../apache_beam/transforms/environments_test.py    |   4 +
 17 files changed, 400 insertions(+), 56 deletions(-)

diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 95cc568..c23f358 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -1244,6 +1244,23 @@ class RestrictionTracker(object):
     """
     raise NotImplementedError
 
+  def is_bounded(self):
+    """Returns whether the amount of work represented by the current restriction
+    is bounded.
+
+    The boundedness of the restriction is used to determine the default behavior
+    of how to truncate restrictions when a pipeline is being
+    `drained <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#>`_.  # pylint: disable=line-too-long
+    If the restriction is bounded, then the entire restriction will be processed
+    otherwise the restriction will be processed till a checkpoint is possible.
+
+    The API is required to be implemented.
+
+    Returns: ``True`` if the restriction represents a finite amount of work.
+    Otherwise, returns ``False``.
+    """
+    raise NotImplementedError
+
 
 class WatermarkEstimator(object):
   """A WatermarkEstimator which is used for estimating output_watermark based on
@@ -1442,6 +1459,9 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
     def check_done(self):
       return self.restriction.range_tracker().fraction_consumed() >= 1.0
 
+    def is_bounded(self):
+      return True
+
   class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
     """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
     def __init__(self, source, desired_chunk_size=None):
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index 9a53cc0..2420c0b 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -161,3 +161,6 @@ class OffsetRestrictionTracker(RestrictionTracker):
           self._checkpointed = True
         self._range, residual_range = self._range.split_at(split_point)
         return self._range, residual_range
+
+  def is_bounded(self):
+    return True
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index e83ffe1..05f7a99 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -44,6 +44,7 @@ cdef class MethodWrapper(object):
   cdef object restriction_provider_arg_name
   cdef object watermark_estimator_provider
   cdef object watermark_estimator_provider_arg_name
+  cdef bint unbounded_per_element
 
 
 cdef class DoFnSignature(object):
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 6ba7771..1c4a25bd 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -199,6 +199,11 @@ class MethodWrapper(object):
     self.watermark_estimator_provider = None
     self.watermark_estimator_provider_arg_name = None
 
+    if hasattr(self.method_value, 'unbounded_per_element'):
+      self.unbounded_per_element = True
+    else:
+      self.unbounded_per_element = False
+
     for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
       if isinstance(v, core.DoFn.StateParam):
         self.state_args_to_replace[kw] = v.state_spec
@@ -307,6 +312,9 @@ class DoFnSignature(object):
     # type: () -> WatermarkEstimatorProvider
     return self.process_method.watermark_estimator_provider
 
+  def is_unbounded_per_element(self):
+    return self.process_method.unbounded_per_element
+
   def _validate(self):
     # type: () -> None
     self._validate_process()
diff --git a/sdks/python/apache_beam/runners/common_test.py b/sdks/python/apache_beam/runners/common_test.py
index 1e0120d..c9860e6 100644
--- a/sdks/python/apache_beam/runners/common_test.py
+++ b/sdks/python/apache_beam/runners/common_test.py
@@ -74,6 +74,21 @@ class DoFnSignatureTest(unittest.TestCase):
     with self.assertRaises(ValueError):
       DoFnSignature(MyDoFn())
 
+  def test_unbounded_element_process_fn(self):
+    class UnboundedDoFn(DoFn):
+      @DoFn.unbounded_per_element()
+      def process(self, element):
+        pass
+
+    class BoundedDoFn(DoFn):
+      def process(self, element):
+        pass
+
+    signature = DoFnSignature(UnboundedDoFn())
+    self.assertTrue(signature.is_unbounded_per_element())
+    signature = DoFnSignature(BoundedDoFn())
+    self.assertFalse(signature.is_unbounded_per_element())
+
 
 class DoFnProcessTest(unittest.TestCase):
   # pylint: disable=expression-not-assigned
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 251a696..5b8e91c 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -683,7 +683,21 @@ class BundleContextManager(object):
     input_pcoll = self.process_bundle_descriptor.transforms[
         transform_id].inputs[input_id]
     for read_id, proto in self.process_bundle_descriptor.transforms.items():
+      # The GrpcRead is followed by the SDF/Process.
       if (proto.spec.urn == bundle_processor.DATA_INPUT_URN and
           input_pcoll in proto.outputs.values()):
         return read_id
+      # The GrpcRead is followed by the SDF/Truncate -> SDF/Process.
+      if (proto.spec.urn ==
+          common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn and
+          input_pcoll in proto.outputs.values()):
+        read_input = list(
+            self.process_bundle_descriptor.transforms[read_id].inputs.values()
+        )[0]
+        for (grpc_read,
+             transform_proto) in self.process_bundle_descriptor.transforms.items():  # pylint: disable=line-too-long
+          if (transform_proto.spec.urn == bundle_processor.DATA_INPUT_URN and
+              read_input in transform_proto.outputs.values()):
+            return grpc_read
+
     raise RuntimeError('No IO transform feeds %s' % transform_id)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index bbbab4d..e53cafe 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -91,7 +91,8 @@ class FnApiRunner(runner.PipelineRunner):
       bundle_repeat=0,
       use_state_iterables=False,
       provision_info=None,  # type: Optional[ExtendedProvisionInfo]
-      progress_request_frequency=None):
+      progress_request_frequency=None,
+      is_drain=False):
     # type: (...) -> None
 
     """Creates a new Fn API Runner.
@@ -105,6 +106,7 @@ class FnApiRunner(runner.PipelineRunner):
       provision_info: provisioning info to make available to workers, or None
       progress_request_frequency: The frequency (in seconds) that the runner
           waits before requesting progress from the SDK.
+      is_drain: identify whether expand the sdf graph in the drain mode.
     """
     super(FnApiRunner, self).__init__()
     self._default_environment = (
@@ -114,6 +116,7 @@ class FnApiRunner(runner.PipelineRunner):
     self._progress_frequency = progress_request_frequency
     self._profiler_factory = None  # type: Optional[Callable[..., profiler.Profile]]
     self._use_state_iterables = use_state_iterables
+    self._is_drain = is_drain
     self._provision_info = provision_info or ExtendedProvisionInfo(
         beam_provision_api_pb2.ProvisionInfo(
             retrieval_token='unused-retrieval-token'))
@@ -304,7 +307,8 @@ class FnApiRunner(runner.PipelineRunner):
             common_urns.primitives.FLATTEN.urn,
             common_urns.primitives.GROUP_BY_KEY.urn
         ]),
-        use_state_iterables=self._use_state_iterables)
+        use_state_iterables=self._use_state_iterables,
+        is_drain=self._is_drain)
 
   def run_stages(self,
                  stage_context,  # type: translations.TransformContext
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
index 95e3b5f..eca6491 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
@@ -98,8 +98,8 @@ def has_urn_and_labels(mi, urn, labels):
 
 
 class FnApiRunnerTest(unittest.TestCase):
-  def create_pipeline(self):
-    return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
+  def create_pipeline(self, is_drain=False):
+    return beam.Pipeline(runner=fn_api_runner.FnApiRunner(is_drain=is_drain))
 
   def test_assert_that(self):
     # TODO: figure out a way for fn_api_runner to parse and raise the
@@ -559,8 +559,7 @@ class FnApiRunnerTest(unittest.TestCase):
       actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
       assert_that(actual, equal_to(list(''.join(data))))
 
-  def test_sdf_with_sdf_initiated_checkpointing(self):
-
+  def run_sdf_initiated_checkpointing(self, is_drain=False):
     counter = beam.metrics.Metrics.counter('ns', 'my_counter')
 
     class ExpandStringsDoFn(beam.DoFn):
@@ -579,7 +578,7 @@ class FnApiRunnerTest(unittest.TestCase):
             return
           cur += 1
 
-    with self.create_pipeline() as p:
+    with self.create_pipeline(is_drain=is_drain) as p:
       data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
       actual = (p | beam.Create(data) | beam.ParDo(ExpandStringsDoFn()))
 
@@ -591,6 +590,63 @@ class FnApiRunnerTest(unittest.TestCase):
       self.assertEqual(1, len(counters))
       self.assertEqual(counters[0].committed, len(''.join(data)))
 
+  def test_sdf_with_sdf_initiated_checkpointing(self):
+    self.run_sdf_initiated_checkpointing(is_drain=False)
+
+  def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+    self.run_sdf_initiated_checkpointing(is_drain=True)
+
+  def test_sdf_default_truncate_when_bounded(self):
+    class SimleSDF(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              OffsetRangeProvider(use_bounded_offset_range=True))):
+        assert isinstance(restriction_tracker, RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
+        while restriction_tracker.try_claim(cur):
+          yield cur
+          cur += 1
+
+    with self.create_pipeline(is_drain=True) as p:
+      actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
+      assert_that(actual, equal_to(range(10)))
+
+  def test_sdf_default_truncate_when_unbounded(self):
+    class SimleSDF(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              OffsetRangeProvider(use_bounded_offset_range=False))):
+        assert isinstance(restriction_tracker, RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
+        while restriction_tracker.try_claim(cur):
+          yield cur
+          cur += 1
+
+    with self.create_pipeline(is_drain=True) as p:
+      actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
+      assert_that(actual, equal_to([]))
+
+  def test_sdf_with_truncate(self):
+    class SimleSDF(beam.DoFn):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              OffsetRangeProviderWithTruncate())):
+        assert isinstance(restriction_tracker, RestrictionTrackerView)
+        cur = restriction_tracker.current_restriction().start
+        while restriction_tracker.try_claim(cur):
+          yield cur
+          cur += 1
+
+    with self.create_pipeline(is_drain=True) as p:
+      actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
+      assert_that(actual, equal_to(range(5)))
+
   def test_group_by_key(self):
     with self.create_pipeline() as p:
       res = (
@@ -1240,25 +1296,28 @@ class FnApiRunnerMetricsTest(unittest.TestCase):
 
 
 class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
-  def create_pipeline(self):
+  def create_pipeline(self, is_drain=False):
     return beam.Pipeline(
         runner=fn_api_runner.FnApiRunner(
-            default_environment=environments.EmbeddedPythonGrpcEnvironment()))
+            default_environment=environments.EmbeddedPythonGrpcEnvironment(),
+            is_drain=is_drain))
 
 
 class FnApiRunnerTestWithDisabledCaching(FnApiRunnerTest):
-  def create_pipeline(self):
+  def create_pipeline(self, is_drain=False):
     return beam.Pipeline(
         runner=fn_api_runner.FnApiRunner(
             default_environment=environments.EmbeddedPythonGrpcEnvironment(
-                state_cache_size=0, data_buffer_time_limit_ms=0)))
+                state_cache_size=0, data_buffer_time_limit_ms=0),
+            is_drain=is_drain))
 
 
 class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest):
-  def create_pipeline(self):
+  def create_pipeline(self, is_drain=False):
     pipeline_options = PipelineOptions(direct_num_workers=2)
     p = beam.Pipeline(
-        runner=fn_api_runner.FnApiRunner(), options=pipeline_options)
+        runner=fn_api_runner.FnApiRunner(is_drain=is_drain),
+        options=pipeline_options)
     #TODO(BEAM-8444): Fix these tests.
     p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
     return p
@@ -1269,16 +1328,20 @@ class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest):
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
   def test_sdf_with_watermark_tracking(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
 
 class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
-  def create_pipeline(self):
+  def create_pipeline(self, is_drain=False):
     pipeline_options = PipelineOptions(
         direct_num_workers=2, direct_running_mode='multi_threading')
     p = beam.Pipeline(
-        runner=fn_api_runner.FnApiRunner(), options=pipeline_options)
+        runner=fn_api_runner.FnApiRunner(is_drain=is_drain),
+        options=pipeline_options)
     #TODO(BEAM-8444): Fix these tests.
     p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
     return p
@@ -1289,23 +1352,27 @@ class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
   def test_sdf_with_watermark_tracking(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
 
 class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
-  def create_pipeline(self):
-    return beam.Pipeline(runner=fn_api_runner.FnApiRunner(bundle_repeat=3))
+  def create_pipeline(self, is_drain=False):
+    return beam.Pipeline(
+        runner=fn_api_runner.FnApiRunner(bundle_repeat=3, is_drain=is_drain))
 
   def test_register_finalizations(self):
     raise unittest.SkipTest("TODO: Avoid bundle finalizations on repeat.")
 
 
 class FnApiRunnerTestWithBundleRepeatAndMultiWorkers(FnApiRunnerTest):
-  def create_pipeline(self):
+  def create_pipeline(self, is_drain=False):
     pipeline_options = PipelineOptions(direct_num_workers=2)
     p = beam.Pipeline(
-        runner=fn_api_runner.FnApiRunner(bundle_repeat=3),
+        runner=fn_api_runner.FnApiRunner(bundle_repeat=3, is_drain=is_drain),
         options=pipeline_options)
     #TODO(BEAM-8444): Fix these tests.
     p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
@@ -1320,17 +1387,21 @@ class FnApiRunnerTestWithBundleRepeatAndMultiWorkers(FnApiRunnerTest):
   def test_sdf_with_sdf_initiated_checkpointing(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
+  def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+    raise unittest.SkipTest("This test is for a single worker only.")
+
   def test_sdf_with_watermark_tracking(self):
     raise unittest.SkipTest("This test is for a single worker only.")
 
 
 class FnApiRunnerSplitTest(unittest.TestCase):
-  def create_pipeline(self):
+  def create_pipeline(self, is_drain=False):
     # Must be GRPC so we can send data and split requests concurrent
     # to the bundle process request.
     return beam.Pipeline(
         runner=fn_api_runner.FnApiRunner(
-            default_environment=environments.EmbeddedPythonGrpcEnvironment()))
+            default_environment=environments.EmbeddedPythonGrpcEnvironment(),
+            is_drain=is_drain))
 
   def test_checkpoint(self):
     # This split manager will get re-invoked on each smaller split,
@@ -1387,16 +1458,7 @@ class FnApiRunnerSplitTest(unittest.TestCase):
             | beam.Map(lambda x: element_counter.increment() or x))
         assert_that(res, equal_to(elements))
 
-  def test_nosplit_sdf(self):
-    def split_manager(num_elements):
-      yield
-
-    elements = [1, 2, 3]
-    expected_groups = [[(e, k) for k in range(e)] for e in elements]
-    self.run_sdf_split_pipeline(
-        split_manager, elements, ElementCounter(), expected_groups)
-
-  def test_checkpoint_sdf(self):
+  def run_sdf_checkpoint(self, is_drain=False):
     element_counter = ElementCounter()
 
     def split_manager(num_elements):
@@ -1409,13 +1471,17 @@ class FnApiRunnerSplitTest(unittest.TestCase):
         breakpoint.clear()
 
     # Everything should be perfectly split.
+
     elements = [2, 3]
     expected_groups = [[(2, 0)], [(2, 1)], [(3, 0)], [(3, 1)], [(3, 2)]]
     self.run_sdf_split_pipeline(
-        split_manager, elements, element_counter, expected_groups)
-
-  def test_split_half_sdf(self):
+        split_manager,
+        elements,
+        element_counter,
+        expected_groups,
+        is_drain=is_drain)
 
+  def run_sdf_split_half(self, is_drain=False):
     element_counter = ElementCounter()
     is_first_bundle = [True]  # emulate nonlocal for Python 2
 
@@ -1438,9 +1504,13 @@ class FnApiRunnerSplitTest(unittest.TestCase):
                                                               (4, 2), (4, 3)]]
 
     self.run_sdf_split_pipeline(
-        split_manager, elements, element_counter, expected_groups)
+        split_manager,
+        elements,
+        element_counter,
+        expected_groups,
+        is_drain=is_drain)
 
-  def test_split_crazy_sdf(self, seed=None):
+  def run_split_crazy_sdf(self, seed=None, is_drain=False):
     if seed is None:
       seed = random.randrange(1 << 20)
     r = random.Random(seed)
@@ -1459,13 +1529,46 @@ class FnApiRunnerSplitTest(unittest.TestCase):
 
     try:
       elements = [r.randrange(5, 10) for _ in range(5)]
-      self.run_sdf_split_pipeline(split_manager, elements, element_counter)
+      self.run_sdf_split_pipeline(
+          split_manager, elements, element_counter, is_drain=is_drain)
     except Exception:
       _LOGGER.error('test_split_crazy_sdf.seed = %s', seed)
       raise
 
+  def test_nosplit_sdf(self):
+    def split_manager(num_elements):
+      yield
+
+    elements = [1, 2, 3]
+    expected_groups = [[(e, k) for k in range(e)] for e in elements]
+    self.run_sdf_split_pipeline(
+        split_manager, elements, ElementCounter(), expected_groups)
+
+  def test_checkpoint_sdf(self):
+    self.run_sdf_checkpoint(is_drain=False)
+
+  def test_checkpoint_draining_sdf(self):
+    self.run_sdf_checkpoint(is_drain=True)
+
+  def test_split_half_sdf(self):
+    self.run_sdf_split_half(is_drain=False)
+
+  def test_split_half_draining_sdf(self):
+    self.run_sdf_split_half(is_drain=True)
+
+  def test_split_crazy_sdf(self, seed=None):
+    self.run_split_crazy_sdf(seed=seed, is_drain=False)
+
+  def test_split_crazy_draining_sdf(self, seed=None):
+    self.run_split_crazy_sdf(seed=seed, is_drain=True)
+
   def run_sdf_split_pipeline(
-      self, split_manager, elements, element_counter, expected_groups=None):
+      self,
+      split_manager,
+      elements,
+      element_counter,
+      expected_groups=None,
+      is_drain=False):
     # Define an SDF that for each input x produces [(x, k) for k in range(x)].
 
     class EnumerateProvider(beam.transforms.core.RestrictionProvider):
@@ -1482,6 +1585,9 @@ class FnApiRunnerSplitTest(unittest.TestCase):
       def restriction_size(self, element, restriction):
         return restriction.size()
 
+      def is_bounded(self):
+        return True
+
     class EnumerateSdf(beam.DoFn):
       def process(
           self,
@@ -1499,7 +1605,7 @@ class FnApiRunnerSplitTest(unittest.TestCase):
     expected = [(e, k) for e in elements for k in range(e)]
 
     with fn_runner.split_manager('SDF', split_manager):
-      with self.create_pipeline() as p:
+      with self.create_pipeline(is_drain=is_drain) as p:
         grouped = (
             p
             | beam.Create(elements, reshuffle=False)
@@ -1627,6 +1733,40 @@ class ExpandStringsProvider(beam.transforms.core.RestrictionProvider):
     return restriction.size()
 
 
+class UnboundedOffsetRestrictionTracker(
+    restriction_trackers.OffsetRestrictionTracker):
+  def is_bounded(self):
+    return False
+
+
+class OffsetRangeProvider(beam.transforms.core.RestrictionProvider):
+  def __init__(self, use_bounded_offset_range):
+    self.use_bounded_offset_range = use_bounded_offset_range
+
+  def initial_restriction(self, element):
+    return restriction_trackers.OffsetRange(0, element)
+
+  def create_tracker(self, restriction):
+    if self.use_bounded_offset_range:
+      return restriction_trackers.OffsetRestrictionTracker(restriction)
+    return UnboundedOffsetRestrictionTracker(restriction)
+
+  def split(self, element, restriction):
+    return [restriction]
+
+  def restriction_size(self, element, restriction):
+    return restriction.size()
+
+
+class OffsetRangeProviderWithTruncate(OffsetRangeProvider):
+  def __init__(self):
+    super(OffsetRangeProviderWithTruncate, self).__init__(True)
+
+  def truncate(self, element, restriction):
+    return restriction_trackers.OffsetRange(
+        restriction.start, restriction.stop // 2)
+
+
 class FnApiBasedLullLoggingTest(unittest.TestCase):
   def create_pipeline(self):
     return beam.Pipeline(
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
index 34dcf48..00c440e 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
@@ -69,6 +69,7 @@ PAR_DO_URNS = frozenset([
     common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
     common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
     common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
+    common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
 ])
 
 IMPULSE_BUFFER = b'impulse'
@@ -343,13 +344,15 @@ class TransformContext(object):
   def __init__(self,
                components,  # type: beam_runner_api_pb2.Components
                known_runner_urns,  # type: FrozenSet[str]
-               use_state_iterables=False
+               use_state_iterables=False,
+               is_drain=False
               ):
     self.components = components
     self.known_runner_urns = known_runner_urns
     self.runner_only_urns = known_runner_urns - frozenset(
         [common_urns.primitives.FLATTEN.urn])
     self.use_state_iterables = use_state_iterables
+    self.is_drain = is_drain
     # ok to pass None for context because BytesCoder has no components
     coder_proto = coders.BytesCoder().to_runner_api(
         None)  # type: ignore[arg-type]
@@ -547,7 +550,8 @@ def pipeline_from_stages(pipeline_proto,  # type: beam_runner_api_pb2.Pipeline
 def create_and_optimize_stages(pipeline_proto,  # type: beam_runner_api_pb2.Pipeline
                                phases,
                                known_runner_urns,  # type: FrozenSet[str]
-                               use_state_iterables=False
+                               use_state_iterables=False,
+    is_drain=False
                               ):
   # type: (...) -> Tuple[TransformContext, List[Stage]]
 
@@ -568,7 +572,8 @@ def create_and_optimize_stages(pipeline_proto,  # type: beam_runner_api_pb2.Pipe
   pipeline_context = TransformContext(
       pipeline_proto.components,
       known_runner_urns,
-      use_state_iterables=use_state_iterables)
+      use_state_iterables=use_state_iterables,
+      is_drain=is_drain)
 
   # Initial set of stages are singleton leaf transforms.
   stages = list(
@@ -973,6 +978,7 @@ def expand_sdf(stages, context):
             inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
             outputs={'out': split_pcoll_id})
 
+        reshuffle_stage = None
         if common_urns.composites.RESHUFFLE.urn in context.known_runner_urns:
           reshuffle_pcoll_id = copy_like(
               context.components.pcollections,
@@ -987,25 +993,57 @@ def expand_sdf(stages, context):
               payload=b'',
               inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}),
               outputs={'out': reshuffle_pcoll_id})
-          yield make_stage(stage, reshuffle_transform_id)
+          reshuffle_stage = make_stage(stage, reshuffle_transform_id)
         else:
           reshuffle_pcoll_id = split_pcoll_id
           reshuffle_transform_id = None
 
-        process_transform_id = copy_like(
-            context.components.transforms,
-            transform,
-            unique_name=transform.unique_name + '/Process',
-            urn=common_urns.sdf_components.
-            PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
-            inputs=dict(
-                transform.inputs, **{main_input_tag: reshuffle_pcoll_id}))
+        if context.is_drain:
+          truncate_pcoll_id = copy_like(
+              context.components.pcollections,
+              main_input_id,
+              '_truncate_restriction',
+              coder_id=sized_coder_id)
+          # Lengthprefix the truncate output.
+          context.length_prefix_pcoll_coders(truncate_pcoll_id)
+          truncate_transform_id = copy_like(
+              context.components.transforms,
+              transform,
+              unique_name=transform.unique_name + '/TruncateAndSizeRestriction',
+              urn=common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
+              inputs=dict(
+                  transform.inputs, **{main_input_tag: reshuffle_pcoll_id}),
+              outputs={'out': truncate_pcoll_id})
+          process_transform_id = copy_like(
+              context.components.transforms,
+              transform,
+              unique_name=transform.unique_name + '/Process',
+              urn=common_urns.sdf_components.
+              PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
+              inputs=dict(
+                  transform.inputs, **{main_input_tag: truncate_pcoll_id}))
+        else:
+          process_transform_id = copy_like(
+              context.components.transforms,
+              transform,
+              unique_name=transform.unique_name + '/Process',
+              urn=common_urns.sdf_components.
+              PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
+              inputs=dict(
+                  transform.inputs, **{main_input_tag: reshuffle_pcoll_id}))
 
         yield make_stage(stage, pair_transform_id)
         split_stage = make_stage(stage, split_transform_id)
         yield split_stage
-        yield make_stage(
-            stage, process_transform_id, extra_must_follow=[split_stage])
+        if reshuffle_stage:
+          yield reshuffle_stage
+        if context.is_drain:
+          yield make_stage(
+              stage, truncate_transform_id, extra_must_follow=[split_stage])
+          yield make_stage(stage, process_transform_id)
+        else:
+          yield make_stage(
+              stage, process_transform_id, extra_must_follow=[split_stage])
 
       else:
         yield stage
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 5264f7f..c9a33c1 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -197,7 +197,7 @@ class PortableRunnerTest(fn_runner_test.FnApiRunnerTest):
         'data_buffer_time_limit_ms=1000')
     return options
 
-  def create_pipeline(self):
+  def create_pipeline(self, is_drain=False):
     return beam.Pipeline(self.get_runner(), self.create_options())
 
   def test_pardo_state_with_custom_key_coder(self):
@@ -242,6 +242,18 @@ class PortableRunnerTest(fn_runner_test.FnApiRunnerTest):
 
   # Inherits all other tests from fn_api_runner_test.FnApiRunnerTest
 
+  def test_sdf_default_truncate_when_bounded(self):
+    raise unittest.SkipTest("Portable runners don't support drain yet.")
+
+  def test_sdf_default_truncate_when_unbounded(self):
+    raise unittest.SkipTest("Portable runners don't support drain yet.")
+
+  def test_sdf_with_truncate(self):
+    raise unittest.SkipTest("Portable runners don't support drain yet.")
+
+  def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+    raise unittest.SkipTest("Portable runners don't support drain yet.")
+
 
 @unittest.skip("BEAM-7248")
 class PortableRunnerOptimized(PortableRunnerTest):
diff --git a/sdks/python/apache_beam/runners/sdf_utils.py b/sdks/python/apache_beam/runners/sdf_utils.py
index 1d92c7a..aa91bda 100644
--- a/sdks/python/apache_beam/runners/sdf_utils.py
+++ b/sdks/python/apache_beam/runners/sdf_utils.py
@@ -151,6 +151,9 @@ class ThreadsafeRestrictionTracker(object):
       return self._deferred_residual, self._deferred_timestamp
     return None
 
+  def is_bounded(self):
+    return self._restriction_tracker.is_bounded()
+
 
 class RestrictionTrackerView(object):
   """A DoFn view of thread-safe RestrictionTracker.
@@ -178,6 +181,9 @@ class RestrictionTrackerView(object):
   def defer_remainder(self, deferred_time=None):
     self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
 
+  def is_bounded(self):
+    self._threadsafe_restriction_tracker.is_bounded()
+
 
 class ThreadsafeWatermarkEstimator(object):
   """A threadsafe wrapper which wraps a WatermarkEstimator with locking
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 3e57046..57a5c6b 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -1079,6 +1079,7 @@ class BundleProcessor(object):
                                   ):
     # type: (...) -> beam_fn_api_pb2.BundleApplication
     transform_id, main_input_tag, main_input_coder, outputs = op.input_info
+
     if output_watermark:
       proto_output_watermark = proto_utils.from_micros(
           timestamp_pb2.Timestamp, output_watermark.micros)
@@ -1426,6 +1427,31 @@ def create_split_and_size_restrictions(*args):
 
 
 @BeamTransformFactory.register_urn(
+    common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
+    beam_runner_api_pb2.ParDoPayload)
+def create_truncate_sized_restriction(*args):
+  class TruncateAndSizeRestriction(beam.DoFn):
+    def __init__(self, fn, restriction_provider, watermark_estimator_provider):
+      self.restriction_provider = restriction_provider
+
+    def process(self, element_restriction, *args, **kwargs):
+      ((element, (restriction, estimator_state)), _) = element_restriction
+      truncated_restriction = self.restriction_provider.truncate(
+          element, restriction)
+      if truncated_restriction:
+        truncated_restriction_size = (
+            self.restriction_provider.restriction_size(
+                element, truncated_restriction))
+        yield ((element, (truncated_restriction, estimator_state)),
+               truncated_restriction_size)
+
+  return _create_sdf_operation(
+      TruncateAndSizeRestriction,
+      *args,
+      operation_cls=operations.SdfTruncateSizedRestrictions)
+
+
+@BeamTransformFactory.register_urn(
     common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
     beam_runner_api_pb2.ParDoPayload)
 def create_process_sized_elements_and_restrictions(
@@ -1448,7 +1474,13 @@ def create_process_sized_elements_and_restrictions(
 
 
 def _create_sdf_operation(
-    proxy_dofn, factory, transform_id, transform_proto, parameter, consumers):
+    proxy_dofn,
+    factory,
+    transform_id,
+    transform_proto,
+    parameter,
+    consumers,
+    operation_cls=operations.DoOperation):
 
   dofn_data = pickler.loads(parameter.do_fn.payload)
   dofn = dofn_data[0]
@@ -1464,7 +1496,8 @@ def _create_sdf_operation(
       transform_proto,
       consumers,
       serialized_fn,
-      parameter)
+      parameter,
+      operation_cls=operation_cls)
 
 
 @BeamTransformFactory.register_urn(
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd
index 36fa809..800e587 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -102,6 +102,8 @@ cdef class SdfProcessSizedElements(DoOperation):
   cdef object lock
   cdef object element_start_output_bytes
 
+cdef class SdfTruncateSizedRestrictions(DoOperation):
+  pass
 
 cdef class CombineOperation(Operation):
   cdef object phased_combine_fn
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
index 3585ad2..51973f0 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -743,6 +743,20 @@ class DoOperation(Operation):
     return infos
 
 
+class SdfTruncateSizedRestrictions(DoOperation):
+  def __init__(self, *args, **kwargs):
+    super(SdfTruncateSizedRestrictions, self).__init__(*args, **kwargs)
+
+  def current_element_progress(self):
+    # type: () -> Optional[iobase.RestrictionProgress]
+    return self.receivers[0].current_element_progress()
+
+  def try_split(
+      self, fraction_of_remainder
+  ):  # type: (...) -> Optional[Tuple[Iterable[SdfSplitResultsPrimary], Iterable[SdfSplitResultsResidual]]]
+    return self.receivers[0].try_split(fraction_of_remainder)
+
+
 class SdfProcessSizedElements(DoOperation):
   def __init__(self, *args, **kwargs):
     super(SdfProcessSizedElements, self).__init__(*args, **kwargs)
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index f1d6dce..30b05e0 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -320,6 +320,24 @@ class RestrictionProvider(object):
     for part in self.split(element, restriction):
       yield part, self.restriction_size(element, part)
 
+  def truncate(self, element, restriction):
+    """Truncates the provided restriction into a restriction representing a
+    finite amount of work when the pipeline is
+    `draining <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#> for additional details about drain.>_`.  # pylint: disable=line-too-long
+    By default, if the restriction is bounded then the restriction will be
+    returned otherwise None will be returned.
+
+    This API is optional and should only be implemented if more granularity is
+    required.
+
+    Return a truncated finite restriction if further processing is required
+    otherwise return None to represent that no further processing of this
+    restriction is required.
+    """
+    restriction_tracker = self.create_tracker(restriction)
+    if restriction_tracker.is_bounded():
+      return restriction
+
 
 def get_function_arguments(obj, func):
   # type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]]
@@ -622,6 +640,16 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   def from_callable(fn):
     return CallableWrapperDoFn(fn)
 
+  @staticmethod
+  def unbounded_per_element():
+    """A decorator on process fn specifying that the fn performs an unbounded
+    amount of work per input element."""
+    def wrapper(process_fn):
+      process_fn.unbounded_per_element = True
+      return process_fn
+
+    return wrapper
+
   def default_label(self):
     return self.__class__.__name__
 
diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py
index 6061987..3d5325f 100644
--- a/sdks/python/apache_beam/transforms/environments.py
+++ b/sdks/python/apache_beam/transforms/environments.py
@@ -595,6 +595,8 @@ def _python_sdk_capabilities_iter():
   yield common_urns.protocols.LEGACY_PROGRESS_REPORTING.urn
   yield common_urns.protocols.WORKER_STATUS.urn
   yield 'beam:version:sdk_base:' + DockerEnvironment.default_docker_image()
+  #TODO(BEAM-10530): Add truncate capability.
+  # yield common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn
 
 
 def python_sdk_dependencies(options, tmp_dir=None):
diff --git a/sdks/python/apache_beam/transforms/environments_test.py b/sdks/python/apache_beam/transforms/environments_test.py
index 6259d71..46be840 100644
--- a/sdks/python/apache_beam/transforms/environments_test.py
+++ b/sdks/python/apache_beam/transforms/environments_test.py
@@ -73,6 +73,10 @@ class RunnerApiTest(unittest.TestCase):
     sdk_capabilities = environments.python_sdk_capabilities()
     self.assertIn(common_urns.coders.LENGTH_PREFIX.urn, sdk_capabilities)
     self.assertIn(common_urns.protocols.WORKER_STATUS.urn, sdk_capabilities)
+    #TODO(BEAM-10530): Add truncate capability.
+    # self.assertIn(
+    #     common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
+    #     sdk_capabilities)
 
   def test_default_capabilities(self):
     environment = DockerEnvironment.from_options(