You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2019/05/08 21:05:54 UTC
[beam] branch master updated: Changes to SDF API to use DoFn Params
(#8430)
This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c9bad5d Changes to SDF API to use DoFn Params (#8430)
c9bad5d is described below
commit c9bad5d023ac4755e47e4b73f5d9a92e402b152c
Author: Pablo <pa...@users.noreply.github.com>
AuthorDate: Wed May 8 14:05:33 2019 -0700
Changes to SDF API to use DoFn Params (#8430)
* Changes to SDF API to use DoFn Params
* Fix docs tests
* Fix docs again
* Fix test
---
sdks/python/apache_beam/runners/common.py | 8 ++++----
.../runners/direct/sdf_direct_runner_test.py | 7 +++++--
.../runners/portability/fn_api_runner_test.py | 17 +++++++++++++---
.../apache_beam/testing/synthetic_pipeline.py | 3 ++-
sdks/python/apache_beam/transforms/core.py | 23 ++++++++++++++++++++--
sdks/python/scripts/generate_pydoc.sh | 1 +
6 files changed, 47 insertions(+), 12 deletions(-)
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 84ac116..3bbfd90 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -239,8 +239,8 @@ class DoFnSignature(object):
def get_restriction_provider(self):
result = _find_param_with_default(self.process_method,
- default_as_type=RestrictionProvider)
- return result[1] if result else None
+ default_as_type=DoFn.RestrictionParam)
+ return result[1].restriction_provider if result else None
def _validate(self):
self._validate_process()
@@ -271,7 +271,7 @@ class DoFnSignature(object):
userstate.validate_stateful_dofn(self.do_fn)
def is_splittable_dofn(self):
- return any([isinstance(default, RestrictionProvider) for default in
+ return any([isinstance(default, DoFn.RestrictionParam) for default in
self.process_method.defaults])
def is_stateful_dofn(self):
@@ -538,7 +538,7 @@ class PerWindowInvoker(DoFnInvoker):
'SDFs in multiply-windowed values with windowed arguments.')
restriction_tracker_param = _find_param_with_default(
self.signature.process_method,
- default_as_type=core.RestrictionProvider)[0]
+ default_as_type=DoFn.RestrictionParam)[0]
if not restriction_tracker_param:
raise ValueError(
'A RestrictionTracker %r was provided but DoFn does not have a '
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index eae38bc..3e1e344 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -57,7 +57,10 @@ class ReadFiles(DoFn):
self._resume_count = resume_count
def process(
- self, element, restriction_tracker=ReadFilesProvider(), *args, **kwargs):
+ self,
+ element,
+ restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()),
+ *args, **kwargs):
file_name = element
assert isinstance(restriction_tracker, OffsetRestrictionTracker)
@@ -107,7 +110,7 @@ class ExpandStrings(DoFn):
def process(
self, element, side1, side2, side3, window=beam.DoFn.WindowParam,
- restriction_tracker=ExpandStringsProvider(),
+ restriction_tracker=DoFn.RestrictionParam(ExpandStringsProvider()),
*args, **kwargs):
side = []
side.extend(side1)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index c584ef1..a807cfa 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -422,7 +422,11 @@ class FnApiRunnerTest(unittest.TestCase):
def test_sdf(self):
class ExpandingStringsDoFn(beam.DoFn):
- def process(self, element, restriction_tracker=ExpandStringsProvider()):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ ExpandStringsProvider())):
assert isinstance(
restriction_tracker,
restriction_trackers.OffsetRestrictionTracker), restriction_tracker
@@ -442,7 +446,11 @@ class FnApiRunnerTest(unittest.TestCase):
counter = beam.metrics.Metrics.counter('ns', 'my_counter')
class ExpandStringsDoFn(beam.DoFn):
- def process(self, element, restriction_tracker=ExpandStringsProvider()):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ ExpandStringsProvider())):
assert isinstance(
restriction_tracker,
restriction_trackers.OffsetRestrictionTracker), restriction_tracker
@@ -1271,7 +1279,10 @@ class FnApiRunnerSplitTest(unittest.TestCase):
return restriction[1] - restriction[0]
class EnumerateSdf(beam.DoFn):
- def process(self, element, restriction_tracker=EnumerateProvider()):
+ def process(
+ self,
+ element,
+ restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
to_emit = []
for k in range(*restriction_tracker.current_restriction()):
if restriction_tracker.try_claim(k):
diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py
index eb1ec8d..2cace10 100644
--- a/sdks/python/apache_beam/testing/synthetic_pipeline.py
+++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py
@@ -358,7 +358,8 @@ class SyntheticSDFAsSource(beam.DoFn):
def process(
self,
element,
- restriction_tracker=SyntheticSDFSourceRestrictionProvider()):
+ restriction_tracker=beam.DoFn.RestrictionParam(
+ SyntheticSDFSourceRestrictionProvider())):
for k in range(*restriction_tracker.current_restriction()):
if not restriction_tracker.try_claim(k):
return
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 5eed185..fb93c00 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -345,6 +345,18 @@ class _DoFnParam(object):
return self.param_id
+class _RestrictionDoFnParam(_DoFnParam):
+ """Restriction Provider DoFn parameter."""
+
+ def __init__(self, restriction_provider):
+ if not isinstance(restriction_provider, RestrictionProvider):
+ raise ValueError(
+ 'DoFn.RestrictionParam expected RestrictionProvider object.')
+ self.restriction_provider = restriction_provider
+ self.param_id = ('RestrictionParam(%s)'
+ % restriction_provider.__class__.__name__)
+
+
class _StateDoFnParam(_DoFnParam):
"""State DoFn parameter."""
@@ -421,6 +433,8 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
StateParam = _StateDoFnParam
TimerParam = _TimerDoFnParam
+ RestrictionParam = _RestrictionDoFnParam
+
@staticmethod
def from_callable(fn):
return CallableWrapperDoFn(fn)
@@ -441,8 +455,13 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
``DoFn.SideInputParam``: a side input that may be used when processing.
``DoFn.TimestampParam``: timestamp of the input element.
``DoFn.WindowParam``: ``Window`` the input element belongs to.
- A ``RestrictionProvider`` instance: an ``iobase.RestrictionTracker`` will be
- provided here to allow treatment as a Splittable `DoFn``.
+ ``DoFn.TimerParam``: a ``userstate.RuntimeTimer`` object defined by the spec
+ of the parameter.
+ ``DoFn.StateParam``: a ``userstate.RuntimeState`` object defined by the spec
+ of the parameter.
+ ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
+ provided here to allow treatment as a Splittable ``DoFn``. The restriction
+ tracker will be derived from the restriction provider in the parameter.
``DoFn.WatermarkReporterParam``: a function that can be used to report
output watermark of Splittable ``DoFn`` implementations.
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index dc9d74b..7564b49 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -178,6 +178,7 @@ ignore_identifiers = [
'_StateDoFnParam',
'_TimerDoFnParam',
'_BundleFinalizerParam',
+ '_RestrictionDoFnParam',
# Sphinx cannot find this py:class reference target
'typing.Generic',