You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2017/01/30 23:03:10 UTC

[03/50] [abbrv] beam git commit: Implement Annotation based NewDoFn in python SDK

Implement Annotation based NewDoFn in python SDK


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/9e272ecf
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/9e272ecf
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/9e272ecf

Branch: refs/heads/master
Commit: 9e272ecf639b7b13f23a83868fd101a437159c1c
Parents: 946135f
Author: Sourabh Bajaj <so...@google.com>
Authored: Fri Jan 20 17:17:25 2017 -0800
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Sat Jan 21 20:37:07 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/pipeline_test.py        | 100 ++++++++-
 sdks/python/apache_beam/runners/common.pxd      |   4 +
 sdks/python/apache_beam/runners/common.py       | 221 +++++++++++++------
 .../runners/direct/transform_evaluator.py       |  15 +-
 sdks/python/apache_beam/transforms/core.py      | 113 +++++++++-
 sdks/python/apache_beam/typehints/decorators.py |   2 +-
 sdks/python/apache_beam/typehints/typecheck.py  | 145 ++++++++++++
 7 files changed, 531 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/pipeline_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index 336bf54..93b68d1 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -24,15 +24,23 @@ import unittest
 from apache_beam.pipeline import Pipeline
 from apache_beam.pipeline import PipelineOptions
 from apache_beam.pipeline import PipelineVisitor
+from apache_beam.pvalue import AsSingleton
 from apache_beam.runners.dataflow.native_io.iobase import NativeSource
 from apache_beam.test_pipeline import TestPipeline
 from apache_beam.transforms import CombineGlobally
 from apache_beam.transforms import Create
 from apache_beam.transforms import FlatMap
 from apache_beam.transforms import Map
+from apache_beam.transforms import NewDoFn
+from apache_beam.transforms import ParDo
 from apache_beam.transforms import PTransform
 from apache_beam.transforms import Read
-from apache_beam.transforms.util import assert_that, equal_to
+from apache_beam.transforms import WindowInto
+from apache_beam.transforms.util import assert_that
+from apache_beam.transforms.util import equal_to
+from apache_beam.transforms.window import IntervalWindow
+from apache_beam.transforms.window import WindowFn
+from apache_beam.utils.timestamp import MIN_TIMESTAMP
 
 
 class FakeSource(NativeSource):
@@ -241,6 +249,96 @@ class PipelineTest(unittest.TestCase):
     self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x))
 
 
+class NewDoFnTest(unittest.TestCase):
+
+  def setUp(self):
+    self.runner_name = 'DirectRunner'
+
+  def test_element(self):
+    class TestDoFn(NewDoFn):
+      def process(self, element):
+        yield element + 10
+
+    pipeline = TestPipeline(runner=self.runner_name)
+    pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
+    assert_that(pcoll, equal_to([11, 12]))
+    pipeline.run()
+
+  def test_context_param(self):
+    class TestDoFn(NewDoFn):
+      def process(self, element, context=NewDoFn.ContextParam):
+        yield context.element + 10
+
+    pipeline = TestPipeline(runner=self.runner_name)
+    pcoll = pipeline | 'Create' >> Create([1, 2])| 'Do' >> ParDo(TestDoFn())
+    assert_that(pcoll, equal_to([11, 12]))
+    pipeline.run()
+
+  def test_side_input_no_tag(self):
+    class TestDoFn(NewDoFn):
+      def process(self, element, prefix, suffix):
+        return ['%s-%s-%s' % (prefix, element, suffix)]
+
+    pipeline = TestPipeline()
+    words_list = ['aa', 'bb', 'cc']
+    words = pipeline | 'SomeWords' >> Create(words_list)
+    prefix = 'zyx'
+    suffix = pipeline | 'SomeString' >> Create(['xyz'])  # side in
+    result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
+        TestDoFn(), prefix, suffix=AsSingleton(suffix))
+    assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
+    pipeline.run()
+
+  def test_side_input_tagged(self):
+    class TestDoFn(NewDoFn):
+      def process(self, element, prefix, suffix=NewDoFn.SideInputParam):
+        return ['%s-%s-%s' % (prefix, element, suffix)]
+
+    pipeline = TestPipeline()
+    words_list = ['aa', 'bb', 'cc']
+    words = pipeline | 'SomeWords' >> Create(words_list)
+    prefix = 'zyx'
+    suffix = pipeline | 'SomeString' >> Create(['xyz'])  # side in
+    result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
+        TestDoFn(), prefix, suffix=AsSingleton(suffix))
+    assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
+    pipeline.run()
+
+  def test_window_param(self):
+    class TestDoFn(NewDoFn):
+      def process(self, element, window=NewDoFn.WindowParam):
+        yield (float(window.start), float(window.end))
+
+    class TestWindowFn(WindowFn):
+      """Windowing function adding two disjoint windows to each element."""
+
+      def assign(self, assign_context):
+        _ = assign_context
+        return [IntervalWindow(10, 20), IntervalWindow(20, 30)]
+
+      def merge(self, existing_windows):
+        return existing_windows
+
+    pipeline = TestPipeline(runner=self.runner_name)
+    pcoll = (pipeline
+             | 'KVs' >> Create([(1, 10), (2, 20)])
+             | 'W' >> WindowInto(windowfn=TestWindowFn())
+             | 'Do' >> ParDo(TestDoFn()))
+    assert_that(pcoll, equal_to([(10.0, 20.0), (10.0, 20.0),
+                                 (20.0, 30.0), (20.0, 30.0)]))
+    pipeline.run()
+
+  def test_timestamp_param(self):
+    class TestDoFn(NewDoFn):
+      def process(self, element, timestamp=NewDoFn.TimestampParam):
+        yield timestamp
+
+    pipeline = TestPipeline(runner=self.runner_name)
+    pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
+    assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
+    pipeline.run()
+
+
 class Bacon(PipelineOptions):
 
   @classmethod

