You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/05/02 00:45:04 UTC

[3/4] beam git commit: Fn API support for Python.

Fn API support for Python.

Also added supporting worker code and a runner
that exercises this code directly.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a856fcf3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a856fcf3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a856fcf3

Branch: refs/heads/master
Commit: a856fcf3f9ca48530078e1d7610e234110cd5890
Parents: b8131fe
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Thu Apr 20 15:59:25 2017 -0500
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Mon May 1 17:44:29 2017 -0700

----------------------------------------------------------------------
 .../apache_beam/runners/portability/__init__.py |  16 +
 .../runners/portability/fn_api_runner.py        | 471 ++++++++++++++
 .../runners/portability/fn_api_runner_test.py   |  40 ++
 .../portability/maptask_executor_runner.py      | 468 +++++++++++++
 .../portability/maptask_executor_runner_test.py | 204 ++++++
 .../apache_beam/runners/worker/__init__.py      |  16 +
 .../apache_beam/runners/worker/data_plane.py    | 288 ++++++++
 .../runners/worker/data_plane_test.py           | 139 ++++
 .../apache_beam/runners/worker/log_handler.py   | 100 +++
 .../runners/worker/log_handler_test.py          | 105 +++
 .../apache_beam/runners/worker/logger.pxd       |  25 +
 .../python/apache_beam/runners/worker/logger.py | 173 +++++
 .../apache_beam/runners/worker/logger_test.py   | 182 ++++++
 .../apache_beam/runners/worker/opcounters.pxd   |  45 ++
 .../apache_beam/runners/worker/opcounters.py    | 162 +++++
 .../runners/worker/opcounters_test.py           | 149 +++++
 .../runners/worker/operation_specs.py           | 368 +++++++++++
 .../apache_beam/runners/worker/operations.pxd   |  89 +++
 .../apache_beam/runners/worker/operations.py    | 651 +++++++++++++++++++
 .../apache_beam/runners/worker/sdk_worker.py    | 451 +++++++++++++
 .../runners/worker/sdk_worker_main.py           |  62 ++
 .../runners/worker/sdk_worker_test.py           | 168 +++++
 .../apache_beam/runners/worker/sideinputs.py    | 166 +++++
 .../runners/worker/sideinputs_test.py           | 150 +++++
 .../apache_beam/runners/worker/statesampler.pyx | 237 +++++++
 .../runners/worker/statesampler_fake.py         |  34 +
 .../runners/worker/statesampler_test.py         | 102 +++
 sdks/python/generate_pydoc.sh                   |   2 +
 sdks/python/setup.py                            |   8 +-
 29 files changed, 5069 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/__init__.py b/sdks/python/apache_beam/runners/portability/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
