You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/08/04 06:45:34 UTC

[2/2] beam git commit: Refactor FnApiRunner to operate directly on the runner API protos.

Refactor FnApiRunner to operate directly on the runner API protos.

This allows for optimization and execution of pipelines in other langauges
over the Fn API (modulo aligning URNs and using the runner API for Coders).

The only portions of the pipeline that are deserialized are the Coders.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e71d53e
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e71d53e
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e71d53e

Branch: refs/heads/master
Commit: 5e71d53ef8c28ec14b6a282b1fe67489c2b3f243
Parents: 9e6530a
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Thu Aug 3 09:26:49 2017 -0700
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Thu Aug 3 23:44:40 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/coders/stream.pxd       |   2 +-
 sdks/python/apache_beam/coders/stream.pyx       |   2 +-
 .../apache_beam/runners/pipeline_context.py     |   8 +-
 .../runners/portability/fn_api_runner.py        | 569 ++++++++++++++++++-
 .../runners/portability/fn_api_runner_test.py   |   2 +-
 .../runners/worker/bundle_processor.py          | 107 +++-
 sdks/python/apache_beam/transforms/core.py      |   6 +-
 7 files changed, 678 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/5e71d53e/sdks/python/apache_beam/coders/stream.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/coders/stream.pxd b/sdks/python/apache_beam/coders/stream.pxd
index 4e01a89..ade9b72 100644
--- a/sdks/python/apache_beam/coders/stream.pxd
+++ b/sdks/python/apache_beam/coders/stream.pxd
@@ -53,7 +53,7 @@ cdef class InputStream(object):
   cdef bytes all
   cdef char* allc
 
-  cpdef size_t size(self) except? -1
+  cpdef ssize_t size(self) except? -1
   cpdef bytes read(self, size_t len)
   cpdef long read_byte(self) except? -1
   cpdef libc.stdint.int64_t read_var_int64(self) except? -1

http://git-wip-us.apache.org/repos/asf/beam/blob/5e71d53e/sdks/python/apache_beam/coders/stream.pyx
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/coders/stream.pyx b/sdks/python/apache_beam/coders/stream.pyx
index 8d97681..7c9521a 100644
--- a/sdks/python/apache_beam/coders/stream.pyx
+++ b/sdks/python/apache_beam/coders/stream.pyx
@@ -167,7 +167,7 @@ cdef class InputStream(object):
     # unsigned char here.
     return <long>(<unsigned char> self.allc[self.pos - 1])
 
-  cpdef size_t size(self) except? -1:
+  cpdef ssize_t size(self) except? -1:
     return len(self.all) - self.pos
 
   cpdef bytes read_all(self, bint nested=False):

http://git-wip-us.apache.org/repos/asf/beam/blob/5e71d53e/sdks/python/apache_beam/runners/pipeline_context.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py
index f4de42a..42d7f5d 100644
--- a/sdks/python/apache_beam/runners/pipeline_context.py
+++ b/sdks/python/apache_beam/runners/pipeline_context.py
@@ -40,7 +40,7 @@ class _PipelineContextMap(object):
     self._obj_type = obj_type
     self._obj_to_id = {}
     self._id_to_obj = {}
-    self._id_to_proto = proto_map if proto_map else {}
+    self._id_to_proto = dict(proto_map) if proto_map else {}
     self._counter = 0
 
   def _unique_ref(self, obj=None, label=None):
@@ -66,6 +66,12 @@ class _PipelineContextMap(object):
           self._id_to_proto[id], self._pipeline_context)
     return self._id_to_obj[id]
 
+  def __getitem__(self, id):
+    return self.get_by_id(id)
+
+  def __contains__(self, id):
+    return id in self._id_to_proto
+
 
 class PipelineContext(object):
   """For internal use only; no backwards-compatibility guarantees.

http://git-wip-us.apache.org/repos/asf/beam/blob/5e71d53e/sdks/python/apache_beam/runners/portability/fn_api_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index f88fe53..3222bcb 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -19,6 +19,7 @@
 """
 import base64
 import collections
+import copy
 import logging
 import Queue as queue
 import threading
@@ -28,21 +29,26 @@ from google.protobuf import wrappers_pb2
 import grpc
 
 import apache_beam as beam  # pylint: disable=ungrouped-imports
+from apache_beam.coders import registry
 from apache_beam.coders import WindowedValueCoder
 from apache_beam.coders.coder_impl import create_InputStream
 from apache_beam.coders.coder_impl import create_OutputStream
 from apache_beam.internal import pickler
 from apache_beam.io import iobase