http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/common.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd
index 085fd11..06fe434 100644
--- a/sdks/python/apache_beam/runners/common.pxd
+++ b/sdks/python/apache_beam/runners/common.pxd
@@ -36,6 +36,10 @@ cdef class DoFnRunner(Receiver):
   cdef object tagged_receivers
   cdef LoggingContext logging_context
   cdef object step_name
+  cdef object is_new_dofn
+  cdef object args
+  cdef object kwargs
+  cdef object side_inputs
   cdef bint has_windowed_side_inputs
 
   cdef Receiver main_receivers

http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/common.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py
index cc834ba..0f63cbc 100644
--- a/sdks/python/apache_beam/runners/common.py
+++ b/sdks/python/apache_beam/runners/common.py
@@ -71,50 +71,21 @@ class DoFnRunner(Receiver):
                # Preferred alternative to context
                # TODO(robertwb): Remove once all runners are updated.
                state=None):
-    self.has_windowed_side_inputs = False  # Set to True in one case below.
-    if not args and not kwargs:
-      self.dofn = fn
-      self.dofn_process = fn.process
-    else:
-      global_window = window.GlobalWindow()
-      # TODO(robertwb): Remove when all runners pass side input maps.
-      side_inputs = [side_input
-                     if isinstance(side_input, sideinputs.SideInputMap)
-                     else {global_window: side_input}
-                     for side_input in side_inputs]
-      if side_inputs and all(
-          isinstance(side_input, dict) or side_input.is_globally_windowed()
-          for side_input in side_inputs):
-        args, kwargs = util.insert_values_in_args(
-            args, kwargs, [side_input[global_window]
-                           for side_input in side_inputs])
-        side_inputs = []
-      if side_inputs:
-        self.has_windowed_side_inputs = True
-
-        def process(context):
-          w = context.windows[0]
-          cur_args, cur_kwargs = util.insert_values_in_args(
-              args, kwargs, [side_input[w] for side_input in side_inputs])
-          return fn.process(context, *cur_args, **cur_kwargs)
-        self.dofn_process = process
-      elif kwargs:
-        self.dofn_process = lambda context: fn.process(context, *args, **kwargs)
-      else:
-        self.dofn_process = lambda context: fn.process(context, *args)
-
-      class CurriedFn(core.DoFn):
+    self.step_name = step_name
+    self.window_fn = windowing.windowfn
+    self.tagged_receivers = tagged_receivers
 
-        start_bundle = staticmethod(fn.start_bundle)
-        process = staticmethod(self.dofn_process)
-        finish_bundle = staticmethod(fn.finish_bundle)
+    global_window = window.GlobalWindow()
 
-      self.dofn = CurriedFn()
+    if logging_context:
+      self.logging_context = logging_context
+    else:
+      self.logging_context = get_logging_context(logger, step_name=step_name)
 