new file mode 100644
index 0000000..5802c17
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -0,0 +1,471 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A PipelineRunner using the SDK harness.
+"""
+import collections
+import json
+import logging
+import Queue as queue
+import threading
+
+import grpc
+from concurrent import futures
+
+import apache_beam as beam
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.coders.coder_impl import create_InputStream
+from apache_beam.coders.coder_impl import create_OutputStream
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.portability import maptask_executor_runner
+from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import sdk_worker
+
+
+def streaming_rpc_handler(cls, method_name):
+  """Un-inverts the flow of control between the runner and the sdk harness."""
+
+  class StreamingRpcHandler(cls):
+
+    _DONE = object()
+
+    def __init__(self):
+      self._push_queue = queue.Queue()
+      self._pull_queue = queue.Queue()
+      setattr(self, method_name, self.run)
+      self._read_thread = threading.Thread(target=self._read)
+
+    def run(self, iterator, context):
+      self._inputs = iterator
+      # Note: We only support one client for now.
+      self._read_thread.start()
+      while True:
+        to_push = self._push_queue.get()
+        if to_push is self._DONE:
+          return
+        yield to_push
+
+    def _read(self):
+      for data in self._inputs:
+        self._pull_queue.put(data)
+
+    def push(self, item):
+      self._push_queue.put(item)
+
+    def pull(self, timeout=None):
+      return self._pull_queue.get(timeout=timeout)
+
+    def empty(self):
+      return self._pull_queue.empty()
+
+    def done(self):
+      self.push(self._DONE)
+      self._read_thread.join()
+
+  return StreamingRpcHandler()
+
+
+class OldeSourceSplittableDoFn(beam.DoFn):
+  """A DoFn that reads and emits an entire source.
+  """
+
+  # TODO(robertwb): Make this a full SDF with progress splitting, etc.
+  def process(self, source):
+    if isinstance(source, iobase.SourceBundle):
+      for value in source.source.read(source.source.get_range_tracker(
+          source.start_position, source.stop_position)):
+        yield value
+    else:
+      # Dataflow native source
+      with source.reader() as reader:
+        for value in reader:
+          yield value
+
+
+# See DataflowRunner._pardo_fn_data
+OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps(
+    (OldeSourceSplittableDoFn(), (), {}, [],
+     beam.transforms.core.Windowing(GlobalWindows())))
+
+
+class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
+
+  def __init__(self):
+    super(FnApiRunner, self).__init__()
+    self._last_uid = -1
+
+  def has_metrics_support(self):
+    return False
+
+  def _next_uid(self):
+    self._last_uid += 1
+    return str(self._last_uid)
+
+  def _map_task_registration(self, map_task, state_handler,
+                             data_operation_spec):
+    input_data = {}
+    runner_sinks = {}
+    transforms = []
+    transform_index_to_id = {}
+
+    # Maps coders to new coder objects and references.
+    coders = {}
+
+    def coder_id(coder):
+      if coder not in coders:
+        coders[coder] = beam_fn_api_pb2.Coder(
+            function_spec=sdk_worker.pack_function_spec_data(
+                json.dumps(coder.as_cloud_object()),
+                sdk_worker.PYTHON_CODER_URN, id=self._next_uid()))
+
+      return coders[coder].function_spec.id
+
+    def output_tags(op):
+      return getattr(op, 'output_tags', ['out'])
+
+    def as_target(op_input):
+      input_op_index, input_output_index = op_input
+      input_op = map_task[input_op_index][1]
+      return {
+          'ignored_input_tag':
+              beam_fn_api_pb2.Target.List(target=[
+                  beam_fn_api_pb2.Target(
+                      primitive_transform_reference=transform_index_to_id[
+                          input_op_index],
+                      name=output_tags(input_op)[input_output_index])
+              ])
+      }
+
+    def outputs(op):
+      return {
+          tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder))
+          for tag, coder in zip(output_tags(op), op.output_coders)
+      }
+
+    for op_ix, (stage_name, operation) in enumerate(map_task):
+      transform_id = transform_index_to_id[op_ix] = self._next_uid()
+      if isinstance(operation, operation_specs.WorkerInMemoryWrite):
+        # Write this data back to the runner.
+        fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_OUTPUT_URN,
+                                          id=self._next_uid())
+        if data_operation_spec:
+          fn.data.Pack(data_operation_spec)
+        inputs = as_target(operation.input)
+        side_inputs = {}
+        runner_sinks[(transform_id, 'out')] = operation
+
+      elif isinstance(operation, operation_specs.WorkerRead):
+        # A Read is either translated to a direct injection of windowed values
+        # into the sdk worker, or an injection of the source object into the
+        # sdk worker as data followed by an SDF that reads that source.
+        if (isinstance(operation.source.source,
+                       worker_runner_base.InMemorySource)
+            and isinstance(operation.source.source.default_output_coder(),
+                           WindowedValueCoder)):
+          output_stream = create_OutputStream()
+          element_coder = (
+              operation.source.source.default_output_coder().get_impl())
+          # Re-encode the elements in the nested context and
+          # concatenate them together
+          for element in operation.source.source.read(None):
+            element_coder.encode_to_stream(element, output_stream, True)
+          target_name = self._next_uid()
+          input_data[(transform_id, target_name)] = output_stream.get()
+          fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_INPUT_URN,
+                                            id=self._next_uid())
+          if data_operation_spec:
+            fn.data.Pack(data_operation_spec)
+          inputs = {target_name: beam_fn_api_pb2.Target.List()}
+          side_inputs = {}
+        else:
+          # Read the source object from the runner.
+          source_coder = beam.coders.DillCoder()
+          input_transform_id = self._next_uid()
+          output_stream = create_OutputStream()
+          source_coder.get_impl().encode_to_stream(
+              GlobalWindows.windowed_value(operation.source),
+              output_stream,
+              True)
+          target_name = self._next_uid()
+          input_data[(input_transform_id, target_name)] = output_stream.get()
+          input_ptransform = beam_fn_api_pb2.PrimitiveTransform(
+              id=input_transform_id,
+              function_spec=beam_fn_api_pb2.FunctionSpec(
+                  urn=sdk_worker.DATA_INPUT_URN,
+                  id=self._next_uid()),
+              # TODO(robertwb): Possible name collision.
+              step_name=stage_name + '/inject_source',
+              inputs={target_name: beam_fn_api_pb2.Target.List()},
+              outputs={
+                  'out':
+                      beam_fn_api_pb2.PCollection(
+                          coder_reference=coder_id(source_coder))
+              })
+          if data_operation_spec:
+            input_ptransform.function_spec.data.Pack(data_operation_spec)
+          transforms.append(input_ptransform)
+
+          # Read the elements out of the source.
+          fn = sdk_worker.pack_function_spec_data(
+              OLDE_SOURCE_SPLITTABLE_DOFN_DATA,
+              sdk_worker.PYTHON_DOFN_URN,
+              id=self._next_uid())
+          inputs = {
+              'ignored_input_tag':
+                  beam_fn_api_pb2.Target.List(target=[
+                      beam_fn_api_pb2.Target(
+                          primitive_transform_reference=input_transform_id,
+                          name='out')
+                  ])
+          }
+          side_inputs = {}
+
+      elif isinstance(operation, operation_specs.WorkerDoFn):
+        fn = sdk_worker.pack_function_spec_data(
+            operation.serialized_fn,
+            sdk_worker.PYTHON_DOFN_URN,
+            id=self._next_uid())
+        inputs = as_target(operation.input)
+        # Store the contents of each side input for state access.
+        for si in operation.side_inputs:
+          assert isinstance(si.source, iobase.BoundedSource)
+          element_coder = si.source.default_output_coder()
+          view_id = self._next_uid()
+          # TODO(robertwb): Actually flesh out the ViewFn API.
+          side_inputs[si.tag] = beam_fn_api_pb2.SideInput(
+              view_fn=sdk_worker.serialize_and_pack_py_fn(
+                  element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN,
+                  id=view_id))
+          # Re-encode the elements in the nested context and
+          # concatenate them together
+          output_stream = create_OutputStream()
+          for element in si.source.read(
+              si.source.get_range_tracker(None, None)):
+            element_coder.get_impl().encode_to_stream(
+                element, output_stream, True)
+          elements_data = output_stream.get()
+          state_key = beam_fn_api_pb2.StateKey(function_spec_reference=view_id)
+          state_handler.Clear(state_key)
+          state_handler.Append(
+              beam_fn_api_pb2.SimpleStateAppendRequest(
+                  state_key=state_key, data=[elements_data]))
+
+      elif isinstance(operation, operation_specs.WorkerFlatten):
+        fn = sdk_worker.pack_function_spec_data(
+            operation.serialized_fn,
+            sdk_worker.IDENTITY_DOFN_URN,
+            id=self._next_uid())
+        inputs = {
+            'ignored_input_tag':
+                beam_fn_api_pb2.Target.List(target=[
+                    beam_fn_api_pb2.Target(
+                        primitive_transform_reference=transform_index_to_id[
+                            input_op_index],
+                        name=output_tags(map_task[input_op_index][1])[
+                            input_output_index])
+                    for input_op_index, input_output_index in operation.inputs
+                ])
+        }
+        side_inputs = {}
+
+      else:
+        raise TypeError(operation)
+
+      ptransform = beam_fn_api_pb2.PrimitiveTransform(
+          id=transform_id,
+          function_spec=fn,
+          step_name=stage_name,
+          inputs=inputs,
+          side_inputs=side_inputs,
+          outputs=outputs(operation))
+      transforms.append(ptransform)
+
+    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
+        id=self._next_uid(), coders=coders.values(),
+        primitive_transform=transforms)
+    return beam_fn_api_pb2.InstructionRequest(
+        instruction_id=self._next_uid(),
+        register=beam_fn_api_pb2.RegisterRequest(
+            process_bundle_descriptor=[process_bundle_descriptor
+                                      ])), runner_sinks, input_data
+
+  def _run_map_task(
+      self, map_task, control_handler, state_handler, data_plane_handler,
+      data_operation_spec):
+    registration, sinks, input_data = self._map_task_registration(
+        map_task, state_handler, data_operation_spec)
+    control_handler.push(registration)
+    process_bundle = beam_fn_api_pb2.InstructionRequest(
+        instruction_id=self._next_uid(),
+        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
+            process_bundle_descriptor_reference=registration.register.
+            process_bundle_descriptor[0].id))
+
+    for (transform_id, name), elements in input_data.items():
+      data_out = data_plane_handler.output_stream(
+          process_bundle.instruction_id, beam_fn_api_pb2.Target(
+              primitive_transform_reference=transform_id, name=name))
+      data_out.write(elements)
+      data_out.close()
+
+    control_handler.push(process_bundle)
+    while True:
+      result = control_handler.pull()
+      if result.instruction_id == process_bundle.instruction_id:
+        if result.error:
+          raise RuntimeError(result.error)
+        expected_targets = [
+            beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
+                                   name=output_name)
+            for (transform_id, output_name), _ in sinks.items()]
+        for output in data_plane_handler.input_elements(
+            process_bundle.instruction_id, expected_targets):
+          target_tuple = (
+              output.target.primitive_transform_reference, output.target.name)
+          if target_tuple not in sinks:
+            # Unconsumed output.
+            continue
+          sink_op = sinks[target_tuple]
+          coder = sink_op.output_coders[0]
+          input_stream = create_InputStream(output.data)
+          elements = []
+          while input_stream.size() > 0:
+            elements.append(coder.get_impl().decode_from_stream(
+                input_stream, True))
+          if not sink_op.write_windowed_values:
+            elements = [e.value for e in elements]
+          for e in elements:
+            sink_op.output_buffer.append(e)
+        return
+
+  def execute_map_tasks(self, ordered_map_tasks, direct=True):
+    if direct:
+      controller = FnApiRunner.DirectController()
+    else:
+      controller = FnApiRunner.GrpcController()
+
+    try:
+      for _, map_task in ordered_map_tasks:
+        logging.info('Running %s', map_task)
+        self._run_map_task(
+            map_task, controller.control_handler, controller.state_handler,
+            controller.data_plane_handler, controller.data_operation_spec())
+    finally:
+      controller.close()
+
+  class SimpleState(object):  # TODO(robertwb): Inherit from GRPC servicer.
+
+    def __init__(self):
+      self._all = collections.defaultdict(list)
+
+    def Get(self, state_key):
+      return beam_fn_api_pb2.Elements.Data(
+          data=''.join(self._all[self._to_key(state_key)]))
+
+    def Append(self, append_request):
+      self._all[self._to_key(append_request.state_key)].extend(
+          append_request.data)
+
+    def Clear(self, state_key):
+      try:
+        del self._all[self._to_key(state_key)]
+      except KeyError:
+        pass
+
+    @staticmethod
+    def _to_key(state_key):
+      return (state_key.function_spec_reference, state_key.window,
+              state_key.key)
+
+  class DirectController(object):
+    """An in-memory controller for fn API control, state and data planes."""
+
+    def __init__(self):
+      self._responses = []
+      self.state_handler = FnApiRunner.SimpleState()
+      self.control_handler = self
+      self.data_plane_handler = data_plane.InMemoryDataChannel()
+      self.worker = sdk_worker.SdkWorker(
+          self.state_handler, data_plane.InMemoryDataChannelFactory(
+              self.data_plane_handler.inverse()))
+
+    def push(self, request):
+      logging.info('CONTROL REQUEST %s', request)
+      response = self.worker.do_instruction(request)
+      logging.info('CONTROL RESPONSE %s', response)
+      self._responses.append(response)
+
+    def pull(self):
+      return self._responses.pop(0)
+
+    def done(self):
+      pass
+
+    def close(self):
+      pass
+
+    def data_operation_spec(self):
+      return None
+
+  class GrpcController(object):
+    """An grpc based controller for fn API control, state and data planes."""
+
+    def __init__(self):
+      self.state_handler = FnApiRunner.SimpleState()
+      self.control_server = grpc.server(
+          futures.ThreadPoolExecutor(max_workers=10))
+      self.control_port = self.control_server.add_insecure_port('[::]:0')
+
+      self.data_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+      self.data_port = self.data_server.add_insecure_port('[::]:0')
+
+      self.control_handler = streaming_rpc_handler(
+          beam_fn_api_pb2.BeamFnControlServicer, 'Control')
+      beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
+          self.control_handler, self.control_server)
+
+      self.data_plane_handler = data_plane.GrpcServerDataChannel()
+      beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+          self.data_plane_handler, self.data_server)
+
+      logging.info('starting control server on port %s', self.control_port)
+      logging.info('starting data server on port %s', self.data_port)
+      self.data_server.start()
+      self.control_server.start()
+
+      self.worker = sdk_worker.SdkHarness(
+          grpc.insecure_channel('localhost:%s' % self.control_port))
+      self.worker_thread = threading.Thread(target=self.worker.run)
+      logging.info('starting worker')
+      self.worker_thread.start()
+
+    def data_operation_spec(self):
+      url = 'localhost:%s' % self.data_port
+      remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+      remote_grpc_port.api_service_descriptor.url = url
+      return remote_grpc_port
+
+    def close(self):
+      self.control_handler.done()
+      self.worker_thread.join()
+      self.data_plane_handler.close()
+      self.control_server.stop(5).wait()
+      self.data_server.stop(5).wait()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
new file mode 100644
index 0000000..633602f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.runners.portability import fn_api_runner
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class FnApiRunnerTest(maptask_executor_runner.MapTaskExecutorRunner):
+
+  def create_pipeline(self):
+    return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
+
+  def test_combine_per_key(self):
+    # TODO(robertwb): Implement PGBKCV operation.
+    pass
+
+  # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
new file mode 100644
index 0000000..d273e18
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
@@ -0,0 +1,468 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Beam runner for testing/profiling worker code directly.
+"""
+
+import collections
+import logging
+import time
+
+import apache_beam as beam
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.runners import DataflowRunner
+from apache_beam.runners.dataflow.internal.dependency import _dependency_file_copy
+from apache_beam.runners.dataflow.internal.names import PropertyNames
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.runners.runner import PipelineRunner
+from apache_beam.runners.runner import PipelineState
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+try:
+  from apache_beam.runners.worker import statesampler
+except ImportError:
+  from apache_beam.runners.worker import statesampler_fake as statesampler
+from apache_beam.typehints import typehints
+from apache_beam.utils import pipeline_options
+from apache_beam.utils import profiler
+from apache_beam.utils.counters import CounterFactory
+
+
+class MapTaskExecutorRunner(PipelineRunner):
+  """Beam runner translating a pipeline into map tasks that are then executed.
+
+  Primarily intended for testing and profiling the worker code paths.
+  """
+
+  def __init__(self):
+    self.executors = []
+
+  def has_metrics_support(self):
+    """Returns whether this runner supports metrics or not.
+    """
+    return False
+
+  def run(self, pipeline):
+    MetricsEnvironment.set_metrics_supported(self.has_metrics_support())
+    # List of map tasks  Each map task is a list of
+    # (stage_name, operation_specs.WorkerOperation) instructions.
+    self.map_tasks = []
+
+    # Map of pvalues to
+    # (map_task_index, producer_operation_index, producer_output_index)
+    self.outputs = {}
+
+    # Unique mappings of PCollections to strings.
+    self.side_input_labels = collections.defaultdict(
+        lambda: str(len(self.side_input_labels)))
+
+    # Mapping of map task indices to all map tasks that must preceed them.
+    self.dependencies = collections.defaultdict(set)
+
+    # Visit the graph, building up the map_tasks and their metadata.
+    super(MapTaskExecutorRunner, self).run(pipeline)
+
+    # Now run the tasks in topological order.
+    def compute_depth_map(deps):
+      memoized = {}
+
+      def compute_depth(x):
+        if x not in memoized:
+          memoized[x] = 1 + max([-1] + [compute_depth(y) for y in deps[x]])
+        return memoized[x]
+
+      return {x: compute_depth(x) for x in deps.keys()}
+
+    map_task_depths = compute_depth_map(self.dependencies)
+    ordered_map_tasks = sorted((map_task_depths.get(ix, -1), map_task)
+                               for ix, map_task in enumerate(self.map_tasks))
+
+    profile_options = pipeline.options.view_as(
+        pipeline_options.ProfilingOptions)
+    if profile_options.profile_cpu:
+      with profiler.Profile(
+          profile_id='worker-runner',
+          profile_location=profile_options.profile_location,
+          log_results=True, file_copy_fn=_dependency_file_copy):
+        self.execute_map_tasks(ordered_map_tasks)
+    else:
+      self.execute_map_tasks(ordered_map_tasks)
+
+    return WorkerRunnerResult(PipelineState.UNKNOWN)
+
+  def metrics_containers(self):
+    return [op.metrics_container
+            for ex in self.executors
+            for op in ex.operations()]
+
+  def execute_map_tasks(self, ordered_map_tasks):
+    tt = time.time()
+    for ix, (_, map_task) in enumerate(ordered_map_tasks):
+      logging.info('Running %s', map_task)
+      t = time.time()
+      stage_names, all_operations = zip(*map_task)
+      # TODO(robertwb): The DataflowRunner worker receives system step names
+      # (e.g. "s3") that are used to label the output msec counters.  We use the
+      # operation names here, but this is not the same scheme used by the
+      # DataflowRunner; the result is that the output msec counters are named
+      # differently.
+      system_names = stage_names
+      # Create the CounterFactory and StateSampler for this MapTask.
+      # TODO(robertwb): Output counters produced here are currently ignored.
+      counter_factory = CounterFactory()
+      state_sampler = statesampler.StateSampler('%s-' % ix, counter_factory)
+      map_executor = operations.SimpleMapTaskExecutor(
+          operation_specs.MapTask(
+              all_operations, 'S%02d' % ix,
+              system_names, stage_names, system_names),
+          counter_factory,
+          state_sampler)
+      self.executors.append(map_executor)
+      map_executor.execute()
+      logging.info(
+          'Stage %s finished: %0.3f sec', stage_names[0], time.time() - t)
+    logging.info('Total time: %0.3f sec', time.time() - tt)
+
+  def run_Read(self, transform_node):
+    self._run_read_from(transform_node, transform_node.transform.source)
+
+  def _run_read_from(self, transform_node, source):
+    """Used when this operation is the result of reading source."""
+    if not isinstance(source, iobase.NativeSource):
+      source = iobase.SourceBundle(1.0, source, None, None)
+    output = transform_node.outputs[None]
+    element_coder = self._get_coder(output)
+    read_op = operation_specs.WorkerRead(source, output_coders=[element_coder])
+    self.outputs[output] = len(self.map_tasks), 0, 0
+    self.map_tasks.append([(transform_node.full_label, read_op)])
+    return len(self.map_tasks) - 1
+
+  def run_ParDo(self, transform_node):
+    transform = transform_node.transform
+    output = transform_node.outputs[None]
+    element_coder = self._get_coder(output)
+    map_task_index, producer_index, output_index = self.outputs[
+        transform_node.inputs[0]]
+
+    # If any of this ParDo's side inputs depend on outputs from this map_task,
+    # we can't continue growing this map task.
+    def is_reachable(leaf, root):
+      if leaf == root:
+        return True
+      else:
+        return any(is_reachable(x, root) for x in self.dependencies[leaf])
+
+    if any(is_reachable(self.outputs[side_input.pvalue][0], map_task_index)
+           for side_input in transform_node.side_inputs):
+      # Start a new map tasks.
+      input_element_coder = self._get_coder(transform_node.inputs[0])
+
+      output_buffer = OutputBuffer(input_element_coder)
+
+      fusion_break_write = operation_specs.WorkerInMemoryWrite(
+          output_buffer=output_buffer,
+          write_windowed_values=True,
+          input=(producer_index, output_index),
+          output_coders=[input_element_coder])
+      self.map_tasks[map_task_index].append(
+          (transform_node.full_label + '/Write', fusion_break_write))
+
+      original_map_task_index = map_task_index
+      map_task_index, producer_index, output_index = len(self.map_tasks), 0, 0
+
+      fusion_break_read = operation_specs.WorkerRead(
+          output_buffer.source_bundle(),
+          output_coders=[input_element_coder])
+      self.map_tasks.append(
+          [(transform_node.full_label + '/Read', fusion_break_read)])
+
+      self.dependencies[map_task_index].add(original_map_task_index)
+
+    def create_side_read(side_input):
+      label = self.side_input_labels[side_input]
+      output_buffer = self.run_side_write(
+          side_input.pvalue, '%s/%s' % (transform_node.full_label, label))
+      return operation_specs.WorkerSideInputSource(
+          output_buffer.source(), label)
+
+    do_op = operation_specs.WorkerDoFn(  #
+        serialized_fn=pickler.dumps(DataflowRunner._pardo_fn_data(
+            transform_node,
+            lambda side_input: self.side_input_labels[side_input])),
+        output_tags=[PropertyNames.OUT] + ['%s_%s' % (PropertyNames.OUT, tag)
+                                           for tag in transform.output_tags
+                                          ],
+        # Same assumption that DataflowRunner has about coders being compatible
+        # across outputs.
+        output_coders=[element_coder] * (len(transform.output_tags) + 1),
+        input=(producer_index, output_index),
+        side_inputs=[create_side_read(side_input)
+                     for side_input in transform_node.side_inputs])
+
+    producer_index = len(self.map_tasks[map_task_index])
+    self.outputs[transform_node.outputs[None]] = (
+        map_task_index, producer_index, 0)
+    for ix, tag in enumerate(transform.output_tags):
+      self.outputs[transform_node.outputs[
+          tag]] = map_task_index, producer_index, ix + 1
+    self.map_tasks[map_task_index].append((transform_node.full_label, do_op))
+
+    for side_input in transform_node.side_inputs:
+      self.dependencies[map_task_index].add(self.outputs[side_input.pvalue][0])
+
+  def run_side_write(self, pcoll, label):
+    map_task_index, producer_index, output_index = self.outputs[pcoll]
+
+    windowed_element_coder = self._get_coder(pcoll)
+    output_buffer = OutputBuffer(windowed_element_coder)
+    write_sideinput_op = operation_specs.WorkerInMemoryWrite(
+        output_buffer=output_buffer,
+        write_windowed_values=True,
+        input=(producer_index, output_index),
+        output_coders=[windowed_element_coder])
+    self.map_tasks[map_task_index].append(
+        (label, write_sideinput_op))
+    return output_buffer
+
+  def run_GroupByKeyOnly(self, transform_node):
+    map_task_index, producer_index, output_index = self.outputs[
+        transform_node.inputs[0]]
+    grouped_element_coder = self._get_coder(transform_node.outputs[None],
+                                            windowed=False)
+    windowed_ungrouped_element_coder = self._get_coder(transform_node.inputs[0])
+
+    output_buffer = GroupingOutputBuffer(grouped_element_coder)
+    shuffle_write = operation_specs.WorkerInMemoryWrite(
+        output_buffer=output_buffer,
+        write_windowed_values=False,
+        input=(producer_index, output_index),
+        output_coders=[windowed_ungrouped_element_coder])
+    self.map_tasks[map_task_index].append(
+        (transform_node.full_label + '/Write', shuffle_write))
+
+    output_map_task_index = self._run_read_from(
+        transform_node, output_buffer.source())
+    self.dependencies[output_map_task_index].add(map_task_index)
+
+  def run_Flatten(self, transform_node):
+    output_buffer = OutputBuffer(self._get_coder(transform_node.outputs[None]))
+    output_map_task = self._run_read_from(transform_node,
+                                          output_buffer.source())
+
+    for input in transform_node.inputs:
+      map_task_index, producer_index, output_index = self.outputs[input]
+      element_coder = self._get_coder(input)
+      flatten_write = operation_specs.WorkerInMemoryWrite(
+          output_buffer=output_buffer,
+          write_windowed_values=True,
+          input=(producer_index, output_index),
+          output_coders=[element_coder])
+      self.map_tasks[map_task_index].append(
+          (transform_node.full_label + '/Write', flatten_write))
+      self.dependencies[output_map_task].add(map_task_index)
+
+  def apply_CombinePerKey(self, transform, input):
+    # TODO(robertwb): Support side inputs.
+    assert not transform.args and not transform.kwargs
+    return (input
+            | PartialGroupByKeyCombineValues(transform.fn)
+            | beam.GroupByKey()
+            | MergeAccumulators(transform.fn)
+            | ExtractOutputs(transform.fn))
+
+  def run_PartialGroupByKeyCombineValues(self, transform_node):
+    element_coder = self._get_coder(transform_node.outputs[None])
+    _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+    combine_op = operation_specs.WorkerPartialGroupByKey(
+        combine_fn=pickler.dumps(
+            (transform_node.transform.combine_fn, (), {}, ())),
+        output_coders=[element_coder],
+        input=(producer_index, output_index))
+    self._run_as_op(transform_node, combine_op)
+
+  def run_MergeAccumulators(self, transform_node):
+    self._run_combine_transform(transform_node, 'merge')
+
+  def run_ExtractOutputs(self, transform_node):
+    self._run_combine_transform(transform_node, 'extract')
+
+  def _run_combine_transform(self, transform_node, phase):
+    transform = transform_node.transform
+    element_coder = self._get_coder(transform_node.outputs[None])
+    _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+    combine_op = operation_specs.WorkerCombineFn(
+        serialized_fn=pickler.dumps(
+            (transform.combine_fn, (), {}, ())),
+        phase=phase,
+        output_coders=[element_coder],
+        input=(producer_index, output_index))
+    self._run_as_op(transform_node, combine_op)
+
+  def _get_coder(self, pvalue, windowed=True):
+    # TODO(robertwb): This should be an attribute of the pvalue itself.
+    return DataflowRunner._get_coder(
+        pvalue.element_type or typehints.Any,
+        pvalue.windowing.windowfn.get_window_coder() if windowed else None)
+
+  def _run_as_op(self, transform_node, op):
+    """Single-output operation in the same map task as its input."""
+    map_task_index, _, _ = self.outputs[transform_node.inputs[0]]
+    op_index = len(self.map_tasks[map_task_index])
+    output = transform_node.outputs[None]
+    self.outputs[output] = map_task_index, op_index, 0
+    self.map_tasks[map_task_index].append((transform_node.full_label, op))
+
+
+class InMemorySource(iobase.BoundedSource):
+  """Source for reading an (as-yet unwritten) set of in-memory encoded elements.
+  """
+
+  def __init__(self, encoded_elements, coder):
+    self._encoded_elements = encoded_elements
+    self._coder = coder
+
+  def get_range_tracker(self, unused_start_position, unused_end_position):
+    return None
+
+  def read(self, unused_range_tracker):
+    for encoded_element in self._encoded_elements:
+      yield self._coder.decode(encoded_element)
+
+  def default_output_coder(self):
+    return self._coder
+
+
+class OutputBuffer(object):
+
+  def __init__(self, coder):
+    self.coder = coder
+    self.elements = []
+    self.encoded_elements = []
+
+  def source(self):
+    return InMemorySource(self.encoded_elements, self.coder)
+
+  def source_bundle(self):
+    return iobase.SourceBundle(
+        1.0, InMemorySource(self.encoded_elements, self.coder), None, None)
+
+  def __repr__(self):
+    return 'GroupingOutput[%r]' % len(self.elements)
+
+  def append(self, value):
+    self.elements.append(value)
+    self.encoded_elements.append(self.coder.encode(value))
+
+
+class GroupingOutputBuffer(object):
+
+  def __init__(self, grouped_coder):
+    self.grouped_coder = grouped_coder
+    self.elements = collections.defaultdict(list)
+    self.frozen = False
+
+  def source(self):
+    return InMemorySource(self.encoded_elements, self.grouped_coder)
+
+  def __repr__(self):
+    return 'GroupingOutputBuffer[%r]' % len(self.elements)
+
+  def append(self, pair):
+    assert not self.frozen
+    k, v = pair
+    self.elements[k].append(v)
+
+  def freeze(self):
+    if not self.frozen:
+      self._encoded_elements = [self.grouped_coder.encode(kv)
+                                for kv in self.elements.iteritems()]
+    self.frozen = True
+    return self._encoded_elements
+
+  @property
+  def encoded_elements(self):
+    return GroupedOutputBuffer(self)
+
+
+class GroupedOutputBuffer(object):
+
+  def __init__(self, buffer):
+    self.buffer = buffer
+
+  def __getitem__(self, ix):
+    return self.buffer.freeze()[ix]
+
+  def __iter__(self):
+    return iter(self.buffer.freeze())
+
+  def __len__(self):
+    return len(self.buffer.freeze())
+
+  def __nonzero__(self):
+    return True
+
+
+class PartialGroupByKeyCombineValues(beam.PTransform):
+
+  def __init__(self, combine_fn, native=True):
+    self.combine_fn = combine_fn
+    self.native = native
+
+  def expand(self, input):
+    if self.native:
+      return beam.pvalue.PCollection(input.pipeline)
+    else:
+      def to_accumulator(v):
+        return self.combine_fn.add_input(
+            self.combine_fn.create_accumulator(), v)
+      return input | beam.Map(lambda (k, v): (k, to_accumulator(v)))
+
+
+class MergeAccumulators(beam.PTransform):
+
+  def __init__(self, combine_fn, native=True):
+    self.combine_fn = combine_fn
+    self.native = native
+
+  def expand(self, input):
+    if self.native:
+      return beam.pvalue.PCollection(input.pipeline)
+    else:
+      merge_accumulators = self.combine_fn.merge_accumulators
+      return input | beam.Map(lambda (k, vs): (k, merge_accumulators(vs)))
+
+
+class ExtractOutputs(beam.PTransform):
+
+  def __init__(self, combine_fn, native=True):
+    self.combine_fn = combine_fn
+    self.native = native
+
+  def expand(self, input):
+    if self.native:
+      return beam.pvalue.PCollection(input.pipeline)
+    else:
+      extract_output = self.combine_fn.extract_output
+      return input | beam.Map(lambda (k, v): (k, extract_output(v)))
+
+
+class WorkerRunnerResult(PipelineResult):
+
+  def wait_until_finish(self, duration=None):
+    pass

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
new file mode 100644
index 0000000..6e13e73
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
@@ -0,0 +1,204 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import tempfile
+import unittest
+
+import apache_beam as beam
+
+from apache_beam.metrics import Metrics
+from apache_beam.metrics.execution import MetricKey
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics.metricbase import MetricName
+
+from apache_beam.pvalue import AsList
+from apache_beam.transforms.util import assert_that
+from apache_beam.transforms.util import BeamAssertException
+from apache_beam.transforms.util import equal_to
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class MapTaskExecutorRunnerTest(unittest.TestCase):
+
+  def create_pipeline(self):
+    return beam.Pipeline(runner=maptask_executor_runner.MapTaskExecutorRunner())
+
+  def test_assert_that(self):
+    with self.assertRaises(BeamAssertException):
+      with self.create_pipeline() as p:
+        assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
+
+  def test_create(self):
+    with self.create_pipeline() as p:
+      assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
+
+  def test_pardo(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create(['a', 'bc'])
+             | beam.Map(lambda e: e * 2)
+             | beam.Map(lambda e: e + 'x'))
+      assert_that(res, equal_to(['aax', 'bcbcx']))
+
+  def test_pardo_metrics(self):
+
+    class MyDoFn(beam.DoFn):
+
+      def start_bundle(self):
+        self.count = Metrics.counter(self.__class__, 'elements')
+
+      def process(self, element):
+        self.count.inc(element)
+        return [element]
+
+    class MyOtherDoFn(beam.DoFn):
+
+      def start_bundle(self):
+        self.count = Metrics.counter(self.__class__, 'elementsplusone')
+
+      def process(self, element):
+        self.count.inc(element + 1)
+        return [element]
+
+    with self.create_pipeline() as p:
+      res = (p | beam.Create([1, 2, 3])
+             | 'mydofn' >> beam.ParDo(MyDoFn())
+             | 'myotherdofn' >> beam.ParDo(MyOtherDoFn()))
+      p.run()
+      if not MetricsEnvironment.METRICS_SUPPORTED:
+        self.skipTest('Metrics are not supported.')
+
+      counter_updates = [{'key': key, 'value': val}
+                         for container in p.runner.metrics_containers()
+                         for key, val in
+                         container.get_updates().counters.items()]
+      counter_values = [update['value'] for update in counter_updates]
+      counter_keys = [update['key'] for update in counter_updates]
+      assert_that(res, equal_to([1, 2, 3]))
+      self.assertEqual(counter_values, [6, 9])
+      self.assertEqual(counter_keys, [
+          MetricKey('mydofn',
+                    MetricName(__name__ + '.MyDoFn', 'elements')),
+          MetricKey('myotherdofn',
+                    MetricName(__name__ + '.MyOtherDoFn', 'elementsplusone'))])
+
+  def test_pardo_side_outputs(self):
+    def tee(elem, *tags):
+      for tag in tags:
+        if tag in elem:
+          yield beam.pvalue.OutputValue(tag, elem)
+    with self.create_pipeline() as p:
+      xy = (p
+            | 'Create' >> beam.Create(['x', 'y', 'xy'])
+            | beam.FlatMap(tee, 'x', 'y').with_outputs())
+      assert_that(xy.x, equal_to(['x', 'xy']), label='x')
+      assert_that(xy.y, equal_to(['y', 'xy']), label='y')
+
+  def test_pardo_side_and_main_outputs(self):
+    def even_odd(elem):
+      yield elem
+      yield beam.pvalue.OutputValue('odd' if elem % 2 else 'even', elem)
+    with self.create_pipeline() as p:
+      ints = p | beam.Create([1, 2, 3])
+      named = ints | 'named' >> beam.FlatMap(
+          even_odd).with_outputs('even', 'odd', main='all')
+      assert_that(named.all, equal_to([1, 2, 3]), label='named.all')
+      assert_that(named.even, equal_to([2]), label='named.even')
+      assert_that(named.odd, equal_to([1, 3]), label='named.odd')
+
+      unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs()
+      unnamed[None] | beam.Map(id)  # pylint: disable=expression-not-assigned
+      assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all')
+      assert_that(unnamed.even, equal_to([2]), label='unnamed.even')
+      assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
+
+  def test_pardo_side_inputs(self):
+    def cross_product(elem, sides):
+      for side in sides:
+        yield elem, side
+    with self.create_pipeline() as p:
+      main = p | 'main' >> beam.Create(['a', 'b', 'c'])
+      side = p | 'side' >> beam.Create(['x', 'y'])
+      assert_that(main | beam.FlatMap(cross_product, AsList(side)),
+                  equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'),
+                            ('a', 'y'), ('b', 'y'), ('c', 'y')]))
+
+  def test_pardo_unfusable_side_inputs(self):
+    def cross_product(elem, sides):
+      for side in sides:
+        yield elem, side
+    with self.create_pipeline() as p:
+      pcoll = p | beam.Create(['a', 'b'])
+      assert_that(pcoll | beam.FlatMap(cross_product, AsList(pcoll)),
+                  equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+    with self.create_pipeline() as p:
+      pcoll = p | beam.Create(['a', 'b'])
+      derived = ((pcoll,) | beam.Flatten()
+                 | beam.Map(lambda x: (x, x))
+                 | beam.GroupByKey()
+                 | 'Unkey' >> beam.Map(lambda (x, _): x))
+      assert_that(
+          pcoll | beam.FlatMap(cross_product, AsList(derived)),
+          equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+  def test_group_by_key(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+             | beam.GroupByKey()
+             | beam.Map(lambda (k, vs): (k, sorted(vs))))
+      assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
+
+  def test_flatten(self):
+    with self.create_pipeline() as p:
+      res = (p | 'a' >> beam.Create(['a']),
+             p | 'bc' >> beam.Create(['b', 'c']),
+             p | 'd' >> beam.Create(['d'])) | beam.Flatten()
+      assert_that(res, equal_to(['a', 'b', 'c', 'd']))
+
+  def test_combine_per_key(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+             | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
+      assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
+
+  def test_read(self):
+    with tempfile.NamedTemporaryFile() as temp_file:
+      temp_file.write('a\nb\nc')
+      temp_file.flush()
+      with self.create_pipeline() as p:
+        assert_that(p | beam.io.ReadFromText(temp_file.name),
+                    equal_to(['a', 'b', 'c']))
+
+  def test_windowing(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create([1, 2, 100, 101, 102])
+             | beam.Map(lambda t: TimestampedValue(('k', t), t))
+             | beam.WindowInto(beam.transforms.window.Sessions(10))
+             | beam.GroupByKey()
+             | beam.Map(lambda (k, vs): (k, sorted(vs))))
+      assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/__init__.py b/sdks/python/apache_beam/runners/worker/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
new file mode 100644
index 0000000..6425447
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -0,0 +1,288 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Implementation of DataChannels for communicating across the data plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import logging
+import Queue as queue
+import threading
+
+from apache_beam.coders import coder_impl
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class ClosableOutputStream(type(coder_impl.create_OutputStream())):
+  """A Outputstream for use with CoderImpls that has a close() method."""
+
+  def __init__(self, close_callback=None):
+    super(ClosableOutputStream, self).__init__()
+    self._close_callback = close_callback
+
+  def close(self):
+    if self._close_callback:
+      self._close_callback(self.get())
+
+
+class DataChannel(object):
+  """Represents a channel for reading and writing data over the data plane.
+
+  Read from this channel with the input_elements method::
+
+    for elements_data in data_channel.input_elements(instruction_id, targets):
+      [process elements_data]
+
+  Write to this channel using the output_stream method::
+
+    out1 = data_channel.output_stream(instruction_id, target1)
+    out1.write(...)
+    out1.close()
+
+  When all data for all instructions is written, close the channel::
+
+    data_channel.close()
+  """
+
+  __metaclass__ = abc.ABCMeta
+
+  @abc.abstractmethod
+  def input_elements(self, instruction_id, expected_targets):
+    """Returns an iterable of all Element.Data bundles for instruction_id.
+
+    This iterable terminates only once the full set of data has been recieved
+    for each of the expected targets. It may block waiting for more data.
+
+    Args:
+        instruction_id: which instruction the results must belong to
+        expected_targets: which targets to wait on for completion
+    """
+    raise NotImplementedError(type(self))
+
+  @abc.abstractmethod
+  def output_stream(self, instruction_id, target):
+    """Returns an output stream writing elements to target.
+
+    Args:
+        instruction_id: which instruction this stream belongs to
+        target: the target of the returned stream
+    """
+    raise NotImplementedError(type(self))
+
+  @abc.abstractmethod
+  def close(self):
+    """Closes this channel, indicating that all data has been written.
+
+    Data can continue to be read.
+
+    If this channel is shared by many instructions, should only be called on
+    worker shutdown.
+    """
+    raise NotImplementedError(type(self))
+
+
+class InMemoryDataChannel(DataChannel):
+  """An in-memory implementation of a DataChannel.
+
+  This channel is two-sided.  What is written to one side is read by the other.
+  The inverse() method returns the other side of a instance.
+  """
+
+  def __init__(self, inverse=None):
+    self._inputs = []
+    self._inverse = inverse or InMemoryDataChannel(self)
+
+  def inverse(self):
+    return self._inverse
+
+  def input_elements(self, instruction_id, unused_expected_targets=None):
+    for data in self._inputs:
+      if data.instruction_reference == instruction_id:
+        yield data
+
+  def output_stream(self, instruction_id, target):
+    def add_to_inverse_output(data):
+      self._inverse._inputs.append(  # pylint: disable=protected-access
+          beam_fn_api_pb2.Elements.Data(
+              instruction_reference=instruction_id,
+              target=target,
+              data=data))
+    return ClosableOutputStream(add_to_inverse_output)
+
+  def close(self):
+    pass
+
+
+class _GrpcDataChannel(DataChannel):
+  """Base class for implementing a BeamFnData-based DataChannel."""
+
+  _WRITES_FINISHED = object()
+
+  def __init__(self):
+    self._to_send = queue.Queue()
+    self._received = collections.defaultdict(queue.Queue)
+    self._receive_lock = threading.Lock()
+    self._reads_finished = threading.Event()
+
+  def close(self):
+    self._to_send.put(self._WRITES_FINISHED)
+
+  def wait(self, timeout=None):
+    self._reads_finished.wait(timeout)
+
+  def _receiving_queue(self, instruction_id):
+    with self._receive_lock:
+      return self._received[instruction_id]
+
+  def input_elements(self, instruction_id, expected_targets):
+    received = self._receiving_queue(instruction_id)
+    done_targets = []
+    while len(done_targets) < len(expected_targets):
+      data = received.get()
+      if not data.data and data.target in expected_targets:
+        done_targets.append(data.target)
+      else:
+        assert data.target not in done_targets
+        yield data
+
+  def output_stream(self, instruction_id, target):
+    def add_to_send_queue(data):
+      self._to_send.put(
+          beam_fn_api_pb2.Elements.Data(
+              instruction_reference=instruction_id,
+              target=target,
+              data=data))
+      self._to_send.put(
+          beam_fn_api_pb2.Elements.Data(
+              instruction_reference=instruction_id,
+              target=target,
+              data=''))
+    return ClosableOutputStream(add_to_send_queue)
+
+  def _write_outputs(self):
+    done = False
+    while not done:
+      data = [self._to_send.get()]
+      try:
+        # Coalesce up to 100 other items.
+        for _ in range(100):
+          data.append(self._to_send.get_nowait())
+      except queue.Empty:
+        pass
+      if data[-1] is self._WRITES_FINISHED:
+        done = True
+        data.pop()
+      if data:
+        yield beam_fn_api_pb2.Elements(data=data)
+
+  def _read_inputs(self, elements_iterator):
+    # TODO(robertwb): Pushback/throttling to avoid unbounded buffering.
+    try:
+      for elements in elements_iterator:
+        for data in elements.data:
+          self._receiving_queue(data.instruction_reference).put(data)
+    except:  # pylint: disable=broad-except
+      logging.exception('Failed to read inputs in the data plane')
+      raise
+    finally:
+      self._reads_finished.set()
+
+  def _start_reader(self, elements_iterator):
+    reader = threading.Thread(
+        target=lambda: self._read_inputs(elements_iterator),
+        name='read_grpc_client_inputs')
+    reader.daemon = True
+    reader.start()
+
+
+class GrpcClientDataChannel(_GrpcDataChannel):
+  """A DataChannel wrapping the client side of a BeamFnData connection."""
+
+  def __init__(self, data_stub):
+    super(GrpcClientDataChannel, self).__init__()
+    self._start_reader(data_stub.Data(self._write_outputs()))
+
+
+class GrpcServerDataChannel(
+    beam_fn_api_pb2.BeamFnDataServicer, _GrpcDataChannel):
+  """A DataChannel wrapping the server side of a BeamFnData connection."""
+
+  def Data(self, elements_iterator, context):
+    self._start_reader(elements_iterator)
+    for elements in self._write_outputs():
+      yield elements
+
+
+class DataChannelFactory(object):
+  """An abstract factory for creating ``DataChannel``."""
+
+  __metaclass__ = abc.ABCMeta
+
+  @abc.abstractmethod
+  def create_data_channel(self, function_spec):
+    """Returns a ``DataChannel`` from the given function_spec."""
+    raise NotImplementedError(type(self))
+
+  @abc.abstractmethod
+  def close(self):
+    """Close all channels that this factory owns."""
+    raise NotImplementedError(type(self))
+
+
+class GrpcClientDataChannelFactory(DataChannelFactory):
+  """A factory for ``GrpcClientDataChannel``.
+
+  Caches the created channels by ``data descriptor url``.
+  """
+
+  def __init__(self):
+    self._data_channel_cache = {}
+
+  def create_data_channel(self, function_spec):
+    remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+    function_spec.data.Unpack(remote_grpc_port)
+    url = remote_grpc_port.api_service_descriptor.url
+    if url not in self._data_channel_cache:
+      logging.info('Creating channel for %s', url)
+      grpc_channel = grpc.insecure_channel(url)
+      self._data_channel_cache[url] = GrpcClientDataChannel(
+          beam_fn_api_pb2.BeamFnDataStub(grpc_channel))
+    return self._data_channel_cache[url]
+
+  def close(self):
+    logging.info('Closing all cached grpc data channels.')
+    for _, channel in self._data_channel_cache.items():
+      channel.close()
+    self._data_channel_cache.clear()
+
+
+class InMemoryDataChannelFactory(DataChannelFactory):
+  """A singleton factory for ``InMemoryDataChannel``."""
+
+  def __init__(self, in_memory_data_channel):
+    self._in_memory_data_channel = in_memory_data_channel
+
+  def create_data_channel(self, unused_function_spec):
+    return self._in_memory_data_channel
+
+  def close(self):
+    pass

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py
new file mode 100644
index 0000000..7340789
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for apache_beam.runners.worker.data_plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import sys
+import threading
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import data_plane
+
+
+def timeout(timeout_secs):
+  def decorate(fn):
+    exc_info = []
+
+    def wrapper(*args, **kwargs):
+      def call_fn():
+        try:
+          fn(*args, **kwargs)
+        except:  # pylint: disable=bare-except
+          exc_info[:] = sys.exc_info()
+      thread = threading.Thread(target=call_fn)
+      thread.daemon = True
+      thread.start()
+      thread.join(timeout_secs)
+      if exc_info:
+        t, v, tb = exc_info  # pylint: disable=unbalanced-tuple-unpacking
+        raise t, v, tb
+      assert not thread.is_alive(), 'timed out after %s seconds' % timeout_secs
+    return wrapper
+  return decorate
+
+
+class DataChannelTest(unittest.TestCase):
+
+  @timeout(5)
+  def test_grpc_data_channel(self):
+    data_channel_service = data_plane.GrpcServerDataChannel()
+
+    server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+    beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+        data_channel_service, server)
+    test_port = server.add_insecure_port('[::]:0')
+    server.start()
+
+    data_channel_stub = beam_fn_api_pb2.BeamFnDataStub(
+        grpc.insecure_channel('localhost:%s' % test_port))
+    data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)
+
+    try:
+      self._data_channel_test(data_channel_service, data_channel_client)
+    finally:
+      data_channel_client.close()
+      data_channel_service.close()
+      data_channel_client.wait()
+      data_channel_service.wait()
+
+  def test_in_memory_data_channel(self):
+    channel = data_plane.InMemoryDataChannel()
+    self._data_channel_test(channel, channel.inverse())
+
+  def _data_channel_test(self, server, client):
+    self._data_channel_test_one_direction(server, client)
+    self._data_channel_test_one_direction(client, server)
+
+  def _data_channel_test_one_direction(self, from_channel, to_channel):
+    def send(instruction_id, target, data):
+      stream = from_channel.output_stream(instruction_id, target)
+      stream.write(data)
+      stream.close()
+    target_1 = beam_fn_api_pb2.Target(
+        primitive_transform_reference='1',
+        name='out')
+    target_2 = beam_fn_api_pb2.Target(
+        primitive_transform_reference='2',
+        name='out')
+
+    # Single write.
+    send('0', target_1, 'abc')
+    self.assertEqual(
+        list(to_channel.input_elements('0', [target_1])),
+        [beam_fn_api_pb2.Elements.Data(
+            instruction_reference='0',
+            target=target_1,
+            data='abc')])
+
+    # Multiple interleaved writes to multiple instructions.
+    target_2 = beam_fn_api_pb2.Target(
+        primitive_transform_reference='2',
+        name='out')
+
+    send('1', target_1, 'abc')
+    send('2', target_1, 'def')
+    self.assertEqual(
+        list(to_channel.input_elements('1', [target_1])),
+        [beam_fn_api_pb2.Elements.Data(
+            instruction_reference='1',
+            target=target_1,
+            data='abc')])
+    send('2', target_2, 'ghi')
+    self.assertEqual(
+        list(to_channel.input_elements('2', [target_1, target_2])),
+        [beam_fn_api_pb2.Elements.Data(
+            instruction_reference='2',
+            target=target_1,
+            data='def'),
+         beam_fn_api_pb2.Elements.Data(
+             instruction_reference='2',
+             target=target_2,
+             data='ghi')])
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py
new file mode 100644
index 0000000..b9e36ad
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Beam fn API log handler."""
+
+import logging
+import math
+import Queue as queue
+import threading
+
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class FnApiLogRecordHandler(logging.Handler):
+  """A handler that writes log records to the fn API."""
+
+  # Maximum number of log entries in a single stream request.
+  _MAX_BATCH_SIZE = 1000
+  # Used to indicate the end of stream.
+  _FINISHED = object()
+
+  # Mapping from logging levels to LogEntry levels.
+  LOG_LEVEL_MAP = {
+      logging.FATAL: beam_fn_api_pb2.LogEntry.CRITICAL,
+      logging.ERROR: beam_fn_api_pb2.LogEntry.ERROR,
+      logging.WARNING: beam_fn_api_pb2.LogEntry.WARN,
+      logging.INFO: beam_fn_api_pb2.LogEntry.INFO,
+      logging.DEBUG: beam_fn_api_pb2.LogEntry.DEBUG
+  }
+
+  def __init__(self, log_service_descriptor):
+    super(FnApiLogRecordHandler, self).__init__()
+    self._log_channel = grpc.insecure_channel(log_service_descriptor.url)
+    self._logging_stub = beam_fn_api_pb2.BeamFnLoggingStub(self._log_channel)
+    self._log_entry_queue = queue.Queue()
+
+    log_control_messages = self._logging_stub.Logging(self._write_log_entries())
+    self._reader = threading.Thread(
+        target=lambda: self._read_log_control_messages(log_control_messages),
+        name='read_log_control_messages')
+    self._reader.daemon = True
+    self._reader.start()
+
+  def emit(self, record):
+    log_entry = beam_fn_api_pb2.LogEntry()
+    log_entry.severity = self.LOG_LEVEL_MAP[record.levelno]
+    log_entry.message = self.format(record)
+    log_entry.thread = record.threadName
+    log_entry.log_location = record.module + '.' + record.funcName
+    (fraction, seconds) = math.modf(record.created)
+    nanoseconds = 1e9 * fraction
+    log_entry.timestamp.seconds = int(seconds)
+    log_entry.timestamp.nanos = int(nanoseconds)
+    self._log_entry_queue.put(log_entry)
+
+  def close(self):
+    """Flush out all existing log entries and unregister this handler."""
+    # Acquiring the handler lock ensures ``emit`` is not run until the lock is
+    # released.
+    self.acquire()
+    self._log_entry_queue.put(self._FINISHED)
+    # wait on server to close.
+    self._reader.join()
+    self.release()
+    # Unregister this handler.
+    super(FnApiLogRecordHandler, self).close()
+
+  def _write_log_entries(self):
+    done = False
+    while not done:
+      log_entries = [self._log_entry_queue.get()]
+      try:
+        for _ in range(self._MAX_BATCH_SIZE):
+          log_entries.append(self._log_entry_queue.get_nowait())
+      except queue.Empty:
+        pass
+      if log_entries[-1] is self._FINISHED:
+        done = True
+        log_entries.pop()
+      if log_entries:
+        yield beam_fn_api_pb2.LogEntry.List(log_entries=log_entries)
+
+  def _read_log_control_messages(self, log_control_iterator):
+    # TODO(vikasrk): Handle control messages.
+    for _ in log_control_iterator:
+      pass

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py
new file mode 100644
index 0000000..565bedb
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py
@@ -0,0 +1,105 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import logging
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import log_handler
+
+
+class BeamFnLoggingServicer(beam_fn_api_pb2.BeamFnLoggingServicer):
+
+  def __init__(self):
+    self.log_records_received = []
+
+  def Logging(self, request_iterator, context):
+
+    for log_record in request_iterator:
+      self.log_records_received.append(log_record)
+
+    yield beam_fn_api_pb2.LogControl()
+
+
+class FnApiLogRecordHandlerTest(unittest.TestCase):
+
+  def setUp(self):
+    self.test_logging_service = BeamFnLoggingServicer()
+    self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+    beam_fn_api_pb2.add_BeamFnLoggingServicer_to_server(
+        self.test_logging_service, self.server)
+    self.test_port = self.server.add_insecure_port('[::]:0')
+    self.server.start()
+
+    self.logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+    self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
+    self.fn_log_handler = log_handler.FnApiLogRecordHandler(
+        self.logging_service_descriptor)
+    logging.getLogger().setLevel(logging.INFO)
+    logging.getLogger().addHandler(self.fn_log_handler)
+
+  def tearDown(self):
+    # wait upto 5 seconds.
+    self.server.stop(5)
+
+  def _verify_fn_log_handler(self, num_log_entries):
+    msg = 'Testing fn logging'
+    logging.debug('Debug Message 1')
+    for idx in range(num_log_entries):
+      logging.info('%s: %s', msg, idx)
+    logging.debug('Debug Message 2')
+
+    # Wait for logs to be sent to server.
+    self.fn_log_handler.close()
+
+    num_received_log_entries = 0
+    for outer in self.test_logging_service.log_records_received:
+      for log_entry in outer.log_entries:
+        self.assertEqual(beam_fn_api_pb2.LogEntry.INFO, log_entry.severity)
+        self.assertEqual('%s: %s' % (msg, num_received_log_entries),
+                         log_entry.message)
+        self.assertEqual(u'log_handler_test._verify_fn_log_handler',
+                         log_entry.log_location)
+        self.assertGreater(log_entry.timestamp.seconds, 0)
+        self.assertGreater(log_entry.timestamp.nanos, 0)
+        num_received_log_entries += 1
+
+    self.assertEqual(num_received_log_entries, num_log_entries)
+
+
+# Test cases.
+data = {
+    'one_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE - 47,
+    'exact_multiple': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE,
+    'multi_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE * 3 + 47
+}
+
+
+def _create_test(name, num_logs):
+  setattr(FnApiLogRecordHandlerTest, 'test_%s' % name,
+          lambda self: self._verify_fn_log_handler(num_logs))
+
+
+if __name__ == '__main__':
+  for test_name, num_logs_entries in data.iteritems():
+    _create_test(test_name, num_logs_entries)
+
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.pxd b/sdks/python/apache_beam/runners/worker/logger.pxd
new file mode 100644
index 0000000..201daf4
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.pxd
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+
+from apache_beam.runners.common cimport LoggingContext
+
+
+cdef class PerThreadLoggingContext(LoggingContext):
+  cdef kwargs
+  cdef list stack

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py
new file mode 100644
index 0000000..217dc58
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.py
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Python worker logging."""
+
+import json
+import logging
+import threading
+import traceback
+
+from apache_beam.runners.common import LoggingContext
+
+
+# Per-thread worker information. This is used only for logging to set
+# context information that changes while work items get executed:
+# work_item_id, step_name, stage_name.
+class _PerThreadWorkerData(threading.local):
+
+  def __init__(self):
+    super(_PerThreadWorkerData, self).__init__()
+    # TODO(robertwb): Consider starting with an initial (ignored) ~20 elements
+    # in the list, as going up and down all the way to zero incurs several
+    # reallocations.
+    self.stack = []
+
+  def get_data(self):
+    all_data = {}
+    for datum in self.stack:
+      all_data.update(datum)
+    return all_data
+
+
+per_thread_worker_data = _PerThreadWorkerData()
+
+
+class PerThreadLoggingContext(LoggingContext):
+  """A context manager to add per thread attributes."""
+
+  def __init__(self, **kwargs):
+    self.kwargs = kwargs
+    self.stack = per_thread_worker_data.stack
+
+  def __enter__(self):
+    self.enter()
+
+  def enter(self):
+    self.stack.append(self.kwargs)
+
+  def __exit__(self, exn_type, exn_value, exn_traceback):
+    self.exit()
+
+  def exit(self):
+    self.stack.pop()
+
+
+class JsonLogFormatter(logging.Formatter):
+  """A JSON formatter class as expected by the logging standard module."""
+
+  def __init__(self, job_id, worker_id):
+    super(JsonLogFormatter, self).__init__()
+    self.job_id = job_id
+    self.worker_id = worker_id
+
+  def format(self, record):
+    """Returns a JSON string based on a LogRecord instance.
+
+    Args:
+      record: A LogRecord instance. See below for details.
+
+    Returns:
+      A JSON string representing the record.
+
+    A LogRecord instance has the following attributes and is used for
+    formatting the final message.
+
+    Attributes:
+      created: A double representing the timestamp for record creation
+        (e.g., 1438365207.624597). Note that the number contains also msecs and
+        microsecs information. Part of this is also available in the 'msecs'
+        attribute.
+      msecs: A double representing the msecs part of the record creation
+        (e.g., 624.5970726013184).
+      msg: Logging message containing formatting instructions or an arbitrary
+        object. This is the first argument of a log call.
+      args: A tuple containing the positional arguments for the logging call.
+      levelname: A string. Possible values are: INFO, WARNING, ERROR, etc.
+      exc_info: None or a 3-tuple with exception information as it is
+        returned by a call to sys.exc_info().
+      name: Logger's name. Most logging is done using the default root logger
+        and therefore the name will be 'root'.
+      filename: Basename of the file where logging occurred.
+      funcName: Name of the function where logging occurred.
+      process: The PID of the process running the worker.
+      thread:  An id for the thread where the record was logged. This is not a
+        real TID (the one provided by OS) but rather the id (address) of a
+        Python thread object. Nevertheless having this value can allow to
+        filter log statement from only one specific thread.
+    """
+    output = {}
+    output['timestamp'] = {
+        'seconds': int(record.created),
+        'nanos': int(record.msecs * 1000000)}
+    # ERROR. INFO, DEBUG log levels translate into the same for severity
+    # property. WARNING becomes WARN.
+    output['severity'] = (
+        record.levelname if record.levelname != 'WARNING' else 'WARN')
+
+    # msg could be an arbitrary object, convert it to a string first.
+    record_msg = str(record.msg)
+
+    # Prepare the actual message using the message formatting string and the
+    # positional arguments as they have been used in the log call.
+    if record.args:
+      try:
+        output['message'] = record_msg % record.args
+      except (TypeError, ValueError):
+        output['message'] = '%s with args (%s)' % (record_msg, record.args)
+    else:
+      output['message'] = record_msg
+
+    # The thread ID is logged as a combination of the process ID and thread ID
+    # since workers can run in multiple processes.
+    output['thread'] = '%s:%s' % (record.process, record.thread)
+    # job ID and worker ID. These do not change during the lifetime of a worker.
+    output['job'] = self.job_id
+    output['worker'] = self.worker_id
+    # Stage, step and work item ID come from thread local storage since they
+    # change with every new work item leased for execution. If there is no
+    # work item ID then we make sure the step is undefined too.
+    data = per_thread_worker_data.get_data()
+    if 'work_item_id' in data:
+      output['work'] = data['work_item_id']
+    if 'stage_name' in data:
+      output['stage'] = data['stage_name']
+    if 'step_name' in data:
+      output['step'] = data['step_name']
+    # All logging happens using the root logger. We will add the basename of the
+    # file and the function name where the logging happened to make it easier
+    # to identify who generated the record.
+    output['logger'] = '%s:%s:%s' % (
+        record.name, record.filename, record.funcName)
+    # Add exception information if any is available.
+    if record.exc_info:
+      output['exception'] = ''.join(
+          traceback.format_exception(*record.exc_info))
+
+    return json.dumps(output)
+
+
+def initialize(job_id, worker_id, log_path):
+  """Initialize root logger so that we log JSON to a file and text to stdout."""
+
+  file_handler = logging.FileHandler(log_path)
+  file_handler.setFormatter(JsonLogFormatter(job_id, worker_id))
+  logging.getLogger().addHandler(file_handler)
+
+  # Set default level to INFO to avoid logging various DEBUG level log calls
+  # sprinkled throughout the code.
+  logging.getLogger().setLevel(logging.INFO)