You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2020/07/18 23:18:10 UTC
[beam] branch master updated: Insert TruncateSizedRestriction when
pipeline starts to drain.
This is an automated email from the ASF dual-hosted git repository.
boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4decedd Insert TruncateSizedRestriction when pipeline starts to drain.
new f841d60 Merge pull request #12289 from boyuanzz/drain_py
4decedd is described below
commit 4decedd9cb6e8e4eef3a41a349a44466eea50f06
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Mon Jun 15 10:21:33 2020 -0700
Insert TruncateSizedRestriction when pipeline starts to drain.
---
sdks/python/apache_beam/io/iobase.py | 20 ++
sdks/python/apache_beam/io/restriction_trackers.py | 3 +
sdks/python/apache_beam/runners/common.pxd | 1 +
sdks/python/apache_beam/runners/common.py | 8 +
sdks/python/apache_beam/runners/common_test.py | 15 ++
.../runners/portability/fn_api_runner/execution.py | 14 ++
.../runners/portability/fn_api_runner/fn_runner.py | 8 +-
.../portability/fn_api_runner/fn_runner_test.py | 214 +++++++++++++++++----
.../portability/fn_api_runner/translations.py | 66 +++++--
.../runners/portability/portable_runner_test.py | 14 +-
sdks/python/apache_beam/runners/sdf_utils.py | 6 +
.../apache_beam/runners/worker/bundle_processor.py | 37 +++-
.../apache_beam/runners/worker/operations.pxd | 2 +
.../apache_beam/runners/worker/operations.py | 14 ++
sdks/python/apache_beam/transforms/core.py | 28 +++
sdks/python/apache_beam/transforms/environments.py | 2 +
.../apache_beam/transforms/environments_test.py | 4 +
17 files changed, 400 insertions(+), 56 deletions(-)
diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py
index 95cc568..c23f358 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -1244,6 +1244,23 @@ class RestrictionTracker(object):
"""
raise NotImplementedError
+ def is_bounded(self):
+ """Returns whether the amount of work represented by the current restriction
+ is bounded.
+
+ The boundedness of the restriction is used to determine the default behavior
+ of how to truncate restrictions when a pipeline is being
+ `drained <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#>`_. # pylint: disable=line-too-long
+ If the restriction is bounded, then the entire restriction will be processed
+ otherwise the restriction will be processed till a checkpoint is possible.
+
+ The API is required to be implemented.
+
+ Returns: ``True`` if the restriction represents a finite amount of work.
+ Otherwise, returns ``False``.
+ """
+ raise NotImplementedError
+
class WatermarkEstimator(object):
"""A WatermarkEstimator which is used for estimating output_watermark based on
@@ -1442,6 +1459,9 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
def check_done(self):
return self.restriction.range_tracker().fraction_consumed() >= 1.0
+ def is_bounded(self):
+ return True
+
class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
"""A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
def __init__(self, source, desired_chunk_size=None):
diff --git a/sdks/python/apache_beam/io/restriction_trackers.py b/sdks/python/apache_beam/io/restriction_trackers.py
index 9a53cc0..2420c0b 100644
--- a/sdks/python/apache_beam/io/restriction_trackers.py
+++ b/sdks/python/apache_beam/io/restriction_trackers.py
@@ -161,3 +161,6 @@ class OffsetRestrictionTracker(RestrictionTracker):
self._checkpointed = True
self._range, residual_range = self._range.split_at(split_point)
return self._range, residual_range
+
+ def is_bounded(self):
+ return True
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index e83ffe1..05f7a99 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -44,6 +44,7 @@ cdef class MethodWrapper(object):
cdef object restriction_provider_arg_name
cdef object watermark_estimator_provider
cdef object watermark_estimator_provider_arg_name
+ cdef bint unbounded_per_element
cdef class DoFnSignature(object):
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 6ba7771..1c4a25bd 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -199,6 +199,11 @@ class MethodWrapper(object):
self.watermark_estimator_provider = None
self.watermark_estimator_provider_arg_name = None
+ if hasattr(self.method_value, 'unbounded_per_element'):
+ self.unbounded_per_element = True
+ else:
+ self.unbounded_per_element = False
+
for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
if isinstance(v, core.DoFn.StateParam):
self.state_args_to_replace[kw] = v.state_spec
@@ -307,6 +312,9 @@ class DoFnSignature(object):
# type: () -> WatermarkEstimatorProvider
return self.process_method.watermark_estimator_provider
+ def is_unbounded_per_element(self):
+ return self.process_method.unbounded_per_element
+
def _validate(self):
# type: () -> None
self._validate_process()
diff --git a/sdks/python/apache_beam/runners/common_test.py b/sdks/python/apache_beam/runners/common_test.py
index 1e0120d..c9860e6 100644
--- a/sdks/python/apache_beam/runners/common_test.py
+++ b/sdks/python/apache_beam/runners/common_test.py
@@ -74,6 +74,21 @@ class DoFnSignatureTest(unittest.TestCase):
with self.assertRaises(ValueError):
DoFnSignature(MyDoFn())
+ def test_unbounded_element_process_fn(self):
+ class UnboundedDoFn(DoFn):
+ @DoFn.unbounded_per_element()
+ def process(self, element):
+ pass
+
+ class BoundedDoFn(DoFn):
+ def process(self, element):
+ pass
+
+ signature = DoFnSignature(UnboundedDoFn())
+ self.assertTrue(signature.is_unbounded_per_element())
+ signature = DoFnSignature(BoundedDoFn())
+ self.assertFalse(signature.is_unbounded_per_element())
+
class DoFnProcessTest(unittest.TestCase):
# pylint: disable=expression-not-assigned
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 251a696..5b8e91c 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -683,7 +683,21 @@ class BundleContextManager(object):
input_pcoll = self.process_bundle_descriptor.transforms[
transform_id].inputs[input_id]
for read_id, proto in self.process_bundle_descriptor.transforms.items():
+ # The GrpcRead is followed by the SDF/Process.
if (proto.spec.urn == bundle_processor.DATA_INPUT_URN and
input_pcoll in proto.outputs.values()):
return read_id
+ # The GrpcRead is followed by the SDF/Truncate -> SDF/Process.
+ if (proto.spec.urn ==
+ common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn and
+ input_pcoll in proto.outputs.values()):
+ read_input = list(
+ self.process_bundle_descriptor.transforms[read_id].inputs.values()
+ )[0]
+ for (grpc_read,
+ transform_proto) in self.process_bundle_descriptor.transforms.items(): # pylint: disable=line-too-long
+ if (transform_proto.spec.urn == bundle_processor.DATA_INPUT_URN and
+ read_input in transform_proto.outputs.values()):
+ return grpc_read
+
raise RuntimeError('No IO transform feeds %s' % transform_id)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index bbbab4d..e53cafe 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -91,7 +91,8 @@ class FnApiRunner(runner.PipelineRunner):
bundle_repeat=0,
use_state_iterables=False,
provision_info=None, # type: Optional[ExtendedProvisionInfo]
- progress_request_frequency=None):
+ progress_request_frequency=None,
+ is_drain=False):
# type: (...) -> None
"""Creates a new Fn API Runner.
@@ -105,6 +106,7 @@ class FnApiRunner(runner.PipelineRunner):
provision_info: provisioning info to make available to workers, or None
progress_request_frequency: The frequency (in seconds) that the runner
waits before requesting progress from the SDK.
+ is_drain: identify whether expand the sdf graph in the drain mode.
"""
super(FnApiRunner, self).__init__()
self._default_environment = (
@@ -114,6 +116,7 @@ class FnApiRunner(runner.PipelineRunner):
self._progress_frequency = progress_request_frequency
self._profiler_factory = None # type: Optional[Callable[..., profiler.Profile]]
self._use_state_iterables = use_state_iterables
+ self._is_drain = is_drain
self._provision_info = provision_info or ExtendedProvisionInfo(
beam_provision_api_pb2.ProvisionInfo(
retrieval_token='unused-retrieval-token'))
@@ -304,7 +307,8 @@ class FnApiRunner(runner.PipelineRunner):
common_urns.primitives.FLATTEN.urn,
common_urns.primitives.GROUP_BY_KEY.urn
]),
- use_state_iterables=self._use_state_iterables)
+ use_state_iterables=self._use_state_iterables,
+ is_drain=self._is_drain)
def run_stages(self,
stage_context, # type: translations.TransformContext
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
index 95e3b5f..eca6491 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
@@ -98,8 +98,8 @@ def has_urn_and_labels(mi, urn, labels):
class FnApiRunnerTest(unittest.TestCase):
- def create_pipeline(self):
- return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
+ def create_pipeline(self, is_drain=False):
+ return beam.Pipeline(runner=fn_api_runner.FnApiRunner(is_drain=is_drain))
def test_assert_that(self):
# TODO: figure out a way for fn_api_runner to parse and raise the
@@ -559,8 +559,7 @@ class FnApiRunnerTest(unittest.TestCase):
actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
assert_that(actual, equal_to(list(''.join(data))))
- def test_sdf_with_sdf_initiated_checkpointing(self):
-
+ def run_sdf_initiated_checkpointing(self, is_drain=False):
counter = beam.metrics.Metrics.counter('ns', 'my_counter')
class ExpandStringsDoFn(beam.DoFn):
@@ -579,7 +578,7 @@ class FnApiRunnerTest(unittest.TestCase):
return
cur += 1
- with self.create_pipeline() as p:
+ with self.create_pipeline(is_drain=is_drain) as p:
data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
actual = (p | beam.Create(data) | beam.ParDo(ExpandStringsDoFn()))
@@ -591,6 +590,63 @@ class FnApiRunnerTest(unittest.TestCase):
self.assertEqual(1, len(counters))
self.assertEqual(counters[0].committed, len(''.join(data)))
+ def test_sdf_with_sdf_initiated_checkpointing(self):
+ self.run_sdf_initiated_checkpointing(is_drain=False)
+
+ def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+ self.run_sdf_initiated_checkpointing(is_drain=True)
+
+ def test_sdf_default_truncate_when_bounded(self):
+ class SimleSDF(beam.DoFn):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ OffsetRangeProvider(use_bounded_offset_range=True))):
+ assert isinstance(restriction_tracker, RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
+ while restriction_tracker.try_claim(cur):
+ yield cur
+ cur += 1
+
+ with self.create_pipeline(is_drain=True) as p:
+ actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
+ assert_that(actual, equal_to(range(10)))
+
+ def test_sdf_default_truncate_when_unbounded(self):
+ class SimleSDF(beam.DoFn):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ OffsetRangeProvider(use_bounded_offset_range=False))):
+ assert isinstance(restriction_tracker, RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
+ while restriction_tracker.try_claim(cur):
+ yield cur
+ cur += 1
+
+ with self.create_pipeline(is_drain=True) as p:
+ actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
+ assert_that(actual, equal_to([]))
+
+ def test_sdf_with_truncate(self):
+ class SimleSDF(beam.DoFn):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ OffsetRangeProviderWithTruncate())):
+ assert isinstance(restriction_tracker, RestrictionTrackerView)
+ cur = restriction_tracker.current_restriction().start
+ while restriction_tracker.try_claim(cur):
+ yield cur
+ cur += 1
+
+ with self.create_pipeline(is_drain=True) as p:
+ actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
+ assert_that(actual, equal_to(range(5)))
+
def test_group_by_key(self):
with self.create_pipeline() as p:
res = (
@@ -1240,25 +1296,28 @@ class FnApiRunnerMetricsTest(unittest.TestCase):
class FnApiRunnerTestWithGrpc(FnApiRunnerTest):
- def create_pipeline(self):
+ def create_pipeline(self, is_drain=False):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
- default_environment=environments.EmbeddedPythonGrpcEnvironment()))
+ default_environment=environments.EmbeddedPythonGrpcEnvironment(),
+ is_drain=is_drain))
class FnApiRunnerTestWithDisabledCaching(FnApiRunnerTest):
- def create_pipeline(self):
+ def create_pipeline(self, is_drain=False):
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
default_environment=environments.EmbeddedPythonGrpcEnvironment(
- state_cache_size=0, data_buffer_time_limit_ms=0)))
+ state_cache_size=0, data_buffer_time_limit_ms=0),
+ is_drain=is_drain))
class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest):
- def create_pipeline(self):
+ def create_pipeline(self, is_drain=False):
pipeline_options = PipelineOptions(direct_num_workers=2)
p = beam.Pipeline(
- runner=fn_api_runner.FnApiRunner(), options=pipeline_options)
+ runner=fn_api_runner.FnApiRunner(is_drain=is_drain),
+ options=pipeline_options)
#TODO(BEAM-8444): Fix these tests.
p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
return p
@@ -1269,16 +1328,20 @@ class FnApiRunnerTestWithMultiWorkers(FnApiRunnerTest):
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("This test is for a single worker only.")
class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
- def create_pipeline(self):
+ def create_pipeline(self, is_drain=False):
pipeline_options = PipelineOptions(
direct_num_workers=2, direct_running_mode='multi_threading')
p = beam.Pipeline(
- runner=fn_api_runner.FnApiRunner(), options=pipeline_options)
+ runner=fn_api_runner.FnApiRunner(is_drain=is_drain),
+ options=pipeline_options)
#TODO(BEAM-8444): Fix these tests.
p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
return p
@@ -1289,23 +1352,27 @@ class FnApiRunnerTestWithGrpcAndMultiWorkers(FnApiRunnerTest):
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("This test is for a single worker only.")
class FnApiRunnerTestWithBundleRepeat(FnApiRunnerTest):
- def create_pipeline(self):
- return beam.Pipeline(runner=fn_api_runner.FnApiRunner(bundle_repeat=3))
+ def create_pipeline(self, is_drain=False):
+ return beam.Pipeline(
+ runner=fn_api_runner.FnApiRunner(bundle_repeat=3, is_drain=is_drain))
def test_register_finalizations(self):
raise unittest.SkipTest("TODO: Avoid bundle finalizations on repeat.")
class FnApiRunnerTestWithBundleRepeatAndMultiWorkers(FnApiRunnerTest):
- def create_pipeline(self):
+ def create_pipeline(self, is_drain=False):
pipeline_options = PipelineOptions(direct_num_workers=2)
p = beam.Pipeline(
- runner=fn_api_runner.FnApiRunner(bundle_repeat=3),
+ runner=fn_api_runner.FnApiRunner(bundle_repeat=3, is_drain=is_drain),
options=pipeline_options)
#TODO(BEAM-8444): Fix these tests.
p._options.view_as(DebugOptions).experiments.remove('beam_fn_api')
@@ -1320,17 +1387,21 @@ class FnApiRunnerTestWithBundleRepeatAndMultiWorkers(FnApiRunnerTest):
def test_sdf_with_sdf_initiated_checkpointing(self):
raise unittest.SkipTest("This test is for a single worker only.")
+ def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+ raise unittest.SkipTest("This test is for a single worker only.")
+
def test_sdf_with_watermark_tracking(self):
raise unittest.SkipTest("This test is for a single worker only.")
class FnApiRunnerSplitTest(unittest.TestCase):
- def create_pipeline(self):
+ def create_pipeline(self, is_drain=False):
# Must be GRPC so we can send data and split requests concurrent
# to the bundle process request.
return beam.Pipeline(
runner=fn_api_runner.FnApiRunner(
- default_environment=environments.EmbeddedPythonGrpcEnvironment()))
+ default_environment=environments.EmbeddedPythonGrpcEnvironment(),
+ is_drain=is_drain))
def test_checkpoint(self):
# This split manager will get re-invoked on each smaller split,
@@ -1387,16 +1458,7 @@ class FnApiRunnerSplitTest(unittest.TestCase):
| beam.Map(lambda x: element_counter.increment() or x))
assert_that(res, equal_to(elements))
- def test_nosplit_sdf(self):
- def split_manager(num_elements):
- yield
-
- elements = [1, 2, 3]
- expected_groups = [[(e, k) for k in range(e)] for e in elements]
- self.run_sdf_split_pipeline(
- split_manager, elements, ElementCounter(), expected_groups)
-
- def test_checkpoint_sdf(self):
+ def run_sdf_checkpoint(self, is_drain=False):
element_counter = ElementCounter()
def split_manager(num_elements):
@@ -1409,13 +1471,17 @@ class FnApiRunnerSplitTest(unittest.TestCase):
breakpoint.clear()
# Everything should be perfectly split.
+
elements = [2, 3]
expected_groups = [[(2, 0)], [(2, 1)], [(3, 0)], [(3, 1)], [(3, 2)]]
self.run_sdf_split_pipeline(
- split_manager, elements, element_counter, expected_groups)
-
- def test_split_half_sdf(self):
+ split_manager,
+ elements,
+ element_counter,
+ expected_groups,
+ is_drain=is_drain)
+ def run_sdf_split_half(self, is_drain=False):
element_counter = ElementCounter()
is_first_bundle = [True] # emulate nonlocal for Python 2
@@ -1438,9 +1504,13 @@ class FnApiRunnerSplitTest(unittest.TestCase):
(4, 2), (4, 3)]]
self.run_sdf_split_pipeline(
- split_manager, elements, element_counter, expected_groups)
+ split_manager,
+ elements,
+ element_counter,
+ expected_groups,
+ is_drain=is_drain)
- def test_split_crazy_sdf(self, seed=None):
+ def run_split_crazy_sdf(self, seed=None, is_drain=False):
if seed is None:
seed = random.randrange(1 << 20)
r = random.Random(seed)
@@ -1459,13 +1529,46 @@ class FnApiRunnerSplitTest(unittest.TestCase):
try:
elements = [r.randrange(5, 10) for _ in range(5)]
- self.run_sdf_split_pipeline(split_manager, elements, element_counter)
+ self.run_sdf_split_pipeline(
+ split_manager, elements, element_counter, is_drain=is_drain)
except Exception:
_LOGGER.error('test_split_crazy_sdf.seed = %s', seed)
raise
+ def test_nosplit_sdf(self):
+ def split_manager(num_elements):
+ yield
+
+ elements = [1, 2, 3]
+ expected_groups = [[(e, k) for k in range(e)] for e in elements]
+ self.run_sdf_split_pipeline(
+ split_manager, elements, ElementCounter(), expected_groups)
+
+ def test_checkpoint_sdf(self):
+ self.run_sdf_checkpoint(is_drain=False)
+
+ def test_checkpoint_draining_sdf(self):
+ self.run_sdf_checkpoint(is_drain=True)
+
+ def test_split_half_sdf(self):
+ self.run_sdf_split_half(is_drain=False)
+
+ def test_split_half_draining_sdf(self):
+ self.run_sdf_split_half(is_drain=True)
+
+ def test_split_crazy_sdf(self, seed=None):
+ self.run_split_crazy_sdf(seed=seed, is_drain=False)
+
+ def test_split_crazy_draining_sdf(self, seed=None):
+ self.run_split_crazy_sdf(seed=seed, is_drain=True)
+
def run_sdf_split_pipeline(
- self, split_manager, elements, element_counter, expected_groups=None):
+ self,
+ split_manager,
+ elements,
+ element_counter,
+ expected_groups=None,
+ is_drain=False):
# Define an SDF that for each input x produces [(x, k) for k in range(x)].
class EnumerateProvider(beam.transforms.core.RestrictionProvider):
@@ -1482,6 +1585,9 @@ class FnApiRunnerSplitTest(unittest.TestCase):
def restriction_size(self, element, restriction):
return restriction.size()
+ def is_bounded(self):
+ return True
+
class EnumerateSdf(beam.DoFn):
def process(
self,
@@ -1499,7 +1605,7 @@ class FnApiRunnerSplitTest(unittest.TestCase):
expected = [(e, k) for e in elements for k in range(e)]
with fn_runner.split_manager('SDF', split_manager):
- with self.create_pipeline() as p:
+ with self.create_pipeline(is_drain=is_drain) as p:
grouped = (
p
| beam.Create(elements, reshuffle=False)
@@ -1627,6 +1733,40 @@ class ExpandStringsProvider(beam.transforms.core.RestrictionProvider):
return restriction.size()
+class UnboundedOffsetRestrictionTracker(
+ restriction_trackers.OffsetRestrictionTracker):
+ def is_bounded(self):
+ return False
+
+
+class OffsetRangeProvider(beam.transforms.core.RestrictionProvider):
+ def __init__(self, use_bounded_offset_range):
+ self.use_bounded_offset_range = use_bounded_offset_range
+
+ def initial_restriction(self, element):
+ return restriction_trackers.OffsetRange(0, element)
+
+ def create_tracker(self, restriction):
+ if self.use_bounded_offset_range:
+ return restriction_trackers.OffsetRestrictionTracker(restriction)
+ return UnboundedOffsetRestrictionTracker(restriction)
+
+ def split(self, element, restriction):
+ return [restriction]
+
+ def restriction_size(self, element, restriction):
+ return restriction.size()
+
+
+class OffsetRangeProviderWithTruncate(OffsetRangeProvider):
+ def __init__(self):
+ super(OffsetRangeProviderWithTruncate, self).__init__(True)
+
+ def truncate(self, element, restriction):
+ return restriction_trackers.OffsetRange(
+ restriction.start, restriction.stop // 2)
+
+
class FnApiBasedLullLoggingTest(unittest.TestCase):
def create_pipeline(self):
return beam.Pipeline(
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
index 34dcf48..00c440e 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
@@ -69,6 +69,7 @@ PAR_DO_URNS = frozenset([
common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
+ common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
])
IMPULSE_BUFFER = b'impulse'
@@ -343,13 +344,15 @@ class TransformContext(object):
def __init__(self,
components, # type: beam_runner_api_pb2.Components
known_runner_urns, # type: FrozenSet[str]
- use_state_iterables=False
+ use_state_iterables=False,
+ is_drain=False
):
self.components = components
self.known_runner_urns = known_runner_urns
self.runner_only_urns = known_runner_urns - frozenset(
[common_urns.primitives.FLATTEN.urn])
self.use_state_iterables = use_state_iterables
+ self.is_drain = is_drain
# ok to pass None for context because BytesCoder has no components
coder_proto = coders.BytesCoder().to_runner_api(
None) # type: ignore[arg-type]
@@ -547,7 +550,8 @@ def pipeline_from_stages(pipeline_proto, # type: beam_runner_api_pb2.Pipeline
def create_and_optimize_stages(pipeline_proto, # type: beam_runner_api_pb2.Pipeline
phases,
known_runner_urns, # type: FrozenSet[str]
- use_state_iterables=False
+ use_state_iterables=False,
+ is_drain=False
):
# type: (...) -> Tuple[TransformContext, List[Stage]]
@@ -568,7 +572,8 @@ def create_and_optimize_stages(pipeline_proto, # type: beam_runner_api_pb2.Pipe
pipeline_context = TransformContext(
pipeline_proto.components,
known_runner_urns,
- use_state_iterables=use_state_iterables)
+ use_state_iterables=use_state_iterables,
+ is_drain=is_drain)
# Initial set of stages are singleton leaf transforms.
stages = list(
@@ -973,6 +978,7 @@ def expand_sdf(stages, context):
inputs=dict(transform.inputs, **{main_input_tag: paired_pcoll_id}),
outputs={'out': split_pcoll_id})
+ reshuffle_stage = None
if common_urns.composites.RESHUFFLE.urn in context.known_runner_urns:
reshuffle_pcoll_id = copy_like(
context.components.pcollections,
@@ -987,25 +993,57 @@ def expand_sdf(stages, context):
payload=b'',
inputs=dict(transform.inputs, **{main_input_tag: split_pcoll_id}),
outputs={'out': reshuffle_pcoll_id})
- yield make_stage(stage, reshuffle_transform_id)
+ reshuffle_stage = make_stage(stage, reshuffle_transform_id)
else:
reshuffle_pcoll_id = split_pcoll_id
reshuffle_transform_id = None
- process_transform_id = copy_like(
- context.components.transforms,
- transform,
- unique_name=transform.unique_name + '/Process',
- urn=common_urns.sdf_components.
- PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
- inputs=dict(
- transform.inputs, **{main_input_tag: reshuffle_pcoll_id}))
+ if context.is_drain:
+ truncate_pcoll_id = copy_like(
+ context.components.pcollections,
+ main_input_id,
+ '_truncate_restriction',
+ coder_id=sized_coder_id)
+ # Lengthprefix the truncate output.
+ context.length_prefix_pcoll_coders(truncate_pcoll_id)
+ truncate_transform_id = copy_like(
+ context.components.transforms,
+ transform,
+ unique_name=transform.unique_name + '/TruncateAndSizeRestriction',
+ urn=common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
+ inputs=dict(
+ transform.inputs, **{main_input_tag: reshuffle_pcoll_id}),
+ outputs={'out': truncate_pcoll_id})
+ process_transform_id = copy_like(
+ context.components.transforms,
+ transform,
+ unique_name=transform.unique_name + '/Process',
+ urn=common_urns.sdf_components.
+ PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
+ inputs=dict(
+ transform.inputs, **{main_input_tag: truncate_pcoll_id}))
+ else:
+ process_transform_id = copy_like(
+ context.components.transforms,
+ transform,
+ unique_name=transform.unique_name + '/Process',
+ urn=common_urns.sdf_components.
+ PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
+ inputs=dict(
+ transform.inputs, **{main_input_tag: reshuffle_pcoll_id}))
yield make_stage(stage, pair_transform_id)
split_stage = make_stage(stage, split_transform_id)
yield split_stage
- yield make_stage(
- stage, process_transform_id, extra_must_follow=[split_stage])
+ if reshuffle_stage:
+ yield reshuffle_stage
+ if context.is_drain:
+ yield make_stage(
+ stage, truncate_transform_id, extra_must_follow=[split_stage])
+ yield make_stage(stage, process_transform_id)
+ else:
+ yield make_stage(
+ stage, process_transform_id, extra_must_follow=[split_stage])
else:
yield stage
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner_test.py b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
index 5264f7f..c9a33c1 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner_test.py
@@ -197,7 +197,7 @@ class PortableRunnerTest(fn_runner_test.FnApiRunnerTest):
'data_buffer_time_limit_ms=1000')
return options
- def create_pipeline(self):
+ def create_pipeline(self, is_drain=False):
return beam.Pipeline(self.get_runner(), self.create_options())
def test_pardo_state_with_custom_key_coder(self):
@@ -242,6 +242,18 @@ class PortableRunnerTest(fn_runner_test.FnApiRunnerTest):
# Inherits all other tests from fn_api_runner_test.FnApiRunnerTest
+ def test_sdf_default_truncate_when_bounded(self):
+ raise unittest.SkipTest("Portable runners don't support drain yet.")
+
+ def test_sdf_default_truncate_when_unbounded(self):
+ raise unittest.SkipTest("Portable runners don't support drain yet.")
+
+ def test_sdf_with_truncate(self):
+ raise unittest.SkipTest("Portable runners don't support drain yet.")
+
+ def test_draining_sdf_with_sdf_initiated_checkpointing(self):
+ raise unittest.SkipTest("Portable runners don't support drain yet.")
+
@unittest.skip("BEAM-7248")
class PortableRunnerOptimized(PortableRunnerTest):
diff --git a/sdks/python/apache_beam/runners/sdf_utils.py b/sdks/python/apache_beam/runners/sdf_utils.py
index 1d92c7a..aa91bda 100644
--- a/sdks/python/apache_beam/runners/sdf_utils.py
+++ b/sdks/python/apache_beam/runners/sdf_utils.py
@@ -151,6 +151,9 @@ class ThreadsafeRestrictionTracker(object):
return self._deferred_residual, self._deferred_timestamp
return None
+ def is_bounded(self):
+ return self._restriction_tracker.is_bounded()
+
class RestrictionTrackerView(object):
"""A DoFn view of thread-safe RestrictionTracker.
@@ -178,6 +181,9 @@ class RestrictionTrackerView(object):
def defer_remainder(self, deferred_time=None):
self._threadsafe_restriction_tracker.defer_remainder(deferred_time)
+ def is_bounded(self):
+ self._threadsafe_restriction_tracker.is_bounded()
+
class ThreadsafeWatermarkEstimator(object):
"""A threadsafe wrapper which wraps a WatermarkEstimator with locking
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 3e57046..57a5c6b 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -1079,6 +1079,7 @@ class BundleProcessor(object):
):
# type: (...) -> beam_fn_api_pb2.BundleApplication
transform_id, main_input_tag, main_input_coder, outputs = op.input_info
+
if output_watermark:
proto_output_watermark = proto_utils.from_micros(
timestamp_pb2.Timestamp, output_watermark.micros)
@@ -1426,6 +1427,31 @@ def create_split_and_size_restrictions(*args):
@BeamTransformFactory.register_urn(
+ common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
+ beam_runner_api_pb2.ParDoPayload)
+def create_truncate_sized_restriction(*args):
+ class TruncateAndSizeRestriction(beam.DoFn):
+ def __init__(self, fn, restriction_provider, watermark_estimator_provider):
+ self.restriction_provider = restriction_provider
+
+ def process(self, element_restriction, *args, **kwargs):
+ ((element, (restriction, estimator_state)), _) = element_restriction
+ truncated_restriction = self.restriction_provider.truncate(
+ element, restriction)
+ if truncated_restriction:
+ truncated_restriction_size = (
+ self.restriction_provider.restriction_size(
+ element, truncated_restriction))
+ yield ((element, (truncated_restriction, estimator_state)),
+ truncated_restriction_size)
+
+ return _create_sdf_operation(
+ TruncateAndSizeRestriction,
+ *args,
+ operation_cls=operations.SdfTruncateSizedRestrictions)
+
+
+@BeamTransformFactory.register_urn(
common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
beam_runner_api_pb2.ParDoPayload)
def create_process_sized_elements_and_restrictions(
@@ -1448,7 +1474,13 @@ def create_process_sized_elements_and_restrictions(
def _create_sdf_operation(
- proxy_dofn, factory, transform_id, transform_proto, parameter, consumers):
+ proxy_dofn,
+ factory,
+ transform_id,
+ transform_proto,
+ parameter,
+ consumers,
+ operation_cls=operations.DoOperation):
dofn_data = pickler.loads(parameter.do_fn.payload)
dofn = dofn_data[0]
@@ -1464,7 +1496,8 @@ def _create_sdf_operation(
transform_proto,
consumers,
serialized_fn,
- parameter)
+ parameter,
+ operation_cls=operation_cls)
@BeamTransformFactory.register_urn(
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd
index 36fa809..800e587 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -102,6 +102,8 @@ cdef class SdfProcessSizedElements(DoOperation):
cdef object lock
cdef object element_start_output_bytes
+cdef class SdfTruncateSizedRestrictions(DoOperation):
+ pass
cdef class CombineOperation(Operation):
cdef object phased_combine_fn
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
index 3585ad2..51973f0 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -743,6 +743,20 @@ class DoOperation(Operation):
return infos
+class SdfTruncateSizedRestrictions(DoOperation):
+ def __init__(self, *args, **kwargs):
+ super(SdfTruncateSizedRestrictions, self).__init__(*args, **kwargs)
+
+ def current_element_progress(self):
+ # type: () -> Optional[iobase.RestrictionProgress]
+ return self.receivers[0].current_element_progress()
+
+ def try_split(
+ self, fraction_of_remainder
+ ): # type: (...) -> Optional[Tuple[Iterable[SdfSplitResultsPrimary], Iterable[SdfSplitResultsResidual]]]
+ return self.receivers[0].try_split(fraction_of_remainder)
+
+
class SdfProcessSizedElements(DoOperation):
def __init__(self, *args, **kwargs):
super(SdfProcessSizedElements, self).__init__(*args, **kwargs)
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index f1d6dce..30b05e0 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -320,6 +320,24 @@ class RestrictionProvider(object):
for part in self.split(element, restriction):
yield part, self.restriction_size(element, part)
+ def truncate(self, element, restriction):
+ """Truncates the provided restriction into a restriction representing a
+ finite amount of work when the pipeline is
+ `draining <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#> for additional details about drain.>_`. # pylint: disable=line-too-long
+ By default, if the restriction is bounded then the restriction will be
+ returned otherwise None will be returned.
+
+ This API is optional and should only be implemented if more granularity is
+ required.
+
+ Return a truncated finite restriction if further processing is required
+ otherwise return None to represent that no further processing of this
+ restriction is required.
+ """
+ restriction_tracker = self.create_tracker(restriction)
+ if restriction_tracker.is_bounded():
+ return restriction
+
def get_function_arguments(obj, func):
# type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]]
@@ -622,6 +640,16 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
def from_callable(fn):
return CallableWrapperDoFn(fn)
+ @staticmethod
+ def unbounded_per_element():
+ """A decorator on process fn specifying that the fn performs an unbounded
+ amount of work per input element."""
+ def wrapper(process_fn):
+ process_fn.unbounded_per_element = True
+ return process_fn
+
+ return wrapper
+
def default_label(self):
return self.__class__.__name__
diff --git a/sdks/python/apache_beam/transforms/environments.py b/sdks/python/apache_beam/transforms/environments.py
index 6061987..3d5325f 100644
--- a/sdks/python/apache_beam/transforms/environments.py
+++ b/sdks/python/apache_beam/transforms/environments.py
@@ -595,6 +595,8 @@ def _python_sdk_capabilities_iter():
yield common_urns.protocols.LEGACY_PROGRESS_REPORTING.urn
yield common_urns.protocols.WORKER_STATUS.urn
yield 'beam:version:sdk_base:' + DockerEnvironment.default_docker_image()
+ #TODO(BEAM-10530): Add truncate capability.
+ # yield common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn
def python_sdk_dependencies(options, tmp_dir=None):
diff --git a/sdks/python/apache_beam/transforms/environments_test.py b/sdks/python/apache_beam/transforms/environments_test.py
index 6259d71..46be840 100644
--- a/sdks/python/apache_beam/transforms/environments_test.py
+++ b/sdks/python/apache_beam/transforms/environments_test.py
@@ -73,6 +73,10 @@ class RunnerApiTest(unittest.TestCase):
sdk_capabilities = environments.python_sdk_capabilities()
self.assertIn(common_urns.coders.LENGTH_PREFIX.urn, sdk_capabilities)
self.assertIn(common_urns.protocols.WORKER_STATUS.urn, sdk_capabilities)
+ #TODO(BEAM-10530): Add truncate capability.
+ # self.assertIn(
+ # common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
+ # sdk_capabilities)
def test_default_capabilities(self):
environment = DockerEnvironment.from_options(