-    self.window_fn = windowing.windowfn
-    self.tagged_receivers = tagged_receivers
-    self.step_name = step_name
+    # Optimize for the common case.
+    self.main_receivers = as_receiver(tagged_receivers[None])
 
+    # TODO(sourabh): Deprecate the use of context
     if state:
       assert context is None
       self.context = DoFnContext(self.step_name, state=state)
@@ -122,48 +93,172 @@ class DoFnRunner(Receiver):
       assert context is not None
       self.context = context
 
-    if logging_context:
-      self.logging_context = logging_context
+    # TODO(Sourabhbajaj): Remove the usage of OldDoFn
+    if isinstance(fn, core.NewDoFn):
+      self.is_new_dofn = True
+
+      # SideInputs
+      self.side_inputs = [side_input
+                          if isinstance(side_input, sideinputs.SideInputMap)
+                          else {global_window: side_input}
+                          for side_input in side_inputs]
+      self.has_windowed_side_inputs = not all(
+          isinstance(si, dict) or si.is_globally_windowed()
+          for si in self.side_inputs)
+
+      self.args = args if args else []
+      self.kwargs = kwargs if kwargs else {}
+      self.dofn = fn
+
     else:
-      self.logging_context = get_logging_context(logger, step_name=step_name)
+      self.is_new_dofn = False
+      self.has_windowed_side_inputs = False  # Set to True in one case below.
+      if not args and not kwargs:
+        self.dofn = fn
+        self.dofn_process = fn.process
+      else:
+        # TODO(robertwb): Remove when all runners pass side input maps.
+        side_inputs = [side_input
+                       if isinstance(side_input, sideinputs.SideInputMap)
+                       else {global_window: side_input}
+                       for side_input in side_inputs]
+        if side_inputs and all(
+            isinstance(side_input, dict) or side_input.is_globally_windowed()
+            for side_input in side_inputs):
+          args, kwargs = util.insert_values_in_args(
+              args, kwargs, [side_input[global_window]
+                             for side_input in side_inputs])
+          side_inputs = []
+        if side_inputs:
+          self.has_windowed_side_inputs = True
+
+          def process(context):
+            w = context.windows[0]
+            cur_args, cur_kwargs = util.insert_values_in_args(
+                args, kwargs, [side_input[w] for side_input in side_inputs])
+            return fn.process(context, *cur_args, **cur_kwargs)
+          self.dofn_process = process
+        elif kwargs:
+          self.dofn_process = lambda context: fn.process(
+              context, *args, **kwargs)
+        else:
+          self.dofn_process = lambda context: fn.process(context, *args)
 
-    # Optimize for the common case.
-    self.main_receivers = as_receiver(tagged_receivers[None])
+        class CurriedFn(core.DoFn):
+
+          start_bundle = staticmethod(fn.start_bundle)
+          process = staticmethod(self.dofn_process)
+          finish_bundle = staticmethod(fn.finish_bundle)
+
+        self.dofn = CurriedFn()
 
   def receive(self, windowed_value):
     self.process(windowed_value)
 