-from apache_beam.transforms.window import GlobalWindows
+from apache_beam.metrics.execution import MetricsEnvironment
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners import pipeline_context
 from apache_beam.runners.portability import maptask_executor_runner
+from apache_beam.runners.runner import PipelineState
 from apache_beam.runners.worker import bundle_processor
 from apache_beam.runners.worker import data_plane
 from apache_beam.runners.worker import operation_specs
 from apache_beam.runners.worker import sdk_worker
+from apache_beam.transforms.window import GlobalWindows
 from apache_beam.utils import proto_utils
+from apache_beam.utils import urns
+
 
 # This module is experimental. No backwards-compatibility guarantees.
 
@@ -113,6 +119,30 @@ OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps(
      beam.transforms.core.Windowing(GlobalWindows())))
 
 
+class _GroupingBuffer(object):
+  """Used to accumulate groupded (shuffled) results."""
+  def __init__(self, pre_grouped_coder, post_grouped_coder):
+    self._key_coder = pre_grouped_coder.value_coder().key_coder()
+    self._pre_grouped_coder = pre_grouped_coder
+    self._post_grouped_coder = post_grouped_coder
+    self._table = collections.defaultdict(list)
+
+  def append(self, elements_data):
+    input_stream = create_InputStream(elements_data)
+    while input_stream.size() > 0:
+      key, value = self._pre_grouped_coder.get_impl().decode_from_stream(
+          input_stream, True).value
+      self._table[self._key_coder.encode(key)].append(value)
+
+  def __iter__(self):
+    output_stream = create_OutputStream()
+    for encoded_key, values in self._table.items():
+      key = self._key_coder.decode(encoded_key)
+      self._post_grouped_coder.get_impl().encode_to_stream(
+          GlobalWindows.windowed_value((key, values)), output_stream, True)
+    return iter([output_stream.get()])
+
+
 class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
 
   def __init__(self):
@@ -126,6 +156,520 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
     self._last_uid += 1
     return str(self._last_uid)
 
