You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/05/02 00:45:02 UTC
[1/4] beam git commit: Fn API support for Python.
Repository: beam
Updated Branches:
refs/heads/master b8131fe93 -> 7c425b097
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
new file mode 100644
index 0000000..996f44c
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py
@@ -0,0 +1,168 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for apache_beam.runners.worker.sdk_worker."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.io.concat_source_test import RangeSource
+from apache_beam.io.iobase import SourceBundle
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker import sdk_worker
+
+
+class BeamFnControlServicer(beam_fn_api_pb2.BeamFnControlServicer):
+
+ def __init__(self, requests, raise_errors=True):
+ self.requests = requests
+ self.instruction_ids = set(r.instruction_id for r in requests)
+ self.responses = {}
+ self.raise_errors = raise_errors
+
+ def Control(self, response_iterator, context):
+ for request in self.requests:
+ logging.info("Sending request %s", request)
+ yield request
+ for response in response_iterator:
+ logging.info("Got response %s", response)
+ if response.instruction_id != -1:
+ assert response.instruction_id in self.instruction_ids
+ assert response.instruction_id not in self.responses
+ self.responses[response.instruction_id] = response
+ if self.raise_errors and response.error:
+ raise RuntimeError(response.error)
+ elif len(self.responses) == len(self.requests):
+ logging.info("All %s instructions finished.", len(self.requests))
+ return
+ raise RuntimeError("Missing responses: %s" %
+ (self.instruction_ids - set(self.responses.keys())))
+
+
+class SdkWorkerTest(unittest.TestCase):
+
+ def test_fn_registration(self):
+ fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)]
+
+ process_bundle_descriptors = [beam_fn_api_pb2.ProcessBundleDescriptor(
+ id=str(100+ix),
+ primitive_transform=[
+ beam_fn_api_pb2.PrimitiveTransform(function_spec=fn)])
+ for ix, fn in enumerate(fns)]
+
+ test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest(
+ register=beam_fn_api_pb2.RegisterRequest(
+ process_bundle_descriptor=process_bundle_descriptors))])
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server)
+ test_port = server.add_insecure_port("[::]:0")
+ server.start()
+
+ channel = grpc.insecure_channel("localhost:%s" % test_port)
+ harness = sdk_worker.SdkHarness(channel)
+ harness.run()
+ self.assertEqual(
+ harness.worker.fns,
+ {item.id: item for item in fns + process_bundle_descriptors})
+
+ @unittest.skip("initial splitting not in proto")
+ def test_source_split(self):
+ source = RangeSource(0, 100)
+ expected_splits = list(source.split(30))
+
+ worker = sdk_harness.SdkWorker(
+ None, data_plane.GrpcClientDataChannelFactory())
+ worker.register(
+ beam_fn_api_pb2.RegisterRequest(
+ process_bundle_descriptor=[beam_fn_api_pb2.ProcessBundleDescriptor(
+ primitive_transform=[beam_fn_api_pb2.PrimitiveTransform(
+ function_spec=sdk_harness.serialize_and_pack_py_fn(
+ SourceBundle(1.0, source, None, None),
+ sdk_harness.PYTHON_SOURCE_URN,
+ id="src"))])]))
+ split_response = worker.initial_source_split(
+ beam_fn_api_pb2.InitialSourceSplitRequest(
+ desired_bundle_size_bytes=30,
+ source_reference="src"))
+
+ self.assertEqual(
+ expected_splits,
+ [sdk_harness.unpack_and_deserialize_py_fn(s.source)
+ for s in split_response.splits])
+
+ self.assertEqual(
+ [s.weight for s in expected_splits],
+ [s.relative_size for s in split_response.splits])
+
+ @unittest.skip("initial splitting not in proto")
+ def test_source_split_via_instruction(self):
+
+ source = RangeSource(0, 100)
+ expected_splits = list(source.split(30))
+
+ test_controller = BeamFnControlServicer([
+ beam_fn_api_pb2.InstructionRequest(
+ instruction_id="register_request",
+ register=beam_fn_api_pb2.RegisterRequest(
+ process_bundle_descriptor=[
+ beam_fn_api_pb2.ProcessBundleDescriptor(
+ primitive_transform=[beam_fn_api_pb2.PrimitiveTransform(
+ function_spec=sdk_harness.serialize_and_pack_py_fn(
+ SourceBundle(1.0, source, None, None),
+ sdk_harness.PYTHON_SOURCE_URN,
+ id="src"))])])),
+ beam_fn_api_pb2.InstructionRequest(
+ instruction_id="split_request",
+ initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest(
+ desired_bundle_size_bytes=30,
+ source_reference="src"))
+ ])
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server)
+ test_port = server.add_insecure_port("[::]:0")
+ server.start()
+
+ channel = grpc.insecure_channel("localhost:%s" % test_port)
+ harness = sdk_harness.SdkHarness(channel)
+ harness.run()
+
+ split_response = test_controller.responses[
+ "split_request"].initial_source_split
+
+ self.assertEqual(
+ expected_splits,
+ [sdk_harness.unpack_and_deserialize_py_fn(s.source)
+ for s in split_response.splits])
+
+ self.assertEqual(
+ [s.weight for s in expected_splits],
+ [s.relative_size for s in split_response.splits])
+
+
+if __name__ == "__main__":
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sideinputs.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sideinputs.py b/sdks/python/apache_beam/runners/worker/sideinputs.py
new file mode 100644
index 0000000..3bac3d9
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sideinputs.py
@@ -0,0 +1,166 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Utilities for handling side inputs."""
+
+import collections
+import logging
+import Queue
+import threading
+import traceback
+
+from apache_beam.io import iobase
+from apache_beam.transforms import window
+
+
+# Maximum number of reader threads for reading side input sources, per side
+# input.
+MAX_SOURCE_READER_THREADS = 15
+
+# Number of slots for elements in side input element queue. Note that this
+# value is intentionally smaller than MAX_SOURCE_READER_THREADS so as to reduce
+# memory pressure of holding potentially-large elements in memory. Note that
+# the number of pending elements in memory is equal to the sum of
+# MAX_SOURCE_READER_THREADS and ELEMENT_QUEUE_SIZE.
+ELEMENT_QUEUE_SIZE = 10
+
+# Special element value sentinel for signaling reader state.
+READER_THREAD_IS_DONE_SENTINEL = object()
+
+# Used to efficiently window the values of non-windowed side inputs.
+_globally_windowed = window.GlobalWindows.windowed_value(None).with_value
+
+
+class PrefetchingSourceSetIterable(object):
+ """Value iterator that reads concurrently from a set of sources."""
+
+ def __init__(self, sources,
+ max_reader_threads=MAX_SOURCE_READER_THREADS):
+ self.sources = sources
+ self.num_reader_threads = min(max_reader_threads, len(self.sources))
+
+ # Queue for sources that are to be read.
+ self.sources_queue = Queue.Queue()
+ for source in sources:
+ self.sources_queue.put(source)
+ # Queue for elements that have been read.
+ self.element_queue = Queue.Queue(ELEMENT_QUEUE_SIZE)
+ # Queue for exceptions encountered in reader threads; to be rethrown.
+ self.reader_exceptions = Queue.Queue()
+ # Whether we have already iterated; this iterable can only be used once.
+ self.already_iterated = False
+ # Whether an error was encountered in any source reader.
+ self.has_errored = False
+
+ self.reader_threads = []
+ self._start_reader_threads()
+
+ def _start_reader_threads(self):
+ for _ in range(0, self.num_reader_threads):
+ t = threading.Thread(target=self._reader_thread)
+ t.daemon = True
+ t.start()
+ self.reader_threads.append(t)
+
+ def _reader_thread(self):
+ # pylint: disable=too-many-nested-blocks
+ try:
+ while True:
+ try:
+ source = self.sources_queue.get_nowait()
+ if isinstance(source, iobase.BoundedSource):
+ for value in source.read(source.get_range_tracker(None, None)):
+ if self.has_errored:
+ # If any reader has errored, just return.
+ return
+ if isinstance(value, window.WindowedValue):
+ self.element_queue.put(value)
+ else:
+ self.element_queue.put(_globally_windowed(value))
+ else:
+ # Native dataflow source.
+ with source.reader() as reader:
+ returns_windowed_values = reader.returns_windowed_values
+ for value in reader:
+ if self.has_errored:
+ # If any reader has errored, just return.
+ return
+ if returns_windowed_values:
+ self.element_queue.put(value)
+ else:
+ self.element_queue.put(_globally_windowed(value))
+ except Queue.Empty:
+ return
+ except Exception as e: # pylint: disable=broad-except
+ logging.error('Encountered exception in PrefetchingSourceSetIterable '
+ 'reader thread: %s', traceback.format_exc())
+ self.reader_exceptions.put(e)
+ self.has_errored = True
+ finally:
+ self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
+
+ def __iter__(self):
+ if self.already_iterated:
+ raise RuntimeError(
+ 'Can only iterate once over PrefetchingSourceSetIterable instance.')
+ self.already_iterated = True
+
+ # The invariants during execution are:
+ # 1) A worker thread always posts the sentinel as the last thing it does
+ # before exiting.
+ # 2) We always wait for all sentinels and then join all threads.
+ num_readers_finished = 0
+ try:
+ while True:
+ element = self.element_queue.get()
+ if element is READER_THREAD_IS_DONE_SENTINEL:
+ num_readers_finished += 1
+ if num_readers_finished == self.num_reader_threads:
+ return
+ elif self.has_errored:
+ raise self.reader_exceptions.get()
+ else:
+ yield element
+ except GeneratorExit:
+ self.has_errored = True
+ raise
+ finally:
+ while num_readers_finished < self.num_reader_threads:
+ element = self.element_queue.get()
+ if element is READER_THREAD_IS_DONE_SENTINEL:
+ num_readers_finished += 1
+ for t in self.reader_threads:
+ t.join()
+
+
+def get_iterator_fn_for_sources(
+ sources, max_reader_threads=MAX_SOURCE_READER_THREADS):
+ """Returns callable that returns iterator over elements for given sources."""
+ def _inner():
+ return iter(PrefetchingSourceSetIterable(
+ sources, max_reader_threads=max_reader_threads))
+ return _inner
+
+
+class EmulatedIterable(collections.Iterable):
+ """Emulates an iterable for a side input."""
+
+ def __init__(self, iterator_fn):
+ self.iterator_fn = iterator_fn
+
+ def __iter__(self):
+ return self.iterator_fn()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sideinputs_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sideinputs_test.py b/sdks/python/apache_beam/runners/worker/sideinputs_test.py
new file mode 100644
index 0000000..d243bbe
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sideinputs_test.py
@@ -0,0 +1,150 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for side input utilities."""
+
+import logging
+import time
+import unittest
+
+from apache_beam.runners.worker import sideinputs
+
+
+def strip_windows(iterator):
+ return [wv.value for wv in iterator]
+
+
+class FakeSource(object):
+
+ def __init__(self, items):
+ self.items = items
+
+ def reader(self):
+ return FakeSourceReader(self.items)
+
+
+class FakeSourceReader(object):
+
+ def __init__(self, items):
+ self.items = items
+ self.entered = False
+ self.exited = False
+
+ def __iter__(self):
+ return iter(self.items)
+
+ def __enter__(self):
+ self.entered = True
+ return self
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self.exited = True
+
+ @property
+ def returns_windowed_values(self):
+ return False
+
+
+class PrefetchingSourceIteratorTest(unittest.TestCase):
+
+ def test_single_source_iterator_fn(self):
+ sources = [
+ FakeSource([0, 1, 2, 3, 4, 5]),
+ ]
+ iterator_fn = sideinputs.get_iterator_fn_for_sources(
+ sources, max_reader_threads=2)
+ assert list(strip_windows(iterator_fn())) == range(6)
+
+ def test_multiple_sources_iterator_fn(self):
+ sources = [
+ FakeSource([0]),
+ FakeSource([1, 2, 3, 4, 5]),
+ FakeSource([]),
+ FakeSource([6, 7, 8, 9, 10]),
+ ]
+ iterator_fn = sideinputs.get_iterator_fn_for_sources(
+ sources, max_reader_threads=3)
+ assert sorted(strip_windows(iterator_fn())) == range(11)
+
+ def test_multiple_sources_single_reader_iterator_fn(self):
+ sources = [
+ FakeSource([0]),
+ FakeSource([1, 2, 3, 4, 5]),
+ FakeSource([]),
+ FakeSource([6, 7, 8, 9, 10]),
+ ]
+ iterator_fn = sideinputs.get_iterator_fn_for_sources(
+ sources, max_reader_threads=1)
+ assert list(strip_windows(iterator_fn())) == range(11)
+
+ def test_source_iterator_fn_exception(self):
+ class MyException(Exception):
+ pass
+
+ def exception_generator():
+ yield 0
+ time.sleep(0.1)
+ raise MyException('I am an exception!')
+
+ def perpetual_generator(value):
+ while True:
+ yield value
+
+ sources = [
+ FakeSource(perpetual_generator(1)),
+ FakeSource(perpetual_generator(2)),
+ FakeSource(perpetual_generator(3)),
+ FakeSource(perpetual_generator(4)),
+ FakeSource(exception_generator()),
+ ]
+ iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
+ seen = set()
+ with self.assertRaises(MyException):
+ for value in iterator_fn():
+ seen.add(value.value)
+ self.assertEqual(sorted(seen), range(5))
+
+
+class EmulatedCollectionsTest(unittest.TestCase):
+
+ def test_emulated_iterable(self):
+ def _iterable_fn():
+ for i in range(10):
+ yield i
+ iterable = sideinputs.EmulatedIterable(_iterable_fn)
+ # Check that multiple iterations are supported.
+ for _ in range(0, 5):
+ for i, j in enumerate(iterable):
+ self.assertEqual(i, j)
+
+ def test_large_iterable_values(self):
+ # Here, we create a large collection that would be too big for memory-
+ # constained test environments, but should be under the memory limit if
+ # materialized one at a time.
+ def _iterable_fn():
+ for i in range(10):
+ yield ('%d' % i) * (200 * 1024 * 1024)
+ iterable = sideinputs.EmulatedIterable(_iterable_fn)
+ # Check that multiple iterations are supported.
+ for _ in range(0, 3):
+ for i, j in enumerate(iterable):
+ self.assertEqual(('%d' % i) * (200 * 1024 * 1024), j)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/statesampler.pyx
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/statesampler.pyx b/sdks/python/apache_beam/runners/worker/statesampler.pyx
new file mode 100644
index 0000000..3ff6c20
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/statesampler.pyx
@@ -0,0 +1,237 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""State sampler for tracking time spent in execution steps.
+
+The state sampler profiles the time spent in each step of a pipeline.
+Operations (defined in executor.py) which are executed as part of a MapTask are
+instrumented with context managers provided by StateSampler.scoped_state().
+These context managers change the internal state of the StateSampler during each
+relevant Operation's .start(), .process() and .finish() methods. State is
+sampled by a raw C thread, not holding the Python Global Interpreter Lock, which
+queries the StateSampler's internal state at a defined sampling frequency. In a
+common example, a ReadOperation during its .start() method reads an element and
+calls a DoOperation's .process() method, which can call a WriteOperation's
+.process() method. Each element processed causes the current state to
+transition between these states of different Operations. Each time the sampling
+thread queries the current state, the time spent since the previous sample is
+attributed to that state and accumulated. Over time, this allows a granular
+runtime profile to be produced.
+"""
+
+import threading
+import time
+
+
+from apache_beam.utils.counters import Counter
+
+
+cimport cython
+from cpython cimport pythread
+from libc.stdint cimport int32_t, int64_t
+
+cdef extern from "Python.h":
+ # This typically requires the GIL, but we synchronize the list modifications
+ # we use this on via our own lock.
+ cdef void* PyList_GET_ITEM(list, Py_ssize_t index) nogil
+
+cdef extern from "unistd.h" nogil:
+ void usleep(long)
+
+cdef extern from "<time.h>" nogil:
+ struct timespec:
+ long tv_sec # seconds
+ long tv_nsec # nanoseconds
+ int clock_gettime(int clock_id, timespec *result)
+
+cdef inline int64_t get_nsec_time() nogil:
+ """Get current time as microseconds since Unix epoch."""
+ cdef timespec current_time
+ # First argument value of 0 corresponds to CLOCK_REALTIME.
+ clock_gettime(0, ¤t_time)
+ return (
+ (<int64_t> current_time.tv_sec) * 1000000000 + # second to nanoseconds
+ current_time.tv_nsec)
+
+
+class StateSamplerInfo(object):
+ """Info for current state and transition statistics of StateSampler."""
+
+ def __init__(self, state_name, transition_count):
+ self.state_name = state_name
+ self.transition_count = transition_count
+
+ def __repr__(self):
+ return '<StateSamplerInfo %s %d>' % (self.state_name, self.transition_count)
+
+
+# Default period for sampling current state of pipeline execution.
+DEFAULT_SAMPLING_PERIOD_MS = 200
+
+
+cdef class StateSampler(object):
+ """Tracks time spent in states during pipeline execution."""
+
+ cdef object prefix
+ cdef object counter_factory
+ cdef int sampling_period_ms
+
+ cdef dict scoped_states_by_name
+ cdef list scoped_states_by_index
+
+ cdef bint started
+ cdef bint finished
+ cdef object sampling_thread
+
+ # This lock guards members that are shared between threads, specificaly
+ # finished, scoped_states_by_index, and the nsecs field of each state therein.
+ cdef pythread.PyThread_type_lock lock
+
+ cdef public int64_t state_transition_count
+
+ cdef int32_t current_state_index
+
+ def __init__(self, prefix, counter_factory,
+ sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS):
+
+ self.prefix = prefix
+ self.counter_factory = counter_factory
+ self.sampling_period_ms = sampling_period_ms
+
+ self.lock = pythread.PyThread_allocate_lock()
+ self.scoped_states_by_name = {}
+
+ self.current_state_index = 0
+ unknown_state = ScopedState(self, 'unknown', self.current_state_index)
+ pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+ self.scoped_states_by_index = [unknown_state]
+ self.finished = False
+ pythread.PyThread_release_lock(self.lock)
+
+ # Assert that the compiler correctly aligned the current_state field. This
+ # is necessary for reads and writes to this variable to be atomic across
+ # threads without additional synchronization.
+ # States are referenced via an index rather than, say, a pointer because
+ # of better support for 32-bit atomic reads and writes.
+ assert (<int64_t> &self.current_state_index) % sizeof(int32_t) == 0, (
+ 'Address of StateSampler.current_state_index is not word-aligned.')
+
+ def __dealloc__(self):
+ pythread.PyThread_free_lock(self.lock)
+
+ def run(self):
+ cdef int64_t last_nsecs = get_nsec_time()
+ cdef int64_t elapsed_nsecs
+ with nogil:
+ while True:
+ usleep(self.sampling_period_ms * 1000)
+ pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+ try:
+ if self.finished:
+ break
+ elapsed_nsecs = get_nsec_time() - last_nsecs
+ # Take an address as we can't create a reference to the scope
+ # without the GIL.
+ nsecs_ptr = &(<ScopedState>PyList_GET_ITEM(
+ self.scoped_states_by_index, self.current_state_index)).nsecs
+ nsecs_ptr[0] += elapsed_nsecs
+ last_nsecs += elapsed_nsecs
+ finally:
+ pythread.PyThread_release_lock(self.lock)
+
+ def start(self):
+ assert not self.started
+ self.started = True
+ self.sampling_thread = threading.Thread(target=self.run)
+ self.sampling_thread.start()
+
+ def stop(self):
+ assert not self.finished
+ pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+ self.finished = True
+ pythread.PyThread_release_lock(self.lock)
+ # May have to wait up to sampling_period_ms, but the platform-independent
+ # pythread doesn't support conditions.
+ self.sampling_thread.join()
+
+ def stop_if_still_running(self):
+ if self.started and not self.finished:
+ self.stop()
+
+ def get_info(self):
+ """Returns StateSamplerInfo with transition statistics."""
+ return StateSamplerInfo(
+ self.scoped_states_by_index[self.current_state_index].name,
+ self.state_transition_count)
+
+ def scoped_state(self, name):
+ """Returns a context manager managing transitions for a given state."""
+ cdef ScopedState scoped_state = self.scoped_states_by_name.get(name, None)
+ if scoped_state is None:
+ output_counter = self.counter_factory.get_counter(
+ '%s%s-msecs' % (self.prefix, name), Counter.SUM)
+ new_state_index = len(self.scoped_states_by_index)
+ scoped_state = ScopedState(self, name, new_state_index, output_counter)
+ # Both scoped_states_by_index and scoped_state.nsecs are accessed
+ # by the sampling thread; initialize them under the lock.
+ pythread.PyThread_acquire_lock(self.lock, pythread.WAIT_LOCK)
+ self.scoped_states_by_index.append(scoped_state)
+ scoped_state.nsecs = 0
+ pythread.PyThread_release_lock(self.lock)
+ self.scoped_states_by_name[name] = scoped_state
+ return scoped_state
+
+ def commit_counters(self):
+ """Updates output counters with latest state statistics."""
+ for state in self.scoped_states_by_name.values():
+ state_msecs = int(1e-6 * state.nsecs)
+ state.counter.update(state_msecs - state.counter.value())
+
+
+cdef class ScopedState(object):
+ """Context manager class managing transitions for a given sampler state."""
+
+ cdef readonly StateSampler sampler
+ cdef readonly int32_t state_index
+ cdef readonly object counter
+ cdef readonly object name
+ cdef readonly int64_t nsecs
+ cdef int32_t old_state_index
+
+ def __init__(self, sampler, name, state_index, counter=None):
+ self.sampler = sampler
+ self.name = name
+ self.state_index = state_index
+ self.counter = counter
+
+ cpdef __enter__(self):
+ self.old_state_index = self.sampler.current_state_index
+ pythread.PyThread_acquire_lock(self.sampler.lock, pythread.WAIT_LOCK)
+ self.sampler.current_state_index = self.state_index
+ pythread.PyThread_release_lock(self.sampler.lock)
+ self.sampler.state_transition_count += 1
+
+ cpdef __exit__(self, unused_exc_type, unused_exc_value, unused_traceback):
+ pythread.PyThread_acquire_lock(self.sampler.lock, pythread.WAIT_LOCK)
+ self.sampler.current_state_index = self.old_state_index
+ pythread.PyThread_release_lock(self.sampler.lock)
+ self.sampler.state_transition_count += 1
+
+ def __repr__(self):
+ return "ScopedState[%s, %s, %s]" % (self.name, self.state_index, self.nsecs)
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/statesampler_fake.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fake.py b/sdks/python/apache_beam/runners/worker/statesampler_fake.py
new file mode 100644
index 0000000..efd7f2d
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/statesampler_fake.py
@@ -0,0 +1,34 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+class StateSampler(object):
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def scoped_state(self, name):
+ return _FakeScopedState()
+
+
+class _FakeScopedState(object):
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *unused_args):
+ pass
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/statesampler_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/statesampler_test.py b/sdks/python/apache_beam/runners/worker/statesampler_test.py
new file mode 100644
index 0000000..663cdec
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/statesampler_test.py
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for state sampler."""
+
+import logging
+import time
+import unittest
+
+from nose.plugins.skip import SkipTest
+
+from apache_beam.utils.counters import CounterFactory
+
+
+class StateSamplerTest(unittest.TestCase):
+
+ def setUp(self):
+ try:
+ # pylint: disable=global-variable-not-assigned
+ global statesampler
+ import statesampler
+ except ImportError:
+ raise SkipTest('State sampler not compiled.')
+ super(StateSamplerTest, self).setUp()
+
+ def test_basic_sampler(self):
+ # Set up state sampler.
+ counter_factory = CounterFactory()
+ sampler = statesampler.StateSampler('basic-', counter_factory,
+ sampling_period_ms=1)
+
+ # Run basic workload transitioning between 3 states.
+ sampler.start()
+ with sampler.scoped_state('statea'):
+ time.sleep(0.1)
+ with sampler.scoped_state('stateb'):
+ time.sleep(0.2 / 2)
+ with sampler.scoped_state('statec'):
+ time.sleep(0.3)
+ time.sleep(0.2 / 2)
+ sampler.stop()
+ sampler.commit_counters()
+
+ # Test that sampled state timings are close to their expected values.
+ expected_counter_values = {
+ 'basic-statea-msecs': 100,
+ 'basic-stateb-msecs': 200,
+ 'basic-statec-msecs': 300,
+ }
+ for counter in counter_factory.get_counters():
+ self.assertIn(counter.name, expected_counter_values)
+ expected_value = expected_counter_values[counter.name]
+ actual_value = counter.value()
+ self.assertGreater(actual_value, expected_value * 0.75)
+ self.assertLess(actual_value, expected_value * 1.25)
+
+ def test_sampler_transition_overhead(self):
+ # Set up state sampler.
+ counter_factory = CounterFactory()
+ sampler = statesampler.StateSampler('overhead-', counter_factory,
+ sampling_period_ms=10)
+
+ # Run basic workload transitioning between 3 states.
+ state_a = sampler.scoped_state('statea')
+ state_b = sampler.scoped_state('stateb')
+ state_c = sampler.scoped_state('statec')
+ start_time = time.time()
+ sampler.start()
+ for _ in range(100000):
+ with state_a:
+ with state_b:
+ for _ in range(10):
+ with state_c:
+ pass
+ sampler.stop()
+ elapsed_time = time.time() - start_time
+ state_transition_count = sampler.get_info().transition_count
+ overhead_us = 1000000.0 * elapsed_time / state_transition_count
+ logging.info('Overhead per transition: %fus', overhead_us)
+ # Conservative upper bound on overhead in microseconds (we expect this to
+ # take 0.17us when compiled in opt mode or 0.48 us when compiled with in
+ # debug mode).
+ self.assertLess(overhead_us, 10.0)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/generate_pydoc.sh
----------------------------------------------------------------------
diff --git a/sdks/python/generate_pydoc.sh b/sdks/python/generate_pydoc.sh
index b04e27a..6039942 100755
--- a/sdks/python/generate_pydoc.sh
+++ b/sdks/python/generate_pydoc.sh
@@ -44,6 +44,8 @@ python $(type -p sphinx-apidoc) -f -o target/docs/source apache_beam \
# Remove Cython modules from doc template; they won't load
sed -i -e '/.. automodule:: apache_beam.coders.stream/d' \
target/docs/source/apache_beam.coders.rst
+sed -i -e '/.. automodule:: apache_beam.runners.worker.statesampler/d' \
+ target/docs/source/apache_beam.runners.worker.rst
# Create the configuration and index files
cat > target/docs/source/conf.py <<'EOF'
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/setup.py
----------------------------------------------------------------------
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index 615931b..f527362 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -121,12 +121,16 @@ setuptools.setup(
author=PACKAGE_AUTHOR,
author_email=PACKAGE_EMAIL,
packages=setuptools.find_packages(),
- package_data={'apache_beam': ['**/*.pyx', '**/*.pxd', 'tests/data/*']},
+ package_data={'apache_beam': [
+ '*/*.pyx', '*/*/*.pyx', '*/*.pxd', '*/*/*.pxd', 'tests/data/*']},
ext_modules=cythonize([
'**/*.pyx',
'apache_beam/coders/coder_impl.py',
- 'apache_beam/runners/common.py',
'apache_beam/metrics/execution.py',
+ 'apache_beam/runners/common.py',
+ 'apache_beam/runners/worker/logger.py',
+ 'apache_beam/runners/worker/opcounters.py',
+ 'apache_beam/runners/worker/operations.py',
'apache_beam/transforms/cy_combiners.py',
'apache_beam/utils/counters.py',
'apache_beam/utils/windowed_value.py',
[2/4] beam git commit: Fn API support for Python.
Posted by ro...@apache.org.
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger_test.py b/sdks/python/apache_beam/runners/worker/logger_test.py
new file mode 100644
index 0000000..cf3f692
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger_test.py
@@ -0,0 +1,182 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for worker logging utilities."""
+
+import json
+import logging
+import sys
+import threading
+import unittest
+
+from apache_beam.runners.worker import logger
+
+
+class PerThreadLoggingContextTest(unittest.TestCase):
+
+ def thread_check_attribute(self, name):
+ self.assertFalse(name in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(**{name: 'thread-value'}):
+ self.assertEqual(
+ logger.per_thread_worker_data.get_data()[name], 'thread-value')
+ self.assertFalse(name in logger.per_thread_worker_data.get_data())
+
+ def test_per_thread_attribute(self):
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(xyz='value'):
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ thread = threading.Thread(
+ target=self.thread_check_attribute, args=('xyz',))
+ thread.start()
+ thread.join()
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+ def test_set_when_undefined(self):
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(xyz='value'):
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+ def test_set_when_already_defined(self):
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+ with logger.PerThreadLoggingContext(xyz='value'):
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ with logger.PerThreadLoggingContext(xyz='value2'):
+ self.assertEqual(
+ logger.per_thread_worker_data.get_data()['xyz'], 'value2')
+ self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
+ self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
+
+
+class JsonLogFormatterTest(unittest.TestCase):
+
+ SAMPLE_RECORD = {
+ 'created': 123456.789, 'msecs': 789.654321,
+ 'msg': '%s:%d:%.2f', 'args': ('xyz', 4, 3.14),
+ 'levelname': 'WARNING',
+ 'process': 'pid', 'thread': 'tid',
+ 'name': 'name', 'filename': 'file', 'funcName': 'func',
+ 'exc_info': None}
+
+ SAMPLE_OUTPUT = {
+ 'timestamp': {'seconds': 123456, 'nanos': 789654321},
+ 'severity': 'WARN', 'message': 'xyz:4:3.14', 'thread': 'pid:tid',
+ 'job': 'jobid', 'worker': 'workerid', 'logger': 'name:file:func'}
+
+ def create_log_record(self, **kwargs):
+
+ class Record(object):
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.iteritems():
+ setattr(self, k, v)
+
+ return Record(**kwargs)
+
+ def test_basic_record(self):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ self.assertEqual(json.loads(formatter.format(record)), self.SAMPLE_OUTPUT)
+
+ def execute_multiple_cases(self, test_cases):
+ record = self.SAMPLE_RECORD
+ output = self.SAMPLE_OUTPUT
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+
+ for case in test_cases:
+ record['msg'] = case['msg']
+ record['args'] = case['args']
+ output['message'] = case['expected']
+
+ self.assertEqual(
+ json.loads(formatter.format(self.create_log_record(**record))),
+ output)
+
+ def test_record_with_format_character(self):
+ test_cases = [
+ {'msg': '%A', 'args': (), 'expected': '%A'},
+ {'msg': '%s', 'args': (), 'expected': '%s'},
+ {'msg': '%A%s', 'args': ('xy'), 'expected': '%A%s with args (xy)'},
+ {'msg': '%s%s', 'args': (1), 'expected': '%s%s with args (1)'},
+ ]
+
+ self.execute_multiple_cases(test_cases)
+
+ def test_record_with_arbitrary_messages(self):
+ test_cases = [
+ {'msg': ImportError('abc'), 'args': (), 'expected': 'abc'},
+ {'msg': TypeError('abc %s'), 'args': ('def'), 'expected': 'abc def'},
+ ]
+
+ self.execute_multiple_cases(test_cases)
+
+ def test_record_with_per_thread_info(self):
+ with logger.PerThreadLoggingContext(
+ work_item_id='workitem', stage_name='stage', step_name='step'):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output = json.loads(formatter.format(record))
+ expected_output = dict(self.SAMPLE_OUTPUT)
+ expected_output.update(
+ {'work': 'workitem', 'stage': 'stage', 'step': 'step'})
+ self.assertEqual(log_output, expected_output)
+
+ def test_nested_with_per_thread_info(self):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ with logger.PerThreadLoggingContext(
+ work_item_id='workitem', stage_name='stage', step_name='step1'):
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output1 = json.loads(formatter.format(record))
+
+ with logger.PerThreadLoggingContext(step_name='step2'):
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output2 = json.loads(formatter.format(record))
+
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output3 = json.loads(formatter.format(record))
+
+ record = self.create_log_record(**self.SAMPLE_RECORD)
+ log_output4 = json.loads(formatter.format(record))
+
+ self.assertEqual(log_output1, dict(
+ self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
+ self.assertEqual(log_output2, dict(
+ self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2'))
+ self.assertEqual(log_output3, dict(
+ self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
+ self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
+
+ def test_exception_record(self):
+ formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
+ try:
+ raise ValueError('Something')
+ except ValueError:
+ attribs = dict(self.SAMPLE_RECORD)
+ attribs.update({'exc_info': sys.exc_info()})
+ record = self.create_log_record(**attribs)
+ log_output = json.loads(formatter.format(record))
+ # Check if exception type, its message, and stack trace information are in.
+ exn_output = log_output.pop('exception')
+ self.assertNotEqual(exn_output.find('ValueError: Something'), -1)
+ self.assertNotEqual(exn_output.find('logger_test.py'), -1)
+ self.assertEqual(log_output, self.SAMPLE_OUTPUT)
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters.pxd b/sdks/python/apache_beam/runners/worker/opcounters.pxd
new file mode 100644
index 0000000..5c1079f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters.pxd
@@ -0,0 +1,45 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+cimport libc.stdint
+
+from apache_beam.utils.counters cimport Counter
+
+
+cdef class SumAccumulator(object):
+ cdef libc.stdint.int64_t _value
+ cpdef update(self, libc.stdint.int64_t value)
+ cpdef libc.stdint.int64_t value(self)
+
+
+cdef class OperationCounters(object):
+ cdef public _counter_factory
+ cdef public Counter element_counter
+ cdef public Counter mean_byte_counter
+ cdef public coder_impl
+ cdef public SumAccumulator active_accumulator
+ cdef public libc.stdint.int64_t _sample_counter
+ cdef public libc.stdint.int64_t _next_sample
+
+ cpdef update_from(self, windowed_value)
+ cdef inline do_sample(self, windowed_value)
+ cpdef update_collect(self)
+
+ cdef libc.stdint.int64_t _compute_next_sample(self, libc.stdint.int64_t i)
+ cdef inline bint _should_sample(self)
+ cpdef bint should_sample(self)
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters.py b/sdks/python/apache_beam/runners/worker/opcounters.py
new file mode 100644
index 0000000..56ce0db
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters.py
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""Counters collect the progress of the Worker for reporting to the service."""
+
+from __future__ import absolute_import
+import math
+import random
+
+from apache_beam.utils.counters import Counter
+
+
+class SumAccumulator(object):
+ """Accumulator for collecting byte counts."""
+
+ def __init__(self):
+ self._value = 0
+
+ def update(self, value):
+ self._value += value
+
+ def value(self):
+ return self._value
+
+
+class OperationCounters(object):
+ """The set of basic counters to attach to an Operation."""
+
+ def __init__(self, counter_factory, step_name, coder, output_index):
+ self._counter_factory = counter_factory
+ self.element_counter = counter_factory.get_counter(
+ '%s-out%d-ElementCount' % (step_name, output_index), Counter.SUM)
+ self.mean_byte_counter = counter_factory.get_counter(
+ '%s-out%d-MeanByteCount' % (step_name, output_index), Counter.MEAN)
+ self.coder_impl = coder.get_impl()
+ self.active_accumulator = None
+ self._sample_counter = 0
+ self._next_sample = 0
+
+ def update_from(self, windowed_value):
+ """Add one value to this counter."""
+ self.element_counter.update(1)
+ if self._should_sample():
+ self.do_sample(windowed_value)
+
+ def _observable_callback(self, inner_coder_impl, accumulator):
+ def _observable_callback_inner(value, is_encoded=False):
+ # TODO(ccy): If this stream is large, sample it as well.
+ # To do this, we'll need to compute the average size of elements
+ # in this stream to add the *total* size of this stream to accumulator.
+ # We'll also want make sure we sample at least some of this stream
+ # (as self.should_sample() may be sampling very sparsely by now).
+ if is_encoded:
+ size = len(value)
+ accumulator.update(size)
+ else:
+ accumulator.update(inner_coder_impl.estimate_size(value))
+ return _observable_callback_inner
+
+ def do_sample(self, windowed_value):
+ size, observables = (
+ self.coder_impl.get_estimated_size_and_observables(windowed_value))
+ if not observables:
+ self.mean_byte_counter.update(size)
+ else:
+ self.active_accumulator = SumAccumulator()
+ self.active_accumulator.update(size)
+ for observable, inner_coder_impl in observables:
+ observable.register_observer(
+ self._observable_callback(
+ inner_coder_impl, self.active_accumulator))
+
+ def update_collect(self):
+ """Collects the accumulated size estimates.
+
+ Now that the element has been processed, we ask our accumulator
+ for the total and store the result in a counter.
+ """
+ if self.active_accumulator is not None:
+ self.mean_byte_counter.update(self.active_accumulator.value())
+ self.active_accumulator = None
+
+ def _compute_next_sample(self, i):
+ # https://en.wikipedia.org/wiki/Reservoir_sampling#Fast_Approximation
+ gap = math.log(1.0 - random.random()) / math.log(1.0 - 10.0/i)
+ return i + math.floor(gap)
+
+ def _should_sample(self):
+ """Determines whether to sample the next element.
+
+ Size calculation can be expensive, so we don't do it for each element.
+ Because we need only an estimate of average size, we sample.
+
+ We always sample the first 10 elements, then the sampling rate
+ is approximately 10/N. After reading N elements, of the next N,
+ we will sample approximately 10*ln(2) (about 7) elements.
+
+ This algorithm samples at the same rate as Reservoir Sampling, but
+ it never throws away early results. (Because we keep only a
+ running accumulation, storage is not a problem, so there is no
+ need to discard earlier calculations.)
+
+ Because we accumulate and do not replace, our statistics are
+ biased toward early data. If the data are distributed uniformly,
+ this is not a problem. If the data change over time (i.e., the
+ element size tends to grow or shrink over time), our estimate will
+ show the bias. We could correct this by giving weight N to each
+ sample, since each sample is a stand-in for the N/(10*ln(2))
+ samples around it, which is proportional to N. Since we do not
+ expect biased data, for efficiency we omit the extra multiplication.
+ We could reduce the early-data bias by putting a lower bound on
+ the sampling rate.
+
+ Computing random.randint(1, self._sample_counter) for each element
+ is too slow, so when the sample size is big enough (we estimate 30
+ is big enough), we estimate the size of the gap after each sample.
+ This estimation allows us to call random much less often.
+
+ Returns:
+ True if it is time to compute another element's size.
+ """
+
+ self._sample_counter += 1
+ if self._next_sample == 0:
+ if random.randint(1, self._sample_counter) <= 10:
+ if self._sample_counter > 30:
+ self._next_sample = self._compute_next_sample(self._sample_counter)
+ return True
+ return False
+ elif self._sample_counter >= self._next_sample:
+ self._next_sample = self._compute_next_sample(self._sample_counter)
+ return True
+ return False
+
+ def should_sample(self):
+ # We create this separate method because the above "_should_sample()" method
+ # is marked as inline in Cython and thus can't be exposed to Python code.
+ return self._should_sample()
+
+ def __str__(self):
+ return '<%s [%s]>' % (self.__class__.__name__,
+ ', '.join([str(x) for x in self.__iter__()]))
+
+ def __repr__(self):
+ return '<%s %s at %s>' % (self.__class__.__name__,
+ [x for x in self.__iter__()], hex(id(self)))
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/opcounters_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/opcounters_test.py b/sdks/python/apache_beam/runners/worker/opcounters_test.py
new file mode 100644
index 0000000..74561b8
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/opcounters_test.py
@@ -0,0 +1,149 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import math
+import random
+import unittest
+
+from apache_beam import coders
+from apache_beam.runners.worker.opcounters import OperationCounters
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.utils.counters import CounterFactory
+
+
+# Classes to test that we can handle a variety of objects.
+# These have to be at top level so the pickler can find them.
+
+
+class OldClassThatDoesNotImplementLen: # pylint: disable=old-style-class
+
+ def __init__(self):
+ pass
+
+
+class ObjectThatDoesNotImplementLen(object):
+
+ def __init__(self):
+ pass
+
+
+class OperationCountersTest(unittest.TestCase):
+
+ def verify_counters(self, opcounts, expected_elements, expected_size=None):
+ self.assertEqual(expected_elements, opcounts.element_counter.value())
+ if expected_size is not None:
+ if math.isnan(expected_size):
+ self.assertTrue(math.isnan(opcounts.mean_byte_counter.value()))
+ else:
+ self.assertEqual(expected_size, opcounts.mean_byte_counter.value())
+
+ def test_update_int(self):
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coders.PickleCoder(), 0)
+ self.verify_counters(opcounts, 0)
+ opcounts.update_from(GlobalWindows.windowed_value(1))
+ self.verify_counters(opcounts, 1)
+
+ def test_update_str(self):
+ coder = coders.PickleCoder()
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+ value = GlobalWindows.windowed_value('abcde')
+ opcounts.update_from(value)
+ estimated_size = coder.estimate_size(value)
+ self.verify_counters(opcounts, 1, estimated_size)
+
+ def test_update_old_object(self):
+ coder = coders.PickleCoder()
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+ obj = OldClassThatDoesNotImplementLen()
+ value = GlobalWindows.windowed_value(obj)
+ opcounts.update_from(value)
+ estimated_size = coder.estimate_size(value)
+ self.verify_counters(opcounts, 1, estimated_size)
+
+ def test_update_new_object(self):
+ coder = coders.PickleCoder()
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+
+ obj = ObjectThatDoesNotImplementLen()
+ value = GlobalWindows.windowed_value(obj)
+ opcounts.update_from(value)
+ estimated_size = coder.estimate_size(value)
+ self.verify_counters(opcounts, 1, estimated_size)
+
+ def test_update_multiple(self):
+ coder = coders.PickleCoder()
+ total_size = 0
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coder, 0)
+ self.verify_counters(opcounts, 0, float('nan'))
+ value = GlobalWindows.windowed_value('abcde')
+ opcounts.update_from(value)
+ total_size += coder.estimate_size(value)
+ value = GlobalWindows.windowed_value('defghij')
+ opcounts.update_from(value)
+ total_size += coder.estimate_size(value)
+ self.verify_counters(opcounts, 2, float(total_size) / 2)
+ value = GlobalWindows.windowed_value('klmnop')
+ opcounts.update_from(value)
+ total_size += coder.estimate_size(value)
+ self.verify_counters(opcounts, 3, float(total_size) / 3)
+
+ def test_should_sample(self):
+ # Order of magnitude more buckets than highest constant in code under test.
+ buckets = [0] * 300
+ # The seed is arbitrary and exists just to ensure this test is robust.
+ # If you don't like this seed, try your own; the test should still pass.
+ random.seed(1717)
+ # Do enough runs that the expected hits even in the last buckets
+ # is big enough to expect some statistical smoothing.
+ total_runs = 10 * len(buckets)
+
+ # Fill the buckets.
+ for _ in xrange(total_runs):
+ opcounts = OperationCounters(CounterFactory(), 'some-name',
+ coders.PickleCoder(), 0)
+ for i in xrange(len(buckets)):
+ if opcounts.should_sample():
+ buckets[i] += 1
+
+ # Look at the buckets to see if they are likely.
+ for i in xrange(10):
+ self.assertEqual(total_runs, buckets[i])
+ for i in xrange(10, len(buckets)):
+ self.assertTrue(buckets[i] > 7 * total_runs / i,
+ 'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % (
+ i, buckets[i],
+ 10 * total_runs / i,
+ buckets[i] / (10.0 * total_runs / i)))
+ self.assertTrue(buckets[i] < 14 * total_runs / i,
+ 'i=%d, buckets[i]=%d, expected=%d, ratio=%f' % (
+ i, buckets[i],
+ 10 * total_runs / i,
+ buckets[i] / (10.0 * total_runs / i)))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operation_specs.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py
new file mode 100644
index 0000000..977e165
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operation_specs.py
@@ -0,0 +1,368 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Worker utilities for representing MapTasks.
+
+Each MapTask represents a sequence of ParallelInstruction(s): read from a
+source, write to a sink, parallel do, etc.
+"""
+
+import collections
+
+from apache_beam import coders
+
+
+def build_worker_instruction(*args):
+ """Create an object representing a ParallelInstruction protobuf.
+
+ This will be a collections.namedtuple with a custom __str__ method.
+
+ Alas, this wrapper is not known to pylint, which thinks it creates
+ constants. You may have to put a disable=invalid-name pylint
+ annotation on any use of this, depending on your names.
+
+ Args:
+ *args: first argument is the name of the type to create. Should
+ start with "Worker". Second arguments is alist of the
+ attributes of this object.
+ Returns:
+ A new class, a subclass of tuple, that represents the protobuf.
+ """
+ tuple_class = collections.namedtuple(*args)
+ tuple_class.__str__ = worker_object_to_string
+ tuple_class.__repr__ = worker_object_to_string
+ return tuple_class
+
+
+def worker_printable_fields(workerproto):
+ """Returns the interesting fields of a Worker* object."""
+ return ['%s=%s' % (name, value)
+ # _asdict is the only way and cannot subclass this generated class
+ # pylint: disable=protected-access
+ for name, value in workerproto._asdict().iteritems()
+ # want to output value 0 but not None nor []
+ if (value or value == 0)
+ and name not in
+ ('coder', 'coders', 'output_coders',
+ 'elements',
+ 'combine_fn', 'serialized_fn', 'window_fn',
+ 'append_trailing_newlines', 'strip_trailing_newlines',
+ 'compression_type', 'context',
+ 'start_shuffle_position', 'end_shuffle_position',
+ 'shuffle_reader_config', 'shuffle_writer_config')]
+
+
+def worker_object_to_string(worker_object):
+ """Returns a string compactly representing a Worker* object."""
+ return '%s(%s)' % (worker_object.__class__.__name__,
+ ', '.join(worker_printable_fields(worker_object)))
+
+
+# All the following Worker* definitions will have these lint problems:
+# pylint: disable=invalid-name
+# pylint: disable=pointless-string-statement
+
+
+WorkerRead = build_worker_instruction(
+ 'WorkerRead', ['source', 'output_coders'])
+"""Worker details needed to read from a source.
+
+Attributes:
+ source: a source object.
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerSideInputSource = build_worker_instruction(
+ 'WorkerSideInputSource', ['source', 'tag'])
+"""Worker details needed to read from a side input source.
+
+Attributes:
+ source: a source object.
+ tag: string tag for this side input.
+"""
+
+
+WorkerGroupingShuffleRead = build_worker_instruction(
+ 'WorkerGroupingShuffleRead',
+ ['start_shuffle_position', 'end_shuffle_position',
+ 'shuffle_reader_config', 'coder', 'output_coders'])
+"""Worker details needed to read from a grouping shuffle source.
+
+Attributes:
+ start_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to start reading.
+ end_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to stop reading.
+ shuffle_reader_config: An opaque string used to initialize the shuffle
+ reader. Contains things like connection endpoints for the shuffle
+ server appliance and various options.
+ coder: The KV coder used to decode shuffle entries.
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerUngroupedShuffleRead = build_worker_instruction(
+ 'WorkerUngroupedShuffleRead',
+ ['start_shuffle_position', 'end_shuffle_position',
+ 'shuffle_reader_config', 'coder', 'output_coders'])
+"""Worker details needed to read from an ungrouped shuffle source.
+
+Attributes:
+ start_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to start reading.
+ end_shuffle_position: An opaque string to be passed to the shuffle
+ source to indicate where to stop reading.
+ shuffle_reader_config: An opaque string used to initialize the shuffle
+ reader. Contains things like connection endpoints for the shuffle
+ server appliance and various options.
+ coder: The value coder used to decode shuffle entries.
+"""
+
+
+WorkerWrite = build_worker_instruction(
+ 'WorkerWrite', ['sink', 'input', 'output_coders'])
+"""Worker details needed to write to a sink.
+
+Attributes:
+ sink: a sink object.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple, coder to use to estimate bytes written.
+"""
+
+
+WorkerInMemoryWrite = build_worker_instruction(
+ 'WorkerInMemoryWrite',
+ ['output_buffer', 'write_windowed_values', 'input', 'output_coders'])
+"""Worker details needed to write to a in-memory sink.
+
+Used only for unit testing. It makes worker tests less cluttered with code like
+"write to a file and then check file contents".
+
+Attributes:
+ output_buffer: list to which output elements will be appended
+ write_windowed_values: whether to record the entire WindowedValue outputs,
+ or just the raw (unwindowed) value
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple, coder to use to estimate bytes written.
+"""
+
+
+WorkerShuffleWrite = build_worker_instruction(
+ 'WorkerShuffleWrite',
+ ['shuffle_kind', 'shuffle_writer_config', 'input', 'output_coders'])
+"""Worker details needed to write to a shuffle sink.
+
+Attributes:
+ shuffle_kind: A string describing the shuffle kind. This can control the
+ way the worker interacts with the shuffle sink. The possible values are:
+ 'ungrouped', 'group_keys', and 'group_keys_and_sort_values'.
+ shuffle_writer_config: An opaque string used to initialize the shuffle
+ write. Contains things like connection endpoints for the shuffle
+ server appliance and various options.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple of the coder for input elements. If the
+ shuffle_kind is grouping, this is expected to be a KV coder.
+"""
+
+
+WorkerDoFn = build_worker_instruction(
+ 'WorkerDoFn',
+ ['serialized_fn', 'output_tags', 'input', 'side_inputs', 'output_coders'])
+"""Worker details needed to run a DoFn.
+Attributes:
+ serialized_fn: A serialized DoFn object to be run for each input element.
+ output_tags: The string tags used to identify the outputs of a ParDo
+ operation. The tag is present even if the ParDo has just one output
+ (e.g., ['out'].
+ output_coders: array of coders, one for each output.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ side_inputs: A list of Worker...Read instances describing sources to be
+ used for getting values. The types supported right now are
+ WorkerInMemoryRead and WorkerTextRead.
+"""
+
+
+WorkerReifyTimestampAndWindows = build_worker_instruction(
+ 'WorkerReifyTimestampAndWindows',
+ ['output_tags', 'input', 'output_coders'])
+"""Worker details needed to run a WindowInto.
+Attributes:
+ output_tags: The string tags used to identify the outputs of a ParDo
+ operation. The tag is present even if the ParDo has just one output
+ (e.g., ['out'].
+ output_coders: array of coders, one for each output.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+"""
+
+
+WorkerMergeWindows = build_worker_instruction(
+ 'WorkerMergeWindows',
+ ['window_fn', 'combine_fn', 'phase', 'output_tags', 'input', 'coders',
+ 'context', 'output_coders'])
+"""Worker details needed to run a MergeWindows (aka. GroupAlsoByWindows).
+Attributes:
+ window_fn: A serialized Windowing object representing the windowing strategy.
+ combine_fn: A serialized CombineFn object to be used after executing the
+ GroupAlsoByWindows operation. May be None if not a combining operation.
+ phase: Possible values are 'all', 'add', 'merge', and 'extract'.
+ A runner optimizer may split the user combiner in 3 separate
+ phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees
+ fit. The phase attribute dictates which DoFn is actually running in
+ the worker. May be None if not a combining operation.
+ output_tags: The string tags used to identify the outputs of a ParDo
+ operation. The tag is present even if the ParDo has just one output
+ (e.g., ['out'].
+ output_coders: array of coders, one for each output.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ coders: A 2-tuple of coders (key, value) to encode shuffle entries.
+ context: The ExecutionContext object for the current work item.
+"""
+
+
+WorkerCombineFn = build_worker_instruction(
+ 'WorkerCombineFn',
+ ['serialized_fn', 'phase', 'input', 'output_coders'])
+"""Worker details needed to run a CombineFn.
+Attributes:
+ serialized_fn: A serialized CombineFn object to be used.
+ phase: Possible values are 'all', 'add', 'merge', and 'extract'.
+ A runner optimizer may split the user combiner in 3 separate
+ phases (ADD, MERGE, and EXTRACT), on separate VMs, as it sees
+ fit. The phase attribute dictates which DoFn is actually running in
+ the worker.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerPartialGroupByKey = build_worker_instruction(
+ 'WorkerPartialGroupByKey',
+ ['combine_fn', 'input', 'output_coders'])
+"""Worker details needed to run a partial group-by-key.
+Attributes:
+ combine_fn: A serialized CombineFn object to be used.
+ input: A (producer index, output index) tuple representing the
+ ParallelInstruction operation whose output feeds into this operation.
+ The output index is 0 except for multi-output operations (like ParDo).
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+WorkerFlatten = build_worker_instruction(
+ 'WorkerFlatten',
+ ['inputs', 'output_coders'])
+"""Worker details needed to run a Flatten.
+Attributes:
+ inputs: A list of tuples, each (producer index, output index), representing
+ the ParallelInstruction operations whose output feeds into this operation.
+ The output index is 0 unless the input is from a multi-output
+ operation (such as ParDo).
+ output_coders: 1-tuple of the coder for the output.
+"""
+
+
+def get_coder_from_spec(coder_spec):
+ """Return a coder instance from a coder spec.
+
+ Args:
+ coder_spec: A dict where the value of the '@type' key is a pickled instance
+ of a Coder instance.
+
+ Returns:
+ A coder instance (has encode/decode methods).
+ """
+ assert coder_spec is not None
+
+ # Ignore the wrappers in these encodings.
+ # TODO(silviuc): Make sure with all the renamings that names below are ok.
+ if coder_spec['@type'] in ignored_wrappers:
+ assert len(coder_spec['component_encodings']) == 1
+ coder_spec = coder_spec['component_encodings'][0]
+ return get_coder_from_spec(coder_spec)
+
+ # Handle a few well known types of coders.
+ if coder_spec['@type'] == 'kind:pair':
+ assert len(coder_spec['component_encodings']) == 2
+ component_coders = [
+ get_coder_from_spec(c) for c in coder_spec['component_encodings']]
+ return coders.TupleCoder(component_coders)
+ elif coder_spec['@type'] == 'kind:stream':
+ assert len(coder_spec['component_encodings']) == 1
+ return coders.IterableCoder(
+ get_coder_from_spec(coder_spec['component_encodings'][0]))
+ elif coder_spec['@type'] == 'kind:windowed_value':
+ assert len(coder_spec['component_encodings']) == 2
+ value_coder, window_coder = [
+ get_coder_from_spec(c) for c in coder_spec['component_encodings']]
+ return coders.WindowedValueCoder(value_coder, window_coder=window_coder)
+ elif coder_spec['@type'] == 'kind:interval_window':
+ assert ('component_encodings' not in coder_spec
+ or len(coder_spec['component_encodings'] == 0))
+ return coders.IntervalWindowCoder()
+ elif coder_spec['@type'] == 'kind:global_window':
+ assert ('component_encodings' not in coder_spec
+ or not coder_spec['component_encodings'])
+ return coders.GlobalWindowCoder()
+ elif coder_spec['@type'] == 'kind:length_prefix':
+ assert len(coder_spec['component_encodings']) == 1
+ return coders.LengthPrefixCoder(
+ get_coder_from_spec(coder_spec['component_encodings'][0]))
+
+ # We pass coders in the form "<coder_name>$<pickled_data>" to make the job
+ # description JSON more readable.
+ return coders.deserialize_coder(coder_spec['@type'])
+
+
+class MapTask(object):
+ """A map task decoded into operations and ready to be executed.
+
+ Attributes:
+ operations: A list of Worker* object created by parsing the instructions
+ within the map task.
+ stage_name: The name of this map task execution stage.
+ system_names: The system names of the step corresponding to each map task
+ operation in the execution graph.
+ step_names: The names of the step corresponding to each map task operation.
+ original_names: The internal name of a step in the original workflow graph.
+ """
+
+ def __init__(
+ self, operations, stage_name, system_names, step_names, original_names):
+ self.operations = operations
+ self.stage_name = stage_name
+ self.system_names = system_names
+ self.step_names = step_names
+ self.original_names = original_names
+
+ def __str__(self):
+ return '<%s %s steps=%s>' % (self.__class__.__name__, self.stage_name,
+ '+'.join(self.step_names))
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd b/sdks/python/apache_beam/runners/worker/operations.pxd
new file mode 100644
index 0000000..2b4e526
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+
+from apache_beam.runners.common cimport Receiver
+from apache_beam.runners.worker cimport opcounters
+from apache_beam.utils.windowed_value cimport WindowedValue
+from apache_beam.metrics.execution cimport ScopedMetricsContainer
+
+
+cdef WindowedValue _globally_windowed_value
+cdef type _global_window_type
+
+cdef class ConsumerSet(Receiver):
+ cdef list consumers
+ cdef opcounters.OperationCounters opcounter
+ cdef public step_name
+ cdef public output_index
+ cdef public coder
+
+ cpdef receive(self, WindowedValue windowed_value)
+ cpdef update_counters_start(self, WindowedValue windowed_value)
+ cpdef update_counters_finish(self)
+
+
+cdef class Operation(object):
+ cdef readonly operation_name
+ cdef readonly spec
+ cdef object consumers
+ cdef readonly counter_factory
+ cdef public metrics_container
+ cdef public ScopedMetricsContainer scoped_metrics_container
+ # Public for access by Fn harness operations.
+ # TODO(robertwb): Cythonize FnHarness.
+ cdef public list receivers
+ cdef readonly bint debug_logging_enabled
+
+ cdef public step_name # initialized lazily
+
+ cdef readonly object state_sampler
+
+ cdef readonly object scoped_start_state
+ cdef readonly object scoped_process_state
+ cdef readonly object scoped_finish_state
+
+ cpdef start(self)
+ cpdef process(self, WindowedValue windowed_value)
+ cpdef finish(self)
+
+ cpdef output(self, WindowedValue windowed_value, int output_index=*)
+
+cdef class ReadOperation(Operation):
+ @cython.locals(windowed_value=WindowedValue)
+ cpdef start(self)
+
+cdef class DoOperation(Operation):
+ cdef object dofn_runner
+ cdef Receiver dofn_receiver
+
+cdef class CombineOperation(Operation):
+ cdef object phased_combine_fn
+
+cdef class FlattenOperation(Operation):
+ pass
+
+cdef class PGBKCVOperation(Operation):
+ cdef public object combine_fn
+ cdef public object combine_fn_add_input
+ cdef dict table
+ cdef long max_keys
+ cdef long key_count
+
+ cpdef output_key(self, tuple wkey, value)
+
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/operations.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py
new file mode 100644
index 0000000..5dbe57e
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -0,0 +1,651 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# cython: profile=True
+
+"""Worker operations executor."""
+
+import collections
+import itertools
+import logging
+
+from apache_beam import pvalue
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.metrics.execution import MetricsContainer
+from apache_beam.metrics.execution import ScopedMetricsContainer
+from apache_beam.runners import common
+from apache_beam.runners.common import Receiver
+from apache_beam.runners.dataflow.internal.names import PropertyNames
+from apache_beam.runners.worker import logger
+from apache_beam.runners.worker import opcounters
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import sideinputs
+from apache_beam.transforms import combiners
+from apache_beam.transforms import core
+from apache_beam.transforms import sideinputs as apache_sideinputs
+from apache_beam.transforms.combiners import curry_combine_fn
+from apache_beam.transforms.combiners import PhasedCombineFnExecutor
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.utils.windowed_value import WindowedValue
+
+# Allow some "pure mode" declarations.
+try:
+ import cython
+except ImportError:
+ class FakeCython(object):
+ @staticmethod
+ def cast(type, value):
+ return value
+ globals()['cython'] = FakeCython()
+
+
+_globally_windowed_value = GlobalWindows.windowed_value(None)
+_global_window_type = type(_globally_windowed_value.windows[0])
+
+
+class ConsumerSet(Receiver):
+ """A ConsumerSet represents a graph edge between two Operation nodes.
+
+ The ConsumerSet object collects information from the output of the
+ Operation at one end of its edge and the input of the Operation at
+ the other edge.
+ ConsumerSet are attached to the outputting Operation.
+ """
+
+ def __init__(
+ self, counter_factory, step_name, output_index, consumers, coder):
+ self.consumers = consumers
+ self.opcounter = opcounters.OperationCounters(
+ counter_factory, step_name, coder, output_index)
+ # Used in repr.
+ self.step_name = step_name
+ self.output_index = output_index
+ self.coder = coder
+
+ def output(self, windowed_value): # For old SDKs.
+ self.receive(windowed_value)
+
+ def receive(self, windowed_value):
+ self.update_counters_start(windowed_value)
+ for consumer in self.consumers:
+ cython.cast(Operation, consumer).process(windowed_value)
+ self.update_counters_finish()
+
+ def update_counters_start(self, windowed_value):
+ self.opcounter.update_from(windowed_value)
+
+ def update_counters_finish(self):
+ self.opcounter.update_collect()
+
+ def __repr__(self):
+ return '%s[%s.out%s, coder=%s, len(consumers)=%s]' % (
+ self.__class__.__name__, self.step_name, self.output_index, self.coder,
+ len(self.consumers))
+
+
+class Operation(object):
+ """An operation representing the live version of a work item specification.
+
+ An operation can have one or more outputs and for each output it can have
+ one or more receiver operations that will take that as input.
+ """
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ """Initializes a worker operation instance.
+
+ Args:
+ operation_name: The system name assigned by the runner for this
+ operation.
+ spec: A operation_specs.Worker* instance.
+ counter_factory: The CounterFactory to use for our counters.
+ state_sampler: The StateSampler for the current operation.
+ """
+ self.operation_name = operation_name
+ self.spec = spec
+ self.counter_factory = counter_factory
+ self.consumers = collections.defaultdict(list)
+
+ self.state_sampler = state_sampler
+ self.scoped_start_state = self.state_sampler.scoped_state(
+ self.operation_name + '-start')
+ self.scoped_process_state = self.state_sampler.scoped_state(
+ self.operation_name + '-process')
+ self.scoped_finish_state = self.state_sampler.scoped_state(
+ self.operation_name + '-finish')
+ # TODO(ccy): the '-abort' state can be added when the abort is supported in
+ # Operations.
+
+ def start(self):
+ """Start operation."""
+ self.debug_logging_enabled = logging.getLogger().isEnabledFor(
+ logging.DEBUG)
+ # Everything except WorkerSideInputSource, which is not a
+ # top-level operation, should have output_coders
+ if getattr(self.spec, 'output_coders', None):
+ self.receivers = [ConsumerSet(self.counter_factory, self.step_name,
+ i, self.consumers[i], coder)
+ for i, coder in enumerate(self.spec.output_coders)]
+
+ def finish(self):
+ """Finish operation."""
+ pass
+
+ def process(self, o):
+ """Process element in operation."""
+ pass
+
+ def output(self, windowed_value, output_index=0):
+ cython.cast(Receiver, self.receivers[output_index]).receive(windowed_value)
+
+ def add_receiver(self, operation, output_index=0):
+ """Adds a receiver operation for the specified output."""
+ self.consumers[output_index].append(operation)
+
+ def __str__(self):
+ """Generates a useful string for this object.
+
+ Compactly displays interesting fields. In particular, pickled
+ fields are not displayed. Note that we collapse the fields of the
+ contained Worker* object into this object, since there is a 1-1
+ mapping between Operation and operation_specs.Worker*.
+
+ Returns:
+ Compact string representing this object.
+ """
+ return self.str_internal()
+
+ def str_internal(self, is_recursive=False):
+ """Internal helper for __str__ that supports recursion.
+
+ When recursing on receivers, keep the output short.
+ Args:
+ is_recursive: whether to omit some details, particularly receivers.
+ Returns:
+ Compact string representing this object.
+ """
+ printable_name = self.__class__.__name__
+ if hasattr(self, 'step_name'):
+ printable_name += ' %s' % self.step_name
+ if is_recursive:
+ # If we have a step name, stop here, no more detail needed.
+ return '<%s>' % printable_name
+
+ if self.spec is None:
+ printable_fields = []
+ else:
+ printable_fields = operation_specs.worker_printable_fields(self.spec)
+
+ if not is_recursive and getattr(self, 'receivers', []):
+ printable_fields.append('receivers=[%s]' % ', '.join([
+ str(receiver) for receiver in self.receivers]))
+
+ return '<%s %s>' % (printable_name, ', '.join(printable_fields))
+
+
+class ReadOperation(Operation):
+
+ def start(self):
+ with self.scoped_start_state:
+ super(ReadOperation, self).start()
+ range_tracker = self.spec.source.source.get_range_tracker(
+ self.spec.source.start_position, self.spec.source.stop_position)
+ for value in self.spec.source.source.read(range_tracker):
+ if isinstance(value, WindowedValue):
+ windowed_value = value
+ else:
+ windowed_value = _globally_windowed_value.with_value(value)
+ self.output(windowed_value)
+
+
+class InMemoryWriteOperation(Operation):
+ """A write operation that will write to an in-memory sink."""
+
+ def process(self, o):
+ with self.scoped_process_state:
+ if self.debug_logging_enabled:
+ logging.debug('Processing [%s] in %s', o, self)
+ self.spec.output_buffer.append(
+ o if self.spec.write_windowed_values else o.value)
+
+
+class _TaggedReceivers(dict):
+
+ class NullReceiver(Receiver):
+
+ def receive(self, element):
+ pass
+
+ # For old SDKs.
+ def output(self, element):
+ pass
+
+ def __missing__(self, unused_key):
+ if not getattr(self, '_null_receiver', None):
+ self._null_receiver = _TaggedReceivers.NullReceiver()
+ return self._null_receiver
+
+
+class DoOperation(Operation):
+ """A Do operation that will execute a custom DoFn for each input element."""
+
+ def _read_side_inputs(self, tags_and_types):
+ """Generator reading side inputs in the order prescribed by tags_and_types.
+
+ Args:
+ tags_and_types: List of tuples (tag, type). Each side input has a string
+ tag that is specified in the worker instruction. The type is actually
+ a boolean which is True for singleton input (read just first value)
+ and False for collection input (read all values).
+
+ Yields:
+ With each iteration it yields the result of reading an entire side source
+ either in singleton or collection mode according to the tags_and_types
+ argument.
+ """
+ # We will read the side inputs in the order prescribed by the
+ # tags_and_types argument because this is exactly the order needed to
+ # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn
+ # getting the side inputs.
+ #
+ # Note that for each tag there could be several read operations in the
+ # specification. This can happen for instance if the source has been
+ # sharded into several files.
+ for side_tag, view_class, view_options in tags_and_types:
+ sources = []
+ # Using the side_tag in the lambda below will trigger a pylint warning.
+ # However in this case it is fine because the lambda is used right away
+ # while the variable has the value assigned by the current iteration of
+ # the for loop.
+ # pylint: disable=cell-var-from-loop
+ for si in itertools.ifilter(
+ lambda o: o.tag == side_tag, self.spec.side_inputs):
+ if not isinstance(si, operation_specs.WorkerSideInputSource):
+ raise NotImplementedError('Unknown side input type: %r' % si)
+ sources.append(si.source)
+ iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
+
+ # Backwards compatibility for pre BEAM-733 SDKs.
+ if isinstance(view_options, tuple):
+ if view_class == pvalue.SingletonPCollectionView:
+ has_default, default = view_options
+ view_options = {'default': default} if has_default else {}
+ else:
+ view_options = {}
+
+ yield apache_sideinputs.SideInputMap(
+ view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
+
+ def start(self):
+ with self.scoped_start_state:
+ super(DoOperation, self).start()
+
+ # See fn_data in dataflow_runner.py
+ fn, args, kwargs, tags_and_types, window_fn = (
+ pickler.loads(self.spec.serialized_fn))
+
+ state = common.DoFnState(self.counter_factory)
+ state.step_name = self.step_name
+
+ # TODO(silviuc): What is the proper label here? PCollection being
+ # processed?
+ context = common.DoFnContext('label', state=state)
+ # Tag to output index map used to dispatch the side output values emitted
+ # by the DoFn function to the appropriate receivers. The main output is
+ # tagged with None and is associated with its corresponding index.
+ tagged_receivers = _TaggedReceivers()
+
+ output_tag_prefix = PropertyNames.OUT + '_'
+ for index, tag in enumerate(self.spec.output_tags):
+ if tag == PropertyNames.OUT:
+ original_tag = None
+ elif tag.startswith(output_tag_prefix):
+ original_tag = tag[len(output_tag_prefix):]
+ else:
+ raise ValueError('Unexpected output name for operation: %s' % tag)
+ tagged_receivers[original_tag] = self.receivers[index]
+
+ self.dofn_runner = common.DoFnRunner(
+ fn, args, kwargs, self._read_side_inputs(tags_and_types),
+ window_fn, context, tagged_receivers,
+ logger, self.step_name,
+ scoped_metrics_container=self.scoped_metrics_container)
+ self.dofn_receiver = (self.dofn_runner
+ if isinstance(self.dofn_runner, Receiver)
+ else DoFnRunnerReceiver(self.dofn_runner))
+
+ self.dofn_runner.start()
+
+ def finish(self):
+ with self.scoped_finish_state:
+ self.dofn_runner.finish()
+
+ def process(self, o):
+ with self.scoped_process_state:
+ self.dofn_receiver.receive(o)
+
+
+class DoFnRunnerReceiver(Receiver):
+
+ def __init__(self, dofn_runner):
+ self.dofn_runner = dofn_runner
+
+ def receive(self, windowed_value):
+ self.dofn_runner.process(windowed_value)
+
+
+class CombineOperation(Operation):
+ """A Combine operation executing a CombineFn for each input element."""
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ super(CombineOperation, self).__init__(
+ operation_name, spec, counter_factory, state_sampler)
+ # Combiners do not accept deferred side-inputs (the ignored fourth argument)
+ # and therefore the code to handle the extra args/kwargs is simpler than for
+ # the DoFn's of ParDo.
+ fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3]
+ self.phased_combine_fn = (
+ PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs))
+
+ def finish(self):
+ logging.debug('Finishing %s', self)
+
+ def process(self, o):
+ if self.debug_logging_enabled:
+ logging.debug('Processing [%s] in %s', o, self)
+ key, values = o.value
+ with self.scoped_metrics_container:
+ self.output(
+ o.with_value((key, self.phased_combine_fn.apply(values))))
+
+
+def create_pgbk_op(step_name, spec, counter_factory, state_sampler):
+ if spec.combine_fn:
+ return PGBKCVOperation(step_name, spec, counter_factory, state_sampler)
+ else:
+ return PGBKOperation(step_name, spec, counter_factory, state_sampler)
+
+
+class PGBKOperation(Operation):
+ """Partial group-by-key operation.
+
+ This takes (windowed) input (key, value) tuples and outputs
+ (key, [value]) tuples, performing a best effort group-by-key for
+ values in this bundle, memory permitting.
+ """
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ super(PGBKOperation, self).__init__(
+ operation_name, spec, counter_factory, state_sampler)
+ assert not self.spec.combine_fn
+ self.table = collections.defaultdict(list)
+ self.size = 0
+ # TODO(robertwb) Make this configurable.
+ self.max_size = 10 * 1000
+
+ def process(self, o):
+ # TODO(robertwb): Structural (hashable) values.
+ key = o.value[0], tuple(o.windows)
+ self.table[key].append(o)
+ self.size += 1
+ if self.size > self.max_size:
+ self.flush(9 * self.max_size // 10)
+
+ def finish(self):
+ self.flush(0)
+
+ def flush(self, target):
+ limit = self.size - target
+ for ix, (kw, vs) in enumerate(self.table.items()):
+ if ix >= limit:
+ break
+ del self.table[kw]
+ key, windows = kw
+ output_value = [v.value[1] for v in vs]
+ windowed_value = WindowedValue(
+ (key, output_value),
+ vs[0].timestamp, windows)
+ self.output(windowed_value)
+
+
+class PGBKCVOperation(Operation):
+
+ def __init__(self, operation_name, spec, counter_factory, state_sampler):
+ super(PGBKCVOperation, self).__init__(
+ operation_name, spec, counter_factory, state_sampler)
+ # Combiners do not accept deferred side-inputs (the ignored fourth
+ # argument) and therefore the code to handle the extra args/kwargs is
+ # simpler than for the DoFn's of ParDo.
+ fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
+ self.combine_fn = curry_combine_fn(fn, args, kwargs)
+ if (getattr(fn.add_input, 'im_func', None)
+ is core.CombineFn.add_input.im_func):
+ # Old versions of the SDK have CombineFns that don't implement add_input.
+ self.combine_fn_add_input = (
+ lambda a, e: self.combine_fn.add_inputs(a, [e]))
+ else:
+ self.combine_fn_add_input = self.combine_fn.add_input
+ # Optimization for the (known tiny accumulator, often wide keyspace)
+ # combine functions.
+ # TODO(b/36567833): Bound by in-memory size rather than key count.
+ self.max_keys = (
+ 1000 * 1000 if
+ isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or
+ # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized
+ # combiners to the short list above.
+ (isinstance(fn, core.CallableWrapperCombineFn) and
+ fn._fn in (min, max, sum)) else 100 * 1000) # pylint: disable=protected-access
+ self.key_count = 0
+ self.table = {}
+
+ def process(self, wkv):
+ key, value = wkv.value
+ # pylint: disable=unidiomatic-typecheck
+ # Optimization for the global window case.
+ if len(wkv.windows) == 1 and type(wkv.windows[0]) is _global_window_type:
+ wkey = 0, key
+ else:
+ wkey = tuple(wkv.windows), key
+ entry = self.table.get(wkey, None)
+ if entry is None:
+ if self.key_count >= self.max_keys:
+ target = self.key_count * 9 // 10
+ old_wkeys = []
+ # TODO(robertwb): Use an LRU cache?
+ for old_wkey, old_wvalue in self.table.iteritems():
+ old_wkeys.append(old_wkey) # Can't mutate while iterating.
+ self.output_key(old_wkey, old_wvalue[0])
+ self.key_count -= 1
+ if self.key_count <= target:
+ break
+ for old_wkey in reversed(old_wkeys):
+ del self.table[old_wkey]
+ self.key_count += 1
+ # We save the accumulator as a one element list so we can efficiently
+ # mutate when new values are added without searching the cache again.
+ entry = self.table[wkey] = [self.combine_fn.create_accumulator()]
+ entry[0] = self.combine_fn_add_input(entry[0], value)
+
+ def finish(self):
+ for wkey, value in self.table.iteritems():
+ self.output_key(wkey, value[0])
+ self.table = {}
+ self.key_count = 0
+
+ def output_key(self, wkey, value):
+ windows, key = wkey
+ if windows is 0:
+ self.output(_globally_windowed_value.with_value((key, value)))
+ else:
+ self.output(WindowedValue((key, value), windows[0].end, windows))
+
+
+class FlattenOperation(Operation):
+ """Flatten operation.
+
+ Receives one or more producer operations, outputs just one list
+ with all the items.
+ """
+
+ def process(self, o):
+ if self.debug_logging_enabled:
+ logging.debug('Processing [%s] in %s', o, self)
+ self.output(o)
+
+
+def create_operation(operation_name, spec, counter_factory, step_name,
+ state_sampler, test_shuffle_source=None,
+ test_shuffle_sink=None, is_streaming=False):
+ """Create Operation object for given operation specification."""
+ if isinstance(spec, operation_specs.WorkerRead):
+ if isinstance(spec.source, iobase.SourceBundle):
+ op = ReadOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ else:
+ from dataflow_worker.native_operations import NativeReadOperation
+ op = NativeReadOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerWrite):
+ from dataflow_worker.native_operations import NativeWriteOperation
+ op = NativeWriteOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerCombineFn):
+ op = CombineOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerPartialGroupByKey):
+ op = create_pgbk_op(operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerDoFn):
+ op = DoOperation(operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerGroupingShuffleRead):
+ from dataflow_worker.shuffle_operations import GroupedShuffleReadOperation
+ op = GroupedShuffleReadOperation(
+ operation_name, spec, counter_factory, state_sampler,
+ shuffle_source=test_shuffle_source)
+ elif isinstance(spec, operation_specs.WorkerUngroupedShuffleRead):
+ from dataflow_worker.shuffle_operations import UngroupedShuffleReadOperation
+ op = UngroupedShuffleReadOperation(
+ operation_name, spec, counter_factory, state_sampler,
+ shuffle_source=test_shuffle_source)
+ elif isinstance(spec, operation_specs.WorkerInMemoryWrite):
+ op = InMemoryWriteOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerShuffleWrite):
+ from dataflow_worker.shuffle_operations import ShuffleWriteOperation
+ op = ShuffleWriteOperation(
+ operation_name, spec, counter_factory, state_sampler,
+ shuffle_sink=test_shuffle_sink)
+ elif isinstance(spec, operation_specs.WorkerFlatten):
+ op = FlattenOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerMergeWindows):
+ from dataflow_worker.shuffle_operations import BatchGroupAlsoByWindowsOperation
+ from dataflow_worker.shuffle_operations import StreamingGroupAlsoByWindowsOperation
+ if is_streaming:
+ op = StreamingGroupAlsoByWindowsOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ else:
+ op = BatchGroupAlsoByWindowsOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ elif isinstance(spec, operation_specs.WorkerReifyTimestampAndWindows):
+ from dataflow_worker.shuffle_operations import ReifyTimestampAndWindowsOperation
+ op = ReifyTimestampAndWindowsOperation(
+ operation_name, spec, counter_factory, state_sampler)
+ else:
+ raise TypeError('Expected an instance of operation_specs.Worker* class '
+ 'instead of %s' % (spec,))
+ op.step_name = step_name
+ op.metrics_container = MetricsContainer(step_name)
+ op.scoped_metrics_container = ScopedMetricsContainer(op.metrics_container)
+ return op
+
+
+class SimpleMapTaskExecutor(object):
+ """An executor for map tasks.
+
+ Stores progress of the read operation that is the first operation of a map
+ task.
+ """
+
+ def __init__(
+ self, map_task, counter_factory, state_sampler,
+ test_shuffle_source=None, test_shuffle_sink=None):
+ """Initializes SimpleMapTaskExecutor.
+
+ Args:
+ map_task: The map task we are to run.
+ counter_factory: The CounterFactory instance for the work item.
+ state_sampler: The StateSampler tracking the execution step.
+ test_shuffle_source: Used during tests for dependency injection into
+ shuffle read operation objects.
+ test_shuffle_sink: Used during tests for dependency injection into
+ shuffle write operation objects.
+ """
+
+ self._map_task = map_task
+ self._counter_factory = counter_factory
+ self._ops = []
+ self._state_sampler = state_sampler
+ self._test_shuffle_source = test_shuffle_source
+ self._test_shuffle_sink = test_shuffle_sink
+
+ def operations(self):
+ return self._ops[:]
+
+ def execute(self):
+ """Executes all the operation_specs.Worker* instructions in a map task.
+
+ We update the map_task with the execution status, expressed as counters.
+
+ Raises:
+ RuntimeError: if we find more than on read instruction in task spec.
+ TypeError: if the spec parameter is not an instance of the recognized
+ operation_specs.Worker* classes.
+ """
+
+ # operations is a list of operation_specs.Worker* instances.
+ # The order of the elements is important because the inputs use
+ # list indexes as references.
+
+ step_names = (
+ self._map_task.step_names or [None] * len(self._map_task.operations))
+ for ix, spec in enumerate(self._map_task.operations):
+ # This is used for logging and assigning names to counters.
+ operation_name = self._map_task.system_names[ix]
+ step_name = step_names[ix]
+ op = create_operation(
+ operation_name, spec, self._counter_factory, step_name,
+ self._state_sampler,
+ test_shuffle_source=self._test_shuffle_source,
+ test_shuffle_sink=self._test_shuffle_sink)
+ self._ops.append(op)
+
+ # Add receiver operations to the appropriate producers.
+ if hasattr(op.spec, 'input'):
+ producer, output_index = op.spec.input
+ self._ops[producer].add_receiver(op, output_index)
+ # Flatten has 'inputs', not 'input'
+ if hasattr(op.spec, 'inputs'):
+ for producer, output_index in op.spec.inputs:
+ self._ops[producer].add_receiver(op, output_index)
+
+ for ix, op in reversed(list(enumerate(self._ops))):
+ logging.debug('Starting op %d %s', ix, op)
+ with op.scoped_metrics_container:
+ op.start()
+ for op in self._ops:
+ with op.scoped_metrics_container:
+ op.finish()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
new file mode 100644
index 0000000..6907f6e
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -0,0 +1,451 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""SDK harness for executing Python Fns via the Fn API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import logging
+import Queue as queue
+import threading
+import traceback
+import zlib
+
+import dill
+from google.protobuf import wrappers_pb2
+
+from apache_beam.coders import coder_impl
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.internal import pickler
+from apache_beam.runners.dataflow.native_io import iobase
+from apache_beam.utils import counters
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+try:
+ from apache_beam.runners.worker import statesampler
+except ImportError:
+ from apache_beam.runners.worker import statesampler_fake as statesampler
+from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory
+
+
+DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1'
+DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1'
+IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1'
+PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1'
+PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1'
+# TODO(vikasrk): Fix this once runner sends appropriate python urns.
+PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1'
+PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1'
+
+
+class RunnerIOOperation(operations.Operation):
+ """Common baseclass for runner harness IO operations."""
+
+ def __init__(self, operation_name, step_name, consumers, counter_factory,
+ state_sampler, windowed_coder, target, data_channel):
+ super(RunnerIOOperation, self).__init__(
+ operation_name, None, counter_factory, state_sampler)
+ self.windowed_coder = windowed_coder
+ self.step_name = step_name
+ # target represents the consumer for the bytes in the data plane for a
+ # DataInputOperation or a producer of these bytes for a DataOutputOperation.
+ self.target = target
+ self.data_channel = data_channel
+ for _, consumer_ops in consumers.items():
+ for consumer in consumer_ops:
+ self.add_receiver(consumer, 0)
+
+
+class DataOutputOperation(RunnerIOOperation):
+ """A sink-like operation that gathers outputs to be sent back to the runner.
+ """
+
+ def set_output_stream(self, output_stream):
+ self.output_stream = output_stream
+
+ def process(self, windowed_value):
+ self.windowed_coder.get_impl().encode_to_stream(
+ windowed_value, self.output_stream, True)
+
+ def finish(self):
+ self.output_stream.close()
+ super(DataOutputOperation, self).finish()
+
+
+class DataInputOperation(RunnerIOOperation):
+ """A source-like operation that gathers input from the runner.
+ """
+
+ def __init__(self, operation_name, step_name, consumers, counter_factory,
+ state_sampler, windowed_coder, input_target, data_channel):
+ super(DataInputOperation, self).__init__(
+ operation_name, step_name, consumers, counter_factory, state_sampler,
+ windowed_coder, target=input_target, data_channel=data_channel)
+ # We must do this manually as we don't have a spec or spec.output_coders.
+ self.receivers = [
+ operations.ConsumerSet(self.counter_factory, self.step_name, 0,
+ consumers.itervalues().next(),
+ self.windowed_coder)]
+
+ def process(self, windowed_value):
+ self.output(windowed_value)
+
+ def process_encoded(self, encoded_windowed_values):
+ input_stream = coder_impl.create_InputStream(encoded_windowed_values)
+ while input_stream.size() > 0:
+ decoded_value = self.windowed_coder.get_impl().decode_from_stream(
+ input_stream, True)
+ self.output(decoded_value)
+
+
+# TODO(robertwb): Revise side input API to not be in terms of native sources.
+# This will enable lookups, but there's an open question as to how to handle
+# custom sources without forcing intermediate materialization. This seems very
+# related to the desire to inject key and window preserving [Splittable]DoFns
+# into the view computation.
+class SideInputSource(iobase.NativeSource, iobase.NativeSourceReader):
+ """A 'source' for reading side inputs via state API calls.
+ """
+
+ def __init__(self, state_handler, state_key, coder):
+ self._state_handler = state_handler
+ self._state_key = state_key
+ self._coder = coder
+
+ def reader(self):
+ return self
+
+ @property
+ def returns_windowed_values(self):
+ return True
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exn_info):
+ pass
+
+ def __iter__(self):
+ # TODO(robertwb): Support pagination.
+ input_stream = coder_impl.create_InputStream(
+ self._state_handler.Get(self._state_key).data)
+ while input_stream.size() > 0:
+ yield self._coder.get_impl().decode_from_stream(input_stream, True)
+
+
+def unpack_and_deserialize_py_fn(function_spec):
+ """Returns unpacked and deserialized object from function spec proto."""
+ return pickler.loads(unpack_function_spec_data(function_spec))
+
+
+def unpack_function_spec_data(function_spec):
+ """Returns unpacked data from function spec proto."""
+ data = wrappers_pb2.BytesValue()
+ function_spec.data.Unpack(data)
+ return data.value
+
+
+# pylint: disable=redefined-builtin
+def serialize_and_pack_py_fn(fn, urn, id=None):
+ """Returns serialized and packed function in a function spec proto."""
+ return pack_function_spec_data(pickler.dumps(fn), urn, id)
+# pylint: enable=redefined-builtin
+
+
+# pylint: disable=redefined-builtin
+def pack_function_spec_data(value, urn, id=None):
+ """Returns packed data in a function spec proto."""
+ data = wrappers_pb2.BytesValue(value=value)
+ fn_proto = beam_fn_api_pb2.FunctionSpec(urn=urn)
+ fn_proto.data.Pack(data)
+ if id:
+ fn_proto.id = id
+ return fn_proto
+# pylint: enable=redefined-builtin
+
+
+# TODO(vikasrk): move this method to ``coders.py`` in the SDK.
+def load_compressed(compressed_data):
+ """Returns a decompressed and deserialized python object."""
+ # Note: SDK uses ``pickler.dumps`` to serialize certain python objects
+ # (like sources), which involves serialization, compression and base64
+ # encoding. We cannot directly use ``pickler.loads`` for
+ # deserialization, as the runner would have already base64 decoded the
+ # data. So we only need to decompress and deserialize.
+
+ data = zlib.decompress(compressed_data)
+ try:
+ return dill.loads(data)
+ except Exception: # pylint: disable=broad-except
+ dill.dill._trace(True) # pylint: disable=protected-access
+ return dill.loads(data)
+ finally:
+ dill.dill._trace(False) # pylint: disable=protected-access
+
+
+class SdkHarness(object):
+
+ def __init__(self, control_channel):
+ self._control_channel = control_channel
+ self._data_channel_factory = GrpcClientDataChannelFactory()
+
+ def run(self):
+ contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel)
+ # TODO(robertwb): Wire up to new state api.
+ state_stub = None
+ self.worker = SdkWorker(state_stub, self._data_channel_factory)
+
+ responses = queue.Queue()
+ no_more_work = object()
+
+ def get_responses():
+ while True:
+ response = responses.get()
+ if response is no_more_work:
+ return
+ yield response
+
+ def process_requests():
+ for work_request in contol_stub.Control(get_responses()):
+ logging.info('Got work %s', work_request.instruction_id)
+ try:
+ response = self.worker.do_instruction(work_request)
+ except Exception: # pylint: disable=broad-except
+ response = beam_fn_api_pb2.InstructionResponse(
+ instruction_id=work_request.instruction_id,
+ error=traceback.format_exc())
+ responses.put(response)
+ t = threading.Thread(target=process_requests)
+ t.start()
+ t.join()
+ # get_responses may be blocked on responses.get(), but we need to return
+ # control to its caller.
+ responses.put(no_more_work)
+ self._data_channel_factory.close()
+ logging.info('Done consuming work.')
+
+
+class SdkWorker(object):
+
+ def __init__(self, state_handler, data_channel_factory):
+ self.fns = {}
+ self.state_handler = state_handler
+ self.data_channel_factory = data_channel_factory
+
+ def do_instruction(self, request):
+ request_type = request.WhichOneof('request')
+ if request_type:
+ # E.g. if register is set, this will construct
+ # InstructionResponse(register=self.register(request.register))
+ return beam_fn_api_pb2.InstructionResponse(**{
+ 'instruction_id': request.instruction_id,
+ request_type: getattr(self, request_type)
+ (getattr(request, request_type), request.instruction_id)
+ })
+ else:
+ raise NotImplementedError
+
+ def register(self, request, unused_instruction_id=None):
+ for process_bundle_descriptor in request.process_bundle_descriptor:
+ self.fns[process_bundle_descriptor.id] = process_bundle_descriptor
+ for p_transform in list(process_bundle_descriptor.primitive_transform):
+ self.fns[p_transform.function_spec.id] = p_transform.function_spec
+ return beam_fn_api_pb2.RegisterResponse()
+
+ def initial_source_split(self, request, unused_instruction_id=None):
+ source_spec = self.fns[request.source_reference]
+ assert source_spec.urn == PYTHON_SOURCE_URN
+ source_bundle = unpack_and_deserialize_py_fn(
+ self.fns[request.source_reference])
+ splits = source_bundle.source.split(request.desired_bundle_size_bytes,
+ source_bundle.start_position,
+ source_bundle.stop_position)
+ response = beam_fn_api_pb2.InitialSourceSplitResponse()
+ response.splits.extend([
+ beam_fn_api_pb2.SourceSplit(
+ source=serialize_and_pack_py_fn(split, PYTHON_SOURCE_URN),
+ relative_size=split.weight,
+ )
+ for split in splits
+ ])
+ return response
+
+ def create_execution_tree(self, descriptor):
+ # TODO(vikasrk): Add an id field to Coder proto and use that instead.
+ coders = {coder.function_spec.id: operation_specs.get_coder_from_spec(
+ json.loads(unpack_function_spec_data(coder.function_spec)))
+ for coder in descriptor.coders}
+
+ counter_factory = counters.CounterFactory()
+ # TODO(robertwb): Figure out the correct prefix to use for output counters
+ # from StateSampler.
+ state_sampler = statesampler.StateSampler(
+ 'fnapi-step%s-' % descriptor.id, counter_factory)
+ consumers = collections.defaultdict(lambda: collections.defaultdict(list))
+ ops_by_id = {}
+ reversed_ops = []
+
+ for transform in reversed(descriptor.primitive_transform):
+ # TODO(robertwb): Figure out how to plumb through the operation name (e.g.
+ # "s3") from the service through the FnAPI so that msec counters can be
+ # reported and correctly plumbed through the service and the UI.
+ operation_name = 'fnapis%s' % transform.id
+
+ def only_element(iterable):
+ element, = iterable
+ return element
+
+ if transform.function_spec.urn == DATA_OUTPUT_URN:
+ target = beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform.id,
+ name=only_element(transform.outputs.keys()))
+
+ op = DataOutputOperation(
+ operation_name,
+ transform.step_name,
+ consumers[transform.id],
+ counter_factory,
+ state_sampler,
+ coders[only_element(transform.outputs.values()).coder_reference],
+ target,
+ self.data_channel_factory.create_data_channel(
+ transform.function_spec))
+
+ elif transform.function_spec.urn == DATA_INPUT_URN:
+ target = beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform.id,
+ name=only_element(transform.inputs.keys()))
+ op = DataInputOperation(
+ operation_name,
+ transform.step_name,
+ consumers[transform.id],
+ counter_factory,
+ state_sampler,
+ coders[only_element(transform.outputs.values()).coder_reference],
+ target,
+ self.data_channel_factory.create_data_channel(
+ transform.function_spec))
+
+ elif transform.function_spec.urn == PYTHON_DOFN_URN:
+ def create_side_input(tag, si):
+ # TODO(robertwb): Extract windows (and keys) out of element data.
+ return operation_specs.WorkerSideInputSource(
+ tag=tag,
+ source=SideInputSource(
+ self.state_handler,
+ beam_fn_api_pb2.StateKey(
+ function_spec_reference=si.view_fn.id),
+ coder=unpack_and_deserialize_py_fn(si.view_fn)))
+ output_tags = list(transform.outputs.keys())
+ spec = operation_specs.WorkerDoFn(
+ serialized_fn=unpack_function_spec_data(transform.function_spec),
+ output_tags=output_tags,
+ input=None,
+ side_inputs=[create_side_input(tag, si)
+ for tag, si in transform.side_inputs.items()],
+ output_coders=[coders[transform.outputs[out].coder_reference]
+ for out in output_tags])
+
+ op = operations.DoOperation(operation_name, spec, counter_factory,
+ state_sampler)
+ # TODO(robertwb): Move these to the constructor.
+ op.step_name = transform.step_name
+ for tag, op_consumers in consumers[transform.id].items():
+ for consumer in op_consumers:
+ op.add_receiver(
+ consumer, output_tags.index(tag))
+
+ elif transform.function_spec.urn == IDENTITY_DOFN_URN:
+ op = operations.FlattenOperation(operation_name, None, counter_factory,
+ state_sampler)
+ # TODO(robertwb): Move these to the constructor.
+ op.step_name = transform.step_name
+ for tag, op_consumers in consumers[transform.id].items():
+ for consumer in op_consumers:
+ op.add_receiver(consumer, 0)
+
+ elif transform.function_spec.urn == PYTHON_SOURCE_URN:
+ source = load_compressed(unpack_function_spec_data(
+ transform.function_spec))
+ # TODO(vikasrk): Remove this once custom source is implemented with
+ # splittable dofn via the data plane.
+ spec = operation_specs.WorkerRead(
+ iobase.SourceBundle(1.0, source, None, None),
+ [WindowedValueCoder(source.default_output_coder())])
+ op = operations.ReadOperation(operation_name, spec, counter_factory,
+ state_sampler)
+ op.step_name = transform.step_name
+ output_tags = list(transform.outputs.keys())
+ for tag, op_consumers in consumers[transform.id].items():
+ for consumer in op_consumers:
+ op.add_receiver(
+ consumer, output_tags.index(tag))
+
+ else:
+ raise NotImplementedError
+
+ # Record consumers.
+ for _, inputs in transform.inputs.items():
+ for target in inputs.target:
+ consumers[target.primitive_transform_reference][target.name].append(
+ op)
+
+ reversed_ops.append(op)
+ ops_by_id[transform.id] = op
+
+ return list(reversed(reversed_ops)), ops_by_id
+
+ def process_bundle(self, request, instruction_id):
+ ops, ops_by_id = self.create_execution_tree(
+ self.fns[request.process_bundle_descriptor_reference])
+
+ expected_inputs = []
+ for _, op in ops_by_id.items():
+ if isinstance(op, DataOutputOperation):
+ # TODO(robertwb): Is there a better way to pass the instruction id to
+ # the operation?
+ op.set_output_stream(op.data_channel.output_stream(
+ instruction_id, op.target))
+ elif isinstance(op, DataInputOperation):
+ # We must wait until we receive "end of stream" for each of these ops.
+ expected_inputs.append(op)
+
+ # Start all operations.
+ for op in reversed(ops):
+ logging.info('start %s', op)
+ op.start()
+
+ # Inject inputs from data plane.
+ for input_op in expected_inputs:
+ for data in input_op.data_channel.input_elements(
+ instruction_id, [input_op.target]):
+ # ignores input name
+ target_op = ops_by_id[data.target.primitive_transform_reference]
+ # lacks coder for non-input ops
+ target_op.process_encoded(data.data)
+
+ # Finish all operations.
+ for op in ops:
+ logging.info('finish %s', op)
+ op.finish()
+
+ return beam_fn_api_pb2.ProcessBundleResponse()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
new file mode 100644
index 0000000..28828c3
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""SDK Fn Harness entry point."""
+
+import logging
+import os
+import sys
+
+import grpc
+from google.protobuf import text_format
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler
+from apache_beam.runners.worker.sdk_worker import SdkHarness
+
+
+def main(unused_argv):
+ """Main entry point for SDK Fn Harness."""
+ logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+ text_format.Merge(os.environ['LOGGING_API_SERVICE_DESCRIPTOR'],
+ logging_service_descriptor)
+
+ # Send all logs to the runner.
+ fn_log_handler = FnApiLogRecordHandler(logging_service_descriptor)
+ # TODO(vikasrk): This should be picked up from pipeline options.
+ logging.getLogger().setLevel(logging.INFO)
+ logging.getLogger().addHandler(fn_log_handler)
+
+ try:
+ logging.info('Python sdk harness started.')
+ service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+ text_format.Merge(os.environ['CONTROL_API_SERVICE_DESCRIPTOR'],
+ service_descriptor)
+ # TODO(robertwb): Support credentials.
+ assert not service_descriptor.oauth2_client_credentials_grant.url
+ channel = grpc.insecure_channel(service_descriptor.url)
+ SdkHarness(channel).run()
+ logging.info('Python sdk harness exiting.')
+ except: # pylint: disable=broad-except
+ logging.exception('Python sdk harness failed: ')
+ raise
+ finally:
+ fn_log_handler.close()
+
+
+if __name__ == '__main__':
+ main(sys.argv)
[4/4] beam git commit: Closes #2644
Posted by ro...@apache.org.
Closes #2644
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/7c425b09
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/7c425b09
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/7c425b09
Branch: refs/heads/master
Commit: 7c425b097fdadcc160cdaa7d8992416b991e37c4
Parents: b8131fe a856fcf
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Mon May 1 17:44:30 2017 -0700
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Mon May 1 17:44:30 2017 -0700
----------------------------------------------------------------------
.../apache_beam/runners/portability/__init__.py | 16 +
.../runners/portability/fn_api_runner.py | 471 ++++++++++++++
.../runners/portability/fn_api_runner_test.py | 40 ++
.../portability/maptask_executor_runner.py | 468 +++++++++++++
.../portability/maptask_executor_runner_test.py | 204 ++++++
.../apache_beam/runners/worker/__init__.py | 16 +
.../apache_beam/runners/worker/data_plane.py | 288 ++++++++
.../runners/worker/data_plane_test.py | 139 ++++
.../apache_beam/runners/worker/log_handler.py | 100 +++
.../runners/worker/log_handler_test.py | 105 +++
.../apache_beam/runners/worker/logger.pxd | 25 +
.../python/apache_beam/runners/worker/logger.py | 173 +++++
.../apache_beam/runners/worker/logger_test.py | 182 ++++++
.../apache_beam/runners/worker/opcounters.pxd | 45 ++
.../apache_beam/runners/worker/opcounters.py | 162 +++++
.../runners/worker/opcounters_test.py | 149 +++++
.../runners/worker/operation_specs.py | 368 +++++++++++
.../apache_beam/runners/worker/operations.pxd | 89 +++
.../apache_beam/runners/worker/operations.py | 651 +++++++++++++++++++
.../apache_beam/runners/worker/sdk_worker.py | 451 +++++++++++++
.../runners/worker/sdk_worker_main.py | 62 ++
.../runners/worker/sdk_worker_test.py | 168 +++++
.../apache_beam/runners/worker/sideinputs.py | 166 +++++
.../runners/worker/sideinputs_test.py | 150 +++++
.../apache_beam/runners/worker/statesampler.pyx | 237 +++++++
.../runners/worker/statesampler_fake.py | 34 +
.../runners/worker/statesampler_test.py | 102 +++
sdks/python/generate_pydoc.sh | 2 +
sdks/python/setup.py | 8 +-
29 files changed, 5069 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
[3/4] beam git commit: Fn API support for Python.
Posted by ro...@apache.org.
Fn API support for Python.
Also added supporting worker code and a runner
that exercises this code directly.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a856fcf3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a856fcf3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a856fcf3
Branch: refs/heads/master
Commit: a856fcf3f9ca48530078e1d7610e234110cd5890
Parents: b8131fe
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Thu Apr 20 15:59:25 2017 -0500
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Mon May 1 17:44:29 2017 -0700
----------------------------------------------------------------------
.../apache_beam/runners/portability/__init__.py | 16 +
.../runners/portability/fn_api_runner.py | 471 ++++++++++++++
.../runners/portability/fn_api_runner_test.py | 40 ++
.../portability/maptask_executor_runner.py | 468 +++++++++++++
.../portability/maptask_executor_runner_test.py | 204 ++++++
.../apache_beam/runners/worker/__init__.py | 16 +
.../apache_beam/runners/worker/data_plane.py | 288 ++++++++
.../runners/worker/data_plane_test.py | 139 ++++
.../apache_beam/runners/worker/log_handler.py | 100 +++
.../runners/worker/log_handler_test.py | 105 +++
.../apache_beam/runners/worker/logger.pxd | 25 +
.../python/apache_beam/runners/worker/logger.py | 173 +++++
.../apache_beam/runners/worker/logger_test.py | 182 ++++++
.../apache_beam/runners/worker/opcounters.pxd | 45 ++
.../apache_beam/runners/worker/opcounters.py | 162 +++++
.../runners/worker/opcounters_test.py | 149 +++++
.../runners/worker/operation_specs.py | 368 +++++++++++
.../apache_beam/runners/worker/operations.pxd | 89 +++
.../apache_beam/runners/worker/operations.py | 651 +++++++++++++++++++
.../apache_beam/runners/worker/sdk_worker.py | 451 +++++++++++++
.../runners/worker/sdk_worker_main.py | 62 ++
.../runners/worker/sdk_worker_test.py | 168 +++++
.../apache_beam/runners/worker/sideinputs.py | 166 +++++
.../runners/worker/sideinputs_test.py | 150 +++++
.../apache_beam/runners/worker/statesampler.pyx | 237 +++++++
.../runners/worker/statesampler_fake.py | 34 +
.../runners/worker/statesampler_test.py | 102 +++
sdks/python/generate_pydoc.sh | 2 +
sdks/python/setup.py | 8 +-
29 files changed, 5069 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/__init__.py b/sdks/python/apache_beam/runners/portability/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
new file mode 100644
index 0000000..5802c17
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -0,0 +1,471 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A PipelineRunner using the SDK harness.
+"""
+import collections
+import json
+import logging
+import Queue as queue
+import threading
+
+import grpc
+from concurrent import futures
+
+import apache_beam as beam
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.coders.coder_impl import create_InputStream
+from apache_beam.coders.coder_impl import create_OutputStream
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.portability import maptask_executor_runner
+from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import sdk_worker
+
+
+def streaming_rpc_handler(cls, method_name):
+ """Un-inverts the flow of control between the runner and the sdk harness."""
+
+ class StreamingRpcHandler(cls):
+
+ _DONE = object()
+
+ def __init__(self):
+ self._push_queue = queue.Queue()
+ self._pull_queue = queue.Queue()
+ setattr(self, method_name, self.run)
+ self._read_thread = threading.Thread(target=self._read)
+
+ def run(self, iterator, context):
+ self._inputs = iterator
+ # Note: We only support one client for now.
+ self._read_thread.start()
+ while True:
+ to_push = self._push_queue.get()
+ if to_push is self._DONE:
+ return
+ yield to_push
+
+ def _read(self):
+ for data in self._inputs:
+ self._pull_queue.put(data)
+
+ def push(self, item):
+ self._push_queue.put(item)
+
+ def pull(self, timeout=None):
+ return self._pull_queue.get(timeout=timeout)
+
+ def empty(self):
+ return self._pull_queue.empty()
+
+ def done(self):
+ self.push(self._DONE)
+ self._read_thread.join()
+
+ return StreamingRpcHandler()
+
+
+class OldeSourceSplittableDoFn(beam.DoFn):
+ """A DoFn that reads and emits an entire source.
+ """
+
+ # TODO(robertwb): Make this a full SDF with progress splitting, etc.
+ def process(self, source):
+ if isinstance(source, iobase.SourceBundle):
+ for value in source.source.read(source.source.get_range_tracker(
+ source.start_position, source.stop_position)):
+ yield value
+ else:
+ # Dataflow native source
+ with source.reader() as reader:
+ for value in reader:
+ yield value
+
+
+# See DataflowRunner._pardo_fn_data
+OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps(
+ (OldeSourceSplittableDoFn(), (), {}, [],
+ beam.transforms.core.Windowing(GlobalWindows())))
+
+
+class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
+
+ def __init__(self):
+ super(FnApiRunner, self).__init__()
+ self._last_uid = -1
+
+ def has_metrics_support(self):
+ return False
+
+ def _next_uid(self):
+ self._last_uid += 1
+ return str(self._last_uid)
+
+ def _map_task_registration(self, map_task, state_handler,
+ data_operation_spec):
+ input_data = {}
+ runner_sinks = {}
+ transforms = []
+ transform_index_to_id = {}
+
+ # Maps coders to new coder objects and references.
+ coders = {}
+
+ def coder_id(coder):
+ if coder not in coders:
+ coders[coder] = beam_fn_api_pb2.Coder(
+ function_spec=sdk_worker.pack_function_spec_data(
+ json.dumps(coder.as_cloud_object()),
+ sdk_worker.PYTHON_CODER_URN, id=self._next_uid()))
+
+ return coders[coder].function_spec.id
+
+ def output_tags(op):
+ return getattr(op, 'output_tags', ['out'])
+
+ def as_target(op_input):
+ input_op_index, input_output_index = op_input
+ input_op = map_task[input_op_index][1]
+ return {
+ 'ignored_input_tag':
+ beam_fn_api_pb2.Target.List(target=[
+ beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_index_to_id[
+ input_op_index],
+ name=output_tags(input_op)[input_output_index])
+ ])
+ }
+
+ def outputs(op):
+ return {
+ tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder))
+ for tag, coder in zip(output_tags(op), op.output_coders)
+ }
+
+ for op_ix, (stage_name, operation) in enumerate(map_task):
+ transform_id = transform_index_to_id[op_ix] = self._next_uid()
+ if isinstance(operation, operation_specs.WorkerInMemoryWrite):
+ # Write this data back to the runner.
+ fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_OUTPUT_URN,
+ id=self._next_uid())
+ if data_operation_spec:
+ fn.data.Pack(data_operation_spec)
+ inputs = as_target(operation.input)
+ side_inputs = {}
+ runner_sinks[(transform_id, 'out')] = operation
+
+ elif isinstance(operation, operation_specs.WorkerRead):
+ # A Read is either translated to a direct injection of windowed values
+ # into the sdk worker, or an injection of the source object into the
+ # sdk worker as data followed by an SDF that reads that source.
+ if (isinstance(operation.source.source,
+ worker_runner_base.InMemorySource)
+ and isinstance(operation.source.source.default_output_coder(),
+ WindowedValueCoder)):
+ output_stream = create_OutputStream()
+ element_coder = (
+ operation.source.source.default_output_coder().get_impl())
+ # Re-encode the elements in the nested context and
+ # concatenate them together
+ for element in operation.source.source.read(None):
+ element_coder.encode_to_stream(element, output_stream, True)
+ target_name = self._next_uid()
+ input_data[(transform_id, target_name)] = output_stream.get()
+ fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_INPUT_URN,
+ id=self._next_uid())
+ if data_operation_spec:
+ fn.data.Pack(data_operation_spec)
+ inputs = {target_name: beam_fn_api_pb2.Target.List()}
+ side_inputs = {}
+ else:
+ # Read the source object from the runner.
+ source_coder = beam.coders.DillCoder()
+ input_transform_id = self._next_uid()
+ output_stream = create_OutputStream()
+ source_coder.get_impl().encode_to_stream(
+ GlobalWindows.windowed_value(operation.source),
+ output_stream,
+ True)
+ target_name = self._next_uid()
+ input_data[(input_transform_id, target_name)] = output_stream.get()
+ input_ptransform = beam_fn_api_pb2.PrimitiveTransform(
+ id=input_transform_id,
+ function_spec=beam_fn_api_pb2.FunctionSpec(
+ urn=sdk_worker.DATA_INPUT_URN,
+ id=self._next_uid()),
+ # TODO(robertwb): Possible name collision.
+ step_name=stage_name + '/inject_source',
+ inputs={target_name: beam_fn_api_pb2.Target.List()},
+ outputs={
+ 'out':
+ beam_fn_api_pb2.PCollection(
+ coder_reference=coder_id(source_coder))
+ })
+ if data_operation_spec:
+ input_ptransform.function_spec.data.Pack(data_operation_spec)
+ transforms.append(input_ptransform)
+
+ # Read the elements out of the source.
+ fn = sdk_worker.pack_function_spec_data(
+ OLDE_SOURCE_SPLITTABLE_DOFN_DATA,
+ sdk_worker.PYTHON_DOFN_URN,
+ id=self._next_uid())
+ inputs = {
+ 'ignored_input_tag':
+ beam_fn_api_pb2.Target.List(target=[
+ beam_fn_api_pb2.Target(
+ primitive_transform_reference=input_transform_id,
+ name='out')
+ ])
+ }
+ side_inputs = {}
+
+ elif isinstance(operation, operation_specs.WorkerDoFn):
+ fn = sdk_worker.pack_function_spec_data(
+ operation.serialized_fn,
+ sdk_worker.PYTHON_DOFN_URN,
+ id=self._next_uid())
+ inputs = as_target(operation.input)
+ # Store the contents of each side input for state access.
+ for si in operation.side_inputs:
+ assert isinstance(si.source, iobase.BoundedSource)
+ element_coder = si.source.default_output_coder()
+ view_id = self._next_uid()
+ # TODO(robertwb): Actually flesh out the ViewFn API.
+ side_inputs[si.tag] = beam_fn_api_pb2.SideInput(
+ view_fn=sdk_worker.serialize_and_pack_py_fn(
+ element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN,
+ id=view_id))
+ # Re-encode the elements in the nested context and
+ # concatenate them together
+ output_stream = create_OutputStream()
+ for element in si.source.read(
+ si.source.get_range_tracker(None, None)):
+ element_coder.get_impl().encode_to_stream(
+ element, output_stream, True)
+ elements_data = output_stream.get()
+ state_key = beam_fn_api_pb2.StateKey(function_spec_reference=view_id)
+ state_handler.Clear(state_key)
+ state_handler.Append(
+ beam_fn_api_pb2.SimpleStateAppendRequest(
+ state_key=state_key, data=[elements_data]))
+
+ elif isinstance(operation, operation_specs.WorkerFlatten):
+ fn = sdk_worker.pack_function_spec_data(
+ operation.serialized_fn,
+ sdk_worker.IDENTITY_DOFN_URN,
+ id=self._next_uid())
+ inputs = {
+ 'ignored_input_tag':
+ beam_fn_api_pb2.Target.List(target=[
+ beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_index_to_id[
+ input_op_index],
+ name=output_tags(map_task[input_op_index][1])[
+ input_output_index])
+ for input_op_index, input_output_index in operation.inputs
+ ])
+ }
+ side_inputs = {}
+
+ else:
+ raise TypeError(operation)
+
+ ptransform = beam_fn_api_pb2.PrimitiveTransform(
+ id=transform_id,
+ function_spec=fn,
+ step_name=stage_name,
+ inputs=inputs,
+ side_inputs=side_inputs,
+ outputs=outputs(operation))
+ transforms.append(ptransform)
+
+ process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
+ id=self._next_uid(), coders=coders.values(),
+ primitive_transform=transforms)
+ return beam_fn_api_pb2.InstructionRequest(
+ instruction_id=self._next_uid(),
+ register=beam_fn_api_pb2.RegisterRequest(
+ process_bundle_descriptor=[process_bundle_descriptor
+ ])), runner_sinks, input_data
+
+ def _run_map_task(
+ self, map_task, control_handler, state_handler, data_plane_handler,
+ data_operation_spec):
+ registration, sinks, input_data = self._map_task_registration(
+ map_task, state_handler, data_operation_spec)
+ control_handler.push(registration)
+ process_bundle = beam_fn_api_pb2.InstructionRequest(
+ instruction_id=self._next_uid(),
+ process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
+ process_bundle_descriptor_reference=registration.register.
+ process_bundle_descriptor[0].id))
+
+ for (transform_id, name), elements in input_data.items():
+ data_out = data_plane_handler.output_stream(
+ process_bundle.instruction_id, beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_id, name=name))
+ data_out.write(elements)
+ data_out.close()
+
+ control_handler.push(process_bundle)
+ while True:
+ result = control_handler.pull()
+ if result.instruction_id == process_bundle.instruction_id:
+ if result.error:
+ raise RuntimeError(result.error)
+ expected_targets = [
+ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
+ name=output_name)
+ for (transform_id, output_name), _ in sinks.items()]
+ for output in data_plane_handler.input_elements(
+ process_bundle.instruction_id, expected_targets):
+ target_tuple = (
+ output.target.primitive_transform_reference, output.target.name)
+ if target_tuple not in sinks:
+ # Unconsumed output.
+ continue
+ sink_op = sinks[target_tuple]
+ coder = sink_op.output_coders[0]
+ input_stream = create_InputStream(output.data)
+ elements = []
+ while input_stream.size() > 0:
+ elements.append(coder.get_impl().decode_from_stream(
+ input_stream, True))
+ if not sink_op.write_windowed_values:
+ elements = [e.value for e in elements]
+ for e in elements:
+ sink_op.output_buffer.append(e)
+ return
+
+ def execute_map_tasks(self, ordered_map_tasks, direct=True):
+ if direct:
+ controller = FnApiRunner.DirectController()
+ else:
+ controller = FnApiRunner.GrpcController()
+
+ try:
+ for _, map_task in ordered_map_tasks:
+ logging.info('Running %s', map_task)
+ self._run_map_task(
+ map_task, controller.control_handler, controller.state_handler,
+ controller.data_plane_handler, controller.data_operation_spec())
+ finally:
+ controller.close()
+
+ class SimpleState(object): # TODO(robertwb): Inherit from GRPC servicer.
+
+ def __init__(self):
+ self._all = collections.defaultdict(list)
+
+ def Get(self, state_key):
+ return beam_fn_api_pb2.Elements.Data(
+ data=''.join(self._all[self._to_key(state_key)]))
+
+ def Append(self, append_request):
+ self._all[self._to_key(append_request.state_key)].extend(
+ append_request.data)
+
+ def Clear(self, state_key):
+ try:
+ del self._all[self._to_key(state_key)]
+ except KeyError:
+ pass
+
+ @staticmethod
+ def _to_key(state_key):
+ return (state_key.function_spec_reference, state_key.window,
+ state_key.key)
+
+ class DirectController(object):
+ """An in-memory controller for fn API control, state and data planes."""
+
+ def __init__(self):
+ self._responses = []
+ self.state_handler = FnApiRunner.SimpleState()
+ self.control_handler = self
+ self.data_plane_handler = data_plane.InMemoryDataChannel()
+ self.worker = sdk_worker.SdkWorker(
+ self.state_handler, data_plane.InMemoryDataChannelFactory(
+ self.data_plane_handler.inverse()))
+
+ def push(self, request):
+ logging.info('CONTROL REQUEST %s', request)
+ response = self.worker.do_instruction(request)
+ logging.info('CONTROL RESPONSE %s', response)
+ self._responses.append(response)
+
+ def pull(self):
+ return self._responses.pop(0)
+
+ def done(self):
+ pass
+
+ def close(self):
+ pass
+
+ def data_operation_spec(self):
+ return None
+
+ class GrpcController(object):
+ """An grpc based controller for fn API control, state and data planes."""
+
+ def __init__(self):
+ self.state_handler = FnApiRunner.SimpleState()
+ self.control_server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=10))
+ self.control_port = self.control_server.add_insecure_port('[::]:0')
+
+ self.data_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ self.data_port = self.data_server.add_insecure_port('[::]:0')
+
+ self.control_handler = streaming_rpc_handler(
+ beam_fn_api_pb2.BeamFnControlServicer, 'Control')
+ beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
+ self.control_handler, self.control_server)
+
+ self.data_plane_handler = data_plane.GrpcServerDataChannel()
+ beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+ self.data_plane_handler, self.data_server)
+
+ logging.info('starting control server on port %s', self.control_port)
+ logging.info('starting data server on port %s', self.data_port)
+ self.data_server.start()
+ self.control_server.start()
+
+ self.worker = sdk_worker.SdkHarness(
+ grpc.insecure_channel('localhost:%s' % self.control_port))
+ self.worker_thread = threading.Thread(target=self.worker.run)
+ logging.info('starting worker')
+ self.worker_thread.start()
+
+ def data_operation_spec(self):
+ url = 'localhost:%s' % self.data_port
+ remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+ remote_grpc_port.api_service_descriptor.url = url
+ return remote_grpc_port
+
+ def close(self):
+ self.control_handler.done()
+ self.worker_thread.join()
+ self.data_plane_handler.close()
+ self.control_server.stop(5).wait()
+ self.data_server.stop(5).wait()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
new file mode 100644
index 0000000..633602f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.runners.portability import fn_api_runner
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class FnApiRunnerTest(maptask_executor_runner.MapTaskExecutorRunner):
+
+ def create_pipeline(self):
+ return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
+
+ def test_combine_per_key(self):
+ # TODO(robertwb): Implement PGBKCV operation.
+ pass
+
+ # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
new file mode 100644
index 0000000..d273e18
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
@@ -0,0 +1,468 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Beam runner for testing/profiling worker code directly.
+"""
+
+import collections
+import logging
+import time
+
+import apache_beam as beam
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.runners import DataflowRunner
+from apache_beam.runners.dataflow.internal.dependency import _dependency_file_copy
+from apache_beam.runners.dataflow.internal.names import PropertyNames
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.runners.runner import PipelineRunner
+from apache_beam.runners.runner import PipelineState
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+try:
+ from apache_beam.runners.worker import statesampler
+except ImportError:
+ from apache_beam.runners.worker import statesampler_fake as statesampler
+from apache_beam.typehints import typehints
+from apache_beam.utils import pipeline_options
+from apache_beam.utils import profiler
+from apache_beam.utils.counters import CounterFactory
+
+
+class MapTaskExecutorRunner(PipelineRunner):
+ """Beam runner translating a pipeline into map tasks that are then executed.
+
+ Primarily intended for testing and profiling the worker code paths.
+ """
+
+ def __init__(self):
+ self.executors = []
+
+ def has_metrics_support(self):
+ """Returns whether this runner supports metrics or not.
+ """
+ return False
+
+ def run(self, pipeline):
+ MetricsEnvironment.set_metrics_supported(self.has_metrics_support())
+ # List of map tasks Each map task is a list of
+ # (stage_name, operation_specs.WorkerOperation) instructions.
+ self.map_tasks = []
+
+ # Map of pvalues to
+ # (map_task_index, producer_operation_index, producer_output_index)
+ self.outputs = {}
+
+ # Unique mappings of PCollections to strings.
+ self.side_input_labels = collections.defaultdict(
+ lambda: str(len(self.side_input_labels)))
+
+ # Mapping of map task indices to all map tasks that must preceed them.
+ self.dependencies = collections.defaultdict(set)
+
+ # Visit the graph, building up the map_tasks and their metadata.
+ super(MapTaskExecutorRunner, self).run(pipeline)
+
+ # Now run the tasks in topological order.
+ def compute_depth_map(deps):
+ memoized = {}
+
+ def compute_depth(x):
+ if x not in memoized:
+ memoized[x] = 1 + max([-1] + [compute_depth(y) for y in deps[x]])
+ return memoized[x]
+
+ return {x: compute_depth(x) for x in deps.keys()}
+
+ map_task_depths = compute_depth_map(self.dependencies)
+ ordered_map_tasks = sorted((map_task_depths.get(ix, -1), map_task)
+ for ix, map_task in enumerate(self.map_tasks))
+
+ profile_options = pipeline.options.view_as(
+ pipeline_options.ProfilingOptions)
+ if profile_options.profile_cpu:
+ with profiler.Profile(
+ profile_id='worker-runner',
+ profile_location=profile_options.profile_location,
+ log_results=True, file_copy_fn=_dependency_file_copy):
+ self.execute_map_tasks(ordered_map_tasks)
+ else:
+ self.execute_map_tasks(ordered_map_tasks)
+
+ return WorkerRunnerResult(PipelineState.UNKNOWN)
+
+ def metrics_containers(self):
+ return [op.metrics_container
+ for ex in self.executors
+ for op in ex.operations()]
+
+ def execute_map_tasks(self, ordered_map_tasks):
+ tt = time.time()
+ for ix, (_, map_task) in enumerate(ordered_map_tasks):
+ logging.info('Running %s', map_task)
+ t = time.time()
+ stage_names, all_operations = zip(*map_task)
+ # TODO(robertwb): The DataflowRunner worker receives system step names
+ # (e.g. "s3") that are used to label the output msec counters. We use the
+ # operation names here, but this is not the same scheme used by the
+ # DataflowRunner; the result is that the output msec counters are named
+ # differently.
+ system_names = stage_names
+ # Create the CounterFactory and StateSampler for this MapTask.
+ # TODO(robertwb): Output counters produced here are currently ignored.
+ counter_factory = CounterFactory()
+ state_sampler = statesampler.StateSampler('%s-' % ix, counter_factory)
+ map_executor = operations.SimpleMapTaskExecutor(
+ operation_specs.MapTask(
+ all_operations, 'S%02d' % ix,
+ system_names, stage_names, system_names),
+ counter_factory,
+ state_sampler)
+ self.executors.append(map_executor)
+ map_executor.execute()
+ logging.info(
+ 'Stage %s finished: %0.3f sec', stage_names[0], time.time() - t)
+ logging.info('Total time: %0.3f sec', time.time() - tt)
+
+ def run_Read(self, transform_node):
+ self._run_read_from(transform_node, transform_node.transform.source)
+
+ def _run_read_from(self, transform_node, source):
+ """Used when this operation is the result of reading source."""
+ if not isinstance(source, iobase.NativeSource):
+ source = iobase.SourceBundle(1.0, source, None, None)
+ output = transform_node.outputs[None]
+ element_coder = self._get_coder(output)
+ read_op = operation_specs.WorkerRead(source, output_coders=[element_coder])
+ self.outputs[output] = len(self.map_tasks), 0, 0
+ self.map_tasks.append([(transform_node.full_label, read_op)])
+ return len(self.map_tasks) - 1
+
+ def run_ParDo(self, transform_node):
+ transform = transform_node.transform
+ output = transform_node.outputs[None]
+ element_coder = self._get_coder(output)
+ map_task_index, producer_index, output_index = self.outputs[
+ transform_node.inputs[0]]
+
+ # If any of this ParDo's side inputs depend on outputs from this map_task,
+ # we can't continue growing this map task.
+ def is_reachable(leaf, root):
+ if leaf == root:
+ return True
+ else:
+ return any(is_reachable(x, root) for x in self.dependencies[leaf])
+
+ if any(is_reachable(self.outputs[side_input.pvalue][0], map_task_index)
+ for side_input in transform_node.side_inputs):
+ # Start a new map tasks.
+ input_element_coder = self._get_coder(transform_node.inputs[0])
+
+ output_buffer = OutputBuffer(input_element_coder)
+
+ fusion_break_write = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=True,
+ input=(producer_index, output_index),
+ output_coders=[input_element_coder])
+ self.map_tasks[map_task_index].append(
+ (transform_node.full_label + '/Write', fusion_break_write))
+
+ original_map_task_index = map_task_index
+ map_task_index, producer_index, output_index = len(self.map_tasks), 0, 0
+
+ fusion_break_read = operation_specs.WorkerRead(
+ output_buffer.source_bundle(),
+ output_coders=[input_element_coder])
+ self.map_tasks.append(
+ [(transform_node.full_label + '/Read', fusion_break_read)])
+
+ self.dependencies[map_task_index].add(original_map_task_index)
+
+ def create_side_read(side_input):
+ label = self.side_input_labels[side_input]
+ output_buffer = self.run_side_write(
+ side_input.pvalue, '%s/%s' % (transform_node.full_label, label))
+ return operation_specs.WorkerSideInputSource(
+ output_buffer.source(), label)
+
+ do_op = operation_specs.WorkerDoFn( #
+ serialized_fn=pickler.dumps(DataflowRunner._pardo_fn_data(
+ transform_node,
+ lambda side_input: self.side_input_labels[side_input])),
+ output_tags=[PropertyNames.OUT] + ['%s_%s' % (PropertyNames.OUT, tag)
+ for tag in transform.output_tags
+ ],
+ # Same assumption that DataflowRunner has about coders being compatible
+ # across outputs.
+ output_coders=[element_coder] * (len(transform.output_tags) + 1),
+ input=(producer_index, output_index),
+ side_inputs=[create_side_read(side_input)
+ for side_input in transform_node.side_inputs])
+
+ producer_index = len(self.map_tasks[map_task_index])
+ self.outputs[transform_node.outputs[None]] = (
+ map_task_index, producer_index, 0)
+ for ix, tag in enumerate(transform.output_tags):
+ self.outputs[transform_node.outputs[
+ tag]] = map_task_index, producer_index, ix + 1
+ self.map_tasks[map_task_index].append((transform_node.full_label, do_op))
+
+ for side_input in transform_node.side_inputs:
+ self.dependencies[map_task_index].add(self.outputs[side_input.pvalue][0])
+
+ def run_side_write(self, pcoll, label):
+ map_task_index, producer_index, output_index = self.outputs[pcoll]
+
+ windowed_element_coder = self._get_coder(pcoll)
+ output_buffer = OutputBuffer(windowed_element_coder)
+ write_sideinput_op = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=True,
+ input=(producer_index, output_index),
+ output_coders=[windowed_element_coder])
+ self.map_tasks[map_task_index].append(
+ (label, write_sideinput_op))
+ return output_buffer
+
+ def run_GroupByKeyOnly(self, transform_node):
+ map_task_index, producer_index, output_index = self.outputs[
+ transform_node.inputs[0]]
+ grouped_element_coder = self._get_coder(transform_node.outputs[None],
+ windowed=False)
+ windowed_ungrouped_element_coder = self._get_coder(transform_node.inputs[0])
+
+ output_buffer = GroupingOutputBuffer(grouped_element_coder)
+ shuffle_write = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=False,
+ input=(producer_index, output_index),
+ output_coders=[windowed_ungrouped_element_coder])
+ self.map_tasks[map_task_index].append(
+ (transform_node.full_label + '/Write', shuffle_write))
+
+ output_map_task_index = self._run_read_from(
+ transform_node, output_buffer.source())
+ self.dependencies[output_map_task_index].add(map_task_index)
+
+ def run_Flatten(self, transform_node):
+ output_buffer = OutputBuffer(self._get_coder(transform_node.outputs[None]))
+ output_map_task = self._run_read_from(transform_node,
+ output_buffer.source())
+
+ for input in transform_node.inputs:
+ map_task_index, producer_index, output_index = self.outputs[input]
+ element_coder = self._get_coder(input)
+ flatten_write = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=True,
+ input=(producer_index, output_index),
+ output_coders=[element_coder])
+ self.map_tasks[map_task_index].append(
+ (transform_node.full_label + '/Write', flatten_write))
+ self.dependencies[output_map_task].add(map_task_index)
+
+ def apply_CombinePerKey(self, transform, input):
+ # TODO(robertwb): Support side inputs.
+ assert not transform.args and not transform.kwargs
+ return (input
+ | PartialGroupByKeyCombineValues(transform.fn)
+ | beam.GroupByKey()
+ | MergeAccumulators(transform.fn)
+ | ExtractOutputs(transform.fn))
+
+ def run_PartialGroupByKeyCombineValues(self, transform_node):
+ element_coder = self._get_coder(transform_node.outputs[None])
+ _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+ combine_op = operation_specs.WorkerPartialGroupByKey(
+ combine_fn=pickler.dumps(
+ (transform_node.transform.combine_fn, (), {}, ())),
+ output_coders=[element_coder],
+ input=(producer_index, output_index))
+ self._run_as_op(transform_node, combine_op)
+
+ def run_MergeAccumulators(self, transform_node):
+ self._run_combine_transform(transform_node, 'merge')
+
+ def run_ExtractOutputs(self, transform_node):
+ self._run_combine_transform(transform_node, 'extract')
+
+ def _run_combine_transform(self, transform_node, phase):
+ transform = transform_node.transform
+ element_coder = self._get_coder(transform_node.outputs[None])
+ _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+ combine_op = operation_specs.WorkerCombineFn(
+ serialized_fn=pickler.dumps(
+ (transform.combine_fn, (), {}, ())),
+ phase=phase,
+ output_coders=[element_coder],
+ input=(producer_index, output_index))
+ self._run_as_op(transform_node, combine_op)
+
+ def _get_coder(self, pvalue, windowed=True):
+ # TODO(robertwb): This should be an attribute of the pvalue itself.
+ return DataflowRunner._get_coder(
+ pvalue.element_type or typehints.Any,
+ pvalue.windowing.windowfn.get_window_coder() if windowed else None)
+
+ def _run_as_op(self, transform_node, op):
+ """Single-output operation in the same map task as its input."""
+ map_task_index, _, _ = self.outputs[transform_node.inputs[0]]
+ op_index = len(self.map_tasks[map_task_index])
+ output = transform_node.outputs[None]
+ self.outputs[output] = map_task_index, op_index, 0
+ self.map_tasks[map_task_index].append((transform_node.full_label, op))
+
+
+class InMemorySource(iobase.BoundedSource):
+ """Source for reading an (as-yet unwritten) set of in-memory encoded elements.
+ """
+
+ def __init__(self, encoded_elements, coder):
+ self._encoded_elements = encoded_elements
+ self._coder = coder
+
+ def get_range_tracker(self, unused_start_position, unused_end_position):
+ return None
+
+ def read(self, unused_range_tracker):
+ for encoded_element in self._encoded_elements:
+ yield self._coder.decode(encoded_element)
+
+ def default_output_coder(self):
+ return self._coder
+
+
+class OutputBuffer(object):
+
+ def __init__(self, coder):
+ self.coder = coder
+ self.elements = []
+ self.encoded_elements = []
+
+ def source(self):
+ return InMemorySource(self.encoded_elements, self.coder)
+
+ def source_bundle(self):
+ return iobase.SourceBundle(
+ 1.0, InMemorySource(self.encoded_elements, self.coder), None, None)
+
+ def __repr__(self):
+ return 'GroupingOutput[%r]' % len(self.elements)
+
+ def append(self, value):
+ self.elements.append(value)
+ self.encoded_elements.append(self.coder.encode(value))
+
+
+class GroupingOutputBuffer(object):
+
+ def __init__(self, grouped_coder):
+ self.grouped_coder = grouped_coder
+ self.elements = collections.defaultdict(list)
+ self.frozen = False
+
+ def source(self):
+ return InMemorySource(self.encoded_elements, self.grouped_coder)
+
+ def __repr__(self):
+ return 'GroupingOutputBuffer[%r]' % len(self.elements)
+
+ def append(self, pair):
+ assert not self.frozen
+ k, v = pair
+ self.elements[k].append(v)
+
+ def freeze(self):
+ if not self.frozen:
+ self._encoded_elements = [self.grouped_coder.encode(kv)
+ for kv in self.elements.iteritems()]
+ self.frozen = True
+ return self._encoded_elements
+
+ @property
+ def encoded_elements(self):
+ return GroupedOutputBuffer(self)
+
+
+class GroupedOutputBuffer(object):
+
+ def __init__(self, buffer):
+ self.buffer = buffer
+
+ def __getitem__(self, ix):
+ return self.buffer.freeze()[ix]
+
+ def __iter__(self):
+ return iter(self.buffer.freeze())
+
+ def __len__(self):
+ return len(self.buffer.freeze())
+
+ def __nonzero__(self):
+ return True
+
+
+class PartialGroupByKeyCombineValues(beam.PTransform):
+
+ def __init__(self, combine_fn, native=True):
+ self.combine_fn = combine_fn
+ self.native = native
+
+ def expand(self, input):
+ if self.native:
+ return beam.pvalue.PCollection(input.pipeline)
+ else:
+ def to_accumulator(v):
+ return self.combine_fn.add_input(
+ self.combine_fn.create_accumulator(), v)
+ return input | beam.Map(lambda (k, v): (k, to_accumulator(v)))
+
+
+class MergeAccumulators(beam.PTransform):
+
+ def __init__(self, combine_fn, native=True):
+ self.combine_fn = combine_fn
+ self.native = native
+
+ def expand(self, input):
+ if self.native:
+ return beam.pvalue.PCollection(input.pipeline)
+ else:
+ merge_accumulators = self.combine_fn.merge_accumulators
+ return input | beam.Map(lambda (k, vs): (k, merge_accumulators(vs)))
+
+
+class ExtractOutputs(beam.PTransform):
+
+ def __init__(self, combine_fn, native=True):
+ self.combine_fn = combine_fn
+ self.native = native
+
+ def expand(self, input):
+ if self.native:
+ return beam.pvalue.PCollection(input.pipeline)
+ else:
+ extract_output = self.combine_fn.extract_output
+ return input | beam.Map(lambda (k, v): (k, extract_output(v)))
+
+
+class WorkerRunnerResult(PipelineResult):
+
+ def wait_until_finish(self, duration=None):
+ pass
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
new file mode 100644
index 0000000..6e13e73
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
@@ -0,0 +1,204 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import tempfile
+import unittest
+
+import apache_beam as beam
+
+from apache_beam.metrics import Metrics
+from apache_beam.metrics.execution import MetricKey
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics.metricbase import MetricName
+
+from apache_beam.pvalue import AsList
+from apache_beam.transforms.util import assert_that
+from apache_beam.transforms.util import BeamAssertException
+from apache_beam.transforms.util import equal_to
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class MapTaskExecutorRunnerTest(unittest.TestCase):
+
+ def create_pipeline(self):
+ return beam.Pipeline(runner=maptask_executor_runner.MapTaskExecutorRunner())
+
+ def test_assert_that(self):
+ with self.assertRaises(BeamAssertException):
+ with self.create_pipeline() as p:
+ assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
+
+ def test_create(self):
+ with self.create_pipeline() as p:
+ assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
+
+ def test_pardo(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create(['a', 'bc'])
+ | beam.Map(lambda e: e * 2)
+ | beam.Map(lambda e: e + 'x'))
+ assert_that(res, equal_to(['aax', 'bcbcx']))
+
+ def test_pardo_metrics(self):
+
+ class MyDoFn(beam.DoFn):
+
+ def start_bundle(self):
+ self.count = Metrics.counter(self.__class__, 'elements')
+
+ def process(self, element):
+ self.count.inc(element)
+ return [element]
+
+ class MyOtherDoFn(beam.DoFn):
+
+ def start_bundle(self):
+ self.count = Metrics.counter(self.__class__, 'elementsplusone')
+
+ def process(self, element):
+ self.count.inc(element + 1)
+ return [element]
+
+ with self.create_pipeline() as p:
+ res = (p | beam.Create([1, 2, 3])
+ | 'mydofn' >> beam.ParDo(MyDoFn())
+ | 'myotherdofn' >> beam.ParDo(MyOtherDoFn()))
+ p.run()
+ if not MetricsEnvironment.METRICS_SUPPORTED:
+ self.skipTest('Metrics are not supported.')
+
+ counter_updates = [{'key': key, 'value': val}
+ for container in p.runner.metrics_containers()
+ for key, val in
+ container.get_updates().counters.items()]
+ counter_values = [update['value'] for update in counter_updates]
+ counter_keys = [update['key'] for update in counter_updates]
+ assert_that(res, equal_to([1, 2, 3]))
+ self.assertEqual(counter_values, [6, 9])
+ self.assertEqual(counter_keys, [
+ MetricKey('mydofn',
+ MetricName(__name__ + '.MyDoFn', 'elements')),
+ MetricKey('myotherdofn',
+ MetricName(__name__ + '.MyOtherDoFn', 'elementsplusone'))])
+
+ def test_pardo_side_outputs(self):
+ def tee(elem, *tags):
+ for tag in tags:
+ if tag in elem:
+ yield beam.pvalue.OutputValue(tag, elem)
+ with self.create_pipeline() as p:
+ xy = (p
+ | 'Create' >> beam.Create(['x', 'y', 'xy'])
+ | beam.FlatMap(tee, 'x', 'y').with_outputs())
+ assert_that(xy.x, equal_to(['x', 'xy']), label='x')
+ assert_that(xy.y, equal_to(['y', 'xy']), label='y')
+
+ def test_pardo_side_and_main_outputs(self):
+ def even_odd(elem):
+ yield elem
+ yield beam.pvalue.OutputValue('odd' if elem % 2 else 'even', elem)
+ with self.create_pipeline() as p:
+ ints = p | beam.Create([1, 2, 3])
+ named = ints | 'named' >> beam.FlatMap(
+ even_odd).with_outputs('even', 'odd', main='all')
+ assert_that(named.all, equal_to([1, 2, 3]), label='named.all')
+ assert_that(named.even, equal_to([2]), label='named.even')
+ assert_that(named.odd, equal_to([1, 3]), label='named.odd')
+
+ unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs()
+ unnamed[None] | beam.Map(id) # pylint: disable=expression-not-assigned
+ assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all')
+ assert_that(unnamed.even, equal_to([2]), label='unnamed.even')
+ assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
+
+ def test_pardo_side_inputs(self):
+ def cross_product(elem, sides):
+ for side in sides:
+ yield elem, side
+ with self.create_pipeline() as p:
+ main = p | 'main' >> beam.Create(['a', 'b', 'c'])
+ side = p | 'side' >> beam.Create(['x', 'y'])
+ assert_that(main | beam.FlatMap(cross_product, AsList(side)),
+ equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'),
+ ('a', 'y'), ('b', 'y'), ('c', 'y')]))
+
+ def test_pardo_unfusable_side_inputs(self):
+ def cross_product(elem, sides):
+ for side in sides:
+ yield elem, side
+ with self.create_pipeline() as p:
+ pcoll = p | beam.Create(['a', 'b'])
+ assert_that(pcoll | beam.FlatMap(cross_product, AsList(pcoll)),
+ equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+ with self.create_pipeline() as p:
+ pcoll = p | beam.Create(['a', 'b'])
+ derived = ((pcoll,) | beam.Flatten()
+ | beam.Map(lambda x: (x, x))
+ | beam.GroupByKey()
+ | 'Unkey' >> beam.Map(lambda (x, _): x))
+ assert_that(
+ pcoll | beam.FlatMap(cross_product, AsList(derived)),
+ equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+ def test_group_by_key(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+ | beam.GroupByKey()
+ | beam.Map(lambda (k, vs): (k, sorted(vs))))
+ assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
+
+ def test_flatten(self):
+ with self.create_pipeline() as p:
+ res = (p | 'a' >> beam.Create(['a']),
+ p | 'bc' >> beam.Create(['b', 'c']),
+ p | 'd' >> beam.Create(['d'])) | beam.Flatten()
+ assert_that(res, equal_to(['a', 'b', 'c', 'd']))
+
+ def test_combine_per_key(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+ | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
+ assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
+
+ def test_read(self):
+ with tempfile.NamedTemporaryFile() as temp_file:
+ temp_file.write('a\nb\nc')
+ temp_file.flush()
+ with self.create_pipeline() as p:
+ assert_that(p | beam.io.ReadFromText(temp_file.name),
+ equal_to(['a', 'b', 'c']))
+
+ def test_windowing(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create([1, 2, 100, 101, 102])
+ | beam.Map(lambda t: TimestampedValue(('k', t), t))
+ | beam.WindowInto(beam.transforms.window.Sessions(10))
+ | beam.GroupByKey()
+ | beam.Map(lambda (k, vs): (k, sorted(vs))))
+ assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/__init__.py b/sdks/python/apache_beam/runners/worker/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
new file mode 100644
index 0000000..6425447
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -0,0 +1,288 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Implementation of DataChannels for communicating across the data plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import logging
+import Queue as queue
+import threading
+
+from apache_beam.coders import coder_impl
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class ClosableOutputStream(type(coder_impl.create_OutputStream())):
+ """A Outputstream for use with CoderImpls that has a close() method."""
+
+ def __init__(self, close_callback=None):
+ super(ClosableOutputStream, self).__init__()
+ self._close_callback = close_callback
+
+ def close(self):
+ if self._close_callback:
+ self._close_callback(self.get())
+
+
+class DataChannel(object):
+ """Represents a channel for reading and writing data over the data plane.
+
+ Read from this channel with the input_elements method::
+
+ for elements_data in data_channel.input_elements(instruction_id, targets):
+ [process elements_data]
+
+ Write to this channel using the output_stream method::
+
+ out1 = data_channel.output_stream(instruction_id, target1)
+ out1.write(...)
+ out1.close()
+
+ When all data for all instructions is written, close the channel::
+
+ data_channel.close()
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def input_elements(self, instruction_id, expected_targets):
+ """Returns an iterable of all Element.Data bundles for instruction_id.
+
+ This iterable terminates only once the full set of data has been recieved
+ for each of the expected targets. It may block waiting for more data.
+
+ Args:
+ instruction_id: which instruction the results must belong to
+ expected_targets: which targets to wait on for completion
+ """
+ raise NotImplementedError(type(self))
+
+ @abc.abstractmethod
+ def output_stream(self, instruction_id, target):
+ """Returns an output stream writing elements to target.
+
+ Args:
+ instruction_id: which instruction this stream belongs to
+ target: the target of the returned stream
+ """
+ raise NotImplementedError(type(self))
+
+ @abc.abstractmethod
+ def close(self):
+ """Closes this channel, indicating that all data has been written.
+
+ Data can continue to be read.
+
+ If this channel is shared by many instructions, should only be called on
+ worker shutdown.
+ """
+ raise NotImplementedError(type(self))
+
+
+class InMemoryDataChannel(DataChannel):
+ """An in-memory implementation of a DataChannel.
+
+ This channel is two-sided. What is written to one side is read by the other.
+ The inverse() method returns the other side of a instance.
+ """
+
+ def __init__(self, inverse=None):
+ self._inputs = []
+ self._inverse = inverse or InMemoryDataChannel(self)
+
+ def inverse(self):
+ return self._inverse
+
+ def input_elements(self, instruction_id, unused_expected_targets=None):
+ for data in self._inputs:
+ if data.instruction_reference == instruction_id:
+ yield data
+
+ def output_stream(self, instruction_id, target):
+ def add_to_inverse_output(data):
+ self._inverse._inputs.append( # pylint: disable=protected-access
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference=instruction_id,
+ target=target,
+ data=data))
+ return ClosableOutputStream(add_to_inverse_output)
+
+ def close(self):
+ pass
+
+
+class _GrpcDataChannel(DataChannel):
+ """Base class for implementing a BeamFnData-based DataChannel."""
+
+ _WRITES_FINISHED = object()
+
+ def __init__(self):
+ self._to_send = queue.Queue()
+ self._received = collections.defaultdict(queue.Queue)
+ self._receive_lock = threading.Lock()
+ self._reads_finished = threading.Event()
+
+ def close(self):
+ self._to_send.put(self._WRITES_FINISHED)
+
+ def wait(self, timeout=None):
+ self._reads_finished.wait(timeout)
+
+ def _receiving_queue(self, instruction_id):
+ with self._receive_lock:
+ return self._received[instruction_id]
+
+ def input_elements(self, instruction_id, expected_targets):
+ received = self._receiving_queue(instruction_id)
+ done_targets = []
+ while len(done_targets) < len(expected_targets):
+ data = received.get()
+ if not data.data and data.target in expected_targets:
+ done_targets.append(data.target)
+ else:
+ assert data.target not in done_targets
+ yield data
+
+ def output_stream(self, instruction_id, target):
+ def add_to_send_queue(data):
+ self._to_send.put(
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference=instruction_id,
+ target=target,
+ data=data))
+ self._to_send.put(
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference=instruction_id,
+ target=target,
+ data=''))
+ return ClosableOutputStream(add_to_send_queue)
+
+ def _write_outputs(self):
+ done = False
+ while not done:
+ data = [self._to_send.get()]
+ try:
+ # Coalesce up to 100 other items.
+ for _ in range(100):
+ data.append(self._to_send.get_nowait())
+ except queue.Empty:
+ pass
+ if data[-1] is self._WRITES_FINISHED:
+ done = True
+ data.pop()
+ if data:
+ yield beam_fn_api_pb2.Elements(data=data)
+
+ def _read_inputs(self, elements_iterator):
+ # TODO(robertwb): Pushback/throttling to avoid unbounded buffering.
+ try:
+ for elements in elements_iterator:
+ for data in elements.data:
+ self._receiving_queue(data.instruction_reference).put(data)
+ except: # pylint: disable=broad-except
+ logging.exception('Failed to read inputs in the data plane')
+ raise
+ finally:
+ self._reads_finished.set()
+
+ def _start_reader(self, elements_iterator):
+ reader = threading.Thread(
+ target=lambda: self._read_inputs(elements_iterator),
+ name='read_grpc_client_inputs')
+ reader.daemon = True
+ reader.start()
+
+
+class GrpcClientDataChannel(_GrpcDataChannel):
+ """A DataChannel wrapping the client side of a BeamFnData connection."""
+
+ def __init__(self, data_stub):
+ super(GrpcClientDataChannel, self).__init__()
+ self._start_reader(data_stub.Data(self._write_outputs()))
+
+
+class GrpcServerDataChannel(
+ beam_fn_api_pb2.BeamFnDataServicer, _GrpcDataChannel):
+ """A DataChannel wrapping the server side of a BeamFnData connection."""
+
+ def Data(self, elements_iterator, context):
+ self._start_reader(elements_iterator)
+ for elements in self._write_outputs():
+ yield elements
+
+
+class DataChannelFactory(object):
+ """An abstract factory for creating ``DataChannel``."""
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def create_data_channel(self, function_spec):
+ """Returns a ``DataChannel`` from the given function_spec."""
+ raise NotImplementedError(type(self))
+
+ @abc.abstractmethod
+ def close(self):
+ """Close all channels that this factory owns."""
+ raise NotImplementedError(type(self))
+
+
+class GrpcClientDataChannelFactory(DataChannelFactory):
+ """A factory for ``GrpcClientDataChannel``.
+
+ Caches the created channels by ``data descriptor url``.
+ """
+
+ def __init__(self):
+ self._data_channel_cache = {}
+
+ def create_data_channel(self, function_spec):
+ remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+ function_spec.data.Unpack(remote_grpc_port)
+ url = remote_grpc_port.api_service_descriptor.url
+ if url not in self._data_channel_cache:
+ logging.info('Creating channel for %s', url)
+ grpc_channel = grpc.insecure_channel(url)
+ self._data_channel_cache[url] = GrpcClientDataChannel(
+ beam_fn_api_pb2.BeamFnDataStub(grpc_channel))
+ return self._data_channel_cache[url]
+
+ def close(self):
+ logging.info('Closing all cached grpc data channels.')
+ for _, channel in self._data_channel_cache.items():
+ channel.close()
+ self._data_channel_cache.clear()
+
+
+class InMemoryDataChannelFactory(DataChannelFactory):
+ """A singleton factory for ``InMemoryDataChannel``."""
+
+ def __init__(self, in_memory_data_channel):
+ self._in_memory_data_channel = in_memory_data_channel
+
+ def create_data_channel(self, unused_function_spec):
+ return self._in_memory_data_channel
+
+ def close(self):
+ pass
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py
new file mode 100644
index 0000000..7340789
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for apache_beam.runners.worker.data_plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import sys
+import threading
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import data_plane
+
+
+def timeout(timeout_secs):
+ def decorate(fn):
+ exc_info = []
+
+ def wrapper(*args, **kwargs):
+ def call_fn():
+ try:
+ fn(*args, **kwargs)
+ except: # pylint: disable=bare-except
+ exc_info[:] = sys.exc_info()
+ thread = threading.Thread(target=call_fn)
+ thread.daemon = True
+ thread.start()
+ thread.join(timeout_secs)
+ if exc_info:
+ t, v, tb = exc_info # pylint: disable=unbalanced-tuple-unpacking
+ raise t, v, tb
+ assert not thread.is_alive(), 'timed out after %s seconds' % timeout_secs
+ return wrapper
+ return decorate
+
+
+class DataChannelTest(unittest.TestCase):
+
+ @timeout(5)
+ def test_grpc_data_channel(self):
+ data_channel_service = data_plane.GrpcServerDataChannel()
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+ beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+ data_channel_service, server)
+ test_port = server.add_insecure_port('[::]:0')
+ server.start()
+
+ data_channel_stub = beam_fn_api_pb2.BeamFnDataStub(
+ grpc.insecure_channel('localhost:%s' % test_port))
+ data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)
+
+ try:
+ self._data_channel_test(data_channel_service, data_channel_client)
+ finally:
+ data_channel_client.close()
+ data_channel_service.close()
+ data_channel_client.wait()
+ data_channel_service.wait()
+
+ def test_in_memory_data_channel(self):
+ channel = data_plane.InMemoryDataChannel()
+ self._data_channel_test(channel, channel.inverse())
+
+ def _data_channel_test(self, server, client):
+ self._data_channel_test_one_direction(server, client)
+ self._data_channel_test_one_direction(client, server)
+
+ def _data_channel_test_one_direction(self, from_channel, to_channel):
+ def send(instruction_id, target, data):
+ stream = from_channel.output_stream(instruction_id, target)
+ stream.write(data)
+ stream.close()
+ target_1 = beam_fn_api_pb2.Target(
+ primitive_transform_reference='1',
+ name='out')
+ target_2 = beam_fn_api_pb2.Target(
+ primitive_transform_reference='2',
+ name='out')
+
+ # Single write.
+ send('0', target_1, 'abc')
+ self.assertEqual(
+ list(to_channel.input_elements('0', [target_1])),
+ [beam_fn_api_pb2.Elements.Data(
+ instruction_reference='0',
+ target=target_1,
+ data='abc')])
+
+ # Multiple interleaved writes to multiple instructions.
+ target_2 = beam_fn_api_pb2.Target(
+ primitive_transform_reference='2',
+ name='out')
+
+ send('1', target_1, 'abc')
+ send('2', target_1, 'def')
+ self.assertEqual(
+ list(to_channel.input_elements('1', [target_1])),
+ [beam_fn_api_pb2.Elements.Data(
+ instruction_reference='1',
+ target=target_1,
+ data='abc')])
+ send('2', target_2, 'ghi')
+ self.assertEqual(
+ list(to_channel.input_elements('2', [target_1, target_2])),
+ [beam_fn_api_pb2.Elements.Data(
+ instruction_reference='2',
+ target=target_1,
+ data='def'),
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference='2',
+ target=target_2,
+ data='ghi')])
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py
new file mode 100644
index 0000000..b9e36ad
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Beam fn API log handler."""
+
+import logging
+import math
+import Queue as queue
+import threading
+
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class FnApiLogRecordHandler(logging.Handler):
+ """A handler that writes log records to the fn API."""
+
+ # Maximum number of log entries in a single stream request.
+ _MAX_BATCH_SIZE = 1000
+ # Used to indicate the end of stream.
+ _FINISHED = object()
+
+ # Mapping from logging levels to LogEntry levels.
+ LOG_LEVEL_MAP = {
+ logging.FATAL: beam_fn_api_pb2.LogEntry.CRITICAL,
+ logging.ERROR: beam_fn_api_pb2.LogEntry.ERROR,
+ logging.WARNING: beam_fn_api_pb2.LogEntry.WARN,
+ logging.INFO: beam_fn_api_pb2.LogEntry.INFO,
+ logging.DEBUG: beam_fn_api_pb2.LogEntry.DEBUG
+ }
+
+ def __init__(self, log_service_descriptor):
+ super(FnApiLogRecordHandler, self).__init__()
+ self._log_channel = grpc.insecure_channel(log_service_descriptor.url)
+ self._logging_stub = beam_fn_api_pb2.BeamFnLoggingStub(self._log_channel)
+ self._log_entry_queue = queue.Queue()
+
+ log_control_messages = self._logging_stub.Logging(self._write_log_entries())
+ self._reader = threading.Thread(
+ target=lambda: self._read_log_control_messages(log_control_messages),
+ name='read_log_control_messages')
+ self._reader.daemon = True
+ self._reader.start()
+
+ def emit(self, record):
+ log_entry = beam_fn_api_pb2.LogEntry()
+ log_entry.severity = self.LOG_LEVEL_MAP[record.levelno]
+ log_entry.message = self.format(record)
+ log_entry.thread = record.threadName
+ log_entry.log_location = record.module + '.' + record.funcName
+ (fraction, seconds) = math.modf(record.created)
+ nanoseconds = 1e9 * fraction
+ log_entry.timestamp.seconds = int(seconds)
+ log_entry.timestamp.nanos = int(nanoseconds)
+ self._log_entry_queue.put(log_entry)
+
+ def close(self):
+ """Flush out all existing log entries and unregister this handler."""
+ # Acquiring the handler lock ensures ``emit`` is not run until the lock is
+ # released.
+ self.acquire()
+ self._log_entry_queue.put(self._FINISHED)
+ # wait on server to close.
+ self._reader.join()
+ self.release()
+ # Unregister this handler.
+ super(FnApiLogRecordHandler, self).close()
+
+ def _write_log_entries(self):
+ done = False
+ while not done:
+ log_entries = [self._log_entry_queue.get()]
+ try:
+ for _ in range(self._MAX_BATCH_SIZE):
+ log_entries.append(self._log_entry_queue.get_nowait())
+ except queue.Empty:
+ pass
+ if log_entries[-1] is self._FINISHED:
+ done = True
+ log_entries.pop()
+ if log_entries:
+ yield beam_fn_api_pb2.LogEntry.List(log_entries=log_entries)
+
+ def _read_log_control_messages(self, log_control_iterator):
+ # TODO(vikasrk): Handle control messages.
+ for _ in log_control_iterator:
+ pass
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py
new file mode 100644
index 0000000..565bedb
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py
@@ -0,0 +1,105 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import logging
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import log_handler
+
+
+class BeamFnLoggingServicer(beam_fn_api_pb2.BeamFnLoggingServicer):
+
+ def __init__(self):
+ self.log_records_received = []
+
+ def Logging(self, request_iterator, context):
+
+ for log_record in request_iterator:
+ self.log_records_received.append(log_record)
+
+ yield beam_fn_api_pb2.LogControl()
+
+
+class FnApiLogRecordHandlerTest(unittest.TestCase):
+
+ def setUp(self):
+ self.test_logging_service = BeamFnLoggingServicer()
+ self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ beam_fn_api_pb2.add_BeamFnLoggingServicer_to_server(
+ self.test_logging_service, self.server)
+ self.test_port = self.server.add_insecure_port('[::]:0')
+ self.server.start()
+
+ self.logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+ self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
+ self.fn_log_handler = log_handler.FnApiLogRecordHandler(
+ self.logging_service_descriptor)
+ logging.getLogger().setLevel(logging.INFO)
+ logging.getLogger().addHandler(self.fn_log_handler)
+
+ def tearDown(self):
+ # wait upto 5 seconds.
+ self.server.stop(5)
+
+ def _verify_fn_log_handler(self, num_log_entries):
+ msg = 'Testing fn logging'
+ logging.debug('Debug Message 1')
+ for idx in range(num_log_entries):
+ logging.info('%s: %s', msg, idx)
+ logging.debug('Debug Message 2')
+
+ # Wait for logs to be sent to server.
+ self.fn_log_handler.close()
+
+ num_received_log_entries = 0
+ for outer in self.test_logging_service.log_records_received:
+ for log_entry in outer.log_entries:
+ self.assertEqual(beam_fn_api_pb2.LogEntry.INFO, log_entry.severity)
+ self.assertEqual('%s: %s' % (msg, num_received_log_entries),
+ log_entry.message)
+ self.assertEqual(u'log_handler_test._verify_fn_log_handler',
+ log_entry.log_location)
+ self.assertGreater(log_entry.timestamp.seconds, 0)
+ self.assertGreater(log_entry.timestamp.nanos, 0)
+ num_received_log_entries += 1
+
+ self.assertEqual(num_received_log_entries, num_log_entries)
+
+
+# Test cases.
+data = {
+ 'one_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE - 47,
+ 'exact_multiple': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE,
+ 'multi_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE * 3 + 47
+}
+
+
+def _create_test(name, num_logs):
+ setattr(FnApiLogRecordHandlerTest, 'test_%s' % name,
+ lambda self: self._verify_fn_log_handler(num_logs))
+
+
+if __name__ == '__main__':
+ for test_name, num_logs_entries in data.iteritems():
+ _create_test(test_name, num_logs_entries)
+
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.pxd b/sdks/python/apache_beam/runners/worker/logger.pxd
new file mode 100644
index 0000000..201daf4
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.pxd
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+
+from apache_beam.runners.common cimport LoggingContext
+
+
+cdef class PerThreadLoggingContext(LoggingContext):
+ cdef kwargs
+ cdef list stack
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py
new file mode 100644
index 0000000..217dc58
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.py
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Python worker logging."""
+
+import json
+import logging
+import threading
+import traceback
+
+from apache_beam.runners.common import LoggingContext
+
+
+# Per-thread worker information. This is used only for logging to set
+# context information that changes while work items get executed:
+# work_item_id, step_name, stage_name.
+class _PerThreadWorkerData(threading.local):
+
+ def __init__(self):
+ super(_PerThreadWorkerData, self).__init__()
+ # TODO(robertwb): Consider starting with an initial (ignored) ~20 elements
+ # in the list, as going up and down all the way to zero incurs several
+ # reallocations.
+ self.stack = []
+
+ def get_data(self):
+ all_data = {}
+ for datum in self.stack:
+ all_data.update(datum)
+ return all_data
+
+
+per_thread_worker_data = _PerThreadWorkerData()
+
+
+class PerThreadLoggingContext(LoggingContext):
+ """A context manager to add per thread attributes."""
+
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.stack = per_thread_worker_data.stack
+
+ def __enter__(self):
+ self.enter()
+
+ def enter(self):
+ self.stack.append(self.kwargs)
+
+ def __exit__(self, exn_type, exn_value, exn_traceback):
+ self.exit()
+
+ def exit(self):
+ self.stack.pop()
+
+
+class JsonLogFormatter(logging.Formatter):
+ """A JSON formatter class as expected by the logging standard module."""
+
+ def __init__(self, job_id, worker_id):
+ super(JsonLogFormatter, self).__init__()
+ self.job_id = job_id
+ self.worker_id = worker_id
+
+ def format(self, record):
+ """Returns a JSON string based on a LogRecord instance.
+
+ Args:
+ record: A LogRecord instance. See below for details.
+
+ Returns:
+ A JSON string representing the record.
+
+ A LogRecord instance has the following attributes and is used for
+ formatting the final message.
+
+ Attributes:
+ created: A double representing the timestamp for record creation
+ (e.g., 1438365207.624597). Note that the number contains also msecs and
+ microsecs information. Part of this is also available in the 'msecs'
+ attribute.
+ msecs: A double representing the msecs part of the record creation
+ (e.g., 624.5970726013184).
+ msg: Logging message containing formatting instructions or an arbitrary
+ object. This is the first argument of a log call.
+ args: A tuple containing the positional arguments for the logging call.
+ levelname: A string. Possible values are: INFO, WARNING, ERROR, etc.
+ exc_info: None or a 3-tuple with exception information as it is
+ returned by a call to sys.exc_info().
+ name: Logger's name. Most logging is done using the default root logger
+ and therefore the name will be 'root'.
+ filename: Basename of the file where logging occurred.
+ funcName: Name of the function where logging occurred.
+ process: The PID of the process running the worker.
+ thread: An id for the thread where the record was logged. This is not a
+ real TID (the one provided by OS) but rather the id (address) of a
+ Python thread object. Nevertheless having this value can allow to
+ filter log statement from only one specific thread.
+ """
+ output = {}
+ output['timestamp'] = {
+ 'seconds': int(record.created),
+ 'nanos': int(record.msecs * 1000000)}
+ # ERROR. INFO, DEBUG log levels translate into the same for severity
+ # property. WARNING becomes WARN.
+ output['severity'] = (
+ record.levelname if record.levelname != 'WARNING' else 'WARN')
+
+ # msg could be an arbitrary object, convert it to a string first.
+ record_msg = str(record.msg)
+
+ # Prepare the actual message using the message formatting string and the
+ # positional arguments as they have been used in the log call.
+ if record.args:
+ try:
+ output['message'] = record_msg % record.args
+ except (TypeError, ValueError):
+ output['message'] = '%s with args (%s)' % (record_msg, record.args)
+ else:
+ output['message'] = record_msg
+
+ # The thread ID is logged as a combination of the process ID and thread ID
+ # since workers can run in multiple processes.
+ output['thread'] = '%s:%s' % (record.process, record.thread)
+ # job ID and worker ID. These do not change during the lifetime of a worker.
+ output['job'] = self.job_id
+ output['worker'] = self.worker_id
+ # Stage, step and work item ID come from thread local storage since they
+ # change with every new work item leased for execution. If there is no
+ # work item ID then we make sure the step is undefined too.
+ data = per_thread_worker_data.get_data()
+ if 'work_item_id' in data:
+ output['work'] = data['work_item_id']
+ if 'stage_name' in data:
+ output['stage'] = data['stage_name']
+ if 'step_name' in data:
+ output['step'] = data['step_name']
+ # All logging happens using the root logger. We will add the basename of the
+ # file and the function name where the logging happened to make it easier
+ # to identify who generated the record.
+ output['logger'] = '%s:%s:%s' % (
+ record.name, record.filename, record.funcName)
+ # Add exception information if any is available.
+ if record.exc_info:
+ output['exception'] = ''.join(
+ traceback.format_exception(*record.exc_info))
+
+ return json.dumps(output)
+
+
+def initialize(job_id, worker_id, log_path):
+ """Initialize root logger so that we log JSON to a file and text to stdout."""
+
+ file_handler = logging.FileHandler(log_path)
+ file_handler.setFormatter(JsonLogFormatter(job_id, worker_id))
+ logging.getLogger().addHandler(file_handler)
+
+ # Set default level to INFO to avoid logging various DEBUG level log calls
+ # sprinkled throughout the code.
+ logging.getLogger().setLevel(logging.INFO)