-  def start(self):
-    self.context.set_element(None)
+  def old_dofn_process(self, element):
+    if self.has_windowed_side_inputs and len(element.windows) > 1:
+      for w in element.windows:
+        self.context.set_element(
+            WindowedValue(element.value, element.timestamp, (w,)))
+        self._process_outputs(element, self.dofn_process(self.context))
+    else:
+      self.context.set_element(element)
+      self._process_outputs(element, self.dofn_process(self.context))
+
+  def new_dofn_process(self, element):
+    self.context.set_element(element)
+    arguments, _, _, defaults = self.dofn.get_function_arguments('process')
+    defaults = defaults if defaults else []
+
+    self_in_args = int(self.dofn.is_process_bounded())
+
+    # Call for the process function for each window if has windowed side inputs
+    # or if the process accesses the window parameter. We can just call it once
+    # otherwise as none of the arguments are changing
+    if self.has_windowed_side_inputs or core.NewDoFn.WindowParam in defaults:
+      windows = element.windows
+    else:
+      windows = [window.GlobalWindow()]
+
+    for w in windows:
+      args, kwargs = util.insert_values_in_args(
+          self.args, self.kwargs,
+          [s[w] for s in self.side_inputs])
+
+      # If there are more arguments than the default then the first argument
+      # should be the element and the rest should be picked from the side
+      # inputs as window and timestamp should always be tagged
+      if len(arguments) > len(defaults) + self_in_args:
+        if core.NewDoFn.ElementParam not in defaults:
+          args_to_pick = len(arguments) - len(defaults) - 1 - self_in_args
+          final_args = [element.value] + args[:args_to_pick]
+        else:
+          args_to_pick = len(arguments) - len(defaults) - self_in_args
+          final_args = args[:args_to_pick]
+      else:
+        args_to_pick = 0
+        final_args = []
+      args = iter(args[args_to_pick:])
+
+      for a, d in zip(arguments[-len(defaults):], defaults):
+        if d == core.NewDoFn.ElementParam:
+          final_args.append(element.value)
+        elif d == core.NewDoFn.ContextParam:
+          final_args.append(self.context)
+        elif d == core.NewDoFn.WindowParam:
+          final_args.append(w)
+        elif d == core.NewDoFn.TimestampParam:
+          final_args.append(element.timestamp)
+        elif d == core.NewDoFn.SideInputParam:
+          # If no more args are present then the value must be passed via kwarg
+          try:
+            final_args.append(args.next())
+          except StopIteration:
+            if a not in kwargs:
+              raise
+        else:
+          # If no more args are present then the value must be passed via kwarg
+          try:
+            final_args.append(args.next())
+          except StopIteration:
+            if a not in kwargs:
+              kwargs[a] = d
+      final_args.extend(list(args))
+      self._process_outputs(element, self.dofn.process(*final_args, **kwargs))
+
+  def _invoke_bundle_method(self, method):
     try:
       self.logging_context.enter()
-      self._process_outputs(None, self.dofn.start_bundle(self.context))
+      self.context.set_element(None)
+      f = getattr(self.dofn, method)
+
+      # TODO(Sourabhbajaj): Remove this if-else
+      if self.is_new_dofn:
+        _, _, _, defaults = self.dofn.get_function_arguments(method)
+        defaults = defaults if defaults else []
+        args = [self.context if d == core.NewDoFn.ContextParam else d
+                for d in defaults]
+        self._process_outputs(None, f(*args))
+      else:
+        self._process_outputs(None, f(self.context))
     except BaseException as exn:
       self.reraise_augmented(exn)
     finally:
       self.logging_context.exit()
 
+  def start(self):
+    self._invoke_bundle_method('start_bundle')
+
   def finish(self):
-    self.context.set_element(None)
-    try:
-      self.logging_context.enter()
-      self._process_outputs(None, self.dofn.finish_bundle(self.context))
-    except BaseException as exn:
-      self.reraise_augmented(exn)
-    finally:
-      self.logging_context.exit()
+    self._invoke_bundle_method('finish_bundle')
 
   def process(self, element):
     try:
       self.logging_context.enter()
-      if self.has_windowed_side_inputs and len(element.windows) > 1:
-        for w in element.windows:
-          self.context.set_element(
-              WindowedValue(element.value, element.timestamp, (w,)))
-          self._process_outputs(element, self.dofn_process(self.context))
+      if self.is_new_dofn:
+        self.new_dofn_process(element)
       else:
-        self.context.set_element(element)
-        self._process_outputs(element, self.dofn_process(self.context))
+        self.old_dofn_process(element)
     except BaseException as exn:
       self.reraise_augmented(exn)
     finally:

http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/runners/direct/transform_evaluator.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/direct/transform_evaluator.py b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
index b4c43ba..ec2b3a1 100644
--- a/sdks/python/apache_beam/runners/direct/transform_evaluator.py
+++ b/sdks/python/apache_beam/runners/direct/transform_evaluator.py
@@ -35,8 +35,10 @@ from apache_beam.transforms import sideinputs
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.transforms.window import WindowedValue
 from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn
+from apache_beam.typehints.typecheck import OutputCheckWrapperNewDoFn
 from apache_beam.typehints.typecheck import TypeCheckError
 from apache_beam.typehints.typecheck import TypeCheckWrapperDoFn
+from apache_beam.typehints.typecheck import TypeCheckWrapperNewDoFn
 from apache_beam.utils import counters
 from apache_beam.utils.pipeline_options import TypeOptions
 
@@ -344,9 +346,18 @@ class _ParDoEvaluator(_TransformEvaluator):
     pipeline_options = self._evaluation_context.pipeline_options
     if (pipeline_options is not None
         and pipeline_options.view_as(TypeOptions).runtime_type_check):