+  def run(self, pipeline):
+    MetricsEnvironment.set_metrics_supported(self.has_metrics_support())
+    if pipeline._verify_runner_api_compatible():
+      return self.run_via_runner_api(pipeline.to_runner_api())
+    else:
+      return super(FnApiRunner, self).run(pipeline)
+
+  def run_via_runner_api(self, pipeline_proto):
+    return self.run_stages(*self.create_stages(pipeline_proto))
+
+  def create_stages(self, pipeline_proto):
+
+    # First define a couple of helpers.
+
+    def union(a, b):
+      # Minimize the number of distinct sets.
+      if not a or a == b:
+        return b
+      elif not b:
+        return a
+      else:
+        return frozenset.union(a, b)
+
+    class Stage(object):
+      """A set of Transforms that can be sent to the worker for processing."""
+      def __init__(self, name, transforms,
+                   downstream_side_inputs=None, must_follow=frozenset()):
+        self.name = name
+        self.transforms = transforms
+        self.downstream_side_inputs = downstream_side_inputs
+        self.must_follow = must_follow
+
+      def __repr__(self):
+        must_follow = ', '.join(prev.name for prev in self.must_follow)
+        return "%s\n    %s\n    must follow: %s" % (
+            self.name,
+            '\n'.join(["%s:%s" % (transform.unique_name, transform.spec.urn)
+                       for transform in self.transforms]),
+            must_follow)
+
+      def can_fuse(self, consumer):
+        def no_overlap(a, b):
+          return not a.intersection(b)
+        return (
+            not self in consumer.must_follow
+            and not self.is_flatten() and not consumer.is_flatten()
+            and no_overlap(self.downstream_side_inputs, consumer.side_inputs()))
+
+      def fuse(self, other):
+        return Stage(
+            "(%s)+(%s)" % (self.name, other.name),
+            self.transforms + other.transforms,
+            union(self.downstream_side_inputs, other.downstream_side_inputs),
+            union(self.must_follow, other.must_follow))
+
+      def is_flatten(self):
+        return any(transform.spec.urn == urns.FLATTEN_TRANSFORM
+                   for transform in self.transforms)
+
+      def side_inputs(self):
+        for transform in self.transforms:
+          if transform.spec.urn == urns.PARDO_TRANSFORM:
+            payload = proto_utils.unpack_Any(
+                transform.spec.parameter, beam_runner_api_pb2.ParDoPayload)
+            for side_input in payload.side_inputs:
+              yield transform.inputs[side_input]
+
+      def has_as_main_input(self, pcoll):
+        for transform in self.transforms:
+          if transform.spec.urn == urns.PARDO_TRANSFORM:
+            payload = proto_utils.unpack_Any(
+                transform.spec.parameter, beam_runner_api_pb2.ParDoPayload)
+            local_side_inputs = payload.side_inputs
+          else:
+            local_side_inputs = {}
+          for local_id, pipeline_id in transform.inputs.items():
+            if pcoll == pipeline_id and local_id not in local_side_inputs:
+              return True
+
+      def deduplicate_read(self):
+        seen_pcolls = set()
+        new_transforms = []
+        for transform in self.transforms:
+          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
+            pcoll = only_element(transform.outputs.items())[1]
+            if pcoll in seen_pcolls:
+              continue
+            seen_pcolls.add(pcoll)
+          new_transforms.append(transform)
+        self.transforms = new_transforms
+
+    # Now define the "optimization" phases.
+
+    def expand_gbk(stages):
+      """Transforms each GBK into a write followed by a read.
+      """
+      for stage in stages:
+        assert len(stage.transforms) == 1
+        transform = stage.transforms[0]
+        if transform.spec.urn == urns.GROUP_BY_KEY_ONLY_TRANSFORM:
+          # This is used later to correlate the read and write.
+          param = proto_utils.pack_Any(
+              wrappers_pb2.BytesValue(
+                  value=str("group:%s" % stage.name)))
+          gbk_write = Stage(
+              transform.unique_name + '/Write',
+              [beam_runner_api_pb2.PTransform(
+                  unique_name=transform.unique_name + '/Write',
+                  inputs=transform.inputs,
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.DATA_OUTPUT_URN,
+                      parameter=param))],
+              downstream_side_inputs=frozenset(),
+              must_follow=stage.must_follow)
+          yield gbk_write
+
+          yield Stage(
+              transform.unique_name + '/Read',
+              [beam_runner_api_pb2.PTransform(
+                  unique_name=transform.unique_name + '/Read',
+                  outputs=transform.outputs,
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.DATA_INPUT_URN,
+                      parameter=param))],
+              downstream_side_inputs=frozenset(),
+              must_follow=union(frozenset([gbk_write]), stage.must_follow))
+        else:
+          yield stage
+
+    def sink_flattens(stages):
+      """Sink flattens and remove them from the graph.
+
+      A flatten that cannot be sunk/fused away becomes multiple writes (to the
+      same logical sink) followed by a read.
+      """
+      # TODO(robertwb): Actually attempt to sink rather than always materialize.
+      # TODO(robertwb): Possibly fuse this into one of the stages.
+      pcollections = pipeline_components.pcollections
+      for stage in stages:
+        assert len(stage.transforms) == 1
+        transform = stage.transforms[0]
+        if transform.spec.urn == urns.FLATTEN_TRANSFORM:
+          # This is used later to correlate the read and writes.
+          param = proto_utils.pack_Any(
+              wrappers_pb2.BytesValue(
+                  value=str("materialize:%s" % transform.unique_name)))
+          output_pcoll_id, = transform.outputs.values()
+          output_coder_id = pcollections[output_pcoll_id].coder_id
+          flatten_writes = []
+          for local_in, pcoll_in in transform.inputs.items():
+
+            if pcollections[pcoll_in].coder_id != output_coder_id:
+              # Flatten inputs must all be written with the same coder as is
+              # used to read them.
+              pcollections[pcoll_in].coder_id = output_coder_id
+              transcoded_pcollection = (
+                  transform.unique_name + '/Transcode/' + local_in + '/out')
+              yield Stage(
+                  transform.unique_name + '/Transcode/' + local_in,
+                  [beam_runner_api_pb2.PTransform(
+                      unique_name=
+                      transform.unique_name + '/Transcode/' + local_in,
+                      inputs={local_in: pcoll_in},
+                      outputs={'out': transcoded_pcollection},
+                      spec=beam_runner_api_pb2.FunctionSpec(
+                          urn=bundle_processor.IDENTITY_DOFN_URN))],
+                  downstream_side_inputs=frozenset(),
+                  must_follow=stage.must_follow)
+              pcollections[transcoded_pcollection].CopyFrom(
+                  pcollections[pcoll_in])
+              pcollections[transcoded_pcollection].coder_id = output_coder_id
+            else:
+              transcoded_pcollection = pcoll_in
+
+            flatten_write = Stage(
+                transform.unique_name + '/Write/' + local_in,
+                [beam_runner_api_pb2.PTransform(
+                    unique_name=transform.unique_name + '/Write/' + local_in,
+                    inputs={local_in: transcoded_pcollection},
+                    spec=beam_runner_api_pb2.FunctionSpec(
+                        urn=bundle_processor.DATA_OUTPUT_URN,
+                        parameter=param))],
+                downstream_side_inputs=frozenset(),
+                must_follow=stage.must_follow)
+            flatten_writes.append(flatten_write)
+            yield flatten_write
+
+          yield Stage(
+              transform.unique_name + '/Read',
+              [beam_runner_api_pb2.PTransform(
+                  unique_name=transform.unique_name + '/Read',
+                  outputs=transform.outputs,
+                  spec=beam_runner_api_pb2.FunctionSpec(
+                      urn=bundle_processor.DATA_INPUT_URN,
+                      parameter=param))],
+              downstream_side_inputs=frozenset(),
+              must_follow=union(frozenset(flatten_writes), stage.must_follow))
+
+        else:
+          yield stage
+
+    def annotate_downstream_side_inputs(stages):
+      """Annotate each stage with fusion-prohibiting information.
+
+      Each stage is annotated with the (transitive) set of pcollections that
+      depend on this stage that are also used later in the pipeline as a
+      side input.
+
+      While theoretically this could result in O(n^2) annotations, the size of
+      each set is bounded by the number of side inputs (typically much smaller
+      than the number of total nodes) and the number of *distinct* side-input
+      sets is also generally small (and shared due to the use of union
+      defined above).
+
+      This representation is also amenable to simple recomputation on fusion.
+      """
+      consumers = collections.defaultdict(list)
+      all_side_inputs = set()
+      for stage in stages:
+        for transform in stage.transforms:
+          for input in transform.inputs.values():
+            consumers[input].append(stage)
+        for si in stage.side_inputs():
+          all_side_inputs.add(si)
+      all_side_inputs = frozenset(all_side_inputs)
+
+      downstream_side_inputs_by_stage = {}
+
+      def compute_downstream_side_inputs(stage):
+        if stage not in downstream_side_inputs_by_stage:
+          downstream_side_inputs = frozenset()
+          for transform in stage.transforms:
+            for output in transform.outputs.values():
+              if output in all_side_inputs:
+                downstream_side_inputs = union(downstream_side_inputs, output)
+                for consumer in consumers[output]:
+                  downstream_side_inputs = union(
+                      downstream_side_inputs,
+                      compute_downstream_side_inputs(consumer))
+          downstream_side_inputs_by_stage[stage] = downstream_side_inputs
+        return downstream_side_inputs_by_stage[stage]
+
+      for stage in stages:
+        stage.downstream_side_inputs = compute_downstream_side_inputs(stage)
+      return stages
+
+    def greedily_fuse(stages):
+      """Places transforms sharing an edge in the same stage, whenever possible.
+      """
+      producers_by_pcoll = {}
+      consumers_by_pcoll = collections.defaultdict(list)
+
+      # Used to always reference the correct stage as the producer and
+      # consumer maps are not updated when stages are fused away.
+      replacements = {}
+
+      def replacement(s):
+        old_ss = []
+        while s in replacements:
+          old_ss.append(s)
+          s = replacements[s]
+        for old_s in old_ss[:-1]:
+          replacements[old_s] = s
+        return s
+
+      def fuse(producer, consumer):
+        fused = producer.fuse(consumer)
+        replacements[producer] = fused
+        replacements[consumer] = fused
+
+      # First record the producers and consumers of each PCollection.
+      for stage in stages:
+        for transform in stage.transforms:
+          for input in transform.inputs.values():
+            consumers_by_pcoll[input].append(stage)
+          for output in transform.outputs.values():
+            producers_by_pcoll[output] = stage
+
+      logging.debug('consumers\n%s', consumers_by_pcoll)
+      logging.debug('producers\n%s', producers_by_pcoll)
+
+      # Now try to fuse away all pcollections.
+      for pcoll, producer in producers_by_pcoll.items():
+        pcoll_as_param = proto_utils.pack_Any(
+            wrappers_pb2.BytesValue(
+                value=str("materialize:%s" % pcoll)))
+        write_pcoll = None
+        for consumer in consumers_by_pcoll[pcoll]:
+          producer = replacement(producer)
+          consumer = replacement(consumer)
+          # Update consumer.must_follow set, as it's used in can_fuse.
+          consumer.must_follow = set(
+              replacement(s) for s in consumer.must_follow)
+          if producer.can_fuse(consumer):
+            fuse(producer, consumer)
+          else:
+            # If we can't fuse, do a read + write.
+            if write_pcoll is None:
+              write_pcoll = Stage(
+                  pcoll + '/Write',
+                  [beam_runner_api_pb2.PTransform(
+                      unique_name=pcoll + '/Write',
+                      inputs={'in': pcoll},
+                      spec=beam_runner_api_pb2.FunctionSpec(
+                          urn=bundle_processor.DATA_OUTPUT_URN,
+                          parameter=pcoll_as_param))])
+              fuse(producer, write_pcoll)
+            if consumer.has_as_main_input(pcoll):
+              read_pcoll = Stage(
+                  pcoll + '/Read',
+                  [beam_runner_api_pb2.PTransform(
+                      unique_name=pcoll + '/Read',
+                      outputs={'out': pcoll},
+                      spec=beam_runner_api_pb2.FunctionSpec(
+                          urn=bundle_processor.DATA_INPUT_URN,
+                          parameter=pcoll_as_param))],
+                  must_follow={write_pcoll})
+              fuse(read_pcoll, consumer)
+
+      # Everything that was originally a stage or a replacement, but wasn't
+      # replaced, should be in the final graph.
+      final_stages = frozenset(stages).union(replacements.values()).difference(
+          replacements.keys())
+
+      for stage in final_stages:
+        # Update all references to their final values before throwing
+        # the replacement data away.
+        stage.must_follow = frozenset(replacement(s) for s in stage.must_follow)
+        # Two reads of the same stage may have been fused.  This is unneeded.
+        stage.deduplicate_read()
+      return final_stages
+
+    def sort_stages(stages):
+      """Order stages suitable for sequential execution.
+      """
+      seen = set()
+      ordered = []
+
+      def process(stage):
+        if stage not in seen:
+          seen.add(stage)
+          for prev in stage.must_follow:
+            process(prev)
+          ordered.append(stage)
+      for stage in stages:
+        process(stage)
+      return ordered
+
+    # Now actually apply the operations.
+
+    pipeline_components = copy.deepcopy(pipeline_proto.components)
+
+    # Reify coders.
+    # TODO(BEAM-2717): Remove once Coders are already in proto.
+    coders = pipeline_context.PipelineContext(pipeline_components).coders
+    for pcoll in pipeline_components.pcollections.values():
+      if pcoll.coder_id not in coders:
+        window_coder = coders[
+            pipeline_components.windowing_strategies[
+                pcoll.windowing_strategy_id].window_coder_id]
+        coder = WindowedValueCoder(
+            registry.get_coder(pickler.loads(pcoll.coder_id)),
+            window_coder=window_coder)
+        pcoll.coder_id = coders.get_id(coder)
+    coders.populate_map(pipeline_components.coders)
+
+    # Initial set of stages are singleton transforms.
+    stages = [
+        Stage(name, [transform])
+        for name, transform in pipeline_proto.components.transforms.items()
+        if not transform.subtransforms]
+
+    # Apply each phase in order.
+    for phase in [
+        annotate_downstream_side_inputs, expand_gbk, sink_flattens,
+        greedily_fuse, sort_stages]:
+      logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
+      stages = list(phase(stages))
+      logging.debug('Stages: %s', [str(s) for s in stages])
+
+    # Return the (possibly mutated) context and ordered set of stages.
+    return pipeline_components, stages
+
+  def run_stages(self, pipeline_components, stages, direct=True):
+
+    if direct:
+      controller = FnApiRunner.DirectController()
+    else:
+      controller = FnApiRunner.GrpcController()
+
+    try:
+      pcoll_buffers = collections.defaultdict(list)
+      for stage in stages:
+        self.run_stage(controller, pipeline_components, stage, pcoll_buffers)
+    finally:
+      controller.close()
+
+    return maptask_executor_runner.WorkerRunnerResult(PipelineState.DONE)
+
+  def run_stage(self, controller, pipeline_components, stage, pcoll_buffers):
+
+    coders = pipeline_context.PipelineContext(pipeline_components).coders
+    data_operation_spec = controller.data_operation_spec()
+
+    def extract_endpoints(stage):
+      # Returns maps of transform names to PCollection identifiers.
+      # Also mutates IO stages to point to the data data_operation_spec.
+      data_input = {}
+      data_side_input = {}
+      data_output = {}
+      for transform in stage.transforms:
+        pcoll_id = proto_utils.unpack_Any(
+            transform.spec.parameter, wrappers_pb2.BytesValue).value
+        if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
+                                  bundle_processor.DATA_OUTPUT_URN):
+          if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
+            target = transform.unique_name, only_element(transform.outputs)
+            data_input[target] = pcoll_id
+          elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
+            target = transform.unique_name, only_element(transform.inputs)
+            data_output[target] = pcoll_id
+          else:
+            raise NotImplementedError
+          if data_operation_spec:
+            transform.spec.parameter.CopyFrom(data_operation_spec)
+          else:
+            transform.spec.parameter.Clear()
+      return data_input, data_side_input, data_output
+
+    logging.info('Running %s', stage.name)
+    logging.debug('       %s', stage)
+    data_input, data_side_input, data_output = extract_endpoints(stage)
+    if data_side_input:
+      raise NotImplementedError('Side inputs.')
+
+    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
+        id=self._next_uid(),
+        transforms={transform.unique_name: transform
+                    for transform in stage.transforms},
+        pcollections=dict(pipeline_components.pcollections.items()),
+        coders=dict(pipeline_components.coders.items()),
+        windowing_strategies=dict(
+            pipeline_components.windowing_strategies.items()),
+        environments=dict(pipeline_components.environments.items()))
+
+    process_bundle_registration = beam_fn_api_pb2.InstructionRequest(
+        instruction_id=self._next_uid(),
+        register=beam_fn_api_pb2.RegisterRequest(
+            process_bundle_descriptor=[process_bundle_descriptor]))
+
+    process_bundle = beam_fn_api_pb2.InstructionRequest(
+        instruction_id=self._next_uid(),
+        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
+            process_bundle_descriptor_reference=
+            process_bundle_descriptor.id))
+
+    # Write all the input data to the channel.
+    for (transform_id, name), pcoll_id in data_input.items():
+      data_out = controller.data_plane_handler.output_stream(
+          process_bundle.instruction_id, beam_fn_api_pb2.Target(
+              primitive_transform_reference=transform_id, name=name))
+      for element_data in pcoll_buffers[pcoll_id]:
+        data_out.write(element_data)
+      data_out.close()
+
+    # Register and start running the bundle.
+    controller.control_handler.push(process_bundle_registration)
+    controller.control_handler.push(process_bundle)
+
+    # Wait for the bundle to finish.
+    while True:
+      result = controller.control_handler.pull()
+      if result.instruction_id == process_bundle.instruction_id:
+        if result.error:
+          raise RuntimeError(result.error)
+        break
+
+    # Gather all output data.
+    expected_targets = [
+        beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
+                               name=output_name)
+        for (transform_id, output_name), _ in data_output.items()]
+    for output in controller.data_plane_handler.input_elements(
+        process_bundle.instruction_id, expected_targets):
+      target_tuple = (
+          output.target.primitive_transform_reference, output.target.name)
+      if target_tuple in data_output:
+        pcoll_id = data_output[target_tuple]
+        if pcoll_id.startswith('materialize:'):
+          # Just store the data chunks for replay.
+          pcoll_buffers[pcoll_id].append(output.data)
+        elif pcoll_id.startswith('group:'):
+          # This is a grouping write, create a grouping buffer if needed.
+          if pcoll_id not in pcoll_buffers:
+            original_gbk_transform = pcoll_id.split(':', 1)[1]
+            transform_proto = pipeline_components.transforms[
+                original_gbk_transform]
+            input_pcoll = only_element(transform_proto.inputs.values())
+            output_pcoll = only_element(transform_proto.outputs.values())
+            pre_gbk_coder = coders[
+                pipeline_components.pcollections[input_pcoll].coder_id]
+            post_gbk_coder = coders[
+                pipeline_components.pcollections[output_pcoll].coder_id]
+            pcoll_buffers[pcoll_id] = _GroupingBuffer(
+                pre_gbk_coder, post_gbk_coder)
+          pcoll_buffers[pcoll_id].append(output.data)
+        else:
+          # These should be the only two identifiers we produce for now,
+          # but special side input writes may go here.
+          raise NotImplementedError(pcoll_id)
+
+  # This is the "old" way of executing pipelines.
+  # TODO(robertwb): Remove once runner API supports side inputs.
+
   def _map_task_registration(self, map_task, state_handler,
                              data_operation_spec):
     input_data, side_input_data, runner_sinks, process_bundle_descriptor = (
@@ -175,10 +719,6 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
       return {tag: pcollection_id(op_ix, out_ix)
               for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))}
 
