You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/05/02 00:45:02 UTC

[1/4] beam git commit: Fn API support for Python.

Repository: beam
Updated Branches:
  refs/heads/master b8131fe93 -> 7c425b097


http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
new file mode 100644
index 0000000..996f44c
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -0,0 +1,168 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for apache_beam.runners.worker.sdk_worker."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.io.concat_source_test import RangeSource
+from apache_beam.io.iobase import SourceBundle
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker import sdk_worker
+
+
+class BeamFnControlServicer(beam_fn_api_pb2.BeamFnControlServicer):
+
+  def __init__(self, requests, raise_errors=True):
+    self.requests = requests
+    self.instruction_ids = set(r.instruction_id for r in requests)
+    self.responses = {}
+    self.raise_errors = raise_errors
+
+  def Control(self, response_iterator, context):
+    for request in self.requests:
+      logging.info("Sending request %s", request)
+      yield request
+    for response in response_iterator:
+      logging.info("Got response %s", response)
+      if response.instruction_id != -1:
+        assert response.instruction_id in self.instruction_ids
+        assert response.instruction_id not in self.responses
+        self.responses[response.instruction_id] = response
+        if self.raise_errors and response.error:
+          raise RuntimeError(response.error)
+        elif len(self.responses) == len(self.requests):
+          logging.info("All %s instructions finished.", len(self.requests))
+          return
+    raise RuntimeError("Missing responses: %s" %
+                       (self.instruction_ids - set(self.responses.keys())))
+
+
+class SdkWorkerTest(unittest.TestCase):
+
+  def test_fn_registration(self):
+    fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)]
+
+    process_bundle_descriptors = [beam_fn_api_pb2.ProcessBundleDescriptor(
+        id=str(100+ix),
+        primitive_transform=[
+            beam_fn_api_pb2.PrimitiveTransform(function_spec=fn)])
+                                  for ix, fn in enumerate(fns)]
+
+    test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest(
+        register=beam_fn_api_pb2.RegisterRequest(
+            process_bundle_descriptor=process_bundle_descriptors))])
+
+    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+    beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server)
+    test_port = server.add_insecure_port("[::]:0")
+    server.start()
+
+    channel = grpc.insecure_channel("localhost:%s" % test_port)
+    harness = sdk_worker.SdkHarness(channel)
+    harness.run()
+    self.assertEqual(
+        harness.worker.fns,
+        {item.id: item for item in fns + process_bundle_descriptors})
+
+  @unittest.skip("initial splitting not in proto")
+  def test_source_split(self):
+    source = RangeSource(0, 100)
+    expected_splits = list(source.split(30))
+
+    worker = sdk_harness.SdkWorker(
+        None, data_plane.GrpcClientDataChannelFactory())
+    worker.register(
+        beam_fn_api_pb2.RegisterRequest(
+            process_bundle_descriptor=[beam_fn_api_pb2.ProcessBundleDescriptor(
+                primitive_transform=[beam_fn_api_pb2.PrimitiveTransform(
+                    function_spec=sdk_harness.serialize_and_pack_py_fn(
+                        SourceBundle(1.0, source, None, None),
+                        sdk_harness.PYTHON_SOURCE_URN,
+                        id="src"))])]))
+    split_response = worker.initial_source_split(
+        beam_fn_api_pb2.InitialSourceSplitRequest(
+            desired_bundle_size_bytes=30,
+            source_reference="src"))
+
+    self.assertEqual(
+        expected_splits,
+        [sdk_harness.unpack_and_deserialize_py_fn(s.source)
+         for s in split_response.splits])
+
+    self.assertEqual(
+        [s.weight for s in expected_splits],
+        [s.relative_size for s in split_response.splits])
+
+  @unittest.skip("initial splitting not in proto")
+  def test_source_split_via_instruction(self):
+
+    source = RangeSource(0, 100)
+    expected_splits = list(source.split(30))
+
+    test_controller = BeamFnControlServicer([
+        beam_fn_api_pb2.InstructionRequest(
+            instruction_id="register_request",
+            register=beam_fn_api_pb2.RegisterRequest(
+                process_bundle_descriptor=[
+                    beam_fn_api_pb2.ProcessBundleDescriptor(
+                        primitive_transform=[beam_fn_api_pb2.PrimitiveTransform(
+                            function_spec=sdk_harness.serialize_and_pack_py_fn(
+                                SourceBundle(1.0, source, None, None),
+                                sdk_harness.PYTHON_SOURCE_URN,
+                                id="src"))])])),
+        beam_fn_api_pb2.InstructionRequest(
+            instruction_id="split_request",
+            initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest(
+                desired_bundle_size_bytes=30,
+                source_reference="src"))
+        ])
+
+    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+    beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server)
+    test_port = server.add_insecure_port("[::]:0")
+    server.start()
+
+    channel = grpc.insecure_channel("localhost:%s" % test_port)
+    harness = sdk_harness.SdkHarness(channel)
+    harness.run()
+
+    split_response = test_controller.responses[
+        "split_request"].initial_source_split
+
+    self.assertEqual(
+        expected_splits,
+        [sdk_harness.unpack_and_deserialize_py_fn(s.source)
+         for s in split_response.splits])
+
+    self.assertEqual(
+        [s.weight for s in expected_splits],
+        [s.relative_size for s in split_response.splits])
+
+
+if __name__ == "__main__":
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sideinputs.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sideinputs.py b/sdks/python/apache_beam/runners/worker/sideinputs.py
new file mode 100644
index 0000000..3bac3d9
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sideinputs.py
@@ -0,0 +1,166 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Utilities for handling side inputs."""
+
+import collections
+import logging
+import Queue
+import threading
+import traceback
+
+from apache_beam.io import iobase
+from apache_beam.transforms import window
+
+
+# Maximum number of reader threads for reading side input sources, per side
+# input.
+MAX_SOURCE_READER_THREADS = 15
+
+# Number of slots for elements in side input element queue.  Note that this
+# value is intentionally smaller than MAX_SOURCE_READER_THREADS so as to reduce
+# memory pressure of holding potentially-large elements in memory.  Note that
+# the number of pending elements in memory is equal to the sum of
+# MAX_SOURCE_READER_THREADS and ELEMENT_QUEUE_SIZE.
+ELEMENT_QUEUE_SIZE = 10
+
+# Special element value sentinel for signaling reader state.
+READER_THREAD_IS_DONE_SENTINEL = object()
+
+# Used to efficiently window the values of non-windowed side inputs.
+_globally_windowed = window.GlobalWindows.windowed_value(None).with_value
+
+
+class PrefetchingSourceSetIterable(object):
+  """Value iterator that reads concurrently from a set of sources."""
+
+  def __init__(self, sources,
+               max_reader_threads=MAX_SOURCE_READER_THREADS):
+    self.sources = sources
+    self.num_reader_threads = min(max_reader_threads, len(self.sources))
+
+    # Queue for sources that are to be read.
+    self.sources_queue = Queue.Queue()
+    for source in sources:
+      self.sources_queue.put(source)
+    # Queue for elements that have been read.
+    self.element_queue = Queue.Queue(ELEMENT_QUEUE_SIZE)
+    # Queue for exceptions encountered in reader threads; to be rethrown.
+    self.reader_exceptions = Queue.Queue()
+    # Whether we have already iterated; this iterable can only be used once.
+    self.already_iterated = False
+    # Whether an error was encountered in any source reader.
+    self.has_errored = False
+
+    self.reader_threads = []
+    self._start_reader_threads()
+
+  def _start_reader_threads(self):
+    for _ in range(0, self.num_reader_threads):
+      t = threading.Thread(target=self._reader_thread)
+      t.daemon = True
+      t.start()
+      self.reader_threads.append(t)
+
+  def _reader_thread(self):
+    # pylint: disable=too-many-nested-blocks
+    try:
+      while True:
+        try:
+          source = self.sources_queue.get_nowait()
+          if isinstance(source, iobase.BoundedSource):
+            for value in source.read(source.get_range_tracker(None, None)):
+              if self.has_errored:
+                # If any reader has errored, just return.
+                return
+              if isinstance(value, window.WindowedValue):
+                self.element_queue.put(value)
+              else:
+                self.element_queue.put(_globally_windowed(value))
+          else:
+            # Native dataflow source.
+            with source.reader() as reader:
+              returns_windowed_values = reader.returns_windowed_values
+              for value in reader:
+                if self.has_errored:
+                  # If any reader has errored, just return.
+                  return
+                if returns_windowed_values:
+                  self.element_queue.put(value)
+                else:
+                  self.element_queue.put(_globally_windowed(value))
+        except Queue.Empty:
+          return
+    except Exception as e:  # pylint: disable=broad-except
+      logging.error('Encountered exception in PrefetchingSourceSetIterable '
+                    'reader thread: %s', traceback.format_exc())
+      self.reader_exceptions.put(e)
+      self.has_errored = True
+    finally:
+      self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
+
+  def __iter__(self):
+    if self.already_iterated:
+      raise RuntimeError(
+          'Can only iterate once over PrefetchingSourceSetIterable instance.')
+    self.already_iterated = True
+
+    # The invariants during execution are:
+    # 1) A worker thread always posts the sentinel as the last thing it does
+    #    before exiting.
+    # 2) We always wait for all sentinels and then join all threads.
+    num_readers_finished = 0
+    try:
+      while True:
+        element = self.element_queue.get()
+        if element is READER_THREAD_IS_DONE_SENTINEL:
+          num_readers_finished += 1
+          if num_readers_finished == self.num_reader_threads:
+            return
+        elif self.has_errored:
+          raise self.reader_exceptions.get()
+        else:
+          yield element
+    except GeneratorExit:
+      self.has_errored = True
+      raise
+    finally:
+      while num_readers_finished < self.num_reader_threads:
+        element = self.element_queue.get()
+        if element is READER_THREAD_IS_DONE_SENTINEL:
+          num_readers_finished += 1
+      for t in self.reader_threads:
+        t.join()
+
+
+def get_iterator_fn_for_sources(
+    sources, max_reader_threads=MAX_SOURCE_READER_THREADS):
+  """Returns callable that returns iterator over elements for given sources."""
+  def _inner():
+    return iter(PrefetchingSourceSetIterable(
+        sources, max_reader_threads=max_reader_threads))
+  return _inner
+
+
+class EmulatedIterable(collections.Iterable):
+  """Emulates an iterable for a side input."""
+
+  def __init__(self, iterator_fn):
+    self.iterator_fn = iterator_fn
+
+  def __iter__(self):
+    return self.iterator_fn()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sideinputs_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sideinputs_test.py b/sdks/python/apache_beam/runners/worker/sideinputs_test.py
new file mode 100644
index 0000000..d243bbe
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sideinputs_test.py
@@ -0,0 +1,150 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for side input utilities."""
+
+import logging
+import time
+import unittest
+
+from apache_beam.runners.worker import sideinputs
+
+
+def strip_windows(iterator):
+  return [wv.value for wv in iterator]
+
+
+class FakeSource(object):
+
+  def __init__(self, items):
+    self.items = items
+
+  def reader(self):
+    return FakeSourceReader(self.items)
+
+
+class FakeSourceReader(object):
+
+  def __init__(self, items):
+    self.items = items
+    self.entered = False
+    self.exited = False
+
+  def __iter__(self):
+    return iter(self.items)
+
+  def __enter__(self):
+    self.entered = True
+    return self
+
+  def __exit__(self, exception_type, exception_value, traceback):
+    self.exited = True
+
+  @property
+  def returns_windowed_values(self):
+    return False
+
+
+class PrefetchingSourceIteratorTest(unittest.TestCase):
+
+  def test_single_source_iterator_fn(self):
+    sources = [
+        FakeSource([0, 1, 2, 3, 4, 5]),
+    ]
+    iterator_fn = sideinputs.get_iterator_fn_for_sources(
+        sources, max_reader_threads=2)
+    assert list(strip_windows(iterator_fn())) == range(6)
+
+  def test_multiple_sources_iterator_fn(self):
+    sources = [
+        FakeSource([0]),
+        FakeSource([1, 2, 3, 4, 5]),
+        FakeSource([]),
+        FakeSource([6, 7, 8, 9, 10]),
+    ]
+    iterator_fn = sideinputs.get_iterator_fn_for_sources(
+        sources, max_reader_threads=3)
+    assert sorted(strip_windows(iterator_fn())) == range(11)
+
+  def test_multiple_sources_single_reader_iterator_fn(self):
+    sources = [
+        FakeSource([0]),
+        FakeSource([1, 2, 3, 4, 5]),
+        FakeSource([]),
+        FakeSource([6, 7, 8, 9, 10]),
+    ]
+    iterator_fn = sideinputs.get_iterator_fn_for_sources(
+        sources, max_reader_threads=1)
+    assert list(strip_windows(iterator_fn())) == range(11)
+
+  def test_source_iterator_fn_exception(self):
+    class MyException(Exception):
+      pass
+
+    def exception_generator():
+      yield 0
+      time.sleep(0.1)
+      raise MyException('I am an exception!')
+
+    def perpetual_generator(value):
+      while True:
+        yield value
+
+    sources = [
+        FakeSource(perpetual_generator(1)),
+        FakeSource(perpetual_generator(2)),
+        FakeSource(perpetual_generator(3)),
+        FakeSource(perpetual_generator(4)),
+        FakeSource(exception_generator()),
+    ]
+    iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
+    seen = set()
+    with self.assertRaises(MyException):
+      for value in iterator_fn():
+        seen.add(value.value)
+    self.assertEqual(sorted(seen), range(5))
+
+
+class EmulatedCollectionsTest(unittest.TestCase):
+
+  def test_emulated_iterable(self):
+    def _iterable_fn():
+      for i in range(10):
+        yield i
+    iterable = sideinputs.EmulatedIterable(_iterable_fn)
+    # Check that multiple iterations are supported.
+    for _ in range(0, 5):
+      for i, j in enumerate(iterable):
+        self.assertEqual(i, j)
+
+  def test_large_iterable_values(self):
+    # Here, we create a large collection that would be too big for memory-
+    # constained test environments, but should be under the memory limit if
+    # materialized one at a time.
+    def _iterable_fn():
+      for i in range(10):
+        yield ('%d' % i) * (200 * 1024 * 1024)
+    iterable = sideinputs.EmulatedIterable(_iterable_fn)
+    # Check that multiple iterations are supported.
+    for _ in range(0, 3):
+      for i, j in enumerate(iterable):
+        self.assertEqual(('%d' % i) * (200 * 1024 * 1024), j)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/statesampler.pyx
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/statesampler.pyx b/sdks/python/apache_beam/runners/worker/statesampler.pyx
new file mode 100644
index 0000000..3ff6c20
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/statesampler.pyx
@@ -0,0 +1,237 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""State sampler for tracking time spent in execution steps.
+
+The state sampler profiles the time spent in each step of a pipeline.
+Operations (defined in executor.py) which are executed as part of a MapTask are
+instrumented with context managers provided by StateSampler.scoped_state().
+These context managers change the internal state of the StateSampler during each
+relevant Operation's .start(), .process() and .finish() methods.  State is
+sampled by a raw C thread, not holding the Python Global Interpreter Lock, which
+queries the StateSampler's internal state at a defined sampling frequency.  In a
+common example, a ReadOperation during its .start() method reads an element and
+calls a DoOperation's .process() method, which can call a WriteOperation's
+.process() method.  Each element processed causes the current state to
+transition between these states of different Operations.  Each time the sampling
+thread queries the current state, the time spent since the previous sample is
+attributed to that state and accumulated.  Over time, this allows a granular
+runtime profile to be produced.
+"""
+
+import threading
+import time
+
+
+from apache_beam.utils.counters import Counter
+
+
+cimport cython
+from cpython cimport pythread
+from libc.stdint cimport int32_t, int64_t
+
+cdef extern from "Python.h":
+  # This typically requires the GIL, but we synchronize the list modifications
+  # we use this on via our own lock.
+  cdef void* PyList_GET_ITEM(list, Py_ssize_t index) nogil
+
+cdef extern from "unistd.h" nogil:
+  void usleep(long)
+
+cdef extern from "<time.h>" nogil:
+  struct timespec:
+    long tv_sec  # seconds
+    long tv_nsec  # nanoseconds
+  int clock_gettime(int clock_id, timespec *result)
+
+cdef inline int64_t get_nsec_time() nogil:
+  """Get current time as microseconds since Unix epoch."""
+  cdef timespec current_time
+  # First argument value of 0 corresponds to CLOCK_REALTIME.
+  clock_gettime(0, &current_time)
+  return (
+      (<int64_t> current_time.tv_sec) * 1000000000 +  # second to nanoseconds
+      current_time.tv_nsec)
+
+
+class StateSamplerInfo(object):
+  """Info for current state and transition statistics of StateSampler."""
+
+  def __init__(self, state_name, transition_count):
+    self.state_name = state_name
+    self.transition_count = transition_count
+
+  def __repr__(self):
+    return '<StateSamplerInfo %s %d>' % (self.state_name, self.transition_count)
+
+
+# Default period for sampling current state of pipeline execution.
+DEFAULT_SAMPLING_PERIOD_MS = 200
+
+
+cdef class StateSampler(object):
+  """Tracks time spent in states during pipeline execution."""
+
+  cdef object prefix
+  cdef object counter_factory
+  cdef int sampling_period_ms
+
+  cdef dict scoped_states_by_name
+  cdef list scoped_states_by_index
+
+  cdef bint started
+  cdef bint finished
+  cdef object sampling_thread
+
+  # This lock guards members that are shared between threads, specificaly
+  # finished, scoped_states_by_index, and the nsecs field of each state therein.
+  cdef pythread.PyThread_type_lock lock
+
+  cdef public int64_t state_transition_count
+
+  cdef int32_t current_state_index
+
+  def __init__(self, prefix, counter_factory,
+      sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS):
+
+    self.prefix = prefix
+    self.counter_factory = counter_factory
+    self.sampling_period_ms = sampling_period_ms
+
+    self.lock = pythread.PyThread_allocate_lock()
+    self.scoped_states_by_name = {}
+
+    self.current_state_index = 0
+    unknown_state = ScopedState(self, 'unknown', self.current_state_index)
+    pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+    self.scoped_states_by_index = [unknown_state]
+    self.finished = False
+    pythread.PyThread_release_lock(self.lock)
+
+    # Assert that the compiler correctly aligned the current_state field.  This
+    # is necessary for reads and writes to this variable to be atomic across
+    # threads without additional synchronization.
+    # States are referenced via an index rather than, say, a pointer because
+    # of better support for 32-bit atomic reads and writes.
+    assert (<int64_t> &self.current_state_index) % sizeof(int32_t) == 0, (
+        'Address of StateSampler.current_state_index is not word-aligned.')
+
+  def __dealloc__(self):
+    pythread.PyThread_free_lock(self.lock)
+
+  def run(self):
+    cdef int64_t last_nsecs = get_nsec_time()
+    cdef int64_t elapsed_nsecs
+    with nogil:
+      while True:
+        usleep(self.sampling_period_ms * 1000)
+        pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+        try:
+          if self.finished:
+            break
+          elapsed_nsecs = get_nsec_time() - last_nsecs
+          # Take an address as we can't create a reference to the scope
+          # without the GIL.
+          nsecs_ptr = &(<ScopedState>PyList_GET_ITEM(
+              self.scoped_states_by_index, self.current_state_index)).nsecs
+          nsecs_ptr[0] += elapsed_nsecs
+          last_nsecs += elapsed_nsecs
+        finally:
+          pythread.PyThread_release_lock(self.lock)
+
+  def start(self):
+    assert not self.started
+    self.started = True
+    self.sampling_thread = threading.Thread(target=self.run)
+    self.sampling_thread.start()
+
+  def stop(self):
+    assert not self.finished
+    pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+    self.finished = True
+    pythread.PyThread_release_lock(self.lock)
+    # May have to wait up to sampling_period_ms, but the platform-independent
+    # pythread doesn't support conditions.
+    self.sampling_thread.join()
+
+  def stop_if_still_running(self):
+    if self.started and not self.finished:
+      self.stop()
+
+  def get_info(self):
+    """Returns StateSamplerInfo with transition statistics."""
+    return StateSamplerInfo(
+        self.scoped_states_by_index[self.current_state_index].name,
+        self.state_transition_count)
+
+  def scoped_state(self, name):
+    """Returns a context manager managing transitions for a given state."""
+    cdef ScopedState scoped_state = self.scoped_states_by_name.get(name, None)
+    if scoped_state is None:
+      output_counter = self.counter_factory.get_counter(
+          '%s%s-msecs' % (self.prefix,  name), Counter.SUM)
+      new_state_index = len(self.scoped_states_by_index)
+      scoped_state = ScopedState(self, name, new_state_index, output_counter)
+      # Both scoped_states_by_index and scoped_state.nsecs are accessed
+      # by the sampling thread; initialize them under the lock.
+      pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+      self.scoped_states_by_index.append(scoped_state)
+      scoped_state.nsecs = 0
+      pythread.PyThread_release_lock(self.lock)
+      self.scoped_states_by_name[name] = scoped_state
+    return scoped_state
+
+  def commit_counters(self):
+    """Updates output counters with latest state statistics."""
+    for state in self.scoped_states_by_name.values():
+      state_msecs = int(1e-6 * state.nsecs)
+      state.counter.update(state_msecs - state.counter.value())
+
+
+cdef class ScopedState(object):
+  """Context manager class managing transitions for a given sampler state."""
+
+  cdef readonly StateSampler sampler
+  cdef readonly int32_t state_index
+  cdef readonly object counter
+  cdef readonly object name
+  cdef readonly int64_t nsecs
+  cdef int32_t old_state_index
+
+  def __init__(self, sampler, name, state_index, counter=None):
+    self.sampler = sampler
+    self.name = name
+    self.state_index = state_index
+    self.counter = counter
+
+  cpdef __enter__(self):
+    self.old_state_index = self.sampler.current_state_index
+    pythread.PyThread_acquire_lock(self.sampler.lock, pythread.WAIT_LOCK)
+    self.sampler.current_state_index = self.state_index
+    pythread.PyThread_release_lock(self.sampler.lock)
+    self.sampler.state_transition_count += 1
+
+  cpdef __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
+    pythread.PyThread_acquire_lock(self.sampler.lock, pythread.WAIT_LOCK)
+    self.sampler.current_state_index = self.old_state_index
+    pythread.PyThread_release_lock(self.sampler.lock)
+    self.sampler.state_transition_count += 1
+
+  def __repr__(self):
+    return "ScopedState[%s, %s, %s]" % (self.name, self.state_index, self.nsecs)

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/statesampler_fake.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fake.py b/sdks/python/apache_beam/runners/worker/statesampler_fake.py
new file mode 100644
index 0000000..efd7f2d
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fake.py
@@ -0,0 +1,34 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+class StateSampler(object):
+
+  def __init__(self, *args, **kwargs):
+    pass
+
+  def scoped_state(self, name):
+    return _FakeScopedState()
+
+
+class _FakeScopedState(object):
+
+  def __enter__(self):
+    pass
+
+  def __exit__(self, *unused_args):
+    pass

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/statesampler_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py b/sdks/python/apache_beam/runners/worker/statesampler_test.py
new file mode 100644
index 0000000..663cdec
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for state sampler."""
+
+import logging
+import time
+import unittest
+
+from nose.plugins.skip import SkipTest
+
+from apache_beam.utils.counters import CounterFactory
+
+
+class StateSamplerTest(unittest.TestCase):
+
+  def setUp(self):
+    try:
+      # pylint: disable=global-variable-not-assigned
+      global statesampler
+      import statesampler
+    except ImportError:
+      raise SkipTest('State sampler not compiled.')
+    super(StateSamplerTest, self).setUp()
+
+  def test_basic_sampler(self):
+    # Set up state sampler.
+    counter_factory = CounterFactory()
+    sampler = statesampler.StateSampler('basic-', counter_factory,
+                                        sampling_period_ms=1)
+
+    # Run basic workload transitioning between 3 states.
+    sampler.start()
+    with sampler.scoped_state('statea'):
+      time.sleep(0.1)
+      with sampler.scoped_state('stateb'):
+        time.sleep(0.2 / 2)
+        with sampler.scoped_state('statec'):
+          time.sleep(0.3)
+        time.sleep(0.2 / 2)
+    sampler.stop()
+    sampler.commit_counters()
+
+    # Test that sampled state timings are close to their expected values.
+    expected_counter_values = {
+        'basic-statea-msecs': 100,
+        'basic-stateb-msecs': 200,
+        'basic-statec-msecs': 300,
+    }
+    for counter in counter_factory.get_counters():
+      self.assertIn(counter.name, expected_counter_values)
+      expected_value = expected_counter_values[counter.name]
+      actual_value = counter.value()
+      self.assertGreater(actual_value, expected_value * 0.75)
+      self.assertLess(actual_value, expected_value * 1.25)
+
+  def test_sampler_transition_overhead(self):
+    # Set up state sampler.
+    counter_factory = CounterFactory()
+    sampler = statesampler.StateSampler('overhead-', counter_factory,
+                                        sampling_period_ms=10)
+
+    # Run basic workload transitioning between 3 states.
+    state_a = sampler.scoped_state('statea')
+    state_b = sampler.scoped_state('stateb')
+    state_c = sampler.scoped_state('statec')
+    start_time = time.time()
+    sampler.start()
+    for _ in range(100000):
+      with state_a:
+        with state_b:
+          for _ in range(10):
+            with state_c:
+              pass
+    sampler.stop()
+    elapsed_time = time.time() - start_time
+    state_transition_count = sampler.get_info().transition_count
+    overhead_us = 1000000.0 * elapsed_time / state_transition_count
+    logging.info('Overhead per transition: %fus', overhead_us)
+    # Conservative upper bound on overhead in microseconds (we expect this to
+    # take 0.17us when compiled in opt mode or 0.48 us when compiled with in
+    # debug mode).
+    self.assertLess(overhead_us, 10.0)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/generate_pydoc.sh
----------------------------------------------------------------------
diff --git a/sdks/python/generate_pydoc.sh b/sdks/python/generate_pydoc.sh
index b04e27a..6039942 100755
--- a/sdks/python/generate_pydoc.sh
+++ b/sdks/python/generate_pydoc.sh
@@ -44,6 +44,8 @@ python $(type -p sphinx-apidoc) -f -o target/docs/source apache_beam \
 # Remove Cython modules from doc template; they won't load
 sed -i -e '/.. automodule:: apache_beam.coders.stream/d' \
     target/docs/source/apache_beam.coders.rst