-      dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())
+      # TODO(sourabhbajaj): Remove this if-else
+      if isinstance(dofn, core.NewDoFn):
+        dofn = TypeCheckWrapperNewDoFn(dofn, transform.get_type_hints())
+      else:
+        dofn = TypeCheckWrapperDoFn(dofn, transform.get_type_hints())
 
-    dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
+    # TODO(sourabhbajaj): Remove this if-else
+    if isinstance(dofn, core.NewDoFn):
+      dofn = OutputCheckWrapperNewDoFn(
+          dofn, self._applied_ptransform.full_label)
+    else:
+      dofn = OutputCheckWrapperDoFn(dofn, self._applied_ptransform.full_label)
     self.runner = DoFnRunner(dofn, transform.args, transform.kwargs,
                              self._side_inputs,
                              self._applied_ptransform.inputs[0].windowing,

http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/transforms/core.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index 72f7cd4..70a03ae 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -109,6 +109,7 @@ class DoFnProcessContext(DoFnContext):
       self.timestamp = windowed_value.timestamp
       self.windows = windowed_value.windows
 
+  # TODO(sourabhbajaj): Move as we're trying to deprecate the use of context
   def aggregate_to(self, aggregator, input_value):
     """Provide a new input value for the aggregator.
 
@@ -119,6 +120,112 @@ class DoFnProcessContext(DoFnContext):
     self.state.counter_for(aggregator).update(input_value)
 
 
+class NewDoFn(WithTypeHints, HasDisplayData):
+  """A function object used by a transform with custom processing.
+
+  The ParDo transform is such a transform. The ParDo.apply
+  method will take an object of type DoFn and apply it to all elements of a
+  PCollection object.
+
+  In order to have concrete DoFn objects one has to subclass from DoFn and
+  define the desired behavior (start_bundle/finish_bundle and process) or wrap a
+  callable object using the CallableWrapperDoFn class.
+  """
+
+  ElementParam = 'ElementParam'
+  ContextParam = 'ContextParam'
+  SideInputParam = 'SideInputParam'
+  TimestampParam = 'TimestampParam'
+  WindowParam = 'WindowParam'
+
+  @staticmethod
+  def from_callable(fn):
+    return CallableWrapperDoFn(fn)
+
+  def default_label(self):
+    return self.__class__.__name__
+
+  def process(self, element, *args, **kwargs):
+    """Called for each element of a pipeline. The default arguments are needed
+    for the DoFnRunner to be able to pass the parameters correctly.
+
+    Args:
+      element: The element to be processed
+      context: a DoFnProcessContext object containing. See the
+        DoFnProcessContext documentation for details.
+      *args: side inputs
+      **kwargs: keyword side inputs
+    """
+    raise NotImplementedError
+
+  def start_bundle(self):
+    """Called before a bundle of elements is processed on a worker.
+
+    Elements to be processed are split into bundles and distributed
+    to workers. Before a worker calls process() on the first element
+    of its bundle, it calls this method.
+    """
+    pass
+
+  def finish_bundle(self):
+    """Called after a bundle of elements is processed on a worker.
+    """
+    pass
+
+  def get_function_arguments(self, func):
+    """Return the function arguments based on the name provided. If they have
+    a _inspect_function attached to the class then use that otherwise default
+    to the python inspect library.
+    """
+    func_name = '_inspect_%s' % func
+    if hasattr(self, func_name):
+      f = getattr(self, func_name)
+      return f()
+    else:
+      f = getattr(self, func)
+      return inspect.getargspec(f)
+
+  # TODO(sourabhbajaj): Do we want to remove the responsiblity of these from
+  # the DoFn or maybe the runner
+  def infer_output_type(self, input_type):
+    # TODO(robertwb): Side inputs types.
+    # TODO(robertwb): Assert compatibility with input type hint?
+    return self._strip_output_annotations(
+        trivial_inference.infer_return_type(self.process, [input_type]))
+
+  def _strip_output_annotations(self, type_hint):
+    annotations = (window.TimestampedValue, window.WindowedValue,
+                   pvalue.SideOutputValue)
+    # TODO(robertwb): These should be parameterized types that the
+    # type inferencer understands.
+    if (type_hint in annotations
+        or trivial_inference.element_type(type_hint) in annotations):
+      return Any
+    else:
+      return type_hint
+
+  def process_argspec_fn(self):
+    """Returns the Python callable that will eventually be invoked.
+
+    This should ideally be the user-level function that is called with
+    the main and (if any) side inputs, and is used to relate the type
+    hint parameters with the input parameters (e.g., by argument name).
+    """
+    return self.process
+
+  def is_process_bounded(self):
+    """Checks if an object is a bound method on an instance."""
+    if not isinstance(self.process, types.MethodType):
+      return False # Not a method
+    if self.process.im_self is None:
+      return False # Method is not bound
+    if issubclass(self.process.im_class, type) or \
+        self.process.im_class is types.ClassType:
+      return False # Method is a classmethod
+    return True
+
+
+# TODO(Sourabh): Remove after migration to NewDoFn
 class DoFn(WithTypeHints, HasDisplayData):
   """A function object used by a transform with custom processing.
 
@@ -577,7 +684,7 @@ class ParDo(PTransformWithSideInputs):
   def __init__(self, fn_or_label, *args, **kwargs):
     super(ParDo, self).__init__(fn_or_label, *args, **kwargs)
 
-    if not isinstance(self.fn, DoFn):
+    if not isinstance(self.fn, (DoFn, NewDoFn)):
       raise TypeError('ParDo must be called with a DoFn instance.')
 
   def default_type_hints(self):
@@ -588,7 +695,9 @@ class ParDo(PTransformWithSideInputs):
         self.fn.infer_output_type(input_type))
 
   def make_fn(self, fn):