-    def only_element(iterable):
-      element, = iterable
-      return element
-
     for op_ix, (stage_name, operation) in enumerate(map_task):
       transform_id = uniquify(stage_name)
 
@@ -332,6 +872,15 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
     finally:
       controller.close()
 
+  @staticmethod
+  def _reencode_elements(elements, element_coder):
+    output_stream = create_OutputStream()
+    for element in elements:
+      element_coder.get_impl().encode_to_stream(element, output_stream, True)
+    return output_stream.get()
+
+  # These classes are used to interact with the worker.
+
   class SimpleState(object):  # TODO(robertwb): Inherit from GRPC servicer.
 
     def __init__(self):
@@ -429,9 +978,7 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
       self.control_server.stop(5).wait()
       self.data_server.stop(5).wait()
 
-  @staticmethod
-  def _reencode_elements(elements, element_coder):
-    output_stream = create_OutputStream()
-    for element in elements:
-      element_coder.get_impl().encode_to_stream(element, output_stream, True)
-    return output_stream.get()
+
+def only_element(iterable):
+  element, = iterable
+  return element

http://git-wip-us.apache.org/repos/asf/beam/blob/5e71d53e/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 163e980..ba21954 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -51,7 +51,7 @@ class FnApiRunnerTest(
   def test_assert_that(self):
     # TODO: figure out a way for fn_api_runner to parse and raise the
     # underlying exception.