+sed -i -e '/.. automodule:: apache_beam.runners.worker.statesampler/d' \
+    target/docs/source/apache_beam.runners.worker.rst
 
 # Create the configuration and index files
 cat > target/docs/source/conf.py <<'EOF'

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/setup.py
----------------------------------------------------------------------
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 615931b..f527362 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -121,12 +121,16 @@ setuptools.setup(
     author=PACKAGE_AUTHOR,
     author_email=PACKAGE_EMAIL,
     packages=setuptools.find_packages(),
-    package_data={'apache_beam': ['**/*.pyx', '**/*.pxd', 'tests/data/*']},
+    package_data={'apache_beam': [
+        '*/*.pyx', '*/*/*.pyx', '*/*.pxd', '*/*/*.pxd', 'tests/data/*']},
     ext_modules=cythonize([
         '**/*.pyx',
         'apache_beam/coders/coder_impl.py',
-        'apache_beam/runners/common.py',
         'apache_beam/metrics/execution.py',
+        'apache_beam/runners/common.py',
+        'apache_beam/runners/worker/logger.py',
+        'apache_beam/runners/worker/opcounters.py',
+        'apache_beam/runners/worker/operations.py',
         'apache_beam/transforms/cy_combiners.py',
         'apache_beam/utils/counters.py',
         'apache_beam/utils/windowed_value.py',


[2/4] beam git commit: Fn API support for Python.

Posted by ro...@apache.org.
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger_test.py b/sdks/python/apache_beam/runners/worker/logger_test.py
new file mode 100644
index 0000000..cf3f692
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger_test.py
@@ -0,0 +1,182 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for worker logging utilities."""
+
+import json
+import logging
+import sys
+import threading
+import unittest
+
+from apache_beam.runners.worker import logger
+
+
+class PerThreadLoggingContextTest(unittest.TestCase):
+
+  def thread_check_attribute(self, name):
+    self.assertFalse(name in logger.per_thread_worker_data.get_data())
+    with logger.PerThreadLoggingContext(**{name: 'thread-value'}):
+      self.assertEqual(
+          logger.per_thread_worker_data.get_data()[name], 'thread-value')
+    self.assertFalse(name in logger.per_thread_worker_data.get_data())
+
+  def test_per_thread_attribute(self):
+    self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+    with logger.PerThreadLoggingContext(xyz='value'):
+      self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+      thread = threading.Thread(
+          target=self.thread_check_attribute, args=('xyz',))
+      thread.start()
+      thread.join()
+      self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+    self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+  def test_set_when_undefined(self):
+    self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+    with logger.PerThreadLoggingContext(xyz='value'):
+      self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+    self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+  def test_set_when_already_defined(self):
+    self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+    with logger.PerThreadLoggingContext(xyz='value'):
+      self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+      with logger.PerThreadLoggingContext(xyz='value2'):
+        self.assertEqual(
+            logger.per_thread_worker_data.get_data()['xyz'], 'value2')
+      self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+    self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+
+class JsonLogFormatterTest(unittest.TestCase):
+
+  SAMPLE_RECORD = {
+      'created': 123456.789, 'msecs': 789.654321,
+      'msg': '%s:%d:%.2f', 'args': ('xyz', 4, 3.14),
+      'levelname': 'WARNING',
+      'process': 'pid', 'thread': 'tid',
+      'name': 'name', 'filename': 'file', 'funcName': 'func',
+      'exc_info': None}
+
+  SAMPLE_OUTPUT = {
+      'timestamp': {'seconds': 123456, 'nanos': 789654321},
+      'severity': 'WARN', 'message': 'xyz:4:3.14', 'thread': 'pid:tid',
+      'job': 'jobid', 'worker': 'workerid', 'logger': 'name:file:func'}
+
+  def create_log_record(self, **kwargs):
+
+    class Record(object):
+
+      def __init__(self, **kwargs):
+        for k, v in kwargs.iteritems():
+          setattr(self, k, v)
+
+    return Record(**kwargs)
+
+  def test_basic_record(self):
+    formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+    record = self.create_log_record(**self.SAMPLE_RECORD)
+    self.assertEqual(json.loads(formatter.format(record)), self.SAMPLE_OUTPUT)
+
+  def execute_multiple_cases(self, test_cases):
+    record = self.SAMPLE_RECORD
+    output = self.SAMPLE_OUTPUT
+    formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+
+    for case in test_cases:
+      record['msg'] = case['msg']
+      record['args'] = case['args']
+      output['message'] = case['expected']
+
+      self.assertEqual(
+          json.loads(formatter.format(self.create_log_record(**record))),
+          output)
+
+  def test_record_with_format_character(self):
+    test_cases = [
+        {'msg': '%A', 'args': (), 'expected': '%A'},
+        {'msg': '%s', 'args': (), 'expected': '%s'},
+        {'msg': '%A%s', 'args': ('xy'), 'expected': '%A%s with args (xy)'},
+        {'msg': '%s%s', 'args': (1), 'expected': '%s%s with args (1)'},
+    ]
+
+    self.execute_multiple_cases(test_cases)
+
+  def test_record_with_arbitrary_messages(self):
+    test_cases = [
+        {'msg': ImportError('abc'), 'args': (), 'expected': 'abc'},
+        {'msg': TypeError('abc %s'), 'args': ('def'), 'expected': 'abc def'},
+    ]
+
+    self.execute_multiple_cases(test_cases)
+
+  def test_record_with_per_thread_info(self):
+    with logger.PerThreadLoggingContext(
+        work_item_id='workitem', stage_name='stage', step_name='step'):
+      formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+      record = self.create_log_record(**self.SAMPLE_RECORD)
+      log_output = json.loads(formatter.format(record))
+    expected_output = dict(self.SAMPLE_OUTPUT)
+    expected_output.update(
+        {'work': 'workitem', 'stage': 'stage', 'step': 'step'})
+    self.assertEqual(log_output, expected_output)
+
+  def test_nested_with_per_thread_info(self):
+    formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+    with logger.PerThreadLoggingContext(
+        work_item_id='workitem', stage_name='stage', step_name='step1'):
+      record = self.create_log_record(**self.SAMPLE_RECORD)
+      log_output1 = json.loads(formatter.format(record))
+
+      with logger.PerThreadLoggingContext(step_name='step2'):
+        record = self.create_log_record(**self.SAMPLE_RECORD)
+        log_output2 = json.loads(formatter.format(record))
+
+      record = self.create_log_record(**self.SAMPLE_RECORD)
+      log_output3 = json.loads(formatter.format(record))
+
+    record = self.create_log_record(**self.SAMPLE_RECORD)
+    log_output4 = json.loads(formatter.format(record))
+
+    self.assertEqual(log_output1, dict(
+        self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
+    self.assertEqual(log_output2, dict(
+        self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2'))
+    self.assertEqual(log_output3, dict(
+        self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
+    self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
+
+  def test_exception_record(self):
+    formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+    try:
+      raise ValueError('Something')
+    except ValueError:
+      attribs = dict(self.SAMPLE_RECORD)
+      attribs.update({'exc_info': sys.exc_info()})
+      record = self.create_log_record(**attribs)
+    log_output = json.loads(formatter.format(record))
+    # Check if exception type, its message, and stack trace information are in.
+    exn_output = log_output.pop('exception')
+    self.assertNotEqual(exn_output.find('ValueError: Something'), -1)
+    self.assertNotEqual(exn_output.find('logger_test.py'), -1)
+    self.assertEqual(log_output, self.SAMPLE_OUTPUT)
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters.pxd b/sdks/python/apache_beam/runners/worker/opcounters.pxd
new file mode 100644
index 0000000..5c1079f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters.pxd
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+cimport libc.stdint
+
+from apache_beam.utils.counters cimport Counter
+
+
+cdef class SumAccumulator(object):
+  cdef libc.stdint.int64_t _value
+  cpdef update(self, libc.stdint.int64_t value)
+  cpdef libc.stdint.int64_t value(self)
+
+
+cdef class OperationCounters(object):
+  cdef public _counter_factory
+  cdef public Counter element_counter
+  cdef public Counter mean_byte_counter
+  cdef public coder_impl
+  cdef public SumAccumulator active_accumulator
+  cdef public libc.stdint.int64_t _sample_counter
+  cdef public libc.stdint.int64_t _next_sample
+
+  cpdef update_from(self, windowed_value)
+  cdef inline do_sample(self, windowed_value)
+  cpdef update_collect(self)
+
+  cdef libc.stdint.int64_t _compute_next_sample(self, libc.stdint.int64_t i)
+  cdef inline bint _should_sample(self)
+  cpdef bint should_sample(self)

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py
new file mode 100644
index 0000000..56ce0db
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters.py
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""Counters collect the progress of the Worker for reporting to the service."""
+
+from __future__ import absolute_import
+import math
+import random
+
+from apache_beam.utils.counters import Counter
+
+
+class SumAccumulator(object):
+  """Accumulator for collecting byte counts."""
+
+  def __init__(self):
+    self._value = 0
+
+  def update(self, value):
+    self._value += value
+
+  def value(self):
+    return self._value
+
+
+class OperationCounters(object):
+  """The set of basic counters to attach to an Operation."""
+
+  def __init__(self, counter_factory, step_name, coder, output_index):
+    self._counter_factory = counter_factory
+    self.element_counter = counter_factory.get_counter(
+        '%s-out%d-ElementCount' % (step_name, output_index), Counter.SUM)
+    self.mean_byte_counter = counter_factory.get_counter(
+        '%s-out%d-MeanByteCount' % (step_name, output_index), Counter.MEAN)
+    self.coder_impl = coder.get_impl()
+    self.active_accumulator = None
+    self._sample_counter = 0
+    self._next_sample = 0
+
+  def update_from(self, windowed_value):
+    """Add one value to this counter."""
+    self.element_counter.update(1)
+    if self._should_sample():
+      self.do_sample(windowed_value)
+
+  def _observable_callback(self, inner_coder_impl, accumulator):
+    def _observable_callback_inner(value, is_encoded=False):
+      # TODO(ccy): If this stream is large, sample it as well.
+      # To do this, we'll need to compute the average size of elements
+      # in this stream to add the *total* size of this stream to accumulator.
+      # We'll also want make sure we sample at least some of this stream
+      # (as self.should_sample() may be sampling very sparsely by now).
+      if is_encoded:
+        size = len(value)
+        accumulator.update(size)
+      else:
+        accumulator.update(inner_coder_impl.estimate_size(value))
+    return _observable_callback_inner
+
+  def do_sample(self, windowed_value):
+    size, observables = (
+        self.coder_impl.get_estimated_size_and_observables(windowed_value))
+    if not observables:
+      self.mean_byte_counter.update(size)
+    else:
+      self.active_accumulator = SumAccumulator()
+      self.active_accumulator.update(size)
+      for observable, inner_coder_impl in observables:
+        observable.register_observer(
+            self._observable_callback(
+                inner_coder_impl, self.active_accumulator))
+
+  def update_collect(self):
+    """Collects the accumulated size estimates.
+
+    Now that the element has been processed, we ask our accumulator
+    for the total and store the result in a counter.
+    """
+    if self.active_accumulator is not None:
+      self.mean_byte_counter.update(self.active_accumulator.value())
+      self.active_accumulator = None
+
+  def _compute_next_sample(self, i):
+    # https://en.wikipedia.org/wiki/Reservoir_sampling#Fast_Approximation
+    gap = math.log(1.0 - random.random()) / math.log(1.0 - 10.0/i)
+    return i + math.floor(gap)
+
+  def _should_sample(self):
+    """Determines whether to sample the next element.
+
+    Size calculation can be expensive, so we don't do it for each element.
+    Because we need only an estimate of average size, we sample.
+
+    We always sample the first 10 elements, then the sampling rate
+    is approximately 10/N.  After reading N elements, of the next N,
+    we will sample approximately 10*ln(2) (about 7) elements.
+
+    This algorithm samples at the same rate as Reservoir Sampling, but
+    it never throws away early results.  (Because we keep only a
+    running accumulation, storage is not a problem, so there is no
+    need to discard earlier calculations.)
+
+    Because we accumulate and do not replace, our statistics are
+    biased toward early data.  If the data are distributed uniformly,
+    this is not a problem.  If the data change over time (i.e., the
+    element size tends to grow or shrink over time), our estimate will
+    show the bias.  We could correct this by giving weight N to each
+    sample, since each sample is a stand-in for the N/(10*ln(2))
+    samples around it, which is proportional to N.  Since we do not
+    expect biased data, for efficiency we omit the extra multiplication.
+    We could reduce the early-data bias by putting a lower bound on
+    the sampling rate.
+
+    Computing random.randint(1, self._sample_counter) for each element
+    is too slow, so when the sample size is big enough (we estimate 30
+    is big enough), we estimate the size of the gap after each sample.
+    This estimation allows us to call random much less often.
+
+    Returns:
+      True if it is time to compute another element's size.
+    """
+
+    self._sample_counter += 1
+    if self._next_sample == 0:
+      if random.randint(1, self._sample_counter) <= 10:
+        if self._sample_counter > 30:
+          self._next_sample = self._compute_next_sample(self._sample_counter)
+        return True
+      return False
+    elif self._sample_counter >= self._next_sample:
+      self._next_sample = self._compute_next_sample(self._sample_counter)
+      return True
+    return False
+
+  def should_sample(self):
+    # We create this separate method because the above "_should_sample()" method
+    # is marked as inline in Cython and thus can't be exposed to Python code.
+    return self._should_sample()
+
+  def __str__(self):
+    return '<%s [%s]>' % (self.__class__.__name__,
+                          ', '.join([str(x) for x in self.__iter__()]))
+
+  def __repr__(self):
+    return '<%s %s at %s>' % (self.__class__.__name__,
+                              [x for x in self.__iter__()], hex(id(self)))

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters_test.py b/sdks/python/apache_beam/runners/worker/opcounters_test.py
new file mode 100644
index 0000000..74561b8
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters_test.py
@@ -0,0 +1,149 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import math
+import random
+import unittest
+
+from apache_beam import coders
+from apache_beam.runners.worker.opcounters import OperationCounters
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.utils.counters import CounterFactory
+
+
+# Classes to test that we can handle a variety of objects.
+# These have to be at top level so the pickler can find them.
+
+
+class OldClassThatDoesNotImplementLen:  # pylint: disable=old-style-class
+
+  def __init__(self):
+    pass
+
+
+class ObjectThatDoesNotImplementLen(object):
+
+  def __init__(self):
+    pass
+
+
+class OperationCountersTest(unittest.TestCase):
+
+  def verify_counters(self, opcounts, expected_elements, expected_size=None):
+    self.assertEqual(expected_elements, opcounts.element_counter.value())
+    if expected_size is not None:
+      if math.isnan(expected_size):
+        self.assertTrue(math.isnan(opcounts.mean_byte_counter.value()))
+      else:
+        self.assertEqual(expected_size, opcounts.mean_byte_counter.value())
+
+  def test_update_int(self):
+    opcounts = OperationCounters(CounterFactory(), 'some-name',
+                                 coders.PickleCoder(), 0)
+    self.verify_counters(opcounts, 0)
+    opcounts.update_from(GlobalWindows.windowed_value(1))
+    self.verify_counters(opcounts, 1)
+
+  def test_update_str(self):
+    coder = coders.PickleCoder()
+    opcounts = OperationCounters(CounterFactory(), 'some-name',
+                                 coder, 0)
+    self.verify_counters(opcounts, 0, float('nan'))
+    value = GlobalWindows.windowed_value('abcde')
+    opcounts.update_from(value)
+    estimated_size = coder.estimate_size(value)
+    self.verify_counters(opcounts, 1, estimated_size)
+
+  def test_update_old_object(self):
+    coder = coders.PickleCoder()
+    opcounts = OperationCounters(CounterFactory(), 'some-name',
+                                 coder, 0)
+    self.verify_counters(opcounts, 0, float('nan'))
+    obj = OldClassThatDoesNotImplementLen()
+    value = GlobalWindows.windowed_value(obj)
+    opcounts.update_from(value)
+    estimated_size = coder.estimate_size(value)
+    self.verify_counters(opcounts, 1, estimated_size)
+
+  def test_update_new_object(self):
+    coder = coders.PickleCoder()
+    opcounts = OperationCounters(CounterFactory(), 'some-name',
+                                 coder, 0)
+    self.verify_counters(opcounts, 0, float('nan'))
+
+    obj = ObjectThatDoesNotImplementLen()
+    value = GlobalWindows.windowed_value(obj)
+    opcounts.update_from(value)
+    estimated_size = coder.estimate_size(value)
+    self.verify_counters(opcounts, 1, estimated_size)
+
+  def test_update_multiple(self):
+    coder = coders.PickleCoder()
+    total_size = 0
+    opcounts = OperationCounters(CounterFactory(), 'some-name',
+                                 coder, 0)
+    self.verify_counters(opcounts, 0, float('nan'))
+    value = GlobalWindows.windowed_value('abcde')
+    opcounts.update_from(value)
+    total_size += coder.estimate_size(value)
+    value = GlobalWindows.windowed_value('defghij')
+    opcounts.update_from(value)
+    total_size += coder.estimate_size(value)
+    self.verify_counters(opcounts, 2, float(total_size) / 2)
+    value = GlobalWindows.windowed_value('klmnop')
+    opcounts.update_from(value)
+    total_size += coder.estimate_size(value)
+    self.verify_counters(opcounts, 3, float(total_size) / 3)
+
+  def test_should_sample(self):
+    # Order of magnitude more buckets than highest constant in code under test.
+    buckets = [0] * 300
+    # The seed is arbitrary and exists just to ensure this test is robust.
+    # If you don't like this seed, try your own; the test should still pass.
+    random.seed(1717)
+    # Do enough runs that the expected hits even in the last buckets
+    # is big enough to expect some statistical smoothing.
+    total_runs = 10 * len(buckets)
+
+    # Fill the buckets.
+    for _ in xrange(total_runs):
+      opcounts = OperationCounters(CounterFactory(), 'some-name',
+                                   coders.PickleCoder(), 0)
+      for i in xrange(len(buckets)):
+        if opcounts.should_sample():
+          buckets[i] += 1
+
+    # Look at the buckets to see if they are likely.
+    for i in xrange(10):
+      self.assertEqual(total_runs, buckets[i])
+    for i in xrange(10, len(buckets)):
+      self.assertTrue(buckets[i] > 7 * total_runs / i,
+                      'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % (
+                          i, buckets[i],
+                          10 * total_runs / i,
+                          buckets[i] / (10.0 * total_runs / i)))
+      self.assertTrue(buckets[i] < 14 * total_runs / i,
+                      'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % (
+                          i, buckets[i],
+                          10 * total_runs / i,
+                          buckets[i] / (10.0 * total_runs / i)))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operation_specs.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py
new file mode 100644
index 0000000..977e165
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operation_specs.py
@@ -0,0 +1,368 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Worker utilities for representing MapTasks.
+
+Each MapTask represents a sequence of ParallelInstruction(s): read from a
+source, write to a sink, parallel do, etc.
+"""
+
+import collections
+
+from apache_beam import coders
+
+
+def build_worker_instruction(*args):
+  """Create an object representing a ParallelInstruction protobuf.
+
+  This will be a collections.namedtuple with a custom __str__ method.
+
+  Alas, this wrapper is not known to pylint, which thinks it creates
+  constants.  You may have to put a disable=invalid-name pylint
+  annotation on any use of this, depending on your names.
+
+  Args:
+    *args: first argument is the name of the type to create.  Should
+      start with "Worker".  Second arguments is alist of the
+      attributes of this object.
+  Returns:
+    A new class, a subclass of tuple, that represents the protobuf.
+  """
+  tuple_class = collections.namedtuple(*args)
+  tuple_class.__str__ = worker_object_to_string
+  tuple_class.__repr__ = worker_object_to_string
+  return tuple_class
+
+
+def worker_printable_fields(workerproto):
+  """Returns the interesting fields of a Worker* object."""
+  return ['%s=%s' % (name, value)
+          # _asdict is the only way and cannot subclass this generated class
+          # pylint: disable=protected-access
+          for name, value in workerproto._asdict().iteritems()
+          # want to output value 0 but not None nor []
+          if (value or value == 0)
+          and name not in
+          ('coder', 'coders', 'output_coders',
+           'elements',
+           'combine_fn', 'serialized_fn', 'window_fn',
+           'append_trailing_newlines', 'strip_trailing_newlines',
+           'compression_type', 'context',
+           'start_shuffle_position', 'end_shuffle_position',
+           'shuffle_reader_config', 'shuffle_writer_config')]
+
+
+def worker_object_to_string(worker_object):
+  """Returns a string compactly representing a Worker* object."""
+  return '%s(%s)' % (worker_object.__class__.__name__,
+                     ', '.join(worker_printable_fields(worker_object)))
+
+
+# All the following Worker* definitions will have these lint problems:
+# pylint: disable=invalid-name
+# pylint: disable=pointless-string-statement
+
+
+WorkerRead = build_worker_instruction(
+    'WorkerRead', ['source', 'output_coders'])
+"""Worker details needed to read from a source.
+
+Attributes:
+  source: a source object.
+  output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerSideInputSource = build_worker_instruction(
+    'WorkerSideInputSource', ['source', 'tag'])
+"""Worker details needed to read from a side input source.
+
+Attributes:
+  source: a source object.
+  tag: string tag for this side input.
+"""
+
+
+WorkerGroupingShuffleRead = build_worker_instruction(
+    'WorkerGroupingShuffleRead',
+    ['start_shuffle_position', 'end_shuffle_position',
+     'shuffle_reader_config', 'coder', 'output_coders'])
+"""Worker details needed to read from a grouping shuffle source.
+
+Attributes:
+  start_shuffle_position: An opaque string to be passed to the shuffle
+    source to indicate where to start reading.
+  end_shuffle_position: An opaque string to be passed to the shuffle
+    source to indicate where to stop reading.
+  shuffle_reader_config: An opaque string used to initialize the shuffle
+    reader. Contains things like connection endpoints for the shuffle
+    server appliance and various options.
+  coder: The KV coder used to decode shuffle entries.
+  output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerUngroupedShuffleRead = build_worker_instruction(
+    'WorkerUngroupedShuffleRead',
+    ['start_shuffle_position', 'end_shuffle_position',
+     'shuffle_reader_config', 'coder', 'output_coders'])
+"""Worker details needed to read from an ungrouped shuffle source.
+
+Attributes:
+  start_shuffle_position: An opaque string to be passed to the shuffle
+    source to indicate where to start reading.
+  end_shuffle_position: An opaque string to be passed to the shuffle
+    source to indicate where to stop reading.
+  shuffle_reader_config: An opaque string used to initialize the shuffle
+    reader. Contains things like connection endpoints for the shuffle
+    server appliance and various options.
+  coder: The value coder used to decode shuffle entries.
+"""
+
+
+WorkerWrite = build_worker_instruction(
+    'WorkerWrite', ['sink', 'input', 'output_coders'])
+"""Worker details needed to write to a sink.
+
+Attributes:
+  sink: a sink object.
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+  output_coders: 1-tuple, coder to use to estimate bytes written.
+"""
+
+
+WorkerInMemoryWrite = build_worker_instruction(
+    'WorkerInMemoryWrite',
+    ['output_buffer', 'write_windowed_values', 'input', 'output_coders'])
+"""Worker details needed to write to a in-memory sink.
+
+Used only for unit testing. It makes worker tests less cluttered with code like
+"write to a file and then check file contents".
+
+Attributes:
+  output_buffer: list to which output elements will be appended
+  write_windowed_values: whether to record the entire WindowedValue outputs,
+    or just the raw (unwindowed) value
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+  output_coders: 1-tuple, coder to use to estimate bytes written.
+"""
+
+
+WorkerShuffleWrite = build_worker_instruction(
+    'WorkerShuffleWrite',
+    ['shuffle_kind', 'shuffle_writer_config', 'input', 'output_coders'])
+"""Worker details needed to write to a shuffle sink.
+
+Attributes:
+  shuffle_kind: A string describing the shuffle kind. This can control the
+    way the worker interacts with the shuffle sink. The possible values are:
+    'ungrouped', 'group_keys', and 'group_keys_and_sort_values'.
+  shuffle_writer_config: An opaque string used to initialize the shuffle
+    write. Contains things like connection endpoints for the shuffle
+    server appliance and various options.
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+  output_coders: 1-tuple of the coder for input elements. If the
+    shuffle_kind is grouping, this is expected to be a KV coder.
+"""
+
+
+WorkerDoFn = build_worker_instruction(
+    'WorkerDoFn',
+    ['serialized_fn', 'output_tags', 'input', 'side_inputs', 'output_coders'])
+"""Worker details needed to run a DoFn.
+Attributes:
+  serialized_fn: A serialized DoFn object to be run for each input element.
+  output_tags: The string tags used to identify the outputs of a ParDo
+    operation. The tag is present even if the ParDo has just one output
+    (e.g., ['out'].
+  output_coders: array of coders, one for each output.
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+  side_inputs: A list of Worker...Read instances describing sources to be
+    used for getting values. The types supported right now are
+    WorkerInMemoryRead and WorkerTextRead.
+"""
+
+
+WorkerReifyTimestampAndWindows = build_worker_instruction(
+    'WorkerReifyTimestampAndWindows',
+    ['output_tags', 'input', 'output_coders'])
+"""Worker details needed to run a WindowInto.
+Attributes:
+  output_tags: The string tags used to identify the outputs of a ParDo
+    operation. The tag is present even if the ParDo has just one output
+    (e.g., ['out'].
+  output_coders: array of coders, one for each output.
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+"""
+
+
+WorkerMergeWindows = build_worker_instruction(
+    'WorkerMergeWindows',
+    ['window_fn', 'combine_fn', 'phase', 'output_tags', 'input', 'coders',
+     'context', 'output_coders'])
+"""Worker details needed to run a MergeWindows (aka. GroupAlsoByWindows).
+Attributes:
+  window_fn: A serialized Windowing object representing the windowing strategy.
+  combine_fn: A serialized CombineFn object to be used after executing the
+    GroupAlsoByWindows operation. May be None if not a combining operation.
+  phase: Possible values are 'all', 'add', 'merge', and 'extract'.
+    A runner optimizer may split the user combiner in 3 separate
+    phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees
+    fit. The phase attribute dictates which DoFn is actually running in
+    the worker. May be None if not a combining operation.
+  output_tags: The string tags used to identify the outputs of a ParDo
+    operation. The tag is present even if the ParDo has just one output
+    (e.g., ['out'].
+  output_coders: array of coders, one for each output.
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+  coders: A 2-tuple of coders (key, value) to encode shuffle entries.
+  context: The ExecutionContext object for the current work item.
+"""
+
+
+WorkerCombineFn = build_worker_instruction(
+    'WorkerCombineFn',
+    ['serialized_fn', 'phase', 'input', 'output_coders'])
+"""Worker details needed to run a CombineFn.
+Attributes:
+  serialized_fn: A serialized CombineFn object to be used.
+  phase: Possible values are 'all', 'add', 'merge', and 'extract'.
+    A runner optimizer may split the user combiner in 3 separate
+    phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees
+    fit. The phase attribute dictates which DoFn is actually running in
+    the worker.
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+  output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerPartialGroupByKey = build_worker_instruction(
+    'WorkerPartialGroupByKey',
+    ['combine_fn', 'input', 'output_coders'])
+"""Worker details needed to run a partial group-by-key.
+Attributes:
+  combine_fn: A serialized CombineFn object to be used.
+  input: A (producer index, output index) tuple representing the
+    ParallelInstruction operation whose output feeds into this operation.
+    The output index is 0 except for multi-output operations (like ParDo).
+  output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerFlatten = build_worker_instruction(
+    'WorkerFlatten',
+    ['inputs', 'output_coders'])
+"""Worker details needed to run a Flatten.
+Attributes:
+  inputs: A list of tuples, each (producer index, output index), representing
+    the ParallelInstruction operations whose output feeds into this operation.
+    The output index is 0 unless the input is from a multi-output
+    operation (such as ParDo).
+  output_coders: 1-tuple of the coder for the output.
+"""
+
+
+def get_coder_from_spec(coder_spec):
+  """Return a coder instance from a coder spec.
+
+  Args:
+    coder_spec: A dict where the value of the '@type' key is a pickled instance
+      of a Coder instance.
+
+  Returns:
+    A coder instance (has encode/decode methods).
+  """
+  assert coder_spec is not None
+
+  # Ignore the wrappers in these encodings.
+  # TODO(silviuc): Make sure with all the renamings that names below are ok.
+  if coder_spec['@type'] in ignored_wrappers:
+    assert len(coder_spec['component_encodings']) == 1
+    coder_spec = coder_spec['component_encodings'][0]
+    return get_coder_from_spec(coder_spec)
+
+  # Handle a few well known types of coders.
+  if coder_spec['@type'] == 'kind:pair':
+    assert len(coder_spec['component_encodings']) == 2
+    component_coders = [
+        get_coder_from_spec(c) for c in coder_spec['component_encodings']]
+    return coders.TupleCoder(component_coders)
+  elif coder_spec['@type'] == 'kind:stream':
+    assert len(coder_spec['component_encodings']) == 1
+    return coders.IterableCoder(
+        get_coder_from_spec(coder_spec['component_encodings'][0]))
+  elif coder_spec['@type'] == 'kind:windowed_value':
+    assert len(coder_spec['component_encodings']) == 2
+    value_coder, window_coder = [
+        get_coder_from_spec(c) for c in coder_spec['component_encodings']]
+    return coders.WindowedValueCoder(value_coder, window_coder=window_coder)
+  elif coder_spec['@type'] == 'kind:interval_window':
+    assert ('component_encodings' not in coder_spec
+            or len(coder_spec['component_encodings'] == 0))
+    return coders.IntervalWindowCoder()
+  elif coder_spec['@type'] == 'kind:global_window':
+    assert ('component_encodings' not in coder_spec
+            or not coder_spec['component_encodings'])
+    return coders.GlobalWindowCoder()
+  elif coder_spec['@type'] == 'kind:length_prefix':
+    assert len(coder_spec['component_encodings']) == 1
+    return coders.LengthPrefixCoder(
+        get_coder_from_spec(coder_spec['component_encodings'][0]))
+
+  # We pass coders in the form "<coder_name>$<pickled_data>" to make the job
+  # description JSON more readable.
+  return coders.deserialize_coder(coder_spec['@type'])
+
+
+class MapTask(object):
+  """A map task decoded into operations and ready to be executed.
+
+  Attributes:
+    operations: A list of Worker* object created by parsing the instructions
+      within the map task.
+    stage_name: The name of this map task execution stage.
+    system_names: The system names of the step corresponding to each map task
+      operation in the execution graph.
+    step_names: The names of the step corresponding to each map task operation.
+    original_names: The internal name of a step in the original workflow graph.
+  """
+
+  def __init__(
+      self, operations, stage_name, system_names, step_names, original_names):
+    self.operations = operations
+    self.stage_name = stage_name
+    self.system_names = system_names
+    self.step_names = step_names
+    self.original_names = original_names
+
+  def __str__(self):
+    return '<%s %s steps=%s>' % (self.__class__.__name__, self.stage_name,
+                                 '+'.join(self.step_names))

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd
new file mode 100644
index 0000000..2b4e526
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+
+from apache_beam.runners.common cimport Receiver
+from apache_beam.runners.worker cimport opcounters
+from apache_beam.utils.windowed_value cimport WindowedValue
+from apache_beam.metrics.execution cimport ScopedMetricsContainer
+
+
+cdef WindowedValue _globally_windowed_value
+cdef type _global_window_type
+
+cdef class ConsumerSet(Receiver):
+  cdef list consumers
+  cdef opcounters.OperationCounters opcounter
+  cdef public step_name
+  cdef public output_index
+  cdef public coder
+
+  cpdef receive(self, WindowedValue windowed_value)
+  cpdef update_counters_start(self, WindowedValue windowed_value)
+  cpdef update_counters_finish(self)
+
+
+cdef class Operation(object):
+  cdef readonly operation_name
+  cdef readonly spec
+  cdef object consumers
+  cdef readonly counter_factory
+  cdef public metrics_container
+  cdef public ScopedMetricsContainer scoped_metrics_container
+  # Public for access by Fn harness operations.
+  # TODO(robertwb): Cythonize FnHarness.
+  cdef public list receivers
+  cdef readonly bint debug_logging_enabled
+
+  cdef public step_name  # initialized lazily
+
+  cdef readonly object state_sampler
+
+  cdef readonly object scoped_start_state
+  cdef readonly object scoped_process_state
+  cdef readonly object scoped_finish_state
+
+  cpdef start(self)
+  cpdef process(self, WindowedValue windowed_value)
+  cpdef finish(self)
+
+  cpdef output(self, WindowedValue windowed_value, int output_index=*)
+
+cdef class ReadOperation(Operation):
+  @cython.locals(windowed_value=WindowedValue)
+  cpdef start(self)
+
+cdef class DoOperation(Operation):
+  cdef object dofn_runner
+  cdef Receiver dofn_receiver
+
+cdef class CombineOperation(Operation):
+  cdef object phased_combine_fn
+
+cdef class FlattenOperation(Operation):
+  pass
+
+cdef class PGBKCVOperation(Operation):
+  cdef public object combine_fn
+  cdef public object combine_fn_add_input
+  cdef dict table
+  cdef long max_keys
+  cdef long key_count
+
+  cpdef output_key(self, tuple wkey, value)
+

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
new file mode 100644
index 0000000..5dbe57e
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -0,0 +1,651 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""Worker operations executor."""
+
+import collections
+import itertools
+import logging
+
+from apache_beam import pvalue
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.metrics.execution import MetricsContainer
+from apache_beam.metrics.execution import ScopedMetricsContainer
+from apache_beam.runners import common
+from apache_beam.runners.common import Receiver
+from apache_beam.runners.dataflow.internal.names import PropertyNames
+from apache_beam.runners.worker import logger
+from apache_beam.runners.worker import opcounters
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import sideinputs
+from apache_beam.transforms import combiners
+from apache_beam.transforms import core
+from apache_beam.transforms import sideinputs as apache_sideinputs
+from apache_beam.transforms.combiners import curry_combine_fn
+from apache_beam.transforms.combiners import PhasedCombineFnExecutor
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.utils.windowed_value import WindowedValue
+
+# Allow some "pure mode" declarations.
+try:
+  import cython
+except ImportError:
+  class FakeCython(object):
+    @staticmethod
+    def cast(type, value):
+      return value
+  globals()['cython'] = FakeCython()
+
+
+_globally_windowed_value = GlobalWindows.windowed_value(None)
+_global_window_type = type(_globally_windowed_value.windows[0])
+
+
+class ConsumerSet(Receiver):
+  """A ConsumerSet represents a graph edge between two Operation nodes.
+
+  The ConsumerSet object collects information from the output of the
+  Operation at one end of its edge and the input of the Operation at
+  the other edge.
+  ConsumerSet are attached to the outputting Operation.
+  """
+
+  def __init__(
+      self, counter_factory, step_name, output_index, consumers, coder):
+    self.consumers = consumers
+    self.opcounter = opcounters.OperationCounters(
+        counter_factory, step_name, coder, output_index)
+    # Used in repr.
+    self.step_name = step_name
+    self.output_index = output_index
+    self.coder = coder
+
+  def output(self, windowed_value):  # For old SDKs.
+    self.receive(windowed_value)
+
+  def receive(self, windowed_value):
+    self.update_counters_start(windowed_value)
+    for consumer in self.consumers:
+      cython.cast(Operation, consumer).process(windowed_value)
+    self.update_counters_finish()
+
+  def update_counters_start(self, windowed_value):
+    self.opcounter.update_from(windowed_value)
+
+  def update_counters_finish(self):
+    self.opcounter.update_collect()
+
+  def __repr__(self):
+    return '%s[%s.out%s, coder=%s, len(consumers)=%s]' % (
+        self.__class__.__name__, self.step_name, self.output_index, self.coder,
+        len(self.consumers))
+
+
+class Operation(object):
+  """An operation representing the live version of a work item specification.
+
+  An operation can have one or more outputs and for each output it can have
+  one or more receiver operations that will take that as input.
+  """
+
+  def __init__(self, operation_name, spec, counter_factory, state_sampler):
+    """Initializes a worker operation instance.
+
+    Args:
+      operation_name: The system name assigned by the runner for this
+        operation.
+      spec: A operation_specs.Worker* instance.
+      counter_factory: The CounterFactory to use for our counters.
+      state_sampler: The StateSampler for the current operation.
+    """
+    self.operation_name = operation_name
+    self.spec = spec
+    self.counter_factory = counter_factory
+    self.consumers = collections.defaultdict(list)
+
+    self.state_sampler = state_sampler
+    self.scoped_start_state = self.state_sampler.scoped_state(
+        self.operation_name + '-start')
+    self.scoped_process_state = self.state_sampler.scoped_state(
+        self.operation_name + '-process')
+    self.scoped_finish_state = self.state_sampler.scoped_state(
+        self.operation_name + '-finish')
+    # TODO(ccy): the '-abort' state can be added when the abort is supported in
+    # Operations.
+
+  def start(self):
+    """Start operation."""
+    self.debug_logging_enabled = logging.getLogger().isEnabledFor(
+        logging.DEBUG)
+    # Everything except WorkerSideInputSource, which is not a
+    # top-level operation, should have output_coders
+    if getattr(self.spec, 'output_coders', None):
+      self.receivers = [ConsumerSet(self.counter_factory, self.step_name,
+                                    i, self.consumers[i], coder)
+                        for i, coder in enumerate(self.spec.output_coders)]
+
+  def finish(self):
+    """Finish operation."""
+    pass
+
+  def process(self, o):
+    """Process element in operation."""
+    pass
+
+  def output(self, windowed_value, output_index=0):
+    cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value)
+
+  def add_receiver(self, operation, output_index=0):
+    """Adds a receiver operation for the specified output."""
+    self.consumers[output_index].append(operation)
+
+  def __str__(self):
+    """Generates a useful string for this object.
+
+    Compactly displays interesting fields.  In particular, pickled
+    fields are not displayed.  Note that we collapse the fields of the
+    contained Worker* object into this object, since there is a 1-1
+    mapping between Operation and operation_specs.Worker*.
+
+    Returns:
+      Compact string representing this object.
+    """
+    return self.str_internal()
+
+  def str_internal(self, is_recursive=False):
+    """Internal helper for __str__ that supports recursion.
+
+    When recursing on receivers, keep the output short.
+    Args:
+      is_recursive: whether to omit some details, particularly receivers.
+    Returns:
+      Compact string representing this object.
+    """
+    printable_name = self.__class__.__name__
+    if hasattr(self, 'step_name'):
+      printable_name += ' %s' % self.step_name
+      if is_recursive:
+        # If we have a step name, stop here, no more detail needed.
+        return '<%s>' % printable_name
+
+    if self.spec is None:
+      printable_fields = []
+    else:
+      printable_fields = operation_specs.worker_printable_fields(self.spec)
+
+    if not is_recursive and getattr(self, 'receivers', []):
+      printable_fields.append('receivers=[%s]' % ', '.join([
+          str(receiver) for receiver in self.receivers]))
+
+    return '<%s %s>' % (printable_name, ', '.join(printable_fields))
+
+
+class ReadOperation(Operation):
+
+  def start(self):
+    with self.scoped_start_state:
+      super(ReadOperation, self).start()
+      range_tracker = self.spec.source.source.get_range_tracker(
+          self.spec.source.start_position, self.spec.source.stop_position)
+      for value in self.spec.source.source.read(range_tracker):
+        if isinstance(value, WindowedValue):
+          windowed_value = value
+        else:
+          windowed_value = _globally_windowed_value.with_value(value)
+        self.output(windowed_value)
+
+
+class InMemoryWriteOperation(Operation):
+  """A write operation that will write to an in-memory sink."""
+
+  def process(self, o):
+    with self.scoped_process_state:
+      if self.debug_logging_enabled:
+        logging.debug('Processing [%s] in %s', o, self)
+      self.spec.output_buffer.append(
+          o if self.spec.write_windowed_values else o.value)
+
+
+class _TaggedReceivers(dict):
+
+  class NullReceiver(Receiver):
+
+    def receive(self, element):
+      pass
+
+    # For old SDKs.
+    def output(self, element):
+      pass
+
+  def __missing__(self, unused_key):
+    if not getattr(self, '_null_receiver', None):
+      self._null_receiver = _TaggedReceivers.NullReceiver()
+    return self._null_receiver
+
+
+class DoOperation(Operation):
+  """A Do operation that will execute a custom DoFn for each input element."""
+
+  def _read_side_inputs(self, tags_and_types):
+    """Generator reading side inputs in the order prescribed by tags_and_types.
+
+    Args:
+      tags_and_types: List of tuples (tag, type). Each side input has a string
+        tag that is specified in the worker instruction. The type is actually
+        a boolean which is True for singleton input (read just first value)
+        and False for collection input (read all values).
+
+    Yields:
+      With each iteration it yields the result of reading an entire side source
+      either in singleton or collection mode according to the tags_and_types
+      argument.
+    """
+    # We will read the side inputs in the order prescribed by the
+    # tags_and_types argument because this is exactly the order needed to
+    # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn
+    # getting the side inputs.
+    #
+    # Note that for each tag there could be several read operations in the
+    # specification. This can happen for instance if the source has been
+    # sharded into several files.
+    for side_tag, view_class, view_options in tags_and_types:
+      sources = []
+      # Using the side_tag in the lambda below will trigger a pylint warning.
+      # However in this case it is fine because the lambda is used right away
+      # while the variable has the value assigned by the current iteration of
+      # the for loop.
+      # pylint: disable=cell-var-from-loop
+      for si in itertools.ifilter(
+          lambda o: o.tag == side_tag, self.spec.side_inputs):
+        if not isinstance(si, operation_specs.WorkerSideInputSource):
+          raise NotImplementedError('Unknown side input type: %r' % si)
+        sources.append(si.source)
+      iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
+
+      # Backwards compatibility for pre BEAM-733 SDKs.
+      if isinstance(view_options, tuple):
+        if view_class == pvalue.SingletonPCollectionView:
+          has_default, default = view_options
+          view_options = {'default': default} if has_default else {}
+        else:
+          view_options = {}
+
+      yield apache_sideinputs.SideInputMap(
+          view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
+
+  def start(self):
+    with self.scoped_start_state:
+      super(DoOperation, self).start()
+
+      # See fn_data in dataflow_runner.py
+      fn, args, kwargs, tags_and_types, window_fn = (
+          pickler.loads(self.spec.serialized_fn))
+
+      state = common.DoFnState(self.counter_factory)
+      state.step_name = self.step_name
+
+      # TODO(silviuc): What is the proper label here? PCollection being
+      # processed?
+      context = common.DoFnContext('label', state=state)
+      # Tag to output index map used to dispatch the side output values emitted
+      # by the DoFn function to the appropriate receivers. The main output is
+      # tagged with None and is associated with its corresponding index.
+      tagged_receivers = _TaggedReceivers()
+
+      output_tag_prefix = PropertyNames.OUT + '_'
+      for index, tag in enumerate(self.spec.output_tags):
+        if tag == PropertyNames.OUT:
+          original_tag = None
+        elif tag.startswith(output_tag_prefix):
+          original_tag = tag[len(output_tag_prefix):]
+        else:
+          raise ValueError('Unexpected output name for operation: %s' % tag)
+        tagged_receivers[original_tag] = self.receivers[index]
+
+      self.dofn_runner = common.DoFnRunner(
+          fn, args, kwargs, self._read_side_inputs(tags_and_types),
+          window_fn, context, tagged_receivers,
+          logger, self.step_name,
+          scoped_metrics_container=self.scoped_metrics_container)
+      self.dofn_receiver = (self.dofn_runner
+                            if isinstance(self.dofn_runner, Receiver)
+                            else DoFnRunnerReceiver(self.dofn_runner))
+
+      self.dofn_runner.start()
+
+  def finish(self):
+    with self.scoped_finish_state:
+      self.dofn_runner.finish()
+
+  def process(self, o):
+    with self.scoped_process_state:
+      self.dofn_receiver.receive(o)
+
+
+class DoFnRunnerReceiver(Receiver):
+
+  def __init__(self, dofn_runner):
+    self.dofn_runner = dofn_runner
+
+  def receive(self, windowed_value):
+    self.dofn_runner.process(windowed_value)
+
+
+class CombineOperation(Operation):
+  """A Combine operation executing a CombineFn for each input element."""
+
+  def __init__(self, operation_name, spec, counter_factory, state_sampler):
+    super(CombineOperation, self).__init__(
+        operation_name, spec, counter_factory, state_sampler)
+    # Combiners do not accept deferred side-inputs (the ignored fourth argument)
+    # and therefore the code to handle the extra args/kwargs is simpler than for
+    # the DoFn's of ParDo.
+    fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3]
+    self.phased_combine_fn = (
+        PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs))
+
+  def finish(self):
+    logging.debug('Finishing %s', self)
+
+  def process(self, o):
+    if self.debug_logging_enabled:
+      logging.debug('Processing [%s] in %s', o, self)
+    key, values = o.value
+    with self.scoped_metrics_container:
+      self.output(
+          o.with_value((key, self.phased_combine_fn.apply(values))))
+
+
+def create_pgbk_op(step_name, spec, counter_factory, state_sampler):
+  if spec.combine_fn:
+    return PGBKCVOperation(step_name, spec, counter_factory, state_sampler)
+  else:
+    return PGBKOperation(step_name, spec, counter_factory, state_sampler)
+
+
+class PGBKOperation(Operation):
+  """Partial group-by-key operation.
+
+  This takes (windowed) input (key, value) tuples and outputs
+  (key, [value]) tuples, performing a best effort group-by-key for
+  values in this bundle, memory permitting.
+  """
+
+  def __init__(self, operation_name, spec, counter_factory, state_sampler):
+    super(PGBKOperation, self).__init__(
+        operation_name, spec, counter_factory, state_sampler)
+    assert not self.spec.combine_fn
+    self.table = collections.defaultdict(list)
+    self.size = 0
+    # TODO(robertwb) Make this configurable.
+    self.max_size = 10 * 1000
+
+  def process(self, o):
+    # TODO(robertwb): Structural (hashable) values.
+    key = o.value[0], tuple(o.windows)
+    self.table[key].append(o)
+    self.size += 1
+    if self.size > self.max_size:
+      self.flush(9 * self.max_size // 10)
+
+  def finish(self):
+    self.flush(0)
+
+  def flush(self, target):
+    limit = self.size - target
+    for ix, (kw, vs) in enumerate(self.table.items()):
+      if ix >= limit:
+        break
+      del self.table[kw]
+      key, windows = kw
+      output_value = [v.value[1] for v in vs]
+      windowed_value = WindowedValue(
+          (key, output_value),
+          vs[0].timestamp, windows)
+      self.output(windowed_value)
+
+
+class PGBKCVOperation(Operation):
+
+  def __init__(self, operation_name, spec, counter_factory, state_sampler):
+    super(PGBKCVOperation, self).__init__(
+        operation_name, spec, counter_factory, state_sampler)
+    # Combiners do not accept deferred side-inputs (the ignored fourth
+    # argument) and therefore the code to handle the extra args/kwargs is
+    # simpler than for the DoFn's of ParDo.
+    fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
+    self.combine_fn = curry_combine_fn(fn, args, kwargs)
+    if (getattr(fn.add_input, 'im_func', None)
+        is core.CombineFn.add_input.im_func):
+      # Old versions of the SDK have CombineFns that don't implement add_input.
+      self.combine_fn_add_input = (
+          lambda a, e: self.combine_fn.add_inputs(a, [e]))
+    else:
+      self.combine_fn_add_input = self.combine_fn.add_input
+    # Optimization for the (known tiny accumulator, often wide keyspace)
+    # combine functions.
+    # TODO(b/36567833): Bound by in-memory size rather than key count.
+    self.max_keys = (
+        1000 * 1000 if
+        isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or
+        # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized
+        # combiners to the short list above.
+        (isinstance(fn, core.CallableWrapperCombineFn) and
+         fn._fn in (min, max, sum)) else 100 * 1000)  # pylint: disable=protected-access
+    self.key_count = 0
+    self.table = {}
+
+  def process(self, wkv):
+    key, value = wkv.value
+    # pylint: disable=unidiomatic-typecheck
+    # Optimization for the global window case.
+    if len(wkv.windows) == 1 and type(wkv.windows[0]) is _global_window_type:
+      wkey = 0, key
+    else:
+      wkey = tuple(wkv.windows), key
+    entry = self.table.get(wkey, None)
+    if entry is None:
+      if self.key_count >= self.max_keys:
+        target = self.key_count * 9 // 10
+        old_wkeys = []
+        # TODO(robertwb): Use an LRU cache?
+        for old_wkey, old_wvalue in self.table.iteritems():
+          old_wkeys.append(old_wkey)  # Can't mutate while iterating.
+          self.output_key(old_wkey, old_wvalue[0])
+          self.key_count -= 1
+          if self.key_count <= target:
+            break
+        for old_wkey in reversed(old_wkeys):
+          del self.table[old_wkey]
+      self.key_count += 1
+      # We save the accumulator as a one element list so we can efficiently
+      # mutate when new values are added without searching the cache again.
+      entry = self.table[wkey] = [self.combine_fn.create_accumulator()]
+    entry[0] = self.combine_fn_add_input(entry[0], value)
+
+  def finish(self):
+    for wkey, value in self.table.iteritems():
+      self.output_key(wkey, value[0])
+    self.table = {}
+    self.key_count = 0
+
+  def output_key(self, wkey, value):
+    windows, key = wkey
+    if windows is 0:
+      self.output(_globally_windowed_value.with_value((key, value)))
+    else:
+      self.output(WindowedValue((key, value), windows[0].end, windows))
+
+
+class FlattenOperation(Operation):
+  """Flatten operation.
+
+  Receives one or more producer operations, outputs just one list
+  with all the items.
+  """
+
+  def process(self, o):
+    if self.debug_logging_enabled:
+      logging.debug('Processing [%s] in %s', o, self)
+    self.output(o)
+
+
+def create_operation(operation_name, spec, counter_factory, step_name,
+                     state_sampler, test_shuffle_source=None,
+                     test_shuffle_sink=None, is_streaming=False):
+  """Create Operation object for given operation specification."""
+  if isinstance(spec, operation_specs.WorkerRead):
+    if isinstance(spec.source, iobase.SourceBundle):
+      op = ReadOperation(
+          operation_name, spec, counter_factory, state_sampler)
+    else:
+      from dataflow_worker.native_operations import NativeReadOperation
+      op = NativeReadOperation(
+          operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerWrite):
+    from dataflow_worker.native_operations import NativeWriteOperation
+    op = NativeWriteOperation(
+        operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerCombineFn):
+    op = CombineOperation(
+        operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerPartialGroupByKey):
+    op = create_pgbk_op(operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerDoFn):
+    op = DoOperation(operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerGroupingShuffleRead):
+    from dataflow_worker.shuffle_operations import GroupedShuffleReadOperation
+    op = GroupedShuffleReadOperation(
+        operation_name, spec, counter_factory, state_sampler,
+        shuffle_source=test_shuffle_source)
+  elif isinstance(spec, operation_specs.WorkerUngroupedShuffleRead):
+    from dataflow_worker.shuffle_operations import UngroupedShuffleReadOperation
+    op = UngroupedShuffleReadOperation(
+        operation_name, spec, counter_factory, state_sampler,
+        shuffle_source=test_shuffle_source)
+  elif isinstance(spec, operation_specs.WorkerInMemoryWrite):
+    op = InMemoryWriteOperation(
+        operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerShuffleWrite):
+    from dataflow_worker.shuffle_operations import ShuffleWriteOperation
+    op = ShuffleWriteOperation(
+        operation_name, spec, counter_factory, state_sampler,
+        shuffle_sink=test_shuffle_sink)
+  elif isinstance(spec, operation_specs.WorkerFlatten):
+    op = FlattenOperation(
+        operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerMergeWindows):
+    from dataflow_worker.shuffle_operations import BatchGroupAlsoByWindowsOperation
+    from dataflow_worker.shuffle_operations import StreamingGroupAlsoByWindowsOperation
+    if is_streaming:
+      op = StreamingGroupAlsoByWindowsOperation(
+          operation_name, spec, counter_factory, state_sampler)
+    else:
+      op = BatchGroupAlsoByWindowsOperation(
+          operation_name, spec, counter_factory, state_sampler)
+  elif isinstance(spec, operation_specs.WorkerReifyTimestampAndWindows):
+    from dataflow_worker.shuffle_operations import ReifyTimestampAndWindowsOperation
+    op = ReifyTimestampAndWindowsOperation(
+        operation_name, spec, counter_factory, state_sampler)
+  else:
+    raise TypeError('Expected an instance of operation_specs.Worker* class '
+                    'instead of %s' % (spec,))
+  op.step_name = step_name
+  op.metrics_container = MetricsContainer(step_name)
+  op.scoped_metrics_container = ScopedMetricsContainer(op.metrics_container)
+  return op
+
+
+class SimpleMapTaskExecutor(object):
+  """An executor for map tasks.
+
+   Stores progress of the read operation that is the first operation of a map
+   task.
+  """
+
+  def __init__(
+      self, map_task, counter_factory, state_sampler,
+      test_shuffle_source=None, test_shuffle_sink=None):
+    """Initializes SimpleMapTaskExecutor.
+
+    Args:
+      map_task: The map task we are to run.
+      counter_factory: The CounterFactory instance for the work item.
+      state_sampler: The StateSampler tracking the execution step.
+      test_shuffle_source: Used during tests for dependency injection into
+        shuffle read operation objects.
+      test_shuffle_sink: Used during tests for dependency injection into
+        shuffle write operation objects.
+    """
+
+    self._map_task = map_task
+    self._counter_factory = counter_factory
+    self._ops = []
+    self._state_sampler = state_sampler
+    self._test_shuffle_source = test_shuffle_source
+    self._test_shuffle_sink = test_shuffle_sink
+
+  def operations(self):
+    return self._ops[:]
+
+  def execute(self):
+    """Executes all the operation_specs.Worker* instructions in a map task.
+
+    We update the map_task with the execution status, expressed as counters.
+
+    Raises:
+      RuntimeError: if we find more than on read instruction in task spec.
+      TypeError: if the spec parameter is not an instance of the recognized
+        operation_specs.Worker* classes.
+    """
+
+    # operations is a list of operation_specs.Worker* instances.
+    # The order of the elements is important because the inputs use
+    # list indexes as references.
+
+    step_names = (
+        self._map_task.step_names or [None] * len(self._map_task.operations))
+    for ix, spec in enumerate(self._map_task.operations):
+      # This is used for logging and assigning names to counters.
+      operation_name = self._map_task.system_names[ix]
+      step_name = step_names[ix]
+      op = create_operation(
+          operation_name, spec, self._counter_factory, step_name,
+          self._state_sampler,
+          test_shuffle_source=self._test_shuffle_source,
+          test_shuffle_sink=self._test_shuffle_sink)
+      self._ops.append(op)
+
+      # Add receiver operations to the appropriate producers.
+      if hasattr(op.spec, 'input'):
+        producer, output_index = op.spec.input
+        self._ops[producer].add_receiver(op, output_index)
+      # Flatten has 'inputs', not 'input'
+      if hasattr(op.spec, 'inputs'):
+        for producer, output_index in op.spec.inputs:
+          self._ops[producer].add_receiver(op, output_index)
+
+    for ix, op in reversed(list(enumerate(self._ops))):
+      logging.debug('Starting op %d %s', ix, op)
+      with op.scoped_metrics_container:
+        op.start()
+    for op in self._ops:
+      with op.scoped_metrics_container:
+        op.finish()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
new file mode 100644
index 0000000..6907f6e
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -0,0 +1,451 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""SDK harness for executing Python Fns via the Fn API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import logging
+import Queue as queue
+import threading
+import traceback
+import zlib
+
+import dill
+from google.protobuf import wrappers_pb2
+
+from apache_beam.coders import coder_impl
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.internal import pickler
+from apache_beam.runners.dataflow.native_io import iobase
+from apache_beam.utils import counters
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+try:
+  from apache_beam.runners.worker import statesampler
+except ImportError:
+  from apache_beam.runners.worker import statesampler_fake as statesampler
+from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory
+
+
+DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1'
+DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1'
+IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1'
+PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1'
+PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1'
+# TODO(vikasrk): Fix this once runner sends appropriate python urns.
+PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1'
+PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1'
+
+
+class RunnerIOOperation(operations.Operation):
+  """Common baseclass for runner harness IO operations."""
+
+  def __init__(self, operation_name, step_name, consumers, counter_factory,
+               state_sampler, windowed_coder, target, data_channel):
+    super(RunnerIOOperation, self).__init__(
+        operation_name, None, counter_factory, state_sampler)
+    self.windowed_coder = windowed_coder
+    self.step_name = step_name
+    # target represents the consumer for the bytes in the data plane for a
+    # DataInputOperation or a producer of these bytes for a DataOutputOperation.
+    self.target = target
+    self.data_channel = data_channel
+    for _, consumer_ops in consumers.items():
+      for consumer in consumer_ops:
+        self.add_receiver(consumer, 0)
+
+
+class DataOutputOperation(RunnerIOOperation):
+  """A sink-like operation that gathers outputs to be sent back to the runner.
+  """
+
+  def set_output_stream(self, output_stream):
+    self.output_stream = output_stream
+
+  def process(self, windowed_value):
+    self.windowed_coder.get_impl().encode_to_stream(
+        windowed_value, self.output_stream, True)
+
+  def finish(self):
+    self.output_stream.close()
+    super(DataOutputOperation, self).finish()
+
+
+class DataInputOperation(RunnerIOOperation):
+  """A source-like operation that gathers input from the runner.
+  """
+
+  def __init__(self, operation_name, step_name, consumers, counter_factory,
+               state_sampler, windowed_coder, input_target, data_channel):
+    super(DataInputOperation, self).__init__(
+        operation_name, step_name, consumers, counter_factory, state_sampler,
+        windowed_coder, target=input_target, data_channel=data_channel)
+    # We must do this manually as we don't have a spec or spec.output_coders.
+    self.receivers = [
+        operations.ConsumerSet(self.counter_factory, self.step_name, 0,
+                               consumers.itervalues().next(),
+                               self.windowed_coder)]
+
+  def process(self, windowed_value):
+    self.output(windowed_value)
+
+  def process_encoded(self, encoded_windowed_values):
+    input_stream = coder_impl.create_InputStream(encoded_windowed_values)
+    while input_stream.size() > 0:
+      decoded_value = self.windowed_coder.get_impl().decode_from_stream(
+          input_stream, True)
+      self.output(decoded_value)
+
+
+# TODO(robertwb): Revise side input API to not be in terms of native sources.
+# This will enable lookups, but there's an open question as to how to handle
+# custom sources without forcing intermediate materialization.  This seems very
+# related to the desire to inject key and window preserving [Splittable]DoFns
+# into the view computation.
+class SideInputSource(iobase.NativeSource, iobase.NativeSourceReader):
+  """A 'source' for reading side inputs via state API calls.
+  """
+
+  def __init__(self, state_handler, state_key, coder):
+    self._state_handler = state_handler
+    self._state_key = state_key
+    self._coder = coder
+
+  def reader(self):
+    return self
+
+  @property
+  def returns_windowed_values(self):
+    return True
+
+  def __enter__(self):
+    return self
+
+  def __exit__(self, *exn_info):
+    pass
+
+  def __iter__(self):
+    # TODO(robertwb): Support pagination.
+    input_stream = coder_impl.create_InputStream(
+        self._state_handler.Get(self._state_key).data)
+    while input_stream.size() > 0:
+      yield self._coder.get_impl().decode_from_stream(input_stream, True)
+
+
+def unpack_and_deserialize_py_fn(function_spec):
+  """Returns unpacked and deserialized object from function spec proto."""
+  return pickler.loads(unpack_function_spec_data(function_spec))
+
+
+def unpack_function_spec_data(function_spec):
+  """Returns unpacked data from function spec proto."""
+  data = wrappers_pb2.BytesValue()
+  function_spec.data.Unpack(data)
+  return data.value
+
+
+# pylint: disable=redefined-builtin
+def serialize_and_pack_py_fn(fn, urn, id=None):
+  """Returns serialized and packed function in a function spec proto."""
+  return pack_function_spec_data(pickler.dumps(fn), urn, id)
+# pylint: enable=redefined-builtin
+
+
+# pylint: disable=redefined-builtin
+def pack_function_spec_data(value, urn, id=None):
+  """Returns packed data in a function spec proto."""
+  data = wrappers_pb2.BytesValue(value=value)
+  fn_proto = beam_fn_api_pb2.FunctionSpec(urn=urn)
+  fn_proto.data.Pack(data)
+  if id:
+    fn_proto.id = id
+  return fn_proto
+# pylint: enable=redefined-builtin
+
+
+# TODO(vikasrk): move this method to ``coders.py`` in the SDK.
+def load_compressed(compressed_data):
+  """Returns a decompressed and deserialized python object."""
+  # Note: SDK uses ``pickler.dumps`` to serialize certain python objects
+  # (like sources), which involves serialization, compression and base64
+  # encoding. We cannot directly use ``pickler.loads`` for
+  # deserialization, as the runner would have already base64 decoded the
+  # data. So we only need to decompress and deserialize.
+
+  data = zlib.decompress(compressed_data)
+  try:
+    return dill.loads(data)
+  except Exception:          # pylint: disable=broad-except
+    dill.dill._trace(True)   # pylint: disable=protected-access
+    return dill.loads(data)
+  finally:
+    dill.dill._trace(False)  # pylint: disable=protected-access
+
+
+class SdkHarness(object):
+
+  def __init__(self, control_channel):
+    self._control_channel = control_channel
+    self._data_channel_factory = GrpcClientDataChannelFactory()
+
+  def run(self):
+    contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel)
+    # TODO(robertwb): Wire up to new state api.
+    state_stub = None
+    self.worker = SdkWorker(state_stub, self._data_channel_factory)
+
+    responses = queue.Queue()
+    no_more_work = object()
+
+    def get_responses():
+      while True:
+        response = responses.get()
+        if response is no_more_work:
+          return
+        yield response
+
+    def process_requests():
+      for work_request in contol_stub.Control(get_responses()):
+        logging.info('Got work %s', work_request.instruction_id)
+        try:
+          response = self.worker.do_instruction(work_request)
+        except Exception:  # pylint: disable=broad-except
+          response = beam_fn_api_pb2.InstructionResponse(
+              instruction_id=work_request.instruction_id,
+              error=traceback.format_exc())
+        responses.put(response)
+    t = threading.Thread(target=process_requests)
+    t.start()
+    t.join()
+    # get_responses may be blocked on responses.get(), but we need to return
+    # control to its caller.
+    responses.put(no_more_work)
+    self._data_channel_factory.close()
+    logging.info('Done consuming work.')
+
+
+class SdkWorker(object):
+
+  def __init__(self, state_handler, data_channel_factory):
+    self.fns = {}
+    self.state_handler = state_handler
+    self.data_channel_factory = data_channel_factory
+
+  def do_instruction(self, request):
+    request_type = request.WhichOneof('request')
+    if request_type:
+      # E.g. if register is set, this will construct
+      # InstructionResponse(register=self.register(request.register))
+      return beam_fn_api_pb2.InstructionResponse(**{
+          'instruction_id': request.instruction_id,
+          request_type: getattr(self, request_type)
+                        (getattr(request, request_type), request.instruction_id)
+      })
+    else:
+      raise NotImplementedError
+
+  def register(self, request, unused_instruction_id=None):
+    for process_bundle_descriptor in request.process_bundle_descriptor:
+      self.fns[process_bundle_descriptor.id] = process_bundle_descriptor
+      for p_transform in list(process_bundle_descriptor.primitive_transform):
+        self.fns[p_transform.function_spec.id] = p_transform.function_spec
+    return beam_fn_api_pb2.RegisterResponse()
+
+  def initial_source_split(self, request, unused_instruction_id=None):
+    source_spec = self.fns[request.source_reference]
+    assert source_spec.urn == PYTHON_SOURCE_URN
+    source_bundle = unpack_and_deserialize_py_fn(
+        self.fns[request.source_reference])
+    splits = source_bundle.source.split(request.desired_bundle_size_bytes,
+                                        source_bundle.start_position,
+                                        source_bundle.stop_position)
+    response = beam_fn_api_pb2.InitialSourceSplitResponse()
+    response.splits.extend([
+        beam_fn_api_pb2.SourceSplit(
+            source=serialize_and_pack_py_fn(split, PYTHON_SOURCE_URN),
+            relative_size=split.weight,
+        )
+        for split in splits
+    ])
+    return response
+
+  def create_execution_tree(self, descriptor):
+    # TODO(vikasrk): Add an id field to Coder proto and use that instead.
+    coders = {coder.function_spec.id: operation_specs.get_coder_from_spec(
+        json.loads(unpack_function_spec_data(coder.function_spec)))
+              for coder in descriptor.coders}
+
+    counter_factory = counters.CounterFactory()
+    # TODO(robertwb): Figure out the correct prefix to use for output counters
+    # from StateSampler.
+    state_sampler = statesampler.StateSampler(
+        'fnapi-step%s-' % descriptor.id, counter_factory)
+    consumers = collections.defaultdict(lambda: collections.defaultdict(list))
+    ops_by_id = {}
+    reversed_ops = []
+
+    for transform in reversed(descriptor.primitive_transform):
+      # TODO(robertwb): Figure out how to plumb through the operation name (e.g.
+      # "s3") from the service through the FnAPI so that msec counters can be
+      # reported and correctly plumbed through the service and the UI.
+      operation_name = 'fnapis%s' % transform.id
+
+      def only_element(iterable):
+        element, = iterable
+        return element
+
+      if transform.function_spec.urn == DATA_OUTPUT_URN:
+        target = beam_fn_api_pb2.Target(
+            primitive_transform_reference=transform.id,
+            name=only_element(transform.outputs.keys()))
+
+        op = DataOutputOperation(
+            operation_name,
+            transform.step_name,
+            consumers[transform.id],
+            counter_factory,
+            state_sampler,
+            coders[only_element(transform.outputs.values()).coder_reference],
+            target,
+            self.data_channel_factory.create_data_channel(
+                transform.function_spec))
+
+      elif transform.function_spec.urn == DATA_INPUT_URN:
+        target = beam_fn_api_pb2.Target(
+            primitive_transform_reference=transform.id,
+            name=only_element(transform.inputs.keys()))
+        op = DataInputOperation(
+            operation_name,
+            transform.step_name,
+            consumers[transform.id],
+            counter_factory,
+            state_sampler,
+            coders[only_element(transform.outputs.values()).coder_reference],
+            target,
+            self.data_channel_factory.create_data_channel(
+                transform.function_spec))
+
+      elif transform.function_spec.urn == PYTHON_DOFN_URN:
+        def create_side_input(tag, si):
+          # TODO(robertwb): Extract windows (and keys) out of element data.
+          return operation_specs.WorkerSideInputSource(
+              tag=tag,
+              source=SideInputSource(
+                  self.state_handler,
+                  beam_fn_api_pb2.StateKey(
+                      function_spec_reference=si.view_fn.id),
+                  coder=unpack_and_deserialize_py_fn(si.view_fn)))
+        output_tags = list(transform.outputs.keys())
+        spec = operation_specs.WorkerDoFn(
+            serialized_fn=unpack_function_spec_data(transform.function_spec),
+            output_tags=output_tags,
+            input=None,
+            side_inputs=[create_side_input(tag, si)
+                         for tag, si in transform.side_inputs.items()],
+            output_coders=[coders[transform.outputs[out].coder_reference]
+                           for out in output_tags])
+
+        op = operations.DoOperation(operation_name, spec, counter_factory,
+                                    state_sampler)
+        # TODO(robertwb): Move these to the constructor.
+        op.step_name = transform.step_name
+        for tag, op_consumers in consumers[transform.id].items():
+          for consumer in op_consumers:
+            op.add_receiver(
+                consumer, output_tags.index(tag))
+
+      elif transform.function_spec.urn == IDENTITY_DOFN_URN:
+        op = operations.FlattenOperation(operation_name, None, counter_factory,
+                                         state_sampler)
+        # TODO(robertwb): Move these to the constructor.
+        op.step_name = transform.step_name
+        for tag, op_consumers in consumers[transform.id].items():
+          for consumer in op_consumers:
+            op.add_receiver(consumer, 0)
+
+      elif transform.function_spec.urn == PYTHON_SOURCE_URN:
+        source = load_compressed(unpack_function_spec_data(
+            transform.function_spec))
+        # TODO(vikasrk): Remove this once custom source is implemented with
+        # splittable dofn via the data plane.
+        spec = operation_specs.WorkerRead(
+            iobase.SourceBundle(1.0, source, None, None),
+            [WindowedValueCoder(source.default_output_coder())])
+        op = operations.ReadOperation(operation_name, spec, counter_factory,
+                                      state_sampler)
+        op.step_name = transform.step_name
+        output_tags = list(transform.outputs.keys())
+        for tag, op_consumers in consumers[transform.id].items():
+          for consumer in op_consumers:
+            op.add_receiver(
+                consumer, output_tags.index(tag))
+
+      else:
+        raise NotImplementedError
+
+      # Record consumers.
+      for _, inputs in transform.inputs.items():
+        for target in inputs.target:
+          consumers[target.primitive_transform_reference][target.name].append(
+              op)
+
+      reversed_ops.append(op)
+      ops_by_id[transform.id] = op
+
+    return list(reversed(reversed_ops)), ops_by_id
+
+  def process_bundle(self, request, instruction_id):
+    ops, ops_by_id = self.create_execution_tree(
+        self.fns[request.process_bundle_descriptor_reference])
+
+    expected_inputs = []
+    for _, op in ops_by_id.items():
+      if isinstance(op, DataOutputOperation):
+        # TODO(robertwb): Is there a better way to pass the instruction id to
+        # the operation?
+        op.set_output_stream(op.data_channel.output_stream(
+            instruction_id, op.target))
+      elif isinstance(op, DataInputOperation):
+        # We must wait until we receive "end of stream" for each of these ops.
+        expected_inputs.append(op)
+
+    # Start all operations.
+    for op in reversed(ops):
+      logging.info('start %s', op)
+      op.start()
+
+    # Inject inputs from data plane.
+    for input_op in expected_inputs:
+      for data in input_op.data_channel.input_elements(
+          instruction_id, [input_op.target]):
+        # ignores input name
+        target_op = ops_by_id[data.target.primitive_transform_reference]
+        # lacks coder for non-input ops
+        target_op.process_encoded(data.data)
+
+    # Finish all operations.
+    for op in ops:
+      logging.info('finish %s', op)
+      op.finish()
+
+    return beam_fn_api_pb2.ProcessBundleResponse()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
new file mode 100644
index 0000000..28828c3
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""SDK Fn Harness entry point."""
+
+import logging
+import os
+import sys
+
+import grpc
+from google.protobuf import text_format
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler
+from apache_beam.runners.worker.sdk_worker import SdkHarness
+
+
+def main(unused_argv):
+  """Main entry point for SDK Fn Harness."""
+  logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+  text_format.Merge(os.environ['LOGGING_API_SERVICE_DESCRIPTOR'],
+                    logging_service_descriptor)
+
+  # Send all logs to the runner.
+  fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor)
+  # TODO(vikasrk): This should be picked up from pipeline options.
+  logging.getLogger().setLevel(logging.INFO)
+  logging.getLogger().addHandler(fn_log_handler)
+
+  try:
+    logging.info('Python sdk harness started.')
+    service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+    text_format.Merge(os.environ['CONTROL_API_SERVICE_DESCRIPTOR'],
+                      service_descriptor)
+    # TODO(robertwb): Support credentials.
+    assert not service_descriptor.oauth2_client_credentials_grant.url
+    channel = grpc.insecure_channel(service_descriptor.url)
+    SdkHarness(channel).run()
+    logging.info('Python sdk harness exiting.')
+  except:  # pylint: disable=broad-except
+    logging.exception('Python sdk harness failed: ')
+    raise
+  finally:
+    fn_log_handler.close()
+
+
+if __name__ == '__main__':
+  main(sys.argv)


[4/4] beam git commit: Closes #2644

Posted by ro...@apache.org.
Closes #2644


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/7c425b09
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/7c425b09
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/7c425b09

Branch: refs/heads/master
Commit: 7c425b097fdadcc160cdaa7d8992416b991e37c4
Parents: b8131fe a856fcf
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Mon May 1 17:44:30 2017 -0700
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Mon May 1 17:44:30 2017 -0700

----------------------------------------------------------------------
 .../apache_beam/runners/portability/__init__.py |  16 +
 .../runners/portability/fn_api_runner.py        | 471 ++++++++++++++
 .../runners/portability/fn_api_runner_test.py   |  40 ++
 .../portability/maptask_executor_runner.py      | 468 +++++++++++++
 .../portability/maptask_executor_runner_test.py | 204 ++++++
 .../apache_beam/runners/worker/__init__.py      |  16 +
 .../apache_beam/runners/worker/data_plane.py    | 288 ++++++++
 .../runners/worker/data_plane_test.py           | 139 ++++
 .../apache_beam/runners/worker/log_handler.py   | 100 +++
 .../runners/worker/log_handler_test.py          | 105 +++
 .../apache_beam/runners/worker/logger.pxd       |  25 +
 .../python/apache_beam/runners/worker/logger.py | 173 +++++
 .../apache_beam/runners/worker/logger_test.py   | 182 ++++++
 .../apache_beam/runners/worker/opcounters.pxd   |  45 ++
 .../apache_beam/runners/worker/opcounters.py    | 162 +++++
 .../runners/worker/opcounters_test.py           | 149 +++++
 .../runners/worker/operation_specs.py           | 368 +++++++++++
 .../apache_beam/runners/worker/operations.pxd   |  89 +++
 .../apache_beam/runners/worker/operations.py    | 651 +++++++++++++++++++
 .../apache_beam/runners/worker/sdk_worker.py    | 451 +++++++++++++
 .../runners/worker/sdk_worker_main.py           |  62 ++
 .../runners/worker/sdk_worker_test.py           | 168 +++++
 .../apache_beam/runners/worker/sideinputs.py    | 166 +++++
 .../runners/worker/sideinputs_test.py           | 150 +++++
 .../apache_beam/runners/worker/statesampler.pyx | 237 +++++++
 .../runners/worker/statesampler_fake.py         |  34 +
 .../runners/worker/statesampler_test.py         | 102 +++
 sdks/python/generate_pydoc.sh                   |   2 +
 sdks/python/setup.py                            |   8 +-
 29 files changed, 5069 insertions(+), 2 deletions(-)
----------------------------------------------------------------------



[3/4] beam git commit: Fn API support for Python.

Posted by ro...@apache.org.
Fn API support for Python.

Also added supporting worker code and a runner
that exercises this code directly.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a856fcf3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a856fcf3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a856fcf3

Branch: refs/heads/master
Commit: a856fcf3f9ca48530078e1d7610e234110cd5890
Parents: b8131fe
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Thu Apr 20 15:59:25 2017 -0500
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Mon May 1 17:44:29 2017 -0700

----------------------------------------------------------------------
 .../apache_beam/runners/portability/__init__.py |  16 +
 .../runners/portability/fn_api_runner.py        | 471 ++++++++++++++
 .../runners/portability/fn_api_runner_test.py   |  40 ++
 .../portability/maptask_executor_runner.py      | 468 +++++++++++++
 .../portability/maptask_executor_runner_test.py | 204 ++++++
 .../apache_beam/runners/worker/__init__.py      |  16 +
 .../apache_beam/runners/worker/data_plane.py    | 288 ++++++++
 .../runners/worker/data_plane_test.py           | 139 ++++
 .../apache_beam/runners/worker/log_handler.py   | 100 +++
 .../runners/worker/log_handler_test.py          | 105 +++
 .../apache_beam/runners/worker/logger.pxd       |  25 +
 .../python/apache_beam/runners/worker/logger.py | 173 +++++
 .../apache_beam/runners/worker/logger_test.py   | 182 ++++++
 .../apache_beam/runners/worker/opcounters.pxd   |  45 ++
 .../apache_beam/runners/worker/opcounters.py    | 162 +++++
 .../runners/worker/opcounters_test.py           | 149 +++++
 .../runners/worker/operation_specs.py           | 368 +++++++++++
 .../apache_beam/runners/worker/operations.pxd   |  89 +++
 .../apache_beam/runners/worker/operations.py    | 651 +++++++++++++++++++
 .../apache_beam/runners/worker/sdk_worker.py    | 451 +++++++++++++
 .../runners/worker/sdk_worker_main.py           |  62 ++
 .../runners/worker/sdk_worker_test.py           | 168 +++++
 .../apache_beam/runners/worker/sideinputs.py    | 166 +++++
 .../runners/worker/sideinputs_test.py           | 150 +++++
 .../apache_beam/runners/worker/statesampler.pyx | 237 +++++++
 .../runners/worker/statesampler_fake.py         |  34 +
 .../runners/worker/statesampler_test.py         | 102 +++
 sdks/python/generate_pydoc.sh                   |   2 +
 sdks/python/setup.py                            |   8 +-
 29 files changed, 5069 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/__init__.py b/sdks/python/apache_beam/runners/portability/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
new file mode 100644
index 0000000..5802c17
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -0,0 +1,471 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A PipelineRunner using the SDK harness.
+"""
+import collections
+import json
+import logging
+import Queue as queue
+import threading
+
+import grpc
+from concurrent import futures
+
+import apache_beam as beam
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.coders.coder_impl import create_InputStream
+from apache_beam.coders.coder_impl import create_OutputStream
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.portability import maptask_executor_runner
+from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import sdk_worker
+
+
+def streaming_rpc_handler(cls, method_name):
+  """Un-inverts the flow of control between the runner and the sdk harness."""
+
+  class StreamingRpcHandler(cls):
+
+    _DONE = object()
+
+    def __init__(self):
+      self._push_queue = queue.Queue()
+      self._pull_queue = queue.Queue()
+      setattr(self, method_name, self.run)
+      self._read_thread = threading.Thread(target=self._read)
+
+    def run(self, iterator, context):
+      self._inputs = iterator
+      # Note: We only support one client for now.
+      self._read_thread.start()
+      while True:
+        to_push = self._push_queue.get()
+        if to_push is self._DONE:
+          return
+        yield to_push
+
+    def _read(self):
+      for data in self._inputs:
+        self._pull_queue.put(data)
+
+    def push(self, item):
+      self._push_queue.put(item)
+
+    def pull(self, timeout=None):
+      return self._pull_queue.get(timeout=timeout)
+
+    def empty(self):
+      return self._pull_queue.empty()
+
+    def done(self):
+      self.push(self._DONE)
+      self._read_thread.join()
+
+  return StreamingRpcHandler()
+
+
+class OldeSourceSplittableDoFn(beam.DoFn):
+  """A DoFn that reads and emits an entire source.
+  """
+
+  # TODO(robertwb): Make this a full SDF with progress splitting, etc.
+  def process(self, source):
+    if isinstance(source, iobase.SourceBundle):
+      for value in source.source.read(source.source.get_range_tracker(
+          source.start_position, source.stop_position)):
+        yield value
+    else:
+      # Dataflow native source
+      with source.reader() as reader:
+        for value in reader:
+          yield value
+
+
+# See DataflowRunner._pardo_fn_data
+OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps(
+    (OldeSourceSplittableDoFn(), (), {}, [],
+     beam.transforms.core.Windowing(GlobalWindows())))
+
+
+class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
+
+  def __init__(self):
+    super(FnApiRunner, self).__init__()
+    self._last_uid = -1
+
+  def has_metrics_support(self):
+    return False
+
+  def _next_uid(self):
+    self._last_uid += 1
+    return str(self._last_uid)
+
+  def _map_task_registration(self, map_task, state_handler,
+                             data_operation_spec):
+    input_data = {}
+    runner_sinks = {}
+    transforms = []
+    transform_index_to_id = {}
+
+    # Maps coders to new coder objects and references.
+    coders = {}
+
+    def coder_id(coder):
+      if coder not in coders:
+        coders[coder] = beam_fn_api_pb2.Coder(
+            function_spec=sdk_worker.pack_function_spec_data(
+                json.dumps(coder.as_cloud_object()),
+                sdk_worker.PYTHON_CODER_URN, id=self._next_uid()))
+
+      return coders[coder].function_spec.id
+
+    def output_tags(op):
+      return getattr(op, 'output_tags', ['out'])
+
+    def as_target(op_input):
+      input_op_index, input_output_index = op_input
+      input_op = map_task[input_op_index][1]
+      return {
+          'ignored_input_tag':
+              beam_fn_api_pb2.Target.List(target=[
+                  beam_fn_api_pb2.Target(
+                      primitive_transform_reference=transform_index_to_id[
+                          input_op_index],
+                      name=output_tags(input_op)[input_output_index])
+              ])
+      }
+
+    def outputs(op):
+      return {
+          tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder))
+          for tag, coder in zip(output_tags(op), op.output_coders)
+      }
+
+    for op_ix, (stage_name, operation) in enumerate(map_task):
+      transform_id = transform_index_to_id[op_ix] = self._next_uid()
+      if isinstance(operation, operation_specs.WorkerInMemoryWrite):
+        # Write this data back to the runner.
+        fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_OUTPUT_URN,
+                                          id=self._next_uid())
+        if data_operation_spec:
+          fn.data.Pack(data_operation_spec)
+        inputs = as_target(operation.input)
+        side_inputs = {}
+        runner_sinks[(transform_id, 'out')] = operation
+
+      elif isinstance(operation, operation_specs.WorkerRead):
+        # A Read is either translated to a direct injection of windowed values
+        # into the sdk worker, or an injection of the source object into the
+        # sdk worker as data followed by an SDF that reads that source.
+        if (isinstance(operation.source.source,
+                       worker_runner_base.InMemorySource)
+            and isinstance(operation.source.source.default_output_coder(),
+                           WindowedValueCoder)):
+          output_stream = create_OutputStream()
+          element_coder = (
+              operation.source.source.default_output_coder().get_impl())
+          # Re-encode the elements in the nested context and
+          # concatenate them together
+          for element in operation.source.source.read(None):
+            element_coder.encode_to_stream(element, output_stream, True)
+          target_name = self._next_uid()
+          input_data[(transform_id, target_name)] = output_stream.get()
+          fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_INPUT_URN,
+                                            id=self._next_uid())
+          if data_operation_spec:
+            fn.data.Pack(data_operation_spec)
+          inputs = {target_name: beam_fn_api_pb2.Target.List()}
+          side_inputs = {}
+        else:
+          # Read the source object from the runner.
+          source_coder = beam.coders.DillCoder()
+          input_transform_id = self._next_uid()
+          output_stream = create_OutputStream()
+          source_coder.get_impl().encode_to_stream(
+              GlobalWindows.windowed_value(operation.source),
+              output_stream,
+              True)
+          target_name = self._next_uid()
+          input_data[(input_transform_id, target_name)] = output_stream.get()
+          input_ptransform = beam_fn_api_pb2.PrimitiveTransform(
+              id=input_transform_id,
+              function_spec=beam_fn_api_pb2.FunctionSpec(
+                  urn=sdk_worker.DATA_INPUT_URN,
+                  id=self._next_uid()),
+              # TODO(robertwb): Possible name collision.
+              step_name=stage_name + '/inject_source',
+              inputs={target_name: beam_fn_api_pb2.Target.List()},
+              outputs={
+                  'out':
+                      beam_fn_api_pb2.PCollection(
+                          coder_reference=coder_id(source_coder))
+              })
+          if data_operation_spec:
+            input_ptransform.function_spec.data.Pack(data_operation_spec)
+          transforms.append(input_ptransform)
+
+          # Read the elements out of the source.
+          fn = sdk_worker.pack_function_spec_data(
+              OLDE_SOURCE_SPLITTABLE_DOFN_DATA,
+              sdk_worker.PYTHON_DOFN_URN,
+              id=self._next_uid())
+          inputs = {
+              'ignored_input_tag':
+                  beam_fn_api_pb2.Target.List(target=[
+                      beam_fn_api_pb2.Target(
+                          primitive_transform_reference=input_transform_id,
+                          name='out')
+                  ])
+          }
+          side_inputs = {}
+
+      elif isinstance(operation, operation_specs.WorkerDoFn):
+        fn = sdk_worker.pack_function_spec_data(
+            operation.serialized_fn,
+            sdk_worker.PYTHON_DOFN_URN,
+            id=self._next_uid())
+        inputs = as_target(operation.input)
+        # Store the contents of each side input for state access.
+        for si in operation.side_inputs:
+          assert isinstance(si.source, iobase.BoundedSource)
+          element_coder = si.source.default_output_coder()
+          view_id = self._next_uid()
+          # TODO(robertwb): Actually flesh out the ViewFn API.
+          side_inputs[si.tag] = beam_fn_api_pb2.SideInput(
+              view_fn=sdk_worker.serialize_and_pack_py_fn(
+                  element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN,
+                  id=view_id))
+          # Re-encode the elements in the nested context and
+          # concatenate them together
+          output_stream = create_OutputStream()
+          for element in si.source.read(
+              si.source.get_range_tracker(None, None)):
+            element_coder.get_impl().encode_to_stream(
+                element, output_stream, True)
+          elements_data = output_stream.get()
+          state_key = beam_fn_api_pb2.StateKey(function_spec_reference=view_id)
+          state_handler.Clear(state_key)
+          state_handler.Append(
+              beam_fn_api_pb2.SimpleStateAppendRequest(
+                  state_key=state_key, data=[elements_data]))
+
+      elif isinstance(operation, operation_specs.WorkerFlatten):
+        fn = sdk_worker.pack_function_spec_data(
+            operation.serialized_fn,
+            sdk_worker.IDENTITY_DOFN_URN,
+            id=self._next_uid())
+        inputs = {
+            'ignored_input_tag':
+                beam_fn_api_pb2.Target.List(target=[
+                    beam_fn_api_pb2.Target(
+                        primitive_transform_reference=transform_index_to_id[
+                            input_op_index],
+                        name=output_tags(map_task[input_op_index][1])[
+                            input_output_index])
+                    for input_op_index, input_output_index in operation.inputs
+                ])
+        }
+        side_inputs = {}
+
+      else:
+        raise TypeError(operation)
+
+      ptransform = beam_fn_api_pb2.PrimitiveTransform(
+          id=transform_id,
+          function_spec=fn,
+          step_name=stage_name,
+          inputs=inputs,
+          side_inputs=side_inputs,
+          outputs=outputs(operation))
+      transforms.append(ptransform)
+
+    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
+        id=self._next_uid(), coders=coders.values(),
+        primitive_transform=transforms)
+    return beam_fn_api_pb2.InstructionRequest(
+        instruction_id=self._next_uid(),
+        register=beam_fn_api_pb2.RegisterRequest(
+            process_bundle_descriptor=[process_bundle_descriptor
+                                      ])), runner_sinks, input_data
+
+  def _run_map_task(
+      self, map_task, control_handler, state_handler, data_plane_handler,
+      data_operation_spec):
+    registration, sinks, input_data = self._map_task_registration(
+        map_task, state_handler, data_operation_spec)
+    control_handler.push(registration)
+    process_bundle = beam_fn_api_pb2.InstructionRequest(
+        instruction_id=self._next_uid(),
+        process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
+            process_bundle_descriptor_reference=registration.register.
+            process_bundle_descriptor[0].id))
+
+    for (transform_id, name), elements in input_data.items():
+      data_out = data_plane_handler.output_stream(
+          process_bundle.instruction_id, beam_fn_api_pb2.Target(
+              primitive_transform_reference=transform_id, name=name))
+      data_out.write(elements)
+      data_out.close()
+
+    control_handler.push(process_bundle)
+    while True:
+      result = control_handler.pull()
+      if result.instruction_id == process_bundle.instruction_id:
+        if result.error:
+          raise RuntimeError(result.error)
+        expected_targets = [
+            beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
+                                   name=output_name)
+            for (transform_id, output_name), _ in sinks.items()]
+        for output in data_plane_handler.input_elements(
+            process_bundle.instruction_id, expected_targets):
+          target_tuple = (
+              output.target.primitive_transform_reference, output.target.name)
+          if target_tuple not in sinks:
+            # Unconsumed output.
+            continue
+          sink_op = sinks[target_tuple]
+          coder = sink_op.output_coders[0]
+          input_stream = create_InputStream(output.data)
+          elements = []
+          while input_stream.size() > 0:
+            elements.append(coder.get_impl().decode_from_stream(
+                input_stream, True))
+          if not sink_op.write_windowed_values:
+            elements = [e.value for e in elements]
+          for e in elements:
+            sink_op.output_buffer.append(e)
+        return
+
+  def execute_map_tasks(self, ordered_map_tasks, direct=True):
+    if direct:
+      controller = FnApiRunner.DirectController()
+    else:
+      controller = FnApiRunner.GrpcController()
+
+    try:
+      for _, map_task in ordered_map_tasks:
+        logging.info('Running %s', map_task)
+        self._run_map_task(
+            map_task, controller.control_handler, controller.state_handler,
+            controller.data_plane_handler, controller.data_operation_spec())
+    finally:
+      controller.close()
+
+  class SimpleState(object):  # TODO(robertwb): Inherit from GRPC servicer.
+
+    def __init__(self):
+      self._all = collections.defaultdict(list)
+
+    def Get(self, state_key):
+      return beam_fn_api_pb2.Elements.Data(
+          data=''.join(self._all[self._to_key(state_key)]))
+
+    def Append(self, append_request):
+      self._all[self._to_key(append_request.state_key)].extend(
+          append_request.data)
+
+    def Clear(self, state_key):
+      try:
+        del self._all[self._to_key(state_key)]
+      except KeyError:
+        pass
+
+    @staticmethod
+    def _to_key(state_key):
+      return (state_key.function_spec_reference, state_key.window,
+              state_key.key)
+
+  class DirectController(object):
+    """An in-memory controller for fn API control, state and data planes."""
+
+    def __init__(self):
+      self._responses = []
+      self.state_handler = FnApiRunner.SimpleState()
+      self.control_handler = self
+      self.data_plane_handler = data_plane.InMemoryDataChannel()
+      self.worker = sdk_worker.SdkWorker(
+          self.state_handler, data_plane.InMemoryDataChannelFactory(
+              self.data_plane_handler.inverse()))
+
+    def push(self, request):
+      logging.info('CONTROL REQUEST %s', request)
+      response = self.worker.do_instruction(request)
+      logging.info('CONTROL RESPONSE %s', response)
+      self._responses.append(response)
+
+    def pull(self):
+      return self._responses.pop(0)
+
+    def done(self):
+      pass
+
+    def close(self):
+      pass
+
+    def data_operation_spec(self):
+      return None
+
+  class GrpcController(object):
+    """An grpc based controller for fn API control, state and data planes."""
+
+    def __init__(self):
+      self.state_handler = FnApiRunner.SimpleState()
+      self.control_server = grpc.server(
+          futures.ThreadPoolExecutor(max_workers=10))
+      self.control_port = self.control_server.add_insecure_port('[::]:0')
+
+      self.data_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+      self.data_port = self.data_server.add_insecure_port('[::]:0')
+
+      self.control_handler = streaming_rpc_handler(
+          beam_fn_api_pb2.BeamFnControlServicer, 'Control')
+      beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
+          self.control_handler, self.control_server)
+
+      self.data_plane_handler = data_plane.GrpcServerDataChannel()
+      beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+          self.data_plane_handler, self.data_server)
+
+      logging.info('starting control server on port %s', self.control_port)
+      logging.info('starting data server on port %s', self.data_port)
+      self.data_server.start()
+      self.control_server.start()
+
+      self.worker = sdk_worker.SdkHarness(
+          grpc.insecure_channel('localhost:%s' % self.control_port))
+      self.worker_thread = threading.Thread(target=self.worker.run)
+      logging.info('starting worker')
+      self.worker_thread.start()
+
+    def data_operation_spec(self):
+      url = 'localhost:%s' % self.data_port
+      remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+      remote_grpc_port.api_service_descriptor.url = url
+      return remote_grpc_port
+
+    def close(self):
+      self.control_handler.done()
+      self.worker_thread.join()
+      self.data_plane_handler.close()
+      self.control_server.stop(5).wait()
+      self.data_server.stop(5).wait()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
new file mode 100644
index 0000000..633602f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.runners.portability import fn_api_runner
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class FnApiRunnerTest(maptask_executor_runner.MapTaskExecutorRunner):
+
+  def create_pipeline(self):
+    return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
+
+  def test_combine_per_key(self):
+    # TODO(robertwb): Implement PGBKCV operation.
+    pass
+
+  # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
new file mode 100644
index 0000000..d273e18
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
@@ -0,0 +1,468 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Beam runner for testing/profiling worker code directly.
+"""
+
+import collections
+import logging
+import time
+
+import apache_beam as beam
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.runners import DataflowRunner
+from apache_beam.runners.dataflow.internal.dependency import _dependency_file_copy
+from apache_beam.runners.dataflow.internal.names import PropertyNames
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.runners.runner import PipelineRunner
+from apache_beam.runners.runner import PipelineState
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+try:
+  from apache_beam.runners.worker import statesampler
+except ImportError:
+  from apache_beam.runners.worker import statesampler_fake as statesampler
+from apache_beam.typehints import typehints
+from apache_beam.utils import pipeline_options
+from apache_beam.utils import profiler
+from apache_beam.utils.counters import CounterFactory
+
+
+class MapTaskExecutorRunner(PipelineRunner):
+  """Beam runner translating a pipeline into map tasks that are then executed.
+
+  Primarily intended for testing and profiling the worker code paths.
+  """
+
+  def __init__(self):
+    self.executors = []
+
+  def has_metrics_support(self):
+    """Returns whether this runner supports metrics or not.
+    """
+    return False
+
+  def run(self, pipeline):
+    MetricsEnvironment.set_metrics_supported(self.has_metrics_support())
+    # List of map tasks  Each map task is a list of
+    # (stage_name, operation_specs.WorkerOperation) instructions.
+    self.map_tasks = []
+
+    # Map of pvalues to
+    # (map_task_index, producer_operation_index, producer_output_index)
+    self.outputs = {}
+
+    # Unique mappings of PCollections to strings.
+    self.side_input_labels = collections.defaultdict(
+        lambda: str(len(self.side_input_labels)))
+
+    # Mapping of map task indices to all map tasks that must preceed them.
+    self.dependencies = collections.defaultdict(set)
+
+    # Visit the graph, building up the map_tasks and their metadata.
+    super(MapTaskExecutorRunner, self).run(pipeline)
+
+    # Now run the tasks in topological order.
+    def compute_depth_map(deps):
+      memoized = {}
+
+      def compute_depth(x):
+        if x not in memoized:
+          memoized[x] = 1 + max([-1] + [compute_depth(y) for y in deps[x]])
+        return memoized[x]
+
+      return {x: compute_depth(x) for x in deps.keys()}
+
+    map_task_depths = compute_depth_map(self.dependencies)
+    ordered_map_tasks = sorted((map_task_depths.get(ix, -1), map_task)
+                               for ix, map_task in enumerate(self.map_tasks))
+
+    profile_options = pipeline.options.view_as(
+        pipeline_options.ProfilingOptions)
+    if profile_options.profile_cpu:
+      with profiler.Profile(
+          profile_id='worker-runner',
+          profile_location=profile_options.profile_location,
+          log_results=True, file_copy_fn=_dependency_file_copy):
+        self.execute_map_tasks(ordered_map_tasks)
+    else:
+      self.execute_map_tasks(ordered_map_tasks)
+
+    return WorkerRunnerResult(PipelineState.UNKNOWN)
+
+  def metrics_containers(self):
+    return [op.metrics_container
+            for ex in self.executors
+            for op in ex.operations()]
+
+  def execute_map_tasks(self, ordered_map_tasks):
+    tt = time.time()
+    for ix, (_, map_task) in enumerate(ordered_map_tasks):
+      logging.info('Running %s', map_task)
+      t = time.time()
+      stage_names, all_operations = zip(*map_task)
+      # TODO(robertwb): The DataflowRunner worker receives system step names
+      # (e.g. "s3") that are used to label the output msec counters.  We use the
+      # operation names here, but this is not the same scheme used by the
+      # DataflowRunner; the result is that the output msec counters are named
+      # differently.
+      system_names = stage_names
+      # Create the CounterFactory and StateSampler for this MapTask.
+      # TODO(robertwb): Output counters produced here are currently ignored.
+      counter_factory = CounterFactory()
+      state_sampler = statesampler.StateSampler('%s-' % ix, counter_factory)
+      map_executor = operations.SimpleMapTaskExecutor(
+          operation_specs.MapTask(
+              all_operations, 'S%02d' % ix,
+              system_names, stage_names, system_names),
+          counter_factory,
+          state_sampler)
+      self.executors.append(map_executor)
+      map_executor.execute()
+      logging.info(
+          'Stage %s finished: %0.3f sec', stage_names[0], time.time() - t)
+    logging.info('Total time: %0.3f sec', time.time() - tt)
+
+  def run_Read(self, transform_node):
+    self._run_read_from(transform_node, transform_node.transform.source)
+
+  def _run_read_from(self, transform_node, source):
+    """Used when this operation is the result of reading source."""
+    if not isinstance(source, iobase.NativeSource):
+      source = iobase.SourceBundle(1.0, source, None, None)
+    output = transform_node.outputs[None]
+    element_coder = self._get_coder(output)
+    read_op = operation_specs.WorkerRead(source, output_coders=[element_coder])
+    self.outputs[output] = len(self.map_tasks), 0, 0
+    self.map_tasks.append([(transform_node.full_label, read_op)])
+    return len(self.map_tasks) - 1
+
+  def run_ParDo(self, transform_node):
+    transform = transform_node.transform
+    output = transform_node.outputs[None]
+    element_coder = self._get_coder(output)
+    map_task_index, producer_index, output_index = self.outputs[
+        transform_node.inputs[0]]
+
+    # If any of this ParDo's side inputs depend on outputs from this map_task,
+    # we can't continue growing this map task.
+    def is_reachable(leaf, root):
+      if leaf == root:
+        return True
+      else:
+        return any(is_reachable(x, root) for x in self.dependencies[leaf])
+
+    if any(is_reachable(self.outputs[side_input.pvalue][0], map_task_index)
+           for side_input in transform_node.side_inputs):
+      # Start a new map tasks.
+      input_element_coder = self._get_coder(transform_node.inputs[0])
+
+      output_buffer = OutputBuffer(input_element_coder)
+
+      fusion_break_write = operation_specs.WorkerInMemoryWrite(
+          output_buffer=output_buffer,
+          write_windowed_values=True,
+          input=(producer_index, output_index),
+          output_coders=[input_element_coder])
+      self.map_tasks[map_task_index].append(
+          (transform_node.full_label + '/Write', fusion_break_write))
+
+      original_map_task_index = map_task_index
+      map_task_index, producer_index, output_index = len(self.map_tasks), 0, 0
+
+      fusion_break_read = operation_specs.WorkerRead(
+          output_buffer.source_bundle(),
+          output_coders=[input_element_coder])
+      self.map_tasks.append(
+          [(transform_node.full_label + '/Read', fusion_break_read)])
+
+      self.dependencies[map_task_index].add(original_map_task_index)
+
+    def create_side_read(side_input):
+      label = self.side_input_labels[side_input]
+      output_buffer = self.run_side_write(
+          side_input.pvalue, '%s/%s' % (transform_node.full_label, label))
+      return operation_specs.WorkerSideInputSource(
+          output_buffer.source(), label)
+
+    do_op = operation_specs.WorkerDoFn(  #
+        serialized_fn=pickler.dumps(DataflowRunner._pardo_fn_data(
+            transform_node,
+            lambda side_input: self.side_input_labels[side_input])),
+        output_tags=[PropertyNames.OUT] + ['%s_%s' % (PropertyNames.OUT, tag)
+                                           for tag in transform.output_tags
+                                          ],
+        # Same assumption that DataflowRunner has about coders being compatible
+        # across outputs.
+        output_coders=[element_coder] * (len(transform.output_tags) + 1),
+        input=(producer_index, output_index),
+        side_inputs=[create_side_read(side_input)
+                     for side_input in transform_node.side_inputs])
+
+    producer_index = len(self.map_tasks[map_task_index])
+    self.outputs[transform_node.outputs[None]] = (
+        map_task_index, producer_index, 0)
+    for ix, tag in enumerate(transform.output_tags):
+      self.outputs[transform_node.outputs[
+          tag]] = map_task_index, producer_index, ix + 1
+    self.map_tasks[map_task_index].append((transform_node.full_label, do_op))
+
+    for side_input in transform_node.side_inputs:
+      self.dependencies[map_task_index].add(self.outputs[side_input.pvalue][0])
+
+  def run_side_write(self, pcoll, label):
+    map_task_index, producer_index, output_index = self.outputs[pcoll]
+
+    windowed_element_coder = self._get_coder(pcoll)
+    output_buffer = OutputBuffer(windowed_element_coder)
+    write_sideinput_op = operation_specs.WorkerInMemoryWrite(
+        output_buffer=output_buffer,
+        write_windowed_values=True,
+        input=(producer_index, output_index),
+        output_coders=[windowed_element_coder])
+    self.map_tasks[map_task_index].append(
+        (label, write_sideinput_op))
+    return output_buffer
+
+  def run_GroupByKeyOnly(self, transform_node):
+    map_task_index, producer_index, output_index = self.outputs[
+        transform_node.inputs[0]]
+    grouped_element_coder = self._get_coder(transform_node.outputs[None],
+                                            windowed=False)
+    windowed_ungrouped_element_coder = self._get_coder(transform_node.inputs[0])
+
+    output_buffer = GroupingOutputBuffer(grouped_element_coder)
+    shuffle_write = operation_specs.WorkerInMemoryWrite(
+        output_buffer=output_buffer,
+        write_windowed_values=False,
+        input=(producer_index, output_index),
+        output_coders=[windowed_ungrouped_element_coder])
+    self.map_tasks[map_task_index].append(
+        (transform_node.full_label + '/Write', shuffle_write))
+
+    output_map_task_index = self._run_read_from(
+        transform_node, output_buffer.source())
+    self.dependencies[output_map_task_index].add(map_task_index)
+
+  def run_Flatten(self, transform_node):
+    output_buffer = OutputBuffer(self._get_coder(transform_node.outputs[None]))
+    output_map_task = self._run_read_from(transform_node,
+                                          output_buffer.source())
+
+    for input in transform_node.inputs:
+      map_task_index, producer_index, output_index = self.outputs[input]
+      element_coder = self._get_coder(input)
+      flatten_write = operation_specs.WorkerInMemoryWrite(
+          output_buffer=output_buffer,
+          write_windowed_values=True,
+          input=(producer_index, output_index),
+          output_coders=[element_coder])
+      self.map_tasks[map_task_index].append(
+          (transform_node.full_label + '/Write', flatten_write))
+      self.dependencies[output_map_task].add(map_task_index)
+
+  def apply_CombinePerKey(self, transform, input):
+    # TODO(robertwb): Support side inputs.
+    assert not transform.args and not transform.kwargs
+    return (input
+            | PartialGroupByKeyCombineValues(transform.fn)
+            | beam.GroupByKey()
+            | MergeAccumulators(transform.fn)
+            | ExtractOutputs(transform.fn))
+
+  def run_PartialGroupByKeyCombineValues(self, transform_node):
+    element_coder = self._get_coder(transform_node.outputs[None])
+    _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+    combine_op = operation_specs.WorkerPartialGroupByKey(
+        combine_fn=pickler.dumps(
+            (transform_node.transform.combine_fn, (), {}, ())),
+        output_coders=[element_coder],
+        input=(producer_index, output_index))
+    self._run_as_op(transform_node, combine_op)
+
+  def run_MergeAccumulators(self, transform_node):
+    self._run_combine_transform(transform_node, 'merge')
+
+  def run_ExtractOutputs(self, transform_node):
+    self._run_combine_transform(transform_node, 'extract')
+
+  def _run_combine_transform(self, transform_node, phase):
+    transform = transform_node.transform
+    element_coder = self._get_coder(transform_node.outputs[None])
+    _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+    combine_op = operation_specs.WorkerCombineFn(
+        serialized_fn=pickler.dumps(
+            (transform.combine_fn, (), {}, ())),
+        phase=phase,
+        output_coders=[element_coder],
+        input=(producer_index, output_index))
+    self._run_as_op(transform_node, combine_op)
+
+  def _get_coder(self, pvalue, windowed=True):
+    # TODO(robertwb): This should be an attribute of the pvalue itself.
+    return DataflowRunner._get_coder(
+        pvalue.element_type or typehints.Any,
+        pvalue.windowing.windowfn.get_window_coder() if windowed else None)
+
+  def _run_as_op(self, transform_node, op):
+    """Single-output operation in the same map task as its input."""
+    map_task_index, _, _ = self.outputs[transform_node.inputs[0]]
+    op_index = len(self.map_tasks[map_task_index])
+    output = transform_node.outputs[None]
+    self.outputs[output] = map_task_index, op_index, 0
+    self.map_tasks[map_task_index].append((transform_node.full_label, op))
+
+
+class InMemorySource(iobase.BoundedSource):
+  """Source for reading an (as-yet unwritten) set of in-memory encoded elements.
+  """
+
+  def __init__(self, encoded_elements, coder):
+    self._encoded_elements = encoded_elements
+    self._coder = coder
+
+  def get_range_tracker(self, unused_start_position, unused_end_position):
+    return None
+
+  def read(self, unused_range_tracker):
+    for encoded_element in self._encoded_elements:
+      yield self._coder.decode(encoded_element)
+
+  def default_output_coder(self):
+    return self._coder
+
+
+class OutputBuffer(object):
+
+  def __init__(self, coder):
+    self.coder = coder
+    self.elements = []
+    self.encoded_elements = []
+
+  def source(self):
+    return InMemorySource(self.encoded_elements, self.coder)
+
+  def source_bundle(self):
+    return iobase.SourceBundle(
+        1.0, InMemorySource(self.encoded_elements, self.coder), None, None)
+
+  def __repr__(self):
+    return 'GroupingOutput[%r]' % len(self.elements)
+
+  def append(self, value):
+    self.elements.append(value)
+    self.encoded_elements.append(self.coder.encode(value))
+
+
+class GroupingOutputBuffer(object):
+
+  def __init__(self, grouped_coder):
+    self.grouped_coder = grouped_coder
+    self.elements = collections.defaultdict(list)
+    self.frozen = False
+
+  def source(self):
+    return InMemorySource(self.encoded_elements, self.grouped_coder)
+
+  def __repr__(self):
+    return 'GroupingOutputBuffer[%r]' % len(self.elements)
+
+  def append(self, pair):
+    assert not self.frozen
+    k, v = pair
+    self.elements[k].append(v)
+
+  def freeze(self):
+    if not self.frozen:
+      self._encoded_elements = [self.grouped_coder.encode(kv)
+                                for kv in self.elements.iteritems()]
+    self.frozen = True
+    return self._encoded_elements
+
+  @property
+  def encoded_elements(self):
+    return GroupedOutputBuffer(self)
+
+
+class GroupedOutputBuffer(object):
+
+  def __init__(self, buffer):
+    self.buffer = buffer
+
+  def __getitem__(self, ix):
+    return self.buffer.freeze()[ix]
+
+  def __iter__(self):
+    return iter(self.buffer.freeze())
+
+  def __len__(self):
+    return len(self.buffer.freeze())
+
+  def __nonzero__(self):
+    return True
+
+
+class PartialGroupByKeyCombineValues(beam.PTransform):
+
+  def __init__(self, combine_fn, native=True):
+    self.combine_fn = combine_fn
+    self.native = native
+
+  def expand(self, input):
+    if self.native:
+      return beam.pvalue.PCollection(input.pipeline)
+    else:
+      def to_accumulator(v):
+        return self.combine_fn.add_input(
+            self.combine_fn.create_accumulator(), v)
+      return input | beam.Map(lambda (k, v): (k, to_accumulator(v)))
+
+
+class MergeAccumulators(beam.PTransform):
+
+  def __init__(self, combine_fn, native=True):
+    self.combine_fn = combine_fn
+    self.native = native
+
+  def expand(self, input):
+    if self.native:
+      return beam.pvalue.PCollection(input.pipeline)
+    else:
+      merge_accumulators = self.combine_fn.merge_accumulators
+      return input | beam.Map(lambda (k, vs): (k, merge_accumulators(vs)))
+
+
+class ExtractOutputs(beam.PTransform):
+
+  def __init__(self, combine_fn, native=True):
+    self.combine_fn = combine_fn
+    self.native = native
+
+  def expand(self, input):
+    if self.native:
+      return beam.pvalue.PCollection(input.pipeline)
+    else:
+      extract_output = self.combine_fn.extract_output
+      return input | beam.Map(lambda (k, v): (k, extract_output(v)))
+
+
+class WorkerRunnerResult(PipelineResult):
+
+  def wait_until_finish(self, duration=None):
+    pass

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
new file mode 100644
index 0000000..6e13e73
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
@@ -0,0 +1,204 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import tempfile
+import unittest
+
+import apache_beam as beam
+
+from apache_beam.metrics import Metrics
+from apache_beam.metrics.execution import MetricKey
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics.metricbase import MetricName
+
+from apache_beam.pvalue import AsList
+from apache_beam.transforms.util import assert_that
+from apache_beam.transforms.util import BeamAssertException
+from apache_beam.transforms.util import equal_to
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class MapTaskExecutorRunnerTest(unittest.TestCase):
+
+  def create_pipeline(self):
+    return beam.Pipeline(runner=maptask_executor_runner.MapTaskExecutorRunner())
+
+  def test_assert_that(self):
+    with self.assertRaises(BeamAssertException):
+      with self.create_pipeline() as p:
+        assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
+
+  def test_create(self):
+    with self.create_pipeline() as p:
+      assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
+
+  def test_pardo(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create(['a', 'bc'])
+             | beam.Map(lambda e: e * 2)
+             | beam.Map(lambda e: e + 'x'))
+      assert_that(res, equal_to(['aax', 'bcbcx']))
+
+  def test_pardo_metrics(self):
+
+    class MyDoFn(beam.DoFn):
+
+      def start_bundle(self):
+        self.count = Metrics.counter(self.__class__, 'elements')
+
+      def process(self, element):
+        self.count.inc(element)
+        return [element]
+
+    class MyOtherDoFn(beam.DoFn):
+
+      def start_bundle(self):
+        self.count = Metrics.counter(self.__class__, 'elementsplusone')
+
+      def process(self, element):
+        self.count.inc(element + 1)
+        return [element]
+
+    with self.create_pipeline() as p:
+      res = (p | beam.Create([1, 2, 3])
+             | 'mydofn' >> beam.ParDo(MyDoFn())
+             | 'myotherdofn' >> beam.ParDo(MyOtherDoFn()))
+      p.run()
+      if not MetricsEnvironment.METRICS_SUPPORTED:
+        self.skipTest('Metrics are not supported.')
+
+      counter_updates = [{'key': key, 'value': val}
+                         for container in p.runner.metrics_containers()
+                         for key, val in
+                         container.get_updates().counters.items()]
+      counter_values = [update['value'] for update in counter_updates]
+      counter_keys = [update['key'] for update in counter_updates]
+      assert_that(res, equal_to([1, 2, 3]))
+      self.assertEqual(counter_values, [6, 9])
+      self.assertEqual(counter_keys, [
+          MetricKey('mydofn',
+                    MetricName(__name__ + '.MyDoFn', 'elements')),
+          MetricKey('myotherdofn',
+                    MetricName(__name__ + '.MyOtherDoFn', 'elementsplusone'))])
+
+  def test_pardo_side_outputs(self):
+    def tee(elem, *tags):
+      for tag in tags:
+        if tag in elem:
+          yield beam.pvalue.OutputValue(tag, elem)
+    with self.create_pipeline() as p:
+      xy = (p
+            | 'Create' >> beam.Create(['x', 'y', 'xy'])
+            | beam.FlatMap(tee, 'x', 'y').with_outputs())
+      assert_that(xy.x, equal_to(['x', 'xy']), label='x')
+      assert_that(xy.y, equal_to(['y', 'xy']), label='y')
+
+  def test_pardo_side_and_main_outputs(self):
+    def even_odd(elem):
+      yield elem
+      yield beam.pvalue.OutputValue('odd' if elem % 2 else 'even', elem)
+    with self.create_pipeline() as p:
+      ints = p | beam.Create([1, 2, 3])
+      named = ints | 'named' >> beam.FlatMap(
+          even_odd).with_outputs('even', 'odd', main='all')
+      assert_that(named.all, equal_to([1, 2, 3]), label='named.all')
+      assert_that(named.even, equal_to([2]), label='named.even')
+      assert_that(named.odd, equal_to([1, 3]), label='named.odd')
+
+      unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs()
+      unnamed[None] | beam.Map(id)  # pylint: disable=expression-not-assigned
+      assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all')
+      assert_that(unnamed.even, equal_to([2]), label='unnamed.even')
+      assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
+
+  def test_pardo_side_inputs(self):
+    def cross_product(elem, sides):
+      for side in sides:
+        yield elem, side
+    with self.create_pipeline() as p:
+      main = p | 'main' >> beam.Create(['a', 'b', 'c'])
+      side = p | 'side' >> beam.Create(['x', 'y'])
+      assert_that(main | beam.FlatMap(cross_product, AsList(side)),
+                  equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'),
+                            ('a', 'y'), ('b', 'y'), ('c', 'y')]))
+
+  def test_pardo_unfusable_side_inputs(self):
+    def cross_product(elem, sides):
+      for side in sides:
+        yield elem, side
+    with self.create_pipeline() as p:
+      pcoll = p | beam.Create(['a', 'b'])
+      assert_that(pcoll | beam.FlatMap(cross_product, AsList(pcoll)),
+                  equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+    with self.create_pipeline() as p:
+      pcoll = p | beam.Create(['a', 'b'])
+      derived = ((pcoll,) | beam.Flatten()
+                 | beam.Map(lambda x: (x, x))
+                 | beam.GroupByKey()
+                 | 'Unkey' >> beam.Map(lambda (x, _): x))
+      assert_that(
+          pcoll | beam.FlatMap(cross_product, AsList(derived)),
+          equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+  def test_group_by_key(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+             | beam.GroupByKey()
+             | beam.Map(lambda (k, vs): (k, sorted(vs))))
+      assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
+
+  def test_flatten(self):
+    with self.create_pipeline() as p:
+      res = (p | 'a' >> beam.Create(['a']),
+             p | 'bc' >> beam.Create(['b', 'c']),
+             p | 'd' >> beam.Create(['d'])) | beam.Flatten()
+      assert_that(res, equal_to(['a', 'b', 'c', 'd']))
+
+  def test_combine_per_key(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+             | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
+      assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
+
+  def test_read(self):
+    with tempfile.NamedTemporaryFile() as temp_file:
+      temp_file.write('a\nb\nc')
+      temp_file.flush()
+      with self.create_pipeline() as p:
+        assert_that(p | beam.io.ReadFromText(temp_file.name),
+                    equal_to(['a', 'b', 'c']))
+
+  def test_windowing(self):
+    with self.create_pipeline() as p:
+      res = (p
+             | beam.Create([1, 2, 100, 101, 102])
+             | beam.Map(lambda t: TimestampedValue(('k', t), t))
+             | beam.WindowInto(beam.transforms.window.Sessions(10))
+             | beam.GroupByKey()
+             | beam.Map(lambda (k, vs): (k, sorted(vs))))
+      assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/__init__.py b/sdks/python/apache_beam/runners/worker/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
new file mode 100644
index 0000000..6425447
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -0,0 +1,288 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Implementation of DataChannels for communicating across the data plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import logging
+import Queue as queue
+import threading
+
+from apache_beam.coders import coder_impl
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class ClosableOutputStream(type(coder_impl.create_OutputStream())):
+  """A Outputstream for use with CoderImpls that has a close() method."""
+
+  def __init__(self, close_callback=None):
+    super(ClosableOutputStream, self).__init__()
+    self._close_callback = close_callback
+
+  def close(self):
+    if self._close_callback:
+      self._close_callback(self.get())
+
+
+class DataChannel(object):
+  """Represents a channel for reading and writing data over the data plane.
+
+  Read from this channel with the input_elements method::
+
+    for elements_data in data_channel.input_elements(instruction_id, targets):
+      [process elements_data]
+
+  Write to this channel using the output_stream method::
+
+    out1 = data_channel.output_stream(instruction_id, target1)
+    out1.write(...)
+    out1.close()
+
+  When all data for all instructions is written, close the channel::
+
+    data_channel.close()
+  """
+
+  __metaclass__ = abc.ABCMeta
+
+  @abc.abstractmethod
+  def input_elements(self, instruction_id, expected_targets):
+    """Returns an iterable of all Element.Data bundles for instruction_id.
+
+    This iterable terminates only once the full set of data has been recieved
+    for each of the expected targets. It may block waiting for more data.
+
+    Args:
+        instruction_id: which instruction the results must belong to
+        expected_targets: which targets to wait on for completion
+    """
+    raise NotImplementedError(type(self))
+
+  @abc.abstractmethod
+  def output_stream(self, instruction_id, target):
+    """Returns an output stream writing elements to target.
+
+    Args:
+        instruction_id: which instruction this stream belongs to
+        target: the target of the returned stream
+    """
+    raise NotImplementedError(type(self))
+
+  @abc.abstractmethod
+  def close(self):
+    """Closes this channel, indicating that all data has been written.
+
+    Data can continue to be read.
+
+    If this channel is shared by many instructions, should only be called on
+    worker shutdown.
+    """
+    raise NotImplementedError(type(self))
+
+
+class InMemoryDataChannel(DataChannel):
+  """An in-memory implementation of a DataChannel.
+
+  This channel is two-sided.  What is written to one side is read by the other.
+  The inverse() method returns the other side of a instance.
+  """
+
+  def __init__(self, inverse=None):
+    self._inputs = []
+    self._inverse = inverse or InMemoryDataChannel(self)
+
+  def inverse(self):
+    return self._inverse
+
+  def input_elements(self, instruction_id, unused_expected_targets=None):
+    for data in self._inputs:
+      if data.instruction_reference == instruction_id:
+        yield data
+
+  def output_stream(self, instruction_id, target):
+    def add_to_inverse_output(data):
+      self._inverse._inputs.append(  # pylint: disable=protected-access
+          beam_fn_api_pb2.Elements.Data(
+              instruction_reference=instruction_id,
+              target=target,
+              data=data))
+    return ClosableOutputStream(add_to_inverse_output)
+
+  def close(self):
+    pass
+
+
+class _GrpcDataChannel(DataChannel):
+  """Base class for implementing a BeamFnData-based DataChannel."""
+
+  _WRITES_FINISHED = object()
+
+  def __init__(self):
+    self._to_send = queue.Queue()
+    self._received = collections.defaultdict(queue.Queue)
+    self._receive_lock = threading.Lock()
+    self._reads_finished = threading.Event()
+
+  def close(self):
+    self._to_send.put(self._WRITES_FINISHED)
+
+  def wait(self, timeout=None):
+    self._reads_finished.wait(timeout)
+
+  def _receiving_queue(self, instruction_id):
+    with self._receive_lock:
+      return self._received[instruction_id]
+
+  def input_elements(self, instruction_id, expected_targets):
+    received = self._receiving_queue(instruction_id)
+    done_targets = []
+    while len(done_targets) < len(expected_targets):
+      data = received.get()
+      if not data.data and data.target in expected_targets:
+        done_targets.append(data.target)
+      else:
+        assert data.target not in done_targets
+        yield data
+
+  def output_stream(self, instruction_id, target):
+    def add_to_send_queue(data):
+      self._to_send.put(
+          beam_fn_api_pb2.Elements.Data(
+              instruction_reference=instruction_id,
+              target=target,
+              data=data))
+      self._to_send.put(
+          beam_fn_api_pb2.Elements.Data(
+              instruction_reference=instruction_id,
+              target=target,
+              data=''))
+    return ClosableOutputStream(add_to_send_queue)
+
+  def _write_outputs(self):
+    done = False
+    while not done:
+      data = [self._to_send.get()]
+      try:
+        # Coalesce up to 100 other items.
+        for _ in range(100):
+          data.append(self._to_send.get_nowait())
+      except queue.Empty:
+        pass
+      if data[-1] is self._WRITES_FINISHED:
+        done = True
+        data.pop()
+      if data:
+        yield beam_fn_api_pb2.Elements(data=data)
+
+  def _read_inputs(self, elements_iterator):
+    # TODO(robertwb): Pushback/throttling to avoid unbounded buffering.
+    try:
+      for elements in elements_iterator:
+        for data in elements.data:
+          self._receiving_queue(data.instruction_reference).put(data)
+    except:  # pylint: disable=broad-except
+      logging.exception('Failed to read inputs in the data plane')
+      raise
+    finally:
+      self._reads_finished.set()
+
+  def _start_reader(self, elements_iterator):
+    reader = threading.Thread(
+        target=lambda: self._read_inputs(elements_iterator),
+        name='read_grpc_client_inputs')
+    reader.daemon = True
+    reader.start()
+
+
+class GrpcClientDataChannel(_GrpcDataChannel):
+  """A DataChannel wrapping the client side of a BeamFnData connection."""
+
+  def __init__(self, data_stub):
+    super(GrpcClientDataChannel, self).__init__()
+    self._start_reader(data_stub.Data(self._write_outputs()))
+
+
+class GrpcServerDataChannel(
+    beam_fn_api_pb2.BeamFnDataServicer, _GrpcDataChannel):
+  """A DataChannel wrapping the server side of a BeamFnData connection."""
+
+  def Data(self, elements_iterator, context):
+    self._start_reader(elements_iterator)
+    for elements in self._write_outputs():
+      yield elements
+
+
+class DataChannelFactory(object):
+  """An abstract factory for creating ``DataChannel``."""
+
+  __metaclass__ = abc.ABCMeta
+
+  @abc.abstractmethod
+  def create_data_channel(self, function_spec):
+    """Returns a ``DataChannel`` from the given function_spec."""
+    raise NotImplementedError(type(self))
+
+  @abc.abstractmethod
+  def close(self):
+    """Close all channels that this factory owns."""
+    raise NotImplementedError(type(self))
+
+
+class GrpcClientDataChannelFactory(DataChannelFactory):
+  """A factory for ``GrpcClientDataChannel``.
+
+  Caches the created channels by ``data descriptor url``.
+  """
+
+  def __init__(self):
+    self._data_channel_cache = {}
+
+  def create_data_channel(self, function_spec):
+    remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+    function_spec.data.Unpack(remote_grpc_port)
+    url = remote_grpc_port.api_service_descriptor.url
+    if url not in self._data_channel_cache:
+      logging.info('Creating channel for %s', url)
+      grpc_channel = grpc.insecure_channel(url)
+      self._data_channel_cache[url] = GrpcClientDataChannel(
+          beam_fn_api_pb2.BeamFnDataStub(grpc_channel))
+    return self._data_channel_cache[url]
+
+  def close(self):
+    logging.info('Closing all cached grpc data channels.')
+    for _, channel in self._data_channel_cache.items():
+      channel.close()
+    self._data_channel_cache.clear()
+
+
+class InMemoryDataChannelFactory(DataChannelFactory):
+  """A singleton factory for ``InMemoryDataChannel``."""
+
+  def __init__(self, in_memory_data_channel):
+    self._in_memory_data_channel = in_memory_data_channel
+
+  def create_data_channel(self, unused_function_spec):
+    return self._in_memory_data_channel
+
+  def close(self):
+    pass

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py
new file mode 100644
index 0000000..7340789
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for apache_beam.runners.worker.data_plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import sys
+import threading
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import data_plane
+
+
+def timeout(timeout_secs):
+  def decorate(fn):
+    exc_info = []
+
+    def wrapper(*args, **kwargs):
+      def call_fn():
+        try:
+          fn(*args, **kwargs)
+        except:  # pylint: disable=bare-except
+          exc_info[:] = sys.exc_info()
+      thread = threading.Thread(target=call_fn)
+      thread.daemon = True
+      thread.start()
+      thread.join(timeout_secs)
+      if exc_info:
+        t, v, tb = exc_info  # pylint: disable=unbalanced-tuple-unpacking
+        raise t, v, tb
+      assert not thread.is_alive(), 'timed out after %s seconds' % timeout_secs
+    return wrapper
+  return decorate
+
+
+class DataChannelTest(unittest.TestCase):
+
+  @timeout(5)
+  def test_grpc_data_channel(self):
+    data_channel_service = data_plane.GrpcServerDataChannel()
+
+    server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+    beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+        data_channel_service, server)
+    test_port = server.add_insecure_port('[::]:0')
+    server.start()
+
+    data_channel_stub = beam_fn_api_pb2.BeamFnDataStub(
+        grpc.insecure_channel('localhost:%s' % test_port))
+    data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)
+
+    try:
+      self._data_channel_test(data_channel_service, data_channel_client)
+    finally:
+      data_channel_client.close()
+      data_channel_service.close()
+      data_channel_client.wait()
+      data_channel_service.wait()
+
+  def test_in_memory_data_channel(self):
+    channel = data_plane.InMemoryDataChannel()
+    self._data_channel_test(channel, channel.inverse())
+
+  def _data_channel_test(self, server, client):
+    self._data_channel_test_one_direction(server, client)
+    self._data_channel_test_one_direction(client, server)
+
+  def _data_channel_test_one_direction(self, from_channel, to_channel):
+    def send(instruction_id, target, data):
+      stream = from_channel.output_stream(instruction_id, target)
+      stream.write(data)
+      stream.close()
+    target_1 = beam_fn_api_pb2.Target(
+        primitive_transform_reference='1',
+        name='out')
+    target_2 = beam_fn_api_pb2.Target(
+        primitive_transform_reference='2',
+        name='out')
+
+    # Single write.
+    send('0', target_1, 'abc')
+    self.assertEqual(
+        list(to_channel.input_elements('0', [target_1])),
+        [beam_fn_api_pb2.Elements.Data(
+            instruction_reference='0',
+            target=target_1,
+            data='abc')])
+
+    # Multiple interleaved writes to multiple instructions.
+    target_2 = beam_fn_api_pb2.Target(
+        primitive_transform_reference='2',
+        name='out')
+
+    send('1', target_1, 'abc')
+    send('2', target_1, 'def')
+    self.assertEqual(
+        list(to_channel.input_elements('1', [target_1])),
+        [beam_fn_api_pb2.Elements.Data(
+            instruction_reference='1',
+            target=target_1,
+            data='abc')])
+    send('2', target_2, 'ghi')
+    self.assertEqual(
+        list(to_channel.input_elements('2', [target_1, target_2])),
+        [beam_fn_api_pb2.Elements.Data(
+            instruction_reference='2',
+            target=target_1,
+            data='def'),
+         beam_fn_api_pb2.Elements.Data(
+             instruction_reference='2',
+             target=target_2,
+             data='ghi')])
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py
new file mode 100644
index 0000000..b9e36ad
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Beam fn API log handler."""
+
+import logging
+import math
+import Queue as queue
+import threading
+
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class FnApiLogRecordHandler(logging.Handler):
+  """A handler that writes log records to the fn API."""
+
+  # Maximum number of log entries in a single stream request.
+  _MAX_BATCH_SIZE = 1000
+  # Used to indicate the end of stream.
+  _FINISHED = object()
+
+  # Mapping from logging levels to LogEntry levels.
+  LOG_LEVEL_MAP = {
+      logging.FATAL: beam_fn_api_pb2.LogEntry.CRITICAL,
+      logging.ERROR: beam_fn_api_pb2.LogEntry.ERROR,
+      logging.WARNING: beam_fn_api_pb2.LogEntry.WARN,
+      logging.INFO: beam_fn_api_pb2.LogEntry.INFO,
+      logging.DEBUG: beam_fn_api_pb2.LogEntry.DEBUG
+  }
+
+  def __init__(self, log_service_descriptor):
+    super(FnApiLogRecordHandler, self).__init__()
+    self._log_channel = grpc.insecure_channel(log_service_descriptor.url)
+    self._logging_stub = beam_fn_api_pb2.BeamFnLoggingStub(self._log_channel)
+    self._log_entry_queue = queue.Queue()
+
+    log_control_messages = self._logging_stub.Logging(self._write_log_entries())
+    self._reader = threading.Thread(
+        target=lambda: self._read_log_control_messages(log_control_messages),
+        name='read_log_control_messages')
+    self._reader.daemon = True
+    self._reader.start()
+
+  def emit(self, record):
+    log_entry = beam_fn_api_pb2.LogEntry()
+    log_entry.severity = self.LOG_LEVEL_MAP[record.levelno]
+    log_entry.message = self.format(record)
+    log_entry.thread = record.threadName
+    log_entry.log_location = record.module + '.' + record.funcName
+    (fraction, seconds) = math.modf(record.created)
+    nanoseconds = 1e9 * fraction
+    log_entry.timestamp.seconds = int(seconds)
+    log_entry.timestamp.nanos = int(nanoseconds)
+    self._log_entry_queue.put(log_entry)
+
+  def close(self):
+    """Flush out all existing log entries and unregister this handler."""
+    # Acquiring the handler lock ensures ``emit`` is not run until the lock is
+    # released.
+    self.acquire()
+    self._log_entry_queue.put(self._FINISHED)
+    # wait on server to close.
+    self._reader.join()
+    self.release()
+    # Unregister this handler.
+    super(FnApiLogRecordHandler, self).close()
+
+  def _write_log_entries(self):
+    done = False
+    while not done:
+      log_entries = [self._log_entry_queue.get()]
+      try:
+        for _ in range(self._MAX_BATCH_SIZE):
+          log_entries.append(self._log_entry_queue.get_nowait())
+      except queue.Empty:
+        pass
+      if log_entries[-1] is self._FINISHED:
+        done = True
+        log_entries.pop()
+      if log_entries:
+        yield beam_fn_api_pb2.LogEntry.List(log_entries=log_entries)
+
+  def _read_log_control_messages(self, log_control_iterator):
+    # TODO(vikasrk): Handle control messages.
+    for _ in log_control_iterator:
+      pass

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py
new file mode 100644
index 0000000..565bedb
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py
@@ -0,0 +1,105 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import logging
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import log_handler
+
+
+class BeamFnLoggingServicer(beam_fn_api_pb2.BeamFnLoggingServicer):
+
+  def __init__(self):
+    self.log_records_received = []
+
+  def Logging(self, request_iterator, context):
+
+    for log_record in request_iterator:
+      self.log_records_received.append(log_record)
+
+    yield beam_fn_api_pb2.LogControl()
+
+
+class FnApiLogRecordHandlerTest(unittest.TestCase):
+
+  def setUp(self):
+    self.test_logging_service = BeamFnLoggingServicer()
+    self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+    beam_fn_api_pb2.add_BeamFnLoggingServicer_to_server(
+        self.test_logging_service, self.server)
+    self.test_port = self.server.add_insecure_port('[::]:0')
+    self.server.start()
+
+    self.logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+    self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
+    self.fn_log_handler = log_handler.FnApiLogRecordHandler(
+        self.logging_service_descriptor)
+    logging.getLogger().setLevel(logging.INFO)
+    logging.getLogger().addHandler(self.fn_log_handler)
+
+  def tearDown(self):
+    # wait upto 5 seconds.
+    self.server.stop(5)
+
+  def _verify_fn_log_handler(self, num_log_entries):
+    msg = 'Testing fn logging'
+    logging.debug('Debug Message 1')
+    for idx in range(num_log_entries):
+      logging.info('%s: %s', msg, idx)
+    logging.debug('Debug Message 2')
+
+    # Wait for logs to be sent to server.
+    self.fn_log_handler.close()
+
+    num_received_log_entries = 0
+    for outer in self.test_logging_service.log_records_received:
+      for log_entry in outer.log_entries:
+        self.assertEqual(beam_fn_api_pb2.LogEntry.INFO, log_entry.severity)
+        self.assertEqual('%s: %s' % (msg, num_received_log_entries),
+                         log_entry.message)
+        self.assertEqual(u'log_handler_test._verify_fn_log_handler',
+                         log_entry.log_location)
+        self.assertGreater(log_entry.timestamp.seconds, 0)
+        self.assertGreater(log_entry.timestamp.nanos, 0)
+        num_received_log_entries += 1
+
+    self.assertEqual(num_received_log_entries, num_log_entries)
+
+
+# Test cases.
+data = {
+    'one_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE - 47,
+    'exact_multiple': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE,
+    'multi_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE * 3 + 47
+}
+
+
+def _create_test(name, num_logs):
+  setattr(FnApiLogRecordHandlerTest, 'test_%s' % name,
+          lambda self: self._verify_fn_log_handler(num_logs))
+
+
+if __name__ == '__main__':
+  for test_name, num_logs_entries in data.iteritems():
+    _create_test(test_name, num_logs_entries)
+
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.pxd b/sdks/python/apache_beam/runners/worker/logger.pxd
new file mode 100644
index 0000000..201daf4
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.pxd
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+
+from apache_beam.runners.common cimport LoggingContext
+
+
+cdef class PerThreadLoggingContext(LoggingContext):
+  cdef kwargs
+  cdef list stack

