You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2017/05/02 00:45:04 UTC
[3/4] beam git commit: Fn API support for Python.
Fn API support for Python.
Also added supporting worker code and a runner
that exercises this code directly.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a856fcf3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a856fcf3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a856fcf3
Branch: refs/heads/master
Commit: a856fcf3f9ca48530078e1d7610e234110cd5890
Parents: b8131fe
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Thu Apr 20 15:59:25 2017 -0500
Committer: Robert Bradshaw <ro...@gmail.com>
Committed: Mon May 1 17:44:29 2017 -0700
----------------------------------------------------------------------
.../apache_beam/runners/portability/__init__.py | 16 +
.../runners/portability/fn_api_runner.py | 471 ++++++++++++++
.../runners/portability/fn_api_runner_test.py | 40 ++
.../portability/maptask_executor_runner.py | 468 +++++++++++++
.../portability/maptask_executor_runner_test.py | 204 ++++++
.../apache_beam/runners/worker/__init__.py | 16 +
.../apache_beam/runners/worker/data_plane.py | 288 ++++++++
.../runners/worker/data_plane_test.py | 139 ++++
.../apache_beam/runners/worker/log_handler.py | 100 +++
.../runners/worker/log_handler_test.py | 105 +++
.../apache_beam/runners/worker/logger.pxd | 25 +
.../python/apache_beam/runners/worker/logger.py | 173 +++++
.../apache_beam/runners/worker/logger_test.py | 182 ++++++
.../apache_beam/runners/worker/opcounters.pxd | 45 ++
.../apache_beam/runners/worker/opcounters.py | 162 +++++
.../runners/worker/opcounters_test.py | 149 +++++
.../runners/worker/operation_specs.py | 368 +++++++++++
.../apache_beam/runners/worker/operations.pxd | 89 +++
.../apache_beam/runners/worker/operations.py | 651 +++++++++++++++++++
.../apache_beam/runners/worker/sdk_worker.py | 451 +++++++++++++
.../runners/worker/sdk_worker_main.py | 62 ++
.../runners/worker/sdk_worker_test.py | 168 +++++
.../apache_beam/runners/worker/sideinputs.py | 166 +++++
.../runners/worker/sideinputs_test.py | 150 +++++
.../apache_beam/runners/worker/statesampler.pyx | 237 +++++++
.../runners/worker/statesampler_fake.py | 34 +
.../runners/worker/statesampler_test.py | 102 +++
sdks/python/generate_pydoc.sh | 2 +
sdks/python/setup.py | 8 +-
29 files changed, 5069 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/__init__.py b/sdks/python/apache_beam/runners/portability/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
new file mode 100644
index 0000000..5802c17
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -0,0 +1,471 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""A PipelineRunner using the SDK harness.
+"""
+import collections
+import json
+import logging
+import Queue as queue
+import threading
+
+import grpc
+from concurrent import futures
+
+import apache_beam as beam
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.coders.coder_impl import create_InputStream
+from apache_beam.coders.coder_impl import create_OutputStream
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.transforms.window import GlobalWindows
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.portability import maptask_executor_runner
+from apache_beam.runners.worker import data_plane
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import sdk_worker
+
+
+def streaming_rpc_handler(cls, method_name):
+ """Un-inverts the flow of control between the runner and the sdk harness."""
+
+ class StreamingRpcHandler(cls):
+
+ _DONE = object()
+
+ def __init__(self):
+ self._push_queue = queue.Queue()
+ self._pull_queue = queue.Queue()
+ setattr(self, method_name, self.run)
+ self._read_thread = threading.Thread(target=self._read)
+
+ def run(self, iterator, context):
+ self._inputs = iterator
+ # Note: We only support one client for now.
+ self._read_thread.start()
+ while True:
+ to_push = self._push_queue.get()
+ if to_push is self._DONE:
+ return
+ yield to_push
+
+ def _read(self):
+ for data in self._inputs:
+ self._pull_queue.put(data)
+
+ def push(self, item):
+ self._push_queue.put(item)
+
+ def pull(self, timeout=None):
+ return self._pull_queue.get(timeout=timeout)
+
+ def empty(self):
+ return self._pull_queue.empty()
+
+ def done(self):
+ self.push(self._DONE)
+ self._read_thread.join()
+
+ return StreamingRpcHandler()
+
+
+class OldeSourceSplittableDoFn(beam.DoFn):
+ """A DoFn that reads and emits an entire source.
+ """
+
+ # TODO(robertwb): Make this a full SDF with progress splitting, etc.
+ def process(self, source):
+ if isinstance(source, iobase.SourceBundle):
+ for value in source.source.read(source.source.get_range_tracker(
+ source.start_position, source.stop_position)):
+ yield value
+ else:
+ # Dataflow native source
+ with source.reader() as reader:
+ for value in reader:
+ yield value
+
+
+# See DataflowRunner._pardo_fn_data
+OLDE_SOURCE_SPLITTABLE_DOFN_DATA = pickler.dumps(
+ (OldeSourceSplittableDoFn(), (), {}, [],
+ beam.transforms.core.Windowing(GlobalWindows())))
+
+
+class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
+
+ def __init__(self):
+ super(FnApiRunner, self).__init__()
+ self._last_uid = -1
+
+ def has_metrics_support(self):
+ return False
+
+ def _next_uid(self):
+ self._last_uid += 1
+ return str(self._last_uid)
+
+ def _map_task_registration(self, map_task, state_handler,
+ data_operation_spec):
+ input_data = {}
+ runner_sinks = {}
+ transforms = []
+ transform_index_to_id = {}
+
+ # Maps coders to new coder objects and references.
+ coders = {}
+
+ def coder_id(coder):
+ if coder not in coders:
+ coders[coder] = beam_fn_api_pb2.Coder(
+ function_spec=sdk_worker.pack_function_spec_data(
+ json.dumps(coder.as_cloud_object()),
+ sdk_worker.PYTHON_CODER_URN, id=self._next_uid()))
+
+ return coders[coder].function_spec.id
+
+ def output_tags(op):
+ return getattr(op, 'output_tags', ['out'])
+
+ def as_target(op_input):
+ input_op_index, input_output_index = op_input
+ input_op = map_task[input_op_index][1]
+ return {
+ 'ignored_input_tag':
+ beam_fn_api_pb2.Target.List(target=[
+ beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_index_to_id[
+ input_op_index],
+ name=output_tags(input_op)[input_output_index])
+ ])
+ }
+
+ def outputs(op):
+ return {
+ tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder))
+ for tag, coder in zip(output_tags(op), op.output_coders)
+ }
+
+ for op_ix, (stage_name, operation) in enumerate(map_task):
+ transform_id = transform_index_to_id[op_ix] = self._next_uid()
+ if isinstance(operation, operation_specs.WorkerInMemoryWrite):
+ # Write this data back to the runner.
+ fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_OUTPUT_URN,
+ id=self._next_uid())
+ if data_operation_spec:
+ fn.data.Pack(data_operation_spec)
+ inputs = as_target(operation.input)
+ side_inputs = {}
+ runner_sinks[(transform_id, 'out')] = operation
+
+ elif isinstance(operation, operation_specs.WorkerRead):
+ # A Read is either translated to a direct injection of windowed values
+ # into the sdk worker, or an injection of the source object into the
+ # sdk worker as data followed by an SDF that reads that source.
+ if (isinstance(operation.source.source,
+ worker_runner_base.InMemorySource)
+ and isinstance(operation.source.source.default_output_coder(),
+ WindowedValueCoder)):
+ output_stream = create_OutputStream()
+ element_coder = (
+ operation.source.source.default_output_coder().get_impl())
+ # Re-encode the elements in the nested context and
+ # concatenate them together
+ for element in operation.source.source.read(None):
+ element_coder.encode_to_stream(element, output_stream, True)
+ target_name = self._next_uid()
+ input_data[(transform_id, target_name)] = output_stream.get()
+ fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_INPUT_URN,
+ id=self._next_uid())
+ if data_operation_spec:
+ fn.data.Pack(data_operation_spec)
+ inputs = {target_name: beam_fn_api_pb2.Target.List()}
+ side_inputs = {}
+ else:
+ # Read the source object from the runner.
+ source_coder = beam.coders.DillCoder()
+ input_transform_id = self._next_uid()
+ output_stream = create_OutputStream()
+ source_coder.get_impl().encode_to_stream(
+ GlobalWindows.windowed_value(operation.source),
+ output_stream,
+ True)
+ target_name = self._next_uid()
+ input_data[(input_transform_id, target_name)] = output_stream.get()
+ input_ptransform = beam_fn_api_pb2.PrimitiveTransform(
+ id=input_transform_id,
+ function_spec=beam_fn_api_pb2.FunctionSpec(
+ urn=sdk_worker.DATA_INPUT_URN,
+ id=self._next_uid()),
+ # TODO(robertwb): Possible name collision.
+ step_name=stage_name + '/inject_source',
+ inputs={target_name: beam_fn_api_pb2.Target.List()},
+ outputs={
+ 'out':
+ beam_fn_api_pb2.PCollection(
+ coder_reference=coder_id(source_coder))
+ })
+ if data_operation_spec:
+ input_ptransform.function_spec.data.Pack(data_operation_spec)
+ transforms.append(input_ptransform)
+
+ # Read the elements out of the source.
+ fn = sdk_worker.pack_function_spec_data(
+ OLDE_SOURCE_SPLITTABLE_DOFN_DATA,
+ sdk_worker.PYTHON_DOFN_URN,
+ id=self._next_uid())
+ inputs = {
+ 'ignored_input_tag':
+ beam_fn_api_pb2.Target.List(target=[
+ beam_fn_api_pb2.Target(
+ primitive_transform_reference=input_transform_id,
+ name='out')
+ ])
+ }
+ side_inputs = {}
+
+ elif isinstance(operation, operation_specs.WorkerDoFn):
+ fn = sdk_worker.pack_function_spec_data(
+ operation.serialized_fn,
+ sdk_worker.PYTHON_DOFN_URN,
+ id=self._next_uid())
+ inputs = as_target(operation.input)
+ # Store the contents of each side input for state access.
+ for si in operation.side_inputs:
+ assert isinstance(si.source, iobase.BoundedSource)
+ element_coder = si.source.default_output_coder()
+ view_id = self._next_uid()
+ # TODO(robertwb): Actually flesh out the ViewFn API.
+ side_inputs[si.tag] = beam_fn_api_pb2.SideInput(
+ view_fn=sdk_worker.serialize_and_pack_py_fn(
+ element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN,
+ id=view_id))
+ # Re-encode the elements in the nested context and
+ # concatenate them together
+ output_stream = create_OutputStream()
+ for element in si.source.read(
+ si.source.get_range_tracker(None, None)):
+ element_coder.get_impl().encode_to_stream(
+ element, output_stream, True)
+ elements_data = output_stream.get()
+ state_key = beam_fn_api_pb2.StateKey(function_spec_reference=view_id)
+ state_handler.Clear(state_key)
+ state_handler.Append(
+ beam_fn_api_pb2.SimpleStateAppendRequest(
+ state_key=state_key, data=[elements_data]))
+
+ elif isinstance(operation, operation_specs.WorkerFlatten):
+ fn = sdk_worker.pack_function_spec_data(
+ operation.serialized_fn,
+ sdk_worker.IDENTITY_DOFN_URN,
+ id=self._next_uid())
+ inputs = {
+ 'ignored_input_tag':
+ beam_fn_api_pb2.Target.List(target=[
+ beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_index_to_id[
+ input_op_index],
+ name=output_tags(map_task[input_op_index][1])[
+ input_output_index])
+ for input_op_index, input_output_index in operation.inputs
+ ])
+ }
+ side_inputs = {}
+
+ else:
+ raise TypeError(operation)
+
+ ptransform = beam_fn_api_pb2.PrimitiveTransform(
+ id=transform_id,
+ function_spec=fn,
+ step_name=stage_name,
+ inputs=inputs,
+ side_inputs=side_inputs,
+ outputs=outputs(operation))
+ transforms.append(ptransform)
+
+ process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
+ id=self._next_uid(), coders=coders.values(),
+ primitive_transform=transforms)
+ return beam_fn_api_pb2.InstructionRequest(
+ instruction_id=self._next_uid(),
+ register=beam_fn_api_pb2.RegisterRequest(
+ process_bundle_descriptor=[process_bundle_descriptor
+ ])), runner_sinks, input_data
+
+ def _run_map_task(
+ self, map_task, control_handler, state_handler, data_plane_handler,
+ data_operation_spec):
+ registration, sinks, input_data = self._map_task_registration(
+ map_task, state_handler, data_operation_spec)
+ control_handler.push(registration)
+ process_bundle = beam_fn_api_pb2.InstructionRequest(
+ instruction_id=self._next_uid(),
+ process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
+ process_bundle_descriptor_reference=registration.register.
+ process_bundle_descriptor[0].id))
+
+ for (transform_id, name), elements in input_data.items():
+ data_out = data_plane_handler.output_stream(
+ process_bundle.instruction_id, beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_id, name=name))
+ data_out.write(elements)
+ data_out.close()
+
+ control_handler.push(process_bundle)
+ while True:
+ result = control_handler.pull()
+ if result.instruction_id == process_bundle.instruction_id:
+ if result.error:
+ raise RuntimeError(result.error)
+ expected_targets = [
+ beam_fn_api_pb2.Target(primitive_transform_reference=transform_id,
+ name=output_name)
+ for (transform_id, output_name), _ in sinks.items()]
+ for output in data_plane_handler.input_elements(
+ process_bundle.instruction_id, expected_targets):
+ target_tuple = (
+ output.target.primitive_transform_reference, output.target.name)
+ if target_tuple not in sinks:
+ # Unconsumed output.
+ continue
+ sink_op = sinks[target_tuple]
+ coder = sink_op.output_coders[0]
+ input_stream = create_InputStream(output.data)
+ elements = []
+ while input_stream.size() > 0:
+ elements.append(coder.get_impl().decode_from_stream(
+ input_stream, True))
+ if not sink_op.write_windowed_values:
+ elements = [e.value for e in elements]
+ for e in elements:
+ sink_op.output_buffer.append(e)
+ return
+
+ def execute_map_tasks(self, ordered_map_tasks, direct=True):
+ if direct:
+ controller = FnApiRunner.DirectController()
+ else:
+ controller = FnApiRunner.GrpcController()
+
+ try:
+ for _, map_task in ordered_map_tasks:
+ logging.info('Running %s', map_task)
+ self._run_map_task(
+ map_task, controller.control_handler, controller.state_handler,
+ controller.data_plane_handler, controller.data_operation_spec())
+ finally:
+ controller.close()
+
+ class SimpleState(object): # TODO(robertwb): Inherit from GRPC servicer.
+
+ def __init__(self):
+ self._all = collections.defaultdict(list)
+
+ def Get(self, state_key):
+ return beam_fn_api_pb2.Elements.Data(
+ data=''.join(self._all[self._to_key(state_key)]))
+
+ def Append(self, append_request):
+ self._all[self._to_key(append_request.state_key)].extend(
+ append_request.data)
+
+ def Clear(self, state_key):
+ try:
+ del self._all[self._to_key(state_key)]
+ except KeyError:
+ pass
+
+ @staticmethod
+ def _to_key(state_key):
+ return (state_key.function_spec_reference, state_key.window,
+ state_key.key)
+
+ class DirectController(object):
+ """An in-memory controller for fn API control, state and data planes."""
+
+ def __init__(self):
+ self._responses = []
+ self.state_handler = FnApiRunner.SimpleState()
+ self.control_handler = self
+ self.data_plane_handler = data_plane.InMemoryDataChannel()
+ self.worker = sdk_worker.SdkWorker(
+ self.state_handler, data_plane.InMemoryDataChannelFactory(
+ self.data_plane_handler.inverse()))
+
+ def push(self, request):
+ logging.info('CONTROL REQUEST %s', request)
+ response = self.worker.do_instruction(request)
+ logging.info('CONTROL RESPONSE %s', response)
+ self._responses.append(response)
+
+ def pull(self):
+ return self._responses.pop(0)
+
+ def done(self):
+ pass
+
+ def close(self):
+ pass
+
+ def data_operation_spec(self):
+ return None
+
+ class GrpcController(object):
+ """An grpc based controller for fn API control, state and data planes."""
+
+ def __init__(self):
+ self.state_handler = FnApiRunner.SimpleState()
+ self.control_server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=10))
+ self.control_port = self.control_server.add_insecure_port('[::]:0')
+
+ self.data_server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ self.data_port = self.data_server.add_insecure_port('[::]:0')
+
+ self.control_handler = streaming_rpc_handler(
+ beam_fn_api_pb2.BeamFnControlServicer, 'Control')
+ beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
+ self.control_handler, self.control_server)
+
+ self.data_plane_handler = data_plane.GrpcServerDataChannel()
+ beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+ self.data_plane_handler, self.data_server)
+
+ logging.info('starting control server on port %s', self.control_port)
+ logging.info('starting data server on port %s', self.data_port)
+ self.data_server.start()
+ self.control_server.start()
+
+ self.worker = sdk_worker.SdkHarness(
+ grpc.insecure_channel('localhost:%s' % self.control_port))
+ self.worker_thread = threading.Thread(target=self.worker.run)
+ logging.info('starting worker')
+ self.worker_thread.start()
+
+ def data_operation_spec(self):
+ url = 'localhost:%s' % self.data_port
+ remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+ remote_grpc_port.api_service_descriptor.url = url
+ return remote_grpc_port
+
+ def close(self):
+ self.control_handler.done()
+ self.worker_thread.join()
+ self.data_plane_handler.close()
+ self.control_server.stop(5).wait()
+ self.data_server.stop(5).wait()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
new file mode 100644
index 0000000..633602f
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import unittest
+
+import apache_beam as beam
+from apache_beam.runners.portability import fn_api_runner
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class FnApiRunnerTest(maptask_executor_runner.MapTaskExecutorRunner):
+
+ def create_pipeline(self):
+ return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
+
+ def test_combine_per_key(self):
+ # TODO(robertwb): Implement PGBKCV operation.
+ pass
+
+ # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
new file mode 100644
index 0000000..d273e18
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner.py
@@ -0,0 +1,468 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Beam runner for testing/profiling worker code directly.
+"""
+
+import collections
+import logging
+import time
+
+import apache_beam as beam
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.runners import DataflowRunner
+from apache_beam.runners.dataflow.internal.dependency import _dependency_file_copy
+from apache_beam.runners.dataflow.internal.names import PropertyNames
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.runners.runner import PipelineRunner
+from apache_beam.runners.runner import PipelineState
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+try:
+ from apache_beam.runners.worker import statesampler
+except ImportError:
+ from apache_beam.runners.worker import statesampler_fake as statesampler
+from apache_beam.typehints import typehints
+from apache_beam.utils import pipeline_options
+from apache_beam.utils import profiler
+from apache_beam.utils.counters import CounterFactory
+
+
+class MapTaskExecutorRunner(PipelineRunner):
+ """Beam runner translating a pipeline into map tasks that are then executed.
+
+ Primarily intended for testing and profiling the worker code paths.
+ """
+
+ def __init__(self):
+ self.executors = []
+
+ def has_metrics_support(self):
+ """Returns whether this runner supports metrics or not.
+ """
+ return False
+
+ def run(self, pipeline):
+ MetricsEnvironment.set_metrics_supported(self.has_metrics_support())
+ # List of map tasks Each map task is a list of
+ # (stage_name, operation_specs.WorkerOperation) instructions.
+ self.map_tasks = []
+
+ # Map of pvalues to
+ # (map_task_index, producer_operation_index, producer_output_index)
+ self.outputs = {}
+
+ # Unique mappings of PCollections to strings.
+ self.side_input_labels = collections.defaultdict(
+ lambda: str(len(self.side_input_labels)))
+
+ # Mapping of map task indices to all map tasks that must preceed them.
+ self.dependencies = collections.defaultdict(set)
+
+ # Visit the graph, building up the map_tasks and their metadata.
+ super(MapTaskExecutorRunner, self).run(pipeline)
+
+ # Now run the tasks in topological order.
+ def compute_depth_map(deps):
+ memoized = {}
+
+ def compute_depth(x):
+ if x not in memoized:
+ memoized[x] = 1 + max([-1] + [compute_depth(y) for y in deps[x]])
+ return memoized[x]
+
+ return {x: compute_depth(x) for x in deps.keys()}
+
+ map_task_depths = compute_depth_map(self.dependencies)
+ ordered_map_tasks = sorted((map_task_depths.get(ix, -1), map_task)
+ for ix, map_task in enumerate(self.map_tasks))
+
+ profile_options = pipeline.options.view_as(
+ pipeline_options.ProfilingOptions)
+ if profile_options.profile_cpu:
+ with profiler.Profile(
+ profile_id='worker-runner',
+ profile_location=profile_options.profile_location,
+ log_results=True, file_copy_fn=_dependency_file_copy):
+ self.execute_map_tasks(ordered_map_tasks)
+ else:
+ self.execute_map_tasks(ordered_map_tasks)
+
+ return WorkerRunnerResult(PipelineState.UNKNOWN)
+
+ def metrics_containers(self):
+ return [op.metrics_container
+ for ex in self.executors
+ for op in ex.operations()]
+
+ def execute_map_tasks(self, ordered_map_tasks):
+ tt = time.time()
+ for ix, (_, map_task) in enumerate(ordered_map_tasks):
+ logging.info('Running %s', map_task)
+ t = time.time()
+ stage_names, all_operations = zip(*map_task)
+ # TODO(robertwb): The DataflowRunner worker receives system step names
+ # (e.g. "s3") that are used to label the output msec counters. We use the
+ # operation names here, but this is not the same scheme used by the
+ # DataflowRunner; the result is that the output msec counters are named
+ # differently.
+ system_names = stage_names
+ # Create the CounterFactory and StateSampler for this MapTask.
+ # TODO(robertwb): Output counters produced here are currently ignored.
+ counter_factory = CounterFactory()
+ state_sampler = statesampler.StateSampler('%s-' % ix, counter_factory)
+ map_executor = operations.SimpleMapTaskExecutor(
+ operation_specs.MapTask(
+ all_operations, 'S%02d' % ix,
+ system_names, stage_names, system_names),
+ counter_factory,
+ state_sampler)
+ self.executors.append(map_executor)
+ map_executor.execute()
+ logging.info(
+ 'Stage %s finished: %0.3f sec', stage_names[0], time.time() - t)
+ logging.info('Total time: %0.3f sec', time.time() - tt)
+
+ def run_Read(self, transform_node):
+ self._run_read_from(transform_node, transform_node.transform.source)
+
+ def _run_read_from(self, transform_node, source):
+ """Used when this operation is the result of reading source."""
+ if not isinstance(source, iobase.NativeSource):
+ source = iobase.SourceBundle(1.0, source, None, None)
+ output = transform_node.outputs[None]
+ element_coder = self._get_coder(output)
+ read_op = operation_specs.WorkerRead(source, output_coders=[element_coder])
+ self.outputs[output] = len(self.map_tasks), 0, 0
+ self.map_tasks.append([(transform_node.full_label, read_op)])
+ return len(self.map_tasks) - 1
+
+ def run_ParDo(self, transform_node):
+ transform = transform_node.transform
+ output = transform_node.outputs[None]
+ element_coder = self._get_coder(output)
+ map_task_index, producer_index, output_index = self.outputs[
+ transform_node.inputs[0]]
+
+ # If any of this ParDo's side inputs depend on outputs from this map_task,
+ # we can't continue growing this map task.
+ def is_reachable(leaf, root):
+ if leaf == root:
+ return True
+ else:
+ return any(is_reachable(x, root) for x in self.dependencies[leaf])
+
+ if any(is_reachable(self.outputs[side_input.pvalue][0], map_task_index)
+ for side_input in transform_node.side_inputs):
+ # Start a new map tasks.
+ input_element_coder = self._get_coder(transform_node.inputs[0])
+
+ output_buffer = OutputBuffer(input_element_coder)
+
+ fusion_break_write = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=True,
+ input=(producer_index, output_index),
+ output_coders=[input_element_coder])
+ self.map_tasks[map_task_index].append(
+ (transform_node.full_label + '/Write', fusion_break_write))
+
+ original_map_task_index = map_task_index
+ map_task_index, producer_index, output_index = len(self.map_tasks), 0, 0
+
+ fusion_break_read = operation_specs.WorkerRead(
+ output_buffer.source_bundle(),
+ output_coders=[input_element_coder])
+ self.map_tasks.append(
+ [(transform_node.full_label + '/Read', fusion_break_read)])
+
+ self.dependencies[map_task_index].add(original_map_task_index)
+
+ def create_side_read(side_input):
+ label = self.side_input_labels[side_input]
+ output_buffer = self.run_side_write(
+ side_input.pvalue, '%s/%s' % (transform_node.full_label, label))
+ return operation_specs.WorkerSideInputSource(
+ output_buffer.source(), label)
+
+ do_op = operation_specs.WorkerDoFn( #
+ serialized_fn=pickler.dumps(DataflowRunner._pardo_fn_data(
+ transform_node,
+ lambda side_input: self.side_input_labels[side_input])),
+ output_tags=[PropertyNames.OUT] + ['%s_%s' % (PropertyNames.OUT, tag)
+ for tag in transform.output_tags
+ ],
+ # Same assumption that DataflowRunner has about coders being compatible
+ # across outputs.
+ output_coders=[element_coder] * (len(transform.output_tags) + 1),
+ input=(producer_index, output_index),
+ side_inputs=[create_side_read(side_input)
+ for side_input in transform_node.side_inputs])
+
+ producer_index = len(self.map_tasks[map_task_index])
+ self.outputs[transform_node.outputs[None]] = (
+ map_task_index, producer_index, 0)
+ for ix, tag in enumerate(transform.output_tags):
+ self.outputs[transform_node.outputs[
+ tag]] = map_task_index, producer_index, ix + 1
+ self.map_tasks[map_task_index].append((transform_node.full_label, do_op))
+
+ for side_input in transform_node.side_inputs:
+ self.dependencies[map_task_index].add(self.outputs[side_input.pvalue][0])
+
+ def run_side_write(self, pcoll, label):
+ map_task_index, producer_index, output_index = self.outputs[pcoll]
+
+ windowed_element_coder = self._get_coder(pcoll)
+ output_buffer = OutputBuffer(windowed_element_coder)
+ write_sideinput_op = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=True,
+ input=(producer_index, output_index),
+ output_coders=[windowed_element_coder])
+ self.map_tasks[map_task_index].append(
+ (label, write_sideinput_op))
+ return output_buffer
+
+ def run_GroupByKeyOnly(self, transform_node):
+ map_task_index, producer_index, output_index = self.outputs[
+ transform_node.inputs[0]]
+ grouped_element_coder = self._get_coder(transform_node.outputs[None],
+ windowed=False)
+ windowed_ungrouped_element_coder = self._get_coder(transform_node.inputs[0])
+
+ output_buffer = GroupingOutputBuffer(grouped_element_coder)
+ shuffle_write = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=False,
+ input=(producer_index, output_index),
+ output_coders=[windowed_ungrouped_element_coder])
+ self.map_tasks[map_task_index].append(
+ (transform_node.full_label + '/Write', shuffle_write))
+
+ output_map_task_index = self._run_read_from(
+ transform_node, output_buffer.source())
+ self.dependencies[output_map_task_index].add(map_task_index)
+
+ def run_Flatten(self, transform_node):
+ output_buffer = OutputBuffer(self._get_coder(transform_node.outputs[None]))
+ output_map_task = self._run_read_from(transform_node,
+ output_buffer.source())
+
+ for input in transform_node.inputs:
+ map_task_index, producer_index, output_index = self.outputs[input]
+ element_coder = self._get_coder(input)
+ flatten_write = operation_specs.WorkerInMemoryWrite(
+ output_buffer=output_buffer,
+ write_windowed_values=True,
+ input=(producer_index, output_index),
+ output_coders=[element_coder])
+ self.map_tasks[map_task_index].append(
+ (transform_node.full_label + '/Write', flatten_write))
+ self.dependencies[output_map_task].add(map_task_index)
+
+ def apply_CombinePerKey(self, transform, input):
+ # TODO(robertwb): Support side inputs.
+ assert not transform.args and not transform.kwargs
+ return (input
+ | PartialGroupByKeyCombineValues(transform.fn)
+ | beam.GroupByKey()
+ | MergeAccumulators(transform.fn)
+ | ExtractOutputs(transform.fn))
+
+ def run_PartialGroupByKeyCombineValues(self, transform_node):
+ element_coder = self._get_coder(transform_node.outputs[None])
+ _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+ combine_op = operation_specs.WorkerPartialGroupByKey(
+ combine_fn=pickler.dumps(
+ (transform_node.transform.combine_fn, (), {}, ())),
+ output_coders=[element_coder],
+ input=(producer_index, output_index))
+ self._run_as_op(transform_node, combine_op)
+
+ def run_MergeAccumulators(self, transform_node):
+ self._run_combine_transform(transform_node, 'merge')
+
+ def run_ExtractOutputs(self, transform_node):
+ self._run_combine_transform(transform_node, 'extract')
+
+ def _run_combine_transform(self, transform_node, phase):
+ transform = transform_node.transform
+ element_coder = self._get_coder(transform_node.outputs[None])
+ _, producer_index, output_index = self.outputs[transform_node.inputs[0]]
+ combine_op = operation_specs.WorkerCombineFn(
+ serialized_fn=pickler.dumps(
+ (transform.combine_fn, (), {}, ())),
+ phase=phase,
+ output_coders=[element_coder],
+ input=(producer_index, output_index))
+ self._run_as_op(transform_node, combine_op)
+
+ def _get_coder(self, pvalue, windowed=True):
+ # TODO(robertwb): This should be an attribute of the pvalue itself.
+ return DataflowRunner._get_coder(
+ pvalue.element_type or typehints.Any,
+ pvalue.windowing.windowfn.get_window_coder() if windowed else None)
+
+ def _run_as_op(self, transform_node, op):
+ """Single-output operation in the same map task as its input."""
+ map_task_index, _, _ = self.outputs[transform_node.inputs[0]]
+ op_index = len(self.map_tasks[map_task_index])
+ output = transform_node.outputs[None]
+ self.outputs[output] = map_task_index, op_index, 0
+ self.map_tasks[map_task_index].append((transform_node.full_label, op))
+
+
+class InMemorySource(iobase.BoundedSource):
+ """Source for reading an (as-yet unwritten) set of in-memory encoded elements.
+ """
+
+ def __init__(self, encoded_elements, coder):
+ self._encoded_elements = encoded_elements
+ self._coder = coder
+
+ def get_range_tracker(self, unused_start_position, unused_end_position):
+ return None
+
+ def read(self, unused_range_tracker):
+ for encoded_element in self._encoded_elements:
+ yield self._coder.decode(encoded_element)
+
+ def default_output_coder(self):
+ return self._coder
+
+
+class OutputBuffer(object):
+
+ def __init__(self, coder):
+ self.coder = coder
+ self.elements = []
+ self.encoded_elements = []
+
+ def source(self):
+ return InMemorySource(self.encoded_elements, self.coder)
+
+ def source_bundle(self):
+ return iobase.SourceBundle(
+ 1.0, InMemorySource(self.encoded_elements, self.coder), None, None)
+
+ def __repr__(self):
+ return 'GroupingOutput[%r]' % len(self.elements)
+
+ def append(self, value):
+ self.elements.append(value)
+ self.encoded_elements.append(self.coder.encode(value))
+
+
+class GroupingOutputBuffer(object):
+
+ def __init__(self, grouped_coder):
+ self.grouped_coder = grouped_coder
+ self.elements = collections.defaultdict(list)
+ self.frozen = False
+
+ def source(self):
+ return InMemorySource(self.encoded_elements, self.grouped_coder)
+
+ def __repr__(self):
+ return 'GroupingOutputBuffer[%r]' % len(self.elements)
+
+ def append(self, pair):
+ assert not self.frozen
+ k, v = pair
+ self.elements[k].append(v)
+
+ def freeze(self):
+ if not self.frozen:
+ self._encoded_elements = [self.grouped_coder.encode(kv)
+ for kv in self.elements.iteritems()]
+ self.frozen = True
+ return self._encoded_elements
+
+ @property
+ def encoded_elements(self):
+ return GroupedOutputBuffer(self)
+
+
+class GroupedOutputBuffer(object):
+
+ def __init__(self, buffer):
+ self.buffer = buffer
+
+ def __getitem__(self, ix):
+ return self.buffer.freeze()[ix]
+
+ def __iter__(self):
+ return iter(self.buffer.freeze())
+
+ def __len__(self):
+ return len(self.buffer.freeze())
+
+ def __nonzero__(self):
+ return True
+
+
+class PartialGroupByKeyCombineValues(beam.PTransform):
+
+ def __init__(self, combine_fn, native=True):
+ self.combine_fn = combine_fn
+ self.native = native
+
+ def expand(self, input):
+ if self.native:
+ return beam.pvalue.PCollection(input.pipeline)
+ else:
+ def to_accumulator(v):
+ return self.combine_fn.add_input(
+ self.combine_fn.create_accumulator(), v)
+ return input | beam.Map(lambda (k, v): (k, to_accumulator(v)))
+
+
+class MergeAccumulators(beam.PTransform):
+
+ def __init__(self, combine_fn, native=True):
+ self.combine_fn = combine_fn
+ self.native = native
+
+ def expand(self, input):
+ if self.native:
+ return beam.pvalue.PCollection(input.pipeline)
+ else:
+ merge_accumulators = self.combine_fn.merge_accumulators
+ return input | beam.Map(lambda (k, vs): (k, merge_accumulators(vs)))
+
+
+class ExtractOutputs(beam.PTransform):
+
+ def __init__(self, combine_fn, native=True):
+ self.combine_fn = combine_fn
+ self.native = native
+
+ def expand(self, input):
+ if self.native:
+ return beam.pvalue.PCollection(input.pipeline)
+ else:
+ extract_output = self.combine_fn.extract_output
+ return input | beam.Map(lambda (k, v): (k, extract_output(v)))
+
+
+class WorkerRunnerResult(PipelineResult):
+
+ def wait_until_finish(self, duration=None):
+ pass
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
new file mode 100644
index 0000000..6e13e73
--- /dev/null
+++ b/sdks/python/apache_beam/runners/portability/maptask_executor_runner_test.py
@@ -0,0 +1,204 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import logging
+import tempfile
+import unittest
+
+import apache_beam as beam
+
+from apache_beam.metrics import Metrics
+from apache_beam.metrics.execution import MetricKey
+from apache_beam.metrics.execution import MetricsEnvironment
+from apache_beam.metrics.metricbase import MetricName
+
+from apache_beam.pvalue import AsList
+from apache_beam.transforms.util import assert_that
+from apache_beam.transforms.util import BeamAssertException
+from apache_beam.transforms.util import equal_to
+from apache_beam.transforms.window import TimestampedValue
+from apache_beam.runners.portability import maptask_executor_runner
+
+
+class MapTaskExecutorRunnerTest(unittest.TestCase):
+
+ def create_pipeline(self):
+ return beam.Pipeline(runner=maptask_executor_runner.MapTaskExecutorRunner())
+
+ def test_assert_that(self):
+ with self.assertRaises(BeamAssertException):
+ with self.create_pipeline() as p:
+ assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
+
+ def test_create(self):
+ with self.create_pipeline() as p:
+ assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
+
+ def test_pardo(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create(['a', 'bc'])
+ | beam.Map(lambda e: e * 2)
+ | beam.Map(lambda e: e + 'x'))
+ assert_that(res, equal_to(['aax', 'bcbcx']))
+
+ def test_pardo_metrics(self):
+
+ class MyDoFn(beam.DoFn):
+
+ def start_bundle(self):
+ self.count = Metrics.counter(self.__class__, 'elements')
+
+ def process(self, element):
+ self.count.inc(element)
+ return [element]
+
+ class MyOtherDoFn(beam.DoFn):
+
+ def start_bundle(self):
+ self.count = Metrics.counter(self.__class__, 'elementsplusone')
+
+ def process(self, element):
+ self.count.inc(element + 1)
+ return [element]
+
+ with self.create_pipeline() as p:
+ res = (p | beam.Create([1, 2, 3])
+ | 'mydofn' >> beam.ParDo(MyDoFn())
+ | 'myotherdofn' >> beam.ParDo(MyOtherDoFn()))
+ p.run()
+ if not MetricsEnvironment.METRICS_SUPPORTED:
+ self.skipTest('Metrics are not supported.')
+
+ counter_updates = [{'key': key, 'value': val}
+ for container in p.runner.metrics_containers()
+ for key, val in
+ container.get_updates().counters.items()]
+ counter_values = [update['value'] for update in counter_updates]
+ counter_keys = [update['key'] for update in counter_updates]
+ assert_that(res, equal_to([1, 2, 3]))
+ self.assertEqual(counter_values, [6, 9])
+ self.assertEqual(counter_keys, [
+ MetricKey('mydofn',
+ MetricName(__name__ + '.MyDoFn', 'elements')),
+ MetricKey('myotherdofn',
+ MetricName(__name__ + '.MyOtherDoFn', 'elementsplusone'))])
+
+ def test_pardo_side_outputs(self):
+ def tee(elem, *tags):
+ for tag in tags:
+ if tag in elem:
+ yield beam.pvalue.OutputValue(tag, elem)
+ with self.create_pipeline() as p:
+ xy = (p
+ | 'Create' >> beam.Create(['x', 'y', 'xy'])
+ | beam.FlatMap(tee, 'x', 'y').with_outputs())
+ assert_that(xy.x, equal_to(['x', 'xy']), label='x')
+ assert_that(xy.y, equal_to(['y', 'xy']), label='y')
+
+ def test_pardo_side_and_main_outputs(self):
+ def even_odd(elem):
+ yield elem
+ yield beam.pvalue.OutputValue('odd' if elem % 2 else 'even', elem)
+ with self.create_pipeline() as p:
+ ints = p | beam.Create([1, 2, 3])
+ named = ints | 'named' >> beam.FlatMap(
+ even_odd).with_outputs('even', 'odd', main='all')
+ assert_that(named.all, equal_to([1, 2, 3]), label='named.all')
+ assert_that(named.even, equal_to([2]), label='named.even')
+ assert_that(named.odd, equal_to([1, 3]), label='named.odd')
+
+ unnamed = ints | 'unnamed' >> beam.FlatMap(even_odd).with_outputs()
+ unnamed[None] | beam.Map(id) # pylint: disable=expression-not-assigned
+ assert_that(unnamed[None], equal_to([1, 2, 3]), label='unnamed.all')
+ assert_that(unnamed.even, equal_to([2]), label='unnamed.even')
+ assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
+
+ def test_pardo_side_inputs(self):
+ def cross_product(elem, sides):
+ for side in sides:
+ yield elem, side
+ with self.create_pipeline() as p:
+ main = p | 'main' >> beam.Create(['a', 'b', 'c'])
+ side = p | 'side' >> beam.Create(['x', 'y'])
+ assert_that(main | beam.FlatMap(cross_product, AsList(side)),
+ equal_to([('a', 'x'), ('b', 'x'), ('c', 'x'),
+ ('a', 'y'), ('b', 'y'), ('c', 'y')]))
+
+ def test_pardo_unfusable_side_inputs(self):
+ def cross_product(elem, sides):
+ for side in sides:
+ yield elem, side
+ with self.create_pipeline() as p:
+ pcoll = p | beam.Create(['a', 'b'])
+ assert_that(pcoll | beam.FlatMap(cross_product, AsList(pcoll)),
+ equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+ with self.create_pipeline() as p:
+ pcoll = p | beam.Create(['a', 'b'])
+ derived = ((pcoll,) | beam.Flatten()
+ | beam.Map(lambda x: (x, x))
+ | beam.GroupByKey()
+ | 'Unkey' >> beam.Map(lambda (x, _): x))
+ assert_that(
+ pcoll | beam.FlatMap(cross_product, AsList(derived)),
+ equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
+
+ def test_group_by_key(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+ | beam.GroupByKey()
+ | beam.Map(lambda (k, vs): (k, sorted(vs))))
+ assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
+
+ def test_flatten(self):
+ with self.create_pipeline() as p:
+ res = (p | 'a' >> beam.Create(['a']),
+ p | 'bc' >> beam.Create(['b', 'c']),
+ p | 'd' >> beam.Create(['d'])) | beam.Flatten()
+ assert_that(res, equal_to(['a', 'b', 'c', 'd']))
+
+ def test_combine_per_key(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create([('a', 1), ('a', 2), ('b', 3)])
+ | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
+ assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
+
+ def test_read(self):
+ with tempfile.NamedTemporaryFile() as temp_file:
+ temp_file.write('a\nb\nc')
+ temp_file.flush()
+ with self.create_pipeline() as p:
+ assert_that(p | beam.io.ReadFromText(temp_file.name),
+ equal_to(['a', 'b', 'c']))
+
+ def test_windowing(self):
+ with self.create_pipeline() as p:
+ res = (p
+ | beam.Create([1, 2, 100, 101, 102])
+ | beam.Map(lambda t: TimestampedValue(('k', t), t))
+ | beam.WindowInto(beam.transforms.window.Sessions(10))
+ | beam.GroupByKey()
+ | beam.Map(lambda (k, vs): (k, sorted(vs))))
+ assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/__init__.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/__init__.py b/sdks/python/apache_beam/runners/worker/__init__.py
new file mode 100644
index 0000000..cce3aca
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
new file mode 100644
index 0000000..6425447
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -0,0 +1,288 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Implementation of DataChannels for communicating across the data plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import logging
+import Queue as queue
+import threading
+
+from apache_beam.coders import coder_impl
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class ClosableOutputStream(type(coder_impl.create_OutputStream())):
+ """A Outputstream for use with CoderImpls that has a close() method."""
+
+ def __init__(self, close_callback=None):
+ super(ClosableOutputStream, self).__init__()
+ self._close_callback = close_callback
+
+ def close(self):
+ if self._close_callback:
+ self._close_callback(self.get())
+
+
+class DataChannel(object):
+ """Represents a channel for reading and writing data over the data plane.
+
+ Read from this channel with the input_elements method::
+
+ for elements_data in data_channel.input_elements(instruction_id, targets):
+ [process elements_data]
+
+ Write to this channel using the output_stream method::
+
+ out1 = data_channel.output_stream(instruction_id, target1)
+ out1.write(...)
+ out1.close()
+
+ When all data for all instructions is written, close the channel::
+
+ data_channel.close()
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def input_elements(self, instruction_id, expected_targets):
+ """Returns an iterable of all Element.Data bundles for instruction_id.
+
+ This iterable terminates only once the full set of data has been recieved
+ for each of the expected targets. It may block waiting for more data.
+
+ Args:
+ instruction_id: which instruction the results must belong to
+ expected_targets: which targets to wait on for completion
+ """
+ raise NotImplementedError(type(self))
+
+ @abc.abstractmethod
+ def output_stream(self, instruction_id, target):
+ """Returns an output stream writing elements to target.
+
+ Args:
+ instruction_id: which instruction this stream belongs to
+ target: the target of the returned stream
+ """
+ raise NotImplementedError(type(self))
+
+ @abc.abstractmethod
+ def close(self):
+ """Closes this channel, indicating that all data has been written.
+
+ Data can continue to be read.
+
+ If this channel is shared by many instructions, should only be called on
+ worker shutdown.
+ """
+ raise NotImplementedError(type(self))
+
+
+class InMemoryDataChannel(DataChannel):
+ """An in-memory implementation of a DataChannel.
+
+ This channel is two-sided. What is written to one side is read by the other.
+ The inverse() method returns the other side of a instance.
+ """
+
+ def __init__(self, inverse=None):
+ self._inputs = []
+ self._inverse = inverse or InMemoryDataChannel(self)
+
+ def inverse(self):
+ return self._inverse
+
+ def input_elements(self, instruction_id, unused_expected_targets=None):
+ for data in self._inputs:
+ if data.instruction_reference == instruction_id:
+ yield data
+
+ def output_stream(self, instruction_id, target):
+ def add_to_inverse_output(data):
+ self._inverse._inputs.append( # pylint: disable=protected-access
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference=instruction_id,
+ target=target,
+ data=data))
+ return ClosableOutputStream(add_to_inverse_output)
+
+ def close(self):
+ pass
+
+
+class _GrpcDataChannel(DataChannel):
+ """Base class for implementing a BeamFnData-based DataChannel."""
+
+ _WRITES_FINISHED = object()
+
+ def __init__(self):
+ self._to_send = queue.Queue()
+ self._received = collections.defaultdict(queue.Queue)
+ self._receive_lock = threading.Lock()
+ self._reads_finished = threading.Event()
+
+ def close(self):
+ self._to_send.put(self._WRITES_FINISHED)
+
+ def wait(self, timeout=None):
+ self._reads_finished.wait(timeout)
+
+ def _receiving_queue(self, instruction_id):
+ with self._receive_lock:
+ return self._received[instruction_id]
+
+ def input_elements(self, instruction_id, expected_targets):
+ received = self._receiving_queue(instruction_id)
+ done_targets = []
+ while len(done_targets) < len(expected_targets):
+ data = received.get()
+ if not data.data and data.target in expected_targets:
+ done_targets.append(data.target)
+ else:
+ assert data.target not in done_targets
+ yield data
+
+ def output_stream(self, instruction_id, target):
+ def add_to_send_queue(data):
+ self._to_send.put(
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference=instruction_id,
+ target=target,
+ data=data))
+ self._to_send.put(
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference=instruction_id,
+ target=target,
+ data=''))
+ return ClosableOutputStream(add_to_send_queue)
+
+ def _write_outputs(self):
+ done = False
+ while not done:
+ data = [self._to_send.get()]
+ try:
+ # Coalesce up to 100 other items.
+ for _ in range(100):
+ data.append(self._to_send.get_nowait())
+ except queue.Empty:
+ pass
+ if data[-1] is self._WRITES_FINISHED:
+ done = True
+ data.pop()
+ if data:
+ yield beam_fn_api_pb2.Elements(data=data)
+
+ def _read_inputs(self, elements_iterator):
+ # TODO(robertwb): Pushback/throttling to avoid unbounded buffering.
+ try:
+ for elements in elements_iterator:
+ for data in elements.data:
+ self._receiving_queue(data.instruction_reference).put(data)
+ except: # pylint: disable=broad-except
+ logging.exception('Failed to read inputs in the data plane')
+ raise
+ finally:
+ self._reads_finished.set()
+
+ def _start_reader(self, elements_iterator):
+ reader = threading.Thread(
+ target=lambda: self._read_inputs(elements_iterator),
+ name='read_grpc_client_inputs')
+ reader.daemon = True
+ reader.start()
+
+
+class GrpcClientDataChannel(_GrpcDataChannel):
+ """A DataChannel wrapping the client side of a BeamFnData connection."""
+
+ def __init__(self, data_stub):
+ super(GrpcClientDataChannel, self).__init__()
+ self._start_reader(data_stub.Data(self._write_outputs()))
+
+
+class GrpcServerDataChannel(
+ beam_fn_api_pb2.BeamFnDataServicer, _GrpcDataChannel):
+ """A DataChannel wrapping the server side of a BeamFnData connection."""
+
+ def Data(self, elements_iterator, context):
+ self._start_reader(elements_iterator)
+ for elements in self._write_outputs():
+ yield elements
+
+
+class DataChannelFactory(object):
+ """An abstract factory for creating ``DataChannel``."""
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def create_data_channel(self, function_spec):
+ """Returns a ``DataChannel`` from the given function_spec."""
+ raise NotImplementedError(type(self))
+
+ @abc.abstractmethod
+ def close(self):
+ """Close all channels that this factory owns."""
+ raise NotImplementedError(type(self))
+
+
+class GrpcClientDataChannelFactory(DataChannelFactory):
+ """A factory for ``GrpcClientDataChannel``.
+
+ Caches the created channels by ``data descriptor url``.
+ """
+
+ def __init__(self):
+ self._data_channel_cache = {}
+
+ def create_data_channel(self, function_spec):
+ remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
+ function_spec.data.Unpack(remote_grpc_port)
+ url = remote_grpc_port.api_service_descriptor.url
+ if url not in self._data_channel_cache:
+ logging.info('Creating channel for %s', url)
+ grpc_channel = grpc.insecure_channel(url)
+ self._data_channel_cache[url] = GrpcClientDataChannel(
+ beam_fn_api_pb2.BeamFnDataStub(grpc_channel))
+ return self._data_channel_cache[url]
+
+ def close(self):
+ logging.info('Closing all cached grpc data channels.')
+ for _, channel in self._data_channel_cache.items():
+ channel.close()
+ self._data_channel_cache.clear()
+
+
+class InMemoryDataChannelFactory(DataChannelFactory):
+ """A singleton factory for ``InMemoryDataChannel``."""
+
+ def __init__(self, in_memory_data_channel):
+ self._in_memory_data_channel = in_memory_data_channel
+
+ def create_data_channel(self, unused_function_spec):
+ return self._in_memory_data_channel
+
+ def close(self):
+ pass
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/data_plane_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py
new file mode 100644
index 0000000..7340789
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Tests for apache_beam.runners.worker.data_plane."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import sys
+import threading
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import data_plane
+
+
+def timeout(timeout_secs):
+ def decorate(fn):
+ exc_info = []
+
+ def wrapper(*args, **kwargs):
+ def call_fn():
+ try:
+ fn(*args, **kwargs)
+ except: # pylint: disable=bare-except
+ exc_info[:] = sys.exc_info()
+ thread = threading.Thread(target=call_fn)
+ thread.daemon = True
+ thread.start()
+ thread.join(timeout_secs)
+ if exc_info:
+ t, v, tb = exc_info # pylint: disable=unbalanced-tuple-unpacking
+ raise t, v, tb
+ assert not thread.is_alive(), 'timed out after %s seconds' % timeout_secs
+ return wrapper
+ return decorate
+
+
+class DataChannelTest(unittest.TestCase):
+
+ @timeout(5)
+ def test_grpc_data_channel(self):
+ data_channel_service = data_plane.GrpcServerDataChannel()
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=2))
+ beam_fn_api_pb2.add_BeamFnDataServicer_to_server(
+ data_channel_service, server)
+ test_port = server.add_insecure_port('[::]:0')
+ server.start()
+
+ data_channel_stub = beam_fn_api_pb2.BeamFnDataStub(
+ grpc.insecure_channel('localhost:%s' % test_port))
+ data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)
+
+ try:
+ self._data_channel_test(data_channel_service, data_channel_client)
+ finally:
+ data_channel_client.close()
+ data_channel_service.close()
+ data_channel_client.wait()
+ data_channel_service.wait()
+
+ def test_in_memory_data_channel(self):
+ channel = data_plane.InMemoryDataChannel()
+ self._data_channel_test(channel, channel.inverse())
+
+ def _data_channel_test(self, server, client):
+ self._data_channel_test_one_direction(server, client)
+ self._data_channel_test_one_direction(client, server)
+
+ def _data_channel_test_one_direction(self, from_channel, to_channel):
+ def send(instruction_id, target, data):
+ stream = from_channel.output_stream(instruction_id, target)
+ stream.write(data)
+ stream.close()
+ target_1 = beam_fn_api_pb2.Target(
+ primitive_transform_reference='1',
+ name='out')
+ target_2 = beam_fn_api_pb2.Target(
+ primitive_transform_reference='2',
+ name='out')
+
+ # Single write.
+ send('0', target_1, 'abc')
+ self.assertEqual(
+ list(to_channel.input_elements('0', [target_1])),
+ [beam_fn_api_pb2.Elements.Data(
+ instruction_reference='0',
+ target=target_1,
+ data='abc')])
+
+ # Multiple interleaved writes to multiple instructions.
+ target_2 = beam_fn_api_pb2.Target(
+ primitive_transform_reference='2',
+ name='out')
+
+ send('1', target_1, 'abc')
+ send('2', target_1, 'def')
+ self.assertEqual(
+ list(to_channel.input_elements('1', [target_1])),
+ [beam_fn_api_pb2.Elements.Data(
+ instruction_reference='1',
+ target=target_1,
+ data='abc')])
+ send('2', target_2, 'ghi')
+ self.assertEqual(
+ list(to_channel.input_elements('2', [target_1, target_2])),
+ [beam_fn_api_pb2.Elements.Data(
+ instruction_reference='2',
+ target=target_1,
+ data='def'),
+ beam_fn_api_pb2.Elements.Data(
+ instruction_reference='2',
+ target=target_2,
+ data='ghi')])
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py
new file mode 100644
index 0000000..b9e36ad
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""Beam fn API log handler."""
+
+import logging
+import math
+import Queue as queue
+import threading
+
+from apache_beam.runners.api import beam_fn_api_pb2
+import grpc
+
+
+class FnApiLogRecordHandler(logging.Handler):
+ """A handler that writes log records to the fn API."""
+
+ # Maximum number of log entries in a single stream request.
+ _MAX_BATCH_SIZE = 1000
+ # Used to indicate the end of stream.
+ _FINISHED = object()
+
+ # Mapping from logging levels to LogEntry levels.
+ LOG_LEVEL_MAP = {
+ logging.FATAL: beam_fn_api_pb2.LogEntry.CRITICAL,
+ logging.ERROR: beam_fn_api_pb2.LogEntry.ERROR,
+ logging.WARNING: beam_fn_api_pb2.LogEntry.WARN,
+ logging.INFO: beam_fn_api_pb2.LogEntry.INFO,
+ logging.DEBUG: beam_fn_api_pb2.LogEntry.DEBUG
+ }
+
+ def __init__(self, log_service_descriptor):
+ super(FnApiLogRecordHandler, self).__init__()
+ self._log_channel = grpc.insecure_channel(log_service_descriptor.url)
+ self._logging_stub = beam_fn_api_pb2.BeamFnLoggingStub(self._log_channel)
+ self._log_entry_queue = queue.Queue()
+
+ log_control_messages = self._logging_stub.Logging(self._write_log_entries())
+ self._reader = threading.Thread(
+ target=lambda: self._read_log_control_messages(log_control_messages),
+ name='read_log_control_messages')
+ self._reader.daemon = True
+ self._reader.start()
+
+ def emit(self, record):
+ log_entry = beam_fn_api_pb2.LogEntry()
+ log_entry.severity = self.LOG_LEVEL_MAP[record.levelno]
+ log_entry.message = self.format(record)
+ log_entry.thread = record.threadName
+ log_entry.log_location = record.module + '.' + record.funcName
+ (fraction, seconds) = math.modf(record.created)
+ nanoseconds = 1e9 * fraction
+ log_entry.timestamp.seconds = int(seconds)
+ log_entry.timestamp.nanos = int(nanoseconds)
+ self._log_entry_queue.put(log_entry)
+
+ def close(self):
+ """Flush out all existing log entries and unregister this handler."""
+ # Acquiring the handler lock ensures ``emit`` is not run until the lock is
+ # released.
+ self.acquire()
+ self._log_entry_queue.put(self._FINISHED)
+ # wait on server to close.
+ self._reader.join()
+ self.release()
+ # Unregister this handler.
+ super(FnApiLogRecordHandler, self).close()
+
+ def _write_log_entries(self):
+ done = False
+ while not done:
+ log_entries = [self._log_entry_queue.get()]
+ try:
+ for _ in range(self._MAX_BATCH_SIZE):
+ log_entries.append(self._log_entry_queue.get_nowait())
+ except queue.Empty:
+ pass
+ if log_entries[-1] is self._FINISHED:
+ done = True
+ log_entries.pop()
+ if log_entries:
+ yield beam_fn_api_pb2.LogEntry.List(log_entries=log_entries)
+
+ def _read_log_control_messages(self, log_control_iterator):
+ # TODO(vikasrk): Handle control messages.
+ for _ in log_control_iterator:
+ pass
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/log_handler_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py
new file mode 100644
index 0000000..565bedb
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py
@@ -0,0 +1,105 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import logging
+import unittest
+
+import grpc
+from concurrent import futures
+
+from apache_beam.runners.api import beam_fn_api_pb2
+from apache_beam.runners.worker import log_handler
+
+
+class BeamFnLoggingServicer(beam_fn_api_pb2.BeamFnLoggingServicer):
+
+ def __init__(self):
+ self.log_records_received = []
+
+ def Logging(self, request_iterator, context):
+
+ for log_record in request_iterator:
+ self.log_records_received.append(log_record)
+
+ yield beam_fn_api_pb2.LogControl()
+
+
+class FnApiLogRecordHandlerTest(unittest.TestCase):
+
+ def setUp(self):
+ self.test_logging_service = BeamFnLoggingServicer()
+ self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ beam_fn_api_pb2.add_BeamFnLoggingServicer_to_server(
+ self.test_logging_service, self.server)
+ self.test_port = self.server.add_insecure_port('[::]:0')
+ self.server.start()
+
+ self.logging_service_descriptor = beam_fn_api_pb2.ApiServiceDescriptor()
+ self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
+ self.fn_log_handler = log_handler.FnApiLogRecordHandler(
+ self.logging_service_descriptor)
+ logging.getLogger().setLevel(logging.INFO)
+ logging.getLogger().addHandler(self.fn_log_handler)
+
+ def tearDown(self):
+ # wait upto 5 seconds.
+ self.server.stop(5)
+
+ def _verify_fn_log_handler(self, num_log_entries):
+ msg = 'Testing fn logging'
+ logging.debug('Debug Message 1')
+ for idx in range(num_log_entries):
+ logging.info('%s: %s', msg, idx)
+ logging.debug('Debug Message 2')
+
+ # Wait for logs to be sent to server.
+ self.fn_log_handler.close()
+
+ num_received_log_entries = 0
+ for outer in self.test_logging_service.log_records_received:
+ for log_entry in outer.log_entries:
+ self.assertEqual(beam_fn_api_pb2.LogEntry.INFO, log_entry.severity)
+ self.assertEqual('%s: %s' % (msg, num_received_log_entries),
+ log_entry.message)
+ self.assertEqual(u'log_handler_test._verify_fn_log_handler',
+ log_entry.log_location)
+ self.assertGreater(log_entry.timestamp.seconds, 0)
+ self.assertGreater(log_entry.timestamp.nanos, 0)
+ num_received_log_entries += 1
+
+ self.assertEqual(num_received_log_entries, num_log_entries)
+
+
+# Test cases.
+data = {
+ 'one_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE - 47,
+ 'exact_multiple': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE,
+ 'multi_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE * 3 + 47
+}
+
+
+def _create_test(name, num_logs):
+ setattr(FnApiLogRecordHandlerTest, 'test_%s' % name,
+ lambda self: self._verify_fn_log_handler(num_logs))
+
+
+if __name__ == '__main__':
+ for test_name, num_logs_entries in data.iteritems():
+ _create_test(test_name, num_logs_entries)
+
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.pxd
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.pxd b/sdks/python/apache_beam/runners/worker/logger.pxd
new file mode 100644
index 0000000..201daf4
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.pxd
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+cimport cython
+
+from apache_beam.runners.common cimport LoggingContext
+
+
+cdef class PerThreadLoggingContext(LoggingContext):
+ cdef kwargs
+ cdef list stack
http://git-wip-us.apache.org/repos/asf/beam/blob/a856fcf3/sdks/python/apache_beam/runners/worker/logger.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/logger.py b/sdks/python/apache_beam/runners/worker/logger.py
new file mode 100644
index 0000000..217dc58
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/logger.py
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Python worker logging."""
+
+import json
+import logging
+import threading
+import traceback
+
+from apache_beam.runners.common import LoggingContext
+
+
+# Per-thread worker information. This is used only for logging to set
+# context information that changes while work items get executed:
+# work_item_id, step_name, stage_name.
+class _PerThreadWorkerData(threading.local):
+
+ def __init__(self):
+ super(_PerThreadWorkerData, self).__init__()
+ # TODO(robertwb): Consider starting with an initial (ignored) ~20 elements
+ # in the list, as going up and down all the way to zero incurs several
+ # reallocations.
+ self.stack = []
+
+ def get_data(self):
+ all_data = {}
+ for datum in self.stack:
+ all_data.update(datum)
+ return all_data
+
+
+per_thread_worker_data = _PerThreadWorkerData()
+
+
+class PerThreadLoggingContext(LoggingContext):
+ """A context manager to add per thread attributes."""
+
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.stack = per_thread_worker_data.stack
+
+ def __enter__(self):
+ self.enter()
+
+ def enter(self):
+ self.stack.append(self.kwargs)
+
+ def __exit__(self, exn_type, exn_value, exn_traceback):
+ self.exit()
+
+ def exit(self):
+ self.stack.pop()
+
+
+class JsonLogFormatter(logging.Formatter):
+ """A JSON formatter class as expected by the logging standard module."""
+
+ def __init__(self, job_id, worker_id):
+ super(JsonLogFormatter, self).__init__()
+ self.job_id = job_id
+ self.worker_id = worker_id
+
+ def format(self, record):
+ """Returns a JSON string based on a LogRecord instance.
+
+ Args:
+ record: A LogRecord instance. See below for details.
+
+ Returns:
+ A JSON string representing the record.
+
+ A LogRecord instance has the following attributes and is used for
+ formatting the final message.
+
+ Attributes:
+ created: A double representing the timestamp for record creation
+ (e.g., 1438365207.624597). Note that the number contains also msecs and
+ microsecs information. Part of this is also available in the 'msecs'
+ attribute.
+ msecs: A double representing the msecs part of the record creation
+ (e.g., 624.5970726013184).
+ msg: Logging message containing formatting instructions or an arbitrary
+ object. This is the first argument of a log call.
+ args: A tuple containing the positional arguments for the logging call.
+ levelname: A string. Possible values are: INFO, WARNING, ERROR, etc.
+ exc_info: None or a 3-tuple with exception information as it is
+ returned by a call to sys.exc_info().
+ name: Logger's name. Most logging is done using the default root logger
+ and therefore the name will be 'root'.
+ filename: Basename of the file where logging occurred.
+ funcName: Name of the function where logging occurred.
+ process: The PID of the process running the worker.
+ thread: An id for the thread where the record was logged. This is not a
+ real TID (the one provided by OS) but rather the id (address) of a
+ Python thread object. Nevertheless having this value can allow to
+ filter log statement from only one specific thread.
+ """
+ output = {}
+ output['timestamp'] = {
+ 'seconds': int(record.created),
+ 'nanos': int(record.msecs * 1000000)}
+ # ERROR. INFO, DEBUG log levels translate into the same for severity
+ # property. WARNING becomes WARN.
+ output['severity'] = (
+ record.levelname if record.levelname != 'WARNING' else 'WARN')
+
+ # msg could be an arbitrary object, convert it to a string first.
+ record_msg = str(record.msg)
+
+ # Prepare the actual message using the message formatting string and the
+ # positional arguments as they have been used in the log call.
+ if record.args:
+ try:
+ output['message'] = record_msg % record.args
+ except (TypeError, ValueError):
+ output['message'] = '%s with args (%s)' % (record_msg, record.args)
+ else:
+ output['message'] = record_msg
+
+ # The thread ID is logged as a combination of the process ID and thread ID
+ # since workers can run in multiple processes.
+ output['thread'] = '%s:%s' % (record.process, record.thread)
+ # job ID and worker ID. These do not change during the lifetime of a worker.
+ output['job'] = self.job_id
+ output['worker'] = self.worker_id
+ # Stage, step and work item ID come from thread local storage since they
+ # change with every new work item leased for execution. If there is no
+ # work item ID then we make sure the step is undefined too.
+ data = per_thread_worker_data.get_data()
+ if 'work_item_id' in data:
+ output['work'] = data['work_item_id']
+ if 'stage_name' in data:
+ output['stage'] = data['stage_name']
+ if 'step_name' in data:
+ output['step'] = data['step_name']
+ # All logging happens using the root logger. We will add the basename of the
+ # file and the function name where the logging happened to make it easier
+ # to identify who generated the record.
+ output['logger'] = '%s:%s:%s' % (
+ record.name, record.filename, record.funcName)
+ # Add exception information if any is available.
+ if record.exc_info:
+ output['exception'] = ''.join(
+ traceback.format_exception(*record.exc_info))
+
+ return json.dumps(output)
+
+
+def initialize(job_id, worker_id, log_path):
+ """Initialize root logger so that we log JSON to a file and text to stdout."""
+
+ file_handler = logging.FileHandler(log_path)
+ file_handler.setFormatter(JsonLogFormatter(job_id, worker_id))
+ logging.getLogger().addHandler(file_handler)
+
+ # Set default level to INFO to avoid logging various DEBUG level log calls
+ # sprinkled throughout the code.
+ logging.getLogger().setLevel(logging.INFO)