-    with self.assertRaisesRegexp(RuntimeError, 'BeamAssertException'):
+    with self.assertRaisesRegexp(Exception, 'Failed assert'):
       with self.create_pipeline() as p:
         assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
 

http://git-wip-us.apache.org/repos/asf/beam/blob/5e71d53e/sdks/python/apache_beam/runners/worker/bundle_processor.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 2669bfc..9474eda 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -28,17 +28,20 @@ import logging
 
 from google.protobuf import wrappers_pb2
 
+import apache_beam as beam
 from apache_beam.coders import coder_impl
 from apache_beam.coders import WindowedValueCoder
 from apache_beam.internal import pickler
 from apache_beam.io import iobase
 from apache_beam.portability.api import beam_fn_api_pb2
+from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.runners.dataflow.native_io import iobase as native_iobase
 from apache_beam.runners import pipeline_context
 from apache_beam.runners.worker import operation_specs
 from apache_beam.runners.worker import operations
 from apache_beam.utils import counters
 from apache_beam.utils import proto_utils
+from apache_beam.utils import urns
 
 # This module is experimental. No backwards-compatibility guarantees.
 
@@ -374,6 +377,24 @@ def create(factory, transform_id, transform_proto, parameter, consumers):
       consumers)
 
 
+@BeamTransformFactory.register_urn(
+    urns.READ_TRANSFORM, beam_runner_api_pb2.ReadPayload)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+  # The Dataflow runner harness strips the base64 encoding.
+  source = iobase.SourceBase.from_runner_api(parameter.source, factory.context)
+  spec = operation_specs.WorkerRead(
+      iobase.SourceBundle(1.0, source, None, None),
+      [WindowedValueCoder(source.default_output_coder())])
+  return factory.augment_oldstyle_op(
+      operations.ReadOperation(
+          transform_proto.unique_name,
+          spec,
+          factory.counter_factory,
+          factory.state_sampler),
+      transform_proto.unique_name,
+      consumers)
+
+
 @BeamTransformFactory.register_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue)
 def create(factory, transform_id, transform_proto, parameter, consumers):
   dofn_data = pickler.loads(parameter.value)