http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py
new file mode 100644
index 0000000..217dc58
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.py
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Python worker logging."""
+
+import json
+import logging
+import threading
+import traceback
+
+from apache_beam.runners.common import LoggingContext
+
+
+# Per-thread worker information. This is used only for logging to set
+# context information that changes while work items get executed:
+# work_item_id, step_name, stage_name.
+class _PerThreadWorkerData(threading.local):
+
+  def __init__(self):
+    super(_PerThreadWorkerData, self).__init__()
+    # TODO(robertwb): Consider starting with an initial (ignored) ~20 elements
+    # in the list, as going up and down all the way to zero incurs several
+    # reallocations.
+    self.stack = []
+
+  def get_data(self):
+    all_data = {}
+    for datum in self.stack:
+      all_data.update(datum)
+    return all_data
+
+
+per_thread_worker_data = _PerThreadWorkerData()
+
+
+class PerThreadLoggingContext(LoggingContext):
+  """A context manager to add per thread attributes."""
+
+  def __init__(self, **kwargs):
+    self.kwargs = kwargs
+    self.stack = per_thread_worker_data.stack
+
+  def __enter__(self):
+    self.enter()
+
+  def enter(self):
+    self.stack.append(self.kwargs)
+
+  def __exit__(self, exn_type, exn_value, exn_traceback):
+    self.exit()
+
+  def exit(self):
+    self.stack.pop()
+
+
+class JsonLogFormatter(logging.Formatter):
+  """A JSON formatter class as expected by the logging standard module."""
+
+  def __init__(self, job_id, worker_id):
+    super(JsonLogFormatter, self).__init__()
+    self.job_id = job_id
+    self.worker_id = worker_id
+
+  def format(self, record):
+    """Returns a JSON string based on a LogRecord instance.
+
+    Args:
+      record: A LogRecord instance. See below for details.
+
+    Returns:
+      A JSON string representing the record.
+
+    A LogRecord instance has the following attributes and is used for
+    formatting the final message.
+
+    Attributes:
+      created: A double representing the timestamp for record creation
+        (e.g., 1438365207.624597). Note that the number contains also msecs and
+        microsecs information. Part of this is also available in the 'msecs'
+        attribute.
+      msecs: A double representing the msecs part of the record creation
+        (e.g., 624.5970726013184).
+      msg: Logging message containing formatting instructions or an arbitrary
+        object. This is the first argument of a log call.
+      args: A tuple containing the positional arguments for the logging call.
+      levelname: A string. Possible values are: INFO, WARNING, ERROR, etc.
+      exc_info: None or a 3-tuple with exception information as it is
+        returned by a call to sys.exc_info().
+      name: Logger's name. Most logging is done using the default root logger
+        and therefore the name will be 'root'.
+      filename: Basename of the file where logging occurred.
+      funcName: Name of the function where logging occurred.
+      process: The PID of the process running the worker.
+      thread:  An id for the thread where the record was logged. This is not a
+        real TID (the one provided by OS) but rather the id (address) of a
+        Python thread object. Nevertheless having this value can allow to
+        filter log statement from only one specific thread.
+    """
+    output = {}
+    output['timestamp'] = {
+        'seconds': int(record.created),
+        'nanos': int(record.msecs * 1000000)}
+    # ERROR. INFO, DEBUG log levels translate into the same for severity
+    # property. WARNING becomes WARN.
+    output['severity'] = (
+        record.levelname if record.levelname != 'WARNING' else 'WARN')
+
+    # msg could be an arbitrary object, convert it to a string first.
+    record_msg = str(record.msg)
+
+    # Prepare the actual message using the message formatting string and the
+    # positional arguments as they have been used in the log call.
+    if record.args:
+      try:
+        output['message'] = record_msg % record.args
+      except (TypeError, ValueError):
+        output['message'] = '%s with args (%s)' % (record_msg, record.args)
+    else:
+      output['message'] = record_msg
+
+    # The thread ID is logged as a combination of the process ID and thread ID
+    # since workers can run in multiple processes.
+    output['thread'] = '%s:%s' % (record.process, record.thread)
+    # job ID and worker ID. These do not change during the lifetime of a worker.
+    output['job'] = self.job_id
+    output['worker'] = self.worker_id
+    # Stage, step and work item ID come from thread local storage since they
+    # change with every new work item leased for execution. If there is no
+    # work item ID then we make sure the step is undefined too.
+    data = per_thread_worker_data.get_data()
+    if 'work_item_id' in data:
+      output['work'] = data['work_item_id']
+    if 'stage_name' in data:
+      output['stage'] = data['stage_name']
+    if 'step_name' in data:
+      output['step'] = data['step_name']
+    # All logging happens using the root logger. We will add the basename of the
+    # file and the function name where the logging happened to make it easier
+    # to identify who generated the record.
+    output['logger'] = '%s:%s:%s' % (
+        record.name, record.filename, record.funcName)
+    # Add exception information if any is available.
+    if record.exc_info:
+      output['exception'] = ''.join(
+          traceback.format_exception(*record.exc_info))
+
+    return json.dumps(output)
+
+
+def initialize(job_id, worker_id, log_path):
+  """Initialize root logger so that we log JSON to a file and text to stdout."""
+
+  file_handler = logging.FileHandler(log_path)
+  file_handler.setFormatter(JsonLogFormatter(job_id, worker_id))
+  logging.getLogger().addHandler(file_handler)
+
+  # Set default level to INFO to avoid logging various DEBUG level log calls
+  # sprinkled throughout the code.
+  logging.getLogger().setLevel(logging.INFO)