You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/01/22 04:37:28 UTC
[1/2] beam git commit: Implement Annotation based NewDoFn in python
SDK
Repository: beam
Updated Branches:
refs/heads/python-sdk 946135f6a -> d0474ab5b
Implement Annotation based NewDoFn in python SDK
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/9e272ecf
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/9e272ecf
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/9e272ecf
Branch: refs/heads/python-sdk
Commit: 9e272ecf639b7b13f23a83868fd101a437159c1c
Parents: 946135f
Author: Sourabh Bajaj <so...@google.com>
Authored: Fri Jan 20 17:17:25 2017 -0800
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Sat Jan 21 20:37:07 2017 -0800
----------------------------------------------------------------------
sdks/python/apache_beam/pipeline_test.py | 100 ++++++++-
sdks/python/apache_beam/runners/common.pxd | 4 +
sdks/python/apache_beam/runners/common.py | 221 +++++++++++++------
.../runners/direct/transform_evaluator.py | 15 +-
sdks/python/apache_beam/transforms/core.py | 113 +++++++++-
sdks/python/apache_beam/typehints/decorators.py | 2 +-
sdks/python/apache_beam/typehints/typecheck.py | 145 ++++++++++++
7 files changed, 531 insertions(+), 69 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/pipeline_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index 336bf54..93b68d1 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -24,15 +24,23 @@ import unittest
from apache_beam.pipeline import Pipeline
from apache_beam.pipeline import PipelineOptions
from apache_beam.pipeline import PipelineVisitor
+from apache_beam.pvalue import AsSingleton
from apache_beam.runners.dataflow.native_io.iobase import NativeSource
from apache_beam.test_pipeline import TestPipeline
from apache_beam.transforms import CombineGlobally
from apache_beam.transforms import Create
from apache_beam.transforms import FlatMap
from apache_beam.transforms import Map
+from apache_beam.transforms import NewDoFn
+from apache_beam.transforms import ParDo
from apache_beam.transforms import PTransform
from apache_beam.transforms import Read
-from apache_beam.transforms.util import assert_that, equal_to
+from apache_beam.transforms import WindowInto
+from apache_beam.transforms.util import assert_that
+from apache_beam.transforms.util import equal_to
+from apache_beam.transforms.window import IntervalWindow
+from apache_beam.transforms.window import WindowFn
+from apache_beam.utils.timestamp import MIN_TIMESTAMP
class FakeSource(NativeSource):
@@ -241,6 +249,96 @@ class PipelineTest(unittest.TestCase):
self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x))
+class NewDoFnTest(unittest.TestCase):
+
+ def setUp(self):
+ self.runner_name = 'DirectRunner'
+
+ def test_element(self):
+ class TestDoFn(NewDoFn):
+ def process(self, element):
+ yield element + 10
+
+ pipeline = TestPipeline(runner=self.runner_name)
+ pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
+ assert_that(pcoll, equal_to([11, 12]))
+ pipeline.run()
+
+ def test_context_param(self):
+ class TestDoFn(NewDoFn):
+ def process(self, element, context=NewDoFn.ContextParam):
+ yield context.element + 10
+
+ pipeline = TestPipeline(runner=self.runner_name)
+ pcoll = pipeline | 'Create' >> Create([1, 2])| 'Do' >> ParDo(TestDoFn())
+ assert_that(pcoll, equal_to([11, 12]))
+ pipeline.run()
+
+ def test_side_input_no_tag(self):
+ class TestDoFn(NewDoFn):
+ def process(self, element, prefix, suffix):
+ return ['%s-%s-%s' % (prefix, element, suffix)]
+
+ pipeline = TestPipeline()
+ words_list = ['aa', 'bb', 'cc']
+ words = pipeline | 'SomeWords' >> Create(words_list)
+ prefix = 'zyx'
+ suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in
+ result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
+ TestDoFn(), prefix, suffix=AsSingleton(suffix))
+ assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
+ pipeline.run()
+
+ def test_side_input_tagged(self):
+ class TestDoFn(NewDoFn):
+ def process(self, element, prefix, suffix=NewDoFn.SideInputParam):
+ return ['%s-%s-%s' % (prefix, element, suffix)]
+
+ pipeline = TestPipeline()
+ words_list = ['aa', 'bb', 'cc']
+ words = pipeline | 'SomeWords' >> Create(words_list)
+ prefix = 'zyx'
+ suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in
+ result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
+ TestDoFn(), prefix, suffix=AsSingleton(suffix))
+ assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
+ pipeline.run()
+
+ def test_window_param(self):
+ class TestDoFn(NewDoFn):
+ def process(self, element, window=NewDoFn.WindowParam):
+ yield (float(window.start), float(window.end))
+
+ class TestWindowFn(WindowFn):
+ """Windowing function adding two disjoint windows to each element."""
+
+ def assign(self, assign_context):
+ _ = assign_context
+ return [IntervalWindow(10, 20), IntervalWindow(20, 30)]
+
+ def merge(self, existing_windows):
+ return existing_windows
+
+ pipeline = TestPipeline(runner=self.runner_name)
+ pcoll = (pipeline
+ | 'KVs' >> Create([(1, 10), (2, 20)])
+ | 'W' >> WindowInto(windowfn=TestWindowFn())
+ | 'Do' >> ParDo(TestDoFn()))
+ assert_that(pcoll, equal_to([(10.0, 20.0), (10.0, 20.0),
+ (20.0, 30.0), (20.0, 30.0)]))
+ pipeline.run()
+
+ def test_timestamp_param(self):
+ class TestDoFn(NewDoFn):
+ def process(self, element, timestamp=NewDoFn.TimestampParam):
+ yield timestamp
+
+ pipeline = TestPipeline(runner=self.runner_name)
+ pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
+ assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
+ pipeline.run()
+
+
class Bacon(PipelineOptions):
@classmethod
http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/common.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index 085fd11..06fe434 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -36,6 +36,10 @@ cdef class DoFnRunner(Receiver):
cdef object tagged_receivers
cdef LoggingContext logging_context
cdef object step_name
+ cdef object is_new_dofn
+ cdef object args
+ cdef object kwargs
+ cdef object side_inputs
cdef bint has_windowed_side_inputs
cdef Receiver main_receivers
http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/common.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index cc834ba..0f63cbc 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -71,50 +71,21 @@ class DoFnRunner(Receiver):
# Preferred alternative to context
# TODO(robertwb): Remove once all runners are updated.
state=None):
- self.has_windowed_side_inputs = False # Set to True in one case below.
- if not args and not kwargs:
- self.dofn = fn
- self.dofn_process = fn.process
- else:
- global_window = window.GlobalWindow()
- # TODO(robertwb): Remove when all runners pass side input maps.
- side_inputs = [side_input
- if isinstance(side_input, sideinputs.SideInputMap)
- else {global_window: side_input}
- for side_input in side_inputs]
- if side_inputs and all(
- isinstance(side_input, dict) or side_input.is_globally_windowed()
- for side_input in side_inputs):
- args, kwargs = util.insert_values_in_args(
- args, kwargs, [side_input[global_window]
- for side_input in side_inputs])
- side_inputs = []
- if side_inputs:
- self.has_windowed_side_inputs = True
-
- def process(context):
- w = context.windows[0]
- cur_args, cur_kwargs = util.insert_values_in_args(
- args, kwargs, [side_input[w] for side_input in side_inputs])
- return fn.process(context, *cur_args, **cur_kwargs)
- self.dofn_process = process
- elif kwargs:
- self.dofn_process = lambda context: fn.process(context, *args, **kwargs)
- else:
- self.dofn_process = lambda context: fn.process(context, *args)
-
- class CurriedFn(core.DoFn):
+ self.step_name = step_name
+ self.window_fn = windowing.windowfn
+ self.tagged_receivers = tagged_receivers
- start_bundle = staticmethod(fn.start_bundle)
- process = staticmethod(self.dofn_process)
- finish_bundle = staticmethod(fn.finish_bundle)
+ global_window = window.GlobalWindow()
- self.dofn = CurriedFn()
+ if logging_context:
+ self.logging_context = logging_context
+ else:
+ self.logging_context = get_logging_context(logger, step_name=step_name)
- self.window_fn = windowing.windowfn
- self.tagged_receivers = tagged_receivers
- self.step_name = step_name
+ # Optimize for the common case.
+ self.main_receivers = as_receiver(tagged_receivers[None])
+ # TODO(sourabh): Deprecate the use of context
if state:
assert context is None
self.context = DoFnContext(self.step_name, state=state)
@@ -122,48 +93,172 @@ class DoFnRunner(Receiver):
assert context is not None
self.context = context
- if logging_context:
- self.logging_context = logging_context
+ # TODO(Sourabhbajaj): Remove the usage of OldDoFn
+ if isinstance(fn, core.NewDoFn):
+ self.is_new_dofn = True
+
+ # SideInputs
+ self.side_inputs = [side_input
+ if isinstance(side_input, sideinputs.SideInputMap)
+ else {global_window: side_input}
+ for side_input in side_inputs]
+ self.has_windowed_side_inputs = not all(
+ isinstance(si, dict) or si.is_globally_windowed()
+ for si in self.side_inputs)
+
+ self.args = args if args else []
+ self.kwargs = kwargs if kwargs else {}
+ self.dofn = fn
+
else:
- self.logging_context = get_logging_context(logger, step_name=step_name)
+ self.is_new_dofn = False
+ self.has_windowed_side_inputs = False # Set to True in one case below.
+ if not args and not kwargs:
+ self.dofn = fn
+ self.dofn_process = fn.process
+ else:
+ # TODO(robertwb): Remove when all runners pass side input maps.
+ side_inputs = [side_input
+ if isinstance(side_input, sideinputs.SideInputMap)
+ else {global_window: side_input}
+ for side_input in side_inputs]
+ if side_inputs and all(
+ isinstance(side_input, dict) or side_input.is_globally_windowed()
+ for side_input in side_inputs):
+ args, kwargs = util.insert_values_in_args(
+ args, kwargs, [side_input[global_window]
+ for side_input in side_inputs])
+ side_inputs = []
+ if side_inputs:
+ self.has_windowed_side_inputs = True
+
+ def process(context):
+ w = context.windows[0]
+ cur_args, cur_kwargs = util.insert_values_in_args(
+ args, kwargs, [side_input[w] for side_input in side_inputs])
+ return fn.process(context, *cur_args, **cur_kwargs)
+ self.dofn_process = process
+ elif kwargs:
+ self.dofn_process = lambda context: fn.process(
+ context, *args, **kwargs)
+ else:
+ self.dofn_process = lambda context: fn.process(context, *args)
- # Optimize for the common case.
- self.main_receivers = as_receiver(tagged_receivers[None])
+ class CurriedFn(core.DoFn):
+
+ start_bundle = staticmethod(fn.start_bundle)
+ process = staticmethod(self.dofn_process)
+ finish_bundle = staticmethod(fn.finish_bundle)
+
+ self.dofn = CurriedFn()
def receive(self, windowed_value):
self.process(windowed_value)
- def start(self):
- self.context.set_element(None)
+ def old_dofn_process(self, element):
+ if self.has_windowed_side_inputs and len(element.windows) > 1:
+ for w in element.windows:
+ self.context.set_element(
+ WindowedValue(element.value, element.timestamp, (w,)))
+ self._process_outputs(element, self.dofn_process(self.context))
+ else:
+ self.context.set_element(element)
+ self._process_outputs(element, self.dofn_process(self.context))
+
+ def new_dofn_process(self, element):
+ self.context.set_element(element)
+ arguments, _, _, defaults = self.dofn.get_function_arguments('process')
+ defaults = defaults if defaults else []
+
+ self_in_args = int(self.dofn.is_process_bounded())
+
+ # Call for the process function for each window if has windowed side inputs
+ # or if the process accesses the window parameter. We can just call it once
+ # otherwise as none of the arguments are changing
+ if self.has_windowed_side_inputs or core.NewDoFn.WindowParam in defaults:
+ windows = element.windows
+ else:
+ windows = [window.GlobalWindow()]
+
+ for w in windows:
+ args, kwargs = util.insert_values_in_args(
+ self.args, self.kwargs,
+ [s[w] for s in self.side_inputs])
+
+ # If there are more arguments than the default then the first argument
+ # should be the element and the rest should be picked from the side
+ # inputs as window and timestamp should always be tagged
+ if len(arguments) > len(defaults) + self_in_args:
+ if core.NewDoFn.ElementParam not in defaults:
+ args_to_pick = len(arguments) - len(defaults) - 1 - self_in_args
+ final_args = [element.value] + args[:args_to_pick]
+ else:
+ args_to_pick = len(arguments) - len(defaults) - self_in_args
+ final_args = args[:args_to_pick]
+ else:
+ args_to_pick = 0
+ final_args = []
+ args = iter(args[args_to_pick:])
+
+ for a, d in zip(arguments[-len(defaults):], defaults):
+ if d == core.NewDoFn.ElementParam:
+ final_args.append(element.value)
+ elif d == core.NewDoFn.ContextParam:
+ final_args.append(self.context)
+ elif d == core.NewDoFn.WindowParam:
+ final_args.append(w)
+ elif d == core.NewDoFn.TimestampParam:
+ final_args.append(element.timestamp)
+ elif d == core.NewDoFn.SideInputParam:
+ # If no more args are present then the value must be passed via kwarg
+ try:
+ final_args.append(args.next())
+ except StopIteration:
+ if a not in kwargs:
+ raise
+ else:
+ # If no more args are present then the value must be passed via kwarg
+ try:
+ final_args.append(args.next())
+ except StopIteration:
+ if a not in kwargs:
+ kwargs[a] = d
+ final_args.extend(list(args))
+ self._process_outputs(element, self.dofn.process(*final_args, **kwargs))
+
+ def _invoke_bundle_method(self, method):
try:
self.logging_context.enter()
- self._process_outputs(None, self.dofn.start_bundle(self.context))
+ self.context.set_element(None)
+ f = getattr(self.dofn, method)
+
+ # TODO(Sourabhbajaj): Remove this if-else
+ if self.is_new_dofn:
+ _, _, _, defaults = self.dofn.get_function_arguments(method)
+ defaults = defaults if defaults else []
+ args = [self.context if d == core.NewDoFn.ContextParam else d
+ for d in defaults]
+ self._process_outputs(None, f(*args))
+ else:
+ self._process_outputs(None, f(self.context))
except BaseException as exn:
self.reraise_augmented(exn)
finally:
self.logging_context.exit()
+ def start(self):
+ self._invoke_bundle_method('start_bundle')
+
def finish(self):
- self.context.set_element(None)
- try:
- self.logging_context.enter()
- self._process_outputs(None, self.dofn.finish_bundle(self.context))
- except BaseException as exn:
- self.reraise_augmented(exn)
- finally:
- self.logging_context.exit()
+ self._invoke_bundle_method('finish_bundle')
def process(self, element):
try:
self.logging_context.enter()
- if self.has_windowed_side_inputs and len(element.windows) > 1:
- for w in element.windows:
- self.context.set_element(
- WindowedValue(element.value, element.timestamp, (w,)))
- self._process_outputs(element, self.dofn_process(self.context))
+ if self.is_new_dofn:
+ self.new_dofn_process(element)
else:
- self.context.set_element(element)
- self._process_outputs(element, self.dofn_process(self.context))
+ self.old_dofn_process(element)
except BaseException as exn:
self.reraise_augmented(exn)
finally:
http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/direct/transform_evaluator.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index b4c43ba..ec2b3a1 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -35,8 +35,10 @@ from apache_beam.transforms import sideinputs
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import WindowedValue
from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn
+from apache_beam.typehints.typecheck import OutputCheckWrapperNewDoFn
from apache_beam.typehints.typecheck import TypeCheckError
from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn
+from apache_beam.typehints.typecheck import TypeCheckWrapperNewDoFn
from apache_beam.utils import counters
from apache_beam.utils.pipeline_options import TypeOptions
@@ -344,9 +346,18 @@ class _ParDoEvaluator(_TransformEvaluator):
pipeline_options = self._evaluation_context.pipeline_options
if (pipeline_options is not None
and pipeline_options.view_as(TypeOptions).runtime_type_check):
- dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())
+ # TODO(sourabhbajaj): Remove this if-else
+ if isinstance(dofn, core.NewDoFn):
+ dofn = TypeCheckWrapperNewDoFn(dofn, transform.get_type_hints())
+ else:
+ dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())
- dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
+ # TODO(sourabhbajaj): Remove this if-else
+ if isinstance(dofn, core.NewDoFn):
+ dofn = OutputCheckWrapperNewDoFn(
+ dofn, self._applied_ptransform.full_label)
+ else:
+ dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
self.runner = DoFnRunner(dofn, transform.args, transform.kwargs,
self._side_inputs,
self._applied_ptransform.inputs[0].windowing,
http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/transforms/core.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 72f7cd4..70a03ae 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -109,6 +109,7 @@ class DoFnProcessContext(DoFnContext):
self.timestamp = windowed_value.timestamp
self.windows = windowed_value.windows
+ # TODO(sourabhbajaj): Move as we're trying to deprecate the use of context
def aggregate_to(self, aggregator, input_value):
"""Provide a new input value for the aggregator.
@@ -119,6 +120,112 @@ class DoFnProcessContext(DoFnContext):
self.state.counter_for(aggregator).update(input_value)
+class NewDoFn(WithTypeHints, HasDisplayData):
+ """A function object used by a transform with custom processing.
+
+ The ParDo transform is such a transform. The ParDo.apply
+ method will take an object of type DoFn and apply it to all elements of a
+ PCollection object.
+
+ In order to have concrete DoFn objects one has to subclass from DoFn and
+ define the desired behavior (start_bundle/finish_bundle and process) or wrap a
+ callable object using the CallableWrapperDoFn class.
+ """
+
+ ElementParam = 'ElementParam'
+ ContextParam = 'ContextParam'
+ SideInputParam = 'SideInputParam'
+ TimestampParam = 'TimestampParam'
+ WindowParam = 'WindowParam'
+
+ @staticmethod
+ def from_callable(fn):
+ return CallableWrapperDoFn(fn)
+
+ def default_label(self):
+ return self.__class__.__name__
+
+ def process(self, element, *args, **kwargs):
+ """Called for each element of a pipeline. The default arguments are needed
+ for the DoFnRunner to be able to pass the parameters correctly.
+
+ Args:
+ element: The element to be processed
+ context: a DoFnProcessContext object containing. See the
+ DoFnProcessContext documentation for details.
+ *args: side inputs
+ **kwargs: keyword side inputs
+ """
+ raise NotImplementedError
+
+ def start_bundle(self):
+ """Called before a bundle of elements is processed on a worker.
+
+ Elements to be processed are split into bundles and distributed
+ to workers. Before a worker calls process() on the first element
+ of its bundle, it calls this method.
+ """
+ pass
+
+ def finish_bundle(self):
+ """Called after a bundle of elements is processed on a worker.
+ """
+ pass
+
+ def get_function_arguments(self, func):
+ """Return the function arguments based on the name provided. If they have
+ a _inspect_function attached to the class then use that otherwise default
+ to the python inspect library.
+ """
+ func_name = '_inspect_%s' % func
+ if hasattr(self, func_name):
+ f = getattr(self, func_name)
+ return f()
+ else:
+ f = getattr(self, func)
+ return inspect.getargspec(f)
+
+ # TODO(sourabhbajaj): Do we want to remove the responsiblity of these from
+ # the DoFn or maybe the runner
+ def infer_output_type(self, input_type):
+ # TODO(robertwb): Side inputs types.
+ # TODO(robertwb): Assert compatibility with input type hint?
+ return self._strip_output_annotations(
+ trivial_inference.infer_return_type(self.process, [input_type]))
+
+ def _strip_output_annotations(self, type_hint):
+ annotations = (window.TimestampedValue, window.WindowedValue,
+ pvalue.SideOutputValue)
+ # TODO(robertwb): These should be parameterized types that the
+ # type inferencer understands.
+ if (type_hint in annotations
+ or trivial_inference.element_type(type_hint) in annotations):
+ return Any
+ else:
+ return type_hint
+
+ def process_argspec_fn(self):
+ """Returns the Python callable that will eventually be invoked.
+
+ This should ideally be the user-level function that is called with
+ the main and (if any) side inputs, and is used to relate the type
+ hint parameters with the input parameters (e.g., by argument name).
+ """
+ return self.process
+
+ def is_process_bounded(self):
+ """Checks if an object is a bound method on an instance."""
+ if not isinstance(self.process, types.MethodType):
+ return False # Not a method
+ if self.process.im_self is None:
+ return False # Method is not bound
+ if issubclass(self.process.im_class, type) or \
+ self.process.im_class is types.ClassType:
+ return False # Method is a classmethod
+ return True
+
+
+# TODO(Sourabh): Remove after migration to NewDoFn
class DoFn(WithTypeHints, HasDisplayData):
"""A function object used by a transform with custom processing.
@@ -577,7 +684,7 @@ class ParDo(PTransformWithSideInputs):
def __init__(self, fn_or_label, *args, **kwargs):
super(ParDo, self).__init__(fn_or_label, *args, **kwargs)
- if not isinstance(self.fn, DoFn):
+ if not isinstance(self.fn, (DoFn, NewDoFn)):
raise TypeError('ParDo must be called with a DoFn instance.')
def default_type_hints(self):
@@ -588,7 +695,9 @@ class ParDo(PTransformWithSideInputs):
self.fn.infer_output_type(input_type))
def make_fn(self, fn):
- return fn if isinstance(fn, DoFn) else CallableWrapperDoFn(fn)
+ if isinstance(fn, (DoFn, NewDoFn)):
+ return fn
+ return CallableWrapperDoFn(fn)
def process_argspec_fn(self):
return self.fn.process_argspec_fn()
http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/typehints/decorators.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py
index a300a3b..df15f1b 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -263,7 +263,7 @@ def getcallargs_forhints(func, *typeargs, **typekwargs):
for k, var in enumerate(reversed(argspec.args)):
if k >= len(argspec.defaults):
break
- if callargs.get(var, None) is argspec.defaults[-k]:
+ if callargs.get(var, None) is argspec.defaults[-k-1]:
callargs[var] = typehints.Any
# Patch up varargs and keywords
if argspec.varargs:
http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/typehints/typecheck.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/typehints/typecheck.py b/sdks/python/apache_beam/typehints/typecheck.py
index d759d55..7a10a5a 100644
--- a/sdks/python/apache_beam/typehints/typecheck.py
+++ b/sdks/python/apache_beam/typehints/typecheck.py
@@ -24,6 +24,7 @@ import types
from apache_beam.pvalue import SideOutputValue
from apache_beam.transforms.core import DoFn
+from apache_beam.transforms.core import NewDoFn
from apache_beam.transforms.window import WindowedValue
from apache_beam.typehints import check_constraint
from apache_beam.typehints import CompositeTypeHintError
@@ -162,3 +163,147 @@ class OutputCheckWrapperDoFn(DoFn):
'iterable. %s was returned instead.'
% type(output))
return output
+
+
+class AbstractDoFnWrapper(NewDoFn):
+ """An abstract class to create wrapper around NewDoFn"""
+
+ def __init__(self, dofn):
+ super(AbstractDoFnWrapper, self).__init__()
+ self.dofn = dofn
+
+ def _inspect_start_bundle(self):
+ return self.dofn.get_function_arguments('start_bundle')
+
+ def _inspect_process(self):
+ return self.dofn.get_function_arguments('process')
+
+ def _inspect_finish_bundle(self):
+ return self.dofn.get_function_arguments('finish_bundle')
+
+ def wrapper(self, method, args, kwargs):
+ return method(*args, **kwargs)
+
+ def start_bundle(self, *args, **kwargs):
+ return self.wrapper(self.dofn.start_bundle, args, kwargs)
+
+ def process(self, *args, **kwargs):
+ return self.wrapper(self.dofn.process, args, kwargs)
+
+ def finish_bundle(self, *args, **kwargs):
+ return self.wrapper(self.dofn.finish_bundle, args, kwargs)
+
+ def is_process_bounded(self):
+ return self.dofn.is_process_bounded()
+
+
+class OutputCheckWrapperNewDoFn(AbstractDoFnWrapper):
+ """A DoFn that verifies against common errors in the output type."""
+
+ def __init__(self, dofn, full_label):
+ super(OutputCheckWrapperNewDoFn, self).__init__(dofn)
+ self.full_label = full_label
+
+ def wrapper(self, method, args, kwargs):
+ try:
+ result = method(*args, **kwargs)
+ except TypeCheckError as e:
+ error_msg = ('Runtime type violation detected within ParDo(%s): '
+ '%s' % (self.full_label, e))
+ raise TypeCheckError, error_msg, sys.exc_info()[2]
+ else:
+ return self._check_type(result)
+
+ def _check_type(self, output):
+ if output is None:
+ return output
+ elif isinstance(output, (dict, basestring)):
+ object_type = type(output).__name__
+ raise TypeCheckError('Returning a %s from a ParDo or FlatMap is '
+ 'discouraged. Please use list("%s") if you really '
+ 'want this behavior.' %
+ (object_type, output))
+ elif not isinstance(output, collections.Iterable):
+ raise TypeCheckError('FlatMap and ParDo must return an '
+ 'iterable. %s was returned instead.'
+ % type(output))
+ return output
+
+
+class TypeCheckWrapperNewDoFn(AbstractDoFnWrapper):
+ """A wrapper around a DoFn which performs type-checking of input and output.
+ """
+
+ def __init__(self, dofn, type_hints, label=None):
+ super(TypeCheckWrapperNewDoFn, self).__init__(dofn)
+ self.dofn = dofn
+ self._process_fn = self.dofn.process_argspec_fn()
+ if type_hints.input_types:
+ input_args, input_kwargs = type_hints.input_types
+ self._input_hints = getcallargs_forhints(
+ self._process_fn, *input_args, **input_kwargs)
+ else:
+ self._input_hints = None
+ # TODO(robertwb): Multi-output.
+ self._output_type_hint = type_hints.simple_output_type(label)
+
+ def wrapper(self, method, args, kwargs):
+ result = method(*args, **kwargs)
+ return self._type_check_result(result)
+
+ def process(self, *args, **kwargs):
+ if self._input_hints:
+ actual_inputs = inspect.getcallargs(self._process_fn, *args, **kwargs)
+ for var, hint in self._input_hints.items():
+ if hint is actual_inputs[var]:
+ # self parameter
+ continue
+ _check_instance_type(hint, actual_inputs[var], var, True)
+ return self._type_check_result(self.dofn.process(*args, **kwargs))
+
+ def _type_check_result(self, transform_results):
+ if self._output_type_hint is None or transform_results is None:
+ return transform_results
+
+ def type_check_output(o):
+ # TODO(robertwb): Multi-output.
+ x = o.value if isinstance(o, (SideOutputValue, WindowedValue)) else o
+ self._type_check(self._output_type_hint, x, is_input=False)
+
+ # If the return type is a generator, then we will need to interleave our
+ # type-checking with its normal iteration so we don't deplete the
+ # generator initially just by type-checking its yielded contents.
+ if isinstance(transform_results, types.GeneratorType):
+ return GeneratorWrapper(transform_results, type_check_output)
+ else:
+ for o in transform_results:
+ type_check_output(o)
+ return transform_results
+
+ def _type_check(self, type_constraint, datum, is_input):
+ """Typecheck a PTransform related datum according to a type constraint.
+
+ This function is used to optionally type-check either an input or an output
+ to a PTransform.
+
+ Args:
+ type_constraint: An instance of a typehints.TypeContraint, one of the
+ white-listed builtin Python types, or a custom user class.
+ datum: An instance of a Python object.
+ is_input: True if 'datum' is an input to a PTransform's DoFn. False
+ otherwise.
+
+ Raises:
+ TypeError: If 'datum' fails to type-check according to 'type_constraint'.
+ """
+ datum_type = 'input' if is_input else 'output'
+
+ try:
+ check_constraint(type_constraint, datum)
+ except CompositeTypeHintError as e:
+ raise TypeCheckError, e.message, sys.exc_info()[2]
+ except SimpleTypeHintError:
+ error_msg = ("According to type-hint expected %s should be of type %s. "
+ "Instead, received '%s', an instance of type %s."
+ % (datum_type, type_constraint, datum, type(datum)))
+ raise TypeCheckError, error_msg, sys.exc_info()[2]
[2/2] beam git commit: Closes #1805
Posted by ro...@apache.org.
Closes #1805
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/d0474ab5
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/d0474ab5
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/d0474ab5
Branch: refs/heads/python-sdk
Commit: d0474ab5bf01339fd95b0ec6c1db4b226f868d61
Parents: 946135f 9e272ec
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Sat Jan 21 20:37:08 2017 -0800
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Sat Jan 21 20:37:08 2017 -0800
----------------------------------------------------------------------
sdks/python/apache_beam/pipeline_test.py | 100 ++++++++-
sdks/python/apache_beam/runners/common.pxd | 4 +
sdks/python/apache_beam/runners/common.py | 221 +++++++++++++------
.../runners/direct/transform_evaluator.py | 15 +-
sdks/python/apache_beam/transforms/core.py | 113 +++++++++-
sdks/python/apache_beam/typehints/decorators.py | 2 +-
sdks/python/apache_beam/typehints/typecheck.py | 145 ++++++++++++
7 files changed, 531 insertions(+), 69 deletions(-)
----------------------------------------------------------------------