-    return fn if isinstance(fn, DoFn) else CallableWrapperDoFn(fn)
+    if isinstance(fn, (DoFn, NewDoFn)):
+      return fn
+    return CallableWrapperDoFn(fn)
 
   def process_argspec_fn(self):
     return self.fn.process_argspec_fn()

http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/typehints/decorators.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py
index a300a3b..df15f1b 100644
--- a/sdks/python/apache_beam/typehints/decorators.py
+++ b/sdks/python/apache_beam/typehints/decorators.py
@@ -263,7 +263,7 @@ def getcallargs_forhints(func, *typeargs, **typekwargs):
     for k, var in enumerate(reversed(argspec.args)):
       if k >= len(argspec.defaults):
         break
-      if callargs.get(var, None) is argspec.defaults[-k]:
+      if callargs.get(var, None) is argspec.defaults[-k-1]:
         callargs[var] = typehints.Any
   # Patch up varargs and keywords
   if argspec.varargs:

http://git-wip-us.apache.org/repos/asf/beam/blob/9e272ecf/sdks/python/apache_beam/typehints/typecheck.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/typehints/typecheck.py b/sdks/python/apache_beam/typehints/typecheck.py
index d759d55..7a10a5a 100644
--- a/sdks/python/apache_beam/typehints/typecheck.py
+++ b/sdks/python/apache_beam/typehints/typecheck.py
@@ -24,6 +24,7 @@ import types
 
 from apache_beam.pvalue import SideOutputValue
 from apache_beam.transforms.core import DoFn
+from apache_beam.transforms.core import NewDoFn
 from apache_beam.transforms.window import WindowedValue
 from apache_beam.typehints import check_constraint
 from apache_beam.typehints import CompositeTypeHintError
@@ -162,3 +163,147 @@ class OutputCheckWrapperDoFn(DoFn):
                            'iterable. %s was returned instead.'
                            % type(output))
     return output