@@ -383,7 +404,32 @@ def create(factory, transform_id, transform_proto, parameter, consumers):
   else:
     # No side input data.
     serialized_fn, side_input_data = parameter.value, []
+  return _create_pardo_operation(
+      factory, transform_id, transform_proto, consumers,
+      serialized_fn, side_input_data)
+
+
+@BeamTransformFactory.register_urn(
+    urns.PARDO_TRANSFORM, beam_runner_api_pb2.ParDoPayload)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+  assert parameter.do_fn.spec.urn == urns.PICKLED_DO_FN_INFO
+  serialized_fn = proto_utils.unpack_Any(
+      parameter.do_fn.spec.parameter, wrappers_pb2.BytesValue).value
+  dofn_data = pickler.loads(serialized_fn)
+  if len(dofn_data) == 2:
+    # Has side input data.
+    serialized_fn, side_input_data = dofn_data
+  else:
+    # No side input data.
+    side_input_data = []
+  return _create_pardo_operation(
+      factory, transform_id, transform_proto, consumers,
+      serialized_fn, side_input_data)
+
 
+def _create_pardo_operation(
+    factory, transform_id, transform_proto, consumers,
+    serialized_fn, side_input_data):
   def create_side_input(tag, coder):
     # TODO(robertwb): Extract windows (and keys) out of element data.
     # TODO(robertwb): Extract state key from ParDoPayload.
