You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ta...@apache.org on 2017/07/13 03:06:54 UTC
[42/50] [abbrv] beam git commit: Split bundle processor into separate
class.
Split bundle processor into separate class.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/4abd7141
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/4abd7141
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/4abd7141
Branch: refs/heads/DSL_SQL
Commit: 4abd7141673f4aead669efd4d2a87fc163764a2d
Parents: 6a61f15
Author: Robert Bradshaw <ro...@gmail.com>
Authored: Wed Jun 28 18:20:12 2017 -0700
Committer: Tyler Akidau <ta...@apache.org>
Committed: Wed Jul 12 20:01:02 2017 -0700
----------------------------------------------------------------------
.../runners/portability/fn_api_runner.py | 20 +-
.../runners/worker/bundle_processor.py | 426 +++++++++++++++++++
.../apache_beam/runners/worker/sdk_worker.py | 398 +----------------
3 files changed, 444 insertions(+), 400 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/4abd7141/sdks/python/apache_beam/runners/portability/fn_api_runner.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index f522864..f88fe53 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -38,6 +38,7 @@ from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import maptask_executor_runner
+from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import operation_specs
from apache_beam.runners.worker import sdk_worker
@@ -186,7 +187,7 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
target_name = only_element(get_inputs(operation).keys())
runner_sinks[(transform_id, target_name)] = operation
transform_spec = beam_runner_api_pb2.FunctionSpec(
- urn=sdk_worker.DATA_OUTPUT_URN,
+ urn=bundle_processor.DATA_OUTPUT_URN,
parameter=proto_utils.pack_Any(data_operation_spec))
elif isinstance(operation, operation_specs.WorkerRead):
@@ -200,7 +201,7 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
operation.source.source.read(None),
operation.source.source.default_output_coder())
transform_spec = beam_runner_api_pb2.FunctionSpec(
- urn=sdk_worker.DATA_INPUT_URN,
+ urn=bundle_processor.DATA_INPUT_URN,
parameter=proto_utils.pack_Any(data_operation_spec))
else:
@@ -209,7 +210,7 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
# The Dataflow runner harness strips the base64 encoding. do the same
# here until we get the same thing back that we sent in.
transform_spec = beam_runner_api_pb2.FunctionSpec(
- urn=sdk_worker.PYTHON_SOURCE_URN,
+ urn=bundle_processor.PYTHON_SOURCE_URN,
parameter=proto_utils.pack_Any(
wrappers_pb2.BytesValue(
value=base64.b64decode(
@@ -223,21 +224,22 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner):
element_coder = si.source.default_output_coder()
# TODO(robertwb): Actually flesh out the ViewFn API.
side_input_extras.append((si.tag, element_coder))
- side_input_data[sdk_worker.side_input_tag(transform_id, si.tag)] = (
- self._reencode_elements(
- si.source.read(si.source.get_range_tracker(None, None)),
- element_coder))
+ side_input_data[
+ bundle_processor.side_input_tag(transform_id, si.tag)] = (
+ self._reencode_elements(
+ si.source.read(si.source.get_range_tracker(None, None)),
+ element_coder))
augmented_serialized_fn = pickler.dumps(
(operation.serialized_fn, side_input_extras))
transform_spec = beam_runner_api_pb2.FunctionSpec(
- urn=sdk_worker.PYTHON_DOFN_URN,
+ urn=bundle_processor.PYTHON_DOFN_URN,
parameter=proto_utils.pack_Any(
wrappers_pb2.BytesValue(value=augmented_serialized_fn)))
elif isinstance(operation, operation_specs.WorkerFlatten):
# Flatten is nice and simple.
transform_spec = beam_runner_api_pb2.FunctionSpec(
- urn=sdk_worker.IDENTITY_DOFN_URN)
+ urn=bundle_processor.IDENTITY_DOFN_URN)
else:
raise NotImplementedError(operation)
http://git-wip-us.apache.org/repos/asf/beam/blob/4abd7141/sdks/python/apache_beam/runners/worker/bundle_processor.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
new file mode 100644
index 0000000..2669bfc
--- /dev/null
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -0,0 +1,426 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""SDK harness for executing Python Fns via the Fn API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import base64
+import collections
+import json
+import logging
+
+from google.protobuf import wrappers_pb2
+
+from apache_beam.coders import coder_impl
+from apache_beam.coders import WindowedValueCoder
+from apache_beam.internal import pickler
+from apache_beam.io import iobase
+from apache_beam.portability.api import beam_fn_api_pb2
+from apache_beam.runners.dataflow.native_io import iobase as native_iobase
+from apache_beam.runners import pipeline_context
+from apache_beam.runners.worker import operation_specs
+from apache_beam.runners.worker import operations
+from apache_beam.utils import counters
+from apache_beam.utils import proto_utils
+
+# This module is experimental. No backwards-compatibility guarantees.
+
+
+try:
+ from apache_beam.runners.worker import statesampler
+except ImportError:
+ from apache_beam.runners.worker import statesampler_fake as statesampler
+
+
+DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1'
+DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1'
+IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1'
+PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1'
+PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1'
+# TODO(vikasrk): Fix this once runner sends appropriate python urns.
+PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1'
+PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1'
+
+
+def side_input_tag(transform_id, tag):
+ return str("%d[%s][%s]" % (len(transform_id), transform_id, tag))
+
+
+class RunnerIOOperation(operations.Operation):
+ """Common baseclass for runner harness IO operations."""
+
+ def __init__(self, operation_name, step_name, consumers, counter_factory,
+ state_sampler, windowed_coder, target, data_channel):
+ super(RunnerIOOperation, self).__init__(
+ operation_name, None, counter_factory, state_sampler)
+ self.windowed_coder = windowed_coder
+ self.step_name = step_name
+ # target represents the consumer for the bytes in the data plane for a
+ # DataInputOperation or a producer of these bytes for a DataOutputOperation.
+ self.target = target
+ self.data_channel = data_channel
+ for _, consumer_ops in consumers.items():
+ for consumer in consumer_ops:
+ self.add_receiver(consumer, 0)
+
+
+class DataOutputOperation(RunnerIOOperation):
+ """A sink-like operation that gathers outputs to be sent back to the runner.
+ """
+
+ def set_output_stream(self, output_stream):
+ self.output_stream = output_stream
+
+ def process(self, windowed_value):
+ self.windowed_coder.get_impl().encode_to_stream(
+ windowed_value, self.output_stream, True)
+
+ def finish(self):
+ self.output_stream.close()
+ super(DataOutputOperation, self).finish()
+
+
+class DataInputOperation(RunnerIOOperation):
+ """A source-like operation that gathers input from the runner.
+ """
+
+ def __init__(self, operation_name, step_name, consumers, counter_factory,
+ state_sampler, windowed_coder, input_target, data_channel):
+ super(DataInputOperation, self).__init__(
+ operation_name, step_name, consumers, counter_factory, state_sampler,
+ windowed_coder, target=input_target, data_channel=data_channel)
+ # We must do this manually as we don't have a spec or spec.output_coders.
+ self.receivers = [
+ operations.ConsumerSet(self.counter_factory, self.step_name, 0,
+ consumers.itervalues().next(),
+ self.windowed_coder)]
+
+ def process(self, windowed_value):
+ self.output(windowed_value)
+
+ def process_encoded(self, encoded_windowed_values):
+ input_stream = coder_impl.create_InputStream(encoded_windowed_values)
+ while input_stream.size() > 0:
+ decoded_value = self.windowed_coder.get_impl().decode_from_stream(
+ input_stream, True)
+ self.output(decoded_value)
+
+
+# TODO(robertwb): Revise side input API to not be in terms of native sources.
+# This will enable lookups, but there's an open question as to how to handle
+# custom sources without forcing intermediate materialization. This seems very
+# related to the desire to inject key and window preserving [Splittable]DoFns
+# into the view computation.
+class SideInputSource(native_iobase.NativeSource,
+ native_iobase.NativeSourceReader):
+ """A 'source' for reading side inputs via state API calls.
+ """
+
+ def __init__(self, state_handler, state_key, coder):
+ self._state_handler = state_handler
+ self._state_key = state_key
+ self._coder = coder
+
+ def reader(self):
+ return self
+
+ @property
+ def returns_windowed_values(self):
+ return True
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exn_info):
+ pass
+
+ def __iter__(self):
+ # TODO(robertwb): Support pagination.
+ input_stream = coder_impl.create_InputStream(
+ self._state_handler.Get(self._state_key).data)
+ while input_stream.size() > 0:
+ yield self._coder.get_impl().decode_from_stream(input_stream, True)
+
+
+def memoize(func):
+ cache = {}
+ missing = object()
+
+ def wrapper(*args):
+ result = cache.get(args, missing)
+ if result is missing:
+ result = cache[args] = func(*args)
+ return result
+ return wrapper
+
+
+def only_element(iterable):
+ element, = iterable
+ return element
+
+
+class BundleProcessor(object):
+ """A class for processing bundles of elements.
+ """
+ def __init__(
+ self, process_bundle_descriptor, state_handler, data_channel_factory):
+ self.process_bundle_descriptor = process_bundle_descriptor
+ self.state_handler = state_handler
+ self.data_channel_factory = data_channel_factory
+
+ def create_execution_tree(self, descriptor):
+ # TODO(robertwb): Figure out the correct prefix to use for output counters
+ # from StateSampler.
+ counter_factory = counters.CounterFactory()
+ state_sampler = statesampler.StateSampler(
+ 'fnapi-step%s-' % descriptor.id, counter_factory)
+
+ transform_factory = BeamTransformFactory(
+ descriptor, self.data_channel_factory, counter_factory, state_sampler,
+ self.state_handler)
+
+ pcoll_consumers = collections.defaultdict(list)
+ for transform_id, transform_proto in descriptor.transforms.items():
+ for pcoll_id in transform_proto.inputs.values():
+ pcoll_consumers[pcoll_id].append(transform_id)
+
+ @memoize
+ def get_operation(transform_id):
+ transform_consumers = {
+ tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
+ for tag, pcoll_id
+ in descriptor.transforms[transform_id].outputs.items()
+ }
+ return transform_factory.create_operation(
+ transform_id, transform_consumers)
+
+ # Operations must be started (hence returned) in order.
+ @memoize
+ def topological_height(transform_id):
+ return 1 + max(
+ [0] +
+ [topological_height(consumer)
+ for pcoll in descriptor.transforms[transform_id].outputs.values()
+ for consumer in pcoll_consumers[pcoll]])
+
+ return [get_operation(transform_id)
+ for transform_id in sorted(
+ descriptor.transforms, key=topological_height, reverse=True)]
+
+ def process_bundle(self, instruction_id):
+ ops = self.create_execution_tree(self.process_bundle_descriptor)
+
+ expected_inputs = []
+ for op in ops:
+ if isinstance(op, DataOutputOperation):
+ # TODO(robertwb): Is there a better way to pass the instruction id to
+ # the operation?
+ op.set_output_stream(op.data_channel.output_stream(
+ instruction_id, op.target))
+ elif isinstance(op, DataInputOperation):
+ # We must wait until we receive "end of stream" for each of these ops.
+ expected_inputs.append(op)
+
+ # Start all operations.
+ for op in reversed(ops):
+ logging.info('start %s', op)
+ op.start()
+
+ # Inject inputs from data plane.
+ for input_op in expected_inputs:
+ for data in input_op.data_channel.input_elements(
+ instruction_id, [input_op.target]):
+ # ignores input name
+ input_op.process_encoded(data.data)
+
+ # Finish all operations.
+ for op in ops:
+ logging.info('finish %s', op)
+ op.finish()
+
+
+class BeamTransformFactory(object):
+ """Factory for turning transform_protos into executable operations."""
+ def __init__(self, descriptor, data_channel_factory, counter_factory,
+ state_sampler, state_handler):
+ self.descriptor = descriptor
+ self.data_channel_factory = data_channel_factory
+ self.counter_factory = counter_factory
+ self.state_sampler = state_sampler
+ self.state_handler = state_handler
+ self.context = pipeline_context.PipelineContext(descriptor)
+
+ _known_urns = {}
+
+ @classmethod
+ def register_urn(cls, urn, parameter_type):
+ def wrapper(func):
+ cls._known_urns[urn] = func, parameter_type
+ return func
+ return wrapper
+
+ def create_operation(self, transform_id, consumers):
+ transform_proto = self.descriptor.transforms[transform_id]
+ creator, parameter_type = self._known_urns[transform_proto.spec.urn]
+ parameter = proto_utils.unpack_Any(
+ transform_proto.spec.parameter, parameter_type)
+ return creator(self, transform_id, transform_proto, parameter, consumers)
+
+ def get_coder(self, coder_id):
+ coder_proto = self.descriptor.coders[coder_id]
+ if coder_proto.spec.spec.urn:
+ return self.context.coders.get_by_id(coder_id)
+ else:
+ # No URN, assume cloud object encoding json bytes.
+ return operation_specs.get_coder_from_spec(
+ json.loads(
+ proto_utils.unpack_Any(coder_proto.spec.spec.parameter,
+ wrappers_pb2.BytesValue).value))
+
+ def get_output_coders(self, transform_proto):
+ return {
+ tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
+ for tag, pcoll_id in transform_proto.outputs.items()
+ }
+
+ def get_only_output_coder(self, transform_proto):
+ return only_element(self.get_output_coders(transform_proto).values())
+
+ def get_input_coders(self, transform_proto):
+ return {
+ tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
+ for tag, pcoll_id in transform_proto.inputs.items()
+ }
+
+ def get_only_input_coder(self, transform_proto):
+ return only_element(self.get_input_coders(transform_proto).values())
+
+ # TODO(robertwb): Update all operations to take these in the constructor.
+ @staticmethod
+ def augment_oldstyle_op(op, step_name, consumers, tag_list=None):
+ op.step_name = step_name
+ for tag, op_consumers in consumers.items():
+ for consumer in op_consumers:
+ op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0)
+ return op
+
+
+@BeamTransformFactory.register_urn(
+ DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
+def create(factory, transform_id, transform_proto, grpc_port, consumers):
+ target = beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_id,
+ name=only_element(transform_proto.outputs.keys()))
+ return DataInputOperation(
+ transform_proto.unique_name,
+ transform_proto.unique_name,
+ consumers,
+ factory.counter_factory,
+ factory.state_sampler,
+ factory.get_only_output_coder(transform_proto),
+ input_target=target,
+ data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
+
+
+@BeamTransformFactory.register_urn(
+ DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
+def create(factory, transform_id, transform_proto, grpc_port, consumers):
+ target = beam_fn_api_pb2.Target(
+ primitive_transform_reference=transform_id,
+ name=only_element(transform_proto.inputs.keys()))
+ return DataOutputOperation(
+ transform_proto.unique_name,
+ transform_proto.unique_name,
+ consumers,
+ factory.counter_factory,
+ factory.state_sampler,
+ # TODO(robertwb): Perhaps this could be distinct from the input coder?
+ factory.get_only_input_coder(transform_proto),
+ target=target,
+ data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
+
+
+@BeamTransformFactory.register_urn(PYTHON_SOURCE_URN, wrappers_pb2.BytesValue)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+ # The Dataflow runner harness strips the base64 encoding.
+ source = pickler.loads(base64.b64encode(parameter.value))
+ spec = operation_specs.WorkerRead(
+ iobase.SourceBundle(1.0, source, None, None),
+ [WindowedValueCoder(source.default_output_coder())])
+ return factory.augment_oldstyle_op(
+ operations.ReadOperation(
+ transform_proto.unique_name,
+ spec,
+ factory.counter_factory,
+ factory.state_sampler),
+ transform_proto.unique_name,
+ consumers)
+
+
+@BeamTransformFactory.register_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue)
+def create(factory, transform_id, transform_proto, parameter, consumers):
+ dofn_data = pickler.loads(parameter.value)
+ if len(dofn_data) == 2:
+ # Has side input data.
+ serialized_fn, side_input_data = dofn_data
+ else:
+ # No side input data.
+ serialized_fn, side_input_data = parameter.value, []
+
+ def create_side_input(tag, coder):
+ # TODO(robertwb): Extract windows (and keys) out of element data.
+ # TODO(robertwb): Extract state key from ParDoPayload.
+ return operation_specs.WorkerSideInputSource(
+ tag=tag,
+ source=SideInputSource(
+ factory.state_handler,
+ beam_fn_api_pb2.StateKey.MultimapSideInput(
+ key=side_input_tag(transform_id, tag)),
+ coder=coder))
+ output_tags = list(transform_proto.outputs.keys())
+ output_coders = factory.get_output_coders(transform_proto)
+ spec = operation_specs.WorkerDoFn(
+ serialized_fn=serialized_fn,
+ output_tags=output_tags,
+ input=None,
+ side_inputs=[
+ create_side_input(tag, coder) for tag, coder in side_input_data],
+ output_coders=[output_coders[tag] for tag in output_tags])
+ return factory.augment_oldstyle_op(
+ operations.DoOperation(
+ transform_proto.unique_name,
+ spec,
+ factory.counter_factory,
+ factory.state_sampler),
+ transform_proto.unique_name,
+ consumers,
+ output_tags)
+
+
+@BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
+def create(factory, transform_id, transform_proto, unused_parameter, consumers):
+ return factory.augment_oldstyle_op(
+ operations.FlattenOperation(
+ transform_proto.unique_name,
+ None,
+ factory.counter_factory,
+ factory.state_sampler),
+ transform_proto.unique_name,
+ consumers)
http://git-wip-us.apache.org/repos/asf/beam/blob/4abd7141/sdks/python/apache_beam/runners/worker/sdk_worker.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index ae86830..6a23680 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -21,170 +21,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import base64
-import collections
-import json
import logging
import Queue as queue
import threading
import traceback
-from google.protobuf import wrappers_pb2
-
-from apache_beam.coders import coder_impl
-from apache_beam.coders import WindowedValueCoder
-from apache_beam.internal import pickler
-from apache_beam.io import iobase
from apache_beam.portability.api import beam_fn_api_pb2
-from apache_beam.runners.dataflow.native_io import iobase as native_iobase
-from apache_beam.runners import pipeline_context
-from apache_beam.runners.worker import operation_specs
-from apache_beam.runners.worker import operations
-from apache_beam.utils import counters
-from apache_beam.utils import proto_utils
-
-# This module is experimental. No backwards-compatibility guarantees.
-
-
-try:
- from apache_beam.runners.worker import statesampler
-except ImportError:
- from apache_beam.runners.worker import statesampler_fake as statesampler
-from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory
-
-
-DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1'
-DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1'
-IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1'
-PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1'
-PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1'
-# TODO(vikasrk): Fix this once runner sends appropriate python urns.
-PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1'
-PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1'
-
-
-def side_input_tag(transform_id, tag):
- return str("%d[%s][%s]" % (len(transform_id), transform_id, tag))
-
-
-class RunnerIOOperation(operations.Operation):
- """Common baseclass for runner harness IO operations."""
-
- def __init__(self, operation_name, step_name, consumers, counter_factory,
- state_sampler, windowed_coder, target, data_channel):
- super(RunnerIOOperation, self).__init__(
- operation_name, None, counter_factory, state_sampler)
- self.windowed_coder = windowed_coder
- self.step_name = step_name
- # target represents the consumer for the bytes in the data plane for a
- # DataInputOperation or a producer of these bytes for a DataOutputOperation.
- self.target = target
- self.data_channel = data_channel
- for _, consumer_ops in consumers.items():
- for consumer in consumer_ops:
- self.add_receiver(consumer, 0)
-
-
-class DataOutputOperation(RunnerIOOperation):
- """A sink-like operation that gathers outputs to be sent back to the runner.
- """
-
- def set_output_stream(self, output_stream):
- self.output_stream = output_stream
-
- def process(self, windowed_value):
- self.windowed_coder.get_impl().encode_to_stream(
- windowed_value, self.output_stream, True)
-
- def finish(self):
- self.output_stream.close()
- super(DataOutputOperation, self).finish()
-
-
-class DataInputOperation(RunnerIOOperation):
- """A source-like operation that gathers input from the runner.
- """
-
- def __init__(self, operation_name, step_name, consumers, counter_factory,
- state_sampler, windowed_coder, input_target, data_channel):
- super(DataInputOperation, self).__init__(
- operation_name, step_name, consumers, counter_factory, state_sampler,
- windowed_coder, target=input_target, data_channel=data_channel)
- # We must do this manually as we don't have a spec or spec.output_coders.
- self.receivers = [
- operations.ConsumerSet(self.counter_factory, self.step_name, 0,
- consumers.itervalues().next(),
- self.windowed_coder)]
-
- def process(self, windowed_value):
- self.output(windowed_value)
-
- def process_encoded(self, encoded_windowed_values):
- input_stream = coder_impl.create_InputStream(encoded_windowed_values)
- while input_stream.size() > 0:
- decoded_value = self.windowed_coder.get_impl().decode_from_stream(
- input_stream, True)
- self.output(decoded_value)
-
-
-# TODO(robertwb): Revise side input API to not be in terms of native sources.
-# This will enable lookups, but there's an open question as to how to handle
-# custom sources without forcing intermediate materialization. This seems very
-# related to the desire to inject key and window preserving [Splittable]DoFns
-# into the view computation.
-class SideInputSource(native_iobase.NativeSource,
- native_iobase.NativeSourceReader):
- """A 'source' for reading side inputs via state API calls.
- """
-
- def __init__(self, state_handler, state_key, coder):
- self._state_handler = state_handler
- self._state_key = state_key
- self._coder = coder
-
- def reader(self):
- return self
-
- @property
- def returns_windowed_values(self):
- return True
-
- def __enter__(self):
- return self
-
- def __exit__(self, *exn_info):
- pass
-
- def __iter__(self):
- # TODO(robertwb): Support pagination.
- input_stream = coder_impl.create_InputStream(
- self._state_handler.Get(self._state_key).data)
- while input_stream.size() > 0:
- yield self._coder.get_impl().decode_from_stream(input_stream, True)
-
-
-def memoize(func):
- cache = {}
- missing = object()
-
- def wrapper(*args):
- result = cache.get(args, missing)
- if result is missing:
- result = cache[args] = func(*args)
- return result
- return wrapper
-
-
-def only_element(iterable):
- element, = iterable
- return element
+from apache_beam.runners.worker import bundle_processor
+from apache_beam.runners.worker import data_plane
class SdkHarness(object):
def __init__(self, control_channel):
self._control_channel = control_channel
- self._data_channel_factory = GrpcClientDataChannelFactory()
+ self._data_channel_factory = data_plane.GrpcClientDataChannelFactory()
def run(self):
contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel)
@@ -251,245 +102,10 @@ class SdkWorker(object):
self.fns[process_bundle_descriptor.id] = process_bundle_descriptor
return beam_fn_api_pb2.RegisterResponse()
- def create_execution_tree(self, descriptor):
- # TODO(robertwb): Figure out the correct prefix to use for output counters
- # from StateSampler.
- counter_factory = counters.CounterFactory()
- state_sampler = statesampler.StateSampler(
- 'fnapi-step%s-' % descriptor.id, counter_factory)
-
- transform_factory = BeamTransformFactory(
- descriptor, self.data_channel_factory, counter_factory, state_sampler,
- self.state_handler)
-
- pcoll_consumers = collections.defaultdict(list)
- for transform_id, transform_proto in descriptor.transforms.items():
- for pcoll_id in transform_proto.inputs.values():
- pcoll_consumers[pcoll_id].append(transform_id)
-
- @memoize
- def get_operation(transform_id):
- transform_consumers = {
- tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
- for tag, pcoll_id
- in descriptor.transforms[transform_id].outputs.items()
- }
- return transform_factory.create_operation(
- transform_id, transform_consumers)
-
- # Operations must be started (hence returned) in order.
- @memoize
- def topological_height(transform_id):
- return 1 + max(
- [0] +
- [topological_height(consumer)
- for pcoll in descriptor.transforms[transform_id].outputs.values()
- for consumer in pcoll_consumers[pcoll]])
-
- return [get_operation(transform_id)
- for transform_id in sorted(
- descriptor.transforms, key=topological_height, reverse=True)]
-
def process_bundle(self, request, instruction_id):
- ops = self.create_execution_tree(
- self.fns[request.process_bundle_descriptor_reference])
-
- expected_inputs = []
- for op in ops:
- if isinstance(op, DataOutputOperation):
- # TODO(robertwb): Is there a better way to pass the instruction id to
- # the operation?
- op.set_output_stream(op.data_channel.output_stream(
- instruction_id, op.target))
- elif isinstance(op, DataInputOperation):
- # We must wait until we receive "end of stream" for each of these ops.
- expected_inputs.append(op)
-
- # Start all operations.
- for op in reversed(ops):
- logging.info('start %s', op)
- op.start()
-
- # Inject inputs from data plane.
- for input_op in expected_inputs:
- for data in input_op.data_channel.input_elements(
- instruction_id, [input_op.target]):
- # ignores input name
- input_op.process_encoded(data.data)
-
- # Finish all operations.
- for op in ops:
- logging.info('finish %s', op)
- op.finish()
+ bundle_processor.BundleProcessor(
+ self.fns[request.process_bundle_descriptor_reference],
+ self.state_handler,
+ self.data_channel_factory).process_bundle(instruction_id)
return beam_fn_api_pb2.ProcessBundleResponse()
-
-
-class BeamTransformFactory(object):
- """Factory for turning transform_protos into executable operations."""
- def __init__(self, descriptor, data_channel_factory, counter_factory,
- state_sampler, state_handler):
- self.descriptor = descriptor
- self.data_channel_factory = data_channel_factory
- self.counter_factory = counter_factory
- self.state_sampler = state_sampler
- self.state_handler = state_handler
- self.context = pipeline_context.PipelineContext(descriptor)
-
- _known_urns = {}
-
- @classmethod
- def register_urn(cls, urn, parameter_type):
- def wrapper(func):
- cls._known_urns[urn] = func, parameter_type
- return func
- return wrapper
-
- def create_operation(self, transform_id, consumers):
- transform_proto = self.descriptor.transforms[transform_id]
- creator, parameter_type = self._known_urns[transform_proto.spec.urn]
- parameter = proto_utils.unpack_Any(
- transform_proto.spec.parameter, parameter_type)
- return creator(self, transform_id, transform_proto, parameter, consumers)
-
- def get_coder(self, coder_id):
- coder_proto = self.descriptor.coders[coder_id]
- if coder_proto.spec.spec.urn:
- return self.context.coders.get_by_id(coder_id)
- else:
- # No URN, assume cloud object encoding json bytes.
- return operation_specs.get_coder_from_spec(
- json.loads(
- proto_utils.unpack_Any(coder_proto.spec.spec.parameter,
- wrappers_pb2.BytesValue).value))
-
- def get_output_coders(self, transform_proto):
- return {
- tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
- for tag, pcoll_id in transform_proto.outputs.items()
- }
-
- def get_only_output_coder(self, transform_proto):
- return only_element(self.get_output_coders(transform_proto).values())
-
- def get_input_coders(self, transform_proto):
- return {
- tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
- for tag, pcoll_id in transform_proto.inputs.items()
- }
-
- def get_only_input_coder(self, transform_proto):
- return only_element(self.get_input_coders(transform_proto).values())
-
- # TODO(robertwb): Update all operations to take these in the constructor.
- @staticmethod
- def augment_oldstyle_op(op, step_name, consumers, tag_list=None):
- op.step_name = step_name
- for tag, op_consumers in consumers.items():
- for consumer in op_consumers:
- op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0)
- return op
-
-
-@BeamTransformFactory.register_urn(
- DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
-def create(factory, transform_id, transform_proto, grpc_port, consumers):
- target = beam_fn_api_pb2.Target(
- primitive_transform_reference=transform_id,
- name=only_element(transform_proto.outputs.keys()))
- return DataInputOperation(
- transform_proto.unique_name,
- transform_proto.unique_name,
- consumers,
- factory.counter_factory,
- factory.state_sampler,
- factory.get_only_output_coder(transform_proto),
- input_target=target,
- data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
-
-
-@BeamTransformFactory.register_urn(
- DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
-def create(factory, transform_id, transform_proto, grpc_port, consumers):
- target = beam_fn_api_pb2.Target(
- primitive_transform_reference=transform_id,
- name=only_element(transform_proto.inputs.keys()))
- return DataOutputOperation(
- transform_proto.unique_name,
- transform_proto.unique_name,
- consumers,
- factory.counter_factory,
- factory.state_sampler,
- # TODO(robertwb): Perhaps this could be distinct from the input coder?
- factory.get_only_input_coder(transform_proto),
- target=target,
- data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
-
-
-@BeamTransformFactory.register_urn(PYTHON_SOURCE_URN, wrappers_pb2.BytesValue)
-def create(factory, transform_id, transform_proto, parameter, consumers):
- # The Dataflow runner harness strips the base64 encoding.
- source = pickler.loads(base64.b64encode(parameter.value))
- spec = operation_specs.WorkerRead(
- iobase.SourceBundle(1.0, source, None, None),
- [WindowedValueCoder(source.default_output_coder())])
- return factory.augment_oldstyle_op(
- operations.ReadOperation(
- transform_proto.unique_name,
- spec,
- factory.counter_factory,
- factory.state_sampler),
- transform_proto.unique_name,
- consumers)
-
-
-@BeamTransformFactory.register_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue)
-def create(factory, transform_id, transform_proto, parameter, consumers):
- dofn_data = pickler.loads(parameter.value)
- if len(dofn_data) == 2:
- # Has side input data.
- serialized_fn, side_input_data = dofn_data
- else:
- # No side input data.
- serialized_fn, side_input_data = parameter.value, []
-
- def create_side_input(tag, coder):
- # TODO(robertwb): Extract windows (and keys) out of element data.
- # TODO(robertwb): Extract state key from ParDoPayload.
- return operation_specs.WorkerSideInputSource(
- tag=tag,
- source=SideInputSource(
- factory.state_handler,
- beam_fn_api_pb2.StateKey.MultimapSideInput(
- key=side_input_tag(transform_id, tag)),
- coder=coder))
- output_tags = list(transform_proto.outputs.keys())
- output_coders = factory.get_output_coders(transform_proto)
- spec = operation_specs.WorkerDoFn(
- serialized_fn=serialized_fn,
- output_tags=output_tags,
- input=None,
- side_inputs=[
- create_side_input(tag, coder) for tag, coder in side_input_data],
- output_coders=[output_coders[tag] for tag in output_tags])
- return factory.augment_oldstyle_op(
- operations.DoOperation(
- transform_proto.unique_name,
- spec,
- factory.counter_factory,
- factory.state_sampler),
- transform_proto.unique_name,
- consumers,
- output_tags)
-
-
-@BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
-def create(factory, transform_id, transform_proto, unused_parameter, consumers):
- return factory.augment_oldstyle_op(
- operations.FlattenOperation(
- transform_proto.unique_name,
- None,
- factory.counter_factory,
- factory.state_sampler),
- transform_proto.unique_name,
- consumers)