+
+
+class AbstractDoFnWrapper(NewDoFn):
+  """An abstract class to create wrapper around NewDoFn"""
+
+  def __init__(self, dofn):
+    super(AbstractDoFnWrapper, self).__init__()
+    self.dofn = dofn
+
+  def _inspect_start_bundle(self):
+    return self.dofn.get_function_arguments('start_bundle')
+
+  def _inspect_process(self):
+    return self.dofn.get_function_arguments('process')
+
+  def _inspect_finish_bundle(self):
+    return self.dofn.get_function_arguments('finish_bundle')
+
+  def wrapper(self, method, args, kwargs):
+    return method(*args, **kwargs)
+
+  def start_bundle(self, *args, **kwargs):
+    return self.wrapper(self.dofn.start_bundle, args, kwargs)
+
+  def process(self, *args, **kwargs):
+    return self.wrapper(self.dofn.process, args, kwargs)
+
+  def finish_bundle(self, *args, **kwargs):
+    return self.wrapper(self.dofn.finish_bundle, args, kwargs)
+
+  def is_process_bounded(self):
+    return self.dofn.is_process_bounded()
+
+
+class OutputCheckWrapperNewDoFn(AbstractDoFnWrapper):
+  """A DoFn that verifies against common errors in the output type."""
+
+  def __init__(self, dofn, full_label):
+    super(OutputCheckWrapperNewDoFn, self).__init__(dofn)
+    self.full_label = full_label
+
+  def wrapper(self, method, args, kwargs):
+    try:
+      result = method(*args, **kwargs)
+    except TypeCheckError as e:
+      error_msg = ('Runtime type violation detected within ParDo(%s): '
+                   '%s' % (self.full_label, e))
+      raise TypeCheckError, error_msg, sys.exc_info()[2]
+    else:
+      return self._check_type(result)
+
+  def _check_type(self, output):
+    if output is None:
+      return output
+    elif isinstance(output, (dict, basestring)):
+      object_type = type(output).__name__
+      raise TypeCheckError('Returning a %s from a ParDo or FlatMap is '
+                           'discouraged. Please use list("%s") if you really '
+                           'want this behavior.' %
+                           (object_type, output))
+    elif not isinstance(output, collections.Iterable):
+      raise TypeCheckError('FlatMap and ParDo must return an '
+                           'iterable. %s was returned instead.'
+                           % type(output))
+    return output
+
+
+class TypeCheckWrapperNewDoFn(AbstractDoFnWrapper):
+  """A wrapper around a DoFn which performs type-checking of input and output.
+  """
+
+  def __init__(self, dofn, type_hints, label=None):
+    super(TypeCheckWrapperNewDoFn, self).__init__(dofn)
+    self.dofn = dofn
+    self._process_fn = self.dofn.process_argspec_fn()
+    if type_hints.input_types:
+      input_args, input_kwargs = type_hints.input_types
+      self._input_hints = getcallargs_forhints(
+          self._process_fn, *input_args, **input_kwargs)
+    else:
+      self._input_hints = None
+    # TODO(robertwb): Multi-output.
+    self._output_type_hint = type_hints.simple_output_type(label)
+
+  def wrapper(self, method, args, kwargs):
+    result = method(*args, **kwargs)
+    return self._type_check_result(result)
+
+  def process(self, *args, **kwargs):
+    if self._input_hints:
+      actual_inputs = inspect.getcallargs(self._process_fn, *args, **kwargs)
+      for var, hint in self._input_hints.items():
+        if hint is actual_inputs[var]:
+          # self parameter
+          continue
+        _check_instance_type(hint, actual_inputs[var], var, True)
+    return self._type_check_result(self.dofn.process(*args, **kwargs))
+
+  def _type_check_result(self, transform_results):
+    if self._output_type_hint is None or transform_results is None:
+      return transform_results
+
+    def type_check_output(o):
+      # TODO(robertwb): Multi-output.
+      x = o.value if isinstance(o, (SideOutputValue, WindowedValue)) else o
+      self._type_check(self._output_type_hint, x, is_input=False)
+
+    # If the return type is a generator, then we will need to interleave our
+    # type-checking with its normal iteration so we don't deplete the
+    # generator initially just by type-checking its yielded contents.
+    if isinstance(transform_results, types.GeneratorType):
+      return GeneratorWrapper(transform_results, type_check_output)
+    else:
+      for o in transform_results:
+        type_check_output(o)
+      return transform_results
+
+  def _type_check(self, type_constraint, datum, is_input):
+    """Typecheck a PTransform related datum according to a type constraint.
+
+    This function is used to optionally type-check either an input or an output
+    to a PTransform.
+
+    Args:
+        type_constraint: An instance of a typehints.TypeContraint, one of the
+          white-listed builtin Python types, or a custom user class.
+        datum: An instance of a Python object.
+        is_input: True if 'datum' is an input to a PTransform's DoFn. False
+          otherwise.
+
+    Raises:
+      TypeError: If 'datum' fails to type-check according to 'type_constraint'.
+    """
+    datum_type = 'input' if is_input else 'output'
+
+    try:
+      check_constraint(type_constraint, datum)
+    except CompositeTypeHintError as e:
+      raise TypeCheckError, e.message, sys.exc_info()[2]
+    except SimpleTypeHintError:
+      error_msg = ("According to type-hint expected %s should be of type %s. "
+                   "Instead, received '%s', an instance of type %s."
+                   % (datum_type, type_constraint, datum, type(datum)))
+      raise TypeCheckError, error_msg, sys.exc_info()[2]