You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by pa...@apache.org on 2020/04/28 00:27:51 UTC
[beam] 01/01: Revert "[BEAM-9639][BEAM-9608] Improvements for
FnApiRunner"
This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch revert-11270-fn-ref-more
in repository https://gitbox.apache.org/repos/asf/beam.git
commit e5e52694bcaf9a4cdc9fd4a130f8cca4dcc6fe6a
Author: Pablo <pa...@users.noreply.github.com>
AuthorDate: Mon Apr 27 17:27:34 2020 -0700
Revert "[BEAM-9639][BEAM-9608] Improvements for FnApiRunner"
---
.../runners/portability/fn_api_runner/execution.py | 179 +--------
.../runners/portability/fn_api_runner/fn_runner.py | 406 +++++++++++++--------
.../portability/fn_api_runner/fn_runner_test.py | 24 --
.../portability/fn_api_runner/translations.py | 11 -
4 files changed, 255 insertions(+), 365 deletions(-)
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 2f29515..e62d8a8 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -22,29 +22,17 @@ from __future__ import absolute_import
import collections
import itertools
from typing import TYPE_CHECKING
-from typing import Any
-from typing import DefaultDict
-from typing import Dict
-from typing import Iterator
-from typing import List
from typing import MutableMapping
-from typing import Optional
-from typing import Tuple
from typing_extensions import Protocol
from apache_beam import coders
-from apache_beam.coders import BytesCoder
from apache_beam.coders.coder_impl import create_InputStream
from apache_beam.coders.coder_impl import create_OutputStream
-from apache_beam.coders.coders import GlobalWindowCoder
-from apache_beam.coders.coders import WindowedValueCoder
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
-from apache_beam.runners.portability.fn_api_runner import translations
-from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import only_element
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import unique_name
@@ -57,13 +45,8 @@ from apache_beam.utils import windowed_value
if TYPE_CHECKING:
from apache_beam.coders.coder_impl import CoderImpl
+ from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner import worker_handlers
- from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
- from apache_beam.transforms.window import BoundedWindow
-
-ENCODED_IMPULSE_VALUE = WindowedValueCoder(
- BytesCoder(), GlobalWindowCoder()).get_impl().encode_nested(
- GlobalWindows.windowed_value(b''))
class Buffer(Protocol):
@@ -221,7 +204,7 @@ class WindowGroupingBuffer(object):
def __init__(
self,
access_pattern,
- coder # type: WindowedValueCoder
+ coder # type: coders.WindowedValueCoder
):
# type: (...) -> None
# Here's where we would use a different type of partitioning
@@ -268,12 +251,11 @@ class WindowGroupingBuffer(object):
class FnApiRunnerExecutionContext(object):
"""
- :var pcoll_buffers: (dict): Mapping of
+ :var pcoll_buffers: (collections.defaultdict of str: list): Mapping of
PCollection IDs to list that functions as buffer for the
``beam.PCollection``.
"""
def __init__(self,
- stages, # type: List[translations.Stage]
worker_handler_manager, # type: worker_handlers.WorkerHandlerManager
pipeline_components, # type: beam_runner_api_pb2.Components
safe_coders,
@@ -286,9 +268,6 @@ class FnApiRunnerExecutionContext(object):
:param safe_coders:
:param data_channel_coders:
"""
- self.stages = stages
- self.side_input_descriptors_by_stage = (
- self._build_data_side_inputs_map(stages))
self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer]
self.timer_buffers = {} # type: MutableMapping[bytes, ListBuffer]
self.worker_handler_manager = worker_handler_manager
@@ -301,63 +280,6 @@ class FnApiRunnerExecutionContext(object):
iterable_state_write=self._iterable_state_write)
self._last_uid = -1
- @staticmethod
- def _build_data_side_inputs_map(stages):
- # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput]
-
- """Builds an index mapping stages to side input descriptors.
-
- A side input descriptor is a map of side input IDs to side input access
- patterns for all of the outputs of a stage that will be consumed as a
- side input.
- """
- transform_consumers = collections.defaultdict(
- list) # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]]
- stage_consumers = collections.defaultdict(
- list) # type: DefaultDict[str, List[translations.Stage]]
-
- def get_all_side_inputs():
- # type: () -> Set[str]
- all_side_inputs = set() # type: Set[str]
- for stage in stages:
- for transform in stage.transforms:
- for input in transform.inputs.values():
- transform_consumers[input].append(transform)
- stage_consumers[input].append(stage)
- for si in stage.side_inputs():
- all_side_inputs.add(si)
- return all_side_inputs
-
- all_side_inputs = frozenset(get_all_side_inputs())
- data_side_inputs_by_producing_stage = {}
-
- producing_stages_by_pcoll = {}
-
- for s in stages:
- data_side_inputs_by_producing_stage[s.name] = {}
- for transform in s.transforms:
- for o in transform.outputs.values():
- if o in s.side_inputs():
- continue
- producing_stages_by_pcoll[o] = s
-
- for side_pc in all_side_inputs:
- for consuming_transform in transform_consumers[side_pc]:
- if consuming_transform.spec.urn not in translations.PAR_DO_URNS:
- continue
- producing_stage = producing_stages_by_pcoll[side_pc]
- payload = proto_utils.parse_Bytes(
- consuming_transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
- for si_tag in payload.side_inputs:
- if consuming_transform.inputs[si_tag] == side_pc:
- side_input_id = (consuming_transform.unique_name, si_tag)
- data_side_inputs_by_producing_stage[
- producing_stage.name][side_input_id] = (
- translations.create_buffer_id(side_pc),
- payload.side_inputs[si_tag].access_pattern)
-
- return data_side_inputs_by_producing_stage
-
@property
def state_servicer(self):
# TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
@@ -379,43 +301,6 @@ class FnApiRunnerExecutionContext(object):
out.get())
return token
- def commit_side_inputs_to_state(
- self,
- data_side_input, # type: DataSideInput
- ):
- # type: (...) -> None
- for (consuming_transform_id, tag), (buffer_id,
- func_spec) in data_side_input.items():
- _, pcoll_id = split_buffer_id(buffer_id)
- value_coder = self.pipeline_context.coders[self.safe_coders[
- self.data_channel_coders[pcoll_id]]]
- elements_by_window = WindowGroupingBuffer(func_spec, value_coder)
- if buffer_id not in self.pcoll_buffers:
- self.pcoll_buffers[buffer_id] = ListBuffer(
- coder_impl=value_coder.get_impl())
- for element_data in self.pcoll_buffers[buffer_id]:
- elements_by_window.append(element_data)
-
- if func_spec.urn == common_urns.side_inputs.ITERABLE.urn:
- for _, window, elements_data in elements_by_window.encoded_items():
- state_key = beam_fn_api_pb2.StateKey(
- iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
- transform_id=consuming_transform_id,
- side_input_id=tag,
- window=window))
- self.state_servicer.append_raw(state_key, elements_data)
- elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn:
- for key, window, elements_data in elements_by_window.encoded_items():
- state_key = beam_fn_api_pb2.StateKey(
- multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
- transform_id=consuming_transform_id,
- side_input_id=tag,
- window=window,
- key=key))
- self.state_servicer.append_raw(state_key, elements_data)
- else:
- raise ValueError("Unknown access pattern: '%s'" % func_spec.urn)
-
class BundleContextManager(object):
@@ -482,64 +367,6 @@ class BundleContextManager(object):
state_api_service_descriptor=self.state_api_service_descriptor(),
timer_api_service_descriptor=self.data_api_service_descriptor())
- def extract_bundle_inputs_and_outputs(self):
- # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]]
-
- """Returns maps of transform names to PCollection identifiers.
-
- Also mutates IO stages to point to the data ApiServiceDescriptor.
-
- Returns:
- A tuple of (data_input, data_output, expected_timer_output) dictionaries.
- `data_input` is a dictionary mapping (transform_name, output_name) to a
- PCollection buffer; `data_output` is a dictionary mapping
- (transform_name, output_name) to a PCollection ID.
- `expected_timer_output` is a dictionary mapping transform_id and
- timer family ID to a buffer id for timers.
- """
- data_input = {} # type: Dict[str, PartitionableBuffer]
- data_output = {} # type: DataOutput
- # A mapping of {(transform_id, timer_family_id) : buffer_id}
- expected_timer_output = {} # type: Dict[Tuple[str, str], str]
- for transform in self.stage.transforms:
- if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
- bundle_processor.DATA_OUTPUT_URN):
- pcoll_id = transform.spec.payload
- if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
- coder_id = self.execution_context.data_channel_coders[only_element(
- transform.outputs.values())]
- coder = self.execution_context.pipeline_context.coders[
- self.execution_context.safe_coders.get(coder_id, coder_id)]
- if pcoll_id == translations.IMPULSE_BUFFER:
- data_input[transform.unique_name] = ListBuffer(
- coder_impl=coder.get_impl())
- data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
- else:
- if pcoll_id not in self.execution_context.pcoll_buffers:
- self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
- coder_impl=coder.get_impl())
- data_input[transform.unique_name] = (
- self.execution_context.pcoll_buffers[pcoll_id])
- elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
- data_output[transform.unique_name] = pcoll_id
- coder_id = self.execution_context.data_channel_coders[only_element(
- transform.inputs.values())]
- else:
- raise NotImplementedError
- data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
- data_api_service_descriptor = self.data_api_service_descriptor()
- if data_api_service_descriptor:
- data_spec.api_service_descriptor.url = (
- data_api_service_descriptor.url)
- transform.spec.payload = data_spec.SerializeToString()
- elif transform.spec.urn in translations.PAR_DO_URNS:
- payload = proto_utils.parse_Bytes(
- transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
- for timer_family_id in payload.timer_family_specs.keys():
- expected_timer_output[(transform.unique_name, timer_family_id)] = (
- create_buffer_id(timer_family_id, 'timers'))
- return data_input, data_output, expected_timer_output
-
def get_input_coder_impl(self, transform_id):
# type: (str) -> CoderImpl
coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString(
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index 1f162ac..8e7ba2d 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -40,9 +40,11 @@ from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import Optional
+from typing import Sequence
from typing import Tuple
from typing import TypeVar
+import apache_beam as beam # pylint: disable=ungrouped-imports
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.metrics import metric
from apache_beam.metrics import monitoring_infos
@@ -58,9 +60,12 @@ from apache_beam.runners.portability import portable_metrics
from apache_beam.runners.portability.fn_api_runner import execution
from apache_beam.runners.portability.fn_api_runner import translations
from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer
+from apache_beam.runners.portability.fn_api_runner.execution import WindowGroupingBuffer
from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import only_element
+from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandlerManager
+from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import environments
from apache_beam.utils import profiler
from apache_beam.utils import proto_utils
@@ -68,6 +73,7 @@ from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor
if TYPE_CHECKING:
from apache_beam.pipeline import Pipeline
+ from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.portability.api import metrics_pb2
_LOGGER = logging.getLogger(__name__)
@@ -82,6 +88,11 @@ BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse,
# This module is experimental. No backwards-compatibility guarantees.
+ENCODED_IMPULSE_VALUE = beam.coders.WindowedValueCoder(
+ beam.coders.BytesCoder(),
+ beam.coders.coders.GlobalWindowCoder()).get_impl().encode_nested(
+ beam.transforms.window.GlobalWindows.windowed_value(b''))
+
class FnApiRunner(runner.PipelineRunner):
@@ -314,7 +325,6 @@ class FnApiRunner(runner.PipelineRunner):
monitoring_infos_by_stage = {}
runner_execution_context = execution.FnApiRunnerExecutionContext(
- stages,
worker_handler_manager,
stage_context.components,
stage_context.safe_coders,
@@ -325,7 +335,6 @@ class FnApiRunner(runner.PipelineRunner):
for stage in stages:
bundle_context_manager = execution.BundleContextManager(
runner_execution_context, stage, self._num_workers)
-
stage_results = self._run_stage(
runner_execution_context,
bundle_context_manager,
@@ -336,14 +345,54 @@ class FnApiRunner(runner.PipelineRunner):
worker_handler_manager.close_all()
return RunnerResult(runner.PipelineState.DONE, monitoring_infos_by_stage)
+ def _store_side_inputs_in_state(self,
+ runner_execution_context, # type: execution.FnApiRunnerExecutionContext
+ data_side_input, # type: DataSideInput
+ ):
+ # type: (...) -> None
+ for (transform_id, tag), (buffer_id, si) in data_side_input.items():
+ _, pcoll_id = split_buffer_id(buffer_id)
+ value_coder = runner_execution_context.pipeline_context.coders[
+ runner_execution_context.safe_coders[
+ runner_execution_context.data_channel_coders[pcoll_id]]]
+ elements_by_window = WindowGroupingBuffer(si, value_coder)
+ if buffer_id not in runner_execution_context.pcoll_buffers:
+ runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer(
+ coder_impl=value_coder.get_impl())
+ for element_data in runner_execution_context.pcoll_buffers[buffer_id]:
+ elements_by_window.append(element_data)
+
+ if si.urn == common_urns.side_inputs.ITERABLE.urn:
+ for _, window, elements_data in elements_by_window.encoded_items():
+ state_key = beam_fn_api_pb2.StateKey(
+ iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
+ transform_id=transform_id, side_input_id=tag, window=window))
+ (
+ runner_execution_context.worker_handler_manager.state_servicer.
+ append_raw(state_key, elements_data))
+ elif si.urn == common_urns.side_inputs.MULTIMAP.urn:
+ for key, window, elements_data in elements_by_window.encoded_items():
+ state_key = beam_fn_api_pb2.StateKey(
+ multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
+ transform_id=transform_id,
+ side_input_id=tag,
+ window=window,
+ key=key))
+ (
+ runner_execution_context.worker_handler_manager.state_servicer.
+ append_raw(state_key, elements_data))
+ else:
+ raise ValueError("Unknown access pattern: '%s'" % si.urn)
+
def _run_bundle_multiple_times_for_testing(
self,
runner_execution_context, # type: execution.FnApiRunnerExecutionContext
- bundle_manager, # type: BundleManager
+ bundle_context_manager, # type: execution.BundleContextManager
data_input,
data_output, # type: DataOutput
fired_timers,
expected_output_timers,
+ cache_token_generator
):
# type: (...) -> None
@@ -354,12 +403,18 @@ class FnApiRunner(runner.PipelineRunner):
for _ in range(self._bundle_repeat):
try:
runner_execution_context.state_servicer.checkpoint()
- bundle_manager.process_bundle(
- data_input,
- data_output,
- fired_timers,
- expected_output_timers,
- dry_run=True)
+ testing_bundle_manager = ParallelBundleManager(
+ bundle_context_manager.worker_handlers,
+ lambda pcoll_id,
+ transform_id: ListBuffer(
+ coder_impl=bundle_context_manager.get_input_coder_impl),
+ bundle_context_manager.get_input_coder_impl,
+ bundle_context_manager.process_bundle_descriptor,
+ self._progress_frequency,
+ num_workers=self._num_workers,
+ cache_token_generator=cache_token_generator)
+ testing_bundle_manager.process_bundle(
+ data_input, data_output, fired_timers, expected_output_timers)
finally:
runner_execution_context.state_servicer.restore()
@@ -389,17 +444,6 @@ class FnApiRunner(runner.PipelineRunner):
fired_timers[(transform_id, timer_family_id)].append(out.get())
written_timers.clear()
- def _add_sdk_delayed_applications_to_deferred_inputs(
- self, bundle_context_manager, bundle_result, deferred_inputs):
- for delayed_application in bundle_result.process_bundle.residual_roots:
- name = bundle_context_manager.input_for(
- delayed_application.application.transform_id,
- delayed_application.application.input_id)
- if name not in deferred_inputs:
- deferred_inputs[name] = ListBuffer(
- coder_impl=bundle_context_manager.get_input_coder_impl(name))
- deferred_inputs[name].append(delayed_application.application.element)
-
def _add_residuals_and_channel_splits_to_deferred_inputs(
self,
splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
@@ -463,115 +507,164 @@ class FnApiRunner(runner.PipelineRunner):
Args:
runner_execution_context (execution.FnApiRunnerExecutionContext): An
object containing execution information for the pipeline.
- bundle_context_manager (execution.BundleContextManager): A description of
- the stage to execute, and its context.
+ stage (translations.Stage): A description of the stage to execute.
"""
- data_input, data_output, expected_timer_output = (
- bundle_context_manager.extract_bundle_inputs_and_outputs())
- input_timers = {}
-
+ worker_handler_list = bundle_context_manager.worker_handlers
worker_handler_manager = runner_execution_context.worker_handler_manager
_LOGGER.info('Running %s', bundle_context_manager.stage.name)
+ (data_input, data_side_input, data_output,
+ expected_timer_output) = self._extract_endpoints(
+ bundle_context_manager, runner_execution_context)
worker_handler_manager.register_process_bundle_descriptor(
bundle_context_manager.process_bundle_descriptor)
- # We create the bundle manager here, as it can be reused for bundles of the
- # same stage, but it may have to be created by-bundle later on.
+ # Store the required side inputs into state so it is accessible for the
+ # worker when it runs this bundle.
+ self._store_side_inputs_in_state(runner_execution_context, data_side_input)
+
+ # Change cache token across bundle repeats
cache_token_generator = FnApiRunner.get_cache_token_generator(static=False)
- bundle_manager = ParallelBundleManager(
+
+ self._run_bundle_multiple_times_for_testing(
+ runner_execution_context,
bundle_context_manager,
+ data_input,
+ data_output, {},
+ expected_timer_output,
+ cache_token_generator=cache_token_generator)
+
+ bundle_manager = ParallelBundleManager(
+ worker_handler_list,
+ bundle_context_manager.get_buffer,
+ bundle_context_manager.get_input_coder_impl,
+ bundle_context_manager.process_bundle_descriptor,
self._progress_frequency,
+ num_workers=self._num_workers,
cache_token_generator=cache_token_generator)
- final_result = None
+ # For the first time of processing, we don't have fired timers as inputs.
+ result, splits = bundle_manager.process_bundle(data_input,
+ data_output,
+ {},
+ expected_timer_output)
- def merge_results(last_result):
- """ Merge the latest result with other accumulated results. """
- return (
- last_result
- if final_result is None else beam_fn_api_pb2.InstructionResponse(
- process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
- monitoring_infos=monitoring_infos.consolidate(
- itertools.chain(
- final_result.process_bundle.monitoring_infos,
- last_result.process_bundle.monitoring_infos))),
- error=final_result.error or last_result.error))
+ last_result = result
+ last_sent = data_input
+ # We cannot split deferred_input until we include residual_roots to
+ # merged results. Without residual_roots, pipeline stops earlier and we
+ # may miss some data.
+ # We also don't partition fired timer inputs for the same reason.
+ bundle_manager._num_workers = 1
while True:
- last_result, deferred_inputs, fired_timers = self._run_bundle(
- runner_execution_context,
- bundle_context_manager,
- data_input,
- data_output,
- input_timers,
- expected_timer_output,
- bundle_manager)
-
- final_result = merge_results(last_result)
- if not deferred_inputs and not fired_timers:
- break
+ deferred_inputs = {} # type: Dict[str, PartitionableBuffer]
+ fired_timers = {}
+
+ self._collect_written_timers_and_add_to_fired_timers(
+ bundle_context_manager, fired_timers)
+ # Queue any process-initiated delayed bundle applications.
+ for delayed_application in last_result.process_bundle.residual_roots:
+ name = bundle_context_manager.input_for(
+ delayed_application.application.transform_id,
+ delayed_application.application.input_id)
+ if name not in deferred_inputs:
+ deferred_inputs[name] = ListBuffer(
+ coder_impl=bundle_context_manager.get_input_coder_impl(name))
+ deferred_inputs[name].append(delayed_application.application.element)
+ # Queue any runner-initiated delayed bundle applications.
+ self._add_residuals_and_channel_splits_to_deferred_inputs(
+ splits, bundle_context_manager, last_sent, deferred_inputs)
+
+ if deferred_inputs or fired_timers:
+ # The worker will be waiting on these inputs as well.
+ for other_input in data_input:
+ if other_input not in deferred_inputs:
+ deferred_inputs[other_input] = ListBuffer(
+ coder_impl=bundle_context_manager.get_input_coder_impl(
+ other_input))
+ # TODO(robertwb): merge results
+ last_result, splits = bundle_manager.process_bundle(
+ deferred_inputs, data_output, fired_timers, expected_timer_output)
+ last_sent = deferred_inputs
+ result = beam_fn_api_pb2.InstructionResponse(
+ process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
+ monitoring_infos=monitoring_infos.consolidate(
+ itertools.chain(
+ result.process_bundle.monitoring_infos,
+ last_result.process_bundle.monitoring_infos))),
+ error=result.error or last_result.error)
else:
- data_input = deferred_inputs
- input_timers = fired_timers
- bundle_manager._registered = True
+ break
- # Store the required downstream side inputs into state so it is accessible
- # for the worker when it runs bundles that consume this stage's output.
- data_side_input = (
- runner_execution_context.side_input_descriptors_by_stage.get(
- bundle_context_manager.stage.name, {}))
- runner_execution_context.commit_side_inputs_to_state(data_side_input)
+ return result
- return final_result
+ @staticmethod
+ def _extract_endpoints(bundle_context_manager, # type: execution.BundleContextManager
+ runner_execution_context, # type: execution.FnApiRunnerExecutionContext
+ ):
+ # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput]
- def _run_bundle(
- self,
- runner_execution_context,
- bundle_context_manager,
- data_input,
- data_output,
- input_timers,
- expected_timer_output,
- bundle_manager):
- """Execute a bundle, and return a result object, and deferred inputs."""
- self._run_bundle_multiple_times_for_testing(
- runner_execution_context,
- bundle_manager,
- data_input,
- data_output,
- input_timers,
- expected_timer_output)
-
- result, splits = bundle_manager.process_bundle(
- data_input, data_output, input_timers, expected_timer_output)
- # Now we collect all the deferred inputs remaining from bundle execution.
- # Deferred inputs can be:
- # - timers
- # - SDK-initiated deferred applications of root elements
- # - Runner-initiated deferred applications of root elements
- deferred_inputs = {} # type: Dict[str, execution.PartitionableBuffer]
- fired_timers = {}
-
- self._collect_written_timers_and_add_to_fired_timers(
- bundle_context_manager, fired_timers)
-
- self._add_sdk_delayed_applications_to_deferred_inputs(
- bundle_context_manager, result, deferred_inputs)
-
- self._add_residuals_and_channel_splits_to_deferred_inputs(
- splits, bundle_context_manager, data_input, deferred_inputs)
-
- # After collecting deferred inputs, we 'pad' the structure with empty
- # buffers for other expected inputs.
- if deferred_inputs or fired_timers:
- # The worker will be waiting on these inputs as well.
- for other_input in data_input:
- if other_input not in deferred_inputs:
- deferred_inputs[other_input] = ListBuffer(
- coder_impl=bundle_context_manager.get_input_coder_impl(
- other_input))
-
- return result, deferred_inputs, fired_timers
+ """Returns maps of transform names to PCollection identifiers.
+
+ Also mutates IO stages to point to the data ApiServiceDescriptor.
+
+ Args:
+ stage (translations.Stage): The stage to extract endpoints
+ for.
+ data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
+ Returns:
+ A tuple of (data_input, data_side_input, data_output) dictionaries.
+ `data_input` is a dictionary mapping (transform_name, output_name) to a
+ PCollection buffer; `data_output` is a dictionary mapping
+ (transform_name, output_name) to a PCollection ID.
+ """
+ data_input = {} # type: Dict[str, PartitionableBuffer]
+ data_side_input = {} # type: DataSideInput
+ data_output = {} # type: DataOutput
+ # A mapping of {(transform_id, timer_family_id) : buffer_id}
+ expected_timer_output = {} # type: Dict[Tuple(str, str), str]
+ for transform in bundle_context_manager.stage.transforms:
+ if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
+ bundle_processor.DATA_OUTPUT_URN):
+ pcoll_id = transform.spec.payload
+ if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
+ coder_id = runner_execution_context.data_channel_coders[only_element(
+ transform.outputs.values())]
+ coder = runner_execution_context.pipeline_context.coders[
+ runner_execution_context.safe_coders.get(coder_id, coder_id)]
+ if pcoll_id == translations.IMPULSE_BUFFER:
+ data_input[transform.unique_name] = ListBuffer(
+ coder_impl=coder.get_impl())
+ data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
+ else:
+ if pcoll_id not in runner_execution_context.pcoll_buffers:
+ runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
+ coder_impl=coder.get_impl())
+ data_input[transform.unique_name] = (
+ runner_execution_context.pcoll_buffers[pcoll_id])
+ elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
+ data_output[transform.unique_name] = pcoll_id
+ coder_id = runner_execution_context.data_channel_coders[only_element(
+ transform.inputs.values())]
+ else:
+ raise NotImplementedError
+ data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
+ data_api_service_descriptor = (
+ bundle_context_manager.data_api_service_descriptor())
+ if data_api_service_descriptor:
+ data_spec.api_service_descriptor.url = (
+ data_api_service_descriptor.url)
+ transform.spec.payload = data_spec.SerializeToString()
+ elif transform.spec.urn in translations.PAR_DO_URNS:
+ payload = proto_utils.parse_Bytes(
+ transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+ for tag, si in payload.side_inputs.items():
+ data_side_input[transform.unique_name, tag] = (
+ create_buffer_id(transform.inputs[tag]), si.access_pattern)
+ for timer_family_id in payload.timer_family_specs.keys():
+ expected_timer_output[(transform.unique_name, timer_family_id)] = (
+ create_buffer_id(timer_family_id, 'timers'))
+ return data_input, data_side_input, data_output, expected_timer_output
@staticmethod
def get_cache_token_generator(static=True):
@@ -672,18 +765,28 @@ class BundleManager(object):
_lock = threading.Lock()
def __init__(self,
- bundle_context_manager, # type: execution.BundleContextManager
+ worker_handler_list, # type: Sequence[WorkerHandler]
+ get_buffer, # type: Callable[[bytes, str], PartitionableBuffer]
+ get_input_coder_impl, # type: Callable[[str], CoderImpl]
+ bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor
progress_frequency=None,
cache_token_generator=FnApiRunner.get_cache_token_generator()
):
"""Set up a bundle manager.
Args:
+ worker_handler_list
+ get_buffer (Callable[[str], list])
+ get_input_coder_impl (Callable[[str], Coder])
+ bundle_descriptor (beam_fn_api_pb2.ProcessBundleDescriptor)
progress_frequency
"""
- self.bundle_context_manager = bundle_context_manager # type: execution.BundleContextManager
+ self._worker_handler_list = worker_handler_list
+ self._get_buffer = get_buffer
+ self._get_input_coder_impl = get_input_coder_impl
+ self._bundle_descriptor = bundle_descriptor
self._progress_frequency = progress_frequency
- self._worker_handler = None # type: Optional[execution.WorkerHandler]
+ self._worker_handler = None # type: Optional[WorkerHandler]
self._cache_token_generator = cache_token_generator
def _send_input_to_worker(self,
@@ -711,8 +814,7 @@ class BundleManager(object):
def _select_split_manager(self):
"""TODO(pabloem) WHAT DOES THIS DO"""
unique_names = set(
- t.unique_name for t in self.bundle_context_manager.
- process_bundle_descriptor.transforms.values())
+ t.unique_name for t in self._bundle_descriptor.transforms.values())
for stage_name, candidate in reversed(_split_managers):
if (stage_name in unique_names or
(stage_name + '/Process') in unique_names):
@@ -733,8 +835,8 @@ class BundleManager(object):
byte_stream = b''.join(buffer_data)
num_elements = len(
list(
- self.bundle_context_manager.get_input_coder_impl(
- read_transform_id).decode_all(byte_stream)))
+ self._get_input_coder_impl(read_transform_id).decode_all(
+ byte_stream)))
# Start the split manager in case it wants to set any breakpoints.
split_manager_generator = split_manager(num_elements)
@@ -787,20 +889,18 @@ class BundleManager(object):
return split_results
def process_bundle(self,
- inputs, # type: Mapping[str, execution.PartitionableBuffer]
+ inputs, # type: Mapping[str, PartitionableBuffer]
expected_outputs, # type: DataOutput
- fired_timers, # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
- expected_output_timers, # type: Dict[Tuple[str, str], str]
- dry_run=False,
+ fired_timers, # type: Mapping[Tuple[str, str], PartitionableBuffer]
+ expected_output_timers # type: Dict[str, Dict[str, str]]
):
# type: (...) -> BundleProcessResult
# Unique id for the instruction processing this bundle.
with BundleManager._lock:
BundleManager._uid_counter += 1
process_bundle_id = 'bundle_%s' % BundleManager._uid_counter
- self._worker_handler = self.bundle_context_manager.worker_handlers[
- BundleManager._uid_counter %
- len(self.bundle_context_manager.worker_handlers)]
+ self._worker_handler = self._worker_handler_list[
+ BundleManager._uid_counter % len(self._worker_handler_list)]
split_manager = self._select_split_manager()
if not split_manager:
@@ -820,8 +920,7 @@ class BundleManager(object):
process_bundle_req = beam_fn_api_pb2.InstructionRequest(
instruction_id=process_bundle_id,
process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
- process_bundle_descriptor_id=self.bundle_context_manager.
- process_bundle_descriptor.id,
+ process_bundle_descriptor_id=self._bundle_descriptor.id,
cache_tokens=[next(self._cache_token_generator)]))
result_future = self._worker_handler.control_conn.push(process_bundle_req)
@@ -843,15 +942,15 @@ class BundleManager(object):
expect_reads,
abort_callback=lambda:
(result_future.is_done() and result_future.get().error)):
- if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run:
+ if isinstance(output, beam_fn_api_pb2.Elements.Timers):
with BundleManager._lock:
- self.bundle_context_manager.get_buffer(
+ self._get_buffer(
expected_output_timers[(
output.transform_id, output.timer_family_id)],
output.transform_id).append(output.timers)
- if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run:
+ if isinstance(output, beam_fn_api_pb2.Elements.Data):
with BundleManager._lock:
- self.bundle_context_manager.get_buffer(
+ self._get_buffer(
expected_outputs[output.transform_id],
output.transform_id).append(output.data)
@@ -874,32 +973,32 @@ class ParallelBundleManager(BundleManager):
def __init__(
self,
- bundle_context_manager, # type: execution.BundleContextManager
+ worker_handler_list, # type: Sequence[WorkerHandler]
+ get_buffer, # type: Callable[[bytes, str], PartitionableBuffer]
+ get_input_coder_impl, # type: Callable[[str], CoderImpl]
+ bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor
progress_frequency=None,
cache_token_generator=None,
**kwargs):
# type: (...) -> None
super(ParallelBundleManager, self).__init__(
- bundle_context_manager,
+ worker_handler_list,
+ get_buffer,
+ get_input_coder_impl,
+ bundle_descriptor,
progress_frequency,
cache_token_generator=cache_token_generator)
- self._num_workers = bundle_context_manager.num_workers
+ self._num_workers = kwargs.pop('num_workers', 1)
def process_bundle(self,
- inputs, # type: Mapping[str, execution.PartitionableBuffer]
+ inputs, # type: Mapping[str, PartitionableBuffer]
expected_outputs, # type: DataOutput
- fired_timers, # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
- expected_output_timers, # type: Dict[Tuple[str, str], str]
- dry_run=False,
- ):
+ fired_timers, # type: Mapping[Tuple[str, str], PartitionableBuffer]
+ expected_output_timers # type: Dict[Tuple[str, str], str]
+ ):
# type: (...) -> BundleProcessResult
part_inputs = [{} for _ in range(self._num_workers)
] # type: List[Dict[str, List[bytes]]]
- # Timers are only executed on the first worker
- # TODO(BEAM-9741): Split timers to multiple workers
- timer_inputs = [
- fired_timers if i == 0 else {} for i in range(self._num_workers)
- ]
for name, input in inputs.items():
for ix, part in enumerate(input.partition(self._num_workers)):
part_inputs[ix][name] = part
@@ -908,23 +1007,21 @@ class ParallelBundleManager(BundleManager):
split_result_list = [
] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
- def execute(part_map_input_timers):
+ def execute(part_map):
# type: (...) -> BundleProcessResult
- part_map, input_timers = part_map_input_timers
bundle_manager = BundleManager(
- self.bundle_context_manager,
+ self._worker_handler_list,
+ self._get_buffer,
+ self._get_input_coder_impl,
+ self._bundle_descriptor,
self._progress_frequency,
cache_token_generator=self._cache_token_generator)
return bundle_manager.process_bundle(
- part_map,
- expected_outputs,
- input_timers,
- expected_output_timers,
- dry_run)
+ part_map, expected_outputs, fired_timers, expected_output_timers)
with UnboundedThreadPoolExecutor() as executor:
- for result, split_result in executor.map(execute, zip(part_inputs, # pylint: disable=zip-builtin-not-iterating
- timer_inputs)):
+ for result, split_result in executor.map(execute, part_inputs):
+
split_result_list += split_result
if merged_result is None:
merged_result = result
@@ -937,6 +1034,7 @@ class ParallelBundleManager(BundleManager):
merged_result.process_bundle.monitoring_infos))),
error=result.error or merged_result.error)
assert merged_result is not None
+
return merged_result, split_result_list
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
index ce99d87..0ff0853 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
@@ -240,30 +240,6 @@ class FnApiRunnerTest(unittest.TestCase):
lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
equal_to([('a', [1, 3]), ('b', [2])]))
- def test_multimap_multiside_input(self):
- # A test where two transforms in the same stage consume the same PCollection
- # twice as side input.
- with self.create_pipeline() as p:
- main = p | 'main' >> beam.Create(['a', 'b'])
- side = (
- p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)])
- # TODO(BEAM-4782): Obviate the need for this map.
- | beam.Map(lambda kv: (kv[0], kv[1])))
- assert_that(
- main | 'first map' >> beam.Map(
- lambda k,
- d,
- l: (k, sorted(d[k]), sorted([e[1] for e in l])),
- beam.pvalue.AsMultiMap(side),
- beam.pvalue.AsList(side))
- | 'second map' >> beam.Map(
- lambda k,
- d,
- l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])),
- beam.pvalue.AsMultiMap(side),
- beam.pvalue.AsList(side)),
- equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])]))
-
def test_multimap_side_input_type_coercion(self):
with self.create_pipeline() as p:
main = p | 'main' >> beam.Create(['a', 'b'])
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
index 235aec8..5d18d29 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
@@ -75,17 +75,6 @@ PAR_DO_URNS = frozenset([
IMPULSE_BUFFER = b'impulse'
-# SideInputId is identified by a consumer ParDo + tag.
-SideInputId = Tuple[str, str]
-SideInputAccessPattern = beam_runner_api_pb2.FunctionSpec
-
-DataOutput = Dict[str, bytes]
-
-# DataSideInput maps SideInputIds to a tuple of the encoded bytes of the side
-# input content, and a payload specification regarding the type of side input
-# (MultiMap / Iterable).
-DataSideInput = Dict[SideInputId, Tuple[bytes, SideInputAccessPattern]]
-
class Stage(object):
"""A set of Transforms that can be sent to the worker for processing."""