@@ -395,10 +441,27 @@ def create(factory, transform_id, transform_proto, parameter, consumers):
                 key=side_input_tag(transform_id, tag)),
             coder=coder))
   output_tags = list(transform_proto.outputs.keys())
+
+  # Hack to match out prefix injected by dataflow runner.
+  def mutate_tag(tag):
+    if 'None' in output_tags:
+      if tag == 'None':
+        return 'out'
+      else:
+        return 'out_' + tag
+    else:
+      return tag
+  dofn_data = pickler.loads(serialized_fn)
+  if not dofn_data[-1]:
+    # Windowing not set.
+    pcoll_id, = transform_proto.inputs.values()
+    windowing = factory.context.windowing_strategies.get_by_id(
+        factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
+    serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
   output_coders = factory.get_output_coders(transform_proto)
   spec = operation_specs.WorkerDoFn(
       serialized_fn=serialized_fn,
-      output_tags=output_tags,
+      output_tags=[mutate_tag(tag) for tag in output_tags],
       input=None,
       side_inputs=[
           create_side_input(tag, coder) for tag, coder in side_input_data],
@@ -414,12 +477,52 @@ def create(factory, transform_id, transform_proto, parameter, consumers):
       output_tags)
 
 
+def _create_simple_pardo_operation(
+    factory, transform_id, transform_proto, consumers, dofn):
+  serialized_fn = pickler.dumps((dofn, (), {}, [], None))
+  side_input_data = []
+  return _create_pardo_operation(
+      factory, transform_id, transform_proto, consumers,
+      serialized_fn, side_input_data)
+
+
+@BeamTransformFactory.register_urn(
+    urns.GROUP_ALSO_BY_WINDOW_TRANSFORM, wrappers_pb2.BytesValue)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+  # Perhaps this hack can go away once all apply overloads are gone.
+  from apache_beam.transforms.core import _GroupAlsoByWindowDoFn
+  return _create_simple_pardo_operation(
+      factory, transform_id, transform_proto, consumers,
+      _GroupAlsoByWindowDoFn(
+          factory.context.windowing_strategies.get_by_id(parameter.value)))
+
+
+@BeamTransformFactory.register_urn(
+    urns.WINDOW_INTO_TRANSFORM, beam_runner_api_pb2.WindowingStrategy)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+  class WindowIntoDoFn(beam.DoFn):
+    def __init__(self, windowing):
+      self.windowing = windowing
+
+    def process(self, element, timestamp=beam.DoFn.TimestampParam):
+      new_windows = self.windowing.windowfn.assign(
+          WindowFn.AssignContext(timestamp, element=element))
+      yield WindowedValue(element, timestamp, new_windows)
+  from apache_beam.transforms.core import Windowing
+  from apache_beam.transforms.window import WindowFn, WindowedValue
+  windowing = Windowing.from_runner_api(parameter, factory.context)
+  return _create_simple_pardo_operation(
+      factory, transform_id, transform_proto, consumers,
+      WindowIntoDoFn(windowing))
+
+
 @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
 def create(factory, transform_id, transform_proto, unused_parameter, consumers):
   return factory.augment_oldstyle_op(
       operations.FlattenOperation(
           transform_proto.unique_name,
-          None,
+          operation_specs.WorkerFlatten(
+              None, [factory.get_only_output_coder(transform_proto)]),
           factory.counter_factory,
           factory.state_sampler),
       transform_proto.unique_name,

http://git-wip-us.apache.org/repos/asf/beam/blob/5e71d53e/sdks/python/apache_beam/transforms/core.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py
index cff6dbe..3f92ce9 100644
--- a/sdks/python/apache_beam/transforms/core.py
+++ b/sdks/python/apache_beam/transforms/core.py
@@ -1461,6 +1461,7 @@ PTransform.register_urn(
     # (Right now only WindowFn is used, but we need this to reconstitute the
     # WindowInto transform, and in the future will need it at runtime to
     # support meta-data driven triggers.)
+    # TODO(robertwb): Use a reference rather than embedding?
     beam_runner_api_pb2.WindowingStrategy,
     WindowInto.from_runner_api_parameter)
 
@@ -1500,7 +1501,10 @@ class Flatten(PTransform):
   def expand(self, pcolls):
     for pcoll in pcolls:
       self._check_pcollection(pcoll)
-    return pvalue.PCollection(self.pipeline)
+    result = pvalue.PCollection(self.pipeline)
+    result.element_type = typehints.Union[
+        tuple(pcoll.element_type for pcoll in pcolls)]
+    return result
 
   def get_windowing(self, inputs):
     if not inputs: