You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2019/05/08 21:05:54 UTC

[beam] branch master updated: Changes to SDF API to use DoFn Params (#8430)

This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c9bad5d  Changes to SDF API to use DoFn Params (#8430)
c9bad5d is described below

commit c9bad5d023ac4755e47e4b73f5d9a92e402b152c
Author: Pablo <pa...@users.noreply.github.com>
AuthorDate: Wed May 8 14:05:33 2019 -0700

    Changes to SDF API to use DoFn Params (#8430)
    
    * Changes to SDF API to use DoFn Params
    
    * Fix docs tests
    
    * Fix docs again
    
    * Fix test
---
 sdks/python/apache_beam/runners/common.py          |  8 ++++----
 .../runners/direct/sdf_direct_runner_test.py       |  7 +++++--
 .../runners/portability/fn_api_runner_test.py      | 17 +++++++++++++---
 .../apache_beam/testing/synthetic_pipeline.py      |  3 ++-
 sdks/python/apache_beam/transforms/core.py         | 23 ++++++++++++++++++++--
 sdks/python/scripts/generate_pydoc.sh              |  1 +
 6 files changed, 47 insertions(+), 12 deletions(-)

diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index 84ac116..3bbfd90 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -239,8 +239,8 @@ class DoFnSignature(object):
 
   def get_restriction_provider(self):
     result = _find_param_with_default(self.process_method,
-                                      default_as_type=RestrictionProvider)
-    return result[1] if result else None
+                                      default_as_type=DoFn.RestrictionParam)
+    return result[1].restriction_provider if result else None
 
   def _validate(self):
     self._validate_process()
@@ -271,7 +271,7 @@ class DoFnSignature(object):
     userstate.validate_stateful_dofn(self.do_fn)
 
   def is_splittable_dofn(self):
-    return any([isinstance(default, RestrictionProvider) for default in
+    return any([isinstance(default, DoFn.RestrictionParam) for default in
                 self.process_method.defaults])
 
   def is_stateful_dofn(self):
@@ -538,7 +538,7 @@ class PerWindowInvoker(DoFnInvoker):
             'SDFs in multiply-windowed values with windowed arguments.')
       restriction_tracker_param = _find_param_with_default(
           self.signature.process_method,
-          default_as_type=core.RestrictionProvider)[0]
+          default_as_type=DoFn.RestrictionParam)[0]
       if not restriction_tracker_param:
         raise ValueError(
             'A RestrictionTracker %r was provided but DoFn does not have a '
diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
index eae38bc..3e1e344 100644
--- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
+++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py
@@ -57,7 +57,10 @@ class ReadFiles(DoFn):
     self._resume_count = resume_count
 
   def process(
-      self, element, restriction_tracker=ReadFilesProvider(), *args, **kwargs):
+      self,
+      element,
+      restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()),
+      *args, **kwargs):
     file_name = element
     assert isinstance(restriction_tracker, OffsetRestrictionTracker)
 
@@ -107,7 +110,7 @@ class ExpandStrings(DoFn):
 
   def process(
       self, element, side1, side2, side3, window=beam.DoFn.WindowParam,
-      restriction_tracker=ExpandStringsProvider(),
+      restriction_tracker=DoFn.RestrictionParam(ExpandStringsProvider()),
       *args, **kwargs):
     side = []
     side.extend(side1)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index c584ef1..a807cfa 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -422,7 +422,11 @@ class FnApiRunnerTest(unittest.TestCase):
   def test_sdf(self):
 
     class ExpandingStringsDoFn(beam.DoFn):
-      def process(self, element, restriction_tracker=ExpandStringsProvider()):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider())):
         assert isinstance(
             restriction_tracker,
             restriction_trackers.OffsetRestrictionTracker), restriction_tracker
@@ -442,7 +446,11 @@ class FnApiRunnerTest(unittest.TestCase):
     counter = beam.metrics.Metrics.counter('ns', 'my_counter')
 
     class ExpandStringsDoFn(beam.DoFn):
-      def process(self, element, restriction_tracker=ExpandStringsProvider()):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(
+              ExpandStringsProvider())):
         assert isinstance(
             restriction_tracker,
             restriction_trackers.OffsetRestrictionTracker), restriction_tracker
@@ -1271,7 +1279,10 @@ class FnApiRunnerSplitTest(unittest.TestCase):
         return restriction[1] - restriction[0]
 
     class EnumerateSdf(beam.DoFn):
-      def process(self, element, restriction_tracker=EnumerateProvider()):
+      def process(
+          self,
+          element,
+          restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())):
         to_emit = []
         for k in range(*restriction_tracker.current_restriction()):
           if restriction_tracker.try_claim(k):
diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py
index eb1ec8d..2cace10 100644
--- a/sdks/python/apache_beam/testing/synthetic_pipeline.py
+++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py
@@ -358,7 +358,8 @@ class SyntheticSDFAsSource(beam.DoFn):
   def process(
       self,
       element,
-      restriction_tracker=SyntheticSDFSourceRestrictionProvider()):
+      restriction_tracker=beam.DoFn.RestrictionParam(
+          SyntheticSDFSourceRestrictionProvider())):
     for k in range(*restriction_tracker.current_restriction()):
       if not restriction_tracker.try_claim(k):
         return
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 5eed185..fb93c00 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -345,6 +345,18 @@ class _DoFnParam(object):
     return self.param_id
 
 
+class _RestrictionDoFnParam(_DoFnParam):
+  """Restriction Provider DoFn parameter."""
+
+  def __init__(self, restriction_provider):
+    if not isinstance(restriction_provider, RestrictionProvider):
+      raise ValueError(
+          'DoFn.RestrictionParam expected RestrictionProvider object.')
+    self.restriction_provider = restriction_provider
+    self.param_id = ('RestrictionParam(%s)'
+                     % restriction_provider.__class__.__name__)
+
+
 class _StateDoFnParam(_DoFnParam):
   """State DoFn parameter."""
 
@@ -421,6 +433,8 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   StateParam = _StateDoFnParam
   TimerParam = _TimerDoFnParam
 
+  RestrictionParam = _RestrictionDoFnParam
+
   @staticmethod
   def from_callable(fn):
     return CallableWrapperDoFn(fn)
@@ -441,8 +455,13 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
     ``DoFn.SideInputParam``: a side input that may be used when processing.
     ``DoFn.TimestampParam``: timestamp of the input element.
     ``DoFn.WindowParam``: ``Window`` the input element belongs to.
-    A ``RestrictionProvider`` instance: an ``iobase.RestrictionTracker`` will be
-    provided here to allow treatment as a Splittable `DoFn``.
+    ``DoFn.TimerParam``: a ``userstate.RuntimeTimer`` object defined by the spec
+    of the parameter.
+    ``DoFn.StateParam``: a ``userstate.RuntimeState`` object defined by the spec
+    of the parameter.
+    ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
+    provided here to allow treatment as a Splittable ``DoFn``. The restriction
+    tracker will be derived from the restriction provider in the parameter.
     ``DoFn.WatermarkReporterParam``: a function that can be used to report
     output watermark of Splittable ``DoFn`` implementations.
 
diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh
index dc9d74b..7564b49 100755
--- a/sdks/python/scripts/generate_pydoc.sh
+++ b/sdks/python/scripts/generate_pydoc.sh
@@ -178,6 +178,7 @@ ignore_identifiers = [
   '_StateDoFnParam',
   '_TimerDoFnParam',
   '_BundleFinalizerParam',
+  '_RestrictionDoFnParam',
 
   # Sphinx cannot find this py:class reference target
   'typing.Generic',